1use std::collections::HashMap;
11
12use shape_value::{VMError, ValueWord};
13
14#[derive(Debug, Clone)]
16pub enum TaskStatus {
17 Pending,
19 Completed(ValueWord),
21 Cancelled,
23}
24
25pub struct TaskScheduler {
34 callables: HashMap<u64, ValueWord>,
37
38 results: HashMap<u64, TaskStatus>,
40
41 external_receivers: HashMap<u64, tokio::sync::oneshot::Receiver<Result<ValueWord, String>>>,
44}
45
46impl TaskScheduler {
47 pub fn new() -> Self {
49 Self {
50 callables: HashMap::new(),
51 results: HashMap::new(),
52 external_receivers: HashMap::new(),
53 }
54 }
55
56 pub fn register(&mut self, task_id: u64, callable: ValueWord) {
60 self.callables.insert(task_id, callable);
61 self.results.insert(task_id, TaskStatus::Pending);
62 }
63
64 pub fn take_callable(&mut self, task_id: u64) -> Option<ValueWord> {
68 self.callables.remove(&task_id)
69 }
70
71 pub fn complete(&mut self, task_id: u64, value: ValueWord) {
73 self.results.insert(task_id, TaskStatus::Completed(value));
74 }
75
76 pub fn cancel(&mut self, task_id: u64) {
78 if let Some(TaskStatus::Pending) = self.results.get(&task_id) {
80 self.results.insert(task_id, TaskStatus::Cancelled);
81 self.callables.remove(&task_id);
82 }
83 }
84
85 pub fn get_result(&self, task_id: u64) -> Option<&TaskStatus> {
87 self.results.get(&task_id)
88 }
89
90 pub fn is_resolved(&self, task_id: u64) -> bool {
92 matches!(
93 self.results.get(&task_id),
94 Some(TaskStatus::Completed(_)) | Some(TaskStatus::Cancelled)
95 )
96 }
97
98 pub fn register_external(
103 &mut self,
104 task_id: u64,
105 ) -> tokio::sync::oneshot::Sender<Result<ValueWord, String>> {
106 let (tx, rx) = tokio::sync::oneshot::channel();
107 self.results.insert(task_id, TaskStatus::Pending);
108 self.external_receivers.insert(task_id, rx);
109 tx
110 }
111
112 pub fn try_resolve_external(&mut self, task_id: u64) -> Option<Result<ValueWord, VMError>> {
117 if let Some(TaskStatus::Completed(val)) = self.results.get(&task_id) {
118 return Some(Ok(val.clone()));
119 }
120 if let Some(rx) = self.external_receivers.get_mut(&task_id) {
121 match rx.try_recv() {
122 Ok(Ok(val)) => {
123 self.results
124 .insert(task_id, TaskStatus::Completed(val.clone()));
125 self.external_receivers.remove(&task_id);
126 Some(Ok(val))
127 }
128 Ok(Err(e)) => {
129 self.external_receivers.remove(&task_id);
130 Some(Err(VMError::RuntimeError(e)))
131 }
132 Err(tokio::sync::oneshot::error::TryRecvError::Empty) => None,
133 Err(tokio::sync::oneshot::error::TryRecvError::Closed) => {
134 self.external_receivers.remove(&task_id);
135 Some(Err(VMError::RuntimeError(
136 "Remote task cancelled".to_string(),
137 )))
138 }
139 }
140 } else {
141 None
142 }
143 }
144
145 pub fn has_external(&self, task_id: u64) -> bool {
147 self.external_receivers.contains_key(&task_id)
148 }
149
150 pub fn take_external_receiver(
155 &mut self,
156 task_id: u64,
157 ) -> Option<tokio::sync::oneshot::Receiver<Result<ValueWord, String>>> {
158 self.external_receivers.remove(&task_id)
159 }
160
161 pub fn resolve_task<F>(&mut self, task_id: u64, executor_fn: F) -> Result<ValueWord, VMError>
169 where
170 F: FnOnce(ValueWord) -> Result<ValueWord, VMError>,
171 {
172 if let Some(TaskStatus::Completed(val)) = self.results.get(&task_id) {
174 return Ok(val.clone());
175 }
176 if let Some(TaskStatus::Cancelled) = self.results.get(&task_id) {
177 return Err(VMError::RuntimeError(format!(
178 "Task {} was cancelled",
179 task_id
180 )));
181 }
182
183 let callable = self.take_callable(task_id).ok_or_else(|| {
185 VMError::RuntimeError(format!("No callable registered for task {}", task_id))
186 })?;
187
188 let result = executor_fn(callable)?;
190
191 self.results
193 .insert(task_id, TaskStatus::Completed(result.clone()));
194
195 Ok(result)
196 }
197
198 pub fn resolve_task_group<F>(
209 &mut self,
210 kind: u8,
211 task_ids: &[u64],
212 mut executor_fn: F,
213 ) -> Result<ValueWord, VMError>
214 where
215 F: FnMut(ValueWord) -> Result<ValueWord, VMError>,
216 {
217 match kind {
218 0 => {
220 let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
221 for &id in task_ids {
222 let val = self.resolve_task(id, &mut executor_fn)?;
223 results.push(val);
224 }
225 Ok(ValueWord::from_array(std::sync::Arc::new(results)))
226 }
227 1 => {
229 for &id in task_ids {
230 let val = self.resolve_task(id, &mut executor_fn)?;
231 return Ok(val);
232 }
233 Err(VMError::RuntimeError(
234 "Race join with empty task list".to_string(),
235 ))
236 }
237 2 => {
239 let mut last_err = None;
240 for &id in task_ids {
241 match self.resolve_task(id, &mut executor_fn) {
242 Ok(val) => return Ok(val),
243 Err(e) => last_err = Some(e),
244 }
245 }
246 Err(last_err.unwrap_or_else(|| {
247 VMError::RuntimeError("Any join with empty task list".to_string())
248 }))
249 }
250 3 => {
252 let mut results: Vec<ValueWord> = Vec::with_capacity(task_ids.len());
253 for &id in task_ids {
254 match self.resolve_task(id, &mut executor_fn) {
255 Ok(val) => results.push(val),
256 Err(e) => results.push(ValueWord::from_string(std::sync::Arc::new(
257 format!("Error: {}", e),
258 ))),
259 }
260 }
261 Ok(ValueWord::from_array(std::sync::Arc::new(results)))
262 }
263 _ => Err(VMError::RuntimeError(format!(
264 "Unknown join kind: {}",
265 kind
266 ))),
267 }
268 }
269}
270
271#[cfg(feature = "gc")]
272impl TaskScheduler {
273 pub(crate) fn scan_roots(&self, visitor: &mut dyn FnMut(*mut u8)) {
278 for callable in self.callables.values() {
279 shape_gc::roots::trace_nanboxed_bits(callable.raw_bits(), visitor);
280 }
281 for status in self.results.values() {
282 if let TaskStatus::Completed(val) = status {
283 shape_gc::roots::trace_nanboxed_bits(val.raw_bits(), visitor);
284 }
285 }
286 }
287}
288
289impl std::fmt::Debug for TaskScheduler {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("TaskScheduler")
292 .field("callables", &self.callables)
293 .field("results", &self.results)
294 .field(
295 "external_receivers",
296 &format!("[{} pending]", self.external_receivers.len()),
297 )
298 .finish()
299 }
300}
301
302impl Default for TaskScheduler {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use std::sync::Arc;
312
313 #[test]
314 fn test_register_and_take_callable() {
315 let mut sched = TaskScheduler::new();
316 sched.register(1, ValueWord::from_function(42));
317 assert!(matches!(sched.get_result(1), Some(TaskStatus::Pending)));
318
319 let callable = sched.take_callable(1);
320 assert!(callable.is_some());
321
322 assert!(sched.take_callable(1).is_none());
324 }
325
326 #[test]
327 fn test_resolve_task_synchronous() {
328 let mut sched = TaskScheduler::new();
329 sched.register(1, ValueWord::from_function(0));
330
331 let result = sched.resolve_task(1, |_callable| Ok(ValueWord::from_f64(99.0)));
332 assert!(result.is_ok());
333 let val = result.unwrap();
334 assert!((val.as_f64().unwrap() - 99.0).abs() < f64::EPSILON);
335
336 let cached = sched.resolve_task(1, |_| panic!("should not be called"));
338 assert!(cached.is_ok());
339 }
340
341 #[test]
342 fn test_cancel_task() {
343 let mut sched = TaskScheduler::new();
344 sched.register(1, ValueWord::from_function(0));
345
346 sched.cancel(1);
347 assert!(sched.is_resolved(1));
348
349 let result = sched.resolve_task(1, |_| Ok(ValueWord::none()));
350 assert!(result.is_err());
351 }
352
353 #[test]
354 fn test_resolve_all_group() {
355 let mut sched = TaskScheduler::new();
356 sched.register(1, ValueWord::from_function(0));
357 sched.register(2, ValueWord::from_function(1));
358
359 let mut call_count = 0u32;
360 let result = sched.resolve_task_group(0, &[1, 2], |_callable| {
361 call_count += 1;
362 Ok(ValueWord::from_f64(call_count as f64))
363 });
364 assert!(result.is_ok());
365 let val = result.unwrap();
366 let view = val.as_any_array().expect("Expected array");
367 assert_eq!(view.len(), 2);
368 }
369
370 #[test]
371 fn test_resolve_race_group() {
372 let mut sched = TaskScheduler::new();
373 sched.register(10, ValueWord::from_function(0));
374 sched.register(20, ValueWord::from_function(1));
375
376 let result = sched.resolve_task_group(1, &[10, 20], |_| {
377 Ok(ValueWord::from_string(Arc::new("first".to_string())))
378 });
379 assert!(result.is_ok());
380 let val = result.unwrap();
381 assert_eq!(val.as_str().unwrap(), "first");
382 }
383
384 #[test]
385 fn test_register_external_and_resolve() {
386 let mut sched = TaskScheduler::new();
387 let tx = sched.register_external(100);
388 assert!(sched.has_external(100));
389 assert!(matches!(sched.get_result(100), Some(TaskStatus::Pending)));
390
391 assert!(sched.try_resolve_external(100).is_none());
393
394 tx.send(Ok(ValueWord::from_f64(42.0))).unwrap();
396
397 let result = sched.try_resolve_external(100);
399 assert!(result.is_some());
400 let val = result.unwrap().unwrap();
401 assert!((val.as_f64().unwrap() - 42.0).abs() < f64::EPSILON);
402
403 assert!(!sched.has_external(100));
405 }
406
407 #[test]
408 fn test_external_task_error() {
409 let mut sched = TaskScheduler::new();
410 let tx = sched.register_external(200);
411
412 tx.send(Err("connection refused".to_string())).unwrap();
413
414 let result = sched.try_resolve_external(200);
415 assert!(result.is_some());
416 assert!(result.unwrap().is_err());
417 }
418
419 #[test]
420 fn test_external_task_cancelled() {
421 let mut sched = TaskScheduler::new();
422 let tx = sched.register_external(300);
423
424 drop(tx);
426
427 let result = sched.try_resolve_external(300);
428 assert!(result.is_some());
429 assert!(result.unwrap().is_err());
430 }
431
432 #[test]
433 fn test_take_external_receiver() {
434 let mut sched = TaskScheduler::new();
435 let _tx = sched.register_external(400);
436
437 assert!(sched.has_external(400));
438 let rx = sched.take_external_receiver(400);
439 assert!(rx.is_some());
440 assert!(!sched.has_external(400));
441 }
442}