1use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10use crate::compile::CompiledCell;
11use crate::error::{Error, Result};
12use crate::graph::CellId;
13use crate::ipc::{WorkerKillHandle, WorkerPool};
14use crate::state::{BoxedOutput, StateManager};
15
16use super::context::{AbortHandle, ExecutionCallback};
17
18pub struct ProcessExecutor {
26 cells: HashMap<CellId, CompiledCellInfo>,
28 state: StateManager,
30 callback: Option<Box<dyn ExecutionCallback>>,
32 abort_handle: Option<AbortHandle>,
34 worker_pool: WorkerPool,
36 current_worker_kill: Arc<Mutex<Option<WorkerKillHandle>>>,
40}
41
42struct CompiledCellInfo {
44 compiled: CompiledCell,
45 dep_count: usize,
46}
47
48#[derive(Clone)]
53pub struct ExecutorKillHandle {
54 inner: Arc<Mutex<Option<WorkerKillHandle>>>,
55}
56
57impl ExecutorKillHandle {
58 pub fn kill(&self) {
62 if let Ok(guard) = self.inner.lock() {
63 if let Some(ref kill_handle) = *guard {
64 kill_handle.kill();
65 }
66 }
67 }
68}
69
70impl ProcessExecutor {
71 pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
73 Ok(Self {
74 cells: HashMap::new(),
75 state: StateManager::new(state_dir)?,
76 callback: None,
77 abort_handle: None,
78 worker_pool: WorkerPool::new(4), current_worker_kill: Arc::new(Mutex::new(None)),
80 })
81 }
82
83 pub fn with_state(state: StateManager) -> Self {
85 Self {
86 cells: HashMap::new(),
87 state,
88 callback: None,
89 abort_handle: None,
90 worker_pool: WorkerPool::new(4),
91 current_worker_kill: Arc::new(Mutex::new(None)),
92 }
93 }
94
95 pub fn with_warm_pool(state_dir: impl AsRef<Path>, pool_size: usize) -> Result<Self> {
97 Ok(Self {
98 cells: HashMap::new(),
99 state: StateManager::new(state_dir)?,
100 callback: None,
101 abort_handle: None,
102 worker_pool: WorkerPool::with_warm_workers(pool_size, pool_size.min(2))?,
103 current_worker_kill: Arc::new(Mutex::new(None)),
104 })
105 }
106
107 pub fn set_callback(&mut self, callback: impl ExecutionCallback + 'static) {
109 self.callback = Some(Box::new(callback));
110 }
111
112 pub fn set_abort_handle(&mut self, handle: AbortHandle) {
114 self.abort_handle = Some(handle);
115 }
116
117 pub fn abort_handle(&self) -> Option<&AbortHandle> {
119 self.abort_handle.as_ref()
120 }
121
122 fn is_aborted(&self) -> bool {
124 self.abort_handle
125 .as_ref()
126 .is_some_and(|h| h.is_aborted())
127 }
128
129 pub fn register_cell(&mut self, compiled: CompiledCell, dep_count: usize) {
134 let cell_id = compiled.cell_id;
135 self.cells.insert(cell_id, CompiledCellInfo {
136 compiled,
137 dep_count,
138 });
139 }
140
141 pub fn unregister_cell(&mut self, cell_id: CellId) -> Option<CompiledCell> {
143 self.cells.remove(&cell_id).map(|info| info.compiled)
144 }
145
146 pub fn is_registered(&self, cell_id: CellId) -> bool {
148 self.cells.contains_key(&cell_id)
149 }
150
151 pub fn execute_cell(
155 &mut self,
156 cell_id: CellId,
157 inputs: &[Arc<BoxedOutput>],
158 ) -> Result<BoxedOutput> {
159 self.execute_cell_with_widgets(cell_id, inputs, Vec::new())
160 .map(|(output, _widgets_json)| output)
161 }
162
163 pub fn execute_cell_with_widgets(
168 &mut self,
169 cell_id: CellId,
170 inputs: &[Arc<BoxedOutput>],
171 widget_values_json: Vec<u8>,
172 ) -> Result<(BoxedOutput, Vec<u8>)> {
173 if self.is_aborted() {
175 return Err(Error::Aborted);
176 }
177
178 let info = self
179 .cells
180 .get(&cell_id)
181 .ok_or_else(|| Error::CellNotFound(format!("Cell {:?} not registered", cell_id)))?;
182
183 let compiled = &info.compiled;
184 let dep_count = info.dep_count;
185
186 if let Some(ref callback) = self.callback {
188 callback.on_cell_started(cell_id, &compiled.name);
189 }
190
191 let mut worker = self.worker_pool.get()?;
193
194 {
196 let mut kill_guard = self.current_worker_kill.lock().unwrap();
197 *kill_guard = Some(WorkerKillHandle::new(&worker));
198 }
199
200 worker.load_cell(
202 compiled.dylib_path.clone(),
203 dep_count,
204 compiled.entry_symbol.clone(),
205 compiled.name.clone(),
206 )?;
207
208 let input_bytes: Vec<Vec<u8>> = inputs
210 .iter()
211 .map(|output| output.bytes().to_vec())
212 .collect();
213
214 if self.is_aborted() {
216 let _ = worker.kill();
218 {
219 let mut kill_guard = self.current_worker_kill.lock().unwrap();
220 *kill_guard = None;
221 }
222 if let Some(ref callback) = self.callback {
223 callback.on_cell_error(cell_id, &compiled.name, &Error::Aborted);
224 }
225 return Err(Error::Aborted);
226 }
227
228 let result = worker.execute_with_widgets(input_bytes, widget_values_json);
230
231 {
233 let mut kill_guard = self.current_worker_kill.lock().unwrap();
234 *kill_guard = None;
235 }
236
237 self.worker_pool.put(worker);
239
240 if self.is_aborted() {
242 if let Some(ref callback) = self.callback {
243 callback.on_cell_error(cell_id, &compiled.name, &Error::Aborted);
244 }
245 return Err(Error::Aborted);
246 }
247
248 match result {
250 Ok((bytes, widgets_json)) => {
251 let output = self.parse_output_bytes(&bytes, &compiled.name)?;
253
254 if let Some(ref callback) = self.callback {
255 callback.on_cell_completed(cell_id, &compiled.name);
256 }
257
258 Ok((output, widgets_json))
259 }
260 Err(e) => {
261 if let Some(ref callback) = self.callback {
262 callback.on_cell_error(cell_id, &compiled.name, &e);
263 }
264 Err(e)
265 }
266 }
267 }
268
269 fn parse_output_bytes(&self, bytes: &[u8], cell_name: &str) -> Result<BoxedOutput> {
278 if bytes.len() < 16 {
279 return Err(Error::Execution(format!(
280 "Cell {} output too short: {} bytes",
281 cell_name,
282 bytes.len()
283 )));
284 }
285
286 let display_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
288 let display_end = 8 + display_len;
289
290 if bytes.len() < display_end {
291 return Err(Error::Execution(format!(
292 "Cell {} output too short for display data",
293 cell_name
294 )));
295 }
296
297 let display_text = String::from_utf8_lossy(&bytes[8..display_end]).to_string();
300 let rkyv_data = bytes[display_end..].to_vec();
301
302 Ok(BoxedOutput::from_raw_bytes_with_display(rkyv_data, display_text))
303 }
304
305 pub fn execute_and_store(
307 &mut self,
308 cell_id: CellId,
309 inputs: &[Arc<BoxedOutput>],
310 ) -> Result<()> {
311 let output = self.execute_cell(cell_id, inputs)?;
312 self.state.store_output(cell_id, output);
313 Ok(())
314 }
315
316 pub fn execute_in_order(
318 &mut self,
319 order: &[CellId],
320 deps: &HashMap<CellId, Vec<CellId>>,
321 ) -> Result<()> {
322 for &cell_id in order {
323 if self.is_aborted() {
325 return Err(Error::Aborted);
326 }
327
328 let dep_ids = deps.get(&cell_id).cloned().unwrap_or_default();
330 let inputs: Vec<Arc<BoxedOutput>> = dep_ids
331 .iter()
332 .filter_map(|&dep_id| self.state.get_output(dep_id))
333 .collect();
334
335 if inputs.len() != dep_ids.len() {
337 return Err(Error::Execution(format!(
338 "Missing dependencies for cell {:?}: expected {}, got {}",
339 cell_id,
340 dep_ids.len(),
341 inputs.len()
342 )));
343 }
344
345 self.execute_and_store(cell_id, &inputs)?;
346 }
347
348 Ok(())
349 }
350
351 pub fn kill_current(&self) {
357 if let Ok(guard) = self.current_worker_kill.lock() {
358 if let Some(ref kill_handle) = *guard {
359 kill_handle.kill();
360 }
361 }
362 }
363
364 pub fn get_kill_handle(&self) -> Option<ExecutorKillHandle> {
369 Some(ExecutorKillHandle {
370 inner: self.current_worker_kill.clone(),
371 })
372 }
373
374 pub fn abort(&mut self) {
378 if let Some(ref handle) = self.abort_handle {
379 handle.abort();
380 }
381 self.kill_current();
382 }
383
384 pub fn state(&self) -> &StateManager {
386 &self.state
387 }
388
389 pub fn state_mut(&mut self) -> &mut StateManager {
391 &mut self.state
392 }
393
394 pub fn shutdown(&mut self) {
396 self.worker_pool.shutdown();
397 }
398}
399
400impl Drop for ProcessExecutor {
401 fn drop(&mut self) {
402 self.shutdown();
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_process_executor_creation() {
412 let temp = tempfile::TempDir::new().unwrap();
413 let executor = ProcessExecutor::new(temp.path()).unwrap();
414 assert!(executor.cells.is_empty());
415 }
416
417 #[test]
418 #[ignore = "Requires venus-worker binary"]
419 fn test_process_executor_worker_pool() {
420 let temp = tempfile::TempDir::new().unwrap();
421 let executor = ProcessExecutor::with_warm_pool(temp.path(), 2).unwrap();
422 assert_eq!(executor.worker_pool.available_count(), 2);
423 }
424}