1mod config;
11mod data_access;
12mod data_cache;
13mod registries;
14mod scope;
15mod variables;
16
17pub use data_cache::DataLoadMode;
19pub use variables::Variable;
20
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use super::alerts::AlertRouter;
25use super::annotation_context::{AnnotationContext, AnnotationRegistry};
26use super::data::DataFrame;
27use super::event_queue::{SharedEventQueue, SuspensionState};
28use super::lookahead_guard::LookAheadGuard;
29use super::metadata::MetadataRegistry;
30use super::simulation::KernelCompiler;
31use super::type_methods::TypeMethodRegistry;
32use super::type_schema::TypeSchemaRegistry;
33use crate::data::Timeframe;
34use crate::snapshot::{
35 ContextSnapshot, SnapshotStore, SuspensionStateSnapshot, TypeAliasRuntimeEntrySnapshot,
36 VariableSnapshot, nanboxed_to_serializable, serializable_to_nanboxed,
37};
38use anyhow::{Result, anyhow};
39use chrono::{DateTime, Utc};
40use shape_value::ValueWord;
41
42#[derive(Clone)]
44pub struct ExecutionContext {
45 data_provider: Option<Arc<dyn std::any::Any + Send + Sync>>,
47 pub(crate) data_cache: Option<crate::data::DataCache>,
50 provider_registry: Arc<super::provider_registry::ProviderRegistry>,
52 type_mapping_registry: Arc<super::type_mapping::TypeMappingRegistry>,
54 type_schema_registry: Arc<TypeSchemaRegistry>,
56 metadata_registry: Arc<MetadataRegistry>,
58 data_load_mode: DataLoadMode,
60 current_id: Option<String>,
62 current_row_index: usize,
64 variable_scopes: Vec<HashMap<String, Variable>>,
66 reference_datetime: Option<DateTime<Utc>>,
71 current_timeframe: Option<Timeframe>,
73 base_timeframe: Option<Timeframe>,
75 lookahead_guard: Option<LookAheadGuard>,
77 type_method_registry: Arc<TypeMethodRegistry>,
79 date_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
81 range_start: usize,
83 range_end: usize,
85 range_active: bool,
87 pattern_registry: HashMap<String, super::closure::Closure>,
91 annotation_context: AnnotationContext,
93 annotation_registry: AnnotationRegistry,
95 event_queue: Option<SharedEventQueue>,
97 suspension_state: Option<SuspensionState>,
99 alert_pipeline: Option<Arc<AlertRouter>>,
101 output_adapter: Box<dyn crate::output_adapter::OutputAdapter>,
103 type_alias_registry: HashMap<String, TypeAliasRuntimeEntry>,
106 enum_registry: EnumRegistry,
108 progress_registry: Option<Arc<super::progress::ProgressRegistry>>,
110 kernel_compiler: Option<Arc<dyn KernelCompiler>>,
113}
114
115#[derive(Debug, Clone)]
117pub struct TypeAliasRuntimeEntry {
118 pub base_type: String,
120 pub overrides: Option<HashMap<String, ValueWord>>,
122}
123
124#[derive(Debug, Clone, Default)]
130pub struct EnumRegistry {
131 enums: HashMap<String, shape_ast::ast::EnumDef>,
133}
134
135impl EnumRegistry {
136 pub fn new() -> Self {
138 Self {
139 enums: HashMap::new(),
140 }
141 }
142
143 pub fn register(&mut self, enum_def: shape_ast::ast::EnumDef) {
145 self.enums.insert(enum_def.name.clone(), enum_def);
146 }
147
148 pub fn get(&self, name: &str) -> Option<&shape_ast::ast::EnumDef> {
150 self.enums.get(name)
151 }
152
153 pub fn contains(&self, name: &str) -> bool {
155 self.enums.contains_key(name)
156 }
157
158 pub fn names(&self) -> impl Iterator<Item = &String> {
160 self.enums.keys()
161 }
162
163 pub fn value_matches_type(&self, value_enum_name: &str, type_name: &str) -> bool {
169 if value_enum_name == type_name {
171 return true;
172 }
173 false
176 }
177}
178
179impl std::fmt::Debug for ExecutionContext {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("ExecutionContext")
182 .field("data_provider", &"<DataProvider>")
183 .field("current_id", &self.current_id)
184 .field("current_row_index", &self.current_row_index)
185 .field("variable_scopes", &self.variable_scopes)
186 .field("reference_datetime", &self.reference_datetime)
187 .field("current_timeframe", &self.current_timeframe)
188 .field("lookahead_guard", &self.lookahead_guard)
189 .finish()
190 }
191}
192
193impl ExecutionContext {
194 pub fn new_with_registry(
196 data: &DataFrame,
197 type_method_registry: Arc<TypeMethodRegistry>,
198 ) -> Self {
199 let current_row_index = if data.row_count() == 0 {
201 0
202 } else {
203 data.row_count() - 1
204 };
205
206 Self {
207 data_provider: None,
208 data_cache: None,
209 provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
210 type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
211 type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
212 metadata_registry: Arc::new(MetadataRegistry::new()),
213 data_load_mode: DataLoadMode::default(),
214 current_id: Some(data.id.clone()),
215 current_row_index,
216 variable_scopes: vec![HashMap::new()], reference_datetime: None,
219 current_timeframe: Some(data.timeframe),
220 base_timeframe: Some(data.timeframe),
221 lookahead_guard: None,
222 type_method_registry,
223 date_range: None,
224 range_start: 0,
225 range_end: usize::MAX,
226 range_active: false,
227 pattern_registry: HashMap::new(),
228 annotation_context: AnnotationContext::new(),
229 annotation_registry: AnnotationRegistry::new(),
230 event_queue: None,
231 suspension_state: None,
232 alert_pipeline: None,
233 output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
234 type_alias_registry: HashMap::new(),
235 enum_registry: EnumRegistry::new(),
236 progress_registry: None,
237 kernel_compiler: None,
238 }
239 }
240
241 pub fn new(data: &DataFrame) -> Self {
243 Self::new_with_registry(data, Arc::new(TypeMethodRegistry::new()))
244 }
245
246 pub fn new_empty_with_registry(type_method_registry: Arc<TypeMethodRegistry>) -> Self {
248 Self {
249 data_provider: None,
250 data_cache: None,
251 provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
252 type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
253 type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
254 metadata_registry: Arc::new(MetadataRegistry::new()),
255 data_load_mode: DataLoadMode::default(),
256 current_id: None,
257 current_row_index: 0,
258 variable_scopes: vec![HashMap::new()], reference_datetime: None,
261 current_timeframe: None,
262 base_timeframe: None,
263 lookahead_guard: None,
264 type_method_registry,
265 date_range: None,
266 range_start: 0,
267 range_end: usize::MAX,
268 range_active: false,
269 pattern_registry: HashMap::new(),
270 annotation_context: AnnotationContext::new(),
271 annotation_registry: AnnotationRegistry::new(),
272 event_queue: None,
273 suspension_state: None,
274 alert_pipeline: None,
275 output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
276 type_alias_registry: HashMap::new(),
277 enum_registry: EnumRegistry::new(),
278 progress_registry: None,
279 kernel_compiler: None,
280 }
281 }
282
283 pub fn new_empty() -> Self {
285 Self::new_empty_with_registry(Arc::new(TypeMethodRegistry::new()))
286 }
287
288 pub fn with_data_provider_and_registry(
290 data_provider: Arc<dyn std::any::Any + Send + Sync>,
291 type_method_registry: Arc<TypeMethodRegistry>,
292 ) -> Self {
293 Self {
294 data_provider: Some(data_provider),
295 data_cache: None,
296 provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
297 type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
298 type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
299 metadata_registry: Arc::new(MetadataRegistry::new()),
300 data_load_mode: DataLoadMode::default(),
301 current_id: None,
302 current_row_index: 0,
303 variable_scopes: vec![HashMap::new()],
304 reference_datetime: None,
306 current_timeframe: None,
307 base_timeframe: None,
308 lookahead_guard: None,
309 type_method_registry,
310 date_range: None,
311 range_start: 0,
312 range_end: usize::MAX,
313 range_active: false,
314 pattern_registry: HashMap::new(),
315 annotation_context: AnnotationContext::new(),
316 annotation_registry: AnnotationRegistry::new(),
317 event_queue: None,
318 suspension_state: None,
319 alert_pipeline: None,
320 output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
321 type_alias_registry: HashMap::new(),
322 enum_registry: EnumRegistry::new(),
323 progress_registry: None,
324 kernel_compiler: None,
325 }
326 }
327
328 pub fn with_data_provider(data_provider: Arc<dyn std::any::Any + Send + Sync>) -> Self {
330 Self::with_data_provider_and_registry(data_provider, Arc::new(TypeMethodRegistry::new()))
331 }
332
333 pub fn with_async_provider(
338 provider: crate::data::SharedAsyncProvider,
339 runtime: tokio::runtime::Handle,
340 ) -> Self {
341 let data_cache = crate::data::DataCache::new(provider, runtime);
342 Self {
343 data_provider: None,
344 data_cache: Some(data_cache),
345 provider_registry: Arc::new(super::provider_registry::ProviderRegistry::new()),
346 type_mapping_registry: Arc::new(super::type_mapping::TypeMappingRegistry::new()),
347 type_schema_registry: Arc::new(TypeSchemaRegistry::with_stdlib_types()),
348 metadata_registry: Arc::new(MetadataRegistry::new()),
349 data_load_mode: DataLoadMode::default(),
350 current_id: None,
351 current_row_index: 0,
352 variable_scopes: vec![HashMap::new()],
353 reference_datetime: None,
355 current_timeframe: None,
356 base_timeframe: None,
357 lookahead_guard: None,
358 type_method_registry: Arc::new(TypeMethodRegistry::new()),
359 date_range: None,
360 range_start: 0,
361 range_end: usize::MAX,
362 range_active: false,
363 pattern_registry: HashMap::new(),
364 annotation_context: AnnotationContext::new(),
365 annotation_registry: AnnotationRegistry::new(),
366 event_queue: None,
367 suspension_state: None,
368 alert_pipeline: None,
369 output_adapter: Box::new(crate::output_adapter::StdoutAdapter),
370 type_alias_registry: HashMap::new(),
371 enum_registry: EnumRegistry::new(),
372 progress_registry: None,
373 kernel_compiler: None,
374 }
375 }
376
377 pub fn set_output_adapter(&mut self, adapter: Box<dyn crate::output_adapter::OutputAdapter>) {
379 self.output_adapter = adapter;
380 }
381
382 pub fn output_adapter_mut(&mut self) -> &mut Box<dyn crate::output_adapter::OutputAdapter> {
384 &mut self.output_adapter
385 }
386
387 pub fn metadata_registry(&self) -> &Arc<MetadataRegistry> {
389 &self.metadata_registry
390 }
391
392 pub fn register_type_alias(
401 &mut self,
402 alias_name: &str,
403 base_type: &str,
404 overrides: Option<HashMap<String, ValueWord>>,
405 ) {
406 self.type_alias_registry.insert(
407 alias_name.to_string(),
408 TypeAliasRuntimeEntry {
409 base_type: base_type.to_string(),
410 overrides,
411 },
412 );
413 }
414
415 pub fn lookup_type_alias(&self, name: &str) -> Option<&TypeAliasRuntimeEntry> {
419 self.type_alias_registry.get(name)
420 }
421
422 pub fn resolve_type_for_format(
427 &self,
428 type_name: &str,
429 ) -> (String, Option<HashMap<String, ValueWord>>) {
430 if let Some(entry) = self.type_alias_registry.get(type_name) {
431 (entry.base_type.clone(), entry.overrides.clone())
432 } else {
433 (type_name.to_string(), None)
434 }
435 }
436
437 pub fn snapshot(&self, store: &SnapshotStore) -> Result<ContextSnapshot> {
443 let mut scopes = Vec::new();
444 for scope in &self.variable_scopes {
445 let mut snap_scope = HashMap::new();
446 for (name, var) in scope.iter() {
447 let value = nanboxed_to_serializable(&var.value, store)?;
448 let format_overrides = match &var.format_overrides {
449 Some(map) => {
450 let mut out = HashMap::new();
451 for (k, v) in map.iter() {
452 out.insert(k.clone(), nanboxed_to_serializable(v, store)?);
453 }
454 Some(out)
455 }
456 None => None,
457 };
458 snap_scope.insert(
459 name.clone(),
460 VariableSnapshot {
461 value,
462 kind: var.kind,
463 is_initialized: var.is_initialized,
464 is_function_scoped: var.is_function_scoped,
465 format_hint: var.format_hint.clone(),
466 format_overrides,
467 },
468 );
469 }
470 scopes.push(snap_scope);
471 }
472
473 let mut alias_registry = HashMap::new();
474 for (name, entry) in self.type_alias_registry.iter() {
475 let overrides = match &entry.overrides {
476 Some(map) => {
477 let mut out = HashMap::new();
478 for (k, v) in map.iter() {
479 out.insert(k.clone(), nanboxed_to_serializable(v, store)?);
480 }
481 Some(out)
482 }
483 None => None,
484 };
485 alias_registry.insert(
486 name.clone(),
487 TypeAliasRuntimeEntrySnapshot {
488 base_type: entry.base_type.clone(),
489 overrides,
490 },
491 );
492 }
493
494 let enum_registry = self
495 .enum_registry
496 .names()
497 .filter_map(|name| {
498 self.enum_registry
499 .get(name)
500 .cloned()
501 .map(|def| (name.clone(), def))
502 })
503 .collect::<HashMap<_, _>>();
504
505 let suspension_state = match self.suspension_state() {
506 Some(state) => {
507 let mut locals = Vec::new();
508 for v in state.saved_locals.iter() {
509 locals.push(nanboxed_to_serializable(v, store)?);
510 }
511 let mut stack = Vec::new();
512 for v in state.saved_stack.iter() {
513 stack.push(nanboxed_to_serializable(v, store)?);
514 }
515 Some(SuspensionStateSnapshot {
516 waiting_for: state.waiting_for.clone(),
517 resume_pc: state.resume_pc,
518 saved_locals: locals,
519 saved_stack: stack,
520 })
521 }
522 None => None,
523 };
524
525 let data_cache = match &self.data_cache {
526 Some(cache) => Some(cache.snapshot(store)?),
527 None => None,
528 };
529
530 Ok(ContextSnapshot {
531 data_load_mode: self.data_load_mode,
532 data_cache,
533 current_id: self.current_id.clone(),
534 current_row_index: self.current_row_index,
535 variable_scopes: scopes,
536 reference_datetime: self.reference_datetime,
537 current_timeframe: self.current_timeframe,
538 base_timeframe: self.base_timeframe,
539 date_range: self.date_range,
540 range_start: self.range_start,
541 range_end: self.range_end,
542 range_active: self.range_active,
543 type_alias_registry: alias_registry,
544 enum_registry,
545 suspension_state,
546 })
547 }
548
549 pub fn restore_from_snapshot(
551 &mut self,
552 snapshot: ContextSnapshot,
553 store: &SnapshotStore,
554 ) -> Result<()> {
555 self.data_load_mode = snapshot.data_load_mode;
556 self.current_id = snapshot.current_id;
557 self.current_row_index = snapshot.current_row_index;
558 self.reference_datetime = snapshot.reference_datetime;
559 self.current_timeframe = snapshot.current_timeframe;
560 self.base_timeframe = snapshot.base_timeframe;
561 self.date_range = snapshot.date_range;
562 self.range_start = snapshot.range_start;
563 self.range_end = snapshot.range_end;
564 self.range_active = snapshot.range_active;
565
566 match snapshot.data_cache {
567 Some(cache_snapshot) => {
568 if let Some(cache) = &self.data_cache {
569 cache.restore_from_snapshot(cache_snapshot, store)?;
570 } else {
571 return Err(anyhow!(
572 "Snapshot includes data cache, but context has no async provider"
573 ));
574 }
575 }
576 None => {
577 if let Some(cache) = &self.data_cache {
578 cache.clear();
579 }
580 }
581 }
582
583 self.variable_scopes.clear();
584 for scope in snapshot.variable_scopes.into_iter() {
585 let mut restored = HashMap::new();
586 for (name, var) in scope.into_iter() {
587 let value = serializable_to_nanboxed(&var.value, store)?;
588 let format_overrides = match var.format_overrides {
589 Some(map) => {
590 let mut out = HashMap::new();
591 for (k, v) in map.into_iter() {
592 out.insert(k, serializable_to_nanboxed(&v, store)?);
593 }
594 Some(out)
595 }
596 None => None,
597 };
598 restored.insert(
599 name,
600 Variable {
601 value,
602 kind: var.kind,
603 is_initialized: var.is_initialized,
604 is_function_scoped: var.is_function_scoped,
605 format_hint: var.format_hint,
606 format_overrides,
607 },
608 );
609 }
610 self.variable_scopes.push(restored);
611 }
612
613 self.type_alias_registry.clear();
614 for (name, entry) in snapshot.type_alias_registry.into_iter() {
615 let overrides = match entry.overrides {
616 Some(map) => {
617 let mut out = HashMap::new();
618 for (k, v) in map.into_iter() {
619 out.insert(k, serializable_to_nanboxed(&v, store)?);
620 }
621 Some(out)
622 }
623 None => None,
624 };
625 self.type_alias_registry.insert(
626 name,
627 TypeAliasRuntimeEntry {
628 base_type: entry.base_type,
629 overrides,
630 },
631 );
632 }
633
634 self.enum_registry = EnumRegistry::default();
635 for (_name, def) in snapshot.enum_registry.into_iter() {
636 self.enum_registry.register(def);
637 }
638
639 if let Some(state) = snapshot.suspension_state {
640 let mut locals = Vec::new();
641 for v in state.saved_locals.into_iter() {
642 locals.push(serializable_to_nanboxed(&v, store)?);
643 }
644 let mut stack = Vec::new();
645 for v in state.saved_stack.into_iter() {
646 stack.push(serializable_to_nanboxed(&v, store)?);
647 }
648 self.set_suspension_state(
649 SuspensionState::new(state.waiting_for, state.resume_pc)
650 .with_locals(locals)
651 .with_stack(stack),
652 );
653 } else {
654 self.clear_suspension_state();
655 }
656
657 Ok(())
661 }
662
663 pub fn set_event_queue(&mut self, queue: SharedEventQueue) {
667 self.event_queue = Some(queue);
668 }
669
670 pub fn event_queue(&self) -> Option<&SharedEventQueue> {
672 self.event_queue.as_ref()
673 }
674
675 pub fn event_queue_mut(&mut self) -> Option<&mut SharedEventQueue> {
677 self.event_queue.as_mut()
678 }
679
680 pub fn set_suspension_state(&mut self, state: SuspensionState) {
682 self.suspension_state = Some(state);
683 }
684
685 pub fn suspension_state(&self) -> Option<&SuspensionState> {
687 self.suspension_state.as_ref()
688 }
689
690 pub fn clear_suspension_state(&mut self) -> Option<SuspensionState> {
692 self.suspension_state.take()
693 }
694
695 pub fn is_suspended(&self) -> bool {
697 self.suspension_state.is_some()
698 }
699
700 pub fn set_alert_pipeline(&mut self, pipeline: Arc<AlertRouter>) {
702 self.alert_pipeline = Some(pipeline);
703 }
704
705 pub fn alert_pipeline(&self) -> Option<&Arc<AlertRouter>> {
707 self.alert_pipeline.as_ref()
708 }
709
710 pub fn emit_alert(&self, alert: super::alerts::Alert) {
712 if let Some(pipeline) = &self.alert_pipeline {
713 pipeline.emit(alert);
714 }
715 }
716
717 pub fn set_progress_registry(&mut self, registry: Arc<super::progress::ProgressRegistry>) {
719 self.progress_registry = Some(registry);
720 }
721
722 pub fn progress_registry(&self) -> Option<&Arc<super::progress::ProgressRegistry>> {
724 self.progress_registry.as_ref()
725 }
726
727 pub fn set_kernel_compiler(&mut self, compiler: Arc<dyn KernelCompiler>) {
732 self.kernel_compiler = Some(compiler);
733 }
734
735 pub fn kernel_compiler(&self) -> Option<&Arc<dyn KernelCompiler>> {
737 self.kernel_compiler.as_ref()
738 }
739}
740
741#[cfg(test)]
742mod tests {
743 use super::*;
744 use crate::data::{AsyncDataProvider, CacheKey, DataQuery, NullAsyncProvider, Timeframe};
745 use crate::snapshot::SnapshotStore;
746 use shape_ast::ast::VarKind;
747 use std::collections::HashMap;
748 use std::sync::Arc;
749 use std::sync::atomic::{AtomicUsize, Ordering};
750 use std::time::{SystemTime, UNIX_EPOCH};
751
752 #[test]
753 fn test_execution_context_new_empty() {
754 let ctx = ExecutionContext::new_empty();
755 assert_eq!(ctx.current_row_index(), 0);
756 }
757
758 #[test]
759 fn test_execution_context_set_current_row() {
760 let mut ctx = ExecutionContext::new_empty();
761 ctx.set_current_row(5).unwrap();
762 assert_eq!(ctx.current_row_index(), 5);
763 }
764
765 #[test]
766 fn test_execution_context_variable_scope() {
767 let mut ctx = ExecutionContext::new_empty();
768
769 ctx.set_variable("x", ValueWord::from_f64(10.0))
771 .unwrap_or_else(|_| {
772 });
775 }
776
777 #[test]
782 fn test_type_alias_registry_basic() {
783 let mut ctx = ExecutionContext::new_empty();
784
785 let mut overrides = HashMap::new();
787 overrides.insert("decimals".to_string(), ValueWord::from_f64(4.0));
788 ctx.register_type_alias("Percent4", "Percent", Some(overrides));
789
790 let entry = ctx.lookup_type_alias("Percent4");
792 assert!(entry.is_some());
793 let entry = entry.unwrap();
794 assert_eq!(entry.base_type, "Percent");
795 assert!(entry.overrides.is_some());
796
797 let overrides = entry.overrides.as_ref().unwrap();
798 assert_eq!(
799 overrides.get("decimals").and_then(|v| v.as_f64()),
800 Some(4.0)
801 );
802 }
803
804 #[test]
805 fn test_type_alias_registry_no_overrides() {
806 let mut ctx = ExecutionContext::new_empty();
807
808 ctx.register_type_alias("MyPercent", "Percent", None);
810
811 let entry = ctx.lookup_type_alias("MyPercent");
812 assert!(entry.is_some());
813 let entry = entry.unwrap();
814 assert_eq!(entry.base_type, "Percent");
815 assert!(entry.overrides.is_none());
816 }
817
818 #[test]
819 fn test_type_alias_registry_unknown_type() {
820 let ctx = ExecutionContext::new_empty();
821
822 let entry = ctx.lookup_type_alias("NonExistent");
824 assert!(entry.is_none());
825 }
826
827 #[test]
828 fn test_resolve_type_for_format_alias() {
829 let mut ctx = ExecutionContext::new_empty();
830
831 let mut overrides = HashMap::new();
833 overrides.insert("decimals".to_string(), ValueWord::from_f64(4.0));
834 ctx.register_type_alias("Percent4", "Percent", Some(overrides.clone()));
835
836 let (base_type, resolved_overrides) = ctx.resolve_type_for_format("Percent4");
838 assert_eq!(base_type, "Percent");
839 assert!(resolved_overrides.is_some());
840 assert_eq!(
841 resolved_overrides
842 .unwrap()
843 .get("decimals")
844 .and_then(|v| v.as_f64()),
845 Some(4.0)
846 );
847 }
848
849 #[test]
850 fn test_resolve_type_for_format_non_alias() {
851 let ctx = ExecutionContext::new_empty();
852
853 let (base_type, resolved_overrides) = ctx.resolve_type_for_format("Number");
855 assert_eq!(base_type, "Number");
856 assert!(resolved_overrides.is_none());
857 }
858
859 #[derive(Clone)]
860 struct TestAsyncProvider {
861 frames: Arc<HashMap<CacheKey, DataFrame>>,
862 load_calls: Arc<AtomicUsize>,
863 }
864
865 impl AsyncDataProvider for TestAsyncProvider {
866 fn load<'a>(
867 &'a self,
868 query: &'a DataQuery,
869 ) -> std::pin::Pin<
870 Box<
871 dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
872 + Send
873 + 'a,
874 >,
875 > {
876 let key = CacheKey::new(query.id.clone(), query.timeframe);
877 let frames = self.frames.clone();
878 let calls = self.load_calls.clone();
879 Box::pin(async move {
880 calls.fetch_add(1, Ordering::SeqCst);
881 frames
882 .get(&key)
883 .cloned()
884 .ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
885 })
886 }
887
888 fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
889 let key = CacheKey::new(symbol.to_string(), *timeframe);
890 self.frames.contains_key(&key)
891 }
892
893 fn symbols(&self) -> Vec<String> {
894 self.frames.keys().map(|k| k.id.clone()).collect()
895 }
896 }
897
898 fn temp_snapshot_root(name: &str) -> std::path::PathBuf {
899 let ts = SystemTime::now()
900 .duration_since(UNIX_EPOCH)
901 .unwrap()
902 .as_millis();
903 std::env::temp_dir().join(format!("shape_ctx_snapshot_{}_{}", name, ts))
904 }
905
906 fn make_df(id: &str, timeframe: Timeframe) -> DataFrame {
907 let mut df = DataFrame::new(id, timeframe);
908 df.timestamps = vec![1, 2, 3];
909 df.add_column("a", vec![10.0, 11.0, 12.0]);
910 df
911 }
912
913 #[tokio::test]
914 async fn test_execution_context_snapshot_restores_data_cache() {
915 let tf = Timeframe::d1();
916 let df = make_df("TEST", tf);
917 let mut frames = HashMap::new();
918 frames.insert(CacheKey::new("TEST".to_string(), tf), df);
919 let load_calls = Arc::new(AtomicUsize::new(0));
920 let provider = Arc::new(TestAsyncProvider {
921 frames: Arc::new(frames),
922 load_calls: load_calls.clone(),
923 });
924
925 let mut ctx =
926 ExecutionContext::with_async_provider(provider, tokio::runtime::Handle::current());
927 ctx.prefetch_data(vec![DataQuery::new("TEST", tf)])
928 .await
929 .unwrap();
930 ctx.declare_variable("x", VarKind::Let, Some(ValueWord::from_f64(42.0)))
931 .unwrap();
932
933 let store = SnapshotStore::new(temp_snapshot_root("context_cache")).unwrap();
934 let snapshot = ctx.snapshot(&store).unwrap();
935
936 let mut restored = ExecutionContext::with_async_provider(
937 Arc::new(NullAsyncProvider::default()),
938 tokio::runtime::Handle::current(),
939 );
940 restored.restore_from_snapshot(snapshot, &store).unwrap();
941
942 let val = restored.get_variable("x").unwrap();
943 assert_eq!(val, Some(ValueWord::from_f64(42.0)));
944
945 let row = restored
946 .data_cache()
947 .unwrap()
948 .get_row("TEST", &tf, 0)
949 .expect("row should be cached");
950 assert_eq!(row.timestamp, 1);
951 assert_eq!(row.fields.get("a"), Some(&10.0));
952
953 assert_eq!(load_calls.load(Ordering::SeqCst), 1);
954 }
955}