fix leakage on mid-level page table when freeing vms
[lunaix-os.git] / lunaix-os / kernel / mm / procvm.c
1 #include <lunaix/mm/procvm.h>
2 #include <lunaix/mm/valloc.h>
3 #include <lunaix/mm/region.h>
4 #include <lunaix/mm/page.h>
5 #include <lunaix/mm/mmap.h>
6 #include <lunaix/process.h>
7 #include <lunaix/syslog.h>
8
9 #include <asm/mm_defs.h>
10
11 #include <klibc/string.h>
12
13 #define alloc_pagetable_trace(ptep, pte, ord, level)                        \
14     ({                                                                      \
15         alloc_kpage_at(ptep, pte, ord);                                     \
16     })
17
18 #define free_pagetable_trace(ptep, pte, level)                              \
19     ({                                                                      \
20         struct leaflet* leaflet = pte_leaflet_aligned(pte);                 \
21         assert(leaflet_order(leaflet) == 0);                                \
22         leaflet_return(leaflet);                                            \
23         set_pte(ptep, null_pte);                                            \
24     })
25
26 struct proc_mm*
27 procvm_create(struct proc_info* proc) {
28     struct proc_mm* mm = vzalloc(sizeof(struct proc_mm));
29
30     assert(mm);
31
32     mm->heap = 0;
33     mm->proc = proc;
34
35     llist_init_head(&mm->regions);
36     return mm;
37 }
38
39 static inline unsigned int
40 __ptep_advancement(struct leaflet* leaflet, int level)
41 {
42     size_t shifts = MAX(MAX_LEVEL - level - 1, 1) * LEVEL_SHIFT;
43     return (1 << (leaflet_order(leaflet) % shifts)) - 1;
44 }
45
46 static inline int
47 __descend(ptr_t dest_mnt, ptr_t src_mnt, ptr_t va, bool alloc)
48 {
49     pte_t *dest, *src, pte;
50
51     int i = 0;
52     while (!pt_last_level(i))
53     {
54         dest = mklntep_va(i, dest_mnt, va);
55         src  = mklntep_va(i, src_mnt, va);
56         pte  = pte_at(src);
57
58         if (!pte_isloaded(pte) || pte_huge(pte)) {
59             break;
60         }
61
62         if (alloc && pte_isnull(pte_at(dest))) {
63             alloc_pagetable_trace(dest, pte, 0, i);
64         }
65
66         i++;
67     }
68
69     return i;
70 }
71
72 static void
73 __free_hierarchy(ptr_t mnt, ptr_t va, int level)
74 {
75     pte_t pte, *ptep, *ptep_next;
76
77     if (pt_last_level(level)) {
78         return;
79     }
80
81     __free_hierarchy(mnt, va, level + 1);
82
83     ptep = mklntep_va(level, mnt, va);
84     pte = pte_at(ptep);
85     if (pte_isnull(pte)) {
86         return;
87     }
88
89     ptep_next = ptep_step_into(ptep);
90     for (unsigned i = 0; i < LEVEL_SIZE; i++, ptep_next++)
91     {
92         if (!pte_isnull(pte_at(ptep_next))) {
93             return;
94         }
95     }
96     
97     free_pagetable_trace(ptep, pte, level);
98 }
99
100 static inline void
101 copy_leaf(pte_t* dest, pte_t* src, pte_t pte, int level)
102 {
103     struct leaflet* leaflet;
104
105     set_pte(dest, pte);
106
107     if (!pte_isloaded(pte)) {
108         return;
109     }
110
111     leaflet = pte_leaflet(pte);
112     assert(leaflet_refcount(leaflet));
113     
114     if (leaflet_ppfn(leaflet) == pte_ppfn(pte)) {
115         leaflet_borrow(leaflet);
116     }
117 }
118
119 static inline void
120 copy_root(pte_t* dest, pte_t* src, pte_t pte, int level)
121 {
122     alloc_pagetable_trace(dest, pte, 0, level);
123 }
124
125 static void
126 vmrcpy(ptr_t dest_mnt, ptr_t src_mnt, struct mm_region* region)
127 {
128     pte_t *src, *dest;
129     ptr_t loc;
130     int level;
131     struct leaflet* leaflet;
132
133     loc  = region->start;
134     src  = mkptep_va(src_mnt, loc);
135     dest = mkptep_va(dest_mnt, loc);
136
137     level = __descend(dest_mnt, src_mnt, loc, true);
138
139     while (loc < region->end)
140     {
141         pte_t pte = *src;
142
143         if (pte_isnull(pte)) {
144             goto cont;
145         } 
146         
147         if (pt_last_level(level) || pte_huge(pte)) {
148             copy_leaf(dest, src, pte, level);
149             goto cont;
150         }
151         
152         if (!pt_last_level(level)) {
153             copy_root(dest, src, pte, level);
154
155             src = ptep_step_into(src);
156             dest = ptep_step_into(dest);
157             level++;
158
159             continue;
160         }
161         
162     cont:
163         loc += lnt_page_size(level);
164         while (ptep_vfn(src) == MAX_PTEN - 1) {
165             assert(level > 0);
166             src = ptep_step_out(src);
167             dest = ptep_step_out(dest);
168             level--;
169         }
170
171         src++;
172         dest++;
173     }
174 }
175
176 static inline void
177 vmrfree_hierachy(ptr_t vm_mnt, struct mm_region* region)
178 {
179     __free_hierarchy(vm_mnt, region->start, 0);
180 }
181
182 static void
183 vmrfree(ptr_t vm_mnt, struct mm_region* region)
184 {
185     pte_t *src, *end;
186     ptr_t loc;
187     int level;
188     struct leaflet* leaflet;
189
190     loc  = region->start;
191     src  = mkptep_va(vm_mnt, region->start);
192     end  = mkptep_va(vm_mnt, region->end);
193
194     level = __descend(vm_mnt, vm_mnt, loc, false);
195
196     while (src < end)
197     {
198         pte_t pte = *src;
199         ptr_t pa  = pte_paddr(pte);
200
201         if (pte_isnull(pte)) {
202             goto cont;
203         } 
204
205         if (!pt_last_level(level) && !pte_huge(pte)) {
206             src = ptep_step_into(src);
207             level++;
208
209             continue;
210         }
211
212         set_pte(src, null_pte);
213         
214         if (pte_isloaded(pte)) {
215             leaflet = pte_leaflet_aligned(pte);
216             leaflet_return(leaflet);
217
218             src += __ptep_advancement(leaflet, level);
219         }
220
221     cont:
222         while (ptep_vfn(src) == MAX_PTEN - 1) {
223             src = ptep_step_out(src);
224             free_pagetable_trace(src, pte_at(src), level);
225             
226             level--;
227         }
228
229         src++;
230     }
231 }
232
233 static void
234 vmscpy(struct proc_mm* dest_mm, struct proc_mm* src_mm)
235 {
236     // Build the self-reference on dest vms
237
238     /* 
239      *        -- What the heck are ptep_ssm and ptep_sms ? --
240      *      
241      *      ptep_dest point to the pagetable itself that is mounted
242      *          at dest_mnt (or simply mnt): 
243      *              mnt -> self -> self -> self -> L0TE@offset
244      * 
245      *      ptep_sms shallowed the recursion chain:
246      *              self -> mnt -> self -> self -> L0TE@self
247      * 
248      *      ptep_ssm shallowed the recursion chain:
249      *              self -> self -> mnt -> self -> L0TE@self
250      *      
251      *      Now, here is the problem, back to x86_32, the translation is 
252      *      a depth-3 recursion:
253      *              L0T -> LFT -> Page
254      *      
255      *      So ptep_ssm will terminate at mnt and give us a leaf
256      *      slot for allocate a fresh page table for mnt:
257      *              self -> self -> L0TE@mnt
258      * 
259      *      but in x86_64 translation has extra two more step:
260      *              L0T -> L1T -> L2T -> LFT -> Page
261      *      
262      *      So we must continue push down.... 
263      *      ptep_sssms shallowed the recursion chain:
264      *              self -> self -> self -> mnt  -> L0TE@self
265      * 
266      *      ptep_ssssm shallowed the recursion chain:
267      *              self -> self -> self -> self -> L0TE@mnt
268      * 
269      *      Note: PML4: 2 extra steps
270      *            PML5: 3 extra steps
271     */
272
273     ptr_t  dest_mnt, src_mnt;
274     
275     dest_mnt = dest_mm->vm_mnt;
276     assert(dest_mnt);
277
278     pte_t* ptep_ssm     = mkl0tep_va(VMS_SELF, dest_mnt);
279     pte_t* ptep_smx     = mkl1tep_va(VMS_SELF, dest_mnt);
280     pte_t  pte_sms      = mkpte_prot(KERNEL_PGTAB);
281
282     pte_sms = alloc_pagetable_trace(ptep_ssm, pte_sms, 0, 0);
283     set_pte(&ptep_smx[VMS_SELF_L0TI], pte_sms);
284     
285     tlb_flush_kernel((ptr_t)dest_mnt);
286
287     if (!src_mm) {
288         goto done;
289     }
290
291     src_mnt = src_mm->vm_mnt;
292
293     struct mm_region *pos, *n;
294     llist_for_each(pos, n, &src_mm->regions, head)
295     {
296         vmrcpy(dest_mnt, src_mnt, pos);
297     }
298
299 done:;
300     procvm_link_kernel(dest_mnt);
301     
302     dest_mm->vmroot = pte_paddr(pte_sms);
303 }
304
305 static void
306 vmsfree(struct proc_mm* mm)
307 {
308     struct leaflet* leaflet;
309     struct mm_region *pos, *n;
310     ptr_t vm_mnt;
311     pte_t* ptep_self;
312     
313     vm_mnt    = mm->vm_mnt;
314     ptep_self = mkl0tep(mkptep_va(vm_mnt, VMS_SELF));
315
316     // first pass: free region mappings
317     llist_for_each(pos, n, &mm->regions, head)
318     {
319         vmrfree(vm_mnt, pos);
320     }
321
322     // second pass: free the hierarchical 
323     llist_for_each(pos, n, &mm->regions, head)
324     {
325         vmrfree_hierachy(vm_mnt, pos);
326     }
327
328     procvm_unlink_kernel();
329
330     free_pagetable_trace(ptep_self, pte_at(ptep_self), 0);
331 }
332
333 static inline void
334 __attach_to_current_vms(struct proc_mm* guest_mm)
335 {
336     struct proc_mm* mm_current = vmspace(__current);
337     if (mm_current) {
338         assert(!mm_current->guest_mm);
339         mm_current->guest_mm = guest_mm;
340     }
341 }
342
343 static inline void
344 __detach_from_current_vms(struct proc_mm* guest_mm)
345 {
346     struct proc_mm* mm_current = vmspace(__current);
347     if (mm_current) {
348         assert(mm_current->guest_mm == guest_mm);
349         mm_current->guest_mm = NULL;
350     }
351 }
352
353 void
354 procvm_prune_vmr(ptr_t vm_mnt, struct mm_region* region)
355 {
356     vmrfree(vm_mnt, region);
357     vmrfree_hierachy(vm_mnt, region);
358 }
359
360 void
361 procvm_dupvms_mount(struct proc_mm* mm) {
362     assert(__current);
363     assert(!mm->vm_mnt);
364
365     struct proc_mm* mm_current = vmspace(__current);
366     
367     __attach_to_current_vms(mm);
368    
369     mm->heap = mm_current->heap;
370     mm->vm_mnt = VMS_MOUNT_1;
371     
372     vmscpy(mm, mm_current);  
373     region_copy_mm(mm_current, mm);
374 }
375
376 void
377 procvm_mount(struct proc_mm* mm)
378 {
379     // if current mm is already active
380     if (active_vms(mm->vm_mnt)) {
381         return;
382     }
383     
384     // we are double mounting
385     assert(!mm->vm_mnt);
386     assert(mm->vmroot);
387
388     vms_mount(VMS_MOUNT_1, mm->vmroot);
389
390     __attach_to_current_vms(mm);
391
392     mm->vm_mnt = VMS_MOUNT_1;
393 }
394
395 void
396 procvm_unmount(struct proc_mm* mm)
397 {
398     if (active_vms(mm->vm_mnt)) {
399         return;
400     }
401     
402     assert(mm->vm_mnt);
403     vms_unmount(VMS_MOUNT_1);
404     
405     struct proc_mm* mm_current = vmspace(__current);
406     if (mm_current) {
407         mm_current->guest_mm = NULL;
408     }
409
410     mm->vm_mnt = 0;
411 }
412
413 void
414 procvm_initvms_mount(struct proc_mm* mm)
415 {
416     assert(!mm->vm_mnt);
417
418     __attach_to_current_vms(mm);
419
420     mm->vm_mnt = VMS_MOUNT_1;
421     vmscpy(mm, NULL);
422 }
423
424 void
425 procvm_unmount_release(struct proc_mm* mm) {
426     ptr_t vm_mnt = mm->vm_mnt;
427     struct mm_region *pos, *n;
428
429     llist_for_each(pos, n, &mm->regions, head)
430     {
431         mem_sync_pages(vm_mnt, pos, pos->start, pos->end - pos->start, 0);
432     }
433
434     vmsfree(mm);
435
436     llist_for_each(pos, n, &mm->regions, head)
437     {
438         region_release(pos);
439     }
440
441     vms_unmount(vm_mnt);
442     vfree(mm);
443
444     __detach_from_current_vms(mm);
445 }
446
447 void
448 procvm_mount_self(struct proc_mm* mm) 
449 {
450     assert(!mm->vm_mnt);
451
452     mm->vm_mnt = VMS_SELF;
453 }
454
455 void
456 procvm_unmount_self(struct proc_mm* mm)
457 {
458     assert(active_vms(mm->vm_mnt));
459
460     mm->vm_mnt = 0;
461 }
462
463 ptr_t
464 procvm_enter_remote(struct remote_vmctx* rvmctx, struct proc_mm* mm, 
465                     ptr_t remote_base, size_t size)
466 {
467     ptr_t vm_mnt = mm->vm_mnt;
468     assert(vm_mnt);
469     
470     pfn_t size_pn = pfn(size + PAGE_SIZE);
471     assert(size_pn < REMOTEVM_MAX_PAGES);
472
473     struct mm_region* region = region_get(&mm->regions, remote_base);
474     assert(region && region_contains(region, remote_base + size));
475
476     rvmctx->vms_mnt = vm_mnt;
477     rvmctx->page_cnt = size_pn;
478
479     remote_base = page_aligned(remote_base);
480     rvmctx->remote = remote_base;
481     rvmctx->local_mnt = PG_MOUNT_VAR;
482
483     pte_t* rptep = mkptep_va(vm_mnt, remote_base);
484     pte_t* lptep = mkptep_va(VMS_SELF, rvmctx->local_mnt);
485
486     pte_t pte, rpte = null_pte;
487     rpte = region_tweakpte(region, rpte);
488
489     for (size_t i = 0; i < size_pn; i++)
490     {
491         pte = vmm_tryptep(rptep, PAGE_SIZE);
492         if (pte_isloaded(pte)) {
493             set_pte(lptep, pte);
494             continue;
495         }
496
497         ptr_t pa = ppage_addr(pmm_alloc_normal(0));
498         set_pte(lptep, mkpte(pa, KERNEL_DATA));
499         set_pte(rptep, pte_setpaddr(rpte, pa));
500     }
501
502     return vm_mnt;
503     
504 }
505
506 int
507 procvm_copy_remote_transaction(struct remote_vmctx* rvmctx, 
508                    ptr_t remote_dest, void* local_src, size_t sz)
509 {
510     if (remote_dest < rvmctx->remote) {
511         return -1;
512     }
513
514     ptr_t offset = remote_dest - rvmctx->remote;
515     if (pfn(offset + sz) >= rvmctx->page_cnt) {
516         return -1;
517     }
518
519     memcpy((void*)(rvmctx->local_mnt + offset), local_src, sz);
520
521     return sz;
522 }
523
524 void
525 procvm_exit_remote(struct remote_vmctx* rvmctx)
526 {
527     pte_t* lptep = mkptep_va(VMS_SELF, rvmctx->local_mnt);
528     vmm_unset_ptes(lptep, rvmctx->page_cnt);
529 }