systemprompt_loader/
extension_loader.rs1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6use systemprompt_models::{DiscoveredExtension, ExtensionManifest};
7
8const CARGO_TARGET: &str = "target";
9
10#[derive(Debug, Clone, Copy)]
11pub struct ExtensionLoader;
12
13impl ExtensionLoader {
14 pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
15 let extensions_dir = project_root.join("extensions");
16
17 if !extensions_dir.exists() {
18 return vec![];
19 }
20
21 let mut discovered = vec![];
22
23 Self::scan_directory(&extensions_dir, &mut discovered);
24
25 if let Ok(entries) = fs::read_dir(&extensions_dir) {
26 for entry in entries.flatten() {
27 let path = entry.path();
28 if path.is_dir() {
29 Self::scan_directory(&path, &mut discovered);
30 }
31 }
32 }
33
34 discovered
35 }
36
37 fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
38 let Ok(entries) = fs::read_dir(dir) else {
39 return;
40 };
41
42 for entry in entries.flatten() {
43 let ext_dir = entry.path();
44 if !ext_dir.is_dir() {
45 continue;
46 }
47
48 let manifest_path = ext_dir.join("manifest.yaml");
49 if manifest_path.exists() {
50 match Self::load_manifest(&manifest_path) {
51 Ok(manifest) => {
52 discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
53 },
54 Err(e) => {
55 tracing::warn!(
56 path = %manifest_path.display(),
57 error = %e,
58 "Failed to parse extension manifest, skipping"
59 );
60 },
61 }
62 }
63 }
64 }
65
66 fn load_manifest(path: &Path) -> Result<ExtensionManifest> {
67 let content = fs::read_to_string(path)
68 .with_context(|| format!("Failed to read manifest: {}", path.display()))?;
69
70 serde_yaml::from_str(&content)
71 .with_context(|| format!("Failed to parse manifest: {}", path.display()))
72 }
73
74 pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
75 Self::discover(project_root)
76 .into_iter()
77 .filter(|e| e.is_mcp() && e.is_enabled())
78 .collect()
79 }
80
81 pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
82 Self::discover(project_root)
83 .into_iter()
84 .filter(|e| e.is_cli() && e.is_enabled())
85 .collect()
86 }
87
88 pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
89 Self::get_enabled_cli_extensions(project_root)
90 .into_iter()
91 .find(|e| {
92 e.binary_name()
93 .is_some_and(|b| b == name || e.manifest.extension.name == name)
94 })
95 }
96
97 pub fn get_cli_binary_path(
98 project_root: &Path,
99 binary_name: &str,
100 ) -> Option<std::path::PathBuf> {
101 let release_path = project_root
102 .join(CARGO_TARGET)
103 .join("release")
104 .join(binary_name);
105 if release_path.exists() {
106 return Some(release_path);
107 }
108
109 let debug_path = project_root
110 .join(CARGO_TARGET)
111 .join("debug")
112 .join(binary_name);
113 if debug_path.exists() {
114 return Some(debug_path);
115 }
116
117 None
118 }
119
120 pub fn resolve_bin_directory(project_root: &Path) -> std::path::PathBuf {
121 let release_dir = project_root.join(CARGO_TARGET).join("release");
122 let debug_dir = project_root.join(CARGO_TARGET).join("debug");
123
124 let release_binary = release_dir.join("systemprompt");
125 let debug_binary = debug_dir.join("systemprompt");
126
127 match (release_binary.exists(), debug_binary.exists()) {
128 (true, true) => {
129 let release_mtime = fs::metadata(&release_binary)
130 .and_then(|m| m.modified())
131 .ok();
132 let debug_mtime = fs::metadata(&debug_binary).and_then(|m| m.modified()).ok();
133
134 match (release_mtime, debug_mtime) {
135 (Some(r), Some(d)) if d > r => debug_dir,
136 _ => release_dir,
137 }
138 },
139 (true | false, false) => release_dir,
140 (false, true) => debug_dir,
141 }
142 }
143
144 pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
145 let extensions = Self::get_enabled_mcp_extensions(project_root);
146 let target_dir = project_root.join(CARGO_TARGET).join("release");
147
148 extensions
149 .into_iter()
150 .filter_map(|ext| {
151 ext.binary_name().and_then(|binary| {
152 let binary_path = target_dir.join(binary);
153 if binary_path.exists() {
154 None
155 } else {
156 Some((binary.to_string(), ext.path.clone()))
157 }
158 })
159 })
160 .collect()
161 }
162
163 pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
164 Self::get_enabled_mcp_extensions(project_root)
165 .iter()
166 .filter_map(|e| e.binary_name().map(String::from))
167 .collect()
168 }
169
170 pub fn get_production_mcp_binary_names(
171 project_root: &Path,
172 services_config: &systemprompt_models::ServicesConfig,
173 ) -> Vec<String> {
174 Self::get_enabled_mcp_extensions(project_root)
175 .iter()
176 .filter_map(|e| {
177 let binary = e.binary_name()?;
178 let is_dev_only = services_config
179 .mcp_servers
180 .values()
181 .find(|d| d.binary == binary)
182 .is_some_and(|d| d.dev_only);
183 (!is_dev_only).then(|| binary.to_string())
184 })
185 .collect()
186 }
187
188 pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
189 Self::discover(project_root)
190 .into_iter()
191 .filter_map(|ext| {
192 let name = ext.binary_name()?.to_string();
193 Some((name, ext))
194 })
195 .collect()
196 }
197
198 pub fn validate(project_root: &Path) -> ExtensionValidationResult {
199 ExtensionValidationResult {
200 discovered: Self::discover(project_root),
201 missing_binaries: Self::validate_mcp_binaries(project_root),
202 missing_manifests: vec![],
203 }
204 }
205}
206
207#[derive(Debug)]
208pub struct ExtensionValidationResult {
209 pub discovered: Vec<DiscoveredExtension>,
210 pub missing_binaries: Vec<(String, std::path::PathBuf)>,
211 pub missing_manifests: Vec<std::path::PathBuf>,
212}
213
214impl ExtensionValidationResult {
215 pub fn is_valid(&self) -> bool {
216 self.missing_binaries.is_empty()
217 }
218
219 pub fn format_missing_binaries(&self) -> String {
220 self.missing_binaries
221 .iter()
222 .map(|(binary, path)| format!(" ✗ {} ({})", binary, path.display()))
223 .collect::<Vec<_>>()
224 .join("\n")
225 }
226}