shape_vm/executor/
task_scheduler.rs1use std::collections::HashMap;
11
12use shape_value::{VMError, ValueWord};
13
14#[derive(Debug, Clone)]
16pub enum TaskStatus {
17 Pending,
19 Completed(ValueWord),
21 Cancelled,
23}
24
25#[derive(Debug)]
31pub struct TaskScheduler {
32 callables: HashMap<u64, ValueWord>,
35
36 results: HashMap<u64, TaskStatus>,
38}
39
40impl TaskScheduler {
41 pub fn new() -> Self {
43 Self {
44 callables: HashMap::new(),
45 results: HashMap::new(),
46 }
47 }
48
49 pub fn register(&mut self, task_id: u64, callable: ValueWord) {
53 self.callables.insert(task_id, callable);
54 self.results.insert(task_id, TaskStatus::Pending);
55 }
56
57 pub fn take_callable(&mut self, task_id: u64) -> Option<ValueWord> {
61 self.callables.remove(&task_id)
62 }
63
64 pub fn complete(&mut self, task_id: u64, value: ValueWord) {
66 self.results.insert(task_id, TaskStatus::Completed(value));
67 }
68
69 pub fn cancel(&mut self, task_id: u64) {
71 if let Some(TaskStatus::Pending) = self.results.get(&task_id) {
73 self.results.insert(task_id, TaskStatus::Cancelled);
74 self.callables.remove(&task_id);
75 }
76 }
77
78 pub fn get_result(&self, task_id: u64) -> Option<&TaskStatus> {
80 self.results.get(&task_id)
81 }
82
83 pub fn is_resolved(&self, task_id: u64) -> bool {
85 matches!(
86 self.results.get(&task_id),
87 Some(TaskStatus::Completed(_)) | Some(TaskStatus::Cancelled)
88 )
89 }
90
91 pub fn resolve_task<F>(&mut self, task_id: u64, executor_fn: F) -> Result<ValueWord, VMError>
99 where
100 F: FnOnce(ValueWord) -> Result<ValueWord, VMError>,
101 {
102 if let Some(TaskStatus::Completed(val)) = self.results.get(&task_id) {
104 return Ok(val.clone());
105 }
106 if let Some(TaskStatus::Cancelled) = self.results.get(&task_id) {
107 return Err(VMError::RuntimeError(format!(
108 "Task {} was cancelled",
109 task_id
110 )));
111 }
112
113 let callable = self.take_callable(task_id).ok_or_else(|| {
115 VMError::RuntimeError(format!("No callable registered for task {}", task_id))
116 })?;
117
118 let result = executor_fn(callable)?;
120
121 self.results
123 .insert(task_id, TaskStatus::Completed(result.clone()));
124
125 Ok(result)
126 }
127
128 pub fn resolve_task_group<F>(
139 &mut self,
140 kind: u8,
141 task_ids: &[u64],
142 mut executor_fn: F,
143 ) -> Result<ValueWord, VMError>
144 where
145 F: FnMut(ValueWord) -> Result<ValueWord, VMError>,
146 {
147 match kind {
148 0 => {
150 let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
151 for &id in task_ids {
152 let val = self.resolve_task(id, &mut executor_fn)?;
153 results.push(val);
154 }
155 Ok(ValueWord::from_array(std::sync::Arc::new(results)))
156 }
157 1 => {
159 for &id in task_ids {
160 let val = self.resolve_task(id, &mut executor_fn)?;
161 return Ok(val);
162 }
163 Err(VMError::RuntimeError(
164 "Race join with empty task list".to_string(),
165 ))
166 }
167 2 => {
169 let mut last_err = None;
170 for &id in task_ids {
171 match self.resolve_task(id, &mut executor_fn) {
172 Ok(val) => return Ok(val),
173 Err(e) => last_err = Some(e),
174 }
175 }
176 Err(last_err.unwrap_or_else(|| {
177 VMError::RuntimeError("Any join with empty task list".to_string())
178 }))
179 }
180 3 => {
182 let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
183 for &id in task_ids {
184 match self.resolve_task(id, &mut executor_fn) {
185 Ok(val) => results.push(val),
186 Err(e) => results.push(ValueWord::from_string(std::sync::Arc::new(
187 format!("Error: {}", e),
188 ))),
189 }
190 }
191 Ok(ValueWord::from_array(std::sync::Arc::new(results)))
192 }
193 _ => Err(VMError::RuntimeError(format!(
194 "Unknown join kind: {}",
195 kind
196 ))),
197 }
198 }
199}
200
201#[cfg(feature = "gc")]
202impl TaskScheduler {
203 pub(crate) fn scan_roots(&self, visitor: &mut dyn FnMut(*mut u8)) {
208 for callable in self.callables.values() {
209 shape_gc::roots::trace_nanboxed_bits(callable.raw_bits(), visitor);
210 }
211 for status in self.results.values() {
212 if let TaskStatus::Completed(val) = status {
213 shape_gc::roots::trace_nanboxed_bits(val.raw_bits(), visitor);
214 }
215 }
216 }
217}
218
219impl Default for TaskScheduler {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use std::sync::Arc;
229
230 #[test]
231 fn test_register_and_take_callable() {
232 let mut sched = TaskScheduler::new();
233 sched.register(1, ValueWord::from_function(42));
234 assert!(matches!(sched.get_result(1), Some(TaskStatus::Pending)));
235
236 let callable = sched.take_callable(1);
237 assert!(callable.is_some());
238
239 assert!(sched.take_callable(1).is_none());
241 }
242
243 #[test]
244 fn test_resolve_task_synchronous() {
245 let mut sched = TaskScheduler::new();
246 sched.register(1, ValueWord::from_function(0));
247
248 let result = sched.resolve_task(1, |_callable| Ok(ValueWord::from_f64(99.0)));
249 assert!(result.is_ok());
250 let val = result.unwrap();
251 assert!((val.as_f64().unwrap() - 99.0).abs() < f64::EPSILON);
252
253 let cached = sched.resolve_task(1, |_| panic!("should not be called"));
255 assert!(cached.is_ok());
256 }
257
258 #[test]
259 fn test_cancel_task() {
260 let mut sched = TaskScheduler::new();
261 sched.register(1, ValueWord::from_function(0));
262
263 sched.cancel(1);
264 assert!(sched.is_resolved(1));
265
266 let result = sched.resolve_task(1, |_| Ok(ValueWord::none()));
267 assert!(result.is_err());
268 }
269
270 #[test]
271 fn test_resolve_all_group() {
272 let mut sched = TaskScheduler::new();
273 sched.register(1, ValueWord::from_function(0));
274 sched.register(2, ValueWord::from_function(1));
275
276 let mut call_count = 0u32;
277 let result = sched.resolve_task_group(0, &[1, 2], |_callable| {
278 call_count += 1;
279 Ok(ValueWord::from_f64(call_count as f64))
280 });
281 assert!(result.is_ok());
282 let val = result.unwrap();
283 let view = val.as_any_array().expect("Expected array");
284 assert_eq!(view.len(), 2);
285 }
286
287 #[test]
288 fn test_resolve_race_group() {
289 let mut sched = TaskScheduler::new();
290 sched.register(10, ValueWord::from_function(0));
291 sched.register(20, ValueWord::from_function(1));
292
293 let result = sched.resolve_task_group(1, &[10, 20], |_| {
294 Ok(ValueWord::from_string(Arc::new("first".to_string())))
295 });
296 assert!(result.is_ok());
297 let val = result.unwrap();
298 assert_eq!(val.as_str().unwrap(), "first");
299 }
300}