2-setup_gdt.md (#22)
[lunaix-os.git] / lunaix-os / scripts / expand.py
1 import jinja2
2 import re
3 import argparse
4 import math
5 import json
6 from pathlib import Path
7 from abc import ABC, abstractmethod
8
9 class Preprocessor:
10     reHex = re.compile(r"^0x([0-9a-fA-F]+)$")
11     reGranuel = re.compile(r"^(?P<num>[0-9]+)@(?P<g>.+)$")
12     reMacroRef = re.compile(r"^\*(?P<var>[a-zA-z0-9]+)$")
13     reInt = re.compile(r"^[0-9]+$")
14     def __init__(self) -> None:
15         pass
16
17     @staticmethod
18     def expand_str(s: str, param_dict):
19         if Preprocessor.reInt.match(s) is not None:
20             return int(s)
21         
22         mo = Preprocessor.reHex.match(s)
23         if mo is not None:
24             return int(s, 16)
25         
26         mo = Preprocessor.reGranuel.match(s)
27         if mo is not None:
28             mg = mo.groupdict()
29             num = int(mg['num'])
30             granuel = param_dict["granule"][mg['g']]
31             return num * granuel
32         
33         mo = Preprocessor.reMacroRef.match(s)
34         if mo is not None:
35             mg = mo.groupdict()
36             return param_dict[mg['var']]
37
38         return s.format_map(param_dict)
39
40 class DataObject(ABC):
41     def __init__(self, name, record):
42         self.key = name
43         self._record = record
44         self.user_field = {}
45         self.ctrl_field = {}
46         self._parse(record)
47
48     @staticmethod
49     def create(record):
50         return DataObject.create("", record)
51
52     @staticmethod
53     def create(key, record):
54         if PrimitiveType.can_create(record):
55             return PrimitiveType(record)
56         
57         name = key
58         t = name if "$type" not in record else record['$type']
59         
60         if "$name" in record:
61             name = record["$name"].strip()
62
63         if not key.startswith('@') and "$type" not in record:
64             return PlainOldObject(name, record)
65         
66         t = t.strip("@")
67         if t == "list":
68             return RangedObject(name, record)
69         elif t == "foreach":
70             return ForEachIndexObject(name, record)
71         elif t == "case_range_index":
72             return Condition(record)
73         elif t == "data":
74             return PlainOldObject(name, record)
75         elif t == "define":
76             return VariableDeclarationObject(record)
77         elif t == "memory_map":
78             return MemoryMapObject(record)
79         else:
80             return RawObject(name, record)
81
82     def _parse(self, record):
83         for k, v in record.items():
84             if k.startswith("$"):
85                 self.ctrl_field[k.strip("$")] = FieldType.create(k, v)
86             elif k.startswith("@"):
87                 self.ctrl_field[k.strip("@")] = DataObject.create(k, v)
88             else:
89                 self.user_field[k] = DataObject.create(k, v)
90
91     def expand(self, param={}):
92         obj2 = {}
93         for f in self.ctrl_field.values():
94             if not isinstance(f, DataObject):
95                 continue
96             obj2.update(**f.expand(param))
97
98         obj = {}
99         _param = {**param, **obj2}
100         for k, v in self.user_field.items():
101             if isinstance(v, DataObject):
102                 obj[k] = v.expand(_param)
103             else:
104                 obj[k] = v
105
106         return {**obj, **obj2}
107     
108
109 class FieldType:
110     def __init__(self, record) -> None:
111         self._record = record
112         self._parse(record)
113
114     @staticmethod
115     def create(field, value):
116         if field == "$range":
117             return RangeType(value)
118         else:
119             return value
120     
121     @abstractmethod
122     def _parse(self, record):
123         pass
124
125     @abstractmethod
126     def get(self, param):
127         pass
128
129     def getraw(self):
130         return self.__record
131
132 class PrimitiveType(DataObject):
133     def __init__(self, record) -> None:
134         super().__init__("", record)
135
136     @staticmethod
137     def can_create(value):
138         return type(value) in (str, int, bool)
139
140     def _parse(self, record):
141         if type(record) not in (str, int, bool):
142             raise Exception(f"{type(self).__name__} require primitive type input")
143         self.val = record
144
145         if type(record) == str:
146             self.__get_fn = self.__process_str
147         else:
148             self.__get_fn = lambda x: self.val
149
150     def __process_str(self, param):
151         return Preprocessor.expand_str(self.val, param)
152     
153     def expand(self, param={}):
154         return self.__get_fn(param)
155
156 class RangeType(FieldType):
157     def __init__(self, record) -> None:
158         self.__ranged_component = re.compile(r"^(?P<index>[^.]+)$|^(?P<start>.+?)\.\.(?P<end>.+)$")
159         super().__init__(record)
160
161     def _parse(self, record):
162         return super()._parse(record)
163     
164     def get(self, param):
165         record = self._record.strip('[]')
166         
167         self.__value=[]
168         for component in record.split(','):
169             component = component.strip()
170             mo = self.__ranged_component.match(component)
171             if mo is None:
172                 raise Exception(f"value '{component}' is not valid range component")
173             
174             mo = mo.groupdict()
175             if mo["index"] is not None:
176                 self.__value.append(Preprocessor.expand_str(mo['index'], param))
177             else:
178                 start = Preprocessor.expand_str(mo['start'], param)
179                 end = Preprocessor.expand_str(mo['end'], param)
180                 self.__value += [x for x in range(start, end + 1)]
181         return self.__value
182
183     def getraw(self):
184         return self._record
185     
186 class VariableDeclarationObject(DataObject):
187     def __init__(self, record):
188         super().__init__("", record)
189     
190     def _parse(self, record):
191         return super()._parse(record)
192     
193     def expand(self, param={}):
194         obj = super().expand(param)
195         param.update(**obj)
196         return {}
197
198 class Condition(DataObject):
199     def __init__(self, record):
200         super().__init__("", record)
201     
202     def _parse(self, record):
203         super()._parse(record)
204         if "range" not in self.ctrl_field:
205             raise Exception("condition body must contains valid range case")
206         if "true" not in self.ctrl_field:
207             raise Exception("condition body must contains 'True' handling case")
208         
209     
210     def expand(self, param={}):
211         self.__range_lst = self.ctrl_field["range"].get(param)
212         if param["index"] in self.__range_lst:
213             return self.ctrl_field["true"].expand(param)
214         elif "else" in self.ctrl_field:
215             return self.ctrl_field["else"].expand(param)
216         return {}
217     
218 class ArrayObject(DataObject):
219     def __init__(self, record, 
220                  nested_array = False, 
221                  el_factory = lambda x: DataObject.create("", x)):
222         self._el_factory = el_factory
223         self._nested_array = nested_array
224
225         super().__init__("", record)
226     
227     def _parse(self, record):
228         if not isinstance(record, list):
229             raise Exception(f"{type(self).__name__} require array input")
230         
231         self.content = []
232         for x in record:
233             self.content.append(self._el_factory(x))
234     
235     def expand(self, param={}):
236         result = []
237         for x in self.content:
238             obj = x.expand(param)
239             if isinstance(obj, list) and not self._nested_array:
240                 result += [*obj]
241             else:
242                 result.append(obj)
243         
244         return result
245     
246 class MemoryMapObject(DataObject):
247     class GranuleObject(DataObject):
248         def __init__(self, record):
249             super().__init__("", record)
250             
251         def _parse(self, record):
252             self.__granules = {}
253             for k, v in record.items():
254                 self.__granules[k] = DataObject.create(k, v)
255
256         def expand(self, param={}):
257             granules = {}
258             for k, v in self.__granules.items():
259                 val = v.expand(param)
260
261                 if not isinstance(val, int):
262                     raise Exception("The granule definition must be either integer or int-castable string")
263                 
264                 granules[k] = val
265                 
266             return {**granules}
267         
268     def __init__(self, record):
269         super().__init__("", record)
270
271     def _parse(self, record):
272         for k, v in record.items():
273             if k.startswith("$"):
274                 self.ctrl_field[k.strip("$")] = FieldType.create(k, v)
275             elif k.startswith("@"):
276                 self.ctrl_field[k.strip("@")] = DataObject.create(k, v)
277         
278         if "granule" in record:
279             self.__g = MemoryMapObject.GranuleObject(record["granule"])
280
281         if "regions" in record:
282             self.__regions = ArrayObject(record["regions"])
283
284         if "width" in record:
285             self.__width = DataObject.create("width", record["width"])
286
287     def __process(self, start_addr, idx, regions):
288         if idx >= len(regions):
289             raise Exception("Unbounded region definition")
290         
291         e = regions[idx]
292
293         if "start" not in e:
294             ne = regions[idx + 1]
295             if "start" not in ne or "size" not in e:
296                 e["start"] = start_addr
297             else:
298                 self.__process(start_addr + e["size"], idx + 1, regions)
299                 e["start"] = ne['start'] - e["size"]
300
301         if "block" in e:
302             b = e["block"] - 1
303             e["start"] = (e["start"] + b) & ~b
304
305         if e["start"] < start_addr:
306             raise Exception(f"starting addr {hex(e['start'])} overrlapping with {hex(start_addr)}")
307         
308         start_addr = e["start"]
309         
310         if "size" not in e:
311             self.__process(start_addr, idx + 1, regions)
312             ne = regions[idx + 1]
313             e["size"] = ne['start'] - start_addr
314         
315         return start_addr + e["size"]
316     
317     def expand(self, param={}):
318         super().expand(param)
319
320         g = self.__g.expand(param)
321
322         param["granule"] = g
323
324         width = self.__width.expand(param)
325         if not isinstance(width, int):
326             raise Exception("'width' attribute must be integer")
327
328         regions = self.__regions.expand(param)
329         
330         start_addr = 0
331         for i in range(len(regions)):
332             start_addr = self.__process(start_addr, i, regions)
333         
334         if math.log2(start_addr) > width:
335             raise Exception("memory size larger than speicified address width")
336
337         return {
338             "granule": g,
339             "regions": regions
340         }
341
342 class ForEachIndexObject(DataObject):
343     def __init__(self, name, record):
344         super().__init__(name, record)
345         self.steps = []
346         for k, v in record.items():
347             self.steps.append(DataObject.create(k, v))
348
349     def _parse(self, record):
350         super()._parse(record)
351     
352     def expand(self, param={}):
353         if "index" not in param:
354             raise Exception(f"'{type(self).__name__}' require parameter 'index'.")
355         obj = {}
356         for cond in self.steps:
357             obj.update(**cond.expand(param))
358         return obj
359     
360 class RawObject(DataObject):
361     def __init__(self, name, record):
362         super().__init__(name, record)
363     
364     def _parse(self, record):
365         return super()._parse(record)
366
367     def expand(self, param={}):
368         return super().expand(param)
369
370 class PlainOldObject(DataObject):
371     def __init__(self, name, record):
372         super().__init__(name, record)
373
374     def _parse(self, record):
375         return super()._parse(record)
376
377     def expand(self, param={}):
378         return super().expand(param)
379
380 class RangedObject(DataObject):
381     def __init__(self, name, record):
382         super().__init__(name, record)
383     
384     def _parse(self, record):
385         super()._parse(record)
386
387     def expand(self, param={}):
388         if "range" not in self.ctrl_field:
389             raise Exception("RangedObject with ranged type must have 'range' field defined")
390         
391         out_lst = []
392         indices = self.ctrl_field["range"].get(param)
393         for i in indices:
394             param["index"] = i
395             out_lst.append(super().expand(param))
396         
397         return out_lst
398     
399 def aligned(v, a):
400     return v & ~(a - 1)
401     
402 class TemplateExpander:
403     def __init__(self, template_path, project_path, arch) -> None:
404         self.arch = arch
405         self.tbase_path = template_path.joinpath(arch)
406         self.pbase_path = project_path
407
408         self.__helper_fns = {
409             "align": aligned,
410             "hex": lambda x: hex(x)
411         }
412
413         self.__load_config()
414         self.__load_mappings()
415
416     def __load_config(self):
417         self.data = {}
418         cfg_file: Path = self.tbase_path.joinpath("config.json")
419         if not cfg_file.is_file():
420             raise Exception(f"config not found. ({cfg_file})")
421         
422         obj = json.loads(cfg_file.read_text())
423         for k, v in obj.items():
424             o = DataObject.create(k, v).expand()
425             self.data[k] = o
426
427     def __load_mappings(self):
428         self.mapping = {}
429         mapping_file: Path = self.tbase_path.joinpath("mappings")
430         if not mapping_file.is_file():
431             raise Exception(f"config not found. ({mapping_file})")
432         
433         with mapping_file.open() as f:
434             for l in f:
435                 l = l.strip()
436
437                 if not l:
438                     continue
439
440                 src, dest = l.split("->")
441                 src = src.strip()
442
443                 if src in self.mapping:
444                     raise Exception(f"repeating entry ({src})")
445                 
446                 self.mapping[src] = dest.strip()
447
448     def render(self):
449         for k, v in self.mapping.items():
450             src: Path = self.tbase_path.joinpath(k)
451             dest: Path = self.pbase_path.joinpath(v)
452             if not src.is_file():
453                 continue
454
455             if not dest.parent.exists():
456                 dest.parent.mkdir(parents=True)
457             
458             self.data["template"] = k
459             template = jinja2.Template(src.read_text(), trim_blocks=True)
460             template.globals.update(**self.__helper_fns)
461             out = template.render(data=self.data)
462
463             dest.write_text(out)
464
465             print(f"rendered: {k} -> {v}")
466
467 import pprint
468
469 def main():
470     parser = argparse.ArgumentParser()
471     parser.add_argument("--arch", default='i386')
472     parser.add_argument("-twd", "--template_dir", default=str(Path.cwd()))
473     parser.add_argument("-pwd", "--project_dir", default=str(Path.cwd()))
474
475     args = parser.parse_args()
476
477     expander = TemplateExpander(Path(args.template_dir), Path(args.project_dir), args.arch)
478     
479     expander.render()
480     # pprint.pprint(expander.data)
481
482 if __name__ == "__main__":
483     main()