shape_vm/executor/async_ops/
mod.rs1use crate::{
9 bytecode::{Instruction, OpCode, Operand},
10 executor::VirtualMachine,
11};
12use shape_value::heap_value::HeapValue;
13use shape_value::{VMError, ValueWord};
14
15#[derive(Debug, Clone)]
17pub enum AsyncExecutionResult {
18 Continue,
20 Yielded,
22 Suspended(SuspensionInfo),
24}
25
26#[derive(Debug, Clone)]
28pub struct SuspensionInfo {
29 pub wait_type: WaitType,
31 pub resume_ip: usize,
33}
34
35#[derive(Debug, Clone)]
37pub enum WaitType {
38 NextBar { source: String },
40 Timer { id: u64 },
42 AnyEvent,
44 Future { id: u64 },
46 TaskGroup { kind: u8, task_ids: Vec<u64> },
48}
49
50impl VirtualMachine {
51 #[inline(always)]
56 pub(in crate::executor) fn exec_async_op(
57 &mut self,
58 instruction: &Instruction,
59 ) -> Result<AsyncExecutionResult, VMError> {
60 use OpCode::*;
61 match instruction.opcode {
62 Yield => self.op_yield(),
63 Suspend => self.op_suspend(instruction),
64 Resume => self.op_resume(instruction),
65 Poll => self.op_poll(),
66 AwaitBar => self.op_await_bar(instruction),
67 AwaitTick => self.op_await_tick(instruction),
68 EmitAlert => self.op_emit_alert(),
69 EmitEvent => self.op_emit_event(),
70 Await => self.op_await(),
71 SpawnTask => self.op_spawn_task(),
72 JoinInit => self.op_join_init(instruction),
73 JoinAwait => self.op_join_await(),
74 CancelTask => self.op_cancel_task(),
75 AsyncScopeEnter => self.op_async_scope_enter(),
76 AsyncScopeExit => self.op_async_scope_exit(),
77 _ => unreachable!(
78 "exec_async_op called with non-async opcode: {:?}",
79 instruction.opcode
80 ),
81 }
82 }
83
84 fn op_yield(&mut self) -> Result<AsyncExecutionResult, VMError> {
89 Ok(AsyncExecutionResult::Yielded)
91 }
92
93 fn op_suspend(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
97 let wait_type = match &instruction.operand {
98 Some(Operand::Const(idx)) => {
99 let _ = idx;
102 WaitType::AnyEvent
103 }
104 _ => WaitType::AnyEvent,
105 };
106
107 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
108 wait_type,
109 resume_ip: self.ip,
110 }))
111 }
112
113 fn op_resume(&mut self, _instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
118 Ok(AsyncExecutionResult::Continue)
121 }
122
123 fn op_poll(&mut self) -> Result<AsyncExecutionResult, VMError> {
128 self.push_vw(ValueWord::none()).map_err(|e| e)?;
132 Ok(AsyncExecutionResult::Continue)
133 }
134
135 fn op_await_bar(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
140 let source = match &instruction.operand {
141 Some(Operand::Const(idx)) => {
142 match self.program.constants.get(*idx as usize) {
144 Some(crate::bytecode::Constant::String(s)) => s.clone(),
145 _ => "default".to_string(),
146 }
147 }
148 _ => "default".to_string(),
149 };
150
151 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
152 wait_type: WaitType::NextBar { source },
153 resume_ip: self.ip,
154 }))
155 }
156
157 fn op_await_tick(
161 &mut self,
162 instruction: &Instruction,
163 ) -> Result<AsyncExecutionResult, VMError> {
164 let timer_id = match &instruction.operand {
165 Some(Operand::Const(idx)) => {
166 match self.program.constants.get(*idx as usize) {
168 Some(crate::bytecode::Constant::Number(n)) => *n as u64,
169 _ => 0,
170 }
171 }
172 _ => 0,
173 };
174
175 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
176 wait_type: WaitType::Timer { id: timer_id },
177 resume_ip: self.ip,
178 }))
179 }
180
181 fn op_emit_alert(&mut self) -> Result<AsyncExecutionResult, VMError> {
186 let _alert_nb = self.pop_vw()?;
187 Ok(AsyncExecutionResult::Continue)
189 }
190
191 fn op_await(&mut self) -> Result<AsyncExecutionResult, VMError> {
199 let nb = self.pop_vw()?;
200 match nb.as_heap_ref() {
201 Some(HeapValue::Future(id)) => {
202 let id = *id;
203
204 let resolved = self.task_scheduler.resolve_task(id, |callable| {
209 Ok(callable)
213 });
214
215 match resolved {
216 Ok(value) => {
217 self.push_vw(value)?;
218 Ok(AsyncExecutionResult::Continue)
219 }
220 Err(_) => {
221 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
223 wait_type: WaitType::Future { id },
224 resume_ip: self.ip,
225 }))
226 }
227 }
228 }
229 _ => {
230 self.push_vw(nb)?;
232 Ok(AsyncExecutionResult::Continue)
233 }
234 }
235 }
236
237 fn op_spawn_task(&mut self) -> Result<AsyncExecutionResult, VMError> {
245 let callable_nb = self.pop_vw()?;
246
247 let task_id = self.next_future_id();
248 self.task_scheduler.register(task_id, callable_nb);
249
250 if let Some(scope) = self.async_scope_stack.last_mut() {
251 scope.push(task_id);
252 }
253
254 self.push_vw(ValueWord::from_future(task_id))?;
255 Ok(AsyncExecutionResult::Continue)
256 }
257
258 fn op_join_init(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
264 let packed = match &instruction.operand {
265 Some(Operand::Count(n)) => *n,
266 _ => {
267 return Err(VMError::RuntimeError(
268 "JoinInit requires Count operand".to_string(),
269 ));
270 }
271 };
272
273 let kind = ((packed >> 14) & 0x03) as u8;
274 let arity = (packed & 0x3FFF) as usize;
275
276 if self.sp < arity {
277 return Err(VMError::StackUnderflow);
278 }
279
280 let mut task_ids = Vec::with_capacity(arity);
281 for _ in 0..arity {
282 let nb = self.pop_vw()?;
283 match nb.as_heap_ref() {
284 Some(HeapValue::Future(id)) => task_ids.push(*id),
285 _ => {
286 return Err(VMError::RuntimeError(format!(
287 "JoinInit expected Future, got {}",
288 nb.type_name()
289 )));
290 }
291 }
292 }
293 task_ids.reverse();
295
296 self.push_vw(ValueWord::from_heap_value(
297 shape_value::heap_value::HeapValue::TaskGroup { kind, task_ids },
298 ))?;
299 Ok(AsyncExecutionResult::Continue)
300 }
301
302 fn op_join_await(&mut self) -> Result<AsyncExecutionResult, VMError> {
309 let nb = self.pop_vw()?;
310 match nb.as_heap_ref() {
311 Some(HeapValue::TaskGroup { kind, task_ids }) => {
312 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
313 wait_type: WaitType::TaskGroup {
314 kind: *kind,
315 task_ids: task_ids.clone(),
316 },
317 resume_ip: self.ip,
318 }))
319 }
320 _ => Err(VMError::RuntimeError(format!(
321 "JoinAwait expected TaskGroup, got {}",
322 nb.type_name()
323 ))),
324 }
325 }
326
327 fn op_cancel_task(&mut self) -> Result<AsyncExecutionResult, VMError> {
332 let nb = self.pop_vw()?;
333 match nb.as_heap_ref() {
334 Some(HeapValue::Future(id)) => {
335 self.task_scheduler.cancel(*id);
336 Ok(AsyncExecutionResult::Continue)
337 }
338 _ => Err(VMError::RuntimeError(format!(
339 "CancelTask expected Future, got {}",
340 nb.type_name()
341 ))),
342 }
343 }
344
345 fn op_async_scope_enter(&mut self) -> Result<AsyncExecutionResult, VMError> {
350 self.async_scope_stack.push(Vec::new());
351 Ok(AsyncExecutionResult::Continue)
352 }
353
354 fn op_async_scope_exit(&mut self) -> Result<AsyncExecutionResult, VMError> {
360 if let Some(mut scope_tasks) = self.async_scope_stack.pop() {
361 scope_tasks.reverse();
363 for task_id in scope_tasks {
364 self.task_scheduler.cancel(task_id);
365 }
366 }
367 Ok(AsyncExecutionResult::Continue)
369 }
370
371 fn op_emit_event(&mut self) -> Result<AsyncExecutionResult, VMError> {
376 let _event_nb = self.pop_vw()?;
377 Ok(AsyncExecutionResult::Continue)
379 }
380}
381
382#[cfg(test)]
384pub fn is_async_opcode(opcode: OpCode) -> bool {
385 matches!(
386 opcode,
387 OpCode::Yield
388 | OpCode::Suspend
389 | OpCode::Resume
390 | OpCode::Poll
391 | OpCode::AwaitBar
392 | OpCode::AwaitTick
393 | OpCode::EmitAlert
394 | OpCode::EmitEvent
395 | OpCode::Await
396 | OpCode::SpawnTask
397 | OpCode::JoinInit
398 | OpCode::JoinAwait
399 | OpCode::CancelTask
400 | OpCode::AsyncScopeEnter
401 | OpCode::AsyncScopeExit
402 )
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_is_async_opcode() {
411 assert!(is_async_opcode(OpCode::Yield));
412 assert!(is_async_opcode(OpCode::Suspend));
413 assert!(is_async_opcode(OpCode::EmitAlert));
414 assert!(is_async_opcode(OpCode::AsyncScopeEnter));
415 assert!(is_async_opcode(OpCode::AsyncScopeExit));
416 assert!(!is_async_opcode(OpCode::Add));
417 assert!(!is_async_opcode(OpCode::Jump));
418 }
419
420 #[test]
421 fn test_is_async_opcode_all_variants() {
422 assert!(is_async_opcode(OpCode::Yield));
424 assert!(is_async_opcode(OpCode::Suspend));
425 assert!(is_async_opcode(OpCode::Resume));
426 assert!(is_async_opcode(OpCode::Poll));
427 assert!(is_async_opcode(OpCode::AwaitBar));
428 assert!(is_async_opcode(OpCode::AwaitTick));
429 assert!(is_async_opcode(OpCode::EmitAlert));
430 assert!(is_async_opcode(OpCode::EmitEvent));
431
432 assert!(!is_async_opcode(OpCode::PushConst));
434 assert!(!is_async_opcode(OpCode::Return));
435 assert!(!is_async_opcode(OpCode::Call));
436 assert!(!is_async_opcode(OpCode::Nop));
437 }
438
439 #[test]
440 fn test_async_execution_result_variants() {
441 let continue_result = AsyncExecutionResult::Continue;
443 assert!(matches!(continue_result, AsyncExecutionResult::Continue));
444
445 let yielded_result = AsyncExecutionResult::Yielded;
447 assert!(matches!(yielded_result, AsyncExecutionResult::Yielded));
448
449 let suspended_result = AsyncExecutionResult::Suspended(SuspensionInfo {
451 wait_type: WaitType::AnyEvent,
452 resume_ip: 42,
453 });
454 match suspended_result {
455 AsyncExecutionResult::Suspended(info) => {
456 assert_eq!(info.resume_ip, 42);
457 assert!(matches!(info.wait_type, WaitType::AnyEvent));
458 }
459 _ => panic!("Expected Suspended"),
460 }
461 }
462
463 #[test]
464 fn test_wait_type_variants() {
465 let next_bar = WaitType::NextBar {
467 source: "market_data".to_string(),
468 };
469 match next_bar {
470 WaitType::NextBar { source } => assert_eq!(source, "market_data"),
471 _ => panic!("Expected NextBar"),
472 }
473
474 let timer = WaitType::Timer { id: 123 };
476 match timer {
477 WaitType::Timer { id } => assert_eq!(id, 123),
478 _ => panic!("Expected Timer"),
479 }
480
481 let any = WaitType::AnyEvent;
483 assert!(matches!(any, WaitType::AnyEvent));
484 }
485
486 #[test]
487 fn test_suspension_info_creation() {
488 let info = SuspensionInfo {
489 wait_type: WaitType::Timer { id: 999 },
490 resume_ip: 100,
491 };
492
493 assert_eq!(info.resume_ip, 100);
494 assert!(matches!(info.wait_type, WaitType::Timer { id: 999 }));
495 }
496
497 #[test]
498 fn test_is_async_opcode_await() {
499 assert!(is_async_opcode(OpCode::Await));
500 }
501
502 #[test]
503 fn test_wait_type_future() {
504 let future = WaitType::Future { id: 42 };
505 match future {
506 WaitType::Future { id } => assert_eq!(id, 42),
507 _ => panic!("Expected Future"),
508 }
509 }
510
511 #[test]
512 fn test_is_async_opcode_join_opcodes() {
513 assert!(is_async_opcode(OpCode::SpawnTask));
514 assert!(is_async_opcode(OpCode::JoinInit));
515 assert!(is_async_opcode(OpCode::JoinAwait));
516 assert!(is_async_opcode(OpCode::CancelTask));
517 }
518
519 #[test]
520 fn test_wait_type_task_group() {
521 let tg = WaitType::TaskGroup {
522 kind: 0,
523 task_ids: vec![1, 2, 3],
524 };
525 match tg {
526 WaitType::TaskGroup { kind, task_ids } => {
527 assert_eq!(kind, 0); assert_eq!(task_ids.len(), 3);
529 assert_eq!(task_ids, vec![1, 2, 3]);
530 }
531 _ => panic!("Expected TaskGroup"),
532 }
533 }
534
535 #[test]
536 fn test_wait_type_task_group_race() {
537 let tg = WaitType::TaskGroup {
538 kind: 1,
539 task_ids: vec![10, 20],
540 };
541 match tg {
542 WaitType::TaskGroup { kind, task_ids } => {
543 assert_eq!(kind, 1); assert_eq!(task_ids, vec![10, 20]);
545 }
546 _ => panic!("Expected TaskGroup"),
547 }
548 }
549}