venus_core/execute/
parallel.rs

1//! Parallel executor for Venus notebooks.
2//!
3//! Executes cells in parallel based on dependency levels using Rayon.
4
5use std::collections::HashMap;
6use std::path::Path;
7use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
8
9use rayon::prelude::*;
10
11use super::context::ExecutionCallback;
12use super::{LinearExecutor, LoadedCell};
13use crate::compile::CompiledCell;
14use crate::error::{Error, Result};
15use crate::graph::CellId;
16use crate::state::{BoxedOutput, StateManager};
17
18/// Parallel executor that runs independent cells concurrently.
19///
20/// Cells are grouped by dependency level and executed in parallel
21/// within each level. Levels are processed sequentially to maintain
22/// dependency ordering.
23pub struct ParallelExecutor {
24    /// Inner linear executor (wrapped for thread-safe access)
25    inner: Arc<Mutex<LinearExecutor>>,
26    /// Execution callback
27    callback: Option<Arc<dyn ExecutionCallback>>,
28}
29
30/// Helper to convert PoisonError to our Error type.
31///
32/// Centralizes lock error handling to eliminate duplication.
33fn lock_error<T>(e: PoisonError<T>) -> Error {
34    Error::Execution(format!("Executor lock poisoned (thread panicked): {}", e))
35}
36
37impl ParallelExecutor {
38    /// Create a new parallel executor.
39    pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
40        let inner = LinearExecutor::new(state_dir)?;
41        Ok(Self {
42            inner: Arc::new(Mutex::new(inner)),
43            callback: None,
44        })
45    }
46
47    /// Create with an existing state manager.
48    pub fn with_state(state: StateManager) -> Self {
49        let inner = LinearExecutor::with_state(state);
50        Self {
51            inner: Arc::new(Mutex::new(inner)),
52            callback: None,
53        }
54    }
55
56    /// Acquire the inner executor lock.
57    ///
58    /// Helper method to centralize lock acquisition and error handling.
59    fn acquire_lock(&self) -> Result<MutexGuard<'_, LinearExecutor>> {
60        self.inner.lock().map_err(lock_error)
61    }
62
63    /// Set the execution callback.
64    pub fn set_callback(&mut self, callback: impl ExecutionCallback + 'static) {
65        self.callback = Some(Arc::new(callback));
66    }
67
68    /// Load a compiled cell for execution.
69    pub fn load_cell(&self, compiled: CompiledCell, dep_count: usize) -> Result<()> {
70        self.acquire_lock()?.load_cell(compiled, dep_count)
71    }
72
73    /// Unload a cell.
74    pub fn unload_cell(&self, cell_id: CellId) -> Result<Option<LoadedCell>> {
75        Ok(self.acquire_lock()?.unload_cell(cell_id))
76    }
77
78    /// Execute cells in parallel based on dependency levels.
79    ///
80    /// # Arguments
81    /// * `levels` - Cells grouped by dependency level (earlier levels have no deps on later ones)
82    /// * `deps` - Dependency map: cell_id -> list of dependency cell_ids
83    pub fn execute_parallel(
84        &self,
85        levels: &[Vec<CellId>],
86        deps: &HashMap<CellId, Vec<CellId>>,
87    ) -> Result<()> {
88        for (level_idx, level_cells) in levels.iter().enumerate() {
89            if level_cells.is_empty() {
90                continue;
91            }
92
93            // Notify callback
94            if let Some(ref callback) = self.callback {
95                callback.on_level_started(level_idx, level_cells.len());
96            }
97
98            // Execute all cells in this level in parallel
99            let results: Vec<Result<()>> = level_cells
100                .par_iter()
101                .map(|&cell_id| self.execute_single_cell(cell_id, deps))
102                .collect();
103
104            // Check for errors
105            let errors: Vec<_> = results.into_iter().filter_map(|r| r.err()).collect();
106            if !errors.is_empty() {
107                // Note: Returns first error. Error aggregation could be added in future.
108                return Err(errors.into_iter().next().unwrap());
109            }
110
111            // Notify callback
112            if let Some(ref callback) = self.callback {
113                callback.on_level_completed(level_idx);
114            }
115        }
116
117        Ok(())
118    }
119
120    /// Execute a single cell, gathering its dependencies.
121    ///
122    /// Acquires the lock once for the entire operation to minimize contention.
123    fn execute_single_cell(
124        &self,
125        cell_id: CellId,
126        deps: &HashMap<CellId, Vec<CellId>>,
127    ) -> Result<()> {
128        let dep_ids = deps.get(&cell_id).cloned().unwrap_or_default();
129
130        // Known limitation: Cells within a level execute sequentially (not in parallel)
131        // because execute_cell requires &mut self. This is a correctness-first design.
132        // True intra-level parallelism would require:
133        //   1. Separating read-only FFI calls from state mutations
134        //   2. RwLock instead of Mutex for concurrent reads
135        //   3. Read lock during FFI, exclusive lock only for output storage
136        // Inter-level parallelism (different levels execute in order) is preserved.
137
138        let mut inner = self.acquire_lock()?;
139
140        // Gather dependency outputs
141        let inputs: Vec<Arc<BoxedOutput>> = dep_ids
142            .iter()
143            .filter_map(|&dep_id| inner.state().get_output(dep_id))
144            .collect();
145
146        // Verify we have all inputs
147        if inputs.len() != dep_ids.len() {
148            return Err(Error::Execution(format!(
149                "Missing dependencies for cell {:?}: expected {}, got {}",
150                cell_id,
151                dep_ids.len(),
152                inputs.len()
153            )));
154        }
155
156        // Execute and store output atomically to prevent races
157        let output = inner.execute_cell(cell_id, &inputs)?;
158        inner.state_mut().store_output(cell_id, output);
159
160        Ok(())
161    }
162
163    /// Get access to the inner executor.
164    pub fn inner(&self) -> &Arc<Mutex<LinearExecutor>> {
165        &self.inner
166    }
167
168    /// Flush all cached outputs to disk.
169    pub fn flush(&self) -> Result<()> {
170        self.acquire_lock()?.state_mut().flush()
171    }
172}
173
174// Note: ExecutionStats was removed as unused dead code.
175// If metrics collection is needed in the future, it can be re-added
176// with fields: cells_executed, levels, max_parallelism, total_time_ms.
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::execute::AbortHandle;
182
183    #[test]
184    fn test_parallel_executor_creation() {
185        let temp = tempfile::TempDir::new().unwrap();
186        let executor = ParallelExecutor::new(temp.path()).unwrap();
187
188        // Should be empty initially
189        let inner = executor.inner.lock().unwrap();
190        assert!(inner.state().stats().cached_outputs == 0);
191    }
192
193    #[test]
194    fn test_with_state_creation() {
195        let temp = tempfile::TempDir::new().unwrap();
196        let state = StateManager::new(temp.path()).unwrap();
197        let executor = ParallelExecutor::with_state(state);
198
199        let inner = executor.inner.lock().unwrap();
200        assert!(inner.state().stats().cached_outputs == 0);
201    }
202
203    #[test]
204    fn test_set_callback() {
205        let temp = tempfile::TempDir::new().unwrap();
206        let mut executor = ParallelExecutor::new(temp.path()).unwrap();
207
208        struct TestCallback;
209        impl ExecutionCallback for TestCallback {
210            fn on_cell_started(&self, _: CellId, _: &str) {}
211            fn on_cell_completed(&self, _: CellId, _: &str) {}
212            fn on_cell_error(&self, _: CellId, _: &str, _: &Error) {}
213            fn on_level_started(&self, _: usize, _: usize) {}
214            fn on_level_completed(&self, _: usize) {}
215        }
216
217        // Just verify we can set a callback (callback field is private, can't test directly)
218        executor.set_callback(TestCallback);
219    }
220
221    #[test]
222    fn test_abort_handle() {
223        let temp = tempfile::TempDir::new().unwrap();
224        let executor = ParallelExecutor::new(temp.path()).unwrap();
225
226        let handle = AbortHandle::new();
227
228        {
229            let mut inner = executor.inner.lock().unwrap();
230            inner.set_abort_handle(handle.clone());
231            assert!(inner.abort_handle().is_some());
232        }
233
234        handle.abort();
235
236        // Verify abort is set (will be checked during execution)
237        assert!(handle.is_aborted());
238    }
239
240    #[test]
241    fn test_empty_levels() {
242        let temp = tempfile::TempDir::new().unwrap();
243        let executor = ParallelExecutor::new(temp.path()).unwrap();
244
245        let levels: Vec<Vec<CellId>> = vec![];
246        let deps = HashMap::new();
247
248        // Should succeed with no work to do
249        executor.execute_parallel(&levels, &deps).unwrap();
250    }
251
252    #[test]
253    fn test_is_loaded() {
254        let temp = tempfile::TempDir::new().unwrap();
255        let executor = ParallelExecutor::new(temp.path()).unwrap();
256
257        let cell_id = CellId::new(1);
258
259        let inner = executor.inner.lock().unwrap();
260        assert!(!inner.is_loaded(cell_id));
261    }
262
263    #[test]
264    fn test_get_state_reference() {
265        let temp = tempfile::TempDir::new().unwrap();
266        let executor = ParallelExecutor::new(temp.path()).unwrap();
267
268        let inner = executor.inner.lock().unwrap();
269        let stats = inner.state().stats();
270        assert_eq!(stats.cached_outputs, 0);
271    }
272
273    #[test]
274    fn test_execute_parallel_aborted() {
275        let temp = tempfile::TempDir::new().unwrap();
276        let executor = ParallelExecutor::new(temp.path()).unwrap();
277
278        let handle = AbortHandle::new();
279        {
280            let mut inner = executor.inner.lock().unwrap();
281            inner.set_abort_handle(handle.clone());
282        }
283        handle.abort();
284
285        let levels: Vec<Vec<CellId>> = vec![vec![CellId::new(1)]];
286        let deps = HashMap::new();
287
288        let result = executor.execute_parallel(&levels, &deps);
289        assert!(matches!(result, Err(Error::Aborted)));
290    }
291}