1mod builder;
8mod execution;
9mod query_extraction;
10mod stdlib;
11mod types;
12
13pub use crate::query_result::QueryType;
15pub use builder::ShapeEngineBuilder;
16use shape_value::ValueWord;
17pub use types::{
18 EngineBootstrapState, ExecutionMetrics, ExecutionResult, ExecutionType, Message, MessageLevel,
19};
20
21use crate::Runtime;
22use crate::data::DataFrame;
23use crate::semantic::SemanticAnalyzer;
24use shape_ast::error::{Result, ShapeError};
25
26#[cfg(feature = "jit")]
27use std::collections::HashMap;
28
29use crate::hashing::HashDigest;
30use crate::snapshot::{ContextSnapshot, ExecutionSnapshot, SemanticSnapshot, SnapshotStore};
31use serde::Serialize;
32use shape_ast::Program;
33use shape_wire::WireValue;
34
35pub trait ExpressionEvaluator: Send + Sync {
41 fn eval_statements(
43 &self,
44 stmts: &[shape_ast::Statement],
45 ctx: &mut crate::context::ExecutionContext,
46 ) -> Result<ValueWord>;
47
48 fn eval_expr(
50 &self,
51 expr: &shape_ast::Expr,
52 ctx: &mut crate::context::ExecutionContext,
53 ) -> Result<ValueWord>;
54}
55
56pub struct ProgramExecutorResult {
58 pub wire_value: WireValue,
59 pub type_info: Option<shape_wire::metadata::TypeInfo>,
60 pub execution_type: ExecutionType,
61 pub content_json: Option<serde_json::Value>,
62 pub content_html: Option<String>,
63 pub content_terminal: Option<String>,
64}
65
66pub trait ProgramExecutor {
68 fn execute_program(
69 &self,
70 engine: &mut ShapeEngine,
71 program: &Program,
72 ) -> Result<ProgramExecutorResult>;
73}
74
75pub struct ShapeEngine {
77 pub runtime: Runtime,
79 pub(crate) analyzer: SemanticAnalyzer,
81 pub default_data: DataFrame,
83 #[cfg(feature = "jit")]
85 pub(crate) jit_cache: HashMap<u64, ()>,
86 pub(crate) current_source: Option<String>,
88 pub(crate) snapshot_store: Option<SnapshotStore>,
90 pub(crate) last_snapshot: Option<HashDigest>,
92 pub(crate) script_path: Option<String>,
94}
95
96impl ShapeEngine {
97 pub fn new() -> Result<Self> {
99 let mut runtime = Runtime::new_without_stdlib();
100 runtime.enable_persistent_context_without_data();
101
102 Ok(Self {
103 runtime,
104 analyzer: SemanticAnalyzer::new(),
105 default_data: DataFrame::default(),
106 #[cfg(feature = "jit")]
107 jit_cache: HashMap::new(),
108 current_source: None,
109 snapshot_store: None,
110 last_snapshot: None,
111 script_path: None,
112 })
113 }
114
115 pub fn with_data(data: DataFrame) -> Result<Self> {
117 let mut runtime = Runtime::new_without_stdlib();
118 runtime.enable_persistent_context(&data);
119 Ok(Self {
120 runtime,
121 analyzer: SemanticAnalyzer::new(),
122 default_data: data,
123 #[cfg(feature = "jit")]
124 jit_cache: HashMap::new(),
125 current_source: None,
126 snapshot_store: None,
127 last_snapshot: None,
128 script_path: None,
129 })
130 }
131
132 pub fn with_async_provider(provider: crate::data::SharedAsyncProvider) -> Result<Self> {
137 let runtime_handle = tokio::runtime::Handle::try_current()
138 .map_err(|_| ShapeError::RuntimeError {
139 message: "No tokio runtime available. Ensure with_async_provider is called within a tokio context.".to_string(),
140 location: None,
141 })?;
142 let mut runtime = Runtime::new_without_stdlib();
143
144 let ctx = crate::context::ExecutionContext::with_async_provider(provider, runtime_handle);
146 runtime.set_persistent_context(ctx);
147
148 Ok(Self {
149 runtime,
150 analyzer: SemanticAnalyzer::new(),
151 default_data: DataFrame::default(),
152 #[cfg(feature = "jit")]
153 jit_cache: HashMap::new(),
154 current_source: None,
155 snapshot_store: None,
156 last_snapshot: None,
157 script_path: None,
158 })
159 }
160
161 pub fn init_repl(&mut self) {
167 self.analyzer.init_repl_scope();
168
169 if let Some(ctx) = self.runtime.persistent_context_mut() {
171 ctx.set_output_adapter(Box::new(crate::output_adapter::ReplAdapter));
172 }
173 }
174
175 pub fn capture_bootstrap_state(&self) -> Result<EngineBootstrapState> {
179 let context =
180 self.runtime
181 .persistent_context()
182 .cloned()
183 .ok_or_else(|| ShapeError::RuntimeError {
184 message: "No persistent context available for bootstrap capture".to_string(),
185 location: None,
186 })?;
187 Ok(EngineBootstrapState {
188 semantic: self.analyzer.snapshot(),
189 context,
190 })
191 }
192
193 pub fn apply_bootstrap_state(&mut self, state: &EngineBootstrapState) {
195 self.analyzer.restore_from_snapshot(state.semantic.clone());
196 self.runtime.set_persistent_context(state.context.clone());
197 }
198
199 pub fn set_script_path(&mut self, path: impl Into<String>) {
201 self.script_path = Some(path.into());
202 }
203
204 pub fn script_path(&self) -> Option<&str> {
206 self.script_path.as_deref()
207 }
208
209 pub fn enable_snapshot_store(&mut self, store: SnapshotStore) {
211 self.snapshot_store = Some(store);
212 }
213
214 pub fn last_snapshot(&self) -> Option<&HashDigest> {
216 self.last_snapshot.as_ref()
217 }
218
219 pub fn snapshot_store(&self) -> Option<&SnapshotStore> {
221 self.snapshot_store.as_ref()
222 }
223
224 pub fn store_snapshot_blob<T: Serialize>(&self, value: &T) -> Result<HashDigest> {
226 let store = self
227 .snapshot_store
228 .as_ref()
229 .ok_or_else(|| ShapeError::RuntimeError {
230 message: "Snapshot store not configured".to_string(),
231 location: None,
232 })?;
233 Ok(store.put_struct(value)?)
234 }
235
236 pub fn snapshot_with_hashes(
238 &mut self,
239 vm_hash: Option<HashDigest>,
240 bytecode_hash: Option<HashDigest>,
241 ) -> Result<HashDigest> {
242 let store = self
243 .snapshot_store
244 .as_ref()
245 .ok_or_else(|| ShapeError::RuntimeError {
246 message: "Snapshot store not configured".to_string(),
247 location: None,
248 })?;
249
250 let semantic = self.analyzer.snapshot();
251 let semantic_hash = store.put_struct(&semantic)?;
252
253 let context = if let Some(ctx) = self.runtime.persistent_context() {
254 ctx.snapshot(store)?
255 } else {
256 return Err(ShapeError::RuntimeError {
257 message: "No persistent context for snapshot".to_string(),
258 location: None,
259 });
260 };
261 let context_hash = store.put_struct(&context)?;
262
263 let snapshot = ExecutionSnapshot {
264 version: crate::snapshot::SNAPSHOT_VERSION,
265 created_at_ms: chrono::Utc::now().timestamp_millis(),
266 semantic_hash,
267 context_hash,
268 vm_hash,
269 bytecode_hash,
270 script_path: self.script_path.clone(),
271 };
272
273 let snapshot_hash = store.put_snapshot(&snapshot)?;
274 self.last_snapshot = Some(snapshot_hash.clone());
275 Ok(snapshot_hash)
276 }
277
278 pub fn load_snapshot(
280 &self,
281 snapshot_id: &HashDigest,
282 ) -> Result<(
283 SemanticSnapshot,
284 ContextSnapshot,
285 Option<HashDigest>,
286 Option<HashDigest>,
287 )> {
288 let store = self
289 .snapshot_store
290 .as_ref()
291 .ok_or_else(|| ShapeError::RuntimeError {
292 message: "Snapshot store not configured".to_string(),
293 location: None,
294 })?;
295 let snapshot = store.get_snapshot(snapshot_id)?;
296 let semantic: SemanticSnapshot =
297 store
298 .get_struct(&snapshot.semantic_hash)
299 .map_err(|e| ShapeError::RuntimeError {
300 message: format!("failed to deserialize SemanticSnapshot: {e}"),
301 location: None,
302 })?;
303 let context: ContextSnapshot =
304 store
305 .get_struct(&snapshot.context_hash)
306 .map_err(|e| ShapeError::RuntimeError {
307 message: format!("failed to deserialize ContextSnapshot: {e}"),
308 location: None,
309 })?;
310 Ok((semantic, context, snapshot.vm_hash, snapshot.bytecode_hash))
311 }
312
313 pub fn apply_snapshot(
315 &mut self,
316 semantic: SemanticSnapshot,
317 context: ContextSnapshot,
318 ) -> Result<()> {
319 self.analyzer.restore_from_snapshot(semantic);
320 if let Some(ctx) = self.runtime.persistent_context_mut() {
321 let store = self
322 .snapshot_store
323 .as_ref()
324 .ok_or_else(|| ShapeError::RuntimeError {
325 message: "Snapshot store not configured".to_string(),
326 location: None,
327 })?;
328 ctx.restore_from_snapshot(context, store)?;
329 Ok(())
330 } else {
331 Err(ShapeError::RuntimeError {
332 message: "No persistent context for snapshot".to_string(),
333 location: None,
334 })
335 }
336 }
337
338 pub fn register_extension_modules(
341 &mut self,
342 modules: &[crate::extensions::ParsedModuleSchema],
343 ) {
344 self.runtime.register_extension_module_artifacts(modules);
345 self.analyzer.register_extension_modules(modules);
346 }
347
348 pub fn set_source(&mut self, source: &str) {
353 self.current_source = Some(source.to_string());
354 }
355
356 pub fn current_source(&self) -> Option<&str> {
358 self.current_source.as_deref()
359 }
360
361 pub fn analyze_incremental(
367 &mut self,
368 program: &shape_ast::Program,
369 source: &str,
370 ) -> Result<()> {
371 self.analyzer.set_source(source);
372 self.analyzer.analyze_incremental(program)
373 }
374
375 pub fn register_provider(&mut self, name: &str, provider: crate::data::SharedAsyncProvider) {
386 if let Some(ctx) = self.runtime.persistent_context_mut() {
387 ctx.register_provider(name, provider);
388 }
389 }
390
391 pub fn set_default_provider(&mut self, name: &str) -> Result<()> {
395 if let Some(ctx) = self.runtime.persistent_context_mut() {
396 ctx.set_default_provider(name)
397 } else {
398 Err(ShapeError::RuntimeError {
399 message: "No execution context available".to_string(),
400 location: None,
401 })
402 }
403 }
404
405 pub fn register_type_mapping(
432 &mut self,
433 type_name: &str,
434 mapping: crate::type_mapping::TypeMapping,
435 ) {
436 if let Some(ctx) = self.runtime.persistent_context_mut() {
437 ctx.register_type_mapping(type_name, mapping);
438 }
439 }
440
441 pub fn get_runtime(&self) -> &Runtime {
443 &self.runtime
444 }
445
446 pub fn get_runtime_mut(&mut self) -> &mut Runtime {
448 &mut self.runtime
449 }
450
451 pub fn get_variable_format_hint(&self, name: &str) -> Option<String> {
456 self.runtime
457 .persistent_context()
458 .and_then(|ctx| ctx.get_variable_format_hint(name))
459 }
460
461 pub fn format_value_string(
493 &mut self,
494 value: f64,
495 type_name: &str,
496 format_name: Option<&str>,
497 params: &std::collections::HashMap<String, serde_json::Value>,
498 ) -> Result<String> {
499 use std::sync::Arc;
500
501 let (resolved_type_name, merged_params) =
503 self.resolve_type_alias_for_formatting(type_name, params)?;
504
505 let param_values: std::collections::HashMap<String, ValueWord> = merged_params
507 .iter()
508 .map(|(k, v)| {
509 let runtime_val = match v {
510 serde_json::Value::Number(n) => ValueWord::from_f64(n.as_f64().unwrap_or(0.0)),
511 serde_json::Value::String(s) => ValueWord::from_string(Arc::new(s.clone())),
512 serde_json::Value::Bool(b) => ValueWord::from_bool(*b),
513 _ => ValueWord::none(),
514 };
515 (k.clone(), runtime_val)
516 })
517 .collect();
518
519 let runtime_value = ValueWord::from_f64(value);
521
522 self.runtime.format_value(
524 runtime_value,
525 resolved_type_name.as_str(),
526 format_name,
527 param_values,
528 )
529 }
530
531 fn resolve_type_alias_for_formatting(
536 &self,
537 type_name: &str,
538 params: &std::collections::HashMap<String, serde_json::Value>,
539 ) -> Result<(String, std::collections::HashMap<String, serde_json::Value>)> {
540 if let Some(alias_entry) = self.analyzer.lookup_type_alias(type_name) {
542 let base_type_name = Self::get_base_type_name(&alias_entry.type_annotation);
544
545 let mut merged = std::collections::HashMap::new();
548
549 if let Some(overrides) = &alias_entry.meta_param_overrides {
551 for (key, expr) in overrides {
552 if let Some(json_val) = Self::expr_to_json(expr) {
554 merged.insert(key.clone(), json_val);
555 }
556 }
557 }
558
559 for (key, val) in params {
561 merged.insert(key.clone(), val.clone());
562 }
563
564 Ok((base_type_name, merged))
565 } else {
566 Ok((type_name.to_string(), params.clone()))
568 }
569 }
570
571 fn get_base_type_name(ty: &shape_ast::ast::TypeAnnotation) -> String {
573 match ty {
574 shape_ast::ast::TypeAnnotation::Basic(name) => name.clone(),
575 shape_ast::ast::TypeAnnotation::Reference(name) => name.clone(),
576 shape_ast::ast::TypeAnnotation::Generic { name, .. } => name.clone(),
577 _ => "Unknown".to_string(),
578 }
579 }
580
581 fn expr_to_json(expr: &shape_ast::ast::Expr) -> Option<serde_json::Value> {
583 use shape_ast::ast::{Expr, Literal};
584 match expr {
585 Expr::Literal(Literal::Number(n), _) => Some(serde_json::json!(n)),
586 Expr::Literal(Literal::String(s), _) => Some(serde_json::json!(s)),
587 Expr::Literal(Literal::Bool(b), _) => Some(serde_json::json!(b)),
588 _ => None, }
590 }
591
592 pub fn load_extension(
618 &mut self,
619 path: &std::path::Path,
620 config: &serde_json::Value,
621 ) -> Result<crate::extensions::LoadedExtension> {
622 if let Some(ctx) = self.runtime.persistent_context_mut() {
623 ctx.load_extension(path, config)
624 } else {
625 Err(ShapeError::RuntimeError {
626 message: "No execution context available for extension loading".to_string(),
627 location: None,
628 })
629 }
630 }
631
632 pub fn unload_extension(&mut self, name: &str) -> bool {
642 if let Some(ctx) = self.runtime.persistent_context_mut() {
643 ctx.unload_extension(name)
644 } else {
645 false
646 }
647 }
648
649 pub fn list_extensions(&self) -> Vec<String> {
651 if let Some(ctx) = self.runtime.persistent_context() {
652 ctx.list_extensions()
653 } else {
654 Vec::new()
655 }
656 }
657
658 pub fn get_extension_query_schema(
668 &self,
669 name: &str,
670 ) -> Option<crate::extensions::ParsedQuerySchema> {
671 if let Some(ctx) = self.runtime.persistent_context() {
672 ctx.get_extension_query_schema(name)
673 } else {
674 None
675 }
676 }
677
678 pub fn get_extension_output_schema(
688 &self,
689 name: &str,
690 ) -> Option<crate::extensions::ParsedOutputSchema> {
691 if let Some(ctx) = self.runtime.persistent_context() {
692 ctx.get_extension_output_schema(name)
693 } else {
694 None
695 }
696 }
697
698 pub fn get_extension(
700 &self,
701 name: &str,
702 ) -> Option<std::sync::Arc<crate::extensions::ExtensionDataSource>> {
703 if let Some(ctx) = self.runtime.persistent_context() {
704 ctx.get_extension(name)
705 } else {
706 None
707 }
708 }
709
710 pub fn get_extension_module_schema(
712 &self,
713 module_name: &str,
714 ) -> Option<crate::extensions::ParsedModuleSchema> {
715 if let Some(ctx) = self.runtime.persistent_context() {
716 ctx.get_extension_module_schema(module_name)
717 } else {
718 None
719 }
720 }
721
722 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
724 if let Some(ctx) = self.runtime.persistent_context() {
725 ctx.module_exports_from_extensions()
726 } else {
727 Vec::new()
728 }
729 }
730
731 pub fn invoke_extension_module_nb(
733 &self,
734 module_name: &str,
735 function: &str,
736 args: &[shape_value::ValueWord],
737 ) -> Result<shape_value::ValueWord> {
738 if let Some(ctx) = self.runtime.persistent_context() {
739 ctx.invoke_extension_module_nb(module_name, function, args)
740 } else {
741 Err(shape_ast::error::ShapeError::RuntimeError {
742 message: "No runtime context available".to_string(),
743 location: None,
744 })
745 }
746 }
747
748 pub fn invoke_extension_module_wire(
750 &self,
751 module_name: &str,
752 function: &str,
753 args: &[shape_wire::WireValue],
754 ) -> Result<shape_wire::WireValue> {
755 if let Some(ctx) = self.runtime.persistent_context() {
756 ctx.invoke_extension_module_wire(module_name, function, args)
757 } else {
758 Err(shape_ast::error::ShapeError::RuntimeError {
759 message: "No runtime context available".to_string(),
760 location: None,
761 })
762 }
763 }
764
765 pub fn enable_progress_tracking(
786 &mut self,
787 ) -> std::sync::Arc<crate::progress::ProgressRegistry> {
788 let registry = crate::progress::ProgressRegistry::new();
790 if let Some(ctx) = self.runtime.persistent_context_mut() {
791 ctx.set_progress_registry(registry.clone());
792 }
793 registry
794 }
795
796 pub fn progress_registry(&self) -> Option<std::sync::Arc<crate::progress::ProgressRegistry>> {
798 self.runtime
799 .persistent_context()
800 .and_then(|ctx| ctx.progress_registry())
801 .cloned()
802 }
803
804 pub fn has_pending_progress(&self) -> bool {
806 if let Some(registry) = self.progress_registry() {
807 !registry.is_empty()
808 } else {
809 false
810 }
811 }
812
813 pub fn poll_progress(&self) -> Option<crate::progress::ProgressEvent> {
817 self.progress_registry()
818 .and_then(|registry| registry.try_recv())
819 }
820}
821
822impl Default for ShapeEngine {
823 fn default() -> Self {
824 Self::new().expect("Failed to create default Shape engine")
825 }
826}
827
828#[cfg(test)]
829mod tests {
830 use super::*;
831 use crate::extensions::{ParsedModuleArtifact, ParsedModuleSchema};
832
833 #[test]
834 fn test_register_extension_modules_registers_module_loader_artifacts() {
835 let mut engine = ShapeEngine::new().expect("engine should create");
836
837 engine.register_extension_modules(&[ParsedModuleSchema {
838 module_name: "duckdb".to_string(),
839 functions: Vec::new(),
840 artifacts: vec![ParsedModuleArtifact {
841 module_path: "duckdb".to_string(),
842 source: Some("pub fn connect(uri) { uri }".to_string()),
843 compiled: None,
844 }],
845 }]);
846
847 let mut loader = engine.runtime.configured_module_loader();
848 let module = loader
849 .load_module("duckdb")
850 .expect("registered extension module artifact should load");
851 assert!(
852 module.exports.contains_key("connect"),
853 "expected connect export"
854 );
855 }
856}