systemprompt_loader/extension_loader/
mod.rs1mod manifest;
10mod result;
11
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15
16use systemprompt_models::DiscoveredExtension;
17
18use manifest::{load_manifest, mtime_of};
19
20pub use result::ExtensionValidationResult;
21
22const CARGO_TARGET: &str = "target";
23
24#[derive(Debug, Clone, Copy)]
25pub struct ExtensionLoader;
26
27impl ExtensionLoader {
28 #[must_use]
29 pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
30 let extensions_dir = project_root.join("extensions");
31
32 if !extensions_dir.exists() {
33 return vec![];
34 }
35
36 let mut discovered = vec![];
37
38 Self::scan_directory(&extensions_dir, &mut discovered);
39
40 if let Ok(entries) = fs::read_dir(&extensions_dir) {
41 for entry in entries.flatten() {
42 let path = entry.path();
43 if path.is_dir() {
44 Self::scan_directory(&path, &mut discovered);
45 }
46 }
47 }
48
49 discovered
50 }
51
52 fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
53 let Ok(entries) = fs::read_dir(dir) else {
54 return;
55 };
56
57 for entry in entries.flatten() {
58 let ext_dir = entry.path();
59 if !ext_dir.is_dir() {
60 continue;
61 }
62
63 let manifest_path = ext_dir.join("manifest.yaml");
64 if manifest_path.exists() {
65 match load_manifest(&manifest_path) {
66 Ok(manifest) => {
67 discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
68 },
69 Err(e) => {
70 tracing::warn!(
71 path = %manifest_path.display(),
72 error = %e,
73 "Failed to parse extension manifest, skipping"
74 );
75 },
76 }
77 }
78 }
79 }
80
81 #[must_use]
82 pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
83 Self::discover(project_root)
84 .into_iter()
85 .filter(|e| e.is_mcp() && e.is_enabled())
86 .collect()
87 }
88
89 #[must_use]
90 pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
91 Self::discover(project_root)
92 .into_iter()
93 .filter(|e| e.is_cli() && e.is_enabled())
94 .collect()
95 }
96
97 #[must_use]
98 pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
99 Self::get_enabled_cli_extensions(project_root)
100 .into_iter()
101 .find(|e| {
102 e.binary_name()
103 .is_some_and(|b| b == name || e.manifest.extension.name == name)
104 })
105 }
106
107 #[must_use]
108 pub fn get_cli_binary_path(
109 project_root: &Path,
110 binary_name: &str,
111 ) -> Option<std::path::PathBuf> {
112 let release_path = project_root
113 .join(CARGO_TARGET)
114 .join("release")
115 .join(binary_name);
116 if release_path.exists() {
117 return Some(release_path);
118 }
119
120 let debug_path = project_root
121 .join(CARGO_TARGET)
122 .join("debug")
123 .join(binary_name);
124 if debug_path.exists() {
125 return Some(debug_path);
126 }
127
128 None
129 }
130
131 #[must_use]
132 pub fn resolve_bin_directory(
133 project_root: &Path,
134 override_path: Option<&Path>,
135 ) -> std::path::PathBuf {
136 if let Some(path) = override_path {
137 return path.to_path_buf();
138 }
139
140 let release_dir = project_root.join(CARGO_TARGET).join("release");
141 let debug_dir = project_root.join(CARGO_TARGET).join("debug");
142
143 let release_binary = release_dir.join("systemprompt");
144 let debug_binary = debug_dir.join("systemprompt");
145
146 match (release_binary.exists(), debug_binary.exists()) {
147 (true, true) => {
148 let release_mtime = mtime_of(&release_binary);
149 let debug_mtime = mtime_of(&debug_binary);
150
151 match (release_mtime, debug_mtime) {
152 (Some(r), Some(d)) if d > r => debug_dir,
153 _ => release_dir,
154 }
155 },
156 (true | false, false) => release_dir,
157 (false, true) => debug_dir,
158 }
159 }
160
161 #[must_use]
162 pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
163 let extensions = Self::get_enabled_mcp_extensions(project_root);
164 let target_dir = Self::resolve_bin_directory(project_root, None);
165
166 extensions
167 .into_iter()
168 .filter_map(|ext| {
169 ext.binary_name().and_then(|binary| {
170 let binary_path = target_dir.join(binary);
171 if binary_path.exists() {
172 None
173 } else {
174 Some((binary.to_string(), ext.path.clone()))
175 }
176 })
177 })
178 .collect()
179 }
180
181 #[must_use]
182 pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
183 Self::get_enabled_mcp_extensions(project_root)
184 .iter()
185 .filter_map(|e| e.binary_name().map(String::from))
186 .collect()
187 }
188
189 #[must_use]
190 pub fn get_production_mcp_binary_names(
191 project_root: &Path,
192 services_config: &systemprompt_models::ServicesConfig,
193 ) -> Vec<String> {
194 Self::get_enabled_mcp_extensions(project_root)
195 .iter()
196 .filter_map(|e| {
197 let binary = e.binary_name()?;
198 let is_dev_only = services_config
199 .mcp_servers
200 .values()
201 .find(|d| d.binary == binary)
202 .is_some_and(|d| d.dev_only);
203 (!is_dev_only).then(|| binary.to_string())
204 })
205 .collect()
206 }
207
208 #[must_use]
209 pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
210 Self::discover(project_root)
211 .into_iter()
212 .filter_map(|ext| {
213 let name = ext.binary_name()?.to_string();
214 Some((name, ext))
215 })
216 .collect()
217 }
218
219 #[must_use]
220 pub fn validate(project_root: &Path) -> ExtensionValidationResult {
221 ExtensionValidationResult {
222 discovered: Self::discover(project_root),
223 missing_binaries: Self::validate_mcp_binaries(project_root),
224 missing_manifests: vec![],
225 }
226 }
227}