前段时间又开始写爬虫了,在整理代码时找到了很久之前写的一个小工具:一个基于路径查询Json数据的工具。
当时还没有ChatGPT,是纯手工打造的匠心代码😋,最近又找出来重构更新了一下,所以我干脆写篇文章来记录一下设计思路。

之前在采集某个平台的数据的时候,发现从接口获取到的数据杂乱无章,一层套一层,各种格式也很混乱,如果要从这些数据中获取到指定的字段,就需要套大量的.get().get().get()...最终导致代码也很混乱,可维护性降低。

基于这个需求,我当时想着搞一个可以方便快捷的获取某个字段的工具。(当时还不知道有Json Path这种东西,也没想着去搜一搜,可能就是码瘾犯了就想写点东西吧...)不过虽然这篇文章的标题叫“基于路径的JSON解析工具实现”,但实际上查询json数据的功能本身不太重要,本文重点是如何对查询语句进行解析。

注意:


目前来说该模块(几乎)所有功能Json Path都可以实现,所以如果有此类需求请直接使用更加标准化的Json Path,该模块用来研究和学习即可。

基本功能

先来了解一下这个模块可以实现哪些功能,然后基于这些功能说明一下具体是如何进行实现的。

下面使用的查询语句都是以这段代码结构为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    # 用于测试的示例JSON数据
    test_data = {
        "root": {
            "root_key": "root_value",
            "child": [
                ["first", "second"],
                ["third", "fourth"]
            ],
            "key.01": "value",
            "key[02]": "value2",
            ".": "pass",
            "exception": {"error": "trigger"},
            "data": [{"id": 1}, {"id": 2}, {"id": 3}],
            "info": {"list": ["first_list_item", {"details": "detail_value"}]},
            "empty": [],
            "items": [
                {
                    "value": 10,
                    "sub_value": 100
                },
                {
                    "value": 20,
                    "sub_value": 200
                },
                {
                    "value": 30,
                    "sub_value": 300
                }
            ],
            "array": [['a', 'b', 'c'], ['d', 'e', 'f']],
            "dictionary": {"key": "value", "invalid": None},
            "list": [
                {"id": 1, "name": "value1", "sub_id": "A", "sub_list": [5,6,7,8]},
                {"id": 2, "name": "value2", "sub_id": "A", "sub_list": [1,2,3,4]},
                {"id": 3, "name": "value3", "sub_id": "B", "sub_list": [9,10,11,12]},
                {"id": 2, "name": "value4", "sub_id": "B", "key": True, "sub_list": [5,6,7,8]}
            ],
            "number_list": [1,2,3,4,5,6,7,8,9]
        }
    }

基本路径查询

首先是最简单的基本路径查询,使用点表示法(.)或字典键访问语法,用于访问名称符合标准标识符规则(字母、数字、下划线开头,不包含特殊字符)的键。

比如对如上数据,如果你想查询root字典下的root_key的内容,使用以下查询语句:

1
"root.root_key"

或者使用Python的字典键访问语法:

1
2
3
4
5
6
'root["root_key"]'
# 或者使用单引号
"root['root_key']"
# 需要注意Python字符串引号嵌套的问题
".['root']['root_key']"    # 连续方括号查询,第一级使用方括号查询需要以点符号开头,其作用同下一条
"*.['root']['root_key']"   # 以通配符开头,使用通配符返回当前层级所有数据

上述几种方法效果相同,都可以得到查询结果: root_value

对于有特殊符号的查询路径,需要使用方括号查询,如下:

1
2
3
4
"root['.']"
# 以及
"root['key.01']"
"root['key[02]']"

原本是实现了转义符\对点操作符和方括号进行转义的,比如 root.key\.01 这种语法,但在重构1后这种语法不再被支持了。

列表索引和切片

除了对字典键的查询外,还支持对列表的索引和切片,索引和切片必须通过方括号来使用,而不能使用点表示法。

比如对number_list进行索引,如果需要查询下标为2的元素,直接使用如下语句即可:

1
'root.number_list[2]'

支持Python语法的基础切片模式,比如如下语法,这些就不多说了:

1
2
3
4
5
6
7
8
9
10
11
12
"root.number_list[1:4]"    # 基本切片
"root.number_list[::2]"    # 步长为2
"root.number_list[::-1]"    # 反向切片
"root.number_list[-3:]"    # 负数索引
"root.number_list[:3]"    # 省略start
"root.number_list[3:]"    # 省略end
"root.number_list[1:6:2]"    # 完整切片语法
"root.number_list[10:20]"    # 超出范围的切片
"root.number_list[-10:-5]"    # 修正期望值
"root.number_list[5:2:-1]"    # 反向步长切片
"root.number_list[2:5:0]"    # 步长为0(非法)
"root.number_list[2:5:1.5]"    # 非整数步长(非法)

基本条件查询

除了上述基本查询外,还支持基于比较符、逻辑运算符的条件查询,通过构建查询表达式来获取数据,比如:

1
2
3
4
5
"root.list['sub_id'=='A'].sub_list"
'root.list["id"<3].name'
'root.list["id"==2&&"name"=="value4"].sub_list'
'root.list["id"==2||"id"==3].sub_id'
'root.list[("id"==2||"id"==3)].sub_id'

"root.list["id"==2&&"name"=="value4"].sub_list" 为例,这段语句与json path中 $.root.list[?(@.id == 2 && @.name == "value4")] 功能一致,也就是查询list下id为2,name为value4的sub_list值,查询结果应该是 [[5, 6, 7, 8]] 注意是一个二维数组,一层是list的数组,一层是sub_list的数组。

其他几条语句类比即可。

同时,可用使用括号来指定运算优先级,如下语句:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
data = {
    "list": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20],
    "items": [
        {"type": "B", "name": "value1", "age": 10},
        {"type": "A", "name": "value2", "age": 20},
        {"type": "B", "name": "value3", "age": 30},
        {"type": "A", "name": "value4", "age": 40},
        {"type": "B", "name": "value5", "age": 50},
    ]
}

path = "items[('type'=='A' && 'age'<25) || ('age' >= 40 && 'type'=='B')].name"

# 这条语句的查询结果应该是['value2', 'value5']

脚本2、函数调用及变量引用

为了配合条件查询,实现更复杂和更灵活的查询情况,我引入了脚本和变量功能。具体用法是:使用@符号声明一段脚本函数调用,使用$符号对变量进行调用,如下语句:

1
path = "@get_max(1, $a)"

上面的语句中用到了get_max函数和a变量,但是在使用前需要先通过脚本管理器对函数和变量进行定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dictquerier import script_manager, query_json

@script_manager.register(name="get_max")
def get_max(*args):
    return max(args)

script_manager.define("a", 6)

data = {
    "list": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
}

path = "list[@get_max(3, 5, $a)]"
print(query_json(data, path))

上面的代码执行输出结果为7,解析器会先获取$a的值(也就是6),然后执行@get_max(3, 5, 6),得到结果6,最后获取list[6]得到7。

使用@script_manager.register()对函数进行注册,可以通过传递name参数来命名调用名,如果留空name,则会自动获取函数名称作为调用名。

执行函数时,会先检测已经通过@script_manager.register()进行注册的函数,如果没有找到对应的注册函数,则会继续寻找Python内置函数或尝试对函数进行导入,比如如下代码:

1
2
3
4
5
6
7
8
9
from dictquerier import script_manager, query_json

# data 不重要,没有用到
data = {
    "list": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
}

path = "@random.randint(1, 10)"
print(query_json(data, path))

这段代码会尝试导入python的random模块,然后调用randint,最后返回一个1到10之间的随机整数;因为random是python的内置包,所以即使该脚本即使没有通过脚本管理器注册,也可以使用。

实现原理

接下来详细说明一下这个解析器是如何实现的。

早期实现方案

在最早期时,因为只是为了做一些简单的路径查询,只支持了点路径语法 "path.to.data",因为语法复杂度不高,所以直接手搓了一个简易的有限状态机3,简单来说就是逐字母扫描,扫描到的每个字母都保存在缓冲区中,如果遇到点符号(".")则将缓冲区中的字符串作为一个整体(我称其为元素),然后清空缓冲区继续扫描,直到完成整个查询语句的处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 伪代码,并非完全逻辑
def parse_path(path_str: str) -> list[str]:
    buffer = ""       # 临时存储当前路径元素的字符
    elements = []     # 存储解析后的路径元素列表
   
    for char in path_str:
        if char == ".":
            # 遇到分隔符时将缓冲区内容存入列表
            elements.append(buffer)
            buffer = ""
        else:
            # 非分隔符时累积字符到缓冲区
            buffer += char
   
    # 处理最后一个元素(点号后的剩余内容)
    if buffer:
        elements.append(buffer)
   
    return elements

使用这种方式开发简单,速度快,在O(1)的时间复杂度的情况下就可以完成整个语句的扫描。在扫描完成后,再循环所有元素,直接使用items[element]的方式进行获取。

但是很快这种方法就无法覆盖需求了。一般来说,从接口请求回来的数据不太可能都只有一条,所以我需要引入索引的方式对列表进行处理。于是我对基础语法进行了扩展,引入了方括号来进行下标索引,语法格式为 "path.to.items[index]",其中index为一个整数。为了处理这种结构,对循环体进行了修改,引入了对方括号的检测。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 伪代码,并非完全逻辑
def parse_path(path_str: str) -> list[str]:
    buffer = ""       # 临时存储当前路径元素的字符
    elements = []     # 存储解析后的路径元素列表
   
    total_length = len(path_str)
    char_index = 0
    while char_index < total_length:
        char = path_str[char_index]
        if char == ".":
            if buffer:
                elements.append(buffer)
                buffer = ""
        elif char == "[":
            elements.append(buffer)
            buffer = ""
            char_index += 1
            while path_str[char_index] != "]":
                buffer += path_str[char_index]
                char_index += 1
            elements.append(int(buffer))
            buffer = ""
        else:
            buffer += char
        char_index += 1

    # 处理最后一个元素(点号后的剩余内容)
    if buffer:
        elements.append(buffer)
   
    return elements

最后在查询时,在使用items[element]时先判断element是一个整数还是字符串,如果是字符串就是用item.get(),如果是整数索引就使用items[element]进行下标查询。

完成整数索引后,我又遇到了新的需求。某组数据中同时存在有效数据和无效数据,数据是否有效是通过某个字段来进行标记的,所以我需要增加一种过滤查询的方法,以实现按照指定条件进行查询,语法为 "path.to.items['type'=='A']"。同时既然有等于(==),那也应该引入>、<、!=、>=、<=等比较操作符。按照这种语法,解析器就需要进行更加复杂的构造。

为了方便对形如'type'=='A'的表达式进行管理,我创建了一个类专门用于表示表达式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Opreater(Enum):
    """
    Operator 枚举类
    用于定义支持的操作符类型,便于后续表达式解析和判断。
    """

    EQUAL = "=="
    NOT_EQUAL = "!="
    LESS_THAN = "<"
    GREATER_THAN = ">"
    LESS_EQUAL = "<="
    GREATER_EQUAL = ">="
    LOGICAL_AND = "&&"
    LOGICAL_OR = "||"
    SLICE = "slice"

    def __str__(self):
        return self.value

class Expression:
    operator_funcs = {
        Opreater.EQUAL: lambda x, y: x == y,
        Opreater.LESS_THAN: lambda x, y: x < y,
        Opreater.GREATER_THAN: lambda x, y: x > y,
        Opreater.LESS_EQUAL: lambda x, y: x <= y,
        Opreater.GREATER_EQUAL: lambda x, y: x >= y,
        Opreater.NOT_EQUAL: lambda x, y: x != y,
    }
    complex_operator_funcs = {
        Opreater.LOGICAL_AND: lambda x, y: x.right.operate(x.left.operate(y)),
        Opreater.LOGICAL_OR: lambda x, y: x.left.operate(y) + x.right.operate(y),
    }
   
    def __init__(self, key=None, operator=None, value=None, left=None, right=None):
        self.key = key
        self.operator = operator
        self.value = value
        self.left: Expression = left
        self.right: Expression = right
       
    def __repr__(self):
        if self.key:
            return f'Expression({self.key} {self.operator} {self.value})'
        elif self.value is not None:
            return f'Expression(value={self.value})'
        else:
            return f'({self.left} {self.operator} {self.right})'

我将运算符分为三类,一类是比较运算符,一类是逻辑运算符,还有一类是特殊的切片。对于这三类运算符,实际进行的运算操作也不同。对于比较标识符,其输入是需要处理的数据的key,以及目标值,输出是一组符合规则的数据,而逻辑运算符的输入则是两组数据,输出的数据是对这两组数据进行取交集和取并集的合集运算。至于切片运算符,则是和python自带的切片一样,这个就不多说了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class Opreater(Enum):
    ...
    def get_expression(self):
        if self.key:
            value = f'"{self.value}"' if isinstance(self.value, str) else self.value
            return f'"{self.key}" {self.operator} {value}'
        elif self.value is not None:
            return self.value
        else:
            return f'({self.left.get_expression()} {self.operator} {self.right.get_expression()})'

    def operate(self, data):
        # 基本运算符
        if self.operator in self.complex_operator_funcs:
            return self.complex_operator_funcs[self.operator](self, data)
       
        # 逻辑操作符
        if self.operator in self.operator_funcs:
            itempack_list = []
            for item_pack in data:
                result = query_json(item_pack, self.key, no_path_exception=True)
               
                if self.operator_funcs[self.operator](result, self.value):
                    itempack_list.append(item_pack)
            return itempack_list
       
        # 切片索引
        if self.operator == Opreater.SLICE:
            return data[self.value]
       
        else:
            raise SyntaxError(f"不支持的运算符: {self.operator}")

    @staticmethod
    def parse(expression: str) -> "Expression":
        expression = expression.strip()
        if expression.startswith("(") and expression.endswith(")"):
            return Expression.parse(expression[1:-1])

        index, operator = Expression.find_outer_operator(expression)
        if index != -1:
            return Expression(left=Expression.parse(expression[:index].strip()),
                              operator=operator,
                              right=Expression.parse(expression[index+len(operator.value):].strip()))

        # 处理比较表达式
        match = re.match(r'("[^"]+"|\'[^\']+\')\s*(==|!=|<=|>=|<|>)\s*(\d+|"[^"]+"|\'[^\']+\'|true|True|False|false|None|null)$', expression)
        if match:
            key, operator_str, value = match.groups()
            key = key.strip("'").strip('"')
            value = value.strip("'").strip('"')
            if value.isdigit():
                value = int(value)
            elif value.startswith('"') or value.startswith("'"):
                value = value[1:-1]  # 去除引号
           
            # 将字符串运算符转换为枚举
            operator = None
            for op_enum in Opreater:
                if op_enum.value == operator_str:
                    operator = op_enum
                    break
           
            if operator is None:
                raise SyntaxError(f"未知的运算符: {operator_str}")
               
            return Expression(key=key, operator=operator, value=value)

        # 处理简单文本或数字索引
        # 简单数字索引
        if expression.isdigit():
            return Expression(value=int(expression))
       
        # 双引号文本字符串
        if re.match(r'^"[^"]*"$', expression):
            return Expression(value=expression.strip('"'))
       
        # 单引号文本字符串
        if re.match(r"^'[^']*'$", expression):
            return Expression(value=expression.strip("'"))
       
        # 简单文本
        if re.match(r'^\w+$', expression):
            return Expression(value=expression)
       
        # 单个通配符
        if expression == "*":
            return Expression(value="*")

        # 处理复杂切片索引
        try:
            tree = ast.parse(f"__slice_check_{''.join(random.choices(string.ascii_letters + string.digits, k=8))}__[{expression}]", mode='eval')
        except Exception as e:
            # 这里是最后一步解析,如果到这一步都解析失败,说明是无效表达式
            raise SyntaxError(f"无效表达式: {expression}")
        if isinstance(tree, ast.Expression) or isinstance(tree.body, ast.Subscript):
            slice_node = tree.body.slice
            if isinstance(slice_node, ast.Slice):
                start = ast.literal_eval(slice_node.lower) if slice_node.lower else None
                end = ast.literal_eval(slice_node.upper) if slice_node.upper else None
                step = ast.literal_eval(slice_node.step) if slice_node.step else None
               
                if not (isinstance(start, int) or start is None) or not (isinstance(end, int) or end is None) or not (isinstance(step, int) or step is None):
                    raise ValueError(f"切片索引必须是整数或None: [{expression}]")
               
                if step == 0:
                    raise ValueError(f"切片步长不能为0: [{expression}]")
               
                return Expression(operator=Opreater.SLICE, value=slice(start, end, step))
           
            if isinstance(slice_node, ast.Tuple):
                # TODO 实现扩展索引(类似list[1,2:3]这种)
                raise NotImplementedError("扩展索引暂未实现")

       
        raise SyntaxError(f"无效表达式: {expression}")

    @staticmethod
    def find_outer_operator(expression):
        """查找表达式中的外层运算符

        Args:
            expression (str): 表达式字符串

        Returns:
            tuple: 包含外层运算符的索引和运算符,如果索引为-1,则表示没有找到外层运算符
        """

        depth = 0
        last_operator_index = -1
        operator = None
        escaped = False
        for i, char in enumerate(expression):
            if char == '\' and not escaped:  # 这里实际上是两个\,但是代码渲染似乎有些问题,不管几个\都只显示一个
                escaped = True
                continue
            if char == '
(' and not escaped:
                depth += 1
            elif char == '
)' and not escaped:
                depth -= 1
            elif depth == 0 and i + 1 < len(expression):
                # 检查是否是复杂运算符(&&, ||)
                for complex_op in [Opreater.LOGICAL_AND, Opreater.LOGICAL_OR]:
                    if expression[i:i+len(complex_op.value)] == complex_op.value:
                        last_operator_index = i
                        operator = complex_op
                        break
            escaped = False
        return last_operator_index, operator

其中在处理切片的部分,还偷懒用到了python的抽象语法树(ast模块,后续我们会实现自己的ast模块),这大大减少了手动解析切片的工作量。同时我们需要修改读取字符的循环,让其将表达式从方括号中分离出来,并传输给表达式类进行解析。

最终的表达式类结构是一个标准的树形结构,最外层的叶子节点应该始终是比较表达式,而中间节点应该始终是逻辑表达式4

随着语法和解析功能的扩展,问题也逐渐显现出来。使用这种方式虽然可以高效处理数据(不过实际上大多数情况使用差别并不大),但是每次引入新的语法就需要更改循环体,并手动增加处理逻辑。这种方式维护性极差,并且随着语法复杂化,状态转移会急剧增长。

为了解决这个问题,我尝试对代码进行了一次重构,完全放弃之前的处理方式,参考一般解释性语言的源代码处理方式5,引入词法、语句分析和抽象语法树(AST)的概念,这将帮助我们能够更好的进行语法和功能的扩展,以及方便地排除错误。

词法分析

词法分析是整个语句处理流程的第一步,其目的是将源代码(查询语句)转换成一组Token。

Token的结构为:

1
2
3
4
5
6
7
8
9
class Token:
    def __init__(self, type: TokenType, value, column=None, line=None):
        self.type = type
        self.value = value
        self.line = line
        self.column = column
   
    def __repr__(self):
        return f'Token({self.type}, {repr(self.value)}, at {self.column} line {self.line})'

其中使用TokenType来标记Token的类型,TokenType枚举如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class TokenType(Enum):
    """
    词法分析器的token类型
    """

    VARSIGN    = ("$", r"\$")                           # 变量符号,这个符号一般来说后面只能跟NAME
    SCRIPTSIGN = ("@", r"@")                            # 脚本符号,这个符号一般来说后面只能跟NAME
    DOT        = (".", r"\.")                           # .
    WHITESPACE = ("whitespace", r"\s+")                 # 空白符,暂时没用,因为在解析时会跳过
    OP         = ("op", r"==|!=|>=|<=|>|<|&&|\|\||[+\-*/<>]")    # 操作符,比如 ==, >, <, *, &&, ||等
    NAME       = ("name", r"[a-zA-Z_][a-zA-Z0-9_]*")    # 标识符
    NUMBER     = ("number", r"\d+(\.\d+)?([eE][+-]?\d+)?")        # 整数或浮点
    STRING     = ("string", r""""(?:\\.|[^"\\])*"|'(?:\\.|[^\'])*'""") # 引号字符串
    LBRACK     = ("[", r"\[")                           # [
    RBRACK     = ("]", r"\]")                           # ]
    LPAREN     = ("(", r"\(")                           # (
    RPAREN     = (")", r"\)")                           # )
    ASSIGN     = ("=", r"=")                            # =
    COLON      = (":", r":")                            # :
    COMMA      = (",", r",")                            # ,
    END        = ("EOF", r"$^")                         # 结束
    UNKNOWN    = ("UNKNOWN", r".")                      # 未知字符
   
    def __init__(self, literal, pattern):
        self._literal = literal
        self._pattern = pattern

    @property
    def literal(self):
        return self._literal

    @property
    def pattern(self):
        return self._pattern
   
    def __repr__(self) -> str:
        return f"TokenType(literal={self.literal}, pattern={self.pattern})"

然后需要实现一个词法分析器,将语句字符串转换成由Token类组成的序列,具体思路是:首先对一些特殊的语句进行处理,比如如果语句以点操作符开头,则自动在点操作符前补"*"(通配符);然后将TokenType中各Token类型的pattern生成一个正则组,对语句进行正则匹配。其中,由于END属于控制类符,无需从语句中进行匹配,所以正则规则设置的"$^"6,属于一个永远不会被匹配的占位符。使用finditer方法对规则组进行匹配,同时在匹配时维护行号7和列号,并跳过空格8

比如对于一段基本上涵盖了目前全部情况(相对来说)较为复杂的语句:

1
path = ".datas['items'].length['var' != $snum || ("id" >= 3 * 5 && 'age'< @max(10, num2=10 + 5))].n[:15:2]"

其转换成Token的结果为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
Token(TokenType.OP, '*', at 0 line 1)
Token(TokenType.DOT, '.', at 1 line 1)
Token(TokenType.NAME, 'datas', at 6 line 1)    
Token(TokenType.LBRACK, '[', at 7 line 1)      
Token(TokenType.STRING, "'items'", at 14 line 1)
Token(TokenType.RBRACK, ']', at 15 line 1)      
Token(TokenType.DOT, '.', at 16 line 1)
Token(TokenType.NAME, 'length', at 22 line 1)  
Token(TokenType.LBRACK, '[', at 23 line 1)      
Token(TokenType.STRING, "'var'", at 28 line 1)  
Token(TokenType.OP, '!=', at 31 line 1)
Token(TokenType.VARSIGN, '$', at 33 line 1)
Token(TokenType.NAME, 'snum', at 37 line 1)
Token(TokenType.OP, '||', at 40 line 1)
Token(TokenType.LPAREN, '(', at 42 line 1)
Token(TokenType.STRING, '"id"', at 46 line 1)
Token(TokenType.OP, '>=', at 49 line 1)
Token(TokenType.NUMBER, '3', at 51 line 1)
Token(TokenType.OP, '*', at 53 line 1)
Token(TokenType.NUMBER, '5', at 55 line 1)
Token(TokenType.OP, '&&', at 58 line 1)
Token(TokenType.STRING, "'age'", at 64 line 1)
Token(TokenType.OP, '<', at 65 line 1)
Token(TokenType.SCRIPTSIGN, '@', at 67 line 1)
Token(TokenType.NAME, 'max', at 70 line 1)
Token(TokenType.LPAREN, '(', at 71 line 1)
Token(TokenType.NUMBER, '10', at 73 line 1)
Token(TokenType.COMMA, ',', at 74 line 1)
Token(TokenType.NAME, 'num2', at 79 line 1)
Token(TokenType.ASSIGN, '=', at 80 line 1)
Token(TokenType.NUMBER, '10', at 82 line 1)
Token(TokenType.OP, '+', at 84 line 1)
Token(TokenType.NUMBER, '5', at 86 line 1)
Token(TokenType.RPAREN, ')', at 87 line 1)
Token(TokenType.RPAREN, ')', at 88 line 1)
Token(TokenType.RBRACK, ']', at 89 line 1)
Token(TokenType.DOT, '.', at 90 line 1)
Token(TokenType.NAME, 'n', at 91 line 1)
Token(TokenType.LBRACK, '[', at 92 line 1)
Token(TokenType.COLON, ':', at 93 line 1)
Token(TokenType.NUMBER, '15', at 95 line 1)
Token(TokenType.COLON, ':', at 96 line 1)
Token(TokenType.NUMBER, '2', at 97 line 1)
Token(TokenType.RBRACK, ']', at 98 line 1)
Token(TokenType.END, 'EOF', at 99 line 1)

这段表达式一共100个字符,45个Tokens。

语法分析

完成词法分析后,需要进行语法分析,并将词法分析得到的Token序列转换成一种带有语法规则的特定结构。

在早期实现方案中,并没有生成语法规则的步骤,而是按照元素逐个解释,这种实现方法对于语法扩展来说极差,实现任意新的语法都需要对代码进行大量修改;而在这个模块中,Token序列最后会被转换成抽象语法树。

早期的Python9使用手写 LL(1)10 语法分析器,后来引入 match-case、完善的类型提示等越来越多的语法后,使用 LL(1) 难以进行维护,所以在 Python 3.9 以后将解析器替换成了PEG11。相比于LL(1),PEG更加灵活对于复杂语法的开发和维护更加容易,语法表达能力也更加强大。

Python 中的pegen可以通过语法规则生成C语言的处理代码,大多数语法都在cpython/Grammar/python.gram文件中,我之前尝试过给Python增加do while语法功能,就是通过修改该文件来进行语法定义的,不过很可惜当时没有写文章记录修改过程。

运行do while语法,但vscode显然不太认识

不过我的GitHub仓库还留有修改后的源代码(同样的,当时还没有ChatGPT,属于纯手工匠心代码了):我称其为knpython,具体的语法定义如下:

1
2
3
4
5
6
7
8
9
10
11
# do while 语法规则
do_while_stmt[stmt_ty]:
    | invalid_do_stmt
    | 'do' ':' b=block 'while' a=named_expression NEWLINE c=[else_block] { _PyAST_DoWhile(a, b, c, EXTRA) }

# do while 错误检测子规则
invalid_do_stmt:
    | 'do' ':' block 'while' named_expression NEWLINE { RAISE_SYNTAX_ERROR("expected NEWLINE after 'while' condition") }
    | a='do' ':' NEWLINE !INDENT {
        RAISE_INDENTATION_ERROR("expected an indented block after 'do' statement on line %d", a->lineno) }
    | 'do' block 'while' named_expression { RAISE_SYNTAX_ERROR("expected ':' after 'do'") }

具体的内容就不展开解释了。

实际上,不管是LL(1)还是PEG,实际都是一种语法类别,用于对代码语法进行描述,正如上面的 do while 描述代码正是PEG的描述语法,最终这段语法会被生成实际解析的代码。

不过在这篇文章中,我们不会用到PEG(因为实现起来太复杂了,没必要,后续可能单独写文章来实现),我们选择手动实现一个递归下降解析器

递归下降解析器

与 LL(1) 和 PEG 不同,递归下降解析器不是语法类型,而是一种解析器实现策略。递归下降解析器将每个语法规则写成一个函数,函数之间递归调用,每个函数试图匹配输入的一部分,如果匹配失败就立即报错。LL(1) 和 PEG12 都可以通过递归下降解析器来实现。

正如其名,递归下降解析器使用递归13的方法来对语句进行处理,在本文实现的解析器中,递归调用处理流程如下:

1
2
3
解析逻辑表达式(||, &&) -> 解析比较表达式(>, <, >=, <=, ==, !=) ->
 解析加减法表达式(+, -) -> 解析乘除法表达式(*, /) ->
 解析路径表达式(obj.key 或 obj[index]) -> 解析基本表达式(变量、脚本调用、字面量等)

不过由于是递归调用,所以实际上处理流程和上面的箭头是相反的,也就是说是先处理基本表达式,最后处理逻辑表达式。同时在这个调用链的流程中也体现出了整个语句的处理优先级,越往里的越优先处理,比如在处理路径前先解析变量和脚本调用,在处理加减法前先处理乘除法等等。

首先是一些基本操作,比如消耗当前Token、预览下一个Token、检查Token类型等工具方法,实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Parser:
    """
    递归下降解析器,将Token序列解析为抽象语法树
    """

    def __init__(self, tokens: Iterator[Token]):
        self.tokens = [token for token in tokens if token.type != TokenType.WHITESPACE]  # 过滤掉所有空白符
        self.current = 0
        self.current_token = self.tokens[0] if self.tokens else None

    def parse(self) -> ASTNode:
        """
        解析入口,解析完整表达式并返回AST根节点
        """

        if not self.tokens:
            raise SyntaxError("没有可解析的令牌")
       
        # 常规解析
        result = self.expr()
       
        # 确保所有令牌都已解析(除了END)
        if self.current_token and self.current_token.type != TokenType.END:
            self.error(f"解析结束后仍有未处理的令牌: {self.current_token}")
           
        return result

    def error(self, message: str):
        """抛出语法错误异常"""
        line = self.current_token.line if self.current_token else "未知"
        column = self.current_token.column if self.current_token else "未知"
        raise SyntaxError(f"{message},在行 {line} 列 {column}")

    def advance(self, offset: int = 1):
        """前进到下一个令牌"""
        self.current += offset
        if self.current < len(self.tokens):
            self.current_token = self.tokens[self.current]
        else:
            self.current_token = None

    def peek(self, offset: int = 1) -> Optional[Token]:
        """预览后面的令牌,不消耗当前令牌"""
        pos = self.current + offset
        if 0 <= pos < len(self.tokens):
            return self.tokens[pos]
        return None

    def match(self, *types: TokenType) -> bool:
        """检查当前令牌是否匹配指定类型之一"""
        if self.current_token and self.current_token.type in types:
            self.advance()
            return True
        return False

    def expect(self, *types: TokenType) -> Token:
        """期望当前令牌是指定类型之一,否则报错"""
        if self.current_token and self.current_token.type in types:
            token = self.current_token
            self.advance()
            return token
       
        expected = " 或 ".join(t.literal for t in types)
        got = self.current_token.type.literal if self.current_token else "EOF"
        self.error(f"期望 {expected},但得到了 {got}")

在这段代码中使用到了一个类型 ASTNode,这是专门用于描述AST节点的类,不同节点有不同的结构和参数。下面的代码是所有AST类,包括标识符、数字、字符串、函数、变量等,都是一些结构定义,没有什么逻辑代码,就不展开说明了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class ASTNode:
    """
    抽象语法树节点基类
    """
   
    def __init__(self, type_: str, line: Optional[int] = None, column: Optional[int] = None) -> None:
        self.type: str = type_
        self.line: Optional[int] = line
        self.column: Optional[int] = column

    def __repr__(self) -> str:
        return f"{self.type}({self.__dict__})"

class NameNode(ASTNode):
    """
    标识符节点
    """

    def __init__(self, name: str, line: Optional[int] = None, column: Optional[int] = None) -> None:
        super().__init__(self.__class__.__name__, line, column)
        self.name: str = name

class NumberNode(ASTNode):
    """
    数字节点
    """

    def __init__(self, value: str, line: Optional[int] = None, column: Optional[int] = None) -> None:
        super().__init__(self.__class__.__name__, line, column)
        # 处理带负号的数字字符串
        self.value: Union[int, float] = float(value) if '.' in value or 'e' in value.lower() else int(value)
       
class StringNode(ASTNode):
    """
    字符串节点
    """

    def __init__(self, value: str, line: Optional[int] = None, column: Optional[int] = None) -> None:
        super().__init__(self.__class__.__name__, line, column)
        self.value: str = value
       


class VarRefNode(ASTNode):
    """
    变量引用节点
    """

    def __init__(self, name: NameNode, line: Optional[int] = None, column: Optional[int] = None) -> None:
        super().__init__(self.__class__.__name__, line, column)
        self.name: NameNode = name

class ScriptCallNode(ASTNode):
    """
    脚本调用节点
    """

    def __init__(self, module: NameNode, name: NameNode, args: List[ASTNode], kwargs: Dict[str, ASTNode], line: Optional[int] = None, column: Optional[int] = None) -> None:
        super().__init__(self.__class__.__name__, line, column)
        self.module: List[NameNode] = module
        self.name: NameNode = name
        self.args: List[ASTNode] = args
        self.kwargs: Dict[str, ASTNode] = kwargs

基本表达式解析

接下来是处理语句的逻辑,我们从内向外实现,首先是基本表达式的解析:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class Parser:
    ...
   
    def primary(self) -> ASTNode:
        """解析基本表达式(变量、脚本调用、字面量等)"""
        if not self.current_token:
            self.error("意外结束,期望表达式")
       
        token = self.current_token
       
        # 检查是否是单独的通配符(*)
        if token.type == TokenType.OP and token.value == '*':
            line, column = token.line, token.column
            self.advance()
            # 创建一个特殊的通配符节点
            # 这里使用一个空的名字,不依赖于特定键名
            root_node = NameNode('', line=line, column=column)
            # 添加一个标记,表示这是根级别的通配符
            root_node._is_root_wildcard = True
            return KeyNode(root_node, '*', is_wildcard=True, line=line, column=column)
       
        # 检查负号开头的负数
        if token.type == TokenType.OP and token.value == '-':
            line, column = token.line, token.column
            self.advance()
           
            # 确保负号后面跟着数字
            if self.current_token and self.current_token.type == TokenType.NUMBER:
                number_token = self.current_token
                value = "-" + number_token.value
                self.advance()
                return NumberNode(value, line, column)
            else:
                # 如果不是负数,回退到上一个token(负号),让正常的二元运算处理
                self.current -= 1
                self.current_token = token
       
        # 变量引用: $name
        if token.type == TokenType.VARSIGN:
            line, column = token.line, token.column
            self.advance()
            name_token = self.expect(TokenType.NAME)
            name_node = NameNode(name_token.value, name_token.line, name_token.column)
            return VarRefNode(name_node, line, column)
           
        # 脚本调用: @func(args)
        elif token.type == TokenType.SCRIPTSIGN:
            line, column = token.line, token.column
            self.advance()

            module_node = []
            name_node = None
           
            # 判断是否是"模块.函数"的形式
            while not (self.current_token and self.current_token.type == TokenType.LPAREN):
                next_token = self.peek()
                if next_token and next_token.type == TokenType.DOT:
                    module_node.append(NameNode(self.current_token.value, self.current_token.line, self.current_token.column))
                    # 前进两步,跳过"."
                    self.advance(2)
                else:
                    # 路径最后一级走此分支
                    name_node = NameNode(self.current_token.value, self.current_token.line, self.current_token.column)
                    self.advance()
                    break
           
            if name_node is None:
                self.error(f"脚本解析失败: 未找到函数名")
           
            # 解析参数列表
            args = []
            kwargs = {}
            if self.current_token and self.current_token.type == TokenType.LPAREN:
                self.advance()
               
                # 解析参数
                if self.current_token and self.current_token.type != TokenType.RPAREN:
                    # 解析第一个参数
                    if self.current_token and self.current_token.type == TokenType.NAME:
                        # 关键词参数
                        key_node = self.expr()
                        self.expect(TokenType.ASSIGN)
                        kwargs[key_node] = self.expr()
                    else:
                        # 位置参数
                        args.append(self.expr())
                   
                    # 解析逗号分隔的后续参数
                    while self.current_token and self.current_token.type == TokenType.COMMA:
                        self.advance()
                        if self.current_token and self.current_token.type == TokenType.NAME:
                            # 关键词参数
                            key_node = self.expr()
                            self.expect(TokenType.ASSIGN)
                            kwargs[key_node] = self.expr()
                        else:
                            # 位置参数
                            args.append(self.expr())
                           
                # 确保参数列表以右括号结束
                self.expect(TokenType.RPAREN)
           
            return ScriptCallNode(module_node, name_node, args, kwargs, line, column)
           
        elif token.type == TokenType.NAME:
            # 标识符或键名: name
            name = token.value
            line, column = token.line, token.column
            self.advance()
            return NameNode(name, line, column)
           
        elif token.type == TokenType.NUMBER:
            # 数字字面量: 123, 45.67
            value = token.value
            line, column = token.line, token.column
            self.advance()
           
            return NumberNode(value, line, column)
           
        elif token.type == TokenType.STRING:
            # 字符串字面量: "text", 'text'
            value = self._parse_string_literal(token.value)
            line, column = token.line, token.column
            self.advance()
            return StringNode(value, line, column)
           
        elif token.type == TokenType.LPAREN:
            # 括号表达式: (expr)
            self.advance()
            # 取括号内的内容,按照一段完整语句进行递归处理
            expr = self.expr()
            self.expect(TokenType.RPAREN)
            return expr
           
        else:
            self.error(f"意外的令牌类型: {token.type.literal}")

    def _parse_string_literal(self, string_literal: str) -> str:
        """解析字符串字面量,去除引号并处理转义字符"""
        # 去除开头和结尾的引号
        inside = string_literal[1:-1]
       
        # 简单处理转义字符
        result = ""
        i = 0
        while i < len(inside):
            if inside[i] == '\' and i + 1 < len(inside):
                # 处理转义字符
                if inside[i+1] in ('
"', "'", '\'):
                    result += inside[i+1]
                    i += 2
                else:
                    result += inside[i:i+2]
                    i += 2
            else:
                result += inside[i]
                i += 1
               
        return result

在这段代码中,首先会判断当前token的类型,并按照每种类型进行分别处理,并返回对应的最终的节点类型。比如单纯的字面量,当判断token的类型为 TokenType.NUMBER 或 TokenType.STRING 时,就会直接返回对应的 NumberNode 和 StringNode;如果是括号表达式,就会让括号中的语句进行下一层递归处理;如果是操作符,则会先判断是否是通配符("*")或者负号("-"),对于通配符,会创建一个没有名称(NameNode的name为空)的 KeyNode,并且对节点标记 node._is_root_wildcard = True。之所以需要标记 _is_root_wildcard,主要是为了与路径中的通配符进行区分(路径中的通配符会在后面的路径解析中处理),在这个模块中,通配符用于占位根节点或在路径中获取所有数据,比如上面提到的例子 "*.['root']['root_key']" 这里就是对这种情况的特殊处理。

这一步处理中,需要展开说明的是变量和函数的解析,最主要的是函数处理。在这个模块中,变量是以$开头的Name Token,其中$符在进行tokenization时会被识别为 TokenType.VARSIGN,在进行语法解析时,检测到 VARSIGN 就会对后面的 NameToken 进行处理,VARSIGN Token 和后续的 Name Token 会被定义为 VarRefNode。同样脚本(函数)也会进行类似处理,只是会多一步变量解析。在解析脚本时,会先判断脚本是否是”模块.函数()“的格式14。如果是这种格式,则会先把所有的模块分离出来,最后一级处理成函数名,最后在创建脚本调用节点 ScriptCallNode 时会分别传入模块和函数名。同样变量也需要分两种类型处理,一种是仅包含参数值,比如”func(15)“;一种是包含形参名的变量,比如”func(num=15)“,检测到左括号 TokenType.LPAREN 时,会进入参数处理的状态,然后判断 Token 类型,如果是 Name Token 则将其作为带形参名的变量处理,否则直接作为普通参数处理。完成一个参数处理后判断后续是逗号 TokenType.COMMA 还是右括号 TokenType.RPAREN,如果是逗号则继续处理下一个参数,否则完成参数的处理。

路径解析

完成基本表达式的解析后,进入路径解析,并将基本表达式创建的节点作为left(语法树的左分支)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class Parser:
    ...

    def path(self) -> ASTNode:
        """解析路径表达式 (obj.key 或 obj[index])"""
        left = self.primary()
       
        def _parse_slice_parts() -> tuple:
            """
            解析切片的end和step部分
           
            返回:
                tuple: (end, step) - 切片的end和step部分
            """

            # 检查是否是第二个冒号,如 [start?::step]
            if self.current_token and self.current_token.type == TokenType.COLON:
                # 是 [start?::step] 形式,end为None
                end = None
                self.advance()  # 跳过第二个冒号
                # 解析step,如果有的话
                step = self.expr() if self.current_token and self.current_token.type != TokenType.RBRACK else None
            else:
                # [start?:end] 形式
                end = self.expr() if self.current_token and self.current_token.type != TokenType.RBRACK else None
               
                # 检查是否有第二个冒号,如 [start?:end?:]
                if self.current_token and self.current_token.type == TokenType.COLON:
                    self.advance()  # 跳过第二个冒号
                    step = self.expr() if self.current_token and self.current_token.type != TokenType.RBRACK else None
                else:
                    step = None
           
            return end, step
       
        while self.current_token:
            if self.current_token.type == TokenType.DOT:
                # 字典键访问: obj.key
                self.advance()
               
                # 检查通配符 obj.*
                if self.current_token and self.current_token.type == TokenType.OP and self.current_token.value == '*':
                    line, column = self.current_token.line, self.current_token.column
                    self.advance()
                    left = KeyNode(left, '*', is_wildcard=True, line=line, column=column)
                # 键名可能是NAME或STRING
                elif self.current_token.type == TokenType.NAME:
                    key_name = self.current_token.value
                    line, column = self.current_token.line, self.current_token.column
                    self.advance()
                    left = KeyNode(left, key_name, line=line, column=column)
                elif self.current_token.type == TokenType.STRING:
                    key_name = self._parse_string_literal(self.current_token.value)
                    line, column = self.current_token.line, self.current_token.column
                    self.advance()
                    left = KeyNode(left, key_name, line=line, column=column)
                else:
                    self.error(f"键访问后期望标识符或字符串,但得到了 {self.current_token.type.literal}")
           
            elif self.current_token.type == TokenType.LBRACK:
                # 索引访问: obj[index]
                line, column = self.current_token.line, self.current_token.column
                self.advance()
               
                # 检查是否是通配符索引 obj[*]
                if self.current_token and self.current_token.type == TokenType.OP and self.current_token.value == '*':
                    wildcard_token = self.current_token
                    self.advance()
                    # 确保通配符后面是右括号
                    self.expect(TokenType.RBRACK)
                    # 创建一个特殊的StringNode表示通配符
                    wildcard_node = StringNode('*', wildcard_token.line, wildcard_token.column)
                    left = IndexNode(left, wildcard_node, line, column)
                else:
                    # 检查是否是切片 obj[start:end:step]
                    if self.current_token and self.current_token.type == TokenType.COLON:
                        # 是 [:end:step] 形式,start为None
                        start = None
                        self.advance()  # 跳过第一个冒号
                        end, step = _parse_slice_parts()
                        self.expect(TokenType.RBRACK)
                        left = SliceNode(left, start, end, step, line, column)
                    else:
                        # 先解析第一个表达式
                        index_expr = self.expr()
                       
                        # 检查是否有冒号,表示这是切片的开始
                        if self.current_token and self.current_token.type == TokenType.COLON:
                            # 是 [start:end:step] 形式
                            start = index_expr
                            self.advance()  # 跳过冒号
                            end, step = _parse_slice_parts()
                            self.expect(TokenType.RBRACK)
                            left = SliceNode(left, start, end, step, line, column)
                        else:
                            # 普通索引访问 obj[index]
                            self.expect(TokenType.RBRACK)
                            left = IndexNode(left, index_expr, line, column)
           
            else:
                # 不是路径访问操作,跳出循环
                break
               
        return left

对于路径的解析,一共有两种情况:一是点操作符,二是方括号。对于点操作符,首先会检查是否是通配符"*",这里的通配符和根节点的通配符功能相同,处理方式也大致一样,之所以分两步处理只是因为通配符所处未知不同,在根目录的通配符需要进行一些路径兼容处理。随后判断是否是Name或者String,分别对应 "path.'to'.data",其中path和data是Name,'to'是String,分两种判断只是对路径中字符串进行一种兼容。而两种类型的Token的处理方式也完全一致,只是String会调用_parse_string_literal来处理掉前后的引号和内部转义符。最后全部处理成 keyNode 作为键节点返回。

对于方括号的处理,方括号中可能有两种情况,一是索引节点 IndexNode,二是切片节点 SliceNode。索引节点不一定单纯指数字索引,也可能包含表达式等,所有表达式节点都会称为 IndexNode 的子节点,后面可以通过分析AST结构来进一步帮助理解,目前先理解成”除切片外所有方括号都算IndexNode“即可。

代码占比比较大的部分是对切片的处理,不过虽然看起来很复杂,实际上也就是暴力涵盖多种情况而已,这还得感谢Python支持切片省略的写法比如 [1::2]、[::3]、[:-1] 等等各种特殊情况都得考虑到。

二元运算

接下来是加减乘除四则运算以及比较表达式和逻辑表达式。按照运算法则,乘除法作为高级运算其优先级应该高于加减法,完成数值运算后进行比较运算,最后进行逻辑运算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class Parser:
    ...

    def expr(self) -> ASTNode:
        """解析逻辑表达式 (||, &&)"""
        left = self.comparison()
       
        while (self.current_token and self.current_token.type == TokenType.OP and
               self.current_token.value in (Operator.LOGICAL_OR.value, Operator.LOGICAL_AND.value)):
            # 将操作符字符串转换为Operator枚举
            op_value = self.current_token.value
            op = Operator.LOGICAL_OR if op_value == Operator.LOGICAL_OR.value else Operator.LOGICAL_AND
            line, column = self.current_token.line, self.current_token.column
            self.advance()
            right = self.comparison()
            left = BinaryOpNode(left, op, right, line, column)
           
        return left

    def comparison(self) -> ASTNode:
        """解析比较表达式 (>, <, >=, <=, ==, !=)"""
        left = self.addition()
       
        while (self.current_token and self.current_token.type == TokenType.OP and
               self.current_token.value in (
                   Operator.GREATER_THAN.value,
                   Operator.LESS_THAN.value,
                   Operator.GREATER_EQUAL.value,
                   Operator.LESS_EQUAL.value,
                   Operator.EQUAL.value,
                   Operator.NOT_EQUAL.value
               )):
            # 将操作符字符串转换为Operator枚举
            op_value = self.current_token.value
            op = next(op for op in Operator if op.value == op_value)
            line, column = self.current_token.line, self.current_token.column
            self.advance()
            right = self.addition()
            left = BinaryOpNode(left, op, right, line, column)
           
        return left

    def addition(self) -> ASTNode:
        """解析加减法表达式 (+, -)"""
        left = self.multiplication()
       
        while (self.current_token and self.current_token.type == TokenType.OP and
               self.current_token.value in (Operator.PLUS.value, Operator.MINUS.value)):
            # 将操作符字符串转换为Operator枚举
            op_value = self.current_token.value
            op = Operator.PLUS if op_value == Operator.PLUS.value else Operator.MINUS
            line, column = self.current_token.line, self.current_token.column
            self.advance()
            right = self.multiplication()
            left = BinaryOpNode(left, op, right, line, column)
           
        return left

    def multiplication(self) -> ASTNode:
        """解析乘除法表达式 (*, /)"""
        left = self.path()
       
        while (self.current_token and self.current_token.type == TokenType.OP and
               self.current_token.value in (Operator.MULTIPLY.value, Operator.DIVIDE.value)):
            # 将操作符字符串转换为Operator枚举
            op_value = self.current_token.value
            op = Operator.MULTIPLY if op_value == Operator.MULTIPLY.value else Operator.DIVIDE
            line, column = self.current_token.line, self.current_token.column
            self.advance()
            right = self.path()
            left = BinaryOpNode(left, op, right, line, column)
           
        return left

这四个函数操作几乎完全一致,都是建立一个 BinaryOpNode ,左边是当前待运算的值,右边是递归结果,op则是具体的操作符。这四个函数在构造AST的阶段都没有什么复杂操作,就不展开介绍了。

最后消耗完所有Token后会检测END Token,如果不是END Token则会报令牌处理未完成的错误。

为了方便查看AST的结构,我们实现一个类似 Python 中 ast.dump() 功能的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class Parser:
    ...

    @staticmethod
    def dump_ast(node, annotate_fields=True, indent=None, level=0):
        """
        将AST节点转换为字符串表示
       
        Args:
            node: 要转换的AST节点
            annotate_fields: 是否注释字段
            indent: 缩进级别
            level: 当前节点层级
        """

        def is_ast_node(obj):
            return hasattr(obj, '__class__') and hasattr(obj, '__dict__') and isinstance(obj, ASTNode)

        def format_node(node, level):
            pad = ' ' * (indent * level) if indent else ''
            next_pad = ' ' * (indent * (level + 1)) if indent else ''
            cls_name = node.__class__.__name__
            fields = [(k, v) for k, v in node.__dict__.items() if not k.startswith('_')]
            if not fields:
                return f"{cls_name}()"
           
            # 根据indent决定分隔符
            sep = '' if indent is None else '\n'
           
            lines = [f"{cls_name}("]
            for i, (k, v) in enumerate(fields):
                if isinstance(v, list):
                    if not v:
                        value_str = '[]'
                    else:
                        list_sep = ', ' if indent is None else ',\n'
                        list_pad = '' if indent is None else next_pad + (' ' * indent)
                        value_str = list_sep.join(
                            list_pad + (format_node(item, level + 2) if is_ast_node(item) else repr(item))
                            for item in v
                        )
                        if indent is not None:
                            value_str = '[\n' + value_str + f'\n{next_pad}]'
                        else:
                            value_str = '[' + value_str + ']'
                elif isinstance(v, dict):
                    if not v:
                        value_str = '{}'
                    else:
                        value_str = list_sep.join(
                            list_pad + (format_node(k, level + 2) if is_ast_node(k) else repr(k)) + ': ' + (format_node(v, level + 2) if is_ast_node(v) else repr(v))
                            for k, v in v.items()
                        )
                        if indent is not None:
                            value_str = '{\n' + value_str + f'\n{next_pad}'+ '}'
                        else:
                            value_str = '{' + value_str + '}'
                elif is_ast_node(v):
                    value_str = format_node(v, level + 1)
                else:
                    value_str = repr(v)
                   
                if annotate_fields:
                    lines.append(f"{next_pad}{k}={value_str},")
                else:
                    lines.append(f"{next_pad}{value_str},")
                   
            lines[-1] = lines[-1].rstrip(',')
            lines.append(f"{pad})")
            return sep.join(lines)
        return format_node(node, level)
1
path = ".datas['items'].length['var' != $snum || ("id" >= 3 * 5 && 'age'< @max(10, num2=10 + 5))].n[:15:2]"

还是以这段语句为例,将其使用 Parser 解析成AST,再使用 dump_ast 格式化输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
SliceNode(
    type='SliceNode',
    line=1,
    column=92,
    obj=KeyNode(
        type='KeyNode',
        line=1,
        column=91,
        obj=IndexNode(
            type='IndexNode',
            line=1,
            column=23,
            obj=KeyNode(
                type='KeyNode',
                line=1,
                column=22,
                obj=IndexNode(
                    type='IndexNode',
                    line=1,
                    column=7,
                    obj=KeyNode(
                        type='KeyNode',
                        line=1,
                        column=6,
                        obj=KeyNode(
                            type='KeyNode',
                            line=1,
                            column=0,
                            obj=NameNode(
                                type='NameNode',
                                line=1,
                                column=0,
                                name=''
                            ),
                            key='*',
                            is_wildcard=True
                        ),
                        key='datas',
                        is_wildcard=False
                    ),
                    index=StringNode(
                        type='StringNode',
                        line=1,
                        column=14,
                        value='items'
                    )
                ),
                key='length',
                is_wildcard=False
            ),
            index=BinaryOpNode(
                type='BinaryOpNode',
                line=1,
                column=40,
                left=BinaryOpNode(
                    type='BinaryOpNode',
                    line=1,
                    column=31,
                    left=StringNode(
                        type='StringNode',
                        line=1,
                        column=28,
                        value='var'
                    ),
                    op=Operator(value=!=),
                    right=VarRefNode(
                        type='VarRefNode',
                        line=1,
                        column=33,
                        name=NameNode(
                            type='NameNode',
                            line=1,
                            column=37,
                            name='snum'
                        )
                    )
                ),
                op=Operator(value=||),
                right=BinaryOpNode(
                    type='BinaryOpNode',
                    line=1,
                    column=58,
                    left=BinaryOpNode(
                        type='BinaryOpNode',
                        line=1,
                        column=49,
                        left=StringNode(
                            type='StringNode',
                            line=1,
                            column=46,
                            value='id'
                        ),
                        op=Operator(value=>=),
                        right=BinaryOpNode(
                            type='BinaryOpNode',
                            line=1,
                            column=53,
                            left=NumberNode(
                                type='NumberNode',
                                line=1,
                                column=51,
                                value=3
                            ),
                            op=Operator(value=*),
                            right=NumberNode(
                                type='NumberNode',
                                line=1,
                                column=55,
                                value=5
                            )
                        )
                    ),
                    op=Operator(value=&&),
                    right=BinaryOpNode(
                        type='BinaryOpNode',
                        line=1,
                        column=65,
                        left=StringNode(
                            type='StringNode',
                            line=1,
                            column=64,
                            value='age'
                        ),
                        op=Operator(value=<),
                        right=ScriptCallNode(
                            type='ScriptCallNode',
                            line=1,
                            column=67,
                            module=[],
                            name=NameNode(
                                type='NameNode',
                                line=1,
                                column=70,
                                name='max'
                            ),
                            args=[
                                NumberNode(
                                    type='NumberNode',
                                    line=1,
                                    column=73,
                                    value=10
                                )
                            ],
                            kwargs={
                                NameNode(
                                    type='NameNode',
                                    line=1,
                                    column=79,
                                    name='num2'
                                ): BinaryOpNode(
                                    type='BinaryOpNode',
                                    line=1,
                                    column=84,
                                    left=NumberNode(
                                        type='NumberNode',
                                        line=1,
                                        column=82,
                                        value=10
                                    ),
                                    op=Operator(value=+),
                                    right=NumberNode(
                                        type='NumberNode',
                                        line=1,
                                        column=86,
                                        value=5
                                    )
                                )
                            }
                        )
                    )
                )
            )
        ),
        key='n',
        is_wildcard=False
    ),
    start=None,
    end=NumberNode(
        type='NumberNode',
        line=1,
        column=95,
        value=15
    ),
    step=NumberNode(
        type='NumberNode',
        line=1,
        column=97,
        value=2
    )
)

可以看到,该结构与语句的结构相反,语句中最前面的元素(点操作符,会被转换成通配符,上面有提到过)反而在结构的最里层,而结构的最外层 SliceNode 切片节点则是语句的最后一个元素,这样也方便我们在运算时以从后往前的顺序进行解析。如果觉得AST树太长看起来不方便,可以结合结构中的 column(列号)辅助定位。

AST执行

接下来就是最后一步,也是最重要的一步,我们需要执行AST,来实现获取数据的功能。

模块使用访问者模式15来设计,使用这种结构可以更好地让节点的定义和逻辑操作解耦,防止在加入过多功能后节点变得更为臃肿。使用访问者模式,AST节点就只负责进行数据表示,而所有的操作将由访问者来实现。

首先在节点基类中增加通用的访问者入口,使用这种方式无需向每个节点实现访问者入口方法,在访问者调用accept时会自动生成访问方法名。

1
2
3
4
5
6
7
class ASTNode:
    ...

    def accept(self, visitor):
        method_name = f'visit_{self.__class__.__name__}'
        method = getattr(visitor, method_name, visitor.generic_visit)
        return method(self)

接下来是访问者基类,所有访问者类继承于这个类,不过目前访问者还只有 Evaluator 执行器这一个类,但是如果后续如果要增加类型检测或代码优化等正对于某个节点功能的,集成该类实现对应的访问者即可。

1
2
3
4
5
6
7
8
9
10
class ASTVisitor:
    """AST访问者基类"""
   
    def generic_visit(self, node):
        """默认访问方法"""
        raise NotImplementedError(f"未实现节点类型 {node.__class__.__name__} 的访问方法")
       
    def visit(self, node):
        """访问节点入口方法"""
        return node.accept(self)

接下来就是具体的执行器,执行器中每个节点都有一个visit方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Evaluator(ASTVisitor):
    """
    执行器,用于执行AST节点
    """

    def __init__(self, data):
        self.data = data
        self.context = {}

    def query(self, ast_root: ASTNode):
        """查询入口方法"""
        # 标记当前是根查询
        self.context['is_root_query'] = True
        result = self.visit(ast_root)
        return result

接下来就分别来介绍这些方法。

NameNode、NumberNode、StringNode

这三种节点作为字面量节点,处理逻辑比较简单,大多数情况直接返回值即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Evaluator(ASTVisitor):
    ...

    def visit_NameNode(self, node: NameNode):
        if self.context.get('get_literal', False):
            # 脚本和变量操作时直接获取字面量名称
            self.context['get_literal'] = False
            return node.name
       
        # 如果是根查询,从self.data中获取对应键的值
        if self.context.get('is_root_query', False):
            name = node.name
            # 取消根查询标记,避免后续节点也被当作根查询
            self.context['is_root_query'] = False
           
            # 如果是根级别的通配符,返回整个数据
            if getattr(node, '_is_root_wildcard', False):
                return self.data
               
            # 从数据中获取对应的值
            if isinstance(self.data, dict) and name in self.data:
                return self.data[name]
            return None
       
        if 'current_item' in self.context:
            raise NameError(f"名称 '{node.name}' 未定义,位于 {node.line} 行 {node.column} 列")
       
        # 非根查询,直接返回节点名称
        return node.name
       
    def visit_NumberNode(self, node: NumberNode):
        return node.value
       
    def visit_StringNode(self, node: StringNode):
        """
        访问字符串节点
        当在条件过滤上下文中时,尝试从当前项中获取对应键的值
        """

        value = node.value
       
        # 检查是否在条件过滤上下文中
        if 'current_item' in self.context:
            current_item = self.context['current_item']
            # 如果当前项是字典且包含该键
            if isinstance(current_item, dict) and value in current_item:
                return current_item[value]
       
        # 普通字符串
        return value

不过还是有一些特殊情况需要单独处理的。比如在NameNode中,会先检查上下文中 get_literal 是否为True,当变量和脚本的NameNode调用visit方法时,该值会被置为True,此时只需要直接获取NameNode的字面量值作为参数名或脚本名即可。

如果 get_literal 不是True,则会检测是否是根查询,因为正常情况下,根一定是NameNode16,即便是以点操作符开头,也会在tokenization时被处理成"*.xxx"的格式。在检测到根查询,且查询值为通配符时,就会直接返回整个数据,否则从数据中获取对应key为name的值。

如果不是根查询,首先会检测 self.context 中是否有 current_item,context用于保存上下文,如果遇到方括号的表达式,会将表达式评估的对象(被过滤的对象)保存到context中,在 NameNode 中如果 context 有 current_item,说明该NameNode在方括号中,而目前语法设计中,NameNode不能出现在方括号中17。NumberNode 没有特殊处理,直接返回具体值。StringNone 会先判断是否包含方括号中,如果是则作为键名进行数据访问18,否则直接返回字符串值。

VarRefNode、ScriptCallNode

变量引用和脚本调用都会用到脚本管理器,所以就一起讲了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Evaluator(ASTVisitor):
    ...

    def visit_VarRefNode(self, node: VarRefNode):
        self.context['get_literal'] = True
       
        var_name = self.visit(node.name)
       
        # 首先从脚本管理器中获取
        var = script_manager.get(var_name)
       
        # 如果脚本管理器中没有获取到,则从数据中获取
        if not var:
            var = self.data.get(var_name)
       
        return var

    def visit_ScriptCallNode(self, node: ScriptCallNode):
        self.context['get_literal'] = True
        func_name = self.visit(node.name)
        module_path = ".".join([self.visit(module) for module in node.module])
       
        if not script_manager.check_script(name=func_name, path=module_path):
            raise ValueError(f"未定义的函数: {func_name}, 确保函数在运行前已注册")
           
           
        # 求值所有参数
        args = [self.visit(arg) for arg in node.args]
        kwargs = {self.visit(key): self.visit(value) for key, value in node.kwargs.items()}

        # 调用函数
        return script_manager.run(name=func_name, path=module_path, args=args, kwargs=kwargs)

实际上最开始设想的变量引用功能是可以直接获取到代码上下文中定义的变量,比如在执行语句前定义了 num = 15,然后在语句中通过$num进行调用。后来发现这种方式行不通,python的globals()作用域仅限于单个模块,在其他模块想要获取要么需要通过传参,要么只能通过importlib来动态导入,而动态导入会严重影响性能,所以最后决定的方案是使用脚本管理器来对变量进行管理,要使用某个变量需要先在脚本管理器中进行注册。

不过这种方法有一个缺点,就是使用这种方法跟使用 path = f"@func(1, {var})" 这种 f-string 没区别了...所以现在这个变量功能似乎没有什么太大的用处,后续如果想到更好的方法再进行优化吧。

在执行 VarRefNode 和 ScriptCallNode 时都要先取消 is_root_query 标记,否则当变量或脚本作为根节点运行时,在处理 NameNode 时会出问题。

变量节点首先会通过脚本管理器的get方法尝试获取变量,如果脚本管理器中没有获取到变量值(包括变量值为None的情况),就会将$作为整个数据根19,并且可以通过语句进行索引,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from dictquerier import query_json, script_manager

@script_manager.register()
def echo(content):
    print(f"echo传入的值: {content}")
    return content

data = {
    "items": {
        "inner": {
            "keyname": ["data", "data2"],
            "data": {
                "key": "value"
            },
            "data2": {
                "key": "value2"
            }
        },
    }
}

path = "items.inner[@echo($items[inner].keyname[1])].key"

result = query_json(data, path)
print(f"最终结果: {result}")

# 输出:
# echo传入的值: data2
# 最终结果: value2

接下来是脚本调用,同样也需要先取消根查询标记。在进行调用前,首先会对 ScriptCallNode 中的函数名称进行处理,使用点对所有module进行连接,然后调用脚本管理器的check_script检查是否是可调用脚本,如果是可调用脚本,则将所有的普通参数处理成list,kwargs参数处理成dict,然后调用脚本管理器的run方法运行脚本。

脚本管理器

变量的定义存取和脚本的调用都依赖于脚本管理器进行。先说变量功能,目前来说变量暂时还只是对数据进行一个存储,没有任何逻辑处理,具体实现的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ScriptManager:
    def __init__(self):
        self.scripts = {}
        self.variables = {}
       
        # 缓存数据
        self._module_cache = {}
        self._function_cache = {}
       
        # 调用状态统计
        self._stats = {
            'hits': 0,
            'misses': 0,
            'total_calls': 0,
        }

    ...

    def define(self, var_name, var_value):
        self.variables[var_name] = var_value
       
    def get(self, var_name):
        return self.variables.get(var_name)

脚本管理器中比较主要的是脚本管理部分(废话了),脚本分为注册、卸载、检查、调用三个主要操作。先说注册,使用register装饰器对函数进行标记以注册,注册的脚本会被以键值对的形式存到脚本管理器的 self.scripts 中,其中键是脚本名称,值是对应方法。注册功能的实现非常简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class ScriptManager:
    ...

    def register(self, name: str = None):
        """
        注册脚本

        Args:
            name (str, optional): 自定义脚本调用名,可选
        Returns:
        """

        def decorator(func):
            key = name or func.__name__
            self.scripts[key] = func
            return func
        return decorator

对应的是卸载功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ScriptManager:
    ...

    def unregister(self, name: str) -> bool:
        """
        卸载已注册的脚本
       
        Args:
            name (str): 要卸载的脚本名称
        Returns:
            bool: 卸载是否成功
        """

        if name in self.scripts:
            # 从scripts字典中移除
            del self.scripts[name]
           
            # 清除相关缓存
            self.clear_specific_cache(name)
           
            return True
        return False

    def clear_specific_cache(self, name: str, path: str = None):
        """
        清理特定脚本的缓存
       
        Args:
            name (str): 脚本名
            path (str, optional): 模块路径
        """

        cache_key = self._get_cache_key(name, path)
       
        if cache_key in self._function_cache:
            del self._function_cache[cache_key]
           
        # 如果有path,尝试清理模块缓存
        if path:
            parts = path.split('.')
            module_path = parts[0]  # 只清理顶级模块
           
            if module_path in self._module_cache:
                del self._module_cache[module_path]

    def _get_cache_key(self, name: str, path: str = None) -> str:
        """生成缓存键"""
        return f"{path}:{name}" if path else name

在卸载的同时会清除掉对应的已缓存的方法。

目前脚本管理器还只支持同步函数的注册和调用,对于需要进行一些IO请求或一些耗时较长的函数来说不太友好,后续可以尝试引入异步函数注册和调用的方法,需要重新设计一下异步函数的查询调用语法。

接下来就是脚本的检查和调用部分,因为调用前会先执行检查,所以这两部分就一起讲了。除了自行注册的脚本外,脚本管理器还支持调用 Python 自带的模块和内置方法,具体调用流程是:

由AI生成的 mermaid 流程图