import ast
-from lib.utils import Schema
+from lib.utils import Schema, SourceLogger
from .lazy import Lazy
-from .common import NodeProperty
+from .common import NodeProperty, NodeInverseDependency
+
+class RewriteRule:
+ MaybeBuiltin = Schema(
+ ast.Call,
+ func=Schema(ast.Name),
+ args=[ast.expr])
+
+ WhenTogglerItem = Schema(
+ ast.Compare,
+ left=ast.Name,
+ ops=[Schema.Union(ast.Is, ast.IsNot)],
+ comparators=[ast.Constant])
+
+ WhenToggler = Schema(
+ Schema.Union(
+ WhenTogglerItem,
+ Schema(ast.BoolOp,
+ op=ast.And,
+ values=Schema.List(WhenTogglerItem))))
class ConfigNodeASTRewriter(ast.NodeTransformer):
- Depend = Schema(
- ast.Call,
- func=Schema(ast.Name, id='require'),
- args=[ast.expr])
+
def __init__(self, cfg_node):
super().__init__()
self.__cfg_node = cfg_node
+ self.__when_epxr = None
+
+ def __put_linkage(self, to_node, item):
+ node = self.__cfg_node._env.get_node(to_node)
+ link = NodeProperty.Linkage[node]
+
+ if not link:
+ link = NodeInverseDependency()
+ NodeProperty.Linkage[node] = link
+
+ link.add_linkage(self.__cfg_node._name, ast.unparse(item))
+
+ def __subscript_accessor(self, name, ctx, token):
+ return ast.Subscript(
+ value=ast.Name("__lzLut__", ctx=ast.Load()),
+ slice=ast.Constant(name),
+ ctx=ctx,
+ )
+
def __gen_accessor(self, orig):
key = Lazy.from_astn(self.__cfg_node, orig)
if not key:
return self.generic_visit(orig)
- return ast.Subscript(
- value=ast.Name("__lzLut__", ctx=ast.Load()),
- slice=ast.Constant(key),
- ctx=orig.ctx,
- lineno=orig.lineno,
- col_offset=orig.col_offset
- )
+ return self.__subscript_accessor(key, orig.ctx, orig)
def __gen_dependency(self, node):
cfgn = self.__cfg_node
dep_expr = ast.BoolOp(ast.And(), [dep_expr, node])
NodeProperty.Dependency[cfgn] = dep_expr
+ def __gen_when_expr(self, node):
+ and_list = []
+ cfgn = self.__cfg_node
+
+ if RewriteRule.WhenToggler != node:
+ raise cfgn.config_error(
+ f"when(...): invalid expression: {ast.unparse(node)}")
+
+ if RewriteRule.WhenTogglerItem == node:
+ and_list.append(node)
+ else:
+ and_list += node.values
+
+ for i in range(len(and_list)):
+ item = and_list[i]
+ operator = item.ops[0]
+
+ if RewriteRule.WhenTogglerItem != item:
+ raise cfgn.config_error(
+ f"when(...): non-trivial subclause : {ast.unparse(node)}")
+
+ name = Lazy.from_type(cfgn, Lazy.NodeValue, item.left.id)
+ acc = self.__subscript_accessor(name, ast.Load(), node)
+
+ self.__put_linkage(item.left.id, item)
+
+ if isinstance(operator, ast.Is):
+ operator = ast.Eq()
+ else:
+ operator = ast.NotEq()
+
+ item.left = acc
+ item.ops = [operator]
+
+
+ current = ast.BoolOp(
+ op=ast.And(),
+ values=[ast.Constant(True), *and_list])
+
+ expr = self.__when_epxr
+ if expr:
+ assert isinstance(expr, ast.expr)
+ current = ast.BoolOp(op=ast.Or(), values=[expr, current])
+
+ self.__when_epxr = current
+
def visit_Attribute(self, node):
return self.__gen_accessor(node)
def visit_Expr(self, node):
val = node.value
- if ConfigNodeASTRewriter.Depend != val:
+ if RewriteRule.MaybeBuiltin != val:
return self.generic_visit(node)
# Process marker functions
name = val.func.id
if name == "require":
self.__gen_dependency(val.args[0])
+ elif name == "when":
+ self.__gen_when_expr(val.args[0])
else:
return self.generic_visit(node)
return None
+
+ def visit_Return(self, node):
+ if self.__when_epxr:
+ SourceLogger.warn(self.__cfg_node, node,
+ "mixed use of `return` and `when` directive. "
+ "`when` have higher precedence than `return`. "
+ "consider remove `return` to avoid confusion")
+ return None
+ return self.generic_visit(node)
+
+ def visit_Is(self, node):
+ return ast.Eq()
+
+ def rewrite(self, node):
+ assert isinstance(node, ast.Module)
+ node = self.visit(node)
+
+ expr = self.__when_epxr
+ if not expr:
+ return node
+
+ node.body.append(ast.Return(expr, lineno=0, col_offset=0))
+ return node