1use crate::extensions::ParsedModuleSchema;
8use crate::frontmatter::parse_frontmatter;
9use crate::module_loader::{ModuleCode, ModuleLoader};
10use crate::project::find_project_root;
11use crate::provider_registry::ProviderRegistry;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
17pub struct ExtensionModuleSpec {
18 pub name: String,
19 pub path: PathBuf,
20 pub config: serde_json::Value,
21 pub extension_sections: HashMap<String, toml::Value>,
23}
24
25#[derive(Debug, Default)]
33pub struct ExtensionModuleSchemaCache {
34 entries: Mutex<HashMap<String, Option<ParsedModuleSchema>>>,
35}
36
37impl ExtensionModuleSchemaCache {
38 pub fn new() -> Self {
40 Self::default()
41 }
42
43 fn key_for(spec: &ExtensionModuleSpec) -> String {
45 let canonical = spec
46 .path
47 .canonicalize()
48 .unwrap_or_else(|_| spec.path.clone())
49 .to_string_lossy()
50 .to_string();
51 let config_key = serde_json::to_string(&spec.config).unwrap_or_default();
52 format!("{}|{}|{}", spec.name, canonical, config_key)
53 }
54
55 fn get(&self, key: &str) -> Option<Option<ParsedModuleSchema>> {
57 self.entries.lock().ok()?.get(key).cloned()
58 }
59
60 fn insert(&self, key: String, schema: Option<ParsedModuleSchema>) {
62 if let Ok(mut guard) = self.entries.lock() {
63 guard.insert(key, schema);
64 }
65 }
66}
67
68pub fn declared_extension_specs_for_context(
72 current_file: Option<&Path>,
73 workspace_root: Option<&Path>,
74 current_source: Option<&str>,
75) -> Vec<ExtensionModuleSpec> {
76 let mut by_name: HashMap<String, ExtensionModuleSpec> = HashMap::new();
77
78 if let Some(source) = current_source {
79 let (frontmatter, _) = parse_frontmatter(source);
80 if let Some(frontmatter) = frontmatter {
81 let base_dir = current_file
82 .and_then(Path::parent)
83 .map(Path::to_path_buf)
84 .or_else(|| std::env::current_dir().ok())
85 .unwrap_or_else(|| PathBuf::from("."));
86 for extension in frontmatter.extensions {
87 let config = extension.config_as_json();
88 let resolved_path = if extension.path.is_absolute() {
89 extension.path.clone()
90 } else {
91 base_dir.join(&extension.path)
92 };
93 by_name.insert(
94 extension.name.clone(),
95 ExtensionModuleSpec {
96 name: extension.name,
97 path: resolved_path,
98 config,
99 extension_sections: frontmatter.extension_sections.clone(),
100 },
101 );
102 }
103 }
104 }
105
106 let project = current_file
107 .and_then(|file| file.parent())
108 .and_then(find_project_root)
109 .or_else(|| workspace_root.and_then(find_project_root));
110 if let Some(project) = project {
111 for extension in project.config.extensions {
112 by_name.entry(extension.name.clone()).or_insert_with(|| {
113 let config = extension.config_as_json();
114 let resolved_path = if extension.path.is_absolute() {
115 extension.path.clone()
116 } else {
117 project.root_path.join(&extension.path)
118 };
119 ExtensionModuleSpec {
120 name: extension.name,
121 path: resolved_path,
122 config,
123 extension_sections: project.config.extension_sections.clone(),
124 }
125 });
126 }
127 }
128
129 let mut specs: Vec<ExtensionModuleSpec> = by_name.into_values().collect();
130 specs.sort_by(|left, right| left.name.cmp(&right.name));
131 specs
132}
133
134pub fn declared_extension_spec_for_module(
136 module_name: &str,
137 current_file: Option<&Path>,
138 workspace_root: Option<&Path>,
139 current_source: Option<&str>,
140) -> Option<ExtensionModuleSpec> {
141 declared_extension_specs_for_context(current_file, workspace_root, current_source)
142 .into_iter()
143 .find(|spec| spec.name == module_name)
144}
145
146pub fn extension_module_schema_for_spec(
149 spec: &ExtensionModuleSpec,
150 cache: &ExtensionModuleSchemaCache,
151) -> Option<ParsedModuleSchema> {
152 if !spec.path.exists() {
153 return None;
154 }
155
156 let key = ExtensionModuleSchemaCache::key_for(spec);
157
158 if let Some(cached) = cache.get(&key) {
159 return cached;
160 }
161
162 let schema = {
163 let registry = ProviderRegistry::new();
164 match registry.load_extension(&spec.path, &spec.config) {
165 Ok(_) => registry
166 .get_extension_module_schema(&spec.name)
167 .or_else(|| {
168 registry
169 .list_extensions()
170 .first()
171 .and_then(|name| registry.get_extension_module_schema(name))
172 }),
173 Err(_) => None,
174 }
175 };
176
177 cache.insert(key, schema.clone());
178
179 schema
180}
181
182pub fn extension_module_schema_for_context(
185 module_name: &str,
186 current_file: Option<&Path>,
187 workspace_root: Option<&Path>,
188 current_source: Option<&str>,
189 cache: &ExtensionModuleSchemaCache,
190) -> Option<ParsedModuleSchema> {
191 let spec = declared_extension_spec_for_module(
192 module_name,
193 current_file,
194 workspace_root,
195 current_source,
196 )?;
197 extension_module_schema_for_spec(&spec, cache)
198}
199
200pub fn register_declared_extensions_in_loader(
203 loader: &mut ModuleLoader,
204 current_file: Option<&Path>,
205 workspace_root: Option<&Path>,
206 current_source: Option<&str>,
207 cache: &ExtensionModuleSchemaCache,
208) {
209 for spec in declared_extension_specs_for_context(current_file, workspace_root, current_source) {
210 let Some(schema) = extension_module_schema_for_spec(&spec, cache) else {
211 continue;
212 };
213 for artifact in schema.artifacts {
214 let code = match (artifact.source, artifact.compiled) {
215 (Some(source), Some(compiled)) => ModuleCode::Both {
216 source: Arc::from(source.as_str()),
217 compiled: Arc::from(compiled),
218 },
219 (Some(source), None) => ModuleCode::Source(Arc::from(source.as_str())),
220 (None, Some(compiled)) => ModuleCode::Compiled(Arc::from(compiled)),
221 (None, None) => continue,
222 };
223 loader.register_extension_module(artifact.module_path, code);
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_declared_extension_spec_for_module_uses_project_config() {
234 let tmp = tempfile::tempdir().expect("temp dir");
235 let root = tmp.path();
236 std::fs::create_dir_all(root.join("src")).expect("create src");
237 std::fs::write(
238 root.join("shape.toml"),
239 r#"
240[[extensions]]
241name = "proj_ext_unique_for_test"
242path = "./extensions/libproj.so"
243"#,
244 )
245 .expect("write shape.toml");
246 std::fs::write(root.join("src/main.shape"), "use proj_ext_unique_for_test")
247 .expect("write main");
248
249 let spec = declared_extension_spec_for_module(
250 "proj_ext_unique_for_test",
251 Some(&root.join("src/main.shape")),
252 None,
253 None,
254 )
255 .expect("project extension should be discovered");
256
257 assert_eq!(spec.name, "proj_ext_unique_for_test");
258 assert_eq!(spec.path, root.join("extensions/libproj.so"));
259 }
260
261 #[test]
262 fn test_declared_extension_specs_frontmatter_overrides_project() {
263 let tmp = tempfile::tempdir().expect("temp dir");
264 let root = tmp.path();
265 std::fs::create_dir_all(root.join("src")).expect("create src");
266 std::fs::write(
267 root.join("shape.toml"),
268 r#"
269[[extensions]]
270name = "duckdb"
271path = "./project/libproject.so"
272"#,
273 )
274 .expect("write shape.toml");
275 std::fs::write(root.join("src/main.shape"), "use duckdb").expect("write main");
276
277 let source = r#"---
278[[extensions]]
279name = "duckdb"
280path = "./frontmatter/libfront.so"
281---
282use duckdb
283"#;
284
285 let spec = declared_extension_spec_for_module(
286 "duckdb",
287 Some(&root.join("src/main.shape")),
288 None,
289 Some(source),
290 )
291 .expect("frontmatter extension should be discovered");
292
293 assert_eq!(spec.path, root.join("src/frontmatter/libfront.so"));
294 }
295}