1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum AgentError {
8 #[error("Invalid agent configuration: {0}")]
9 Invalid(String),
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AgentProfile {
15 #[serde(default)]
17 pub prompt: Option<String>,
18
19 #[serde(default)]
21 pub style: Option<String>,
22
23 #[serde(default)]
25 pub temperature: Option<f32>,
26
27 #[serde(default)]
29 pub model_provider: Option<String>,
30
31 #[serde(default)]
33 pub model_name: Option<String>,
34
35 #[serde(default)]
37 pub allowed_tools: Option<Vec<String>>,
38
39 #[serde(default)]
41 pub denied_tools: Option<Vec<String>>,
42
43 #[serde(default = "AgentProfile::default_memory_k")]
45 pub memory_k: usize,
46
47 #[serde(default = "AgentProfile::default_top_p")]
49 pub top_p: f32,
50
51 #[serde(default)]
53 pub max_context_tokens: Option<usize>,
54
55 #[serde(default)]
58 pub enable_graph: bool,
59
60 #[serde(default)]
62 pub graph_memory: bool,
63
64 #[serde(default = "AgentProfile::default_graph_depth")]
66 pub graph_depth: usize,
67
68 #[serde(default = "AgentProfile::default_graph_weight")]
70 pub graph_weight: f32,
71
72 #[serde(default)]
74 pub auto_graph: bool,
75
76 #[serde(default = "AgentProfile::default_graph_threshold")]
78 pub graph_threshold: f32,
79
80 #[serde(default)]
82 pub graph_steering: bool,
83
84 #[serde(default)]
87 pub fast_reasoning: bool,
88
89 #[serde(default)]
91 pub fast_model_provider: Option<String>,
92
93 #[serde(default)]
95 pub fast_model_name: Option<String>,
96
97 #[serde(default = "AgentProfile::default_fast_temperature")]
99 pub fast_model_temperature: f32,
100
101 #[serde(default = "AgentProfile::default_fast_tasks")]
103 pub fast_model_tasks: Vec<String>,
104
105 #[serde(default = "AgentProfile::default_escalation_threshold")]
107 pub escalation_threshold: f32,
108
109 #[serde(default)]
111 pub show_reasoning: bool,
112
113 #[serde(default)]
116 pub enable_audio_transcription: bool,
117
118 #[serde(default = "AgentProfile::default_audio_response_mode")]
120 pub audio_response_mode: String,
121
122 #[serde(default)]
124 pub audio_scenario: Option<String>,
125
126 #[serde(default)]
129 pub enable_collective: bool,
130
131 #[serde(default = "AgentProfile::default_accept_delegations")]
133 pub accept_delegations: bool,
134
135 #[serde(default)]
137 pub preferred_domains: Vec<String>,
138
139 #[serde(default = "AgentProfile::default_max_concurrent_tasks")]
141 pub max_concurrent_tasks: usize,
142
143 #[serde(default = "AgentProfile::default_min_delegation_score")]
145 pub min_delegation_score: f32,
146
147 #[serde(default)]
149 pub share_learnings: bool,
150
151 #[serde(default = "AgentProfile::default_participate_in_voting")]
153 pub participate_in_voting: bool,
154}
155
156impl AgentProfile {
157 const ALWAYS_ALLOWED_TOOLS: [&'static str; 1] = ["prompt_user"];
158 fn default_memory_k() -> usize {
159 10
160 }
161
162 fn default_top_p() -> f32 {
163 0.9
164 }
165
166 fn default_graph_depth() -> usize {
167 3
168 }
169
170 fn default_graph_weight() -> f32 {
171 0.5 }
173
174 fn default_graph_threshold() -> f32 {
175 0.7 }
177
178 fn default_fast_temperature() -> f32 {
179 0.3 }
181
182 fn default_fast_tasks() -> Vec<String> {
183 vec![
184 "entity_extraction".to_string(),
185 "graph_analysis".to_string(),
186 "decision_routing".to_string(),
187 "tool_selection".to_string(),
188 "confidence_scoring".to_string(),
189 ]
190 }
191
192 fn default_escalation_threshold() -> f32 {
193 0.6 }
195
196 fn default_audio_response_mode() -> String {
197 "immediate".to_string()
198 }
199
200 fn default_accept_delegations() -> bool {
201 true
202 }
203
204 fn default_max_concurrent_tasks() -> usize {
205 3
206 }
207
208 fn default_min_delegation_score() -> f32 {
209 0.3
210 }
211
212 fn default_participate_in_voting() -> bool {
213 true
214 }
215
216 pub fn validate(&self) -> Result<()> {
218 if let Some(temp) = self.temperature {
220 if !(0.0..=2.0).contains(&temp) {
221 return Err(AgentError::Invalid(format!(
222 "temperature must be between 0.0 and 2.0, got {}",
223 temp
224 ))
225 .into());
226 }
227 }
228
229 if self.top_p < 0.0 || self.top_p > 1.0 {
231 return Err(AgentError::Invalid(format!(
232 "top_p must be between 0.0 and 1.0, got {}",
233 self.top_p
234 ))
235 .into());
236 }
237
238 if self.graph_weight < 0.0 || self.graph_weight > 1.0 {
240 return Err(AgentError::Invalid(format!(
241 "graph_weight must be between 0.0 and 1.0, got {}",
242 self.graph_weight
243 ))
244 .into());
245 }
246
247 if self.graph_threshold < 0.0 || self.graph_threshold > 1.0 {
249 return Err(AgentError::Invalid(format!(
250 "graph_threshold must be between 0.0 and 1.0, got {}",
251 self.graph_threshold
252 ))
253 .into());
254 }
255
256 if let (Some(allowed), Some(denied)) = (&self.allowed_tools, &self.denied_tools) {
258 let allowed_set: HashSet<_> = allowed.iter().collect();
259 let denied_set: HashSet<_> = denied.iter().collect();
260 let overlap: Vec<_> = allowed_set.intersection(&denied_set).collect();
261
262 if !overlap.is_empty() {
263 return Err(AgentError::Invalid(format!(
264 "tools cannot be both allowed and denied: {:?}",
265 overlap
266 ))
267 .into());
268 }
269 }
270
271 if let Some(provider) = &self.model_provider {
273 let valid_providers = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
274 if !valid_providers.contains(&provider.as_str()) {
275 return Err(AgentError::Invalid(format!(
276 "model_provider must be one of: {}. Got: {}",
277 valid_providers.join(", "),
278 provider
279 ))
280 .into());
281 }
282 }
283
284 Ok(())
285 }
286
287 pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
289 if let Some(denied) = &self.denied_tools {
291 if denied.iter().any(|t| t == tool_name) {
292 return false;
293 }
294 }
295
296 if Self::ALWAYS_ALLOWED_TOOLS.contains(&tool_name) {
297 return true;
298 }
299
300 if let Some(allowed) = &self.allowed_tools {
302 return allowed.iter().any(|t| t == tool_name);
303 }
304
305 true
307 }
308
309 pub fn effective_temperature(&self, default: f32) -> f32 {
311 self.temperature.unwrap_or(default)
312 }
313
314 pub fn effective_provider<'a>(&'a self, default: &'a str) -> &'a str {
316 self.model_provider.as_deref().unwrap_or(default)
317 }
318
319 pub fn effective_model_name<'a>(&'a self, default: Option<&'a str>) -> Option<&'a str> {
321 self.model_name.as_deref().or(default)
322 }
323}
324
325impl Default for AgentProfile {
326 fn default() -> Self {
327 Self {
328 prompt: None,
329 style: None,
330 temperature: None,
331 model_provider: None,
332 model_name: None,
333 allowed_tools: None,
334 denied_tools: None,
335 memory_k: Self::default_memory_k(),
336 top_p: Self::default_top_p(),
337 max_context_tokens: None,
338 enable_graph: true, graph_memory: true, graph_depth: Self::default_graph_depth(),
341 graph_weight: Self::default_graph_weight(),
342 auto_graph: true, graph_threshold: Self::default_graph_threshold(),
344 graph_steering: true, fast_reasoning: true, fast_model_provider: Some("lmstudio".to_string()), fast_model_name: Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string()),
348 fast_model_temperature: Self::default_fast_temperature(),
349 fast_model_tasks: Self::default_fast_tasks(),
350 escalation_threshold: Self::default_escalation_threshold(),
351 show_reasoning: false, enable_audio_transcription: false, audio_response_mode: Self::default_audio_response_mode(),
354 audio_scenario: None,
355 enable_collective: false,
357 accept_delegations: Self::default_accept_delegations(),
358 preferred_domains: Vec::new(),
359 max_concurrent_tasks: Self::default_max_concurrent_tasks(),
360 min_delegation_score: Self::default_min_delegation_score(),
361 share_learnings: false, participate_in_voting: Self::default_participate_in_voting(),
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_default_agent_profile() {
373 let profile = AgentProfile::default();
374 assert_eq!(profile.memory_k, 10);
375 assert_eq!(profile.top_p, 0.9);
376
377 assert!(profile.fast_reasoning);
379 assert_eq!(profile.fast_model_provider, Some("lmstudio".to_string()));
380 assert_eq!(
381 profile.fast_model_name,
382 Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string())
383 );
384 assert_eq!(profile.fast_model_temperature, 0.3);
385 assert_eq!(profile.escalation_threshold, 0.6);
386
387 assert!(profile.enable_graph);
389 assert!(profile.graph_memory);
390 assert!(profile.auto_graph);
391 assert!(profile.graph_steering);
392
393 assert!(profile.validate().is_ok());
394 }
395
396 #[test]
397 fn test_validate_invalid_temperature() {
398 let mut profile = AgentProfile::default();
399 profile.temperature = Some(3.0);
400 assert!(profile.validate().is_err());
401 }
402
403 #[test]
404 fn test_validate_invalid_top_p() {
405 let mut profile = AgentProfile::default();
406 profile.top_p = 1.5;
407 assert!(profile.validate().is_err());
408 }
409
410 #[test]
411 fn test_validate_tool_overlap() {
412 let mut profile = AgentProfile::default();
413 profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
414 profile.denied_tools = Some(vec!["tool2".to_string(), "tool3".to_string()]);
415 assert!(profile.validate().is_err());
416 }
417
418 #[test]
419 fn test_is_tool_allowed_no_restrictions() {
420 let profile = AgentProfile::default();
421 assert!(profile.is_tool_allowed("any_tool"));
422 assert!(profile.is_tool_allowed("prompt_user"));
423 }
424
425 #[test]
426 fn test_is_tool_allowed_with_allowlist() {
427 let mut profile = AgentProfile::default();
428 profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
429
430 assert!(profile.is_tool_allowed("tool1"));
431 assert!(profile.is_tool_allowed("tool2"));
432 assert!(!profile.is_tool_allowed("tool3"));
433 assert!(profile.is_tool_allowed("prompt_user"));
435 }
436
437 #[test]
438 fn test_is_tool_allowed_with_denylist() {
439 let mut profile = AgentProfile::default();
440 profile.denied_tools = Some(vec!["tool1".to_string(), "prompt_user".to_string()]);
441
442 assert!(!profile.is_tool_allowed("tool1"));
443 assert!(profile.is_tool_allowed("tool2"));
444 assert!(!profile.is_tool_allowed("prompt_user"));
445 }
446
447 #[test]
448 fn test_effective_temperature() {
449 let mut profile = AgentProfile::default();
450 assert_eq!(profile.effective_temperature(0.7), 0.7);
451
452 profile.temperature = Some(0.5);
453 assert_eq!(profile.effective_temperature(0.7), 0.5);
454 }
455
456 #[test]
457 fn test_effective_provider() {
458 let mut profile = AgentProfile::default();
459 assert_eq!(profile.effective_provider("mock"), "mock");
460
461 profile.model_provider = Some("openai".to_string());
462 assert_eq!(profile.effective_provider("mock"), "openai");
463 }
464}