Skip to main content

systemprompt_loader/
enhanced_config_loader.rs

1use anyhow::{Context, Result};
2use std::collections::{HashMap, HashSet};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use systemprompt_models::mcp::Deployment;
7use systemprompt_models::services::{
8    AgentConfig, AiConfig, PartialServicesConfig, PluginConfig, SchedulerConfig, ServicesConfig,
9    Settings as ServicesSettings, WebConfig,
10};
11use systemprompt_models::AppPaths;
12
13use crate::ConfigWriter;
14
15#[derive(Debug)]
16pub struct EnhancedConfigLoader {
17    base_path: PathBuf,
18    config_path: PathBuf,
19}
20
21#[derive(serde::Deserialize)]
22struct RootConfig {
23    #[serde(default)]
24    includes: Vec<String>,
25    #[serde(flatten)]
26    config: PartialServicesRootConfig,
27}
28
29#[derive(serde::Deserialize, Default)]
30struct PartialServicesRootConfig {
31    #[serde(default)]
32    pub agents: HashMap<String, AgentConfig>,
33    #[serde(default)]
34    pub mcp_servers: HashMap<String, Deployment>,
35    #[serde(default)]
36    pub settings: ServicesSettings,
37    #[serde(default)]
38    pub scheduler: Option<SchedulerConfig>,
39    #[serde(default)]
40    pub ai: Option<AiConfig>,
41    #[serde(default)]
42    pub web: Option<WebConfig>,
43    #[serde(default)]
44    pub plugins: HashMap<String, PluginConfig>,
45}
46
47impl EnhancedConfigLoader {
48    pub fn new(config_path: PathBuf) -> Self {
49        let base_path = config_path
50            .parent()
51            .unwrap_or_else(|| Path::new("."))
52            .to_path_buf();
53        Self {
54            base_path,
55            config_path,
56        }
57    }
58
59    pub fn from_env() -> Result<Self> {
60        let paths = AppPaths::get().map_err(|e| anyhow::anyhow!("{}", e))?;
61        let config_path = paths.system().settings().to_path_buf();
62        Ok(Self::new(config_path))
63    }
64
65    pub fn load(&self) -> Result<ServicesConfig> {
66        let content = fs::read_to_string(&self.config_path)
67            .with_context(|| format!("Failed to read config: {}", self.config_path.display()))?;
68
69        self.load_from_content(&content)
70    }
71
72    pub fn load_from_content(&self, content: &str) -> Result<ServicesConfig> {
73        let root: RootConfig = serde_yaml::from_str(content)
74            .with_context(|| format!("Failed to parse config: {}", self.config_path.display()))?;
75
76        let mut merged = ServicesConfig {
77            agents: root.config.agents,
78            mcp_servers: root.config.mcp_servers,
79            settings: root.config.settings,
80            scheduler: root.config.scheduler,
81            ai: root.config.ai.unwrap_or_else(AiConfig::default),
82            web: root.config.web.unwrap_or_else(WebConfig::default),
83            plugins: root.config.plugins,
84        };
85
86        for include_path in &root.includes {
87            let partial = self.load_include(include_path)?;
88            Self::merge_partial(&mut merged, partial)?;
89        }
90
91        self.discover_and_load_agents(&root.includes, &mut merged)?;
92
93        self.resolve_includes(&mut merged)?;
94
95        merged.settings.apply_env_overrides();
96
97        merged
98            .validate()
99            .map_err(|e| anyhow::anyhow!("Services config validation failed: {}", e))?;
100
101        Ok(merged)
102    }
103
104    fn discover_and_load_agents(
105        &self,
106        existing_includes: &[String],
107        merged: &mut ServicesConfig,
108    ) -> Result<()> {
109        let agents_dir = self.base_path.join("../agents");
110
111        if !agents_dir.exists() {
112            return Ok(());
113        }
114
115        let included_files: HashSet<String> = existing_includes
116            .iter()
117            .filter_map(|inc| {
118                Path::new(inc)
119                    .file_name()
120                    .map(|f| f.to_string_lossy().to_string())
121            })
122            .collect();
123
124        let entries = fs::read_dir(&agents_dir).with_context(|| {
125            format!("Failed to read agents directory: {}", agents_dir.display())
126        })?;
127
128        for entry in entries {
129            let path = entry
130                .with_context(|| format!("Failed to read entry in: {}", agents_dir.display()))?
131                .path();
132
133            let is_yaml = path
134                .extension()
135                .is_some_and(|ext| ext == "yaml" || ext == "yml");
136
137            if !is_yaml {
138                continue;
139            }
140
141            let file_name = path
142                .file_name()
143                .map(|f| f.to_string_lossy().to_string())
144                .ok_or_else(|| anyhow::anyhow!("Invalid file path: {}", path.display()))?;
145
146            if included_files.contains(&file_name) {
147                continue;
148            }
149
150            let relative_path = format!("../agents/{}", file_name);
151            let partial = self.load_include(&relative_path)?;
152            Self::merge_partial(merged, partial)?;
153
154            ConfigWriter::add_include(&relative_path, &self.config_path).with_context(|| {
155                format!(
156                    "Failed to add discovered agent to includes: {}",
157                    relative_path
158                )
159            })?;
160        }
161
162        Ok(())
163    }
164
165    fn load_include(&self, path: &str) -> Result<PartialServicesConfig> {
166        let full_path = self.base_path.join(path);
167
168        if !full_path.exists() {
169            anyhow::bail!(
170                "Include file not found: {}\nReferenced in: {}/config.yaml\nEither create the \
171                 file or remove it from the includes list.",
172                full_path.display(),
173                self.base_path.display()
174            );
175        }
176
177        let content = fs::read_to_string(&full_path)
178            .with_context(|| format!("Failed to read include: {}", full_path.display()))?;
179
180        serde_yaml::from_str(&content)
181            .with_context(|| format!("Failed to parse include: {}", full_path.display()))
182    }
183
184    fn merge_partial(target: &mut ServicesConfig, partial: PartialServicesConfig) -> Result<()> {
185        for (name, agent) in partial.agents {
186            if target.agents.contains_key(&name) {
187                anyhow::bail!("Duplicate agent definition: {name}");
188            }
189            target.agents.insert(name, agent);
190        }
191
192        for (name, mcp) in partial.mcp_servers {
193            if target.mcp_servers.contains_key(&name) {
194                anyhow::bail!("Duplicate MCP server definition: {name}");
195            }
196            target.mcp_servers.insert(name, mcp);
197        }
198
199        if partial.scheduler.is_some() && target.scheduler.is_none() {
200            target.scheduler = partial.scheduler;
201        }
202
203        if let Some(ai) = partial.ai {
204            if target.ai.providers.is_empty() && !ai.providers.is_empty() {
205                target.ai = ai;
206            } else {
207                for (name, provider) in ai.providers {
208                    target.ai.providers.insert(name, provider);
209                }
210            }
211        }
212
213        if let Some(web) = partial.web {
214            target.web = web;
215        }
216
217        for (name, plugin) in partial.plugins {
218            if target.plugins.contains_key(&name) {
219                anyhow::bail!("Duplicate plugin definition: {name}");
220            }
221            target.plugins.insert(name, plugin);
222        }
223
224        Ok(())
225    }
226
227    fn resolve_includes(&self, config: &mut ServicesConfig) -> Result<()> {
228        for (name, agent) in &mut config.agents {
229            if let Some(ref system_prompt) = agent.metadata.system_prompt {
230                if let Some(include_path) = system_prompt.strip_prefix("!include ") {
231                    let full_path = self.base_path.join(include_path.trim());
232                    let resolved = fs::read_to_string(&full_path).with_context(|| {
233                        format!(
234                            "Failed to resolve system_prompt include for agent '{name}': {}",
235                            full_path.display()
236                        )
237                    })?;
238                    agent.metadata.system_prompt = Some(resolved);
239                }
240            }
241        }
242
243        Ok(())
244    }
245
246    pub fn validate_file(path: &Path) -> Result<()> {
247        let loader = Self::new(path.to_path_buf());
248        let _config = loader.load()?;
249        Ok(())
250    }
251
252    pub fn get_includes(&self) -> Result<Vec<String>> {
253        #[derive(serde::Deserialize)]
254        struct IncludesOnly {
255            #[serde(default)]
256            includes: Vec<String>,
257        }
258
259        let content = fs::read_to_string(&self.config_path)?;
260        let parsed: IncludesOnly = serde_yaml::from_str(&content)?;
261        Ok(parsed.includes)
262    }
263
264    pub fn list_all_includes(&self) -> Result<Vec<(String, bool)>> {
265        self.get_includes()?
266            .into_iter()
267            .map(|include| {
268                let exists = self.base_path.join(&include).exists();
269                Ok((include, exists))
270            })
271            .collect()
272    }
273
274    pub fn base_path(&self) -> &Path {
275        &self.base_path
276    }
277}