systemprompt_loader/
extension_loader.rs1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6use systemprompt_models::{DiscoveredExtension, ExtensionManifest};
7
8const CARGO_TARGET: &str = "target";
9
10#[derive(Debug, Clone, Copy)]
11pub struct ExtensionLoader;
12
13impl ExtensionLoader {
14 pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
15 let extensions_dir = project_root.join("extensions");
16
17 if !extensions_dir.exists() {
18 return vec![];
19 }
20
21 let mut discovered = vec![];
22
23 Self::scan_directory(&extensions_dir, &mut discovered);
24
25 if let Ok(entries) = fs::read_dir(&extensions_dir) {
26 for entry in entries.flatten() {
27 let path = entry.path();
28 if path.is_dir() {
29 Self::scan_directory(&path, &mut discovered);
30 }
31 }
32 }
33
34 discovered
35 }
36
37 fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
38 let Ok(entries) = fs::read_dir(dir) else {
39 return;
40 };
41
42 for entry in entries.flatten() {
43 let ext_dir = entry.path();
44 if !ext_dir.is_dir() {
45 continue;
46 }
47
48 let manifest_path = ext_dir.join("manifest.yaml");
49 if manifest_path.exists() {
50 match Self::load_manifest(&manifest_path) {
51 Ok(manifest) => {
52 discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
53 },
54 Err(e) => {
55 tracing::warn!(
56 path = %manifest_path.display(),
57 error = %e,
58 "Failed to parse extension manifest, skipping"
59 );
60 },
61 }
62 }
63 }
64 }
65
66 fn load_manifest(path: &Path) -> Result<ExtensionManifest> {
67 let content = fs::read_to_string(path)
68 .with_context(|| format!("Failed to read manifest: {}", path.display()))?;
69
70 serde_yaml::from_str(&content)
71 .with_context(|| format!("Failed to parse manifest: {}", path.display()))
72 }
73
74 pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
75 Self::discover(project_root)
76 .into_iter()
77 .filter(|e| e.is_mcp() && e.is_enabled())
78 .collect()
79 }
80
81 pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
82 Self::discover(project_root)
83 .into_iter()
84 .filter(|e| e.is_cli() && e.is_enabled())
85 .collect()
86 }
87
88 pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
89 Self::get_enabled_cli_extensions(project_root)
90 .into_iter()
91 .find(|e| {
92 e.binary_name()
93 .is_some_and(|b| b == name || e.manifest.extension.name == name)
94 })
95 }
96
97 pub fn get_cli_binary_path(
98 project_root: &Path,
99 binary_name: &str,
100 ) -> Option<std::path::PathBuf> {
101 let release_path = project_root
102 .join(CARGO_TARGET)
103 .join("release")
104 .join(binary_name);
105 if release_path.exists() {
106 return Some(release_path);
107 }
108
109 let debug_path = project_root
110 .join(CARGO_TARGET)
111 .join("debug")
112 .join(binary_name);
113 if debug_path.exists() {
114 return Some(debug_path);
115 }
116
117 None
118 }
119
120 pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
121 let extensions = Self::get_enabled_mcp_extensions(project_root);
122 let target_dir = project_root.join(CARGO_TARGET).join("release");
123
124 extensions
125 .into_iter()
126 .filter_map(|ext| {
127 ext.binary_name().and_then(|binary| {
128 let binary_path = target_dir.join(binary);
129 if binary_path.exists() {
130 None
131 } else {
132 Some((binary.to_string(), ext.path.clone()))
133 }
134 })
135 })
136 .collect()
137 }
138
139 pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
140 Self::get_enabled_mcp_extensions(project_root)
141 .iter()
142 .filter_map(|e| e.binary_name().map(String::from))
143 .collect()
144 }
145
146 pub fn get_production_mcp_binary_names(
147 project_root: &Path,
148 services_config: &systemprompt_models::ServicesConfig,
149 ) -> Vec<String> {
150 Self::get_enabled_mcp_extensions(project_root)
151 .iter()
152 .filter_map(|e| {
153 let binary = e.binary_name()?;
154 let is_dev_only = services_config
155 .mcp_servers
156 .values()
157 .find(|d| d.binary == binary)
158 .is_some_and(|d| d.dev_only);
159 (!is_dev_only).then(|| binary.to_string())
160 })
161 .collect()
162 }
163
164 pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
165 Self::discover(project_root)
166 .into_iter()
167 .filter_map(|ext| {
168 let name = ext.binary_name()?.to_string();
169 Some((name, ext))
170 })
171 .collect()
172 }
173
174 pub fn validate(project_root: &Path) -> ExtensionValidationResult {
175 ExtensionValidationResult {
176 discovered: Self::discover(project_root),
177 missing_binaries: Self::validate_mcp_binaries(project_root),
178 missing_manifests: vec![],
179 }
180 }
181}
182
183#[derive(Debug)]
184pub struct ExtensionValidationResult {
185 pub discovered: Vec<DiscoveredExtension>,
186 pub missing_binaries: Vec<(String, std::path::PathBuf)>,
187 pub missing_manifests: Vec<std::path::PathBuf>,
188}
189
190impl ExtensionValidationResult {
191 pub fn is_valid(&self) -> bool {
192 self.missing_binaries.is_empty()
193 }
194
195 pub fn format_missing_binaries(&self) -> String {
196 self.missing_binaries
197 .iter()
198 .map(|(binary, path)| format!(" ✗ {} ({})", binary, path.display()))
199 .collect::<Vec<_>>()
200 .join("\n")
201 }
202}