Skip to main content

shape_runtime/
extension_context.rs

1//! Context-aware extension discovery and module-artifact registration.
2//!
3//! This module is the single source of truth for resolving declared
4//! `[[extensions]]` across frontmatter / project config and exposing
5//! extension module artifacts to the unified module loader.
6
7use crate::extensions::ParsedModuleSchema;
8use crate::frontmatter::parse_frontmatter;
9use crate::module_loader::{ModuleCode, ModuleLoader};
10use crate::project::find_project_root;
11use crate::provider_registry::ProviderRegistry;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
17pub struct ExtensionModuleSpec {
18    pub name: String,
19    pub path: PathBuf,
20    pub config: serde_json::Value,
21    /// Extension sections from the project config, available for section claims.
22    pub extension_sections: HashMap<String, toml::Value>,
23}
24
25/// Process-local cache for parsed extension module schemas.
26///
27/// Loading an extension `.so` is expensive (`dlopen` + schema parsing) so
28/// repeat callers for the same `(name, canonical path, config)` key share a
29/// cached [`ParsedModuleSchema`]. Caches are owned by their user (e.g.
30/// [`crate::Runtime::extension_module_schemas`] or the LSP's per-process
31/// cache) — there is no process-global instance.
32#[derive(Debug, Default)]
33pub struct ExtensionModuleSchemaCache {
34    entries: Mutex<HashMap<String, Option<ParsedModuleSchema>>>,
35}
36
37impl ExtensionModuleSchemaCache {
38    /// Create a fresh empty cache.
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Build the cache key used for a given [`ExtensionModuleSpec`].
44    fn key_for(spec: &ExtensionModuleSpec) -> String {
45        let canonical = spec
46            .path
47            .canonicalize()
48            .unwrap_or_else(|_| spec.path.clone())
49            .to_string_lossy()
50            .to_string();
51        let config_key = serde_json::to_string(&spec.config).unwrap_or_default();
52        format!("{}|{}|{}", spec.name, canonical, config_key)
53    }
54
55    /// Fetch a cached schema result, if any.
56    fn get(&self, key: &str) -> Option<Option<ParsedModuleSchema>> {
57        self.entries.lock().ok()?.get(key).cloned()
58    }
59
60    /// Insert a schema result into the cache.
61    fn insert(&self, key: String, schema: Option<ParsedModuleSchema>) {
62        if let Ok(mut guard) = self.entries.lock() {
63            guard.insert(key, schema);
64        }
65    }
66}
67
68/// Resolve declared extension module specs for the current context.
69///
70/// Precedence: frontmatter > shape.toml.
71pub fn declared_extension_specs_for_context(
72    current_file: Option<&Path>,
73    workspace_root: Option<&Path>,
74    current_source: Option<&str>,
75) -> Vec<ExtensionModuleSpec> {
76    let mut by_name: HashMap<String, ExtensionModuleSpec> = HashMap::new();
77
78    if let Some(source) = current_source {
79        let (frontmatter, _) = parse_frontmatter(source);
80        if let Some(frontmatter) = frontmatter {
81            let base_dir = current_file
82                .and_then(Path::parent)
83                .map(Path::to_path_buf)
84                .or_else(|| std::env::current_dir().ok())
85                .unwrap_or_else(|| PathBuf::from("."));
86            for extension in frontmatter.extensions {
87                let config = extension.config_as_json();
88                let resolved_path = if extension.path.is_absolute() {
89                    extension.path.clone()
90                } else {
91                    base_dir.join(&extension.path)
92                };
93                by_name.insert(
94                    extension.name.clone(),
95                    ExtensionModuleSpec {
96                        name: extension.name,
97                        path: resolved_path,
98                        config,
99                        extension_sections: frontmatter.extension_sections.clone(),
100                    },
101                );
102            }
103        }
104    }
105
106    let project = current_file
107        .and_then(|file| file.parent())
108        .and_then(find_project_root)
109        .or_else(|| workspace_root.and_then(find_project_root));
110    if let Some(project) = project {
111        for extension in project.config.extensions {
112            by_name.entry(extension.name.clone()).or_insert_with(|| {
113                let config = extension.config_as_json();
114                let resolved_path = if extension.path.is_absolute() {
115                    extension.path.clone()
116                } else {
117                    project.root_path.join(&extension.path)
118                };
119                ExtensionModuleSpec {
120                    name: extension.name,
121                    path: resolved_path,
122                    config,
123                    extension_sections: project.config.extension_sections.clone(),
124                }
125            });
126        }
127    }
128
129    let mut specs: Vec<ExtensionModuleSpec> = by_name.into_values().collect();
130    specs.sort_by(|left, right| left.name.cmp(&right.name));
131    specs
132}
133
134/// Resolve one declared extension module spec by module namespace.
135pub fn declared_extension_spec_for_module(
136    module_name: &str,
137    current_file: Option<&Path>,
138    workspace_root: Option<&Path>,
139    current_source: Option<&str>,
140) -> Option<ExtensionModuleSpec> {
141    declared_extension_specs_for_context(current_file, workspace_root, current_source)
142        .into_iter()
143        .find(|spec| spec.name == module_name)
144}
145
146/// Load one declared extension's `shape.module` schema, consulting the
147/// provided cache before hitting the provider registry.
148pub fn extension_module_schema_for_spec(
149    spec: &ExtensionModuleSpec,
150    cache: &ExtensionModuleSchemaCache,
151) -> Option<ParsedModuleSchema> {
152    if !spec.path.exists() {
153        return None;
154    }
155
156    let key = ExtensionModuleSchemaCache::key_for(spec);
157
158    if let Some(cached) = cache.get(&key) {
159        return cached;
160    }
161
162    let schema = {
163        let registry = ProviderRegistry::new();
164        match registry.load_extension(&spec.path, &spec.config) {
165            Ok(_) => registry
166                .get_extension_module_schema(&spec.name)
167                .or_else(|| {
168                    registry
169                        .list_extensions()
170                        .first()
171                        .and_then(|name| registry.get_extension_module_schema(name))
172                }),
173            Err(_) => None,
174        }
175    };
176
177    cache.insert(key, schema.clone());
178
179    schema
180}
181
182/// Load one declared extension module schema by name for current context,
183/// consulting the provided cache.
184pub fn extension_module_schema_for_context(
185    module_name: &str,
186    current_file: Option<&Path>,
187    workspace_root: Option<&Path>,
188    current_source: Option<&str>,
189    cache: &ExtensionModuleSchemaCache,
190) -> Option<ParsedModuleSchema> {
191    let spec = declared_extension_spec_for_module(
192        module_name,
193        current_file,
194        workspace_root,
195        current_source,
196    )?;
197    extension_module_schema_for_spec(&spec, cache)
198}
199
200/// Register declared extension module artifacts into the given module loader,
201/// consulting the provided cache for already-parsed schemas.
202pub fn register_declared_extensions_in_loader(
203    loader: &mut ModuleLoader,
204    current_file: Option<&Path>,
205    workspace_root: Option<&Path>,
206    current_source: Option<&str>,
207    cache: &ExtensionModuleSchemaCache,
208) {
209    for spec in declared_extension_specs_for_context(current_file, workspace_root, current_source) {
210        let Some(schema) = extension_module_schema_for_spec(&spec, cache) else {
211            continue;
212        };
213        for artifact in schema.artifacts {
214            let code = match (artifact.source, artifact.compiled) {
215                (Some(source), Some(compiled)) => ModuleCode::Both {
216                    source: Arc::from(source.as_str()),
217                    compiled: Arc::from(compiled),
218                },
219                (Some(source), None) => ModuleCode::Source(Arc::from(source.as_str())),
220                (None, Some(compiled)) => ModuleCode::Compiled(Arc::from(compiled)),
221                (None, None) => continue,
222            };
223            loader.register_extension_module(artifact.module_path, code);
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_declared_extension_spec_for_module_uses_project_config() {
234        let tmp = tempfile::tempdir().expect("temp dir");
235        let root = tmp.path();
236        std::fs::create_dir_all(root.join("src")).expect("create src");
237        std::fs::write(
238            root.join("shape.toml"),
239            r#"
240[[extensions]]
241name = "proj_ext_unique_for_test"
242path = "./extensions/libproj.so"
243"#,
244        )
245        .expect("write shape.toml");
246        std::fs::write(root.join("src/main.shape"), "use proj_ext_unique_for_test")
247            .expect("write main");
248
249        let spec = declared_extension_spec_for_module(
250            "proj_ext_unique_for_test",
251            Some(&root.join("src/main.shape")),
252            None,
253            None,
254        )
255        .expect("project extension should be discovered");
256
257        assert_eq!(spec.name, "proj_ext_unique_for_test");
258        assert_eq!(spec.path, root.join("extensions/libproj.so"));
259    }
260
261    #[test]
262    fn test_declared_extension_specs_frontmatter_overrides_project() {
263        let tmp = tempfile::tempdir().expect("temp dir");
264        let root = tmp.path();
265        std::fs::create_dir_all(root.join("src")).expect("create src");
266        std::fs::write(
267            root.join("shape.toml"),
268            r#"
269[[extensions]]
270name = "duckdb"
271path = "./project/libproject.so"
272"#,
273        )
274        .expect("write shape.toml");
275        std::fs::write(root.join("src/main.shape"), "use duckdb").expect("write main");
276
277        let source = r#"---
278[[extensions]]
279name = "duckdb"
280path = "./frontmatter/libfront.so"
281---
282use duckdb
283"#;
284
285        let spec = declared_extension_spec_for_module(
286            "duckdb",
287            Some(&root.join("src/main.shape")),
288            None,
289            Some(source),
290        )
291        .expect("frontmatter extension should be discovered");
292
293        assert_eq!(spec.path, root.join("src/frontmatter/libfront.so"));
294    }
295}