Skip to main content

systemprompt_loader/extension_loader/
mod.rs

1//! Discovers `extensions/<name>/manifest.yaml` files and resolves the
2//! binaries those manifests reference.
3//!
4//! All operations are infallible at the public API level — failures are
5//! either represented as "not found" (`Option`, empty `Vec`) or surfaced
6//! through the [`ExtensionValidationResult`] returned by
7//! [`ExtensionLoader::validate`].
8
9mod manifest;
10mod result;
11
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15
16use systemprompt_models::DiscoveredExtension;
17
18use manifest::{load_manifest, mtime_of};
19
20pub use result::ExtensionValidationResult;
21
22const CARGO_TARGET: &str = "target";
23
24#[derive(Debug, Clone, Copy)]
25pub struct ExtensionLoader;
26
27impl ExtensionLoader {
28    #[must_use]
29    pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
30        let extensions_dir = project_root.join("extensions");
31
32        if !extensions_dir.exists() {
33            return vec![];
34        }
35
36        let mut discovered = vec![];
37
38        Self::scan_directory(&extensions_dir, &mut discovered);
39
40        if let Ok(entries) = fs::read_dir(&extensions_dir) {
41            for entry in entries.flatten() {
42                let path = entry.path();
43                if path.is_dir() {
44                    Self::scan_directory(&path, &mut discovered);
45                }
46            }
47        }
48
49        discovered
50    }
51
52    fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
53        let Ok(entries) = fs::read_dir(dir) else {
54            return;
55        };
56
57        for entry in entries.flatten() {
58            let ext_dir = entry.path();
59            if !ext_dir.is_dir() {
60                continue;
61            }
62
63            let manifest_path = ext_dir.join("manifest.yaml");
64            if manifest_path.exists() {
65                match load_manifest(&manifest_path) {
66                    Ok(manifest) => {
67                        discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
68                    },
69                    Err(e) => {
70                        tracing::warn!(
71                            path = %manifest_path.display(),
72                            error = %e,
73                            "Failed to parse extension manifest, skipping"
74                        );
75                    },
76                }
77            }
78        }
79    }
80
81    #[must_use]
82    pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
83        Self::discover(project_root)
84            .into_iter()
85            .filter(|e| e.is_mcp() && e.is_enabled())
86            .collect()
87    }
88
89    #[must_use]
90    pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
91        Self::discover(project_root)
92            .into_iter()
93            .filter(|e| e.is_cli() && e.is_enabled())
94            .collect()
95    }
96
97    #[must_use]
98    pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
99        Self::get_enabled_cli_extensions(project_root)
100            .into_iter()
101            .find(|e| {
102                e.binary_name()
103                    .is_some_and(|b| b == name || e.manifest.extension.name == name)
104            })
105    }
106
107    #[must_use]
108    pub fn get_cli_binary_path(
109        project_root: &Path,
110        binary_name: &str,
111    ) -> Option<std::path::PathBuf> {
112        let release_path = project_root
113            .join(CARGO_TARGET)
114            .join("release")
115            .join(binary_name);
116        if release_path.exists() {
117            return Some(release_path);
118        }
119
120        let debug_path = project_root
121            .join(CARGO_TARGET)
122            .join("debug")
123            .join(binary_name);
124        if debug_path.exists() {
125            return Some(debug_path);
126        }
127
128        None
129    }
130
131    #[must_use]
132    pub fn resolve_bin_directory(
133        project_root: &Path,
134        override_path: Option<&Path>,
135    ) -> std::path::PathBuf {
136        if let Some(path) = override_path {
137            return path.to_path_buf();
138        }
139
140        let release_dir = project_root.join(CARGO_TARGET).join("release");
141        let debug_dir = project_root.join(CARGO_TARGET).join("debug");
142
143        let release_binary = release_dir.join("systemprompt");
144        let debug_binary = debug_dir.join("systemprompt");
145
146        match (release_binary.exists(), debug_binary.exists()) {
147            (true, true) => {
148                let release_mtime = mtime_of(&release_binary);
149                let debug_mtime = mtime_of(&debug_binary);
150
151                match (release_mtime, debug_mtime) {
152                    (Some(r), Some(d)) if d > r => debug_dir,
153                    _ => release_dir,
154                }
155            },
156            (true | false, false) => release_dir,
157            (false, true) => debug_dir,
158        }
159    }
160
161    #[must_use]
162    pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
163        let extensions = Self::get_enabled_mcp_extensions(project_root);
164        let target_dir = Self::resolve_bin_directory(project_root, None);
165
166        extensions
167            .into_iter()
168            .filter_map(|ext| {
169                ext.binary_name().and_then(|binary| {
170                    let binary_path = target_dir.join(binary);
171                    if binary_path.exists() {
172                        None
173                    } else {
174                        Some((binary.to_string(), ext.path.clone()))
175                    }
176                })
177            })
178            .collect()
179    }
180
181    #[must_use]
182    pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
183        Self::get_enabled_mcp_extensions(project_root)
184            .iter()
185            .filter_map(|e| e.binary_name().map(String::from))
186            .collect()
187    }
188
189    #[must_use]
190    pub fn get_production_mcp_binary_names(
191        project_root: &Path,
192        services_config: &systemprompt_models::ServicesConfig,
193    ) -> Vec<String> {
194        Self::get_enabled_mcp_extensions(project_root)
195            .iter()
196            .filter_map(|e| {
197                let binary = e.binary_name()?;
198                let is_dev_only = services_config
199                    .mcp_servers
200                    .values()
201                    .find(|d| d.binary == binary)
202                    .is_some_and(|d| d.dev_only);
203                (!is_dev_only).then(|| binary.to_string())
204            })
205            .collect()
206    }
207
208    #[must_use]
209    pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
210        Self::discover(project_root)
211            .into_iter()
212            .filter_map(|ext| {
213                let name = ext.binary_name()?.to_string();
214                Some((name, ext))
215            })
216            .collect()
217    }
218
219    #[must_use]
220    pub fn validate(project_root: &Path) -> ExtensionValidationResult {
221        ExtensionValidationResult {
222            discovered: Self::discover(project_root),
223            missing_binaries: Self::validate_mcp_binaries(project_root),
224            missing_manifests: vec![],
225        }
226    }
227}