Skip to main content

shape_runtime/context/
mod.rs

1//! Execution context for Shape runtime
2//!
3//! This module contains the ExecutionContext which manages runtime state including:
4//! - Variable scopes and bindings
5//! - Data access and caching
6//! - Backtest state and series caching
7//! - Type registries and evaluator
8//! - Configuration (symbol, timeframe, date range)
9
10mod config;
11mod data_access;
12mod data_cache;
13mod registries;
14mod scope;
15mod variables;
16
17// Re-export public types
18pub use data_cache::DataLoadMode;
19pub use variables::Variable;
20
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use super::alerts::AlertRouter;
25use super::annotation_context::{AnnotationContext, AnnotationRegistry};
26use super::data::DataFrame;
27use super::event_queue::{SharedEventQueue, SuspensionState};
28use super::lookahead_guard::LookAheadGuard;
29use super::metadata::MetadataRegistry;
30use super::simulation::KernelCompiler;
31use super::type_methods::TypeMethodRegistry;
32use super::type_schema::TypeSchemaRegistry;
33use crate::data::Timeframe;
34use crate::snapshot::{
35    ContextSnapshot, SnapshotStore, SuspensionStateSnapshot, TypeAliasRuntimeEntrySnapshot,
36    VariableSnapshot, nanboxed_to_serializable, serializable_to_nanboxed,
37};
38use anyhow::{Result, anyhow};
39use chrono::{DateTime, Utc};
40use shape_value::ValueWord;
41
42/// Execution context for evaluating expressions
43#[derive(Clone)]
44pub struct ExecutionContext {
45    /// Market data provider (abstraction layer - legacy)
46    data_provider: Option<Arc<dyn std::any::Any + Send + Sync>>,
47    /// Data cache for async provider (Phase 6)
48    /// Clone is cheap since all heavy data is Arc-wrapped internally
49    pub(crate) data_cache: Option<crate::data::DataCache>,
50    /// Provider registry (Phase 8)
51    provider_registry: Arc<super::provider_registry::ProviderRegistry>,
52    /// Type mapping registry (Phase 8)
53    type_mapping_registry: Arc<super::type_mapping::TypeMappingRegistry>,
54    /// Type schema registry for JIT type specialization
55    type_schema_registry: Arc<TypeSchemaRegistry>,
56    /// Metadata registry for generic type metadata (Logic)
57    metadata_registry: Arc<MetadataRegistry>,
58    /// Execution mode for data loading (Phase 8)
59    data_load_mode: DataLoadMode,
60    /// Current ID being analyzed (e.g. symbol, sensor ID)
61    current_id: Option<String>,
62    /// Current data row index (for pattern matching)
63    current_row_index: usize,
64    /// Variable bindings (stack of scopes for function calls)
65    variable_scopes: Vec<HashMap<String, Variable>>,
66    /// Expression evaluator
67    // TODO: Replace with BytecodeExecutor/VM
68    // evaluator: Evaluator,
69    /// Reference datetime for relative data row access
70    reference_datetime: Option<DateTime<Utc>>,
71    /// Current timeframe for data row operations
72    current_timeframe: Option<Timeframe>,
73    /// Base timeframe of the actual data in DuckDB
74    base_timeframe: Option<Timeframe>,
75    /// Look-ahead bias guard
76    lookahead_guard: Option<LookAheadGuard>,
77    /// Registry for user-defined type methods
78    type_method_registry: Arc<TypeMethodRegistry>,
79    /// Date range for data loading (start, end) as native DateTime
80    date_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
81    /// Walk-forward range start (inclusive)
82    range_start: usize,
83    /// Walk-forward range end (exclusive)
84    range_end: usize,
85    /// Whether a custom range is active
86    range_active: bool,
87    /// Decorator-based registry for pattern functions (generic - works for any domain)
88    /// NOTE: This will be replaced by annotation_context.registry("patterns")
89    /// once lifecycle hooks are fully integrated
90    pattern_registry: HashMap<String, super::closure::Closure>,
91    /// Annotation context for lifecycle hooks (cache, state, registries, emit)
92    annotation_context: AnnotationContext,
93    /// Registry for `annotation ... { ... }` definitions
94    annotation_registry: AnnotationRegistry,
95    /// Event queue for async operations (streaming, real-time data)
96    event_queue: Option<SharedEventQueue>,
97    /// Suspension state when execution is paused waiting for an event
98    suspension_state: Option<SuspensionState>,
99    /// Alert pipeline for sending alerts to output sinks
100    alert_pipeline: Option<Arc<AlertRouter>>,
101    /// Output adapter for handling print() results
102    output_adapter: Box<dyn crate::output_adapter::OutputAdapter>,
103    /// Type alias registry for meta parameter overrides
104    /// Maps alias name (e.g., "Percent4") -> (base_type, overrides)
105    type_alias_registry: HashMap<String, TypeAliasRuntimeEntry>,
106    /// Enum definition registry for sum type support
107    enum_registry: EnumRegistry,
108    /// Progress registry for monitoring load operations
109    progress_registry: Option<Arc<super::progress::ProgressRegistry>>,
110    /// Optional JIT kernel compiler for high-performance simulation.
111    /// Set this to enable JIT compilation of simulation kernels.
112    kernel_compiler: Option<Arc<dyn KernelCompiler>>,
113}
114
115/// Runtime entry for a type alias with meta parameter overrides
116#[derive(Debug, Clone)]
117pub struct TypeAliasRuntimeEntry {
118    /// The base type name (e.g., "Percent" for `type Percent4 = Percent { decimals: 4 }`)
119    pub base_type: String,
120    /// Meta parameter overrides as runtime values
121    pub overrides: Option<HashMap<String, ValueWord>>,
122}
123
124/// Registry for enum definitions
125///
126/// Enables enum sum types by tracking which enums exist and their variants.
127/// Used for pattern matching resolution when matching against union types like
128/// `type SaveError = NetworkError | DiskError`.
129#[derive(Debug, Clone, Default)]
130pub struct EnumRegistry {
131    /// Map from enum name to its definition
132    enums: HashMap<String, shape_ast::ast::EnumDef>,
133}
134
135impl EnumRegistry {
136    /// Create a new empty enum registry
137    pub fn new() -> Self {
138        Self {
139            enums: HashMap::new(),
140        }
141    }
142
143    /// Register an enum definition
144    pub fn register(&mut self, enum_def: shape_ast::ast::EnumDef) {
145        self.enums.insert(enum_def.name.clone(), enum_def);
146    }
147
148    /// Look up an enum by name
149    pub fn get(&self, name: &str) -> Option<&shape_ast::ast::EnumDef> {
150        self.enums.get(name)
151    }
152
153    /// Check if an enum exists
154    pub fn contains(&self, name: &str) -> bool {
155        self.enums.contains_key(name)
156    }
157
158    /// Get all enum names
159    pub fn names(&self) -> impl Iterator<Item = &String> {
160        self.enums.keys()
161    }
162
163    /// Check if an enum value belongs to a given enum or union type
164    ///
165    /// For simple enum types, checks if `value_enum_name` matches.
166    /// For union types (resolved from type aliases), checks if the enum
167    /// is one of the union members.
168    pub fn value_matches_type(&self, value_enum_name: &str, type_name: &str) -> bool {
169        // Direct match
170        if value_enum_name == type_name {
171            return true;
172        }
173        // Otherwise, the type_name might be a union type alias
174        // which needs to be resolved externally
175        false
176    }
177}
178
179impl std::fmt::Debug for ExecutionContext {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("ExecutionContext")
182            .field("data_provider", &"<DataProvider>")
183            .field("current_id", &self.current_id)
184            .field("current_row_index", &self.current_row_index)
185            .field("variable_scopes", &self.variable_scopes)
186            .field("reference_datetime", &self.reference_datetime)
187            .field("current_timeframe", &self.current_timeframe)
188            .field("lookahead_guard", &self.lookahead_guard)
189            .finish()
190    }
191}
192
193impl ExecutionContext {
194    /// Create a new execution context with shared type method registry
195    pub fn new_with_registry(
196        data: &DataFrame,
197        type_method_registry: Arc<TypeMethodRegistry>,
198    ) -> Self {
199        // Set current_row_index to last row so [-1] gives most recent value
200        let current_row_index = if data.row_count() == 0 {
201            0
202        } else {
203            data.row_count() - 1
204        };
205
206        Self {
207            data_provider: None,
208            data_cache: None,
209            provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
210            type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
211            type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
212            metadata_registry: Arc::new(MetadataRegistry::new()),
213            data_load_mode: DataLoadMode::default(),
214            current_id: Some(data.id.clone()),
215            current_row_index,
216            variable_scopes: vec![HashMap::new()], // Start with root scope
217            // evaluator: Evaluator::new(),
218            reference_datetime: None,
219            current_timeframe: Some(data.timeframe),
220            base_timeframe: Some(data.timeframe),
221            lookahead_guard: None,
222            type_method_registry,
223            date_range: None,
224            range_start: 0,
225            range_end: usize::MAX,
226            range_active: false,
227            pattern_registry: HashMap::new(),
228            annotation_context: AnnotationContext::new(),
229            annotation_registry: AnnotationRegistry::new(),
230            event_queue: None,
231            suspension_state: None,
232            alert_pipeline: None,
233            output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
234            type_alias_registry: HashMap::new(),
235            enum_registry: EnumRegistry::new(),
236            progress_registry: None,
237            kernel_compiler: None,
238        }
239    }
240
241    /// Create a new execution context
242    pub fn new(data: &DataFrame) -> Self {
243        Self::new_with_registry(data, Arc::new(TypeMethodRegistry::new()))
244    }
245
246    /// Create a new execution context without market data with shared registry
247    pub fn new_empty_with_registry(type_method_registry: Arc<TypeMethodRegistry>) -> Self {
248        Self {
249            data_provider: None,
250            data_cache: None,
251            provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
252            type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
253            type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
254            metadata_registry: Arc::new(MetadataRegistry::new()),
255            data_load_mode: DataLoadMode::default(),
256            current_id: None,
257            current_row_index: 0,
258            variable_scopes: vec![HashMap::new()], // Start with root scope
259            // evaluator: Evaluator::new(),
260            reference_datetime: None,
261            current_timeframe: None,
262            base_timeframe: None,
263            lookahead_guard: None,
264            type_method_registry,
265            date_range: None,
266            range_start: 0,
267            range_end: usize::MAX,
268            range_active: false,
269            pattern_registry: HashMap::new(),
270            annotation_context: AnnotationContext::new(),
271            annotation_registry: AnnotationRegistry::new(),
272            event_queue: None,
273            suspension_state: None,
274            alert_pipeline: None,
275            output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
276            type_alias_registry: HashMap::new(),
277            enum_registry: EnumRegistry::new(),
278            progress_registry: None,
279            kernel_compiler: None,
280        }
281    }
282
283    /// Create a new execution context without market data
284    pub fn new_empty() -> Self {
285        Self::new_empty_with_registry(Arc::new(TypeMethodRegistry::new()))
286    }
287
288    /// Create a new execution context with DuckDB provider and shared registry
289    pub fn with_data_provider_and_registry(
290        data_provider: Arc<dyn std::any::Any + Send + Sync>,
291        type_method_registry: Arc<TypeMethodRegistry>,
292    ) -> Self {
293        Self {
294            data_provider: Some(data_provider),
295            data_cache: None,
296            provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
297            type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
298            type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
299            metadata_registry: Arc::new(MetadataRegistry::new()),
300            data_load_mode: DataLoadMode::default(),
301            current_id: None,
302            current_row_index: 0,
303            variable_scopes: vec![HashMap::new()],
304            // evaluator: Evaluator::new(),
305            reference_datetime: None,
306            current_timeframe: None,
307            base_timeframe: None,
308            lookahead_guard: None,
309            type_method_registry,
310            date_range: None,
311            range_start: 0,
312            range_end: usize::MAX,
313            range_active: false,
314            pattern_registry: HashMap::new(),
315            annotation_context: AnnotationContext::new(),
316            annotation_registry: AnnotationRegistry::new(),
317            event_queue: None,
318            suspension_state: None,
319            alert_pipeline: None,
320            output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
321            type_alias_registry: HashMap::new(),
322            enum_registry: EnumRegistry::new(),
323            progress_registry: None,
324            kernel_compiler: None,
325        }
326    }
327
328    /// Create a new execution context with DuckDB provider
329    pub fn with_data_provider(data_provider: Arc<dyn std::any::Any + Send + Sync>) -> Self {
330        Self::with_data_provider_and_registry(data_provider, Arc::new(TypeMethodRegistry::new()))
331    }
332
333    /// Create with async data provider (Phase 6)
334    ///
335    /// This constructor sets up ExecutionContext with a DataCache for async data loading.
336    /// Call `prefetch_data()` before executing to populate the cache.
337    pub fn with_async_provider(
338        provider: crate::data::SharedAsyncProvider,
339        runtime: tokio::runtime::Handle,
340    ) -> Self {
341        let data_cache = crate::data::DataCache::new(provider, runtime);
342        Self {
343            data_provider: None,
344            data_cache: Some(data_cache),
345            provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
346            type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
347            type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
348            metadata_registry: Arc::new(MetadataRegistry::new()),
349            data_load_mode: DataLoadMode::default(),
350            current_id: None,
351            current_row_index: 0,
352            variable_scopes: vec![HashMap::new()],
353            // evaluator: Evaluator::new(),
354            reference_datetime: None,
355            current_timeframe: None,
356            base_timeframe: None,
357            lookahead_guard: None,
358            type_method_registry: Arc::new(TypeMethodRegistry::new()),
359            date_range: None,
360            range_start: 0,
361            range_end: usize::MAX,
362            range_active: false,
363            pattern_registry: HashMap::new(),
364            annotation_context: AnnotationContext::new(),
365            annotation_registry: AnnotationRegistry::new(),
366            event_queue: None,
367            suspension_state: None,
368            alert_pipeline: None,
369            output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
370            type_alias_registry: HashMap::new(),
371            enum_registry: EnumRegistry::new(),
372            progress_registry: None,
373            kernel_compiler: None,
374        }
375    }
376
377    /// Set the output adapter for print() handling
378    pub fn set_output_adapter(&mut self, adapter: Box<dyn crate::output_adapter::OutputAdapter>) {
379        self.output_adapter = adapter;
380    }
381
382    /// Get mutable reference to output adapter
383    pub fn output_adapter_mut(&mut self) -> &mut Box<dyn crate::output_adapter::OutputAdapter> {
384        &mut self.output_adapter
385    }
386
387    /// Get the metadata registry
388    pub fn metadata_registry(&self) -> &Arc<MetadataRegistry> {
389        &self.metadata_registry
390    }
391
392    // =========================================================================
393    // Type Alias Registry Methods
394    // =========================================================================
395
396    /// Register a type alias for runtime meta resolution
397    ///
398    /// Used when loading stdlib to make type aliases available for formatting.
399    /// Example: `type Percent4 = Percent { decimals: 4 }`
400    pub fn register_type_alias(
401        &mut self,
402        alias_name: &str,
403        base_type: &str,
404        overrides: Option<HashMap<String, ValueWord>>,
405    ) {
406        self.type_alias_registry.insert(
407            alias_name.to_string(),
408            TypeAliasRuntimeEntry {
409                base_type: base_type.to_string(),
410                overrides,
411            },
412        );
413    }
414
415    /// Look up a type alias
416    ///
417    /// Returns the base type name and any parameter overrides.
418    pub fn lookup_type_alias(&self, name: &str) -> Option<&TypeAliasRuntimeEntry> {
419        self.type_alias_registry.get(name)
420    }
421
422    /// Resolve a type name, following aliases if needed
423    ///
424    /// If the type is an alias, returns (base_type, Some(overrides))
425    /// If not an alias, returns (type_name, None)
426    pub fn resolve_type_for_format(
427        &self,
428        type_name: &str,
429    ) -> (String, Option<HashMap<String, ValueWord>>) {
430        if let Some(entry) = self.type_alias_registry.get(type_name) {
431            (entry.base_type.clone(), entry.overrides.clone())
432        } else {
433            (type_name.to_string(), None)
434        }
435    }
436
437    // =========================================================================
438    // Snapshotting
439    // =========================================================================
440
441    /// Create a serializable snapshot of the dynamic execution state.
442    pub fn snapshot(&self, store: &SnapshotStore) -> Result<ContextSnapshot> {
443        let mut scopes = Vec::new();
444        for scope in &self.variable_scopes {
445            let mut snap_scope = HashMap::new();
446            for (name, var) in scope.iter() {
447                let value = nanboxed_to_serializable(&var.value, store)?;
448                let format_overrides = match &var.format_overrides {
449                    Some(map) => {
450                        let mut out = HashMap::new();
451                        for (k, v) in map.iter() {
452                            out.insert(k.clone(), nanboxed_to_serializable(v, store)?);
453                        }
454                        Some(out)
455                    }
456                    None => None,
457                };
458                snap_scope.insert(
459                    name.clone(),
460                    VariableSnapshot {
461                        value,
462                        kind: var.kind,
463                        is_initialized: var.is_initialized,
464                        is_function_scoped: var.is_function_scoped,
465                        format_hint: var.format_hint.clone(),
466                        format_overrides,
467                    },
468                );
469            }
470            scopes.push(snap_scope);
471        }
472
473        let mut alias_registry = HashMap::new();
474        for (name, entry) in self.type_alias_registry.iter() {
475            let overrides = match &entry.overrides {
476                Some(map) => {
477                    let mut out = HashMap::new();
478                    for (k, v) in map.iter() {
479                        out.insert(k.clone(), nanboxed_to_serializable(v, store)?);
480                    }
481                    Some(out)
482                }
483                None => None,
484            };
485            alias_registry.insert(
486                name.clone(),
487                TypeAliasRuntimeEntrySnapshot {
488                    base_type: entry.base_type.clone(),
489                    overrides,
490                },
491            );
492        }
493
494        let enum_registry = self
495            .enum_registry
496            .names()
497            .filter_map(|name| {
498                self.enum_registry
499                    .get(name)
500                    .cloned()
501                    .map(|def| (name.clone(), def))
502            })
503            .collect::<HashMap<_, _>>();
504
505        let suspension_state = match self.suspension_state() {
506            Some(state) => {
507                let mut locals = Vec::new();
508                for v in state.saved_locals.iter() {
509                    locals.push(nanboxed_to_serializable(v, store)?);
510                }
511                let mut stack = Vec::new();
512                for v in state.saved_stack.iter() {
513                    stack.push(nanboxed_to_serializable(v, store)?);
514                }
515                Some(SuspensionStateSnapshot {
516                    waiting_for: state.waiting_for.clone(),
517                    resume_pc: state.resume_pc,
518                    saved_locals: locals,
519                    saved_stack: stack,
520                })
521            }
522            None => None,
523        };
524
525        let data_cache = match &self.data_cache {
526            Some(cache) => Some(cache.snapshot(store)?),
527            None => None,
528        };
529
530        Ok(ContextSnapshot {
531            data_load_mode: self.data_load_mode,
532            data_cache,
533            current_id: self.current_id.clone(),
534            current_row_index: self.current_row_index,
535            variable_scopes: scopes,
536            reference_datetime: self.reference_datetime,
537            current_timeframe: self.current_timeframe,
538            base_timeframe: self.base_timeframe,
539            date_range: self.date_range,
540            range_start: self.range_start,
541            range_end: self.range_end,
542            range_active: self.range_active,
543            type_alias_registry: alias_registry,
544            enum_registry,
545            suspension_state,
546        })
547    }
548
549    /// Restore execution state from a snapshot.
550    pub fn restore_from_snapshot(
551        &mut self,
552        snapshot: ContextSnapshot,
553        store: &SnapshotStore,
554    ) -> Result<()> {
555        self.data_load_mode = snapshot.data_load_mode;
556        self.current_id = snapshot.current_id;
557        self.current_row_index = snapshot.current_row_index;
558        self.reference_datetime = snapshot.reference_datetime;
559        self.current_timeframe = snapshot.current_timeframe;
560        self.base_timeframe = snapshot.base_timeframe;
561        self.date_range = snapshot.date_range;
562        self.range_start = snapshot.range_start;
563        self.range_end = snapshot.range_end;
564        self.range_active = snapshot.range_active;
565
566        match snapshot.data_cache {
567            Some(cache_snapshot) => {
568                if let Some(cache) = &self.data_cache {
569                    cache.restore_from_snapshot(cache_snapshot, store)?;
570                } else {
571                    return Err(anyhow!(
572                        "Snapshot includes data cache, but context has no async provider"
573                    ));
574                }
575            }
576            None => {
577                if let Some(cache) = &self.data_cache {
578                    cache.clear();
579                }
580            }
581        }
582
583        self.variable_scopes.clear();
584        for scope in snapshot.variable_scopes.into_iter() {
585            let mut restored = HashMap::new();
586            for (name, var) in scope.into_iter() {
587                let value = serializable_to_nanboxed(&var.value, store)?;
588                let format_overrides = match var.format_overrides {
589                    Some(map) => {
590                        let mut out = HashMap::new();
591                        for (k, v) in map.into_iter() {
592                            out.insert(k, serializable_to_nanboxed(&v, store)?);
593                        }
594                        Some(out)
595                    }
596                    None => None,
597                };
598                restored.insert(
599                    name,
600                    Variable {
601                        value,
602                        kind: var.kind,
603                        is_initialized: var.is_initialized,
604                        is_function_scoped: var.is_function_scoped,
605                        format_hint: var.format_hint,
606                        format_overrides,
607                    },
608                );
609            }
610            self.variable_scopes.push(restored);
611        }
612
613        self.type_alias_registry.clear();
614        for (name, entry) in snapshot.type_alias_registry.into_iter() {
615            let overrides = match entry.overrides {
616                Some(map) => {
617                    let mut out = HashMap::new();
618                    for (k, v) in map.into_iter() {
619                        out.insert(k, serializable_to_nanboxed(&v, store)?);
620                    }
621                    Some(out)
622                }
623                None => None,
624            };
625            self.type_alias_registry.insert(
626                name,
627                TypeAliasRuntimeEntry {
628                    base_type: entry.base_type,
629                    overrides,
630                },
631            );
632        }
633
634        self.enum_registry = EnumRegistry::default();
635        for (_name, def) in snapshot.enum_registry.into_iter() {
636            self.enum_registry.register(def);
637        }
638
639        if let Some(state) = snapshot.suspension_state {
640            let mut locals = Vec::new();
641            for v in state.saved_locals.into_iter() {
642                locals.push(serializable_to_nanboxed(&v, store)?);
643            }
644            let mut stack = Vec::new();
645            for v in state.saved_stack.into_iter() {
646                stack.push(serializable_to_nanboxed(&v, store)?);
647            }
648            self.set_suspension_state(
649                SuspensionState::new(state.waiting_for, state.resume_pc)
650                    .with_locals(locals)
651                    .with_stack(stack),
652            );
653        } else {
654            self.clear_suspension_state();
655        }
656
657        // Note: output_adapter is NOT restored from snapshot.
658        // It's set by the caller (StdoutAdapter for scripts, ReplAdapter for REPL).
659
660        Ok(())
661    }
662
663    /// Set indicator cache
664
665    /// Set the event queue for async operations
666    pub fn set_event_queue(&mut self, queue: SharedEventQueue) {
667        self.event_queue = Some(queue);
668    }
669
670    /// Get the event queue
671    pub fn event_queue(&self) -> Option<&SharedEventQueue> {
672        self.event_queue.as_ref()
673    }
674
675    /// Get mutable reference to event queue
676    pub fn event_queue_mut(&mut self) -> Option<&mut SharedEventQueue> {
677        self.event_queue.as_mut()
678    }
679
680    /// Set suspension state (called when yielding/suspending)
681    pub fn set_suspension_state(&mut self, state: SuspensionState) {
682        self.suspension_state = Some(state);
683    }
684
685    /// Get suspension state
686    pub fn suspension_state(&self) -> Option<&SuspensionState> {
687        self.suspension_state.as_ref()
688    }
689
690    /// Clear suspension state (called when resuming)
691    pub fn clear_suspension_state(&mut self) -> Option<SuspensionState> {
692        self.suspension_state.take()
693    }
694
695    /// Check if execution is suspended
696    pub fn is_suspended(&self) -> bool {
697        self.suspension_state.is_some()
698    }
699
700    /// Set the alert pipeline for routing alerts to sinks
701    pub fn set_alert_pipeline(&mut self, pipeline: Arc<AlertRouter>) {
702        self.alert_pipeline = Some(pipeline);
703    }
704
705    /// Get the alert pipeline
706    pub fn alert_pipeline(&self) -> Option<&Arc<AlertRouter>> {
707        self.alert_pipeline.as_ref()
708    }
709
710    /// Emit an alert through the pipeline
711    pub fn emit_alert(&self, alert: super::alerts::Alert) {
712        if let Some(pipeline) = &self.alert_pipeline {
713            pipeline.emit(alert);
714        }
715    }
716
717    /// Set the progress registry for monitoring load operations
718    pub fn set_progress_registry(&mut self, registry: Arc<super::progress::ProgressRegistry>) {
719        self.progress_registry = Some(registry);
720    }
721
722    /// Get the progress registry
723    pub fn progress_registry(&self) -> Option<&Arc<super::progress::ProgressRegistry>> {
724        self.progress_registry.as_ref()
725    }
726
727    /// Set the JIT kernel compiler for high-performance simulation.
728    ///
729    /// This enables JIT compilation of simulation kernels when the state is a TypedObject.
730    /// The compiler should be an instance of `shape_jit::JITCompiler` wrapped in Arc.
731    pub fn set_kernel_compiler(&mut self, compiler: Arc<dyn KernelCompiler>) {
732        self.kernel_compiler = Some(compiler);
733    }
734
735    /// Get the JIT kernel compiler, if set.
736    pub fn kernel_compiler(&self) -> Option<&Arc<dyn KernelCompiler>> {
737        self.kernel_compiler.as_ref()
738    }
739}
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744    use crate::data::{AsyncDataProvider, CacheKey, DataQuery, NullAsyncProvider, Timeframe};
745    use crate::snapshot::SnapshotStore;
746    use shape_ast::ast::VarKind;
747    use std::collections::HashMap;
748    use std::sync::Arc;
749    use std::sync::atomic::{AtomicUsize, Ordering};
750    use std::time::{SystemTime, UNIX_EPOCH};
751
752    #[test]
753    fn test_execution_context_new_empty() {
754        let ctx = ExecutionContext::new_empty();
755        assert_eq!(ctx.current_row_index(), 0);
756    }
757
758    #[test]
759    fn test_execution_context_set_current_row() {
760        let mut ctx = ExecutionContext::new_empty();
761        ctx.set_current_row(5).unwrap();
762        assert_eq!(ctx.current_row_index(), 5);
763    }
764
765    #[test]
766    fn test_execution_context_variable_scope() {
767        let mut ctx = ExecutionContext::new_empty();
768
769        // Set a variable using the public API
770        ctx.set_variable("x", ValueWord::from_f64(10.0))
771            .unwrap_or_else(|_| {
772                // Variable doesn't exist yet, need to create it first
773                // This is expected - we test that set_variable fails for undefined vars
774            });
775    }
776
777    // =========================================================================
778    // Type Alias Registry Tests
779    // =========================================================================
780
781    #[test]
782    fn test_type_alias_registry_basic() {
783        let mut ctx = ExecutionContext::new_empty();
784
785        // Register a type alias: type Percent4 = Percent { decimals: 4 }
786        let mut overrides = HashMap::new();
787        overrides.insert("decimals".to_string(), ValueWord::from_f64(4.0));
788        ctx.register_type_alias("Percent4", "Percent", Some(overrides));
789
790        // Look up the alias
791        let entry = ctx.lookup_type_alias("Percent4");
792        assert!(entry.is_some());
793        let entry = entry.unwrap();
794        assert_eq!(entry.base_type, "Percent");
795        assert!(entry.overrides.is_some());
796
797        let overrides = entry.overrides.as_ref().unwrap();
798        assert_eq!(
799            overrides.get("decimals").and_then(|v| v.as_f64()),
800            Some(4.0)
801        );
802    }
803
804    #[test]
805    fn test_type_alias_registry_no_overrides() {
806        let mut ctx = ExecutionContext::new_empty();
807
808        // Register a type alias without overrides
809        ctx.register_type_alias("MyPercent", "Percent", None);
810
811        let entry = ctx.lookup_type_alias("MyPercent");
812        assert!(entry.is_some());
813        let entry = entry.unwrap();
814        assert_eq!(entry.base_type, "Percent");
815        assert!(entry.overrides.is_none());
816    }
817
818    #[test]
819    fn test_type_alias_registry_unknown_type() {
820        let ctx = ExecutionContext::new_empty();
821
822        // Look up a non-existent alias
823        let entry = ctx.lookup_type_alias("NonExistent");
824        assert!(entry.is_none());
825    }
826
827    #[test]
828    fn test_resolve_type_for_format_alias() {
829        let mut ctx = ExecutionContext::new_empty();
830
831        // Register a type alias
832        let mut overrides = HashMap::new();
833        overrides.insert("decimals".to_string(), ValueWord::from_f64(4.0));
834        ctx.register_type_alias("Percent4", "Percent", Some(overrides.clone()));
835
836        // Resolve the alias
837        let (base_type, resolved_overrides) = ctx.resolve_type_for_format("Percent4");
838        assert_eq!(base_type, "Percent");
839        assert!(resolved_overrides.is_some());
840        assert_eq!(
841            resolved_overrides
842                .unwrap()
843                .get("decimals")
844                .and_then(|v| v.as_f64()),
845            Some(4.0)
846        );
847    }
848
849    #[test]
850    fn test_resolve_type_for_format_non_alias() {
851        let ctx = ExecutionContext::new_empty();
852
853        // Resolve a non-alias type
854        let (base_type, resolved_overrides) = ctx.resolve_type_for_format("Number");
855        assert_eq!(base_type, "Number");
856        assert!(resolved_overrides.is_none());
857    }
858
859    #[derive(Clone)]
860    struct TestAsyncProvider {
861        frames: Arc<HashMap<CacheKey, DataFrame>>,
862        load_calls: Arc<AtomicUsize>,
863    }
864
865    impl AsyncDataProvider for TestAsyncProvider {
866        fn load<'a>(
867            &'a self,
868            query: &'a DataQuery,
869        ) -> std::pin::Pin<
870            Box<
871                dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
872                    + Send
873                    + 'a,
874            >,
875        > {
876            let key = CacheKey::new(query.id.clone(), query.timeframe);
877            let frames = self.frames.clone();
878            let calls = self.load_calls.clone();
879            Box::pin(async move {
880                calls.fetch_add(1, Ordering::SeqCst);
881                frames
882                    .get(&key)
883                    .cloned()
884                    .ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
885            })
886        }
887
888        fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
889            let key = CacheKey::new(symbol.to_string(), *timeframe);
890            self.frames.contains_key(&key)
891        }
892
893        fn symbols(&self) -> Vec<String> {
894            self.frames.keys().map(|k| k.id.clone()).collect()
895        }
896    }
897
898    fn temp_snapshot_root(name: &str) -> std::path::PathBuf {
899        let ts = SystemTime::now()
900            .duration_since(UNIX_EPOCH)
901            .unwrap()
902            .as_millis();
903        std::env::temp_dir().join(format!("shape_ctx_snapshot_{}_{}", name, ts))
904    }
905
906    fn make_df(id: &str, timeframe: Timeframe) -> DataFrame {
907        let mut df = DataFrame::new(id, timeframe);
908        df.timestamps = vec![1, 2, 3];
909        df.add_column("a", vec![10.0, 11.0, 12.0]);
910        df
911    }
912
913    #[tokio::test]
914    async fn test_execution_context_snapshot_restores_data_cache() {
915        let tf = Timeframe::d1();
916        let df = make_df("TEST", tf);
917        let mut frames = HashMap::new();
918        frames.insert(CacheKey::new("TEST".to_string(), tf), df);
919        let load_calls = Arc::new(AtomicUsize::new(0));
920        let provider = Arc::new(TestAsyncProvider {
921            frames: Arc::new(frames),
922            load_calls: load_calls.clone(),
923        });
924
925        let mut ctx =
926            ExecutionContext::with_async_provider(provider, tokio::runtime::Handle::current());
927        ctx.prefetch_data(vec![DataQuery::new("TEST", tf)])
928            .await
929            .unwrap();
930        ctx.declare_variable("x", VarKind::Let, Some(ValueWord::from_f64(42.0)))
931            .unwrap();
932
933        let store = SnapshotStore::new(temp_snapshot_root("context_cache")).unwrap();
934        let snapshot = ctx.snapshot(&store).unwrap();
935
936        let mut restored = ExecutionContext::with_async_provider(
937            Arc::new(NullAsyncProvider::default()),
938            tokio::runtime::Handle::current(),
939        );
940        restored.restore_from_snapshot(snapshot, &store).unwrap();
941
942        let val = restored.get_variable("x").unwrap();
943        assert_eq!(val, Some(ValueWord::from_f64(42.0)));
944
945        let row = restored
946            .data_cache()
947            .unwrap()
948            .get_row("TEST", &tf, 0)
949            .expect("row should be cached");
950        assert_eq!(row.timestamp, 1);
951        assert_eq!(row.fields.get("a"), Some(&10.0));
952
953        assert_eq!(load_calls.load(Ordering::SeqCst), 1);
954    }
955}