1use anyhow::{Context, Result, anyhow};
6use serde::Deserialize;
7use std::collections::HashMap;
8use url::Url;
9
10#[derive(Deserialize, Debug, Clone)]
13pub struct AgentConfig {
14 pub system_prompt: String,
15 pub default_provider: String,
16 #[serde(default)]
17 pub providers: HashMap<String, ProviderInstanceConfig>,
18 #[serde(default)]
19 pub mcp_servers: HashMap<String, McpServerConfig>,
20 #[serde(default)]
21 pub strategies: HashMap<String, StrategyConfig>,
22}
23
24#[derive(Deserialize, Debug, Clone)]
25pub struct ProviderInstanceConfig {
26 #[serde(rename = "type")]
28 pub provider_type: String,
29 pub api_key_env_var: String,
30 pub model_config: ModelConfig,
31}
32
33#[derive(Deserialize, Debug, Clone)]
34pub struct McpServerConfig {
35 pub command: String,
36 #[serde(default)]
37 pub args: Vec<String>,
38}
39
40#[derive(Deserialize, Debug, Clone)]
41pub struct StrategyConfig {
42 pub planning_provider: Option<String>,
43 pub execution_provider: Option<String>,
44}
45
46#[derive(Deserialize, Debug, Clone)]
47pub struct ModelConfig {
48 pub model_name: String,
49 #[serde(default)]
50 pub parameters: Option<toml::Value>,
51 #[serde(default)]
52 pub endpoint: Option<String>,
53}
54
55impl AgentConfig {
56 pub fn from_toml_str(config_toml_content: &str) -> Result<AgentConfig> {
57 let config: AgentConfig = match toml::from_str(config_toml_content) {
58 Ok(cfg) => cfg,
59 Err(e) => {
60 tracing::error!(error=%e, content=%config_toml_content, "Failed to parse TOML content");
61 return Err(anyhow!(e))
62 .context("Failed to parse configuration TOML content. Check TOML syntax.");
63 }
64 };
65
66 if config.system_prompt.trim().is_empty() {
68 return Err(anyhow!("'system_prompt' in config content is empty."));
69 }
70 if config.default_provider.trim().is_empty() {
71 return Err(anyhow!(
72 "'default_provider' key in config content is empty."
73 ));
74 }
75 if !config.providers.contains_key(&config.default_provider) {
76 return Err(anyhow!(
77 "Default provider '{}' not found in [providers] map.",
78 config.default_provider
79 ));
80 }
81
82 for (key, provider) in &config.providers {
84 if provider.provider_type.trim().is_empty() {
86 return Err(anyhow!(
87 "Provider '{}' is missing 'type' (provider_type).",
88 key
89 ));
90 }
91 if provider.model_config.model_name.trim().is_empty() {
92 return Err(anyhow!(
93 "Provider '{}' is missing 'model_config.model_name'.",
94 key
95 ));
96 }
97 if provider.api_key_env_var.trim().is_empty() && provider.provider_type != "ollama" {
98 return Err(anyhow!("Provider '{}' is missing 'api_key_env_var'.", key));
100 }
101 if let Some(endpoint) = &provider.model_config.endpoint {
102 if endpoint.trim().is_empty() {
103 return Err(anyhow!(
104 "Provider '{}' has an empty 'model_config.endpoint'.",
105 key
106 ));
107 }
108 Url::parse(endpoint).with_context(|| {
109 format!(
110 "Invalid URL format for endpoint ('{}') in provider '{}'.",
111 endpoint, key
112 )
113 })?;
114 } else if provider.provider_type != "ollama" {
115 }
118 if let Some(params) = &provider.model_config.parameters {
119 if !params.is_table() && !params.is_str() {
120 return Err(anyhow!(
121 "Provider '{}' has invalid 'model_config.parameters'. Expected a TOML table or string.",
122 key
123 ));
124 }
125 }
126 }
127
128 for (key, server) in &config.mcp_servers {
130 if server.command.trim().is_empty() {
131 return Err(anyhow!("MCP Server '{}' has an empty 'command'.", key));
132 }
133 }
134
135 tracing::info!("Successfully parsed and validated agent configuration.");
136 Ok(config)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 fn valid_mcp_config_content() -> String {
146 r#"
147 system_prompt = "You are Volition MCP."
148 default_provider = "gemini_default"
149
150 [providers.gemini_default]
151 type = "gemini" # Use `type` here
152 api_key_env_var = "GOOGLE_API_KEY"
153 [providers.gemini_default.model_config]
154 model_name = "gemini-2.5-pro"
155 endpoint = "https://example.com/gemini"
156 parameters = { temperature = 0.6 }
157
158 [providers.openai_fast]
159 type = "openai" # Use `type` here
160 api_key_env_var = "OPENAI_API_KEY"
161 [providers.openai_fast.model_config]
162 model_name = "gpt-4o-mini"
163 endpoint = "https://example.com/openai"
164 parameters = { temperature = 0.1 }
165
166 [mcp_servers.filesystem]
167 command = "echo"
168 args = ["fs"]
169
170 [mcp_servers.shell]
171 command = "echo"
172 args = ["sh"]
173
174 # Temporarily comment out strategies to isolate parsing issue
175 # [strategies.plan_execute]
176 # planning_provider = "openai_fast"
177 # execution_provider = "gemini_default"
178 "#
179 .to_string()
180 }
181
182 #[test]
183 fn test_mcp_config_parse_success() {
184 let content = valid_mcp_config_content();
185 let result = AgentConfig::from_toml_str(&content);
186 assert!(
188 result.is_ok(),
189 "Parse failed: {:?}\nContent:\n{}",
190 result.err(),
191 content
192 );
193 let config = result.unwrap();
194 assert_eq!(config.default_provider, "gemini_default");
195 assert_eq!(config.providers.len(), 2);
196 assert!(config.providers.contains_key("gemini_default"));
197 assert_eq!(config.providers["gemini_default"].provider_type, "gemini");
199 assert_eq!(config.providers["openai_fast"].provider_type, "openai");
200 assert_eq!(
201 config.providers["openai_fast"].model_config.model_name,
202 "gpt-4o-mini"
203 );
204 assert!(
205 config.providers["gemini_default"]
206 .model_config
207 .parameters
208 .is_some()
209 );
210 assert_eq!(config.mcp_servers.len(), 2);
211 assert_eq!(config.mcp_servers["filesystem"].command, "echo");
212 }
216
217 #[test]
218 fn test_mcp_config_missing_default_provider_def() {
219 let content = r#"
221 system_prompt = "Valid"
222 default_provider = "missing_provider"
223 [providers.gemini_default]
224 type = "gemini" # Use `type` here
225 api_key_env_var = "GOOGLE_API_KEY"
226 [providers.gemini_default.model_config]
227 model_name = "gemini-2.5-pro"
228 endpoint = "https://example.com"
229 "#;
230 let result = AgentConfig::from_toml_str(content);
231 assert!(result.is_err());
232 let error_string = result.err().unwrap().to_string();
234 assert!(
235 error_string.contains("Default provider 'missing_provider' not found"),
236 "Unexpected error message: {}",
237 error_string
238 );
239 }
240
241 }