Skip to main content

shape_vm/executor/
task_scheduler.rs

1//! Task scheduler for the async host runtime.
2//!
3//! Manages spawned async tasks: stores their callables, tracks completion,
4//! and executes them (synchronously for now) when the VM suspends on an await.
5//!
6//! The initial design runs tasks inline (synchronous execution at await-time).
7//! True concurrent execution via Tokio can be layered on later by changing
8//! `resolve_task` to spawn on the Tokio runtime.
9
10use std::collections::HashMap;
11
12use shape_value::{VMError, ValueWord};
13
14/// Completion status of a spawned task.
15#[derive(Debug, Clone)]
16pub enum TaskStatus {
17    /// Task has been spawned but not yet executed.
18    Pending,
19    /// Task finished successfully with a result value.
20    Completed(ValueWord),
21    /// Task was cancelled before completion.
22    Cancelled,
23}
24
25/// Scheduler that tracks spawned async tasks by their future ID.
26///
27/// The VM's `SpawnTask` opcode registers a callable here. When the VM later
28/// suspends on `WaitType::Future { id }`, the host looks up the callable,
29/// executes it, and stores the result so the VM can resume.
30#[derive(Debug)]
31pub struct TaskScheduler {
32    /// Map from task_id to the callable value (Closure or Function) that
33    /// was passed to `spawn`. Consumed on first execution.
34    callables: HashMap<u64, ValueWord>,
35
36    /// Map from task_id to its completion status.
37    results: HashMap<u64, TaskStatus>,
38}
39
40impl TaskScheduler {
41    /// Create a new, empty scheduler.
42    pub fn new() -> Self {
43        Self {
44            callables: HashMap::new(),
45            results: HashMap::new(),
46        }
47    }
48
49    /// Register a callable for a given task_id.
50    ///
51    /// Called by `op_spawn_task` when a new task is spawned.
52    pub fn register(&mut self, task_id: u64, callable: ValueWord) {
53        self.callables.insert(task_id, callable);
54        self.results.insert(task_id, TaskStatus::Pending);
55    }
56
57    /// Take (remove) the callable for `task_id` so it can be executed.
58    ///
59    /// Returns `None` if the task was already consumed or never registered.
60    pub fn take_callable(&mut self, task_id: u64) -> Option<ValueWord> {
61        self.callables.remove(&task_id)
62    }
63
64    /// Record a completed result for a task.
65    pub fn complete(&mut self, task_id: u64, value: ValueWord) {
66        self.results.insert(task_id, TaskStatus::Completed(value));
67    }
68
69    /// Mark a task as cancelled.
70    pub fn cancel(&mut self, task_id: u64) {
71        // Only cancel if still pending
72        if let Some(TaskStatus::Pending) = self.results.get(&task_id) {
73            self.results.insert(task_id, TaskStatus::Cancelled);
74            self.callables.remove(&task_id);
75        }
76    }
77
78    /// Get the result for a task, if it has completed.
79    pub fn get_result(&self, task_id: u64) -> Option<&TaskStatus> {
80        self.results.get(&task_id)
81    }
82
83    /// Check whether a task has a stored result (completed or cancelled).
84    pub fn is_resolved(&self, task_id: u64) -> bool {
85        matches!(
86            self.results.get(&task_id),
87            Some(TaskStatus::Completed(_)) | Some(TaskStatus::Cancelled)
88        )
89    }
90
91    /// Resolve a single task by executing its callable on a fresh VM executor.
92    ///
93    /// This is the synchronous (inline) strategy: the callable is executed
94    /// immediately when awaited. Returns the result value, or an error.
95    ///
96    /// The `executor_fn` callback receives the callable ValueWord and must
97    /// execute it, returning the result.
98    pub fn resolve_task<F>(&mut self, task_id: u64, executor_fn: F) -> Result<ValueWord, VMError>
99    where
100        F: FnOnce(ValueWord) -> Result<ValueWord, VMError>,
101    {
102        // If already resolved, return the cached result
103        if let Some(TaskStatus::Completed(val)) = self.results.get(&task_id) {
104            return Ok(val.clone());
105        }
106        if let Some(TaskStatus::Cancelled) = self.results.get(&task_id) {
107            return Err(VMError::RuntimeError(format!(
108                "Task {} was cancelled",
109                task_id
110            )));
111        }
112
113        // Take the callable (consume it)
114        let callable = self.take_callable(task_id).ok_or_else(|| {
115            VMError::RuntimeError(format!("No callable registered for task {}", task_id))
116        })?;
117
118        // Execute synchronously
119        let result = executor_fn(callable)?;
120
121        // Cache the result
122        self.results
123            .insert(task_id, TaskStatus::Completed(result.clone()));
124
125        Ok(result)
126    }
127
128    /// Resolve a task group according to the join strategy.
129    ///
130    /// Join kinds (encoded in the high 2 bits of JoinInit's packed operand):
131    ///   0 = All  — wait for all tasks, return array of results
132    ///   1 = Race — return first completed result
133    ///   2 = Any  — return first successful result (skip errors)
134    ///   3 = AllSettled — return array of {status, value/error} for every task
135    ///
136    /// Since we execute synchronously, "race" and "any" still run all tasks
137    /// sequentially but return early on the first applicable result.
138    pub fn resolve_task_group<F>(
139        &mut self,
140        kind: u8,
141        task_ids: &[u64],
142        mut executor_fn: F,
143    ) -> Result<ValueWord, VMError>
144    where
145        F: FnMut(ValueWord) -> Result<ValueWord, VMError>,
146    {
147        match kind {
148            // All: collect all results into an array
149            0 => {
150                let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
151                for &id in task_ids {
152                    let val = self.resolve_task(id, &mut executor_fn)?;
153                    results.push(val);
154                }
155                Ok(ValueWord::from_array(std::sync::Arc::new(results)))
156            }
157            // Race: return first result (all run, but we return first)
158            1 => {
159                for &id in task_ids {
160                    let val = self.resolve_task(id, &mut executor_fn)?;
161                    return Ok(val);
162                }
163                Err(VMError::RuntimeError(
164                    "Race join with empty task list".to_string(),
165                ))
166            }
167            // Any: return first success, skip errors
168            2 => {
169                let mut last_err = None;
170                for &id in task_ids {
171                    match self.resolve_task(id, &mut executor_fn) {
172                        Ok(val) => return Ok(val),
173                        Err(e) => last_err = Some(e),
174                    }
175                }
176                Err(last_err.unwrap_or_else(|| {
177                    VMError::RuntimeError("Any join with empty task list".to_string())
178                }))
179            }
180            // AllSettled: collect {status, value/error} for each
181            3 => {
182                let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
183                for &id in task_ids {
184                    match self.resolve_task(id, &mut executor_fn) {
185                        Ok(val) => results.push(val),
186                        Err(e) => results.push(ValueWord::from_string(std::sync::Arc::new(
187                            format!("Error: {}", e),
188                        ))),
189                    }
190                }
191                Ok(ValueWord::from_array(std::sync::Arc::new(results)))
192            }
193            _ => Err(VMError::RuntimeError(format!(
194                "Unknown join kind: {}",
195                kind
196            ))),
197        }
198    }
199}
200
201#[cfg(feature = "gc")]
202impl TaskScheduler {
203    /// Scan all heap-referencing ValueWord roots held by the scheduler.
204    ///
205    /// Called during GC root enumeration. Both pending callables and
206    /// completed results may reference heap objects.
207    pub(crate) fn scan_roots(&self, visitor: &mut dyn FnMut(*mut u8)) {
208        for callable in self.callables.values() {
209            shape_gc::roots::trace_nanboxed_bits(callable.raw_bits(), visitor);
210        }
211        for status in self.results.values() {
212            if let TaskStatus::Completed(val) = status {
213                shape_gc::roots::trace_nanboxed_bits(val.raw_bits(), visitor);
214            }
215        }
216    }
217}
218
219impl Default for TaskScheduler {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use std::sync::Arc;
229
230    #[test]
231    fn test_register_and_take_callable() {
232        let mut sched = TaskScheduler::new();
233        sched.register(1, ValueWord::from_function(42));
234        assert!(matches!(sched.get_result(1), Some(TaskStatus::Pending)));
235
236        let callable = sched.take_callable(1);
237        assert!(callable.is_some());
238
239        // Second take returns None (consumed)
240        assert!(sched.take_callable(1).is_none());
241    }
242
243    #[test]
244    fn test_resolve_task_synchronous() {
245        let mut sched = TaskScheduler::new();
246        sched.register(1, ValueWord::from_function(0));
247
248        let result = sched.resolve_task(1, |_callable| Ok(ValueWord::from_f64(99.0)));
249        assert!(result.is_ok());
250        let val = result.unwrap();
251        assert!((val.as_f64().unwrap() - 99.0).abs() < f64::EPSILON);
252
253        // Second resolve returns cached result
254        let cached = sched.resolve_task(1, |_| panic!("should not be called"));
255        assert!(cached.is_ok());
256    }
257
258    #[test]
259    fn test_cancel_task() {
260        let mut sched = TaskScheduler::new();
261        sched.register(1, ValueWord::from_function(0));
262
263        sched.cancel(1);
264        assert!(sched.is_resolved(1));
265
266        let result = sched.resolve_task(1, |_| Ok(ValueWord::none()));
267        assert!(result.is_err());
268    }
269
270    #[test]
271    fn test_resolve_all_group() {
272        let mut sched = TaskScheduler::new();
273        sched.register(1, ValueWord::from_function(0));
274        sched.register(2, ValueWord::from_function(1));
275
276        let mut call_count = 0u32;
277        let result = sched.resolve_task_group(0, &[1, 2], |_callable| {
278            call_count += 1;
279            Ok(ValueWord::from_f64(call_count as f64))
280        });
281        assert!(result.is_ok());
282        let val = result.unwrap();
283        let view = val.as_any_array().expect("Expected array");
284        assert_eq!(view.len(), 2);
285    }
286
287    #[test]
288    fn test_resolve_race_group() {
289        let mut sched = TaskScheduler::new();
290        sched.register(10, ValueWord::from_function(0));
291        sched.register(20, ValueWord::from_function(1));
292
293        let result = sched.resolve_task_group(1, &[10, 20], |_| {
294            Ok(ValueWord::from_string(Arc::new("first".to_string())))
295        });
296        assert!(result.is_ok());
297        let val = result.unwrap();
298        assert_eq!(val.as_str().unwrap(), "first");
299    }
300}