1mod builder;
8mod execution;
9mod query_extraction;
10mod stdlib;
11mod types;
12
13pub use crate::query_result::QueryType;
15pub use builder::ShapeEngineBuilder;
16use shape_value::ValueWord;
17pub use types::{
18 EngineBootstrapState, ExecutionMetrics, ExecutionResult, ExecutionType, Message, MessageLevel,
19};
20
21use crate::Runtime;
22use crate::data::DataFrame;
23use shape_ast::error::{Result, ShapeError};
24
25#[cfg(feature = "jit")]
26use std::collections::HashMap;
27
28use crate::hashing::HashDigest;
29use crate::snapshot::{ContextSnapshot, ExecutionSnapshot, SemanticSnapshot, SnapshotStore};
30use serde::Serialize;
31use shape_ast::Program;
32use shape_wire::WireValue;
33
34pub trait ExpressionEvaluator: Send + Sync {
40 fn eval_statements(
42 &self,
43 stmts: &[shape_ast::Statement],
44 ctx: &mut crate::context::ExecutionContext,
45 ) -> Result<ValueWord>;
46
47 fn eval_expr(
49 &self,
50 expr: &shape_ast::Expr,
51 ctx: &mut crate::context::ExecutionContext,
52 ) -> Result<ValueWord>;
53}
54
55pub struct ProgramExecutorResult {
57 pub wire_value: WireValue,
58 pub type_info: Option<shape_wire::metadata::TypeInfo>,
59 pub execution_type: ExecutionType,
60 pub content_json: Option<serde_json::Value>,
61 pub content_html: Option<String>,
62 pub content_terminal: Option<String>,
63}
64
65pub trait ProgramExecutor {
67 fn execute_program(
68 &mut self,
69 engine: &mut ShapeEngine,
70 program: &Program,
71 ) -> Result<ProgramExecutorResult>;
72}
73
74pub struct ShapeEngine {
76 pub runtime: Runtime,
78 pub default_data: DataFrame,
80 #[cfg(feature = "jit")]
82 pub(crate) jit_cache: HashMap<u64, ()>,
83 pub(crate) current_source: Option<String>,
85 pub(crate) snapshot_store: Option<SnapshotStore>,
87 pub(crate) last_snapshot: Option<HashDigest>,
89 pub(crate) script_path: Option<String>,
91 pub(crate) exported_symbols: std::collections::HashSet<String>,
93}
94
95impl ShapeEngine {
96 pub fn new() -> Result<Self> {
98 let mut runtime = Runtime::new_without_stdlib();
99 runtime.enable_persistent_context_without_data();
100
101 Ok(Self {
102 runtime,
103 default_data: DataFrame::default(),
104 #[cfg(feature = "jit")]
105 jit_cache: HashMap::new(),
106 current_source: None,
107 snapshot_store: None,
108 last_snapshot: None,
109 script_path: None,
110 exported_symbols: std::collections::HashSet::new(),
111 })
112 }
113
114 pub fn with_data(data: DataFrame) -> Result<Self> {
116 let mut runtime = Runtime::new_without_stdlib();
117 runtime.enable_persistent_context(&data);
118 Ok(Self {
119 runtime,
120 default_data: data,
121 #[cfg(feature = "jit")]
122 jit_cache: HashMap::new(),
123 current_source: None,
124 snapshot_store: None,
125 last_snapshot: None,
126 script_path: None,
127 exported_symbols: std::collections::HashSet::new(),
128 })
129 }
130
131 pub fn with_async_provider(provider: crate::data::SharedAsyncProvider) -> Result<Self> {
136 let runtime_handle = tokio::runtime::Handle::try_current()
137 .map_err(|_| ShapeError::RuntimeError {
138 message: "No tokio runtime available. Ensure with_async_provider is called within a tokio context.".to_string(),
139 location: None,
140 })?;
141 let mut runtime = Runtime::new_without_stdlib();
142
143 let ctx = crate::context::ExecutionContext::with_async_provider(provider, runtime_handle);
145 runtime.set_persistent_context(ctx);
146
147 Ok(Self {
148 runtime,
149 default_data: DataFrame::default(),
150 #[cfg(feature = "jit")]
151 jit_cache: HashMap::new(),
152 current_source: None,
153 snapshot_store: None,
154 last_snapshot: None,
155 script_path: None,
156 exported_symbols: std::collections::HashSet::new(),
157 })
158 }
159
160 pub fn init_repl(&mut self) {
166 if let Some(ctx) = self.runtime.persistent_context_mut() {
168 ctx.set_output_adapter(Box::new(crate::output_adapter::ReplAdapter));
169 }
170 }
171
172 pub fn capture_bootstrap_state(&self) -> Result<EngineBootstrapState> {
176 let context =
177 self.runtime
178 .persistent_context()
179 .cloned()
180 .ok_or_else(|| ShapeError::RuntimeError {
181 message: "No persistent context available for bootstrap capture".to_string(),
182 location: None,
183 })?;
184 Ok(EngineBootstrapState {
185 semantic: SemanticSnapshot {
186 exported_symbols: self.exported_symbols.clone(),
187 },
188 context,
189 })
190 }
191
192 pub fn apply_bootstrap_state(&mut self, state: &EngineBootstrapState) {
194 self.exported_symbols = state.semantic.exported_symbols.clone();
195 self.runtime.set_persistent_context(state.context.clone());
196 }
197
198 pub fn set_script_path(&mut self, path: impl Into<String>) {
200 self.script_path = Some(path.into());
201 }
202
203 pub fn script_path(&self) -> Option<&str> {
205 self.script_path.as_deref()
206 }
207
208 pub fn enable_snapshot_store(&mut self, store: SnapshotStore) {
210 self.snapshot_store = Some(store);
211 }
212
213 pub fn last_snapshot(&self) -> Option<&HashDigest> {
215 self.last_snapshot.as_ref()
216 }
217
218 pub fn snapshot_store(&self) -> Option<&SnapshotStore> {
220 self.snapshot_store.as_ref()
221 }
222
223 pub fn store_snapshot_blob<T: Serialize>(&self, value: &T) -> Result<HashDigest> {
225 let store = self
226 .snapshot_store
227 .as_ref()
228 .ok_or_else(|| ShapeError::RuntimeError {
229 message: "Snapshot store not configured".to_string(),
230 location: None,
231 })?;
232 Ok(store.put_struct(value)?)
233 }
234
235 pub fn snapshot_with_hashes(
237 &mut self,
238 vm_hash: Option<HashDigest>,
239 bytecode_hash: Option<HashDigest>,
240 ) -> Result<HashDigest> {
241 let store = self
242 .snapshot_store
243 .as_ref()
244 .ok_or_else(|| ShapeError::RuntimeError {
245 message: "Snapshot store not configured".to_string(),
246 location: None,
247 })?;
248
249 let semantic = SemanticSnapshot {
250 exported_symbols: self.exported_symbols.clone(),
251 };
252 let semantic_hash = store.put_struct(&semantic)?;
253
254 let context = if let Some(ctx) = self.runtime.persistent_context() {
255 ctx.snapshot(store)?
256 } else {
257 return Err(ShapeError::RuntimeError {
258 message: "No persistent context for snapshot".to_string(),
259 location: None,
260 });
261 };
262 let context_hash = store.put_struct(&context)?;
263
264 let snapshot = ExecutionSnapshot {
265 version: crate::snapshot::SNAPSHOT_VERSION,
266 created_at_ms: chrono::Utc::now().timestamp_millis(),
267 semantic_hash,
268 context_hash,
269 vm_hash,
270 bytecode_hash,
271 script_path: self.script_path.clone(),
272 };
273
274 let snapshot_hash = store.put_snapshot(&snapshot)?;
275 self.last_snapshot = Some(snapshot_hash.clone());
276 Ok(snapshot_hash)
277 }
278
279 pub fn load_snapshot(
281 &self,
282 snapshot_id: &HashDigest,
283 ) -> Result<(
284 SemanticSnapshot,
285 ContextSnapshot,
286 Option<HashDigest>,
287 Option<HashDigest>,
288 )> {
289 let store = self
290 .snapshot_store
291 .as_ref()
292 .ok_or_else(|| ShapeError::RuntimeError {
293 message: "Snapshot store not configured".to_string(),
294 location: None,
295 })?;
296 let snapshot = store.get_snapshot(snapshot_id)?;
297 let semantic: SemanticSnapshot =
298 store
299 .get_struct(&snapshot.semantic_hash)
300 .map_err(|e| ShapeError::RuntimeError {
301 message: format!("failed to deserialize SemanticSnapshot: {e}"),
302 location: None,
303 })?;
304 let context: ContextSnapshot =
305 store
306 .get_struct(&snapshot.context_hash)
307 .map_err(|e| ShapeError::RuntimeError {
308 message: format!("failed to deserialize ContextSnapshot: {e}"),
309 location: None,
310 })?;
311 Ok((semantic, context, snapshot.vm_hash, snapshot.bytecode_hash))
312 }
313
314 pub fn apply_snapshot(
316 &mut self,
317 semantic: SemanticSnapshot,
318 context: ContextSnapshot,
319 ) -> Result<()> {
320 self.exported_symbols = semantic.exported_symbols;
321 if let Some(ctx) = self.runtime.persistent_context_mut() {
322 let store = self
323 .snapshot_store
324 .as_ref()
325 .ok_or_else(|| ShapeError::RuntimeError {
326 message: "Snapshot store not configured".to_string(),
327 location: None,
328 })?;
329 ctx.restore_from_snapshot(context, store)?;
330 Ok(())
331 } else {
332 Err(ShapeError::RuntimeError {
333 message: "No persistent context for snapshot".to_string(),
334 location: None,
335 })
336 }
337 }
338
339 pub fn register_extension_modules(
342 &mut self,
343 modules: &[crate::extensions::ParsedModuleSchema],
344 ) {
345 self.runtime.register_extension_module_artifacts(modules);
346 }
347
348 pub fn set_source(&mut self, source: &str) {
353 self.current_source = Some(source.to_string());
354 }
355
356 pub fn current_source(&self) -> Option<&str> {
358 self.current_source.as_deref()
359 }
360
361 pub fn register_provider(&mut self, name: &str, provider: crate::data::SharedAsyncProvider) {
372 if let Some(ctx) = self.runtime.persistent_context_mut() {
373 ctx.register_provider(name, provider);
374 }
375 }
376
377 pub fn set_default_provider(&mut self, name: &str) -> Result<()> {
381 if let Some(ctx) = self.runtime.persistent_context_mut() {
382 ctx.set_default_provider(name)
383 } else {
384 Err(ShapeError::RuntimeError {
385 message: "No execution context available".to_string(),
386 location: None,
387 })
388 }
389 }
390
391 pub fn register_type_mapping(
418 &mut self,
419 type_name: &str,
420 mapping: crate::type_mapping::TypeMapping,
421 ) {
422 if let Some(ctx) = self.runtime.persistent_context_mut() {
423 ctx.register_type_mapping(type_name, mapping);
424 }
425 }
426
427 pub fn get_runtime(&self) -> &Runtime {
429 &self.runtime
430 }
431
432 pub fn get_runtime_mut(&mut self) -> &mut Runtime {
434 &mut self.runtime
435 }
436
437 pub fn get_variable_format_hint(&self, name: &str) -> Option<String> {
442 self.runtime
443 .persistent_context()
444 .and_then(|ctx| ctx.get_variable_format_hint(name))
445 }
446
447 pub fn format_value_string(
479 &mut self,
480 value: f64,
481 type_name: &str,
482 format_name: Option<&str>,
483 params: &std::collections::HashMap<String, serde_json::Value>,
484 ) -> Result<String> {
485 use std::sync::Arc;
486
487 let (resolved_type_name, merged_params) =
489 self.resolve_type_alias_for_formatting(type_name, params)?;
490
491 let param_values: std::collections::HashMap<String, ValueWord> = merged_params
493 .iter()
494 .map(|(k, v)| {
495 let runtime_val = match v {
496 serde_json::Value::Number(n) => ValueWord::from_f64(n.as_f64().unwrap_or(0.0)),
497 serde_json::Value::String(s) => ValueWord::from_string(Arc::new(s.clone())),
498 serde_json::Value::Bool(b) => ValueWord::from_bool(*b),
499 _ => ValueWord::none(),
500 };
501 (k.clone(), runtime_val)
502 })
503 .collect();
504
505 let runtime_value = ValueWord::from_f64(value);
507
508 self.runtime.format_value(
510 runtime_value,
511 resolved_type_name.as_str(),
512 format_name,
513 param_values,
514 )
515 }
516
517 fn resolve_type_alias_for_formatting(
522 &self,
523 type_name: &str,
524 params: &std::collections::HashMap<String, serde_json::Value>,
525 ) -> Result<(String, std::collections::HashMap<String, serde_json::Value>)> {
526 let resolved = self
528 .runtime
529 .persistent_context()
530 .map(|ctx| ctx.resolve_type_for_format(type_name));
531
532 if let Some((base_type, Some(overrides))) = resolved {
533 if base_type != type_name {
534 let mut merged = std::collections::HashMap::new();
535
536 for (key, val) in overrides {
538 let json_val = if let Some(n) = val.as_f64() {
539 serde_json::json!(n)
540 } else if val.is_bool() {
541 serde_json::json!(val.as_bool())
542 } else {
543 continue;
545 };
546 merged.insert(key, json_val);
547 }
548
549 for (key, val) in params {
551 merged.insert(key.clone(), val.clone());
552 }
553
554 return Ok((base_type, merged));
555 }
556 }
557
558 Ok((type_name.to_string(), params.clone()))
560 }
561
562 pub fn load_extension(
588 &mut self,
589 path: &std::path::Path,
590 config: &serde_json::Value,
591 ) -> Result<crate::extensions::LoadedExtension> {
592 if let Some(ctx) = self.runtime.persistent_context_mut() {
593 ctx.load_extension(path, config)
594 } else {
595 Err(ShapeError::RuntimeError {
596 message: "No execution context available for extension loading".to_string(),
597 location: None,
598 })
599 }
600 }
601
602 pub fn unload_extension(&mut self, name: &str) -> bool {
612 if let Some(ctx) = self.runtime.persistent_context_mut() {
613 ctx.unload_extension(name)
614 } else {
615 false
616 }
617 }
618
619 pub fn list_extensions(&self) -> Vec<String> {
621 if let Some(ctx) = self.runtime.persistent_context() {
622 ctx.list_extensions()
623 } else {
624 Vec::new()
625 }
626 }
627
628 pub fn get_extension_query_schema(
638 &self,
639 name: &str,
640 ) -> Option<crate::extensions::ParsedQuerySchema> {
641 if let Some(ctx) = self.runtime.persistent_context() {
642 ctx.get_extension_query_schema(name)
643 } else {
644 None
645 }
646 }
647
648 pub fn get_extension_output_schema(
658 &self,
659 name: &str,
660 ) -> Option<crate::extensions::ParsedOutputSchema> {
661 if let Some(ctx) = self.runtime.persistent_context() {
662 ctx.get_extension_output_schema(name)
663 } else {
664 None
665 }
666 }
667
668 pub fn get_extension(
670 &self,
671 name: &str,
672 ) -> Option<std::sync::Arc<crate::extensions::ExtensionDataSource>> {
673 if let Some(ctx) = self.runtime.persistent_context() {
674 ctx.get_extension(name)
675 } else {
676 None
677 }
678 }
679
680 pub fn get_extension_module_schema(
682 &self,
683 module_name: &str,
684 ) -> Option<crate::extensions::ParsedModuleSchema> {
685 if let Some(ctx) = self.runtime.persistent_context() {
686 ctx.get_extension_module_schema(module_name)
687 } else {
688 None
689 }
690 }
691
692 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
694 if let Some(ctx) = self.runtime.persistent_context() {
695 ctx.module_exports_from_extensions()
696 } else {
697 Vec::new()
698 }
699 }
700
701 pub fn language_runtimes(
703 &self,
704 ) -> std::collections::HashMap<String, std::sync::Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>
705 {
706 if let Some(ctx) = self.runtime.persistent_context() {
707 ctx.language_runtimes()
708 } else {
709 std::collections::HashMap::new()
710 }
711 }
712
713 pub fn invoke_extension_module_nb(
715 &self,
716 module_name: &str,
717 function: &str,
718 args: &[shape_value::ValueWord],
719 ) -> Result<shape_value::ValueWord> {
720 if let Some(ctx) = self.runtime.persistent_context() {
721 ctx.invoke_extension_module_nb(module_name, function, args)
722 } else {
723 Err(shape_ast::error::ShapeError::RuntimeError {
724 message: "No runtime context available".to_string(),
725 location: None,
726 })
727 }
728 }
729
730 pub fn invoke_extension_module_wire(
732 &self,
733 module_name: &str,
734 function: &str,
735 args: &[shape_wire::WireValue],
736 ) -> Result<shape_wire::WireValue> {
737 if let Some(ctx) = self.runtime.persistent_context() {
738 ctx.invoke_extension_module_wire(module_name, function, args)
739 } else {
740 Err(shape_ast::error::ShapeError::RuntimeError {
741 message: "No runtime context available".to_string(),
742 location: None,
743 })
744 }
745 }
746
747 pub fn enable_progress_tracking(
768 &mut self,
769 ) -> std::sync::Arc<crate::progress::ProgressRegistry> {
770 let registry = crate::progress::ProgressRegistry::new();
772 if let Some(ctx) = self.runtime.persistent_context_mut() {
773 ctx.set_progress_registry(registry.clone());
774 }
775 registry
776 }
777
778 pub fn progress_registry(&self) -> Option<std::sync::Arc<crate::progress::ProgressRegistry>> {
780 self.runtime
781 .persistent_context()
782 .and_then(|ctx| ctx.progress_registry())
783 .cloned()
784 }
785
786 pub fn has_pending_progress(&self) -> bool {
788 if let Some(registry) = self.progress_registry() {
789 !registry.is_empty()
790 } else {
791 false
792 }
793 }
794
795 pub fn poll_progress(&self) -> Option<crate::progress::ProgressEvent> {
799 self.progress_registry()
800 .and_then(|registry| registry.try_recv())
801 }
802}
803
804impl Default for ShapeEngine {
805 fn default() -> Self {
806 Self::new().expect("Failed to create default Shape engine")
807 }
808}
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813 use crate::extensions::{ParsedModuleArtifact, ParsedModuleSchema};
814
815 #[test]
816 fn test_register_extension_modules_registers_module_loader_artifacts() {
817 let mut engine = ShapeEngine::new().expect("engine should create");
818
819 engine.register_extension_modules(&[ParsedModuleSchema {
820 module_name: "duckdb".to_string(),
821 functions: Vec::new(),
822 artifacts: vec![ParsedModuleArtifact {
823 module_path: "duckdb".to_string(),
824 source: Some("pub fn connect(uri) { uri }".to_string()),
825 compiled: None,
826 }],
827 }]);
828
829 let mut loader = engine.runtime.configured_module_loader();
830 let module = loader
831 .load_module("duckdb")
832 .expect("registered extension module artifact should load");
833 assert!(
834 module.exports.contains_key("connect"),
835 "expected connect export"
836 );
837 }
838}