Skip to main content

shape_runtime/context/
data_cache.rs

1//! Data cache and provider management for ExecutionContext
2//!
3//! Handles async data loading, prefetching, and live data feeds.
4
5use shape_ast::error::{Result, ShapeError};
6
7/// Data loading execution mode (Phase 8)
8///
9/// Determines how runtime data access behaves:
10/// - Async: Data must be prefetched before execution (scripts, backtests)
11/// - Sync: Data requests can block during execution (REPL only)
12#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
13pub enum DataLoadMode {
14    /// Async mode - data must be prefetched.
15    /// Data requests return cached data or errors.
16    /// Used for: scripts, backtests, production
17    #[default]
18    Async,
19
20    /// Sync mode - data requests can block.
21    /// Uses tokio::runtime::Handle::current().block_on()
22    /// Used for: REPL, interactive exploration
23    Sync,
24}
25
26impl super::ExecutionContext {
27    /// Prefetch data before execution (Phase 6)
28    ///
29    /// This async method loads all required data concurrently and populates the cache.
30    /// Must be called before execution starts.
31    ///
32    /// # Arguments
33    ///
34    /// * `queries` - List of DataQuery objects specifying what data to load
35    ///
36    /// # Example
37    ///
38    /// ```ignore
39    /// let queries = vec![
40    ///     DataQuery::new("AAPL", Timeframe::d1()).limit(1000),
41    /// ];
42    /// ctx.prefetch_data(queries).await?;
43    /// ```
44    pub async fn prefetch_data(&mut self, queries: Vec<crate::data::DataQuery>) -> Result<()> {
45        if let Some(cache) = &self.data_cache {
46            cache
47                .prefetch(queries)
48                .await
49                .map_err(|e| ShapeError::DataError {
50                    message: format!("Failed to prefetch data: {}", e),
51                    symbol: None,
52                    timeframe: None,
53                })?;
54        }
55        Ok(())
56    }
57
58    /// Start live data feed (Phase 6)
59    ///
60    /// Subscribes to live bar updates for the current symbol/timeframe.
61    /// New bars will be appended to the live buffer as they arrive.
62    pub fn start_live_feed(&mut self) -> Result<()> {
63        let id = self.get_current_id()?;
64        let timeframe = self.get_current_timeframe()?;
65
66        if let Some(cache) = &mut self.data_cache {
67            cache
68                .subscribe_live(&id, &timeframe)
69                .map_err(|e| ShapeError::RuntimeError {
70                    message: format!("Failed to start live feed: {}", e),
71                    location: None,
72                })?;
73        }
74        Ok(())
75    }
76
77    /// Stop live data feed (Phase 6)
78    pub fn stop_live_feed(&mut self) -> Result<()> {
79        let id = self.get_current_id()?;
80        let timeframe = self.get_current_timeframe()?;
81
82        if let Some(cache) = &mut self.data_cache {
83            cache.unsubscribe_live(&id, &timeframe);
84        }
85        Ok(())
86    }
87
88    /// Check if using async data cache (Phase 6)
89    pub fn has_data_cache(&self) -> bool {
90        self.data_cache.is_some()
91    }
92
93    /// Get reference to data cache (Phase 8)
94    pub fn data_cache(&self) -> Option<&crate::data::DataCache> {
95        self.data_cache.as_ref()
96    }
97
98    /// Get async data provider (Phase 7)
99    ///
100    /// Returns the AsyncDataProvider from the data cache if available.
101    /// This is used for constructing TableRef and other lazy data references.
102    pub fn async_provider(&self) -> Option<crate::data::SharedAsyncProvider> {
103        self.data_cache.as_ref().map(|cache| cache.provider())
104    }
105
106    /// Register a data provider (Phase 8)
107    ///
108    /// Registers a named provider in the registry.
109    pub fn register_provider(&self, name: &str, provider: crate::data::SharedAsyncProvider) {
110        self.provider_registry.register(name, provider);
111    }
112
113    /// Get provider by name (Phase 8)
114    pub fn get_provider(&self, name: &str) -> Result<crate::data::SharedAsyncProvider> {
115        self.provider_registry
116            .get(name)
117            .ok_or_else(|| ShapeError::RuntimeError {
118                message: format!("Provider '{}' not registered", name),
119                location: None,
120            })
121    }
122
123    /// Get default provider (Phase 8)
124    pub fn get_default_provider(&self) -> Result<crate::data::SharedAsyncProvider> {
125        self.provider_registry
126            .get_default()
127            .ok_or_else(|| ShapeError::RuntimeError {
128                message: "No default provider configured".to_string(),
129                location: None,
130            })
131    }
132
133    /// Set default provider (Phase 8)
134    pub fn set_default_provider(&self, name: &str) -> Result<()> {
135        self.provider_registry.set_default(name)
136    }
137
138    /// Register a type mapping (Phase 8)
139    ///
140    /// Registers a type mapping that defines the expected DataFrame structure.
141    pub fn register_type_mapping(
142        &self,
143        type_name: &str,
144        mapping: super::super::type_mapping::TypeMapping,
145    ) {
146        self.type_mapping_registry.register(type_name, mapping);
147    }
148
149    /// Get type mapping (Phase 8)
150    ///
151    /// Retrieves the type mapping for validation.
152    pub fn get_type_mapping(
153        &self,
154        type_name: &str,
155    ) -> Result<super::super::type_mapping::TypeMapping> {
156        self.type_mapping_registry
157            .get(type_name)
158            .ok_or_else(|| ShapeError::RuntimeError {
159                message: format!("Type mapping for '{}' not found", type_name),
160                location: None,
161            })
162    }
163
164    /// Check if type mapping exists (Phase 8)
165    pub fn has_type_mapping(&self, type_name: &str) -> bool {
166        self.type_mapping_registry.has(type_name)
167    }
168
169    // ========================================================================
170    // Extension Management
171    // ========================================================================
172
173    /// Load a data source extension from a shared library
174    ///
175    /// # Arguments
176    ///
177    /// * `path` - Path to the extension shared library (.so, .dll, .dylib)
178    /// * `config` - Configuration value for the extension
179    pub fn load_extension(
180        &self,
181        path: &std::path::Path,
182        config: &serde_json::Value,
183    ) -> Result<super::super::extensions::LoadedExtension> {
184        self.provider_registry.load_extension(path, config)
185    }
186
187    /// Unload an extension by name
188    pub fn unload_extension(&self, name: &str) -> bool {
189        self.provider_registry.unload_extension(name)
190    }
191
192    /// List all loaded extension names
193    pub fn list_extensions(&self) -> Vec<String> {
194        self.provider_registry.list_extensions()
195    }
196
197    /// Get query schema for an extension (for LSP autocomplete)
198    pub fn get_extension_query_schema(
199        &self,
200        name: &str,
201    ) -> Option<super::super::extensions::ParsedQuerySchema> {
202        self.provider_registry.get_extension_query_schema(name)
203    }
204
205    /// Get output schema for an extension (for LSP autocomplete)
206    pub fn get_extension_output_schema(
207        &self,
208        name: &str,
209    ) -> Option<super::super::extensions::ParsedOutputSchema> {
210        self.provider_registry.get_extension_output_schema(name)
211    }
212
213    /// Get an extension data source by name
214    pub fn get_extension(
215        &self,
216        name: &str,
217    ) -> Option<std::sync::Arc<super::super::extensions::ExtensionDataSource>> {
218        self.provider_registry.get_extension(name)
219    }
220
221    /// Get extension module schema by module namespace.
222    pub fn get_extension_module_schema(
223        &self,
224        module_name: &str,
225    ) -> Option<super::super::extensions::ParsedModuleSchema> {
226        self.provider_registry
227            .get_extension_module_schema(module_name)
228    }
229
230    /// Get a language runtime by its language identifier (e.g., "python").
231    pub fn get_language_runtime(
232        &self,
233        language_id: &str,
234    ) -> Option<std::sync::Arc<super::super::plugins::language_runtime::PluginLanguageRuntime>>
235    {
236        self.provider_registry.get_language_runtime(language_id)
237    }
238
239    /// Build VM extension modules from loaded extension module capabilities.
240    pub fn module_exports_from_extensions(
241        &self,
242    ) -> Vec<super::super::module_exports::ModuleExports> {
243        self.provider_registry.module_exports_from_extensions()
244    }
245
246    /// Invoke one loaded module export via module namespace.
247    pub fn invoke_extension_module_nb(
248        &self,
249        module_name: &str,
250        function: &str,
251        args: &[shape_value::ValueWord],
252    ) -> Result<shape_value::ValueWord> {
253        self.provider_registry
254            .invoke_extension_module_nb(module_name, function, args)
255    }
256
257    /// Invoke one loaded module export via module namespace.
258    pub fn invoke_extension_module_wire(
259        &self,
260        module_name: &str,
261        function: &str,
262        args: &[shape_wire::WireValue],
263    ) -> Result<shape_wire::WireValue> {
264        self.provider_registry
265            .invoke_extension_module_wire(module_name, function, args)
266    }
267
268    /// Get current data load mode (Phase 8)
269    pub fn data_load_mode(&self) -> DataLoadMode {
270        self.data_load_mode
271    }
272
273    /// Set data load mode (Phase 8)
274    pub fn set_data_load_mode(&mut self, mode: DataLoadMode) {
275        self.data_load_mode = mode;
276    }
277
278    /// Check if in REPL mode (sync loading allowed)
279    pub fn is_repl_mode(&self) -> bool {
280        self.data_load_mode == DataLoadMode::Sync
281    }
282
283    /// Set the DuckDB provider
284    pub fn set_data_provider(&mut self, provider: std::sync::Arc<dyn std::any::Any + Send + Sync>) {
285        self.data_provider = Some(provider);
286    }
287
288    /// Get DataProvider (legacy compatibility - returns type-erased Arc)
289    #[inline]
290    pub fn data_provider(&self) -> Result<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
291        self.data_provider
292            .as_ref()
293            .ok_or_else(|| ShapeError::RuntimeError {
294                message: "No DataProvider configured. Use engine's async provider.".to_string(),
295                location: None,
296            })
297            .cloned()
298    }
299}