Skip to main content

systemprompt_loader/
enhanced_config_loader.rs

1use anyhow::{Context, Result};
2use std::collections::{HashMap, HashSet};
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use systemprompt_models::mcp::Deployment;
7use systemprompt_models::services::{
8    AgentConfig, AiConfig, PartialServicesConfig, SchedulerConfig, ServicesConfig,
9    Settings as ServicesSettings, WebConfig,
10};
11use systemprompt_models::AppPaths;
12
13use crate::ConfigWriter;
14
15#[derive(Debug)]
16pub struct EnhancedConfigLoader {
17    base_path: PathBuf,
18    config_path: PathBuf,
19}
20
21#[derive(serde::Deserialize)]
22struct RootConfig {
23    #[serde(default)]
24    includes: Vec<String>,
25    #[serde(flatten)]
26    config: PartialServicesRootConfig,
27}
28
29#[derive(serde::Deserialize, Default)]
30struct PartialServicesRootConfig {
31    #[serde(default)]
32    pub agents: HashMap<String, AgentConfig>,
33    #[serde(default)]
34    pub mcp_servers: HashMap<String, Deployment>,
35    #[serde(default)]
36    pub settings: ServicesSettings,
37    #[serde(default)]
38    pub scheduler: Option<SchedulerConfig>,
39    #[serde(default)]
40    pub ai: Option<AiConfig>,
41    #[serde(default)]
42    pub web: Option<WebConfig>,
43}
44
45impl EnhancedConfigLoader {
46    pub fn new(config_path: PathBuf) -> Self {
47        let base_path = config_path
48            .parent()
49            .unwrap_or_else(|| Path::new("."))
50            .to_path_buf();
51        Self {
52            base_path,
53            config_path,
54        }
55    }
56
57    pub fn from_env() -> Result<Self> {
58        let paths = AppPaths::get().map_err(|e| anyhow::anyhow!("{}", e))?;
59        let config_path = paths.system().settings().to_path_buf();
60        Ok(Self::new(config_path))
61    }
62
63    pub fn load(&self) -> Result<ServicesConfig> {
64        let content = fs::read_to_string(&self.config_path)
65            .with_context(|| format!("Failed to read config: {}", self.config_path.display()))?;
66
67        self.load_from_content(&content)
68    }
69
70    pub fn load_from_content(&self, content: &str) -> Result<ServicesConfig> {
71        let root: RootConfig = serde_yaml::from_str(content)
72            .with_context(|| format!("Failed to parse config: {}", self.config_path.display()))?;
73
74        let mut merged = ServicesConfig {
75            agents: root.config.agents,
76            mcp_servers: root.config.mcp_servers,
77            settings: root.config.settings,
78            scheduler: root.config.scheduler,
79            ai: root.config.ai.unwrap_or_else(AiConfig::default),
80            web: root.config.web.unwrap_or_else(WebConfig::default),
81        };
82
83        for include_path in &root.includes {
84            let partial = self.load_include(include_path)?;
85            Self::merge_partial(&mut merged, partial)?;
86        }
87
88        self.discover_and_load_agents(&root.includes, &mut merged)?;
89
90        self.resolve_includes(&mut merged)?;
91
92        merged.settings.apply_env_overrides();
93
94        merged
95            .validate()
96            .map_err(|e| anyhow::anyhow!("Services config validation failed: {}", e))?;
97
98        Ok(merged)
99    }
100
101    fn discover_and_load_agents(
102        &self,
103        existing_includes: &[String],
104        merged: &mut ServicesConfig,
105    ) -> Result<()> {
106        let agents_dir = self.base_path.join("../agents");
107
108        if !agents_dir.exists() {
109            return Ok(());
110        }
111
112        let included_files: HashSet<String> = existing_includes
113            .iter()
114            .filter_map(|inc| {
115                Path::new(inc)
116                    .file_name()
117                    .map(|f| f.to_string_lossy().to_string())
118            })
119            .collect();
120
121        let entries = fs::read_dir(&agents_dir).with_context(|| {
122            format!("Failed to read agents directory: {}", agents_dir.display())
123        })?;
124
125        for entry in entries {
126            let path = entry
127                .with_context(|| format!("Failed to read entry in: {}", agents_dir.display()))?
128                .path();
129
130            let is_yaml = path
131                .extension()
132                .is_some_and(|ext| ext == "yaml" || ext == "yml");
133
134            if !is_yaml {
135                continue;
136            }
137
138            let file_name = path
139                .file_name()
140                .map(|f| f.to_string_lossy().to_string())
141                .ok_or_else(|| anyhow::anyhow!("Invalid file path: {}", path.display()))?;
142
143            if included_files.contains(&file_name) {
144                continue;
145            }
146
147            let relative_path = format!("../agents/{}", file_name);
148            let partial = self.load_include(&relative_path)?;
149            Self::merge_partial(merged, partial)?;
150
151            ConfigWriter::add_include(&relative_path, &self.config_path).with_context(|| {
152                format!(
153                    "Failed to add discovered agent to includes: {}",
154                    relative_path
155                )
156            })?;
157        }
158
159        Ok(())
160    }
161
162    fn load_include(&self, path: &str) -> Result<PartialServicesConfig> {
163        let full_path = self.base_path.join(path);
164
165        if !full_path.exists() {
166            anyhow::bail!(
167                "Include file not found: {}\nReferenced in: {}/config.yaml\nEither create the \
168                 file or remove it from the includes list.",
169                full_path.display(),
170                self.base_path.display()
171            );
172        }
173
174        let content = fs::read_to_string(&full_path)
175            .with_context(|| format!("Failed to read include: {}", full_path.display()))?;
176
177        serde_yaml::from_str(&content)
178            .with_context(|| format!("Failed to parse include: {}", full_path.display()))
179    }
180
181    fn merge_partial(target: &mut ServicesConfig, partial: PartialServicesConfig) -> Result<()> {
182        for (name, agent) in partial.agents {
183            if target.agents.contains_key(&name) {
184                anyhow::bail!("Duplicate agent definition: {name}");
185            }
186            target.agents.insert(name, agent);
187        }
188
189        for (name, mcp) in partial.mcp_servers {
190            if target.mcp_servers.contains_key(&name) {
191                anyhow::bail!("Duplicate MCP server definition: {name}");
192            }
193            target.mcp_servers.insert(name, mcp);
194        }
195
196        if partial.scheduler.is_some() && target.scheduler.is_none() {
197            target.scheduler = partial.scheduler;
198        }
199
200        Ok(())
201    }
202
203    fn resolve_includes(&self, config: &mut ServicesConfig) -> Result<()> {
204        for (name, agent) in &mut config.agents {
205            if let Some(ref system_prompt) = agent.metadata.system_prompt {
206                if let Some(include_path) = system_prompt.strip_prefix("!include ") {
207                    let full_path = self.base_path.join(include_path.trim());
208                    let resolved = fs::read_to_string(&full_path).with_context(|| {
209                        format!(
210                            "Failed to resolve system_prompt include for agent '{name}': {}",
211                            full_path.display()
212                        )
213                    })?;
214                    agent.metadata.system_prompt = Some(resolved);
215                }
216            }
217        }
218
219        Ok(())
220    }
221
222    pub fn validate_file(path: &Path) -> Result<()> {
223        let loader = Self::new(path.to_path_buf());
224        let _config = loader.load()?;
225        Ok(())
226    }
227
228    pub fn get_includes(&self) -> Result<Vec<String>> {
229        #[derive(serde::Deserialize)]
230        struct IncludesOnly {
231            #[serde(default)]
232            includes: Vec<String>,
233        }
234
235        let content = fs::read_to_string(&self.config_path)?;
236        let parsed: IncludesOnly = serde_yaml::from_str(&content)?;
237        Ok(parsed.includes)
238    }
239
240    pub fn list_all_includes(&self) -> Result<Vec<(String, bool)>> {
241        self.get_includes()?
242            .into_iter()
243            .map(|include| {
244                let exists = self.base_path.join(&include).exists();
245                Ok((include, exists))
246            })
247            .collect()
248    }
249
250    pub fn base_path(&self) -> &Path {
251        &self.base_path
252    }
253}