venus_core/execute/
executor.rs

1//! Linear executor for sequential cell execution.
2//!
3//! Executes cells in dependency order, one at a time.
4
5use std::collections::HashMap;
6use std::path::Path;
7use std::sync::Arc;
8
9use libloading::Symbol;
10use tracing::{debug, trace, warn};
11
12use crate::compile::CompiledCell;
13use crate::error::{Error, Result};
14use crate::graph::CellId;
15use crate::state::{BoxedOutput, StateManager};
16
17use super::context::{AbortHandle, ExecutionCallback};
18use super::ffi::{
19    EntryFn0, EntryFn1, EntryFn2, EntryFn3, EntryFn4, EntryFn5, EntryFn6, EntryFn7, EntryFn8,
20    ExecutionResult, call_cell_n_deps,
21};
22use super::loaded_cell::LoadedCell;
23
24/// RAII guard for FFI-allocated memory.
25/// Ensures libc::free is called even if panic occurs during processing.
26struct FfiMemoryGuard {
27    ptr: *mut u8,
28}
29
30impl FfiMemoryGuard {
31    unsafe fn new(ptr: *mut u8) -> Self {
32        Self { ptr }
33    }
34
35    fn as_slice(&self, len: usize) -> &[u8] {
36        unsafe { std::slice::from_raw_parts(self.ptr, len) }
37    }
38}
39
40impl Drop for FfiMemoryGuard {
41    fn drop(&mut self) {
42        if !self.ptr.is_null() {
43            unsafe {
44                libc::free(self.ptr as *mut libc::c_void);
45            }
46        }
47    }
48}
49
50/// Linear executor that runs cells sequentially in dependency order.
51pub struct LinearExecutor {
52    /// Loaded cell libraries
53    cells: HashMap<CellId, LoadedCell>,
54    /// State manager for inputs/outputs
55    state: StateManager,
56    /// Execution callback for progress reporting
57    callback: Option<Box<dyn ExecutionCallback>>,
58    /// Abort handle for cooperative cancellation
59    abort_handle: Option<AbortHandle>,
60}
61
62impl LinearExecutor {
63    /// Create a new linear executor.
64    pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
65        Ok(Self {
66            cells: HashMap::new(),
67            state: StateManager::new(state_dir)?,
68            callback: None,
69            abort_handle: None,
70        })
71    }
72
73    /// Create with an existing state manager.
74    pub fn with_state(state: StateManager) -> Self {
75        Self {
76            cells: HashMap::new(),
77            state,
78            callback: None,
79            abort_handle: None,
80        }
81    }
82
83    /// Set the execution callback for progress reporting.
84    pub fn set_callback(&mut self, callback: impl ExecutionCallback + 'static) {
85        self.callback = Some(Box::new(callback));
86    }
87
88    /// Set the abort handle for cooperative cancellation.
89    pub fn set_abort_handle(&mut self, handle: AbortHandle) {
90        self.abort_handle = Some(handle);
91    }
92
93    /// Get the current abort handle.
94    pub fn abort_handle(&self) -> Option<&AbortHandle> {
95        self.abort_handle.as_ref()
96    }
97
98    /// Check if execution has been aborted.
99    fn is_aborted(&self) -> bool {
100        self.abort_handle
101            .as_ref()
102            .is_some_and(|h| h.is_aborted())
103    }
104
105    /// Load a compiled cell for execution.
106    pub fn load_cell(&mut self, compiled: CompiledCell, dep_count: usize) -> Result<()> {
107        let cell_id = compiled.cell_id;
108        let loaded = LoadedCell::load(compiled, dep_count)?;
109        self.cells.insert(cell_id, loaded);
110        Ok(())
111    }
112
113    /// Unload a cell (e.g., before hot-reload).
114    pub fn unload_cell(&mut self, cell_id: CellId) -> Option<LoadedCell> {
115        self.cells.remove(&cell_id)
116    }
117
118    /// Restore a previously unloaded cell (for hot-reload rollback).
119    pub fn restore_cell(&mut self, cell: LoadedCell) {
120        self.cells.insert(cell.compiled.cell_id, cell);
121    }
122
123    /// Check if a cell is loaded.
124    pub fn is_loaded(&self, cell_id: CellId) -> bool {
125        self.cells.contains_key(&cell_id)
126    }
127
128    /// Execute a single cell with the given inputs.
129    ///
130    /// Returns the serialized output on success.
131    /// Returns `Error::Aborted` if abort was requested before execution.
132    pub fn execute_cell(
133        &mut self,
134        cell_id: CellId,
135        inputs: &[Arc<BoxedOutput>],
136    ) -> Result<BoxedOutput> {
137        // Check for abort before starting
138        if self.is_aborted() {
139            return Err(Error::Aborted);
140        }
141
142        let loaded = self
143            .cells
144            .get(&cell_id)
145            .ok_or_else(|| Error::CellNotFound(format!("Cell {:?} not loaded", cell_id)))?;
146
147        // Notify callback
148        if let Some(ref callback) = self.callback {
149            callback.on_cell_started(cell_id, &loaded.compiled.name);
150        }
151
152        // Execute the cell
153        let result = self.call_cell_ffi(loaded, inputs);
154
155        // Check for abort after execution (cell may have been aborted mid-flight)
156        if self.is_aborted() {
157            if let Some(ref callback) = self.callback {
158                callback.on_cell_error(cell_id, &loaded.compiled.name, &Error::Aborted);
159            }
160            return Err(Error::Aborted);
161        }
162
163        // Notify callback
164        match &result {
165            Ok(_) => {
166                if let Some(ref callback) = self.callback {
167                    callback.on_cell_completed(cell_id, &loaded.compiled.name);
168                }
169            }
170            Err(e) => {
171                if let Some(ref callback) = self.callback {
172                    callback.on_cell_error(cell_id, &loaded.compiled.name, e);
173                }
174            }
175        }
176
177        result
178    }
179
180    /// Execute a cell and store the output in the state manager.
181    pub fn execute_and_store(
182        &mut self,
183        cell_id: CellId,
184        inputs: &[Arc<BoxedOutput>],
185    ) -> Result<()> {
186        let output = self.execute_cell(cell_id, inputs)?;
187        self.state.store_output(cell_id, output);
188        Ok(())
189    }
190
191    /// Execute cells in the given order, resolving dependencies from state.
192    ///
193    /// Returns `Error::Aborted` if abort was requested during execution.
194    pub fn execute_in_order(
195        &mut self,
196        order: &[CellId],
197        deps: &HashMap<CellId, Vec<CellId>>,
198    ) -> Result<()> {
199        for &cell_id in order {
200            // Check for abort before each cell
201            if self.is_aborted() {
202                return Err(Error::Aborted);
203            }
204
205            // Gather inputs from dependencies
206            let dep_ids = deps.get(&cell_id).cloned().unwrap_or_default();
207            let inputs: Vec<Arc<BoxedOutput>> = dep_ids
208                .iter()
209                .filter_map(|&dep_id| self.state.get_output(dep_id))
210                .collect();
211
212            // Check we have all required inputs
213            if inputs.len() != dep_ids.len() {
214                return Err(Error::Execution(format!(
215                    "Missing dependencies for cell {:?}: expected {}, got {}",
216                    cell_id,
217                    dep_ids.len(),
218                    inputs.len()
219                )));
220            }
221
222            self.execute_and_store(cell_id, &inputs)?;
223        }
224
225        Ok(())
226    }
227
228    /// Get a reference to the state manager.
229    pub fn state(&self) -> &StateManager {
230        &self.state
231    }
232
233    /// Get a mutable reference to the state manager.
234    pub fn state_mut(&mut self) -> &mut StateManager {
235        &mut self.state
236    }
237
238    /// Call the cell's FFI entry point.
239    fn call_cell_ffi(
240        &self,
241        loaded: &LoadedCell,
242        inputs: &[Arc<BoxedOutput>],
243    ) -> Result<BoxedOutput> {
244        // Verify input count matches
245        if inputs.len() != loaded.dep_count {
246            return Err(Error::Execution(format!(
247                "Cell {} expects {} inputs, got {}",
248                loaded.compiled.name,
249                loaded.dep_count,
250                inputs.len()
251            )));
252        }
253
254        // For cells with no dependencies, use the simple path
255        if loaded.dep_count == 0 {
256            return self.call_cell_no_deps(loaded);
257        }
258
259        // For cells with dependencies, we need to construct the FFI call dynamically
260        // This is complex because the number of parameters varies
261        self.call_cell_with_deps(loaded, inputs)
262    }
263
264    /// Call a cell with no dependencies.
265    fn call_cell_no_deps(&self, loaded: &LoadedCell) -> Result<BoxedOutput> {
266        let symbol_name = loaded.entry_symbol();
267
268        // Safety: We trust the symbol exists and has the correct signature
269        let func: Symbol<EntryFn0> = unsafe { loaded.library.get(symbol_name.as_bytes()) }
270            .map_err(|e| {
271                Error::Execution(format!("Failed to get symbol {}: {}", symbol_name, e))
272            })?;
273
274        let mut out_ptr: *mut u8 = std::ptr::null_mut();
275        let mut out_len: usize = 0;
276
277        // Empty widget values (LinearExecutor doesn't support widgets)
278        let widget_values: &[u8] = &[];
279
280        // Safety: We're calling a function generated by our compiler
281        let result_code = unsafe {
282            func(
283                widget_values.as_ptr(), widget_values.len(),
284                &mut out_ptr, &mut out_len,
285            )
286        };
287
288        self.process_ffi_result(result_code, out_ptr, out_len, &loaded.compiled.name)
289    }
290
291    /// Call a cell with dependencies (up to 8 supported via macro).
292    ///
293    /// Uses the `call_cell_n_deps!` macro to eliminate code duplication.
294    /// Each match arm generates the appropriate typed FFI call.
295    fn call_cell_with_deps(
296        &self,
297        loaded: &LoadedCell,
298        inputs: &[Arc<BoxedOutput>],
299    ) -> Result<BoxedOutput> {
300        let symbol_name = loaded.entry_symbol();
301
302        // Empty widget values (LinearExecutor doesn't support widgets)
303        let widget_values: &[u8] = &[];
304
305        debug!(
306            cell = %loaded.compiled.name,
307            dep_count = inputs.len(),
308            "Calling FFI entry point"
309        );
310        trace!(
311            cell = %loaded.compiled.name,
312            input_sizes = ?inputs.iter().map(|i| i.bytes().len()).collect::<Vec<_>>(),
313            "Input buffer sizes (bytes)"
314        );
315
316        // Implementation note: Uses match arms for 1-10 dependencies.
317        // libffi could support arbitrary counts but adds complexity and overhead.
318        // Current limit (10 dependencies) is sufficient for typical notebook cells.
319        match inputs.len() {
320            1 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn1, 0),
321            2 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn2, 0, 1),
322            3 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn3, 0, 1, 2),
323            4 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn4, 0, 1, 2, 3),
324            5 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn5, 0, 1, 2, 3, 4),
325            6 => call_cell_n_deps!(
326                self,
327                loaded,
328                symbol_name,
329                inputs,
330                widget_values,
331                EntryFn6,
332                0,
333                1,
334                2,
335                3,
336                4,
337                5
338            ),
339            7 => call_cell_n_deps!(
340                self,
341                loaded,
342                symbol_name,
343                inputs,
344                widget_values,
345                EntryFn7,
346                0,
347                1,
348                2,
349                3,
350                4,
351                5,
352                6
353            ),
354            8 => call_cell_n_deps!(
355                self,
356                loaded,
357                symbol_name,
358                inputs,
359                widget_values,
360                EntryFn8,
361                0,
362                1,
363                2,
364                3,
365                4,
366                5,
367                6,
368                7
369            ),
370            n => Err(Error::Execution(format!(
371                "Cells with {} dependencies not yet supported (max 8)",
372                n
373            ))),
374        }
375    }
376
377    /// Process the FFI result and convert output to BoxedOutput.
378    ///
379    /// Output format from cells:
380    /// - display_len (8 bytes, u64 LE): length of display string
381    /// - display_bytes (N bytes): display string (UTF-8)
382    /// - widgets_len (8 bytes, u64 LE): length of widgets JSON
383    /// - widgets_json (M bytes): JSON-encoded widget definitions
384    /// - rkyv_data (remaining bytes): rkyv-serialized data
385    pub(crate) fn process_ffi_result(
386        &self,
387        result_code: i32,
388        out_ptr: *mut u8,
389        out_len: usize,
390        cell_name: &str,
391    ) -> Result<BoxedOutput> {
392        let result = ExecutionResult::from(result_code);
393
394        match result {
395            ExecutionResult::Success => {
396                if out_ptr.is_null() || out_len == 0 {
397                    return Err(Error::Execution(format!(
398                        "Cell {} returned null output",
399                        cell_name
400                    )));
401                }
402
403                // Safety: The cell allocated this memory via libc malloc
404                // Use RAII guard to ensure cleanup even if processing panics
405                let memory_guard = unsafe { FfiMemoryGuard::new(out_ptr) };
406                let bytes = memory_guard.as_slice(out_len).to_vec();
407                // Guard's Drop will free the memory automatically
408
409                // Parse output format:
410                // display_len (8) | display_bytes (N) | widgets_len (8) | widgets_json (M) | rkyv_data
411
412                if bytes.len() < 16 {
413                    return Err(Error::Execution(format!(
414                        "Cell {} output too short: {} bytes",
415                        cell_name, bytes.len()
416                    )));
417                }
418
419                // Read display_len
420                let display_len_bytes: [u8; 8] = bytes[0..8].try_into().map_err(|_| {
421                    Error::Execution(format!(
422                        "Cell {} output has malformed display_len field",
423                        cell_name
424                    ))
425                })?;
426                let display_len = u64::from_le_bytes(display_len_bytes) as usize;
427                let display_end = 8 + display_len;
428
429                if bytes.len() < display_end + 8 {
430                    return Err(Error::Execution(format!(
431                        "Cell {} output too short for display data",
432                        cell_name
433                    )));
434                }
435
436                // Read widgets_len
437                let widgets_len_bytes: [u8; 8] = bytes[display_end..display_end + 8].try_into().map_err(|_| {
438                    Error::Execution(format!(
439                        "Cell {} output has malformed widgets_len field",
440                        cell_name
441                    ))
442                })?;
443                let widgets_len = u64::from_le_bytes(widgets_len_bytes) as usize;
444                let widgets_end = display_end + 8 + widgets_len;
445
446                if bytes.len() < widgets_end {
447                    return Err(Error::Execution(format!(
448                        "Cell {} output too short for widget data",
449                        cell_name
450                    )));
451                }
452
453                // Format is: display_len | display_bytes | widgets_len | widgets_json | rkyv_data
454                let display_text = String::from_utf8_lossy(&bytes[8..display_end]).to_string();
455                let rkyv_data = bytes[widgets_end..].to_vec();
456
457                Ok(BoxedOutput::from_raw_bytes_with_display(rkyv_data, display_text))
458            }
459            ExecutionResult::DeserializationError => {
460                warn!(
461                    cell = %cell_name,
462                    "Cell failed to deserialize input - likely type mismatch. Enable RUST_LOG=debug for details."
463                );
464                Err(Error::Execution(format!(
465                    "Cell {} failed to deserialize input - check dependency types match parameter types. Run with RUST_LOG=venus=debug for details.",
466                    cell_name
467                )))
468            }
469            ExecutionResult::CellError => Err(Error::Execution(format!(
470                "Cell {} returned an error",
471                cell_name
472            ))),
473            ExecutionResult::SerializationError => Err(Error::Execution(format!(
474                "Cell {} failed to serialize output",
475                cell_name
476            ))),
477            ExecutionResult::Panic => Err(Error::Execution(format!(
478                "Cell {} panicked during execution. Check for unwrap() on None/Err, out-of-bounds access, or other panic sources.",
479                cell_name
480            ))),
481        }
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_linear_executor_creation() {
491        let temp = tempfile::TempDir::new().unwrap();
492        let executor = LinearExecutor::new(temp.path()).unwrap();
493        assert!(executor.cells.is_empty());
494        assert!(executor.callback.is_none());
495        assert!(executor.abort_handle.is_none());
496    }
497
498    #[test]
499    fn test_with_state_creation() {
500        let temp = tempfile::TempDir::new().unwrap();
501        let state = StateManager::new(temp.path()).unwrap();
502        let executor = LinearExecutor::with_state(state);
503        assert!(executor.cells.is_empty());
504    }
505
506    #[test]
507    fn test_set_callback() {
508        let temp = tempfile::TempDir::new().unwrap();
509        let mut executor = LinearExecutor::new(temp.path()).unwrap();
510
511        struct TestCallback;
512        impl ExecutionCallback for TestCallback {
513            fn on_cell_started(&self, _: CellId, _: &str) {}
514            fn on_cell_completed(&self, _: CellId, _: &str) {}
515            fn on_cell_error(&self, _: CellId, _: &str, _: &Error) {}
516            fn on_level_started(&self, _: usize, _: usize) {}
517            fn on_level_completed(&self, _: usize) {}
518        }
519
520        executor.set_callback(TestCallback);
521        assert!(executor.callback.is_some());
522    }
523
524    #[test]
525    fn test_abort_handle() {
526        let temp = tempfile::TempDir::new().unwrap();
527        let mut executor = LinearExecutor::new(temp.path()).unwrap();
528
529        let handle = AbortHandle::new();
530        executor.set_abort_handle(handle.clone());
531
532        assert!(executor.abort_handle().is_some());
533        assert!(!executor.is_aborted());
534
535        handle.abort();
536        assert!(executor.is_aborted());
537    }
538
539    #[test]
540    fn test_is_loaded() {
541        let temp = tempfile::TempDir::new().unwrap();
542        let executor = LinearExecutor::new(temp.path()).unwrap();
543
544        let cell_id = CellId::new(1);
545        assert!(!executor.is_loaded(cell_id));
546
547        // Note: We can't actually load without a real dylib, but we test the interface
548        assert!(!executor.is_loaded(cell_id));
549    }
550
551    #[test]
552    fn test_get_state_reference() {
553        let temp = tempfile::TempDir::new().unwrap();
554        let executor = LinearExecutor::new(temp.path()).unwrap();
555
556        let state_ref = executor.state();
557        // Verify we can access state methods
558        let stats = state_ref.stats();
559        assert_eq!(stats.cached_outputs, 0);
560    }
561
562    #[test]
563    fn test_execute_in_order_empty() {
564        let temp = tempfile::TempDir::new().unwrap();
565        let mut executor = LinearExecutor::new(temp.path()).unwrap();
566
567        let empty_order: Vec<CellId> = vec![];
568        let empty_deps = std::collections::HashMap::new();
569
570        let result = executor.execute_in_order(&empty_order, &empty_deps);
571        assert!(result.is_ok());
572    }
573
574    #[test]
575    fn test_execute_cell_not_found() {
576        let temp = tempfile::TempDir::new().unwrap();
577        let mut executor = LinearExecutor::new(temp.path()).unwrap();
578
579        let cell_id = CellId::new(999);
580        let result = executor.execute_cell(cell_id, &[]);
581
582        assert!(result.is_err());
583        match result {
584            Err(Error::CellNotFound(msg)) => {
585                assert!(msg.contains("not loaded"));
586            }
587            _ => panic!("Expected CellNotFound error"),
588        }
589    }
590
591    #[test]
592    fn test_execute_cell_aborted() {
593        let temp = tempfile::TempDir::new().unwrap();
594        let mut executor = LinearExecutor::new(temp.path()).unwrap();
595
596        let handle = AbortHandle::new();
597        executor.set_abort_handle(handle.clone());
598        handle.abort();
599
600        let cell_id = CellId::new(1);
601        let result = executor.execute_cell(cell_id, &[]);
602
603        assert!(matches!(result, Err(Error::Aborted)));
604    }
605}