shape_vm/executor/async_ops/
mod.rs1use crate::{
46 bytecode::{Instruction, OpCode, Operand},
47 executor::VirtualMachine,
48};
49use shape_value::heap_value::HeapValue;
50use shape_value::{VMError, ValueWord};
51
52#[derive(Debug, Clone)]
54pub enum AsyncExecutionResult {
55 Continue,
57 Yielded,
59 Suspended(SuspensionInfo),
61}
62
63#[derive(Debug, Clone)]
65pub struct SuspensionInfo {
66 pub wait_type: WaitType,
68 pub resume_ip: usize,
70}
71
72#[derive(Debug, Clone)]
74pub enum WaitType {
75 NextBar { source: String },
77 Timer { id: u64 },
79 AnyEvent,
81 Future { id: u64 },
83 TaskGroup { kind: u8, task_ids: Vec<u64> },
85}
86
87impl VirtualMachine {
88 #[inline(always)]
93 pub(in crate::executor) fn exec_async_op(
94 &mut self,
95 instruction: &Instruction,
96 ) -> Result<AsyncExecutionResult, VMError> {
97 use OpCode::*;
98 match instruction.opcode {
99 Yield => self.op_yield(),
100 Suspend => self.op_suspend(instruction),
101 Resume => self.op_resume(instruction),
102 Poll => self.op_poll(),
103 AwaitBar => self.op_await_bar(instruction),
104 AwaitTick => self.op_await_tick(instruction),
105 EmitAlert => self.op_emit_alert(),
106 EmitEvent => self.op_emit_event(),
107 Await => self.op_await(),
108 SpawnTask => self.op_spawn_task(),
109 JoinInit => self.op_join_init(instruction),
110 JoinAwait => self.op_join_await(),
111 CancelTask => self.op_cancel_task(),
112 AsyncScopeEnter => self.op_async_scope_enter(),
113 AsyncScopeExit => self.op_async_scope_exit(),
114 _ => unreachable!(
115 "exec_async_op called with non-async opcode: {:?}",
116 instruction.opcode
117 ),
118 }
119 }
120
121 fn op_yield(&mut self) -> Result<AsyncExecutionResult, VMError> {
126 Ok(AsyncExecutionResult::Yielded)
128 }
129
130 fn op_suspend(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
134 let wait_type = match &instruction.operand {
135 Some(Operand::Const(idx)) => {
136 let _ = idx;
139 WaitType::AnyEvent
140 }
141 _ => WaitType::AnyEvent,
142 };
143
144 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
145 wait_type,
146 resume_ip: self.ip,
147 }))
148 }
149
150 fn op_resume(&mut self, _instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
155 Ok(AsyncExecutionResult::Continue)
158 }
159
160 fn op_poll(&mut self) -> Result<AsyncExecutionResult, VMError> {
165 self.push_vw(ValueWord::none()).map_err(|e| e)?;
169 Ok(AsyncExecutionResult::Continue)
170 }
171
172 fn op_await_bar(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
177 let source = match &instruction.operand {
178 Some(Operand::Const(idx)) => {
179 match self.program.constants.get(*idx as usize) {
181 Some(crate::bytecode::Constant::String(s)) => s.clone(),
182 _ => "default".to_string(),
183 }
184 }
185 _ => "default".to_string(),
186 };
187
188 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
189 wait_type: WaitType::NextBar { source },
190 resume_ip: self.ip,
191 }))
192 }
193
194 fn op_await_tick(
198 &mut self,
199 instruction: &Instruction,
200 ) -> Result<AsyncExecutionResult, VMError> {
201 let timer_id = match &instruction.operand {
202 Some(Operand::Const(idx)) => {
203 match self.program.constants.get(*idx as usize) {
205 Some(crate::bytecode::Constant::Number(n)) => *n as u64,
206 _ => 0,
207 }
208 }
209 _ => 0,
210 };
211
212 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
213 wait_type: WaitType::Timer { id: timer_id },
214 resume_ip: self.ip,
215 }))
216 }
217
218 fn op_emit_alert(&mut self) -> Result<AsyncExecutionResult, VMError> {
223 let _alert_nb = self.pop_vw()?;
224 Ok(AsyncExecutionResult::Continue)
226 }
227
228 fn op_await(&mut self) -> Result<AsyncExecutionResult, VMError> {
236 let sp_before = self.sp;
237 let nb = self.pop_vw()?;
238 match nb.as_heap_ref() {
239 Some(HeapValue::Future(id)) => {
240 let id = *id;
241
242 let resolved = self.task_scheduler.resolve_task(id, |callable| {
247 Ok(callable)
251 });
252
253 match resolved {
254 Ok(value) => {
255 self.push_vw(value)?;
256 debug_assert_eq!(
258 self.sp, sp_before,
259 "op_await: stack depth changed (before={}, after={})",
260 sp_before, self.sp
261 );
262 Ok(AsyncExecutionResult::Continue)
263 }
264 Err(_) => {
265 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
267 wait_type: WaitType::Future { id },
268 resume_ip: self.ip,
269 }))
270 }
271 }
272 }
273 _ => {
274 self.push_vw(nb)?;
276 debug_assert_eq!(
277 self.sp, sp_before,
278 "op_await (sync shortcut): stack depth changed (before={}, after={})",
279 sp_before, self.sp
280 );
281 Ok(AsyncExecutionResult::Continue)
282 }
283 }
284 }
285
286 fn op_spawn_task(&mut self) -> Result<AsyncExecutionResult, VMError> {
296 let sp_before = self.sp;
297 let callable_nb = self.pop_vw()?;
298
299 let task_id = self.next_future_id();
300 self.task_scheduler.register(task_id, callable_nb);
301
302 if let Some(scope) = self.async_scope_stack.last_mut() {
303 scope.push(task_id);
304 }
305
306 self.push_vw(ValueWord::from_future(task_id))?;
307 debug_assert_eq!(
309 self.sp, sp_before,
310 "op_spawn_task: stack depth changed (before={}, after={})",
311 sp_before, self.sp
312 );
313 Ok(AsyncExecutionResult::Continue)
314 }
315
316 fn op_join_init(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
322 let packed = match &instruction.operand {
323 Some(Operand::Count(n)) => *n,
324 _ => {
325 return Err(VMError::RuntimeError(
326 "JoinInit requires Count operand".to_string(),
327 ));
328 }
329 };
330
331 let kind = ((packed >> 14) & 0x03) as u8;
332 let arity = (packed & 0x3FFF) as usize;
333
334 if self.sp < arity {
335 return Err(VMError::StackUnderflow);
336 }
337
338 let mut task_ids = Vec::with_capacity(arity);
339 for _ in 0..arity {
340 let nb = self.pop_vw()?;
341 match nb.as_heap_ref() {
342 Some(HeapValue::Future(id)) => task_ids.push(*id),
343 _ => {
344 return Err(VMError::RuntimeError(format!(
345 "JoinInit expected Future, got {}",
346 nb.type_name()
347 )));
348 }
349 }
350 }
351 task_ids.reverse();
353
354 self.push_vw(ValueWord::from_heap_value(
355 shape_value::heap_value::HeapValue::TaskGroup { kind, task_ids },
356 ))?;
357 Ok(AsyncExecutionResult::Continue)
358 }
359
360 fn op_join_await(&mut self) -> Result<AsyncExecutionResult, VMError> {
367 let sp_before = self.sp;
368 let nb = self.pop_vw()?;
369 match nb.as_heap_ref() {
370 Some(HeapValue::TaskGroup { kind, task_ids }) => {
371 let kind = *kind;
372 let task_ids = task_ids.clone();
373
374 let result = self
375 .task_scheduler
376 .resolve_task_group(kind, &task_ids, |callable| Ok(callable));
377
378 match result {
379 Ok(value) => {
380 self.push_vw(value)?;
381 debug_assert_eq!(
383 self.sp, sp_before,
384 "op_join_await: stack depth changed (before={}, after={})",
385 sp_before, self.sp
386 );
387 Ok(AsyncExecutionResult::Continue)
388 }
389 Err(_) => {
390 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
392 wait_type: WaitType::TaskGroup { kind, task_ids },
393 resume_ip: self.ip,
394 }))
395 }
396 }
397 }
398 _ => Err(VMError::RuntimeError(format!(
399 "JoinAwait expected TaskGroup, got {}",
400 nb.type_name()
401 ))),
402 }
403 }
404
405 fn op_cancel_task(&mut self) -> Result<AsyncExecutionResult, VMError> {
410 let nb = self.pop_vw()?;
411 match nb.as_heap_ref() {
412 Some(HeapValue::Future(id)) => {
413 self.task_scheduler.cancel(*id);
414 Ok(AsyncExecutionResult::Continue)
415 }
416 _ => Err(VMError::RuntimeError(format!(
417 "CancelTask expected Future, got {}",
418 nb.type_name()
419 ))),
420 }
421 }
422
423 fn op_async_scope_enter(&mut self) -> Result<AsyncExecutionResult, VMError> {
428 let depth_before = self.async_scope_stack.len();
429 self.async_scope_stack.push(Vec::new());
430 debug_assert_eq!(
431 self.async_scope_stack.len(),
432 depth_before + 1,
433 "op_async_scope_enter: scope stack depth not incremented"
434 );
435 Ok(AsyncExecutionResult::Continue)
436 }
437
438 fn op_async_scope_exit(&mut self) -> Result<AsyncExecutionResult, VMError> {
444 debug_assert!(
445 !self.async_scope_stack.is_empty(),
446 "op_async_scope_exit: scope stack is empty (mismatched Enter/Exit)"
447 );
448 if let Some(mut scope_tasks) = self.async_scope_stack.pop() {
449 scope_tasks.reverse();
451 for task_id in scope_tasks {
452 self.task_scheduler.cancel(task_id);
453 }
454 }
455 Ok(AsyncExecutionResult::Continue)
457 }
458
459 fn op_emit_event(&mut self) -> Result<AsyncExecutionResult, VMError> {
464 let _event_nb = self.pop_vw()?;
465 Ok(AsyncExecutionResult::Continue)
467 }
468}
469
470#[cfg(test)]
472pub fn is_async_opcode(opcode: OpCode) -> bool {
473 matches!(
474 opcode,
475 OpCode::Yield
476 | OpCode::Suspend
477 | OpCode::Resume
478 | OpCode::Poll
479 | OpCode::AwaitBar
480 | OpCode::AwaitTick
481 | OpCode::EmitAlert
482 | OpCode::EmitEvent
483 | OpCode::Await
484 | OpCode::SpawnTask
485 | OpCode::JoinInit
486 | OpCode::JoinAwait
487 | OpCode::CancelTask
488 | OpCode::AsyncScopeEnter
489 | OpCode::AsyncScopeExit
490 )
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_is_async_opcode() {
499 assert!(is_async_opcode(OpCode::Yield));
500 assert!(is_async_opcode(OpCode::Suspend));
501 assert!(is_async_opcode(OpCode::EmitAlert));
502 assert!(is_async_opcode(OpCode::AsyncScopeEnter));
503 assert!(is_async_opcode(OpCode::AsyncScopeExit));
504 assert!(!is_async_opcode(OpCode::Add));
505 assert!(!is_async_opcode(OpCode::Jump));
506 }
507
508 #[test]
509 fn test_is_async_opcode_all_variants() {
510 assert!(is_async_opcode(OpCode::Yield));
512 assert!(is_async_opcode(OpCode::Suspend));
513 assert!(is_async_opcode(OpCode::Resume));
514 assert!(is_async_opcode(OpCode::Poll));
515 assert!(is_async_opcode(OpCode::AwaitBar));
516 assert!(is_async_opcode(OpCode::AwaitTick));
517 assert!(is_async_opcode(OpCode::EmitAlert));
518 assert!(is_async_opcode(OpCode::EmitEvent));
519
520 assert!(!is_async_opcode(OpCode::PushConst));
522 assert!(!is_async_opcode(OpCode::Return));
523 assert!(!is_async_opcode(OpCode::Call));
524 assert!(!is_async_opcode(OpCode::Nop));
525 }
526
527 #[test]
528 fn test_async_execution_result_variants() {
529 let continue_result = AsyncExecutionResult::Continue;
531 assert!(matches!(continue_result, AsyncExecutionResult::Continue));
532
533 let yielded_result = AsyncExecutionResult::Yielded;
535 assert!(matches!(yielded_result, AsyncExecutionResult::Yielded));
536
537 let suspended_result = AsyncExecutionResult::Suspended(SuspensionInfo {
539 wait_type: WaitType::AnyEvent,
540 resume_ip: 42,
541 });
542 match suspended_result {
543 AsyncExecutionResult::Suspended(info) => {
544 assert_eq!(info.resume_ip, 42);
545 assert!(matches!(info.wait_type, WaitType::AnyEvent));
546 }
547 _ => panic!("Expected Suspended"),
548 }
549 }
550
551 #[test]
552 fn test_wait_type_variants() {
553 let next_bar = WaitType::NextBar {
555 source: "market_data".to_string(),
556 };
557 match next_bar {
558 WaitType::NextBar { source } => assert_eq!(source, "market_data"),
559 _ => panic!("Expected NextBar"),
560 }
561
562 let timer = WaitType::Timer { id: 123 };
564 match timer {
565 WaitType::Timer { id } => assert_eq!(id, 123),
566 _ => panic!("Expected Timer"),
567 }
568
569 let any = WaitType::AnyEvent;
571 assert!(matches!(any, WaitType::AnyEvent));
572 }
573
574 #[test]
575 fn test_suspension_info_creation() {
576 let info = SuspensionInfo {
577 wait_type: WaitType::Timer { id: 999 },
578 resume_ip: 100,
579 };
580
581 assert_eq!(info.resume_ip, 100);
582 assert!(matches!(info.wait_type, WaitType::Timer { id: 999 }));
583 }
584
585 #[test]
586 fn test_is_async_opcode_await() {
587 assert!(is_async_opcode(OpCode::Await));
588 }
589
590 #[test]
591 fn test_wait_type_future() {
592 let future = WaitType::Future { id: 42 };
593 match future {
594 WaitType::Future { id } => assert_eq!(id, 42),
595 _ => panic!("Expected Future"),
596 }
597 }
598
599 #[test]
600 fn test_is_async_opcode_join_opcodes() {
601 assert!(is_async_opcode(OpCode::SpawnTask));
602 assert!(is_async_opcode(OpCode::JoinInit));
603 assert!(is_async_opcode(OpCode::JoinAwait));
604 assert!(is_async_opcode(OpCode::CancelTask));
605 }
606
607 #[test]
608 fn test_wait_type_task_group() {
609 let tg = WaitType::TaskGroup {
610 kind: 0,
611 task_ids: vec![1, 2, 3],
612 };
613 match tg {
614 WaitType::TaskGroup { kind, task_ids } => {
615 assert_eq!(kind, 0); assert_eq!(task_ids.len(), 3);
617 assert_eq!(task_ids, vec![1, 2, 3]);
618 }
619 _ => panic!("Expected TaskGroup"),
620 }
621 }
622
623 #[test]
624 fn test_wait_type_task_group_race() {
625 let tg = WaitType::TaskGroup {
626 kind: 1,
627 task_ids: vec![10, 20],
628 };
629 match tg {
630 WaitType::TaskGroup { kind, task_ids } => {
631 assert_eq!(kind, 1); assert_eq!(task_ids, vec![10, 20]);
633 }
634 _ => panic!("Expected TaskGroup"),
635 }
636 }
637}