shape_runtime/context/
data_cache.rs1use shape_ast::error::{Result, ShapeError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
13pub enum DataLoadMode {
14 #[default]
18 Async,
19
20 Sync,
24}
25
26impl super::ExecutionContext {
27 pub async fn prefetch_data(&mut self, queries: Vec<crate::data::DataQuery>) -> Result<()> {
45 if let Some(cache) = &self.data_cache {
46 cache
47 .prefetch(queries)
48 .await
49 .map_err(|e| ShapeError::DataError {
50 message: format!("Failed to prefetch data: {}", e),
51 symbol: None,
52 timeframe: None,
53 })?;
54 }
55 Ok(())
56 }
57
58 pub fn start_live_feed(&mut self) -> Result<()> {
63 let id = self.get_current_id()?;
64 let timeframe = self.get_current_timeframe()?;
65
66 if let Some(cache) = &mut self.data_cache {
67 cache
68 .subscribe_live(&id, &timeframe)
69 .map_err(|e| ShapeError::RuntimeError {
70 message: format!("Failed to start live feed: {}", e),
71 location: None,
72 })?;
73 }
74 Ok(())
75 }
76
77 pub fn stop_live_feed(&mut self) -> Result<()> {
79 let id = self.get_current_id()?;
80 let timeframe = self.get_current_timeframe()?;
81
82 if let Some(cache) = &mut self.data_cache {
83 cache.unsubscribe_live(&id, &timeframe);
84 }
85 Ok(())
86 }
87
88 pub fn has_data_cache(&self) -> bool {
90 self.data_cache.is_some()
91 }
92
93 pub fn data_cache(&self) -> Option<&crate::data::DataCache> {
95 self.data_cache.as_ref()
96 }
97
98 pub fn async_provider(&self) -> Option<crate::data::SharedAsyncProvider> {
103 self.data_cache.as_ref().map(|cache| cache.provider())
104 }
105
106 pub fn register_provider(&self, name: &str, provider: crate::data::SharedAsyncProvider) {
110 self.provider_registry.register(name, provider);
111 }
112
113 pub fn get_provider(&self, name: &str) -> Result<crate::data::SharedAsyncProvider> {
115 self.provider_registry
116 .get(name)
117 .ok_or_else(|| ShapeError::RuntimeError {
118 message: format!("Provider '{}' not registered", name),
119 location: None,
120 })
121 }
122
123 pub fn get_default_provider(&self) -> Result<crate::data::SharedAsyncProvider> {
125 self.provider_registry
126 .get_default()
127 .ok_or_else(|| ShapeError::RuntimeError {
128 message: "No default provider configured".to_string(),
129 location: None,
130 })
131 }
132
133 pub fn set_default_provider(&self, name: &str) -> Result<()> {
135 self.provider_registry.set_default(name)
136 }
137
138 pub fn register_type_mapping(
142 &self,
143 type_name: &str,
144 mapping: super::super::type_mapping::TypeMapping,
145 ) {
146 self.type_mapping_registry.register(type_name, mapping);
147 }
148
149 pub fn get_type_mapping(
153 &self,
154 type_name: &str,
155 ) -> Result<super::super::type_mapping::TypeMapping> {
156 self.type_mapping_registry
157 .get(type_name)
158 .ok_or_else(|| ShapeError::RuntimeError {
159 message: format!("Type mapping for '{}' not found", type_name),
160 location: None,
161 })
162 }
163
164 pub fn has_type_mapping(&self, type_name: &str) -> bool {
166 self.type_mapping_registry.has(type_name)
167 }
168
169 pub fn load_extension(
180 &self,
181 path: &std::path::Path,
182 config: &serde_json::Value,
183 ) -> Result<super::super::extensions::LoadedExtension> {
184 self.provider_registry.load_extension(path, config)
185 }
186
187 pub fn unload_extension(&self, name: &str) -> bool {
189 self.provider_registry.unload_extension(name)
190 }
191
192 pub fn list_extensions(&self) -> Vec<String> {
194 self.provider_registry.list_extensions()
195 }
196
197 pub fn get_extension_query_schema(
199 &self,
200 name: &str,
201 ) -> Option<super::super::extensions::ParsedQuerySchema> {
202 self.provider_registry.get_extension_query_schema(name)
203 }
204
205 pub fn get_extension_output_schema(
207 &self,
208 name: &str,
209 ) -> Option<super::super::extensions::ParsedOutputSchema> {
210 self.provider_registry.get_extension_output_schema(name)
211 }
212
213 pub fn get_extension(
215 &self,
216 name: &str,
217 ) -> Option<std::sync::Arc<super::super::extensions::ExtensionDataSource>> {
218 self.provider_registry.get_extension(name)
219 }
220
221 pub fn get_extension_module_schema(
223 &self,
224 module_name: &str,
225 ) -> Option<super::super::extensions::ParsedModuleSchema> {
226 self.provider_registry
227 .get_extension_module_schema(module_name)
228 }
229
230 pub fn get_language_runtime(
232 &self,
233 language_id: &str,
234 ) -> Option<std::sync::Arc<super::super::plugins::language_runtime::PluginLanguageRuntime>>
235 {
236 self.provider_registry.get_language_runtime(language_id)
237 }
238
239 pub fn module_exports_from_extensions(
241 &self,
242 ) -> Vec<super::super::module_exports::ModuleExports> {
243 self.provider_registry.module_exports_from_extensions()
244 }
245
246 pub fn invoke_extension_module_nb(
248 &self,
249 module_name: &str,
250 function: &str,
251 args: &[shape_value::ValueWord],
252 ) -> Result<shape_value::ValueWord> {
253 self.provider_registry
254 .invoke_extension_module_nb(module_name, function, args)
255 }
256
257 pub fn invoke_extension_module_wire(
259 &self,
260 module_name: &str,
261 function: &str,
262 args: &[shape_wire::WireValue],
263 ) -> Result<shape_wire::WireValue> {
264 self.provider_registry
265 .invoke_extension_module_wire(module_name, function, args)
266 }
267
268 pub fn data_load_mode(&self) -> DataLoadMode {
270 self.data_load_mode
271 }
272
273 pub fn set_data_load_mode(&mut self, mode: DataLoadMode) {
275 self.data_load_mode = mode;
276 }
277
278 pub fn is_repl_mode(&self) -> bool {
280 self.data_load_mode == DataLoadMode::Sync
281 }
282
283 pub fn set_data_provider(&mut self, provider: std::sync::Arc<dyn std::any::Any + Send + Sync>) {
285 self.data_provider = Some(provider);
286 }
287
288 #[inline]
290 pub fn data_provider(&self) -> Result<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
291 self.data_provider
292 .as_ref()
293 .ok_or_else(|| ShapeError::RuntimeError {
294 message: "No DataProvider configured. Use engine's async provider.".to_string(),
295 location: None,
296 })
297 .cloned()
298 }
299}