eac983f31d7f0f37caaf865b74cd9b9c9b413daf
[lunaix-os.git] / lunaix-os / scripts / expand.py
1 import jinja2
2 import re
3 import argparse
4 import sys
5 import json
6 from pathlib import Path
7 from abc import ABC, abstractmethod
8
9 class ControlAction(ABC):
10     def __init__(self, record) -> None:
11         self.__record = record
12         self._parse(record)
13
14     @staticmethod
15     def create(field, value):
16         if field == "$range":
17             return RangeAction(value)
18         else:
19             return value
20     
21     @abstractmethod
22     def _parse(self, record):
23         pass
24
25     @abstractmethod
26     def action(self, data, param):
27         pass
28
29     def get(self):
30         return self.__record
31
32 class RangeAction(ControlAction):
33     def __init__(self, record) -> None:
34         self.__ranged_component = re.compile(r"^(?P<index>[0-9]+)$|^(?P<start>[0-9]+)..(?P<end>[0-9]+)$")
35         super().__init__(record)
36
37     def _parse(self, record):
38         if not (record.startswith('[') and record.endswith(']')):
39             raise Exception(f"'{record}' is not valid range expression")
40         record = record.strip('[]')
41         
42         self.__value=[]
43         for component in record.split(','):
44             component = component.strip()
45             mo = self.__ranged_component.match(component)
46             if mo is None:
47                 raise Exception(f"value '{component}' is not valid range component")
48             
49             mo = mo.groupdict()
50             if mo["index"] is not None:
51                 self.__value.append(int(mo['index']))
52             else:
53                 start = int(mo['start'])
54                 end = int(mo['end'])
55                 self.__value += [x for x in range(start, end + 1)]
56
57         self.__value = sorted(self.__value)
58
59     def action(self, data, param):
60         return super().action(data, param)
61
62     def get(self):
63         return self.__value
64     
65
66 class DataObject(ABC):
67     def __init__(self, name, record):
68         self.key = name
69         self.user_field = {}
70         self.ctrl_field = {}
71         self._parse(record)
72
73     @staticmethod
74     def create(record):
75         return DataObject.create("", record)
76
77     @staticmethod
78     def create(key, record):
79         if not isinstance(record, dict):
80             return record
81         
82         name = key
83         t = name if "$type" not in record else record['$type']
84         
85         if "$name" in record:
86             name = record["$name"].strip()
87
88         if not key.startswith('@') and "$type" not in record:
89             return PlainOldObject(name, record)
90         
91         t = t.strip("@")
92         if t == "list":
93             return RangedObject(name, record)
94         elif t == "foreach":
95             return ForEachIndexObject(name, record)
96         elif t == "case_range_index":
97             return Condition(record)
98         elif t == "data":
99             return PlainOldObject(name, record)
100         else:
101             return RawObject(name, record)
102
103     @abstractmethod
104     def _parse(self, record):
105         for k, v in record.items():
106             if k.startswith("$"):
107                 self.ctrl_field[k.strip("$")] = ControlAction.create(k, v)
108             elif k.startswith("@"):
109                 self.ctrl_field[k.strip("@")] = DataObject.create(k, v)
110             else:
111                 self.user_field[k] = DataObject.create(k, v)
112
113     @abstractmethod
114     def expand(self, param={}):
115         obj = { **self.user_field }
116         for f in self.ctrl_field.values():
117             if not isinstance(f, DataObject):
118                 continue
119             obj.update(**f.expand(param))
120         return obj
121
122 class Condition(DataObject):
123     def __init__(self, record):
124         super().__init__("", record)
125     
126     def _parse(self, record):
127         super()._parse(record)
128         if "range" not in self.ctrl_field:
129             raise Exception("condition body must contains valid range case")
130         if "true" not in self.ctrl_field:
131             raise Exception("condition body must contains 'True' handling case")
132         
133         self.__range_lst = self.ctrl_field["range"].get()
134     
135     def expand(self, param={}):
136         if param["index"] in self.__range_lst:
137             return self.ctrl_field["true"].expand(param)
138         elif "else" in self.ctrl_field:
139             return self.ctrl_field["else"].expand(param)
140         return {}
141
142 class ForEachIndexObject(DataObject):
143     def __init__(self, name, record):
144         super().__init__(name, record)
145         self.conditions = []
146         for k, v in record.items():
147             self.conditions.append(DataObject.create(k, v))
148
149     def _parse(self, record):
150         super()._parse(record)
151     
152     def expand(self, param={}):
153         if "index" not in param:
154             raise Exception(f"'{type(self).__name__}' require parameter 'index'.")
155         obj = {}
156         for cond in self.conditions:
157             obj.update(**cond.expand(param))
158         return obj
159     
160 class RawObject(DataObject):
161     def __init__(self, name, record):
162         super().__init__(name, record)
163     
164     def _parse(self, record):
165         return super()._parse(record)
166
167     def expand(self, param={}):
168         return super().expand(param)
169
170 class PlainOldObject(DataObject):
171     def __init__(self, name, record):
172         super().__init__(name, record)
173
174     def _parse(self, record):
175         return super()._parse(record)
176
177     def expand(self, param={}):
178         return {
179             self.key: super().expand(param)
180         }
181
182 class RangedObject(DataObject):
183     def __init__(self, name, record):
184         super().__init__(name, record)
185     
186     def _parse(self, record):
187         super()._parse(record)
188
189     def expand(self, param={}):
190         if "range" not in self.ctrl_field:
191             raise Exception("RangedObject with ranged type must have 'range' field defined")
192         
193         out_lst = []
194         indices = self.ctrl_field["range"].get()
195         for i in indices:
196             param["index"] = i
197             self.user_field["index"] = i
198             out_lst.append(super().expand(param))
199         
200         return {
201             self.key: out_lst
202         }
203     
204 class TemplateExpander:
205     def __init__(self, template_path, project_path, arch) -> None:
206         self.arch = arch
207         self.tbase_path = template_path.joinpath(arch)
208         self.pbase_path = project_path
209
210         self.__load_config()
211         self.__load_mappings()
212
213     def __load_config(self):
214         self.data = {}
215         cfg_file: Path = self.tbase_path.joinpath("config.json")
216         if not cfg_file.is_file():
217             raise Exception(f"config not found. ({cfg_file})")
218         
219         obj = json.loads(cfg_file.read_text())
220         for k, v in obj.items():
221             o = DataObject.create(k, v).expand()
222             self.data.update(**o)
223
224     def __load_mappings(self):
225         self.mapping = {}
226         mapping_file: Path = self.tbase_path.joinpath("mappings")
227         if not mapping_file.is_file():
228             raise Exception(f"config not found. ({mapping_file})")
229         
230         with mapping_file.open() as f:
231             for l in f:
232                 src, dest = l.split("->")
233                 src = src.strip()
234
235                 if src in self.mapping:
236                     raise Exception(f"repeating entry ({src})")
237                 
238                 self.mapping[src] = dest.strip()
239
240     def render(self):
241         for k, v in self.mapping.items():
242             src: Path = self.tbase_path.joinpath(k)
243             dest: Path = self.pbase_path.joinpath(v)
244             if not src.is_file():
245                 continue
246             
247             template = jinja2.Template(src.read_text(), trim_blocks=True)
248             out = template.render(data=self.data)
249
250             dest.write_text(out)
251
252             print(f"rendered: {k} -> {v}")
253
254
255 def main():
256     parser = argparse.ArgumentParser()
257     parser.add_argument("--arch", default='i386')
258     parser.add_argument("-twd", "--template_dir", default=str(Path.cwd()))
259     parser.add_argument("-pwd", "--project_dir", default=str(Path.cwd()))
260
261     args = parser.parse_args()
262
263     expander = TemplateExpander(Path(args.template_dir), Path(args.project_dir), args.arch)
264     expander.render()
265
266 if __name__ == "__main__":
267     main()