X-Git-Url: https://scm.lunaixsky.com/lunaix-os.git/blobdiff_plain/7804c2dae30700296c3205aaf7f546f491999bf4..b60166b327a9108b07e3069fa6568a451529ffd9:/lunaix-os/scripts/expand.py diff --git a/lunaix-os/scripts/expand.py b/lunaix-os/scripts/expand.py index eac983f..31e06da 100644 --- a/lunaix-os/scripts/expand.py +++ b/lunaix-os/scripts/expand.py @@ -1,71 +1,46 @@ import jinja2 import re import argparse -import sys +import math import json from pathlib import Path from abc import ABC, abstractmethod -class ControlAction(ABC): - def __init__(self, record) -> None: - self.__record = record - self._parse(record) - - @staticmethod - def create(field, value): - if field == "$range": - return RangeAction(value) - else: - return value - - @abstractmethod - def _parse(self, record): +class Preprocessor: + reHex = re.compile(r"^0x([0-9a-fA-F]+)$") + reGranuel = re.compile(r"^(?P[0-9]+)@(?P.+)$") + reMacroRef = re.compile(r"^\*(?P[a-zA-z0-9]+)$") + reInt = re.compile(r"^[0-9]+$") + def __init__(self) -> None: pass - @abstractmethod - def action(self, data, param): - pass - - def get(self): - return self.__record - -class RangeAction(ControlAction): - def __init__(self, record) -> None: - self.__ranged_component = re.compile(r"^(?P[0-9]+)$|^(?P[0-9]+)..(?P[0-9]+)$") - super().__init__(record) - - def _parse(self, record): - if not (record.startswith('[') and record.endswith(']')): - raise Exception(f"'{record}' is not valid range expression") - record = record.strip('[]') + @staticmethod + def expand_str(s: str, param_dict): + if Preprocessor.reInt.match(s) is not None: + return int(s) - self.__value=[] - for component in record.split(','): - component = component.strip() - mo = self.__ranged_component.match(component) - if mo is None: - raise Exception(f"value '{component}' is not valid range component") - - mo = mo.groupdict() - if mo["index"] is not None: - self.__value.append(int(mo['index'])) - else: - start = int(mo['start']) - end = int(mo['end']) - self.__value += [x for x in range(start, end + 1)] - - self.__value = sorted(self.__value) - - def action(self, data, param): - return super().action(data, param) + mo = Preprocessor.reHex.match(s) + if mo is not None: + return int(s, 16) + + mo = Preprocessor.reGranuel.match(s) + if mo is not None: + mg = mo.groupdict() + num = int(mg['num']) + granuel = param_dict["granule"][mg['g']] + return num * granuel + + mo = Preprocessor.reMacroRef.match(s) + if mo is not None: + mg = mo.groupdict() + return param_dict[mg['var']] - def get(self): - return self.__value - + return s.format_map(param_dict) class DataObject(ABC): def __init__(self, name, record): self.key = name + self._record = record self.user_field = {} self.ctrl_field = {} self._parse(record) @@ -76,8 +51,8 @@ class DataObject(ABC): @staticmethod def create(key, record): - if not isinstance(record, dict): - return record + if PrimitiveType.can_create(record): + return PrimitiveType(record) name = key t = name if "$type" not in record else record['$type'] @@ -97,27 +72,128 @@ class DataObject(ABC): return Condition(record) elif t == "data": return PlainOldObject(name, record) + elif t == "define": + return VariableDeclarationObject(record) + elif t == "memory_map": + return MemoryMapObject(record) else: return RawObject(name, record) - @abstractmethod def _parse(self, record): for k, v in record.items(): if k.startswith("$"): - self.ctrl_field[k.strip("$")] = ControlAction.create(k, v) + self.ctrl_field[k.strip("$")] = FieldType.create(k, v) elif k.startswith("@"): self.ctrl_field[k.strip("@")] = DataObject.create(k, v) else: self.user_field[k] = DataObject.create(k, v) - @abstractmethod def expand(self, param={}): - obj = { **self.user_field } + obj2 = {} for f in self.ctrl_field.values(): if not isinstance(f, DataObject): continue - obj.update(**f.expand(param)) - return obj + obj2.update(**f.expand(param)) + + obj = {} + _param = {**param, **obj2} + for k, v in self.user_field.items(): + if isinstance(v, DataObject): + obj[k] = v.expand(_param) + else: + obj[k] = v + + return {**obj, **obj2} + + +class FieldType: + def __init__(self, record) -> None: + self._record = record + self._parse(record) + + @staticmethod + def create(field, value): + if field == "$range": + return RangeType(value) + else: + return value + + @abstractmethod + def _parse(self, record): + pass + + @abstractmethod + def get(self, param): + pass + + def getraw(self): + return self.__record + +class PrimitiveType(DataObject): + def __init__(self, record) -> None: + super().__init__("", record) + + @staticmethod + def can_create(value): + return type(value) in (str, int, bool) + + def _parse(self, record): + if type(record) not in (str, int, bool): + raise Exception(f"{type(self).__name__} require primitive type input") + self.val = record + + if type(record) == str: + self.__get_fn = self.__process_str + else: + self.__get_fn = lambda x: self.val + + def __process_str(self, param): + return Preprocessor.expand_str(self.val, param) + + def expand(self, param={}): + return self.__get_fn(param) + +class RangeType(FieldType): + def __init__(self, record) -> None: + self.__ranged_component = re.compile(r"^(?P[^.]+)$|^(?P.+?)\.\.(?P.+)$") + super().__init__(record) + + def _parse(self, record): + return super()._parse(record) + + def get(self, param): + record = self._record.strip('[]') + + self.__value=[] + for component in record.split(','): + component = component.strip() + mo = self.__ranged_component.match(component) + if mo is None: + raise Exception(f"value '{component}' is not valid range component") + + mo = mo.groupdict() + if mo["index"] is not None: + self.__value.append(Preprocessor.expand_str(mo['index'], param)) + else: + start = Preprocessor.expand_str(mo['start'], param) + end = Preprocessor.expand_str(mo['end'], param) + self.__value += [x for x in range(start, end + 1)] + return self.__value + + def getraw(self): + return self._record + +class VariableDeclarationObject(DataObject): + def __init__(self, record): + super().__init__("", record) + + def _parse(self, record): + return super()._parse(record) + + def expand(self, param={}): + obj = super().expand(param) + param.update(**obj) + return {} class Condition(DataObject): def __init__(self, record): @@ -130,21 +206,145 @@ class Condition(DataObject): if "true" not in self.ctrl_field: raise Exception("condition body must contains 'True' handling case") - self.__range_lst = self.ctrl_field["range"].get() def expand(self, param={}): + self.__range_lst = self.ctrl_field["range"].get(param) if param["index"] in self.__range_lst: return self.ctrl_field["true"].expand(param) elif "else" in self.ctrl_field: return self.ctrl_field["else"].expand(param) return {} + +class ArrayObject(DataObject): + def __init__(self, record, + nested_array = False, + el_factory = lambda x: DataObject.create("", x)): + self._el_factory = el_factory + self._nested_array = nested_array + + super().__init__("", record) + + def _parse(self, record): + if not isinstance(record, list): + raise Exception(f"{type(self).__name__} require array input") + + self.content = [] + for x in record: + self.content.append(self._el_factory(x)) + + def expand(self, param={}): + result = [] + for x in self.content: + obj = x.expand(param) + if isinstance(obj, list) and not self._nested_array: + result += [*obj] + else: + result.append(obj) + + return result + +class MemoryMapObject(DataObject): + class GranuleObject(DataObject): + def __init__(self, record): + super().__init__("", record) + + def _parse(self, record): + self.__granules = {} + for k, v in record.items(): + self.__granules[k] = DataObject.create(k, v) + + def expand(self, param={}): + granules = {} + for k, v in self.__granules.items(): + val = v.expand(param) + + if not isinstance(val, int): + raise Exception("The granule definition must be either integer or int-castable string") + + granules[k] = val + + return {**granules} + + def __init__(self, record): + super().__init__("", record) + + def _parse(self, record): + for k, v in record.items(): + if k.startswith("$"): + self.ctrl_field[k.strip("$")] = FieldType.create(k, v) + elif k.startswith("@"): + self.ctrl_field[k.strip("@")] = DataObject.create(k, v) + + if "granule" in record: + self.__g = MemoryMapObject.GranuleObject(record["granule"]) + + if "regions" in record: + self.__regions = ArrayObject(record["regions"]) + + if "width" in record: + self.__width = DataObject.create("width", record["width"]) + + def __process(self, start_addr, idx, regions): + if idx >= len(regions): + raise Exception("Unbounded region definition") + + e = regions[idx] + + if "start" not in e: + ne = regions[idx + 1] + if "start" not in ne or "size" not in e: + e["start"] = start_addr + else: + self.__process(start_addr + e["size"], idx + 1, regions) + e["start"] = ne['start'] - e["size"] + + if "block" in e: + b = e["block"] - 1 + e["start"] = (e["start"] + b) & ~b + + if e["start"] < start_addr: + raise Exception(f"starting addr {hex(e['start'])} overrlapping with {hex(start_addr)}") + + start_addr = e["start"] + + if "size" not in e: + self.__process(start_addr, idx + 1, regions) + ne = regions[idx + 1] + e["size"] = ne['start'] - start_addr + + return start_addr + e["size"] + + def expand(self, param={}): + super().expand(param) + + g = self.__g.expand(param) + + param["granule"] = g + + width = self.__width.expand(param) + if not isinstance(width, int): + raise Exception("'width' attribute must be integer") + + regions = self.__regions.expand(param) + + start_addr = 0 + for i in range(len(regions)): + start_addr = self.__process(start_addr, i, regions) + + if math.log2(start_addr) > width: + raise Exception("memory size larger than speicified address width") + + return { + "granule": g, + "regions": regions + } class ForEachIndexObject(DataObject): def __init__(self, name, record): super().__init__(name, record) - self.conditions = [] + self.steps = [] for k, v in record.items(): - self.conditions.append(DataObject.create(k, v)) + self.steps.append(DataObject.create(k, v)) def _parse(self, record): super()._parse(record) @@ -153,7 +353,7 @@ class ForEachIndexObject(DataObject): if "index" not in param: raise Exception(f"'{type(self).__name__}' require parameter 'index'.") obj = {} - for cond in self.conditions: + for cond in self.steps: obj.update(**cond.expand(param)) return obj @@ -175,9 +375,7 @@ class PlainOldObject(DataObject): return super()._parse(record) def expand(self, param={}): - return { - self.key: super().expand(param) - } + return super().expand(param) class RangedObject(DataObject): def __init__(self, name, record): @@ -191,15 +389,15 @@ class RangedObject(DataObject): raise Exception("RangedObject with ranged type must have 'range' field defined") out_lst = [] - indices = self.ctrl_field["range"].get() + indices = self.ctrl_field["range"].get(param) for i in indices: param["index"] = i - self.user_field["index"] = i out_lst.append(super().expand(param)) - return { - self.key: out_lst - } + return out_lst + +def aligned(v, a): + return v & ~(a - 1) class TemplateExpander: def __init__(self, template_path, project_path, arch) -> None: @@ -207,6 +405,11 @@ class TemplateExpander: self.tbase_path = template_path.joinpath(arch) self.pbase_path = project_path + self.__helper_fns = { + "align": aligned, + "hex": lambda x: hex(x) + } + self.__load_config() self.__load_mappings() @@ -219,7 +422,7 @@ class TemplateExpander: obj = json.loads(cfg_file.read_text()) for k, v in obj.items(): o = DataObject.create(k, v).expand() - self.data.update(**o) + self.data[k] = o def __load_mappings(self): self.mapping = {} @@ -229,6 +432,11 @@ class TemplateExpander: with mapping_file.open() as f: for l in f: + l = l.strip() + + if not l: + continue + src, dest = l.split("->") src = src.strip() @@ -237,23 +445,33 @@ class TemplateExpander: self.mapping[src] = dest.strip() - def render(self): + def render(self, selected = []): for k, v in self.mapping.items(): src: Path = self.tbase_path.joinpath(k) dest: Path = self.pbase_path.joinpath(v) + if (k not in selected): + continue + if not src.is_file(): continue + + if not dest.parent.exists(): + dest.parent.mkdir(parents=True) + self.data["template"] = k template = jinja2.Template(src.read_text(), trim_blocks=True) + template.globals.update(**self.__helper_fns) out = template.render(data=self.data) dest.write_text(out) print(f"rendered: {k} -> {v}") +import pprint def main(): parser = argparse.ArgumentParser() + parser.add_argument("selects", nargs="*") parser.add_argument("--arch", default='i386') parser.add_argument("-twd", "--template_dir", default=str(Path.cwd())) parser.add_argument("-pwd", "--project_dir", default=str(Path.cwd())) @@ -261,7 +479,9 @@ def main(): args = parser.parse_args() expander = TemplateExpander(Path(args.template_dir), Path(args.project_dir), args.arch) - expander.render() + + expander.render(args.selects) + # pprint.pprint(expander.data) if __name__ == "__main__": main() \ No newline at end of file