1use std::collections::HashMap;
6use std::path::Path;
7use std::sync::Arc;
8
9use libloading::Symbol;
10use tracing::{debug, trace, warn};
11
12use crate::compile::CompiledCell;
13use crate::error::{Error, Result};
14use crate::graph::CellId;
15use crate::state::{BoxedOutput, StateManager};
16
17use super::context::{AbortHandle, ExecutionCallback};
18use super::ffi::{
19 EntryFn0, EntryFn1, EntryFn2, EntryFn3, EntryFn4, EntryFn5, EntryFn6, EntryFn7, EntryFn8,
20 ExecutionResult, call_cell_n_deps,
21};
22use super::loaded_cell::LoadedCell;
23
24struct FfiMemoryGuard {
27 ptr: *mut u8,
28}
29
30impl FfiMemoryGuard {
31 unsafe fn new(ptr: *mut u8) -> Self {
32 Self { ptr }
33 }
34
35 fn as_slice(&self, len: usize) -> &[u8] {
36 unsafe { std::slice::from_raw_parts(self.ptr, len) }
37 }
38}
39
40impl Drop for FfiMemoryGuard {
41 fn drop(&mut self) {
42 if !self.ptr.is_null() {
43 unsafe {
44 libc::free(self.ptr as *mut libc::c_void);
45 }
46 }
47 }
48}
49
50pub struct LinearExecutor {
52 cells: HashMap<CellId, LoadedCell>,
54 state: StateManager,
56 callback: Option<Box<dyn ExecutionCallback>>,
58 abort_handle: Option<AbortHandle>,
60}
61
62impl LinearExecutor {
63 pub fn new(state_dir: impl AsRef<Path>) -> Result<Self> {
65 Ok(Self {
66 cells: HashMap::new(),
67 state: StateManager::new(state_dir)?,
68 callback: None,
69 abort_handle: None,
70 })
71 }
72
73 pub fn with_state(state: StateManager) -> Self {
75 Self {
76 cells: HashMap::new(),
77 state,
78 callback: None,
79 abort_handle: None,
80 }
81 }
82
83 pub fn set_callback(&mut self, callback: impl ExecutionCallback + 'static) {
85 self.callback = Some(Box::new(callback));
86 }
87
88 pub fn set_abort_handle(&mut self, handle: AbortHandle) {
90 self.abort_handle = Some(handle);
91 }
92
93 pub fn abort_handle(&self) -> Option<&AbortHandle> {
95 self.abort_handle.as_ref()
96 }
97
98 fn is_aborted(&self) -> bool {
100 self.abort_handle
101 .as_ref()
102 .is_some_and(|h| h.is_aborted())
103 }
104
105 pub fn load_cell(&mut self, compiled: CompiledCell, dep_count: usize) -> Result<()> {
107 let cell_id = compiled.cell_id;
108 let loaded = LoadedCell::load(compiled, dep_count)?;
109 self.cells.insert(cell_id, loaded);
110 Ok(())
111 }
112
113 pub fn unload_cell(&mut self, cell_id: CellId) -> Option<LoadedCell> {
115 self.cells.remove(&cell_id)
116 }
117
118 pub fn restore_cell(&mut self, cell: LoadedCell) {
120 self.cells.insert(cell.compiled.cell_id, cell);
121 }
122
123 pub fn is_loaded(&self, cell_id: CellId) -> bool {
125 self.cells.contains_key(&cell_id)
126 }
127
128 pub fn execute_cell(
133 &mut self,
134 cell_id: CellId,
135 inputs: &[Arc<BoxedOutput>],
136 ) -> Result<BoxedOutput> {
137 if self.is_aborted() {
139 return Err(Error::Aborted);
140 }
141
142 let loaded = self
143 .cells
144 .get(&cell_id)
145 .ok_or_else(|| Error::CellNotFound(format!("Cell {:?} not loaded", cell_id)))?;
146
147 if let Some(ref callback) = self.callback {
149 callback.on_cell_started(cell_id, &loaded.compiled.name);
150 }
151
152 let result = self.call_cell_ffi(loaded, inputs);
154
155 if self.is_aborted() {
157 if let Some(ref callback) = self.callback {
158 callback.on_cell_error(cell_id, &loaded.compiled.name, &Error::Aborted);
159 }
160 return Err(Error::Aborted);
161 }
162
163 match &result {
165 Ok(_) => {
166 if let Some(ref callback) = self.callback {
167 callback.on_cell_completed(cell_id, &loaded.compiled.name);
168 }
169 }
170 Err(e) => {
171 if let Some(ref callback) = self.callback {
172 callback.on_cell_error(cell_id, &loaded.compiled.name, e);
173 }
174 }
175 }
176
177 result
178 }
179
180 pub fn execute_and_store(
182 &mut self,
183 cell_id: CellId,
184 inputs: &[Arc<BoxedOutput>],
185 ) -> Result<()> {
186 let output = self.execute_cell(cell_id, inputs)?;
187 self.state.store_output(cell_id, output);
188 Ok(())
189 }
190
191 pub fn execute_in_order(
195 &mut self,
196 order: &[CellId],
197 deps: &HashMap<CellId, Vec<CellId>>,
198 ) -> Result<()> {
199 for &cell_id in order {
200 if self.is_aborted() {
202 return Err(Error::Aborted);
203 }
204
205 let dep_ids = deps.get(&cell_id).cloned().unwrap_or_default();
207 let inputs: Vec<Arc<BoxedOutput>> = dep_ids
208 .iter()
209 .filter_map(|&dep_id| self.state.get_output(dep_id))
210 .collect();
211
212 if inputs.len() != dep_ids.len() {
214 return Err(Error::Execution(format!(
215 "Missing dependencies for cell {:?}: expected {}, got {}",
216 cell_id,
217 dep_ids.len(),
218 inputs.len()
219 )));
220 }
221
222 self.execute_and_store(cell_id, &inputs)?;
223 }
224
225 Ok(())
226 }
227
228 pub fn state(&self) -> &StateManager {
230 &self.state
231 }
232
233 pub fn state_mut(&mut self) -> &mut StateManager {
235 &mut self.state
236 }
237
238 fn call_cell_ffi(
240 &self,
241 loaded: &LoadedCell,
242 inputs: &[Arc<BoxedOutput>],
243 ) -> Result<BoxedOutput> {
244 if inputs.len() != loaded.dep_count {
246 return Err(Error::Execution(format!(
247 "Cell {} expects {} inputs, got {}",
248 loaded.compiled.name,
249 loaded.dep_count,
250 inputs.len()
251 )));
252 }
253
254 if loaded.dep_count == 0 {
256 return self.call_cell_no_deps(loaded);
257 }
258
259 self.call_cell_with_deps(loaded, inputs)
262 }
263
264 fn call_cell_no_deps(&self, loaded: &LoadedCell) -> Result<BoxedOutput> {
266 let symbol_name = loaded.entry_symbol();
267
268 let func: Symbol<EntryFn0> = unsafe { loaded.library.get(symbol_name.as_bytes()) }
270 .map_err(|e| {
271 Error::Execution(format!("Failed to get symbol {}: {}", symbol_name, e))
272 })?;
273
274 let mut out_ptr: *mut u8 = std::ptr::null_mut();
275 let mut out_len: usize = 0;
276
277 let widget_values: &[u8] = &[];
279
280 let result_code = unsafe {
282 func(
283 widget_values.as_ptr(), widget_values.len(),
284 &mut out_ptr, &mut out_len,
285 )
286 };
287
288 self.process_ffi_result(result_code, out_ptr, out_len, &loaded.compiled.name)
289 }
290
291 fn call_cell_with_deps(
296 &self,
297 loaded: &LoadedCell,
298 inputs: &[Arc<BoxedOutput>],
299 ) -> Result<BoxedOutput> {
300 let symbol_name = loaded.entry_symbol();
301
302 let widget_values: &[u8] = &[];
304
305 debug!(
306 cell = %loaded.compiled.name,
307 dep_count = inputs.len(),
308 "Calling FFI entry point"
309 );
310 trace!(
311 cell = %loaded.compiled.name,
312 input_sizes = ?inputs.iter().map(|i| i.bytes().len()).collect::<Vec<_>>(),
313 "Input buffer sizes (bytes)"
314 );
315
316 match inputs.len() {
320 1 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn1, 0),
321 2 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn2, 0, 1),
322 3 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn3, 0, 1, 2),
323 4 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn4, 0, 1, 2, 3),
324 5 => call_cell_n_deps!(self, loaded, symbol_name, inputs, widget_values, EntryFn5, 0, 1, 2, 3, 4),
325 6 => call_cell_n_deps!(
326 self,
327 loaded,
328 symbol_name,
329 inputs,
330 widget_values,
331 EntryFn6,
332 0,
333 1,
334 2,
335 3,
336 4,
337 5
338 ),
339 7 => call_cell_n_deps!(
340 self,
341 loaded,
342 symbol_name,
343 inputs,
344 widget_values,
345 EntryFn7,
346 0,
347 1,
348 2,
349 3,
350 4,
351 5,
352 6
353 ),
354 8 => call_cell_n_deps!(
355 self,
356 loaded,
357 symbol_name,
358 inputs,
359 widget_values,
360 EntryFn8,
361 0,
362 1,
363 2,
364 3,
365 4,
366 5,
367 6,
368 7
369 ),
370 n => Err(Error::Execution(format!(
371 "Cells with {} dependencies not yet supported (max 8)",
372 n
373 ))),
374 }
375 }
376
377 pub(crate) fn process_ffi_result(
386 &self,
387 result_code: i32,
388 out_ptr: *mut u8,
389 out_len: usize,
390 cell_name: &str,
391 ) -> Result<BoxedOutput> {
392 let result = ExecutionResult::from(result_code);
393
394 match result {
395 ExecutionResult::Success => {
396 if out_ptr.is_null() || out_len == 0 {
397 return Err(Error::Execution(format!(
398 "Cell {} returned null output",
399 cell_name
400 )));
401 }
402
403 let memory_guard = unsafe { FfiMemoryGuard::new(out_ptr) };
406 let bytes = memory_guard.as_slice(out_len).to_vec();
407 if bytes.len() < 16 {
413 return Err(Error::Execution(format!(
414 "Cell {} output too short: {} bytes",
415 cell_name, bytes.len()
416 )));
417 }
418
419 let display_len_bytes: [u8; 8] = bytes[0..8].try_into().map_err(|_| {
421 Error::Execution(format!(
422 "Cell {} output has malformed display_len field",
423 cell_name
424 ))
425 })?;
426 let display_len = u64::from_le_bytes(display_len_bytes) as usize;
427 let display_end = 8 + display_len;
428
429 if bytes.len() < display_end + 8 {
430 return Err(Error::Execution(format!(
431 "Cell {} output too short for display data",
432 cell_name
433 )));
434 }
435
436 let widgets_len_bytes: [u8; 8] = bytes[display_end..display_end + 8].try_into().map_err(|_| {
438 Error::Execution(format!(
439 "Cell {} output has malformed widgets_len field",
440 cell_name
441 ))
442 })?;
443 let widgets_len = u64::from_le_bytes(widgets_len_bytes) as usize;
444 let widgets_end = display_end + 8 + widgets_len;
445
446 if bytes.len() < widgets_end {
447 return Err(Error::Execution(format!(
448 "Cell {} output too short for widget data",
449 cell_name
450 )));
451 }
452
453 let display_text = String::from_utf8_lossy(&bytes[8..display_end]).to_string();
455 let rkyv_data = bytes[widgets_end..].to_vec();
456
457 Ok(BoxedOutput::from_raw_bytes_with_display(rkyv_data, display_text))
458 }
459 ExecutionResult::DeserializationError => {
460 warn!(
461 cell = %cell_name,
462 "Cell failed to deserialize input - likely type mismatch. Enable RUST_LOG=debug for details."
463 );
464 Err(Error::Execution(format!(
465 "Cell {} failed to deserialize input - check dependency types match parameter types. Run with RUST_LOG=venus=debug for details.",
466 cell_name
467 )))
468 }
469 ExecutionResult::CellError => Err(Error::Execution(format!(
470 "Cell {} returned an error",
471 cell_name
472 ))),
473 ExecutionResult::SerializationError => Err(Error::Execution(format!(
474 "Cell {} failed to serialize output",
475 cell_name
476 ))),
477 ExecutionResult::Panic => Err(Error::Execution(format!(
478 "Cell {} panicked during execution. Check for unwrap() on None/Err, out-of-bounds access, or other panic sources.",
479 cell_name
480 ))),
481 }
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_linear_executor_creation() {
491 let temp = tempfile::TempDir::new().unwrap();
492 let executor = LinearExecutor::new(temp.path()).unwrap();
493 assert!(executor.cells.is_empty());
494 assert!(executor.callback.is_none());
495 assert!(executor.abort_handle.is_none());
496 }
497
498 #[test]
499 fn test_with_state_creation() {
500 let temp = tempfile::TempDir::new().unwrap();
501 let state = StateManager::new(temp.path()).unwrap();
502 let executor = LinearExecutor::with_state(state);
503 assert!(executor.cells.is_empty());
504 }
505
506 #[test]
507 fn test_set_callback() {
508 let temp = tempfile::TempDir::new().unwrap();
509 let mut executor = LinearExecutor::new(temp.path()).unwrap();
510
511 struct TestCallback;
512 impl ExecutionCallback for TestCallback {
513 fn on_cell_started(&self, _: CellId, _: &str) {}
514 fn on_cell_completed(&self, _: CellId, _: &str) {}
515 fn on_cell_error(&self, _: CellId, _: &str, _: &Error) {}
516 fn on_level_started(&self, _: usize, _: usize) {}
517 fn on_level_completed(&self, _: usize) {}
518 }
519
520 executor.set_callback(TestCallback);
521 assert!(executor.callback.is_some());
522 }
523
524 #[test]
525 fn test_abort_handle() {
526 let temp = tempfile::TempDir::new().unwrap();
527 let mut executor = LinearExecutor::new(temp.path()).unwrap();
528
529 let handle = AbortHandle::new();
530 executor.set_abort_handle(handle.clone());
531
532 assert!(executor.abort_handle().is_some());
533 assert!(!executor.is_aborted());
534
535 handle.abort();
536 assert!(executor.is_aborted());
537 }
538
539 #[test]
540 fn test_is_loaded() {
541 let temp = tempfile::TempDir::new().unwrap();
542 let executor = LinearExecutor::new(temp.path()).unwrap();
543
544 let cell_id = CellId::new(1);
545 assert!(!executor.is_loaded(cell_id));
546
547 assert!(!executor.is_loaded(cell_id));
549 }
550
551 #[test]
552 fn test_get_state_reference() {
553 let temp = tempfile::TempDir::new().unwrap();
554 let executor = LinearExecutor::new(temp.path()).unwrap();
555
556 let state_ref = executor.state();
557 let stats = state_ref.stats();
559 assert_eq!(stats.cached_outputs, 0);
560 }
561
562 #[test]
563 fn test_execute_in_order_empty() {
564 let temp = tempfile::TempDir::new().unwrap();
565 let mut executor = LinearExecutor::new(temp.path()).unwrap();
566
567 let empty_order: Vec<CellId> = vec![];
568 let empty_deps = std::collections::HashMap::new();
569
570 let result = executor.execute_in_order(&empty_order, &empty_deps);
571 assert!(result.is_ok());
572 }
573
574 #[test]
575 fn test_execute_cell_not_found() {
576 let temp = tempfile::TempDir::new().unwrap();
577 let mut executor = LinearExecutor::new(temp.path()).unwrap();
578
579 let cell_id = CellId::new(999);
580 let result = executor.execute_cell(cell_id, &[]);
581
582 assert!(result.is_err());
583 match result {
584 Err(Error::CellNotFound(msg)) => {
585 assert!(msg.contains("not loaded"));
586 }
587 _ => panic!("Expected CellNotFound error"),
588 }
589 }
590
591 #[test]
592 fn test_execute_cell_aborted() {
593 let temp = tempfile::TempDir::new().unwrap();
594 let mut executor = LinearExecutor::new(temp.path()).unwrap();
595
596 let handle = AbortHandle::new();
597 executor.set_abort_handle(handle.clone());
598 handle.abort();
599
600 let cell_id = CellId::new(1);
601 let result = executor.execute_cell(cell_id, &[]);
602
603 assert!(matches!(result, Err(Error::Aborted)));
604 }
605}