spec_ai_config/config/
agent.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum AgentError {
8    #[error("Invalid agent configuration: {0}")]
9    Invalid(String),
10}
11
12/// Configuration for a specific agent profile
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AgentProfile {
15    /// System prompt for this agent
16    #[serde(default)]
17    pub prompt: Option<String>,
18
19    /// Conversational style or personality
20    #[serde(default)]
21    pub style: Option<String>,
22
23    /// Temperature override for this agent (0.0 to 2.0)
24    #[serde(default)]
25    pub temperature: Option<f32>,
26
27    /// Model provider override (e.g., "openai", "anthropic", "lmstudio")
28    #[serde(default)]
29    pub model_provider: Option<String>,
30
31    /// Model name override (e.g., "gpt-4", "claude-3-opus")
32    #[serde(default)]
33    pub model_name: Option<String>,
34
35    /// List of tools this agent is allowed to use
36    #[serde(default)]
37    pub allowed_tools: Option<Vec<String>>,
38
39    /// List of tools this agent is forbidden from using
40    #[serde(default)]
41    pub denied_tools: Option<Vec<String>>,
42
43    /// Memory parameters: number of messages to recall (k for top-k)
44    #[serde(default = "AgentProfile::default_memory_k")]
45    pub memory_k: usize,
46
47    /// Top-p sampling parameter for memory recall
48    #[serde(default = "AgentProfile::default_top_p")]
49    pub top_p: f32,
50
51    /// Maximum context window size for this agent
52    #[serde(default)]
53    pub max_context_tokens: Option<usize>,
54
55    // ========== Knowledge Graph Configuration ==========
56    /// Enable knowledge graph features for this agent
57    #[serde(default)]
58    pub enable_graph: bool,
59
60    /// Use graph-based memory recall (combines with embeddings)
61    #[serde(default)]
62    pub graph_memory: bool,
63
64    /// Maximum graph traversal depth for context building
65    #[serde(default = "AgentProfile::default_graph_depth")]
66    pub graph_depth: usize,
67
68    /// Weight for graph-based relevance vs semantic similarity (0.0 to 1.0)
69    #[serde(default = "AgentProfile::default_graph_weight")]
70    pub graph_weight: f32,
71
72    /// Automatically build graph from conversations
73    #[serde(default)]
74    pub auto_graph: bool,
75
76    /// Graph-based tool recommendation threshold (0.0 to 1.0)
77    #[serde(default = "AgentProfile::default_graph_threshold")]
78    pub graph_threshold: f32,
79
80    /// Use graph for decision steering
81    #[serde(default)]
82    pub graph_steering: bool,
83
84    // ========== Multi-Model Reasoning Configuration ==========
85    /// Enable fast reasoning with a smaller model
86    #[serde(default)]
87    pub fast_reasoning: bool,
88
89    /// Model provider for fast reasoning (e.g., "mlx", "ollama", "lmstudio")
90    #[serde(default)]
91    pub fast_model_provider: Option<String>,
92
93    /// Model name for fast reasoning (e.g., "lmstudio-community/Llama-3.2-3B-Instruct")
94    #[serde(default)]
95    pub fast_model_name: Option<String>,
96
97    /// Temperature for fast model (typically lower for consistency)
98    #[serde(default = "AgentProfile::default_fast_temperature")]
99    pub fast_model_temperature: f32,
100
101    /// Tasks to delegate to fast model
102    #[serde(default = "AgentProfile::default_fast_tasks")]
103    pub fast_model_tasks: Vec<String>,
104
105    /// Confidence threshold to escalate to main model
106    #[serde(default = "AgentProfile::default_escalation_threshold")]
107    pub escalation_threshold: f32,
108
109    /// Display reasoning summary to user (requires fast model for summarization)
110    #[serde(default)]
111    pub show_reasoning: bool,
112
113    // ========== Audio Transcription Configuration ==========
114    /// Enable audio transcription for this agent
115    #[serde(default)]
116    pub enable_audio_transcription: bool,
117
118    /// Audio response mode: "immediate" or "batch"
119    #[serde(default = "AgentProfile::default_audio_response_mode")]
120    pub audio_response_mode: String,
121
122    /// Preferred audio transcription scenario for testing
123    #[serde(default)]
124    pub audio_scenario: Option<String>,
125
126    // ========== Collective Intelligence Configuration ==========
127    /// Enable collective intelligence features for this agent
128    #[serde(default)]
129    pub enable_collective: bool,
130
131    /// Allow this agent to accept delegated tasks from peers
132    #[serde(default = "AgentProfile::default_accept_delegations")]
133    pub accept_delegations: bool,
134
135    /// Domains this agent prefers to specialize in
136    #[serde(default)]
137    pub preferred_domains: Vec<String>,
138
139    /// Maximum concurrent delegated tasks this agent can handle
140    #[serde(default = "AgentProfile::default_max_concurrent_tasks")]
141    pub max_concurrent_tasks: usize,
142
143    /// Minimum capability score required to accept a delegated task
144    #[serde(default = "AgentProfile::default_min_delegation_score")]
145    pub min_delegation_score: f32,
146
147    /// Enable sharing learned strategies with the mesh
148    #[serde(default)]
149    pub share_learnings: bool,
150
151    /// Participate in collective decision-making (voting)
152    #[serde(default = "AgentProfile::default_participate_in_voting")]
153    pub participate_in_voting: bool,
154}
155
156impl AgentProfile {
157    const ALWAYS_ALLOWED_TOOLS: [&'static str; 1] = ["prompt_user"];
158    fn default_memory_k() -> usize {
159        10
160    }
161
162    fn default_top_p() -> f32 {
163        0.9
164    }
165
166    fn default_graph_depth() -> usize {
167        3
168    }
169
170    fn default_graph_weight() -> f32 {
171        0.5 // Equal weight to graph and semantic
172    }
173
174    fn default_graph_threshold() -> f32 {
175        0.7 // Recommend tools with >70% relevance
176    }
177
178    fn default_fast_temperature() -> f32 {
179        0.3 // Lower temperature for consistency in fast model
180    }
181
182    fn default_fast_tasks() -> Vec<String> {
183        vec![
184            "entity_extraction".to_string(),
185            "graph_analysis".to_string(),
186            "decision_routing".to_string(),
187            "tool_selection".to_string(),
188            "confidence_scoring".to_string(),
189        ]
190    }
191
192    fn default_escalation_threshold() -> f32 {
193        0.6 // Escalate to main model if confidence < 60%
194    }
195
196    fn default_audio_response_mode() -> String {
197        "immediate".to_string()
198    }
199
200    fn default_accept_delegations() -> bool {
201        true
202    }
203
204    fn default_max_concurrent_tasks() -> usize {
205        3
206    }
207
208    fn default_min_delegation_score() -> f32 {
209        0.3
210    }
211
212    fn default_participate_in_voting() -> bool {
213        true
214    }
215
216    /// Validate the agent profile configuration
217    pub fn validate(&self) -> Result<()> {
218        // Validate temperature if specified
219        if let Some(temp) = self.temperature {
220            if !(0.0..=2.0).contains(&temp) {
221                return Err(AgentError::Invalid(format!(
222                    "temperature must be between 0.0 and 2.0, got {}",
223                    temp
224                ))
225                .into());
226            }
227        }
228
229        // Validate top_p
230        if self.top_p < 0.0 || self.top_p > 1.0 {
231            return Err(AgentError::Invalid(format!(
232                "top_p must be between 0.0 and 1.0, got {}",
233                self.top_p
234            ))
235            .into());
236        }
237
238        // Validate graph_weight
239        if self.graph_weight < 0.0 || self.graph_weight > 1.0 {
240            return Err(AgentError::Invalid(format!(
241                "graph_weight must be between 0.0 and 1.0, got {}",
242                self.graph_weight
243            ))
244            .into());
245        }
246
247        // Validate graph_threshold
248        if self.graph_threshold < 0.0 || self.graph_threshold > 1.0 {
249            return Err(AgentError::Invalid(format!(
250                "graph_threshold must be between 0.0 and 1.0, got {}",
251                self.graph_threshold
252            ))
253            .into());
254        }
255
256        // Validate that allowed_tools and denied_tools don't overlap
257        if let (Some(allowed), Some(denied)) = (&self.allowed_tools, &self.denied_tools) {
258            let allowed_set: HashSet<_> = allowed.iter().collect();
259            let denied_set: HashSet<_> = denied.iter().collect();
260            let overlap: Vec<_> = allowed_set.intersection(&denied_set).collect();
261
262            if !overlap.is_empty() {
263                return Err(AgentError::Invalid(format!(
264                    "tools cannot be both allowed and denied: {:?}",
265                    overlap
266                ))
267                .into());
268            }
269        }
270
271        // Validate model provider if specified
272        if let Some(provider) = &self.model_provider {
273            let valid_providers = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
274            if !valid_providers.contains(&provider.as_str()) {
275                return Err(AgentError::Invalid(format!(
276                    "model_provider must be one of: {}. Got: {}",
277                    valid_providers.join(", "),
278                    provider
279                ))
280                .into());
281            }
282        }
283
284        Ok(())
285    }
286
287    /// Check if a tool is allowed for this agent
288    pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
289        // If denied list exists and contains the tool, deny it
290        if let Some(denied) = &self.denied_tools {
291            if denied.iter().any(|t| t == tool_name) {
292                return false;
293            }
294        }
295
296        if Self::ALWAYS_ALLOWED_TOOLS.contains(&tool_name) {
297            return true;
298        }
299
300        // If allowed list exists, only allow tools in the list
301        if let Some(allowed) = &self.allowed_tools {
302            return allowed.iter().any(|t| t == tool_name);
303        }
304
305        // If no restrictions, allow all tools
306        true
307    }
308
309    /// Get the effective temperature (override or default)
310    pub fn effective_temperature(&self, default: f32) -> f32 {
311        self.temperature.unwrap_or(default)
312    }
313
314    /// Get the effective model provider (override or default)
315    pub fn effective_provider<'a>(&'a self, default: &'a str) -> &'a str {
316        self.model_provider.as_deref().unwrap_or(default)
317    }
318
319    /// Get the effective model name (override or default)
320    pub fn effective_model_name<'a>(&'a self, default: Option<&'a str>) -> Option<&'a str> {
321        self.model_name.as_deref().or(default)
322    }
323}
324
325impl Default for AgentProfile {
326    fn default() -> Self {
327        Self {
328            prompt: None,
329            style: None,
330            temperature: None,
331            model_provider: None,
332            model_name: None,
333            allowed_tools: None,
334            denied_tools: None,
335            memory_k: Self::default_memory_k(),
336            top_p: Self::default_top_p(),
337            max_context_tokens: None,
338            enable_graph: true, // Enable by default
339            graph_memory: true, // Enable by default
340            graph_depth: Self::default_graph_depth(),
341            graph_weight: Self::default_graph_weight(),
342            auto_graph: true, // Enable by default
343            graph_threshold: Self::default_graph_threshold(),
344            graph_steering: true, // Enable by default
345            fast_reasoning: true, // Enable multi-model by default
346            fast_model_provider: Some("lmstudio".to_string()), // Default to LM Studio local server
347            fast_model_name: Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string()),
348            fast_model_temperature: Self::default_fast_temperature(),
349            fast_model_tasks: Self::default_fast_tasks(),
350            escalation_threshold: Self::default_escalation_threshold(),
351            show_reasoning: false,             // Disabled by default
352            enable_audio_transcription: false, // Disabled by default
353            audio_response_mode: Self::default_audio_response_mode(),
354            audio_scenario: None,
355            // Collective intelligence - disabled by default until explicitly enabled
356            enable_collective: false,
357            accept_delegations: Self::default_accept_delegations(),
358            preferred_domains: Vec::new(),
359            max_concurrent_tasks: Self::default_max_concurrent_tasks(),
360            min_delegation_score: Self::default_min_delegation_score(),
361            share_learnings: false, // Disabled by default
362            participate_in_voting: Self::default_participate_in_voting(),
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_default_agent_profile() {
373        let profile = AgentProfile::default();
374        assert_eq!(profile.memory_k, 10);
375        assert_eq!(profile.top_p, 0.9);
376
377        // Verify multi-model is enabled by default
378        assert!(profile.fast_reasoning);
379        assert_eq!(profile.fast_model_provider, Some("lmstudio".to_string()));
380        assert_eq!(
381            profile.fast_model_name,
382            Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string())
383        );
384        assert_eq!(profile.fast_model_temperature, 0.3);
385        assert_eq!(profile.escalation_threshold, 0.6);
386
387        // Verify knowledge graph is enabled by default
388        assert!(profile.enable_graph);
389        assert!(profile.graph_memory);
390        assert!(profile.auto_graph);
391        assert!(profile.graph_steering);
392
393        assert!(profile.validate().is_ok());
394    }
395
396    #[test]
397    fn test_validate_invalid_temperature() {
398        let mut profile = AgentProfile::default();
399        profile.temperature = Some(3.0);
400        assert!(profile.validate().is_err());
401    }
402
403    #[test]
404    fn test_validate_invalid_top_p() {
405        let mut profile = AgentProfile::default();
406        profile.top_p = 1.5;
407        assert!(profile.validate().is_err());
408    }
409
410    #[test]
411    fn test_validate_tool_overlap() {
412        let mut profile = AgentProfile::default();
413        profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
414        profile.denied_tools = Some(vec!["tool2".to_string(), "tool3".to_string()]);
415        assert!(profile.validate().is_err());
416    }
417
418    #[test]
419    fn test_is_tool_allowed_no_restrictions() {
420        let profile = AgentProfile::default();
421        assert!(profile.is_tool_allowed("any_tool"));
422        assert!(profile.is_tool_allowed("prompt_user"));
423    }
424
425    #[test]
426    fn test_is_tool_allowed_with_allowlist() {
427        let mut profile = AgentProfile::default();
428        profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
429
430        assert!(profile.is_tool_allowed("tool1"));
431        assert!(profile.is_tool_allowed("tool2"));
432        assert!(!profile.is_tool_allowed("tool3"));
433        // prompt_user should remain available even if not explicitly listed
434        assert!(profile.is_tool_allowed("prompt_user"));
435    }
436
437    #[test]
438    fn test_is_tool_allowed_with_denylist() {
439        let mut profile = AgentProfile::default();
440        profile.denied_tools = Some(vec!["tool1".to_string(), "prompt_user".to_string()]);
441
442        assert!(!profile.is_tool_allowed("tool1"));
443        assert!(profile.is_tool_allowed("tool2"));
444        assert!(!profile.is_tool_allowed("prompt_user"));
445    }
446
447    #[test]
448    fn test_effective_temperature() {
449        let mut profile = AgentProfile::default();
450        assert_eq!(profile.effective_temperature(0.7), 0.7);
451
452        profile.temperature = Some(0.5);
453        assert_eq!(profile.effective_temperature(0.7), 0.5);
454    }
455
456    #[test]
457    fn test_effective_provider() {
458        let mut profile = AgentProfile::default();
459        assert_eq!(profile.effective_provider("mock"), "mock");
460
461        profile.model_provider = Some("openai".to_string());
462        assert_eq!(profile.effective_provider("mock"), "openai");
463    }
464}