Skip to main content

shape_runtime/context/
data_cache.rs

1//! Data cache and provider management for ExecutionContext
2//!
3//! Handles async data loading, prefetching, and live data feeds.
4
5use shape_ast::error::{Result, ShapeError};
6
7/// Data loading execution mode (Phase 8)
8///
9/// Determines how runtime data access behaves:
10/// - Async: Data must be prefetched before execution (scripts, backtests)
11/// - Sync: Data requests can block during execution (REPL only)
12#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
13pub enum DataLoadMode {
14    /// Async mode - data must be prefetched.
15    /// Data requests return cached data or errors.
16    /// Used for: scripts, backtests, production
17    #[default]
18    Async,
19
20    /// Sync mode - data requests can block.
21    /// Uses tokio::runtime::Handle::current().block_on()
22    /// Used for: REPL, interactive exploration
23    Sync,
24}
25
26impl super::ExecutionContext {
27    /// Prefetch data before execution (Phase 6)
28    ///
29    /// This async method loads all required data concurrently and populates the cache.
30    /// Must be called before execution starts.
31    ///
32    /// # Arguments
33    ///
34    /// * `queries` - List of DataQuery objects specifying what data to load
35    ///
36    /// # Example
37    ///
38    /// ```ignore
39    /// let queries = vec![
40    ///     DataQuery::new("AAPL", Timeframe::d1()).limit(1000),
41    /// ];
42    /// ctx.prefetch_data(queries).await?;
43    /// ```
44    pub async fn prefetch_data(&mut self, queries: Vec<crate::data::DataQuery>) -> Result<()> {
45        if let Some(cache) = &self.data_cache {
46            cache
47                .prefetch(queries)
48                .await
49                .map_err(|e| ShapeError::DataError {
50                    message: format!("Failed to prefetch data: {}", e),
51                    symbol: None,
52                    timeframe: None,
53                })?;
54        }
55        Ok(())
56    }
57
58    /// Start live data feed (Phase 6)
59    ///
60    /// Subscribes to live bar updates for the current symbol/timeframe.
61    /// New bars will be appended to the live buffer as they arrive.
62    pub fn start_live_feed(&mut self) -> Result<()> {
63        let id = self.get_current_id()?;
64        let timeframe = self.get_current_timeframe()?;
65
66        if let Some(cache) = &mut self.data_cache {
67            cache
68                .subscribe_live(&id, &timeframe)
69                .map_err(|e| ShapeError::RuntimeError {
70                    message: format!("Failed to start live feed: {}", e),
71                    location: None,
72                })?;
73        }
74        Ok(())
75    }
76
77    /// Stop live data feed (Phase 6)
78    pub fn stop_live_feed(&mut self) -> Result<()> {
79        let id = self.get_current_id()?;
80        let timeframe = self.get_current_timeframe()?;
81
82        if let Some(cache) = &mut self.data_cache {
83            cache.unsubscribe_live(&id, &timeframe);
84        }
85        Ok(())
86    }
87
88    /// Check if using async data cache (Phase 6)
89    pub fn has_data_cache(&self) -> bool {
90        self.data_cache.is_some()
91    }
92
93    /// Get reference to data cache (Phase 8)
94    pub fn data_cache(&self) -> Option<&crate::data::DataCache> {
95        self.data_cache.as_ref()
96    }
97
98    /// Get async data provider (Phase 7)
99    ///
100    /// Returns the AsyncDataProvider from the data cache if available.
101    /// This is used for constructing TableRef and other lazy data references.
102    pub fn async_provider(&self) -> Option<crate::data::SharedAsyncProvider> {
103        self.data_cache.as_ref().map(|cache| cache.provider())
104    }
105
106    /// Register a data provider (Phase 8)
107    ///
108    /// Registers a named provider in the registry.
109    pub fn register_provider(&self, name: &str, provider: crate::data::SharedAsyncProvider) {
110        self.provider_registry.register(name, provider);
111    }
112
113    /// Get provider by name (Phase 8)
114    pub fn get_provider(&self, name: &str) -> Result<crate::data::SharedAsyncProvider> {
115        self.provider_registry
116            .get(name)
117            .ok_or_else(|| ShapeError::RuntimeError {
118                message: format!("Provider '{}' not registered", name),
119                location: None,
120            })
121    }
122
123    /// Get default provider (Phase 8)
124    pub fn get_default_provider(&self) -> Result<crate::data::SharedAsyncProvider> {
125        self.provider_registry
126            .get_default()
127            .ok_or_else(|| ShapeError::RuntimeError {
128                message: "No default provider configured".to_string(),
129                location: None,
130            })
131    }
132
133    /// Set default provider (Phase 8)
134    pub fn set_default_provider(&self, name: &str) -> Result<()> {
135        self.provider_registry.set_default(name)
136    }
137
138    /// Register a type mapping (Phase 8)
139    ///
140    /// Registers a type mapping that defines the expected DataFrame structure.
141    pub fn register_type_mapping(
142        &self,
143        type_name: &str,
144        mapping: super::super::type_mapping::TypeMapping,
145    ) {
146        self.type_mapping_registry.register(type_name, mapping);
147    }
148
149    /// Get type mapping (Phase 8)
150    ///
151    /// Retrieves the type mapping for validation.
152    pub fn get_type_mapping(
153        &self,
154        type_name: &str,
155    ) -> Result<super::super::type_mapping::TypeMapping> {
156        self.type_mapping_registry
157            .get(type_name)
158            .ok_or_else(|| ShapeError::RuntimeError {
159                message: format!("Type mapping for '{}' not found", type_name),
160                location: None,
161            })
162    }
163
164    /// Check if type mapping exists (Phase 8)
165    pub fn has_type_mapping(&self, type_name: &str) -> bool {
166        self.type_mapping_registry.has(type_name)
167    }
168
169    // ========================================================================
170    // Extension Management
171    // ========================================================================
172
173    /// Load a data source extension from a shared library
174    ///
175    /// # Arguments
176    ///
177    /// * `path` - Path to the extension shared library (.so, .dll, .dylib)
178    /// * `config` - Configuration value for the extension
179    pub fn load_extension(
180        &self,
181        path: &std::path::Path,
182        config: &serde_json::Value,
183    ) -> Result<super::super::extensions::LoadedExtension> {
184        self.provider_registry.load_extension(path, config)
185    }
186
187    /// Unload an extension by name
188    pub fn unload_extension(&self, name: &str) -> bool {
189        self.provider_registry.unload_extension(name)
190    }
191
192    /// List all loaded extension names
193    pub fn list_extensions(&self) -> Vec<String> {
194        self.provider_registry.list_extensions()
195    }
196
197    /// Get query schema for an extension (for LSP autocomplete)
198    pub fn get_extension_query_schema(
199        &self,
200        name: &str,
201    ) -> Option<super::super::extensions::ParsedQuerySchema> {
202        self.provider_registry.get_extension_query_schema(name)
203    }
204
205    /// Get output schema for an extension (for LSP autocomplete)
206    pub fn get_extension_output_schema(
207        &self,
208        name: &str,
209    ) -> Option<super::super::extensions::ParsedOutputSchema> {
210        self.provider_registry.get_extension_output_schema(name)
211    }
212
213    /// Get an extension data source by name
214    pub fn get_extension(
215        &self,
216        name: &str,
217    ) -> Option<std::sync::Arc<super::super::extensions::ExtensionDataSource>> {
218        self.provider_registry.get_extension(name)
219    }
220
221    /// Get extension module schema by module namespace.
222    pub fn get_extension_module_schema(
223        &self,
224        module_name: &str,
225    ) -> Option<super::super::extensions::ParsedModuleSchema> {
226        self.provider_registry
227            .get_extension_module_schema(module_name)
228    }
229
230    /// Get a language runtime by its language identifier (e.g., "python").
231    pub fn get_language_runtime(
232        &self,
233        language_id: &str,
234    ) -> Option<std::sync::Arc<super::super::plugins::language_runtime::PluginLanguageRuntime>>
235    {
236        self.provider_registry.get_language_runtime(language_id)
237    }
238
239    /// Return all loaded language runtimes, keyed by language identifier.
240    pub fn language_runtimes(
241        &self,
242    ) -> std::collections::HashMap<String, std::sync::Arc<super::super::plugins::language_runtime::PluginLanguageRuntime>>
243    {
244        self.provider_registry.language_runtimes()
245    }
246
247    /// Build VM extension modules from loaded extension module capabilities.
248    pub fn module_exports_from_extensions(
249        &self,
250    ) -> Vec<super::super::module_exports::ModuleExports> {
251        self.provider_registry.module_exports_from_extensions()
252    }
253
254    /// Invoke one loaded module export via module namespace.
255    pub fn invoke_extension_module_nb(
256        &self,
257        module_name: &str,
258        function: &str,
259        args: &[shape_value::ValueWord],
260    ) -> Result<shape_value::ValueWord> {
261        self.provider_registry
262            .invoke_extension_module_nb(module_name, function, args)
263    }
264
265    /// Invoke one loaded module export via module namespace.
266    pub fn invoke_extension_module_wire(
267        &self,
268        module_name: &str,
269        function: &str,
270        args: &[shape_wire::WireValue],
271    ) -> Result<shape_wire::WireValue> {
272        self.provider_registry
273            .invoke_extension_module_wire(module_name, function, args)
274    }
275
276    /// Get current data load mode (Phase 8)
277    pub fn data_load_mode(&self) -> DataLoadMode {
278        self.data_load_mode
279    }
280
281    /// Set data load mode (Phase 8)
282    pub fn set_data_load_mode(&mut self, mode: DataLoadMode) {
283        self.data_load_mode = mode;
284    }
285
286    /// Check if in REPL mode (sync loading allowed)
287    pub fn is_repl_mode(&self) -> bool {
288        self.data_load_mode == DataLoadMode::Sync
289    }
290
291    /// Set the DuckDB provider
292    pub fn set_data_provider(&mut self, provider: std::sync::Arc<dyn std::any::Any + Send + Sync>) {
293        self.data_provider = Some(provider);
294    }
295
296    /// Get DataProvider (legacy compatibility - returns type-erased Arc)
297    #[inline]
298    pub fn data_provider(&self) -> Result<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
299        self.data_provider
300            .as_ref()
301            .ok_or_else(|| ShapeError::RuntimeError {
302                message: "No DataProvider configured. Use engine's async provider.".to_string(),
303                location: None,
304            })
305            .cloned()
306    }
307}