Skip to main content

shape_vm/executor/
task_scheduler.rs

1//! Task scheduler for the async host runtime.
2//!
3//! Manages spawned async tasks: stores their callables, tracks completion,
4//! and executes them (synchronously for now) when the VM suspends on an await.
5//!
6//! The initial design runs tasks inline (synchronous execution at await-time).
7//! True concurrent execution via Tokio can be layered on later by changing
8//! `resolve_task` to spawn on the Tokio runtime.
9
10use std::collections::HashMap;
11
12use shape_value::{VMError, ValueWord};
13
14/// Completion status of a spawned task.
15#[derive(Debug, Clone)]
16pub enum TaskStatus {
17    /// Task has been spawned but not yet executed.
18    Pending,
19    /// Task finished successfully with a result value.
20    Completed(ValueWord),
21    /// Task was cancelled before completion.
22    Cancelled,
23}
24
25/// Scheduler that tracks spawned async tasks by their future ID.
26///
27/// The VM's `SpawnTask` opcode registers a callable here. When the VM later
28/// suspends on `WaitType::Future { id }`, the host looks up the callable,
29/// executes it, and stores the result so the VM can resume.
30///
31/// Supports both inline tasks (callable executed synchronously at await-time)
32/// and external tasks (completed by background Tokio tasks via oneshot channels).
33pub struct TaskScheduler {
34    /// Map from task_id to the callable value (Closure or Function) that
35    /// was passed to `spawn`. Consumed on first execution.
36    callables: HashMap<u64, ValueWord>,
37
38    /// Map from task_id to its completion status.
39    results: HashMap<u64, TaskStatus>,
40
41    /// External completion channels — Tokio background tasks send results here.
42    /// Used for remote calls and other externally-completed futures.
43    external_receivers: HashMap<u64, tokio::sync::oneshot::Receiver<Result<ValueWord, String>>>,
44}
45
46impl TaskScheduler {
47    /// Create a new, empty scheduler.
48    pub fn new() -> Self {
49        Self {
50            callables: HashMap::new(),
51            results: HashMap::new(),
52            external_receivers: HashMap::new(),
53        }
54    }
55
56    /// Register a callable for a given task_id.
57    ///
58    /// Called by `op_spawn_task` when a new task is spawned.
59    pub fn register(&mut self, task_id: u64, callable: ValueWord) {
60        self.callables.insert(task_id, callable);
61        self.results.insert(task_id, TaskStatus::Pending);
62    }
63
64    /// Take (remove) the callable for `task_id` so it can be executed.
65    ///
66    /// Returns `None` if the task was already consumed or never registered.
67    pub fn take_callable(&mut self, task_id: u64) -> Option<ValueWord> {
68        self.callables.remove(&task_id)
69    }
70
71    /// Record a completed result for a task.
72    pub fn complete(&mut self, task_id: u64, value: ValueWord) {
73        self.results.insert(task_id, TaskStatus::Completed(value));
74    }
75
76    /// Mark a task as cancelled.
77    pub fn cancel(&mut self, task_id: u64) {
78        // Only cancel if still pending
79        if let Some(TaskStatus::Pending) = self.results.get(&task_id) {
80            self.results.insert(task_id, TaskStatus::Cancelled);
81            self.callables.remove(&task_id);
82        }
83    }
84
85    /// Get the result for a task, if it has completed.
86    pub fn get_result(&self, task_id: u64) -> Option<&TaskStatus> {
87        self.results.get(&task_id)
88    }
89
90    /// Check whether a task has a stored result (completed or cancelled).
91    pub fn is_resolved(&self, task_id: u64) -> bool {
92        matches!(
93            self.results.get(&task_id),
94            Some(TaskStatus::Completed(_)) | Some(TaskStatus::Cancelled)
95        )
96    }
97
98    /// Register an externally-completed task (e.g., remote call).
99    ///
100    /// Returns a `oneshot::Sender` that the background task uses to deliver the
101    /// result. The scheduler marks the task as Pending and stores the receiver.
102    pub fn register_external(
103        &mut self,
104        task_id: u64,
105    ) -> tokio::sync::oneshot::Sender<Result<ValueWord, String>> {
106        let (tx, rx) = tokio::sync::oneshot::channel();
107        self.results.insert(task_id, TaskStatus::Pending);
108        self.external_receivers.insert(task_id, rx);
109        tx
110    }
111
112    /// Try to resolve an external task (non-blocking check).
113    ///
114    /// Returns `Some(Ok(val))` if the external task completed successfully,
115    /// `Some(Err(..))` on error/cancellation, or `None` if still pending.
116    pub fn try_resolve_external(&mut self, task_id: u64) -> Option<Result<ValueWord, VMError>> {
117        if let Some(TaskStatus::Completed(val)) = self.results.get(&task_id) {
118            return Some(Ok(val.clone()));
119        }
120        if let Some(rx) = self.external_receivers.get_mut(&task_id) {
121            match rx.try_recv() {
122                Ok(Ok(val)) => {
123                    self.results
124                        .insert(task_id, TaskStatus::Completed(val.clone()));
125                    self.external_receivers.remove(&task_id);
126                    Some(Ok(val))
127                }
128                Ok(Err(e)) => {
129                    self.external_receivers.remove(&task_id);
130                    Some(Err(VMError::RuntimeError(e)))
131                }
132                Err(tokio::sync::oneshot::error::TryRecvError::Empty) => None,
133                Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
134                    self.external_receivers.remove(&task_id);
135                    Some(Err(VMError::RuntimeError(
136                        "Remote task cancelled".to_string(),
137                    )))
138                }
139            }
140        } else {
141            None
142        }
143    }
144
145    /// Check whether a task has an external receiver (is externally-completed).
146    pub fn has_external(&self, task_id: u64) -> bool {
147        self.external_receivers.contains_key(&task_id)
148    }
149
150    /// Take the external receiver for async awaiting.
151    ///
152    /// Used by `execute_with_async` when it needs to truly `.await` an external
153    /// task's completion.
154    pub fn take_external_receiver(
155        &mut self,
156        task_id: u64,
157    ) -> Option<tokio::sync::oneshot::Receiver<Result<ValueWord, String>>> {
158        self.external_receivers.remove(&task_id)
159    }
160
161    /// Resolve a single task by executing its callable on a fresh VM executor.
162    ///
163    /// This is the synchronous (inline) strategy: the callable is executed
164    /// immediately when awaited. Returns the result value, or an error.
165    ///
166    /// The `executor_fn` callback receives the callable ValueWord and must
167    /// execute it, returning the result.
168    pub fn resolve_task<F>(&mut self, task_id: u64, executor_fn: F) -> Result<ValueWord, VMError>
169    where
170        F: FnOnce(ValueWord) -> Result<ValueWord, VMError>,
171    {
172        // If already resolved, return the cached result
173        if let Some(TaskStatus::Completed(val)) = self.results.get(&task_id) {
174            return Ok(val.clone());
175        }
176        if let Some(TaskStatus::Cancelled) = self.results.get(&task_id) {
177            return Err(VMError::RuntimeError(format!(
178                "Task {} was cancelled",
179                task_id
180            )));
181        }
182
183        // Take the callable (consume it)
184        let callable = self.take_callable(task_id).ok_or_else(|| {
185            VMError::RuntimeError(format!("No callable registered for task {}", task_id))
186        })?;
187
188        // Execute synchronously
189        let result = executor_fn(callable)?;
190
191        // Cache the result
192        self.results
193            .insert(task_id, TaskStatus::Completed(result.clone()));
194
195        Ok(result)
196    }
197
198    /// Resolve a task group according to the join strategy.
199    ///
200    /// Join kinds (encoded in the high 2 bits of JoinInit's packed operand):
201    ///   0 = All  — wait for all tasks, return array of results
202    ///   1 = Race — return first completed result
203    ///   2 = Any  — return first successful result (skip errors)
204    ///   3 = AllSettled — return array of {status, value/error} for every task
205    ///
206    /// Since we execute synchronously, "race" and "any" still run all tasks
207    /// sequentially but return early on the first applicable result.
208    pub fn resolve_task_group<F>(
209        &mut self,
210        kind: u8,
211        task_ids: &[u64],
212        mut executor_fn: F,
213    ) -> Result<ValueWord, VMError>
214    where
215        F: FnMut(ValueWord) -> Result<ValueWord, VMError>,
216    {
217        match kind {
218            // All: collect all results into an array
219            0 => {
220                let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
221                for &id in task_ids {
222                    let val = self.resolve_task(id, &mut executor_fn)?;
223                    results.push(val);
224                }
225                Ok(ValueWord::from_array(std::sync::Arc::new(results)))
226            }
227            // Race: return first result (all run, but we return first)
228            1 => {
229                for &id in task_ids {
230                    let val = self.resolve_task(id, &mut executor_fn)?;
231                    return Ok(val);
232                }
233                Err(VMError::RuntimeError(
234                    "Race join with empty task list".to_string(),
235                ))
236            }
237            // Any: return first success, skip errors
238            2 => {
239                let mut last_err = None;
240                for &id in task_ids {
241                    match self.resolve_task(id, &mut executor_fn) {
242                        Ok(val) => return Ok(val),
243                        Err(e) => last_err = Some(e),
244                    }
245                }
246                Err(last_err.unwrap_or_else(|| {
247                    VMError::RuntimeError("Any join with empty task list".to_string())
248                }))
249            }
250            // AllSettled: collect {status, value/error} for each
251            3 => {
252                let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
253                for &id in task_ids {
254                    match self.resolve_task(id, &mut executor_fn) {
255                        Ok(val) => results.push(val),
256                        Err(e) => results.push(ValueWord::from_string(std::sync::Arc::new(
257                            format!("Error: {}", e),
258                        ))),
259                    }
260                }
261                Ok(ValueWord::from_array(std::sync::Arc::new(results)))
262            }
263            _ => Err(VMError::RuntimeError(format!(
264                "Unknown join kind: {}",
265                kind
266            ))),
267        }
268    }
269}
270
271#[cfg(feature = "gc")]
272impl TaskScheduler {
273    /// Scan all heap-referencing ValueWord roots held by the scheduler.
274    ///
275    /// Called during GC root enumeration. Both pending callables and
276    /// completed results may reference heap objects.
277    pub(crate) fn scan_roots(&self, visitor: &mut dyn FnMut(*mut u8)) {
278        for callable in self.callables.values() {
279            shape_gc::roots::trace_nanboxed_bits(callable.raw_bits(), visitor);
280        }
281        for status in self.results.values() {
282            if let TaskStatus::Completed(val) = status {
283                shape_gc::roots::trace_nanboxed_bits(val.raw_bits(), visitor);
284            }
285        }
286    }
287}
288
289impl std::fmt::Debug for TaskScheduler {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("TaskScheduler")
292            .field("callables", &self.callables)
293            .field("results", &self.results)
294            .field(
295                "external_receivers",
296                &format!("[{} pending]", self.external_receivers.len()),
297            )
298            .finish()
299    }
300}
301
302impl Default for TaskScheduler {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use std::sync::Arc;
312
313    #[test]
314    fn test_register_and_take_callable() {
315        let mut sched = TaskScheduler::new();
316        sched.register(1, ValueWord::from_function(42));
317        assert!(matches!(sched.get_result(1), Some(TaskStatus::Pending)));
318
319        let callable = sched.take_callable(1);
320        assert!(callable.is_some());
321
322        // Second take returns None (consumed)
323        assert!(sched.take_callable(1).is_none());
324    }
325
326    #[test]
327    fn test_resolve_task_synchronous() {
328        let mut sched = TaskScheduler::new();
329        sched.register(1, ValueWord::from_function(0));
330
331        let result = sched.resolve_task(1, |_callable| Ok(ValueWord::from_f64(99.0)));
332        assert!(result.is_ok());
333        let val = result.unwrap();
334        assert!((val.as_f64().unwrap() - 99.0).abs() < f64::EPSILON);
335
336        // Second resolve returns cached result
337        let cached = sched.resolve_task(1, |_| panic!("should not be called"));
338        assert!(cached.is_ok());
339    }
340
341    #[test]
342    fn test_cancel_task() {
343        let mut sched = TaskScheduler::new();
344        sched.register(1, ValueWord::from_function(0));
345
346        sched.cancel(1);
347        assert!(sched.is_resolved(1));
348
349        let result = sched.resolve_task(1, |_| Ok(ValueWord::none()));
350        assert!(result.is_err());
351    }
352
353    #[test]
354    fn test_resolve_all_group() {
355        let mut sched = TaskScheduler::new();
356        sched.register(1, ValueWord::from_function(0));
357        sched.register(2, ValueWord::from_function(1));
358
359        let mut call_count = 0u32;
360        let result = sched.resolve_task_group(0, &[1, 2], |_callable| {
361            call_count += 1;
362            Ok(ValueWord::from_f64(call_count as f64))
363        });
364        assert!(result.is_ok());
365        let val = result.unwrap();
366        let view = val.as_any_array().expect("Expected array");
367        assert_eq!(view.len(), 2);
368    }
369
370    #[test]
371    fn test_resolve_race_group() {
372        let mut sched = TaskScheduler::new();
373        sched.register(10, ValueWord::from_function(0));
374        sched.register(20, ValueWord::from_function(1));
375
376        let result = sched.resolve_task_group(1, &[10, 20], |_| {
377            Ok(ValueWord::from_string(Arc::new("first".to_string())))
378        });
379        assert!(result.is_ok());
380        let val = result.unwrap();
381        assert_eq!(val.as_str().unwrap(), "first");
382    }
383
384    #[test]
385    fn test_register_external_and_resolve() {
386        let mut sched = TaskScheduler::new();
387        let tx = sched.register_external(100);
388        assert!(sched.has_external(100));
389        assert!(matches!(sched.get_result(100), Some(TaskStatus::Pending)));
390
391        // Not yet resolved
392        assert!(sched.try_resolve_external(100).is_none());
393
394        // Send result from "background task"
395        tx.send(Ok(ValueWord::from_f64(42.0))).unwrap();
396
397        // Now resolves
398        let result = sched.try_resolve_external(100);
399        assert!(result.is_some());
400        let val = result.unwrap().unwrap();
401        assert!((val.as_f64().unwrap() - 42.0).abs() < f64::EPSILON);
402
403        // Receiver removed after resolution
404        assert!(!sched.has_external(100));
405    }
406
407    #[test]
408    fn test_external_task_error() {
409        let mut sched = TaskScheduler::new();
410        let tx = sched.register_external(200);
411
412        tx.send(Err("connection refused".to_string())).unwrap();
413
414        let result = sched.try_resolve_external(200);
415        assert!(result.is_some());
416        assert!(result.unwrap().is_err());
417    }
418
419    #[test]
420    fn test_external_task_cancelled() {
421        let mut sched = TaskScheduler::new();
422        let tx = sched.register_external(300);
423
424        // Drop sender to simulate cancellation
425        drop(tx);
426
427        let result = sched.try_resolve_external(300);
428        assert!(result.is_some());
429        assert!(result.unwrap().is_err());
430    }
431
432    #[test]
433    fn test_take_external_receiver() {
434        let mut sched = TaskScheduler::new();
435        let _tx = sched.register_external(400);
436
437        assert!(sched.has_external(400));
438        let rx = sched.take_external_receiver(400);
439        assert!(rx.is_some());
440        assert!(!sched.has_external(400));
441    }
442}