shape_vm/executor/async_ops/
mod.rs1use crate::{
9 bytecode::{Instruction, OpCode, Operand},
10 executor::VirtualMachine,
11};
12use shape_value::heap_value::HeapValue;
13use shape_value::{VMError, ValueWord};
14
15#[derive(Debug, Clone)]
17pub enum AsyncExecutionResult {
18 Continue,
20 Yielded,
22 Suspended(SuspensionInfo),
24}
25
26#[derive(Debug, Clone)]
28pub struct SuspensionInfo {
29 pub wait_type: WaitType,
31 pub resume_ip: usize,
33}
34
35#[derive(Debug, Clone)]
37pub enum WaitType {
38 NextBar { source: String },
40 Timer { id: u64 },
42 AnyEvent,
44 Future { id: u64 },
46 TaskGroup { kind: u8, task_ids: Vec<u64> },
48}
49
50impl VirtualMachine {
51 #[inline(always)]
56 pub(in crate::executor) fn exec_async_op(
57 &mut self,
58 instruction: &Instruction,
59 ) -> Result<AsyncExecutionResult, VMError> {
60 use OpCode::*;
61 match instruction.opcode {
62 Yield => self.op_yield(),
63 Suspend => self.op_suspend(instruction),
64 Resume => self.op_resume(instruction),
65 Poll => self.op_poll(),
66 AwaitBar => self.op_await_bar(instruction),
67 AwaitTick => self.op_await_tick(instruction),
68 EmitAlert => self.op_emit_alert(),
69 EmitEvent => self.op_emit_event(),
70 Await => self.op_await(),
71 SpawnTask => self.op_spawn_task(),
72 JoinInit => self.op_join_init(instruction),
73 JoinAwait => self.op_join_await(),
74 CancelTask => self.op_cancel_task(),
75 AsyncScopeEnter => self.op_async_scope_enter(),
76 AsyncScopeExit => self.op_async_scope_exit(),
77 _ => unreachable!(
78 "exec_async_op called with non-async opcode: {:?}",
79 instruction.opcode
80 ),
81 }
82 }
83
84 fn op_yield(&mut self) -> Result<AsyncExecutionResult, VMError> {
89 Ok(AsyncExecutionResult::Yielded)
91 }
92
93 fn op_suspend(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
97 let wait_type = match &instruction.operand {
98 Some(Operand::Const(idx)) => {
99 let _ = idx;
102 WaitType::AnyEvent
103 }
104 _ => WaitType::AnyEvent,
105 };
106
107 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
108 wait_type,
109 resume_ip: self.ip,
110 }))
111 }
112
113 fn op_resume(&mut self, _instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
118 Ok(AsyncExecutionResult::Continue)
121 }
122
123 fn op_poll(&mut self) -> Result<AsyncExecutionResult, VMError> {
128 self.push_vw(ValueWord::none()).map_err(|e| e)?;
132 Ok(AsyncExecutionResult::Continue)
133 }
134
135 fn op_await_bar(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
140 let source = match &instruction.operand {
141 Some(Operand::Const(idx)) => {
142 match self.program.constants.get(*idx as usize) {
144 Some(crate::bytecode::Constant::String(s)) => s.clone(),
145 _ => "default".to_string(),
146 }
147 }
148 _ => "default".to_string(),
149 };
150
151 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
152 wait_type: WaitType::NextBar { source },
153 resume_ip: self.ip,
154 }))
155 }
156
157 fn op_await_tick(
161 &mut self,
162 instruction: &Instruction,
163 ) -> Result<AsyncExecutionResult, VMError> {
164 let timer_id = match &instruction.operand {
165 Some(Operand::Const(idx)) => {
166 match self.program.constants.get(*idx as usize) {
168 Some(crate::bytecode::Constant::Number(n)) => *n as u64,
169 _ => 0,
170 }
171 }
172 _ => 0,
173 };
174
175 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
176 wait_type: WaitType::Timer { id: timer_id },
177 resume_ip: self.ip,
178 }))
179 }
180
181 fn op_emit_alert(&mut self) -> Result<AsyncExecutionResult, VMError> {
186 let _alert_nb = self.pop_vw()?;
187 Ok(AsyncExecutionResult::Continue)
189 }
190
191 fn op_await(&mut self) -> Result<AsyncExecutionResult, VMError> {
196 let nb = self.pop_vw()?;
197 match nb.as_heap_ref() {
198 Some(HeapValue::Future(id)) => {
199 let id = *id;
200 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
201 wait_type: WaitType::Future { id },
202 resume_ip: self.ip,
203 }))
204 }
205 _ => {
206 self.push_vw(nb)?;
208 Ok(AsyncExecutionResult::Continue)
209 }
210 }
211 }
212
213 fn op_spawn_task(&mut self) -> Result<AsyncExecutionResult, VMError> {
221 let callable_nb = self.pop_vw()?;
222
223 let task_id = self.next_future_id();
224 self.task_scheduler.register(task_id, callable_nb);
225
226 if let Some(scope) = self.async_scope_stack.last_mut() {
227 scope.push(task_id);
228 }
229
230 self.push_vw(ValueWord::from_future(task_id))?;
231 Ok(AsyncExecutionResult::Continue)
232 }
233
234 fn op_join_init(&mut self, instruction: &Instruction) -> Result<AsyncExecutionResult, VMError> {
240 let packed = match &instruction.operand {
241 Some(Operand::Count(n)) => *n,
242 _ => {
243 return Err(VMError::RuntimeError(
244 "JoinInit requires Count operand".to_string(),
245 ));
246 }
247 };
248
249 let kind = ((packed >> 14) & 0x03) as u8;
250 let arity = (packed & 0x3FFF) as usize;
251
252 if self.sp < arity {
253 return Err(VMError::StackUnderflow);
254 }
255
256 let mut task_ids = Vec::with_capacity(arity);
257 for _ in 0..arity {
258 let nb = self.pop_vw()?;
259 match nb.as_heap_ref() {
260 Some(HeapValue::Future(id)) => task_ids.push(*id),
261 _ => {
262 return Err(VMError::RuntimeError(format!(
263 "JoinInit expected Future, got {}",
264 nb.type_name()
265 )));
266 }
267 }
268 }
269 task_ids.reverse();
271
272 self.push_vw(ValueWord::from_heap_value(
273 shape_value::heap_value::HeapValue::TaskGroup { kind, task_ids },
274 ))?;
275 Ok(AsyncExecutionResult::Continue)
276 }
277
278 fn op_join_await(&mut self) -> Result<AsyncExecutionResult, VMError> {
285 let nb = self.pop_vw()?;
286 match nb.as_heap_ref() {
287 Some(HeapValue::TaskGroup { kind, task_ids }) => {
288 Ok(AsyncExecutionResult::Suspended(SuspensionInfo {
289 wait_type: WaitType::TaskGroup {
290 kind: *kind,
291 task_ids: task_ids.clone(),
292 },
293 resume_ip: self.ip,
294 }))
295 }
296 _ => Err(VMError::RuntimeError(format!(
297 "JoinAwait expected TaskGroup, got {}",
298 nb.type_name()
299 ))),
300 }
301 }
302
303 fn op_cancel_task(&mut self) -> Result<AsyncExecutionResult, VMError> {
308 let nb = self.pop_vw()?;
309 match nb.as_heap_ref() {
310 Some(HeapValue::Future(id)) => {
311 self.task_scheduler.cancel(*id);
312 Ok(AsyncExecutionResult::Continue)
313 }
314 _ => Err(VMError::RuntimeError(format!(
315 "CancelTask expected Future, got {}",
316 nb.type_name()
317 ))),
318 }
319 }
320
321 fn op_async_scope_enter(&mut self) -> Result<AsyncExecutionResult, VMError> {
326 self.async_scope_stack.push(Vec::new());
327 Ok(AsyncExecutionResult::Continue)
328 }
329
330 fn op_async_scope_exit(&mut self) -> Result<AsyncExecutionResult, VMError> {
336 if let Some(mut scope_tasks) = self.async_scope_stack.pop() {
337 scope_tasks.reverse();
339 for task_id in scope_tasks {
340 self.task_scheduler.cancel(task_id);
341 }
342 }
343 Ok(AsyncExecutionResult::Continue)
345 }
346
347 fn op_emit_event(&mut self) -> Result<AsyncExecutionResult, VMError> {
352 let _event_nb = self.pop_vw()?;
353 Ok(AsyncExecutionResult::Continue)
355 }
356}
357
358#[cfg(test)]
360pub fn is_async_opcode(opcode: OpCode) -> bool {
361 matches!(
362 opcode,
363 OpCode::Yield
364 | OpCode::Suspend
365 | OpCode::Resume
366 | OpCode::Poll
367 | OpCode::AwaitBar
368 | OpCode::AwaitTick
369 | OpCode::EmitAlert
370 | OpCode::EmitEvent
371 | OpCode::Await
372 | OpCode::SpawnTask
373 | OpCode::JoinInit
374 | OpCode::JoinAwait
375 | OpCode::CancelTask
376 | OpCode::AsyncScopeEnter
377 | OpCode::AsyncScopeExit
378 )
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_is_async_opcode() {
387 assert!(is_async_opcode(OpCode::Yield));
388 assert!(is_async_opcode(OpCode::Suspend));
389 assert!(is_async_opcode(OpCode::EmitAlert));
390 assert!(is_async_opcode(OpCode::AsyncScopeEnter));
391 assert!(is_async_opcode(OpCode::AsyncScopeExit));
392 assert!(!is_async_opcode(OpCode::Add));
393 assert!(!is_async_opcode(OpCode::Jump));
394 }
395
396 #[test]
397 fn test_is_async_opcode_all_variants() {
398 assert!(is_async_opcode(OpCode::Yield));
400 assert!(is_async_opcode(OpCode::Suspend));
401 assert!(is_async_opcode(OpCode::Resume));
402 assert!(is_async_opcode(OpCode::Poll));
403 assert!(is_async_opcode(OpCode::AwaitBar));
404 assert!(is_async_opcode(OpCode::AwaitTick));
405 assert!(is_async_opcode(OpCode::EmitAlert));
406 assert!(is_async_opcode(OpCode::EmitEvent));
407
408 assert!(!is_async_opcode(OpCode::PushConst));
410 assert!(!is_async_opcode(OpCode::Return));
411 assert!(!is_async_opcode(OpCode::Call));
412 assert!(!is_async_opcode(OpCode::Nop));
413 }
414
415 #[test]
416 fn test_async_execution_result_variants() {
417 let continue_result = AsyncExecutionResult::Continue;
419 assert!(matches!(continue_result, AsyncExecutionResult::Continue));
420
421 let yielded_result = AsyncExecutionResult::Yielded;
423 assert!(matches!(yielded_result, AsyncExecutionResult::Yielded));
424
425 let suspended_result = AsyncExecutionResult::Suspended(SuspensionInfo {
427 wait_type: WaitType::AnyEvent,
428 resume_ip: 42,
429 });
430 match suspended_result {
431 AsyncExecutionResult::Suspended(info) => {
432 assert_eq!(info.resume_ip, 42);
433 assert!(matches!(info.wait_type, WaitType::AnyEvent));
434 }
435 _ => panic!("Expected Suspended"),
436 }
437 }
438
439 #[test]
440 fn test_wait_type_variants() {
441 let next_bar = WaitType::NextBar {
443 source: "market_data".to_string(),
444 };
445 match next_bar {
446 WaitType::NextBar { source } => assert_eq!(source, "market_data"),
447 _ => panic!("Expected NextBar"),
448 }
449
450 let timer = WaitType::Timer { id: 123 };
452 match timer {
453 WaitType::Timer { id } => assert_eq!(id, 123),
454 _ => panic!("Expected Timer"),
455 }
456
457 let any = WaitType::AnyEvent;
459 assert!(matches!(any, WaitType::AnyEvent));
460 }
461
462 #[test]
463 fn test_suspension_info_creation() {
464 let info = SuspensionInfo {
465 wait_type: WaitType::Timer { id: 999 },
466 resume_ip: 100,
467 };
468
469 assert_eq!(info.resume_ip, 100);
470 assert!(matches!(info.wait_type, WaitType::Timer { id: 999 }));
471 }
472
473 #[test]
474 fn test_is_async_opcode_await() {
475 assert!(is_async_opcode(OpCode::Await));
476 }
477
478 #[test]
479 fn test_wait_type_future() {
480 let future = WaitType::Future { id: 42 };
481 match future {
482 WaitType::Future { id } => assert_eq!(id, 42),
483 _ => panic!("Expected Future"),
484 }
485 }
486
487 #[test]
488 fn test_is_async_opcode_join_opcodes() {
489 assert!(is_async_opcode(OpCode::SpawnTask));
490 assert!(is_async_opcode(OpCode::JoinInit));
491 assert!(is_async_opcode(OpCode::JoinAwait));
492 assert!(is_async_opcode(OpCode::CancelTask));
493 }
494
495 #[test]
496 fn test_wait_type_task_group() {
497 let tg = WaitType::TaskGroup {
498 kind: 0,
499 task_ids: vec![1, 2, 3],
500 };
501 match tg {
502 WaitType::TaskGroup { kind, task_ids } => {
503 assert_eq!(kind, 0); assert_eq!(task_ids.len(), 3);
505 assert_eq!(task_ids, vec![1, 2, 3]);
506 }
507 _ => panic!("Expected TaskGroup"),
508 }
509 }
510
511 #[test]
512 fn test_wait_type_task_group_race() {
513 let tg = WaitType::TaskGroup {
514 kind: 1,
515 task_ids: vec![10, 20],
516 };
517 match tg {
518 WaitType::TaskGroup { kind, task_ids } => {
519 assert_eq!(kind, 1); assert_eq!(task_ids, vec![10, 20]);
521 }
522 _ => panic!("Expected TaskGroup"),
523 }
524 }
525}