Skip to main content

systemprompt_loader/
extension_loader.rs

1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6use systemprompt_models::{DiscoveredExtension, ExtensionManifest};
7
8const CARGO_TARGET: &str = "target";
9
10#[derive(Debug, Clone, Copy)]
11pub struct ExtensionLoader;
12
13impl ExtensionLoader {
14    pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
15        let extensions_dir = project_root.join("extensions");
16
17        if !extensions_dir.exists() {
18            return vec![];
19        }
20
21        let mut discovered = vec![];
22
23        Self::scan_directory(&extensions_dir, &mut discovered);
24
25        if let Ok(entries) = fs::read_dir(&extensions_dir) {
26            for entry in entries.flatten() {
27                let path = entry.path();
28                if path.is_dir() {
29                    Self::scan_directory(&path, &mut discovered);
30                }
31            }
32        }
33
34        discovered
35    }
36
37    fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
38        let Ok(entries) = fs::read_dir(dir) else {
39            return;
40        };
41
42        for entry in entries.flatten() {
43            let ext_dir = entry.path();
44            if !ext_dir.is_dir() {
45                continue;
46            }
47
48            let manifest_path = ext_dir.join("manifest.yaml");
49            if manifest_path.exists() {
50                match Self::load_manifest(&manifest_path) {
51                    Ok(manifest) => {
52                        discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
53                    },
54                    Err(e) => {
55                        tracing::warn!(
56                            path = %manifest_path.display(),
57                            error = %e,
58                            "Failed to parse extension manifest, skipping"
59                        );
60                    },
61                }
62            }
63        }
64    }
65
66    fn load_manifest(path: &Path) -> Result<ExtensionManifest> {
67        let content = fs::read_to_string(path)
68            .with_context(|| format!("Failed to read manifest: {}", path.display()))?;
69
70        serde_yaml::from_str(&content)
71            .with_context(|| format!("Failed to parse manifest: {}", path.display()))
72    }
73
74    pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
75        Self::discover(project_root)
76            .into_iter()
77            .filter(|e| e.is_mcp() && e.is_enabled())
78            .collect()
79    }
80
81    pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
82        Self::discover(project_root)
83            .into_iter()
84            .filter(|e| e.is_cli() && e.is_enabled())
85            .collect()
86    }
87
88    pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
89        Self::get_enabled_cli_extensions(project_root)
90            .into_iter()
91            .find(|e| {
92                e.binary_name()
93                    .is_some_and(|b| b == name || e.manifest.extension.name == name)
94            })
95    }
96
97    pub fn get_cli_binary_path(
98        project_root: &Path,
99        binary_name: &str,
100    ) -> Option<std::path::PathBuf> {
101        let release_path = project_root
102            .join(CARGO_TARGET)
103            .join("release")
104            .join(binary_name);
105        if release_path.exists() {
106            return Some(release_path);
107        }
108
109        let debug_path = project_root
110            .join(CARGO_TARGET)
111            .join("debug")
112            .join(binary_name);
113        if debug_path.exists() {
114            return Some(debug_path);
115        }
116
117        None
118    }
119
120    pub fn resolve_bin_directory(project_root: &Path) -> std::path::PathBuf {
121        let release_dir = project_root.join(CARGO_TARGET).join("release");
122        let debug_dir = project_root.join(CARGO_TARGET).join("debug");
123
124        let release_binary = release_dir.join("systemprompt");
125        let debug_binary = debug_dir.join("systemprompt");
126
127        match (release_binary.exists(), debug_binary.exists()) {
128            (true, true) => {
129                let release_mtime = fs::metadata(&release_binary)
130                    .and_then(|m| m.modified())
131                    .ok();
132                let debug_mtime = fs::metadata(&debug_binary).and_then(|m| m.modified()).ok();
133
134                match (release_mtime, debug_mtime) {
135                    (Some(r), Some(d)) if d > r => debug_dir,
136                    _ => release_dir,
137                }
138            },
139            (true | false, false) => release_dir,
140            (false, true) => debug_dir,
141        }
142    }
143
144    pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
145        let extensions = Self::get_enabled_mcp_extensions(project_root);
146        let target_dir = project_root.join(CARGO_TARGET).join("release");
147
148        extensions
149            .into_iter()
150            .filter_map(|ext| {
151                ext.binary_name().and_then(|binary| {
152                    let binary_path = target_dir.join(binary);
153                    if binary_path.exists() {
154                        None
155                    } else {
156                        Some((binary.to_string(), ext.path.clone()))
157                    }
158                })
159            })
160            .collect()
161    }
162
163    pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
164        Self::get_enabled_mcp_extensions(project_root)
165            .iter()
166            .filter_map(|e| e.binary_name().map(String::from))
167            .collect()
168    }
169
170    pub fn get_production_mcp_binary_names(
171        project_root: &Path,
172        services_config: &systemprompt_models::ServicesConfig,
173    ) -> Vec<String> {
174        Self::get_enabled_mcp_extensions(project_root)
175            .iter()
176            .filter_map(|e| {
177                let binary = e.binary_name()?;
178                let is_dev_only = services_config
179                    .mcp_servers
180                    .values()
181                    .find(|d| d.binary == binary)
182                    .is_some_and(|d| d.dev_only);
183                (!is_dev_only).then(|| binary.to_string())
184            })
185            .collect()
186    }
187
188    pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
189        Self::discover(project_root)
190            .into_iter()
191            .filter_map(|ext| {
192                let name = ext.binary_name()?.to_string();
193                Some((name, ext))
194            })
195            .collect()
196    }
197
198    pub fn validate(project_root: &Path) -> ExtensionValidationResult {
199        ExtensionValidationResult {
200            discovered: Self::discover(project_root),
201            missing_binaries: Self::validate_mcp_binaries(project_root),
202            missing_manifests: vec![],
203        }
204    }
205}
206
207#[derive(Debug)]
208pub struct ExtensionValidationResult {
209    pub discovered: Vec<DiscoveredExtension>,
210    pub missing_binaries: Vec<(String, std::path::PathBuf)>,
211    pub missing_manifests: Vec<std::path::PathBuf>,
212}
213
214impl ExtensionValidationResult {
215    pub fn is_valid(&self) -> bool {
216        self.missing_binaries.is_empty()
217    }
218
219    pub fn format_missing_binaries(&self) -> String {
220        self.missing_binaries
221            .iter()
222            .map(|(binary, path)| format!("  ✗ {} ({})", binary, path.display()))
223            .collect::<Vec<_>>()
224            .join("\n")
225    }
226}