Skip to main content

systemprompt_loader/
extension_loader.rs

1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6use systemprompt_models::{DiscoveredExtension, ExtensionManifest};
7
8const CARGO_TARGET: &str = "target";
9
10#[derive(Debug, Clone, Copy)]
11pub struct ExtensionLoader;
12
13impl ExtensionLoader {
14    pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
15        let extensions_dir = project_root.join("extensions");
16
17        if !extensions_dir.exists() {
18            return vec![];
19        }
20
21        let mut discovered = vec![];
22
23        Self::scan_directory(&extensions_dir, &mut discovered);
24
25        if let Ok(entries) = fs::read_dir(&extensions_dir) {
26            for entry in entries.flatten() {
27                let path = entry.path();
28                if path.is_dir() {
29                    Self::scan_directory(&path, &mut discovered);
30                }
31            }
32        }
33
34        discovered
35    }
36
37    fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
38        let Ok(entries) = fs::read_dir(dir) else {
39            return;
40        };
41
42        for entry in entries.flatten() {
43            let ext_dir = entry.path();
44            if !ext_dir.is_dir() {
45                continue;
46            }
47
48            let manifest_path = ext_dir.join("manifest.yaml");
49            if manifest_path.exists() {
50                match Self::load_manifest(&manifest_path) {
51                    Ok(manifest) => {
52                        discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
53                    },
54                    Err(e) => {
55                        tracing::warn!(
56                            path = %manifest_path.display(),
57                            error = %e,
58                            "Failed to parse extension manifest, skipping"
59                        );
60                    },
61                }
62            }
63        }
64    }
65
66    fn load_manifest(path: &Path) -> Result<ExtensionManifest> {
67        let content = fs::read_to_string(path)
68            .with_context(|| format!("Failed to read manifest: {}", path.display()))?;
69
70        serde_yaml::from_str(&content)
71            .with_context(|| format!("Failed to parse manifest: {}", path.display()))
72    }
73
74    pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
75        Self::discover(project_root)
76            .into_iter()
77            .filter(|e| e.is_mcp() && e.is_enabled())
78            .collect()
79    }
80
81    pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
82        Self::discover(project_root)
83            .into_iter()
84            .filter(|e| e.is_cli() && e.is_enabled())
85            .collect()
86    }
87
88    pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
89        Self::get_enabled_cli_extensions(project_root)
90            .into_iter()
91            .find(|e| {
92                e.binary_name()
93                    .is_some_and(|b| b == name || e.manifest.extension.name == name)
94            })
95    }
96
97    pub fn get_cli_binary_path(
98        project_root: &Path,
99        binary_name: &str,
100    ) -> Option<std::path::PathBuf> {
101        let release_path = project_root
102            .join(CARGO_TARGET)
103            .join("release")
104            .join(binary_name);
105        if release_path.exists() {
106            return Some(release_path);
107        }
108
109        let debug_path = project_root
110            .join(CARGO_TARGET)
111            .join("debug")
112            .join(binary_name);
113        if debug_path.exists() {
114            return Some(debug_path);
115        }
116
117        None
118    }
119
120    pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
121        let extensions = Self::get_enabled_mcp_extensions(project_root);
122        let target_dir = project_root.join(CARGO_TARGET).join("release");
123
124        extensions
125            .into_iter()
126            .filter_map(|ext| {
127                ext.binary_name().and_then(|binary| {
128                    let binary_path = target_dir.join(binary);
129                    if binary_path.exists() {
130                        None
131                    } else {
132                        Some((binary.to_string(), ext.path.clone()))
133                    }
134                })
135            })
136            .collect()
137    }
138
139    pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
140        Self::get_enabled_mcp_extensions(project_root)
141            .iter()
142            .filter_map(|e| e.binary_name().map(String::from))
143            .collect()
144    }
145
146    pub fn get_production_mcp_binary_names(
147        project_root: &Path,
148        services_config: &systemprompt_models::ServicesConfig,
149    ) -> Vec<String> {
150        Self::get_enabled_mcp_extensions(project_root)
151            .iter()
152            .filter_map(|e| {
153                let binary = e.binary_name()?;
154                let is_dev_only = services_config
155                    .mcp_servers
156                    .values()
157                    .find(|d| d.binary == binary)
158                    .is_some_and(|d| d.dev_only);
159                (!is_dev_only).then(|| binary.to_string())
160            })
161            .collect()
162    }
163
164    pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
165        Self::discover(project_root)
166            .into_iter()
167            .filter_map(|ext| {
168                let name = ext.binary_name()?.to_string();
169                Some((name, ext))
170            })
171            .collect()
172    }
173
174    pub fn validate(project_root: &Path) -> ExtensionValidationResult {
175        ExtensionValidationResult {
176            discovered: Self::discover(project_root),
177            missing_binaries: Self::validate_mcp_binaries(project_root),
178            missing_manifests: vec![],
179        }
180    }
181}
182
183#[derive(Debug)]
184pub struct ExtensionValidationResult {
185    pub discovered: Vec<DiscoveredExtension>,
186    pub missing_binaries: Vec<(String, std::path::PathBuf)>,
187    pub missing_manifests: Vec<std::path::PathBuf>,
188}
189
190impl ExtensionValidationResult {
191    pub fn is_valid(&self) -> bool {
192        self.missing_binaries.is_empty()
193    }
194
195    pub fn format_missing_binaries(&self) -> String {
196        self.missing_binaries
197            .iter()
198            .map(|(binary, path)| format!("  ✗ {} ({})", binary, path.display()))
199            .collect::<Vec<_>>()
200            .join("\n")
201    }
202}