1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum AgentError {
8 #[error("Invalid agent configuration: {0}")]
9 Invalid(String),
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AgentProfile {
15 #[serde(default)]
17 pub prompt: Option<String>,
18
19 #[serde(default)]
21 pub style: Option<String>,
22
23 #[serde(default)]
25 pub temperature: Option<f32>,
26
27 #[serde(default)]
29 pub model_provider: Option<String>,
30
31 #[serde(default)]
33 pub model_name: Option<String>,
34
35 #[serde(default)]
37 pub allowed_tools: Option<Vec<String>>,
38
39 #[serde(default)]
41 pub denied_tools: Option<Vec<String>>,
42
43 #[serde(default = "AgentProfile::default_memory_k")]
45 pub memory_k: usize,
46
47 #[serde(default = "AgentProfile::default_top_p")]
49 pub top_p: f32,
50
51 #[serde(default)]
53 pub max_context_tokens: Option<usize>,
54
55 #[serde(default)]
58 pub enable_graph: bool,
59
60 #[serde(default)]
62 pub graph_memory: bool,
63
64 #[serde(default = "AgentProfile::default_graph_depth")]
66 pub graph_depth: usize,
67
68 #[serde(default = "AgentProfile::default_graph_weight")]
70 pub graph_weight: f32,
71
72 #[serde(default)]
74 pub auto_graph: bool,
75
76 #[serde(default = "AgentProfile::default_graph_threshold")]
78 pub graph_threshold: f32,
79
80 #[serde(default)]
82 pub graph_steering: bool,
83
84 #[serde(default)]
87 pub fast_reasoning: bool,
88
89 #[serde(default)]
91 pub fast_model_provider: Option<String>,
92
93 #[serde(default)]
95 pub fast_model_name: Option<String>,
96
97 #[serde(default = "AgentProfile::default_fast_temperature")]
99 pub fast_model_temperature: f32,
100
101 #[serde(default = "AgentProfile::default_fast_tasks")]
103 pub fast_model_tasks: Vec<String>,
104
105 #[serde(default = "AgentProfile::default_escalation_threshold")]
107 pub escalation_threshold: f32,
108
109 #[serde(default)]
111 pub show_reasoning: bool,
112
113 #[serde(default)]
116 pub enable_audio_transcription: bool,
117
118 #[serde(default = "AgentProfile::default_audio_response_mode")]
120 pub audio_response_mode: String,
121
122 #[serde(default)]
124 pub audio_scenario: Option<String>,
125}
126
127impl AgentProfile {
128 const ALWAYS_ALLOWED_TOOLS: [&'static str; 1] = ["prompt_user"];
129 fn default_memory_k() -> usize {
130 10
131 }
132
133 fn default_top_p() -> f32 {
134 0.9
135 }
136
137 fn default_graph_depth() -> usize {
138 3
139 }
140
141 fn default_graph_weight() -> f32 {
142 0.5 }
144
145 fn default_graph_threshold() -> f32 {
146 0.7 }
148
149 fn default_fast_temperature() -> f32 {
150 0.3 }
152
153 fn default_fast_tasks() -> Vec<String> {
154 vec![
155 "entity_extraction".to_string(),
156 "graph_analysis".to_string(),
157 "decision_routing".to_string(),
158 "tool_selection".to_string(),
159 "confidence_scoring".to_string(),
160 ]
161 }
162
163 fn default_escalation_threshold() -> f32 {
164 0.6 }
166
167 fn default_audio_response_mode() -> String {
168 "immediate".to_string()
169 }
170
171 pub fn validate(&self) -> Result<()> {
173 if let Some(temp) = self.temperature {
175 if !(0.0..=2.0).contains(&temp) {
176 return Err(AgentError::Invalid(format!(
177 "temperature must be between 0.0 and 2.0, got {}",
178 temp
179 ))
180 .into());
181 }
182 }
183
184 if self.top_p < 0.0 || self.top_p > 1.0 {
186 return Err(AgentError::Invalid(format!(
187 "top_p must be between 0.0 and 1.0, got {}",
188 self.top_p
189 ))
190 .into());
191 }
192
193 if self.graph_weight < 0.0 || self.graph_weight > 1.0 {
195 return Err(AgentError::Invalid(format!(
196 "graph_weight must be between 0.0 and 1.0, got {}",
197 self.graph_weight
198 ))
199 .into());
200 }
201
202 if self.graph_threshold < 0.0 || self.graph_threshold > 1.0 {
204 return Err(AgentError::Invalid(format!(
205 "graph_threshold must be between 0.0 and 1.0, got {}",
206 self.graph_threshold
207 ))
208 .into());
209 }
210
211 if let (Some(allowed), Some(denied)) = (&self.allowed_tools, &self.denied_tools) {
213 let allowed_set: HashSet<_> = allowed.iter().collect();
214 let denied_set: HashSet<_> = denied.iter().collect();
215 let overlap: Vec<_> = allowed_set.intersection(&denied_set).collect();
216
217 if !overlap.is_empty() {
218 return Err(AgentError::Invalid(format!(
219 "tools cannot be both allowed and denied: {:?}",
220 overlap
221 ))
222 .into());
223 }
224 }
225
226 if let Some(provider) = &self.model_provider {
228 let valid_providers = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
229 if !valid_providers.contains(&provider.as_str()) {
230 return Err(AgentError::Invalid(format!(
231 "model_provider must be one of: {}. Got: {}",
232 valid_providers.join(", "),
233 provider
234 ))
235 .into());
236 }
237 }
238
239 Ok(())
240 }
241
242 pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
244 if let Some(denied) = &self.denied_tools {
246 if denied.iter().any(|t| t == tool_name) {
247 return false;
248 }
249 }
250
251 if Self::ALWAYS_ALLOWED_TOOLS.contains(&tool_name) {
252 return true;
253 }
254
255 if let Some(allowed) = &self.allowed_tools {
257 return allowed.iter().any(|t| t == tool_name);
258 }
259
260 true
262 }
263
264 pub fn effective_temperature(&self, default: f32) -> f32 {
266 self.temperature.unwrap_or(default)
267 }
268
269 pub fn effective_provider<'a>(&'a self, default: &'a str) -> &'a str {
271 self.model_provider.as_deref().unwrap_or(default)
272 }
273
274 pub fn effective_model_name<'a>(&'a self, default: Option<&'a str>) -> Option<&'a str> {
276 self.model_name.as_deref().or(default)
277 }
278}
279
280impl Default for AgentProfile {
281 fn default() -> Self {
282 Self {
283 prompt: None,
284 style: None,
285 temperature: None,
286 model_provider: None,
287 model_name: None,
288 allowed_tools: None,
289 denied_tools: None,
290 memory_k: Self::default_memory_k(),
291 top_p: Self::default_top_p(),
292 max_context_tokens: None,
293 enable_graph: true, graph_memory: true, graph_depth: Self::default_graph_depth(),
296 graph_weight: Self::default_graph_weight(),
297 auto_graph: true, graph_threshold: Self::default_graph_threshold(),
299 graph_steering: true, fast_reasoning: true, fast_model_provider: Some("lmstudio".to_string()), fast_model_name: Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string()),
303 fast_model_temperature: Self::default_fast_temperature(),
304 fast_model_tasks: Self::default_fast_tasks(),
305 escalation_threshold: Self::default_escalation_threshold(),
306 show_reasoning: false, enable_audio_transcription: false, audio_response_mode: Self::default_audio_response_mode(),
309 audio_scenario: None,
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_default_agent_profile() {
320 let profile = AgentProfile::default();
321 assert_eq!(profile.memory_k, 10);
322 assert_eq!(profile.top_p, 0.9);
323
324 assert!(profile.fast_reasoning);
326 assert_eq!(profile.fast_model_provider, Some("lmstudio".to_string()));
327 assert_eq!(
328 profile.fast_model_name,
329 Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string())
330 );
331 assert_eq!(profile.fast_model_temperature, 0.3);
332 assert_eq!(profile.escalation_threshold, 0.6);
333
334 assert!(profile.enable_graph);
336 assert!(profile.graph_memory);
337 assert!(profile.auto_graph);
338 assert!(profile.graph_steering);
339
340 assert!(profile.validate().is_ok());
341 }
342
343 #[test]
344 fn test_validate_invalid_temperature() {
345 let mut profile = AgentProfile::default();
346 profile.temperature = Some(3.0);
347 assert!(profile.validate().is_err());
348 }
349
350 #[test]
351 fn test_validate_invalid_top_p() {
352 let mut profile = AgentProfile::default();
353 profile.top_p = 1.5;
354 assert!(profile.validate().is_err());
355 }
356
357 #[test]
358 fn test_validate_tool_overlap() {
359 let mut profile = AgentProfile::default();
360 profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
361 profile.denied_tools = Some(vec!["tool2".to_string(), "tool3".to_string()]);
362 assert!(profile.validate().is_err());
363 }
364
365 #[test]
366 fn test_is_tool_allowed_no_restrictions() {
367 let profile = AgentProfile::default();
368 assert!(profile.is_tool_allowed("any_tool"));
369 assert!(profile.is_tool_allowed("prompt_user"));
370 }
371
372 #[test]
373 fn test_is_tool_allowed_with_allowlist() {
374 let mut profile = AgentProfile::default();
375 profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
376
377 assert!(profile.is_tool_allowed("tool1"));
378 assert!(profile.is_tool_allowed("tool2"));
379 assert!(!profile.is_tool_allowed("tool3"));
380 assert!(profile.is_tool_allowed("prompt_user"));
382 }
383
384 #[test]
385 fn test_is_tool_allowed_with_denylist() {
386 let mut profile = AgentProfile::default();
387 profile.denied_tools = Some(vec!["tool1".to_string(), "prompt_user".to_string()]);
388
389 assert!(!profile.is_tool_allowed("tool1"));
390 assert!(profile.is_tool_allowed("tool2"));
391 assert!(!profile.is_tool_allowed("prompt_user"));
392 }
393
394 #[test]
395 fn test_effective_temperature() {
396 let mut profile = AgentProfile::default();
397 assert_eq!(profile.effective_temperature(0.7), 0.7);
398
399 profile.temperature = Some(0.5);
400 assert_eq!(profile.effective_temperature(0.7), 0.5);
401 }
402
403 #[test]
404 fn test_effective_provider() {
405 let mut profile = AgentProfile::default();
406 assert_eq!(profile.effective_provider("mock"), "mock");
407
408 profile.model_provider = Some("openai".to_string());
409 assert_eq!(profile.effective_provider("mock"), "openai");
410 }
411}