venus_core/execute/
parallel.rs1use std::collections::HashMap;
6use std::path::Path;
7use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
8
9use rayon::prelude::*;
10
11use super::context::ExecutionCallback;
12use super::{LinearExecutor, LoadedCell};
13use crate::compile::CompiledCell;
14use crate::error::{Error, Result};
15use crate::graph::CellId;
16use crate::state::{BoxedOutput, StateManager};
17
18pub struct ParallelExecutor {
24 inner: Arc<Mutex<LinearExecutor>>,
26 callback: Option<Arc<dyn ExecutionCallback>>,
28}
29
30fn lock_error<T>(e: PoisonError<T>) -> Error {
34 Error::Execution(format!("Executor lock poisoned (thread panicked): {}", e))
35}
36
37impl ParallelExecutor {
38 pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
40 let inner = LinearExecutor::new(state_dir)?;
41 Ok(Self {
42 inner: Arc::new(Mutex::new(inner)),
43 callback: None,
44 })
45 }
46
47 pub fn with_state(state: StateManager) -> Self {
49 let inner = LinearExecutor::with_state(state);
50 Self {
51 inner: Arc::new(Mutex::new(inner)),
52 callback: None,
53 }
54 }
55
56 fn acquire_lock(&self) -> Result<MutexGuard<'_, LinearExecutor>> {
60 self.inner.lock().map_err(lock_error)
61 }
62
63 pub fn set_callback(&mut self, callback: impl ExecutionCallback + 'static) {
65 self.callback = Some(Arc::new(callback));
66 }
67
68 pub fn load_cell(&self, compiled: CompiledCell, dep_count: usize) -> Result<()> {
70 self.acquire_lock()?.load_cell(compiled, dep_count)
71 }
72
73 pub fn unload_cell(&self, cell_id: CellId) -> Result<Option<LoadedCell>> {
75 Ok(self.acquire_lock()?.unload_cell(cell_id))
76 }
77
78 pub fn execute_parallel(
84 &self,
85 levels: &[Vec<CellId>],
86 deps: &HashMap<CellId, Vec<CellId>>,
87 ) -> Result<()> {
88 for (level_idx, level_cells) in levels.iter().enumerate() {
89 if level_cells.is_empty() {
90 continue;
91 }
92
93 if let Some(ref callback) = self.callback {
95 callback.on_level_started(level_idx, level_cells.len());
96 }
97
98 let results: Vec<Result<()>> = level_cells
100 .par_iter()
101 .map(|&cell_id| self.execute_single_cell(cell_id, deps))
102 .collect();
103
104 let errors: Vec<_> = results.into_iter().filter_map(|r| r.err()).collect();
106 if !errors.is_empty() {
107 return Err(errors.into_iter().next().unwrap());
109 }
110
111 if let Some(ref callback) = self.callback {
113 callback.on_level_completed(level_idx);
114 }
115 }
116
117 Ok(())
118 }
119
120 fn execute_single_cell(
124 &self,
125 cell_id: CellId,
126 deps: &HashMap<CellId, Vec<CellId>>,
127 ) -> Result<()> {
128 let dep_ids = deps.get(&cell_id).cloned().unwrap_or_default();
129
130 let mut inner = self.acquire_lock()?;
139
140 let inputs: Vec<Arc<BoxedOutput>> = dep_ids
142 .iter()
143 .filter_map(|&dep_id| inner.state().get_output(dep_id))
144 .collect();
145
146 if inputs.len() != dep_ids.len() {
148 return Err(Error::Execution(format!(
149 "Missing dependencies for cell {:?}: expected {}, got {}",
150 cell_id,
151 dep_ids.len(),
152 inputs.len()
153 )));
154 }
155
156 let output = inner.execute_cell(cell_id, &inputs)?;
158 inner.state_mut().store_output(cell_id, output);
159
160 Ok(())
161 }
162
163 pub fn inner(&self) -> &Arc<Mutex<LinearExecutor>> {
165 &self.inner
166 }
167
168 pub fn flush(&self) -> Result<()> {
170 self.acquire_lock()?.state_mut().flush()
171 }
172}
173
174#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::execute::AbortHandle;
182
183 #[test]
184 fn test_parallel_executor_creation() {
185 let temp = tempfile::TempDir::new().unwrap();
186 let executor = ParallelExecutor::new(temp.path()).unwrap();
187
188 let inner = executor.inner.lock().unwrap();
190 assert!(inner.state().stats().cached_outputs == 0);
191 }
192
193 #[test]
194 fn test_with_state_creation() {
195 let temp = tempfile::TempDir::new().unwrap();
196 let state = StateManager::new(temp.path()).unwrap();
197 let executor = ParallelExecutor::with_state(state);
198
199 let inner = executor.inner.lock().unwrap();
200 assert!(inner.state().stats().cached_outputs == 0);
201 }
202
203 #[test]
204 fn test_set_callback() {
205 let temp = tempfile::TempDir::new().unwrap();
206 let mut executor = ParallelExecutor::new(temp.path()).unwrap();
207
208 struct TestCallback;
209 impl ExecutionCallback for TestCallback {
210 fn on_cell_started(&self, _: CellId, _: &str) {}
211 fn on_cell_completed(&self, _: CellId, _: &str) {}
212 fn on_cell_error(&self, _: CellId, _: &str, _: &Error) {}
213 fn on_level_started(&self, _: usize, _: usize) {}
214 fn on_level_completed(&self, _: usize) {}
215 }
216
217 executor.set_callback(TestCallback);
219 }
220
221 #[test]
222 fn test_abort_handle() {
223 let temp = tempfile::TempDir::new().unwrap();
224 let executor = ParallelExecutor::new(temp.path()).unwrap();
225
226 let handle = AbortHandle::new();
227
228 {
229 let mut inner = executor.inner.lock().unwrap();
230 inner.set_abort_handle(handle.clone());
231 assert!(inner.abort_handle().is_some());
232 }
233
234 handle.abort();
235
236 assert!(handle.is_aborted());
238 }
239
240 #[test]
241 fn test_empty_levels() {
242 let temp = tempfile::TempDir::new().unwrap();
243 let executor = ParallelExecutor::new(temp.path()).unwrap();
244
245 let levels: Vec<Vec<CellId>> = vec![];
246 let deps = HashMap::new();
247
248 executor.execute_parallel(&levels, &deps).unwrap();
250 }
251
252 #[test]
253 fn test_is_loaded() {
254 let temp = tempfile::TempDir::new().unwrap();
255 let executor = ParallelExecutor::new(temp.path()).unwrap();
256
257 let cell_id = CellId::new(1);
258
259 let inner = executor.inner.lock().unwrap();
260 assert!(!inner.is_loaded(cell_id));
261 }
262
263 #[test]
264 fn test_get_state_reference() {
265 let temp = tempfile::TempDir::new().unwrap();
266 let executor = ParallelExecutor::new(temp.path()).unwrap();
267
268 let inner = executor.inner.lock().unwrap();
269 let stats = inner.state().stats();
270 assert_eq!(stats.cached_outputs, 0);
271 }
272
273 #[test]
274 fn test_execute_parallel_aborted() {
275 let temp = tempfile::TempDir::new().unwrap();
276 let executor = ParallelExecutor::new(temp.path()).unwrap();
277
278 let handle = AbortHandle::new();
279 {
280 let mut inner = executor.inner.lock().unwrap();
281 inner.set_abort_handle(handle.clone());
282 }
283 handle.abort();
284
285 let levels: Vec<Vec<CellId>> = vec![vec![CellId::new(1)]];
286 let deps = HashMap::new();
287
288 let result = executor.execute_parallel(&levels, &deps);
289 assert!(matches!(result, Err(Error::Aborted)));
290 }
291}