Rewrite the lunabuild toolchain with enhanced feature (#60)
[lunaix-os.git] / lunaix-os / scripts / build-tools / lcfg2 / ast_validator.py
1 import ast
2 import inspect
3 import textwrap
4
5 from typing import Callable
6 from lib.utils import Schema, ConfigASTVisitor, SourceLogger
7 from .common import Schema, ConfigNodeError
8
9 class Rule:
10     def __init__(self, t, v, name, fn):
11         self.type = t
12         self.__name = name
13         self.__var = v
14         self.__fn = fn
15         self.__help_msg = inspect.getdoc(fn)
16         self.__help_msg = textwrap.dedent(self.__help_msg.strip())
17
18     def match_variant(self, astn):
19         if not self.__var:
20             return True
21         return self.__var.match(astn)
22     
23     def invoke(self, reducer, node):
24         if self.__fn(reducer._rules, reducer, node):
25            return
26
27         SourceLogger.warn(reducer._cfgn, node, 
28                           f"rule violation: {self.__name}: {self.__help_msg}")
29         # raise ConfigNodeError(reducer._cfgn, 
30         #         f"rule failed: {self.__name}: {self.__help_msg}")
31  
32
33 def rule(ast_type: type, variant: Schema, name: str):
34     def __rule(fn: Callable):
35         return Rule(ast_type, variant, name, fn)
36     return __rule
37
38 class RuleCollection:
39     def __init__(self):
40         self.__rules = {}
41
42         members = inspect.getmembers(self, lambda p: isinstance(p, Rule))
43         for _, rule in members:
44             t = rule.type
45             if rule.type not in self.__rules:
46                 self.__rules[t] = [rule]
47             else:
48                 self.__rules[t].append(rule)
49     
50     def execute(self, reducer, node):
51         rules = self.__rules.get(type(node))
52         if not rules:
53             return
54         
55         for rule in rules:
56             if not rule.match_variant(node):
57                 continue
58             rule.invoke(reducer, node)
59
60 class NodeValidator(ast.NodeTransformer):
61     def __init__(self, all_rules):
62         super().__init__()
63         self._rules = all_rules
64
65     def validate(self, cfgn, astn):
66         self._cfgn = cfgn
67         self.visit(astn)
68
69     def visit(self, node):
70         self._rules.execute(self, node)
71         return super().visit(node)