Skip to main content

tidepool_codegen/
effect_machine.rs

1use crate::context::VMContext;
2use crate::heap_bridge;
3use crate::layout;
4use crate::yield_type::{Yield, YieldError};
5use tidepool_heap::layout as heap_layout;
6
7/// Constructor tags for the freer-simple Eff type.
8///
9/// These identify which DataCon a heap-allocated constructor represents,
10/// allowing the effect machine to distinguish Val (pure result) from
11/// E (effect request) and destructure Union wrappers and Leaf/Node continuations.
12#[derive(Debug, Clone, Copy)]
13pub struct ConTags {
14    /// Con_tag for the Val constructor (pure result).
15    pub val: u64,
16    /// Con_tag for the E constructor (effect request).
17    pub e: u64,
18    /// Con_tag for the Union constructor (effect type wrapper).
19    pub union: u64,
20    /// Con_tag for the Leaf constructor (leaf continuation).
21    pub leaf: u64,
22    /// Con_tag for the Node constructor (composed continuation).
23    pub node: u64,
24}
25
26impl ConTags {
27    /// Resolve freer-simple constructor tags from a DataConTable.
28    pub fn from_table(table: &tidepool_repr::DataConTable) -> Option<Self> {
29        Some(ConTags {
30            val: table.get_by_name("Val")?.0,
31            e: table.get_by_name("E")?.0,
32            union: table.get_by_name("Union")?.0,
33            leaf: table.get_by_name("Leaf")?.0,
34            node: table.get_by_name("Node")?.0,
35        })
36    }
37}
38
39/// Compiled effect machine — drives JIT-compiled freer-simple effect stacks.
40///
41/// The step/resume protocol:
42/// 1. step() calls the compiled function, parses the result:
43///    - Con with Val con_tag → Yield::Done(value)
44///    - Con with E con_tag → Yield::Request(tag, request, continuation)
45/// 2. resume(continuation, response) applies the continuation tree to the response
46///    and parses the resulting heap object.
47pub struct CompiledEffectMachine {
48    func_ptr: unsafe extern "C" fn(*mut VMContext) -> *mut u8,
49    vmctx: VMContext,
50    tags: ConTags,
51}
52
53// SAFETY: All fields are raw pointers or function pointers, which are Send.
54unsafe impl Send for CompiledEffectMachine {}
55
56impl CompiledEffectMachine {
57    pub fn new(
58        func_ptr: unsafe extern "C" fn(*mut VMContext) -> *mut u8,
59        vmctx: VMContext,
60        tags: ConTags,
61    ) -> Self {
62        Self {
63            func_ptr,
64            vmctx,
65            tags,
66        }
67    }
68
69    /// Access the VMContext (e.g., to update nursery pointers after GC).
70    pub fn vmctx_mut(&mut self) -> &mut VMContext {
71        &mut self.vmctx
72    }
73
74    /// Execute the compiled function and parse the result.
75    pub fn step(&mut self) -> Yield {
76        // SAFETY: func_ptr is a finalized JIT function pointer. vmctx is valid and
77        // owned by this machine. The function returns a heap pointer to an Eff value.
78        let mut result: *mut u8 = unsafe { (self.func_ptr)(&mut self.vmctx) };
79        // SAFETY: resolve_tail_calls reads/writes vmctx.tail_callee/tail_arg which
80        // are valid heap pointers set by JIT tail-call sites.
81        unsafe {
82            self.resolve_tail_calls(&mut result);
83        }
84        self.parse_result(result)
85    }
86
87    /// Resume after handling an effect by applying the continuation to the response.
88    ///
89    /// # Safety
90    ///
91    /// `continuation` and `response` must be valid heap pointers from the nursery.
92    pub unsafe fn resume(&mut self, continuation: *mut u8, response: *mut u8) -> Yield {
93        // SAFETY: Caller guarantees continuation and response are valid nursery heap pointers.
94        let mut result = self.apply_cont_heap(continuation, response);
95        self.resolve_tail_calls(&mut result);
96        self.parse_result(result)
97    }
98
99    /// Parse a heap-allocated Eff result into a Yield.
100    fn parse_result(&mut self, result: *mut u8) -> Yield {
101        // Check for runtime error FIRST (before null check), because runtime_error
102        // now returns a "poison" non-null Lit object to prevent segfaults in JIT code.
103        if let Some(err) = crate::host_fns::take_runtime_error() {
104            return Yield::Error(YieldError::from(err));
105        }
106        if result.is_null() {
107            return Yield::Error(YieldError::NullPointer);
108        }
109
110        // Force result if it's a thunk (lazy Con field from parent)
111        let result = self.force_ptr(result);
112        if result.is_null() {
113            return Yield::Error(YieldError::NullPointer);
114        }
115
116        // SAFETY: result is non-null (checked above) and points to a valid heap object.
117        // All field reads below use known layout offsets from tidepool_heap::layout.
118        let tag = unsafe { *result };
119        if tag != layout::TAG_CON {
120            return Yield::Error(YieldError::UnexpectedTag(tag));
121        }
122
123        let con_tag = unsafe { *(result.add(layout::CON_TAG_OFFSET as usize) as *const u64) };
124
125        if con_tag == self.tags.val {
126            // Val(value) — extract value from fields[0]
127            let num_fields =
128                unsafe { *(result.add(layout::CON_NUM_FIELDS_OFFSET as usize) as *const u16) };
129            if num_fields < 1 {
130                return Yield::Error(YieldError::BadValFields(num_fields));
131            }
132            let value =
133                unsafe { *(result.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8) };
134            // Force value field — it may be a thunk
135            let value = self.force_ptr(value);
136            Yield::Done(value)
137        } else if con_tag == self.tags.e {
138            // E(union, continuation) — extract Union and k
139            let num_fields =
140                unsafe { *(result.add(layout::CON_NUM_FIELDS_OFFSET as usize) as *const u16) };
141            if num_fields != 2 {
142                return Yield::Error(YieldError::BadEFields(num_fields));
143            }
144            let mut union_ptr =
145                unsafe { *(result.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8) };
146            let mut continuation =
147                unsafe { *(result.add(layout::CON_FIELDS_OFFSET as usize + 8) as *const *mut u8) };
148
149            // Force all field pointers — they may be thunks from lazy Con fields
150            union_ptr = self.force_ptr(union_ptr);
151            if union_ptr.is_null() {
152                return Yield::Error(YieldError::NullPointer);
153            }
154            continuation = self.force_ptr(continuation);
155            if continuation.is_null() {
156                return Yield::Error(YieldError::NullPointer);
157            }
158
159            let union_tag = unsafe { *union_ptr };
160            if union_tag != layout::TAG_CON {
161                return Yield::Error(YieldError::UnexpectedTag(union_tag));
162            }
163
164            let union_num_fields =
165                unsafe { *(union_ptr.add(layout::CON_NUM_FIELDS_OFFSET as usize) as *const u16) };
166            if union_num_fields != 2 {
167                return Yield::Error(YieldError::BadUnionFields(union_num_fields));
168            }
169
170            let tag_ptr =
171                unsafe { *(union_ptr.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8) };
172            let tag_ptr = self.force_ptr(tag_ptr);
173            if tag_ptr.is_null() {
174                return Yield::Error(YieldError::NullPointer);
175            }
176            // Read the actual tag value from the Lit HeapObject (offset 16 = LIT_VALUE_OFFSET)
177            let tag_ptr_tag = unsafe { *tag_ptr };
178            let effect_tag =
179                unsafe { *(tag_ptr.add(layout::LIT_VALUE_OFFSET as usize) as *const u64) };
180            let mut request = unsafe {
181                *(union_ptr.add(layout::CON_FIELDS_OFFSET as usize + 8) as *const *mut u8)
182            };
183            request = self.force_ptr(request);
184
185            if std::env::var("TIDEPOOL_TRACE_EFFECTS").is_ok() {
186                eprintln!(
187                    "[effect_machine] effect_tag={} tag_ptr_tag={} union_con_tag={} request_tag={}",
188                    effect_tag,
189                    tag_ptr_tag,
190                    unsafe { *(union_ptr.add(layout::CON_TAG_OFFSET as usize) as *const u64) },
191                    if request.is_null() {
192                        255
193                    } else {
194                        unsafe { *request }
195                    }
196                );
197            }
198
199            Yield::Request {
200                tag: effect_tag,
201                request,
202                continuation,
203            }
204        } else {
205            Yield::Error(YieldError::UnexpectedConTag(con_tag))
206        }
207    }
208
209    /// Force a heap pointer if it's a thunk, returning the WHNF result.
210    /// Loops to handle chains (thunk returning thunk).
211    fn force_ptr(&mut self, ptr: *mut u8) -> *mut u8 {
212        let mut current = ptr;
213        loop {
214            if current.is_null() {
215                return current;
216            }
217            // SAFETY: current is non-null (checked above) and points to a valid heap object.
218            let tag = unsafe { *current };
219            if tag == layout::TAG_THUNK {
220                let vmctx = &mut self.vmctx as *mut VMContext;
221                current = crate::host_fns::heap_force(vmctx, current);
222            } else {
223                return current;
224            }
225        }
226    }
227
228    /// Apply a Leaf/Node continuation tree to a value, yielding a new Eff result.
229    ///
230    /// Mirrors the interpreter's `apply_cont` on raw heap pointers:
231    /// - Leaf(f): call f(arg) via call_closure
232    /// - Node(k1, k2): apply k1(arg), if Val(y) → k2(y), if E(union, k') → E(union, Node(k', k2))
233    /// - Closure: direct call_closure (degenerate continuation fallback)
234    ///
235    /// # Safety
236    ///
237    /// `k` and `arg` must be valid heap pointers.
238    unsafe fn apply_cont_heap(&mut self, k: *mut u8, arg: *mut u8) -> *mut u8 {
239        // SAFETY: k and arg are valid heap pointers (or null, handled below).
240        // All field reads use known layout offsets. Recursive calls maintain the invariant.
241        if k.is_null() {
242            return std::ptr::null_mut();
243        }
244
245        // Force k and arg in case they are thunks (lazy Con fields)
246        let k = self.force_ptr(k);
247        if k.is_null() {
248            return std::ptr::null_mut();
249        }
250        let arg = self.force_ptr(arg);
251
252        let tag = *k;
253        match tag {
254            t if t == layout::TAG_CON => {
255                let con_tag = unsafe { *(k.add(layout::CON_TAG_OFFSET as usize) as *const u64) };
256
257                if con_tag == self.tags.leaf {
258                    // Leaf(f) — extract closure f at field[0], call f(arg)
259                    let f = self.force_ptr(unsafe {
260                        *(k.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8)
261                    });
262                    self.call_closure(f, arg)
263                } else if con_tag == self.tags.node {
264                    // Node(k1, k2) — apply k1 to arg, then compose with k2
265                    let k1 = self.force_ptr(unsafe {
266                        *(k.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8)
267                    });
268                    let k2 = self.force_ptr(unsafe {
269                        *(k.add(layout::CON_FIELDS_OFFSET as usize + 8) as *const *mut u8)
270                    });
271
272                    let result = self.apply_cont_heap(k1, arg);
273                    if result.is_null() {
274                        return std::ptr::null_mut();
275                    }
276
277                    // Force result in case it's a thunk
278                    let result = self.force_ptr(result);
279                    if result.is_null() {
280                        return std::ptr::null_mut();
281                    }
282
283                    // Check if result is Val or E
284                    let result_tag = unsafe { *result };
285                    if result_tag != layout::TAG_CON {
286                        return std::ptr::null_mut();
287                    }
288
289                    let result_con_tag =
290                        unsafe { *(result.add(layout::CON_TAG_OFFSET as usize) as *const u64) };
291
292                    if result_con_tag == self.tags.val {
293                        // Val(y) — extract y, apply k2(y)
294                        let y = self.force_ptr(unsafe {
295                            *(result.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8)
296                        });
297                        self.apply_cont_heap(k2, y)
298                    } else if result_con_tag == self.tags.e {
299                        // E(union, k') — compose: E(union, Node(k', k2))
300                        let union_val = self.force_ptr(unsafe {
301                            *(result.add(layout::CON_FIELDS_OFFSET as usize) as *const *mut u8)
302                        });
303                        let k_prime = self.force_ptr(unsafe {
304                            *(result.add(layout::CON_FIELDS_OFFSET as usize + 8) as *const *mut u8)
305                        });
306
307                        // Allocate Node(k', k2)
308                        let new_node = self.alloc_con(self.tags.node, &[k_prime, k2]);
309                        if new_node.is_null() {
310                            return std::ptr::null_mut();
311                        }
312                        // Allocate E(union, new_node)
313                        self.alloc_con(self.tags.e, &[union_val, new_node])
314                    } else {
315                        std::ptr::null_mut()
316                    }
317                } else {
318                    // Unknown Con tag in continuation position — error
319                    std::ptr::null_mut()
320                }
321            }
322            t if t == layout::TAG_CLOSURE => {
323                // Raw closure (degenerate continuation fallback)
324                self.call_closure(k, arg)
325            }
326            t if t == layout::TAG_THUNK => {
327                // Thunk in continuation position — already forced above, shouldn't happen
328                std::ptr::null_mut()
329            }
330            _ => std::ptr::null_mut(),
331        }
332    }
333
334    /// Call a compiled closure: read code_ptr from closure[8], invoke it.
335    ///
336    /// # Safety
337    ///
338    /// `closure` must point to a valid Closure HeapObject.
339    unsafe fn call_closure(&mut self, closure: *mut u8, arg: *mut u8) -> *mut u8 {
340        // SAFETY: closure is a valid Closure heap object. Reading code_ptr at the known offset.
341        let code_ptr = *(closure.add(layout::CLOSURE_CODE_PTR_OFFSET as usize) as *const usize);
342
343        let trace = crate::debug::trace_level();
344        if trace >= crate::debug::TraceLevel::Calls {
345            let name = crate::debug::lookup_lambda(code_ptr)
346                .unwrap_or_else(|| format!("0x{:x}", code_ptr));
347            eprintln!(
348                "[trace] call_closure {} closure={:?} arg={}",
349                name,
350                closure,
351                crate::debug::heap_describe(arg),
352            );
353        }
354        if trace >= crate::debug::TraceLevel::Heap {
355            if let Err(e) = crate::debug::heap_validate_deep(closure) {
356                eprintln!("[trace] INVALID closure: {}", e);
357                eprintln!("[trace]   {}", crate::debug::heap_describe(closure));
358                return std::ptr::null_mut();
359            }
360            if let Err(e) = crate::debug::heap_validate(arg) {
361                eprintln!("[trace] INVALID arg: {}", e);
362                return std::ptr::null_mut();
363            }
364            // Dump captures
365            let num_captured =
366                *(closure.add(layout::CLOSURE_NUM_CAPTURED_OFFSET as usize) as *const u16);
367            for i in 0..num_captured as usize {
368                let cap = *(closure.add(layout::CLOSURE_CAPTURED_OFFSET as usize + 8 * i)
369                    as *const *const u8);
370                if cap.is_null() {
371                    eprintln!("[trace]   capture[{}] = NULL", i);
372                } else {
373                    eprintln!(
374                        "[trace]   capture[{}] = {}",
375                        i,
376                        crate::debug::heap_describe(cap)
377                    );
378                }
379            }
380        }
381
382        // SAFETY: code_ptr was set during JIT compilation and points to a finalized
383        // Cranelift function with the closure calling convention (vmctx, self, arg) -> result.
384        let func: unsafe extern "C" fn(*mut VMContext, *mut u8, *mut u8) -> *mut u8 =
385            std::mem::transmute(code_ptr);
386        let mut result = func(&mut self.vmctx, closure, arg);
387        // SAFETY: After a closure call, pending tail calls may be stored in vmctx.
388        unsafe {
389            self.resolve_tail_calls(&mut result);
390        }
391
392        if trace >= crate::debug::TraceLevel::Calls {
393            let name = crate::debug::lookup_lambda(code_ptr)
394                .unwrap_or_else(|| format!("0x{:x}", code_ptr));
395            if result.is_null() {
396                eprintln!("[trace] {} returned NULL", name);
397            } else {
398                eprintln!(
399                    "[trace] {} returned {}",
400                    name,
401                    crate::debug::heap_describe(result)
402                );
403            }
404        }
405
406        result
407    }
408
409    /// Resolve pending tail calls stored in VMContext by the JIT.
410    ///
411    /// # Safety
412    /// VMContext must have valid tail_callee/tail_arg if non-null.
413    unsafe fn resolve_tail_calls(&mut self, result: &mut *mut u8) {
414        // SAFETY: tail_callee and tail_arg are valid heap pointers set by JIT tail-call
415        // sites. Code pointers in closures point to finalized JIT functions.
416        while result.is_null() && !self.vmctx.tail_callee.is_null() {
417            let callee = self.vmctx.tail_callee;
418            let arg = self.vmctx.tail_arg;
419            self.vmctx.tail_callee = std::ptr::null_mut();
420            self.vmctx.tail_arg = std::ptr::null_mut();
421            crate::host_fns::reset_call_depth();
422            let code_ptr = *(callee.add(layout::CLOSURE_CODE_PTR_OFFSET as usize) as *const usize);
423            let func: unsafe extern "C" fn(*mut VMContext, *mut u8, *mut u8) -> *mut u8 =
424                std::mem::transmute(code_ptr);
425            *result = func(&mut self.vmctx, callee, arg);
426        }
427    }
428
429    /// Allocate a Con HeapObject on the nursery with the given tag and fields.
430    unsafe fn alloc_con(&mut self, con_tag: u64, fields: &[*mut u8]) -> *mut u8 {
431        // SAFETY: Bump-allocating from vmctx nursery. Writing Con header, tag,
432        // num_fields, and field pointers at known layout offsets within the allocation.
433        let size = 24 + 8 * fields.len();
434        let ptr = heap_bridge::bump_alloc_from_vmctx(&mut self.vmctx, size);
435        if ptr.is_null() {
436            return std::ptr::null_mut();
437        }
438        heap_layout::write_header(ptr, layout::TAG_CON, size as u16);
439        *(ptr.add(layout::CON_TAG_OFFSET as usize) as *mut u64) = con_tag;
440        *(ptr.add(layout::CON_NUM_FIELDS_OFFSET as usize) as *mut u16) = fields.len() as u16;
441        for (i, &fp) in fields.iter().enumerate() {
442            *(ptr.add(layout::CON_FIELDS_OFFSET as usize + 8 * i) as *mut *mut u8) = fp;
443        }
444        ptr
445    }
446}