systemprompt_loader/
enhanced_config_loader.rs1use anyhow::{Context, Result};
2use std::collections::{HashMap, HashSet};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use systemprompt_models::mcp::Deployment;
7use systemprompt_models::services::{
8 AgentConfig, AiConfig, PartialServicesConfig, PluginConfig, SchedulerConfig, ServicesConfig,
9 Settings as ServicesSettings, WebConfig,
10};
11use systemprompt_models::AppPaths;
12
13use crate::ConfigWriter;
14
15#[derive(Debug)]
16pub struct EnhancedConfigLoader {
17 base_path: PathBuf,
18 config_path: PathBuf,
19}
20
21#[derive(serde::Deserialize)]
22struct RootConfig {
23 #[serde(default)]
24 includes: Vec<String>,
25 #[serde(flatten)]
26 config: PartialServicesRootConfig,
27}
28
29#[derive(serde::Deserialize, Default)]
30struct PartialServicesRootConfig {
31 #[serde(default)]
32 pub agents: HashMap<String, AgentConfig>,
33 #[serde(default)]
34 pub mcp_servers: HashMap<String, Deployment>,
35 #[serde(default)]
36 pub settings: ServicesSettings,
37 #[serde(default)]
38 pub scheduler: Option<SchedulerConfig>,
39 #[serde(default)]
40 pub ai: Option<AiConfig>,
41 #[serde(default)]
42 pub web: Option<WebConfig>,
43 #[serde(default)]
44 pub plugins: HashMap<String, PluginConfig>,
45}
46
47impl EnhancedConfigLoader {
48 pub fn new(config_path: PathBuf) -> Self {
49 let base_path = config_path
50 .parent()
51 .unwrap_or_else(|| Path::new("."))
52 .to_path_buf();
53 Self {
54 base_path,
55 config_path,
56 }
57 }
58
59 pub fn from_env() -> Result<Self> {
60 let paths = AppPaths::get().map_err(|e| anyhow::anyhow!("{}", e))?;
61 let config_path = paths.system().settings().to_path_buf();
62 Ok(Self::new(config_path))
63 }
64
65 pub fn load(&self) -> Result<ServicesConfig> {
66 let content = fs::read_to_string(&self.config_path)
67 .with_context(|| format!("Failed to read config: {}", self.config_path.display()))?;
68
69 self.load_from_content(&content)
70 }
71
72 pub fn load_from_content(&self, content: &str) -> Result<ServicesConfig> {
73 let root: RootConfig = serde_yaml::from_str(content)
74 .with_context(|| format!("Failed to parse config: {}", self.config_path.display()))?;
75
76 let mut merged = ServicesConfig {
77 agents: root.config.agents,
78 mcp_servers: root.config.mcp_servers,
79 settings: root.config.settings,
80 scheduler: root.config.scheduler,
81 ai: root.config.ai.unwrap_or_else(AiConfig::default),
82 web: root.config.web.unwrap_or_else(WebConfig::default),
83 plugins: root.config.plugins,
84 };
85
86 for include_path in &root.includes {
87 let partial = self.load_include(include_path)?;
88 Self::merge_partial(&mut merged, partial)?;
89 }
90
91 self.discover_and_load_agents(&root.includes, &mut merged)?;
92
93 self.resolve_includes(&mut merged)?;
94
95 merged.settings.apply_env_overrides();
96
97 merged
98 .validate()
99 .map_err(|e| anyhow::anyhow!("Services config validation failed: {}", e))?;
100
101 Ok(merged)
102 }
103
104 fn discover_and_load_agents(
105 &self,
106 existing_includes: &[String],
107 merged: &mut ServicesConfig,
108 ) -> Result<()> {
109 let agents_dir = self.base_path.join("../agents");
110
111 if !agents_dir.exists() {
112 return Ok(());
113 }
114
115 let included_files: HashSet<String> = existing_includes
116 .iter()
117 .filter_map(|inc| {
118 Path::new(inc)
119 .file_name()
120 .map(|f| f.to_string_lossy().to_string())
121 })
122 .collect();
123
124 let entries = fs::read_dir(&agents_dir).with_context(|| {
125 format!("Failed to read agents directory: {}", agents_dir.display())
126 })?;
127
128 for entry in entries {
129 let path = entry
130 .with_context(|| format!("Failed to read entry in: {}", agents_dir.display()))?
131 .path();
132
133 let is_yaml = path
134 .extension()
135 .is_some_and(|ext| ext == "yaml" || ext == "yml");
136
137 if !is_yaml {
138 continue;
139 }
140
141 let file_name = path
142 .file_name()
143 .map(|f| f.to_string_lossy().to_string())
144 .ok_or_else(|| anyhow::anyhow!("Invalid file path: {}", path.display()))?;
145
146 if included_files.contains(&file_name) {
147 continue;
148 }
149
150 let relative_path = format!("../agents/{}", file_name);
151 let partial = self.load_include(&relative_path)?;
152 Self::merge_partial(merged, partial)?;
153
154 ConfigWriter::add_include(&relative_path, &self.config_path).with_context(|| {
155 format!(
156 "Failed to add discovered agent to includes: {}",
157 relative_path
158 )
159 })?;
160 }
161
162 Ok(())
163 }
164
165 fn load_include(&self, path: &str) -> Result<PartialServicesConfig> {
166 let full_path = self.base_path.join(path);
167
168 if !full_path.exists() {
169 anyhow::bail!(
170 "Include file not found: {}\nReferenced in: {}/config.yaml\nEither create the \
171 file or remove it from the includes list.",
172 full_path.display(),
173 self.base_path.display()
174 );
175 }
176
177 let content = fs::read_to_string(&full_path)
178 .with_context(|| format!("Failed to read include: {}", full_path.display()))?;
179
180 serde_yaml::from_str(&content)
181 .with_context(|| format!("Failed to parse include: {}", full_path.display()))
182 }
183
184 fn merge_partial(target: &mut ServicesConfig, partial: PartialServicesConfig) -> Result<()> {
185 for (name, agent) in partial.agents {
186 if target.agents.contains_key(&name) {
187 anyhow::bail!("Duplicate agent definition: {name}");
188 }
189 target.agents.insert(name, agent);
190 }
191
192 for (name, mcp) in partial.mcp_servers {
193 if target.mcp_servers.contains_key(&name) {
194 anyhow::bail!("Duplicate MCP server definition: {name}");
195 }
196 target.mcp_servers.insert(name, mcp);
197 }
198
199 if partial.scheduler.is_some() && target.scheduler.is_none() {
200 target.scheduler = partial.scheduler;
201 }
202
203 if let Some(ai) = partial.ai {
204 if target.ai.providers.is_empty() && !ai.providers.is_empty() {
205 target.ai = ai;
206 } else {
207 for (name, provider) in ai.providers {
208 target.ai.providers.insert(name, provider);
209 }
210 }
211 }
212
213 if let Some(web) = partial.web {
214 target.web = web;
215 }
216
217 for (name, plugin) in partial.plugins {
218 if target.plugins.contains_key(&name) {
219 anyhow::bail!("Duplicate plugin definition: {name}");
220 }
221 target.plugins.insert(name, plugin);
222 }
223
224 Ok(())
225 }
226
227 fn resolve_includes(&self, config: &mut ServicesConfig) -> Result<()> {
228 for (name, agent) in &mut config.agents {
229 if let Some(ref system_prompt) = agent.metadata.system_prompt {
230 if let Some(include_path) = system_prompt.strip_prefix("!include ") {
231 let full_path = self.base_path.join(include_path.trim());
232 let resolved = fs::read_to_string(&full_path).with_context(|| {
233 format!(
234 "Failed to resolve system_prompt include for agent '{name}': {}",
235 full_path.display()
236 )
237 })?;
238 agent.metadata.system_prompt = Some(resolved);
239 }
240 }
241 }
242
243 Ok(())
244 }
245
246 pub fn validate_file(path: &Path) -> Result<()> {
247 let loader = Self::new(path.to_path_buf());
248 let _config = loader.load()?;
249 Ok(())
250 }
251
252 pub fn get_includes(&self) -> Result<Vec<String>> {
253 #[derive(serde::Deserialize)]
254 struct IncludesOnly {
255 #[serde(default)]
256 includes: Vec<String>,
257 }
258
259 let content = fs::read_to_string(&self.config_path)?;
260 let parsed: IncludesOnly = serde_yaml::from_str(&content)?;
261 Ok(parsed.includes)
262 }
263
264 pub fn list_all_includes(&self) -> Result<Vec<(String, bool)>> {
265 self.get_includes()?
266 .into_iter()
267 .map(|include| {
268 let exists = self.base_path.join(&include).exists();
269 Ok((include, exists))
270 })
271 .collect()
272 }
273
274 pub fn base_path(&self) -> &Path {
275 &self.base_path
276 }
277}