spec_ai_config/config/
agent.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum AgentError {
8    #[error("Invalid agent configuration: {0}")]
9    Invalid(String),
10}
11
12/// Configuration for a specific agent profile
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AgentProfile {
15    /// System prompt for this agent
16    #[serde(default)]
17    pub prompt: Option<String>,
18
19    /// Conversational style or personality
20    #[serde(default)]
21    pub style: Option<String>,
22
23    /// Temperature override for this agent (0.0 to 2.0)
24    #[serde(default)]
25    pub temperature: Option<f32>,
26
27    /// Model provider override (e.g., "openai", "anthropic", "lmstudio")
28    #[serde(default)]
29    pub model_provider: Option<String>,
30
31    /// Model name override (e.g., "gpt-4", "claude-3-opus")
32    #[serde(default)]
33    pub model_name: Option<String>,
34
35    /// List of tools this agent is allowed to use
36    #[serde(default)]
37    pub allowed_tools: Option<Vec<String>>,
38
39    /// List of tools this agent is forbidden from using
40    #[serde(default)]
41    pub denied_tools: Option<Vec<String>>,
42
43    /// Memory parameters: number of messages to recall (k for top-k)
44    #[serde(default = "AgentProfile::default_memory_k")]
45    pub memory_k: usize,
46
47    /// Top-p sampling parameter for memory recall
48    #[serde(default = "AgentProfile::default_top_p")]
49    pub top_p: f32,
50
51    /// Maximum context window size for this agent
52    #[serde(default)]
53    pub max_context_tokens: Option<usize>,
54
55    // ========== Knowledge Graph Configuration ==========
56    /// Enable knowledge graph features for this agent
57    #[serde(default)]
58    pub enable_graph: bool,
59
60    /// Use graph-based memory recall (combines with embeddings)
61    #[serde(default)]
62    pub graph_memory: bool,
63
64    /// Maximum graph traversal depth for context building
65    #[serde(default = "AgentProfile::default_graph_depth")]
66    pub graph_depth: usize,
67
68    /// Weight for graph-based relevance vs semantic similarity (0.0 to 1.0)
69    #[serde(default = "AgentProfile::default_graph_weight")]
70    pub graph_weight: f32,
71
72    /// Automatically build graph from conversations
73    #[serde(default)]
74    pub auto_graph: bool,
75
76    /// Graph-based tool recommendation threshold (0.0 to 1.0)
77    #[serde(default = "AgentProfile::default_graph_threshold")]
78    pub graph_threshold: f32,
79
80    /// Use graph for decision steering
81    #[serde(default)]
82    pub graph_steering: bool,
83
84    // ========== Multi-Model Reasoning Configuration ==========
85    /// Enable fast reasoning with a smaller model
86    #[serde(default)]
87    pub fast_reasoning: bool,
88
89    /// Model provider for fast reasoning (e.g., "mlx", "ollama", "lmstudio")
90    #[serde(default)]
91    pub fast_model_provider: Option<String>,
92
93    /// Model name for fast reasoning (e.g., "lmstudio-community/Llama-3.2-3B-Instruct")
94    #[serde(default)]
95    pub fast_model_name: Option<String>,
96
97    /// Temperature for fast model (typically lower for consistency)
98    #[serde(default = "AgentProfile::default_fast_temperature")]
99    pub fast_model_temperature: f32,
100
101    /// Tasks to delegate to fast model
102    #[serde(default = "AgentProfile::default_fast_tasks")]
103    pub fast_model_tasks: Vec<String>,
104
105    /// Confidence threshold to escalate to main model
106    #[serde(default = "AgentProfile::default_escalation_threshold")]
107    pub escalation_threshold: f32,
108
109    /// Display reasoning summary to user (requires fast model for summarization)
110    #[serde(default)]
111    pub show_reasoning: bool,
112
113    // ========== Audio Transcription Configuration ==========
114    /// Enable audio transcription for this agent
115    #[serde(default)]
116    pub enable_audio_transcription: bool,
117
118    /// Audio response mode: "immediate" or "batch"
119    #[serde(default = "AgentProfile::default_audio_response_mode")]
120    pub audio_response_mode: String,
121
122    /// Preferred audio transcription scenario for testing
123    #[serde(default)]
124    pub audio_scenario: Option<String>,
125}
126
127impl AgentProfile {
128    const ALWAYS_ALLOWED_TOOLS: [&'static str; 1] = ["prompt_user"];
129    fn default_memory_k() -> usize {
130        10
131    }
132
133    fn default_top_p() -> f32 {
134        0.9
135    }
136
137    fn default_graph_depth() -> usize {
138        3
139    }
140
141    fn default_graph_weight() -> f32 {
142        0.5 // Equal weight to graph and semantic
143    }
144
145    fn default_graph_threshold() -> f32 {
146        0.7 // Recommend tools with >70% relevance
147    }
148
149    fn default_fast_temperature() -> f32 {
150        0.3 // Lower temperature for consistency in fast model
151    }
152
153    fn default_fast_tasks() -> Vec<String> {
154        vec![
155            "entity_extraction".to_string(),
156            "graph_analysis".to_string(),
157            "decision_routing".to_string(),
158            "tool_selection".to_string(),
159            "confidence_scoring".to_string(),
160        ]
161    }
162
163    fn default_escalation_threshold() -> f32 {
164        0.6 // Escalate to main model if confidence < 60%
165    }
166
167    fn default_audio_response_mode() -> String {
168        "immediate".to_string()
169    }
170
171    /// Validate the agent profile configuration
172    pub fn validate(&self) -> Result<()> {
173        // Validate temperature if specified
174        if let Some(temp) = self.temperature {
175            if !(0.0..=2.0).contains(&temp) {
176                return Err(AgentError::Invalid(format!(
177                    "temperature must be between 0.0 and 2.0, got {}",
178                    temp
179                ))
180                .into());
181            }
182        }
183
184        // Validate top_p
185        if self.top_p < 0.0 || self.top_p > 1.0 {
186            return Err(AgentError::Invalid(format!(
187                "top_p must be between 0.0 and 1.0, got {}",
188                self.top_p
189            ))
190            .into());
191        }
192
193        // Validate graph_weight
194        if self.graph_weight < 0.0 || self.graph_weight > 1.0 {
195            return Err(AgentError::Invalid(format!(
196                "graph_weight must be between 0.0 and 1.0, got {}",
197                self.graph_weight
198            ))
199            .into());
200        }
201
202        // Validate graph_threshold
203        if self.graph_threshold < 0.0 || self.graph_threshold > 1.0 {
204            return Err(AgentError::Invalid(format!(
205                "graph_threshold must be between 0.0 and 1.0, got {}",
206                self.graph_threshold
207            ))
208            .into());
209        }
210
211        // Validate that allowed_tools and denied_tools don't overlap
212        if let (Some(allowed), Some(denied)) = (&self.allowed_tools, &self.denied_tools) {
213            let allowed_set: HashSet<_> = allowed.iter().collect();
214            let denied_set: HashSet<_> = denied.iter().collect();
215            let overlap: Vec<_> = allowed_set.intersection(&denied_set).collect();
216
217            if !overlap.is_empty() {
218                return Err(AgentError::Invalid(format!(
219                    "tools cannot be both allowed and denied: {:?}",
220                    overlap
221                ))
222                .into());
223            }
224        }
225
226        // Validate model provider if specified
227        if let Some(provider) = &self.model_provider {
228            let valid_providers = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
229            if !valid_providers.contains(&provider.as_str()) {
230                return Err(AgentError::Invalid(format!(
231                    "model_provider must be one of: {}. Got: {}",
232                    valid_providers.join(", "),
233                    provider
234                ))
235                .into());
236            }
237        }
238
239        Ok(())
240    }
241
242    /// Check if a tool is allowed for this agent
243    pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
244        // If denied list exists and contains the tool, deny it
245        if let Some(denied) = &self.denied_tools {
246            if denied.iter().any(|t| t == tool_name) {
247                return false;
248            }
249        }
250
251        if Self::ALWAYS_ALLOWED_TOOLS.contains(&tool_name) {
252            return true;
253        }
254
255        // If allowed list exists, only allow tools in the list
256        if let Some(allowed) = &self.allowed_tools {
257            return allowed.iter().any(|t| t == tool_name);
258        }
259
260        // If no restrictions, allow all tools
261        true
262    }
263
264    /// Get the effective temperature (override or default)
265    pub fn effective_temperature(&self, default: f32) -> f32 {
266        self.temperature.unwrap_or(default)
267    }
268
269    /// Get the effective model provider (override or default)
270    pub fn effective_provider<'a>(&'a self, default: &'a str) -> &'a str {
271        self.model_provider.as_deref().unwrap_or(default)
272    }
273
274    /// Get the effective model name (override or default)
275    pub fn effective_model_name<'a>(&'a self, default: Option<&'a str>) -> Option<&'a str> {
276        self.model_name.as_deref().or(default)
277    }
278}
279
280impl Default for AgentProfile {
281    fn default() -> Self {
282        Self {
283            prompt: None,
284            style: None,
285            temperature: None,
286            model_provider: None,
287            model_name: None,
288            allowed_tools: None,
289            denied_tools: None,
290            memory_k: Self::default_memory_k(),
291            top_p: Self::default_top_p(),
292            max_context_tokens: None,
293            enable_graph: true, // Enable by default
294            graph_memory: true, // Enable by default
295            graph_depth: Self::default_graph_depth(),
296            graph_weight: Self::default_graph_weight(),
297            auto_graph: true, // Enable by default
298            graph_threshold: Self::default_graph_threshold(),
299            graph_steering: true, // Enable by default
300            fast_reasoning: true, // Enable multi-model by default
301            fast_model_provider: Some("lmstudio".to_string()), // Default to LM Studio local server
302            fast_model_name: Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string()),
303            fast_model_temperature: Self::default_fast_temperature(),
304            fast_model_tasks: Self::default_fast_tasks(),
305            escalation_threshold: Self::default_escalation_threshold(),
306            show_reasoning: false,             // Disabled by default
307            enable_audio_transcription: false, // Disabled by default
308            audio_response_mode: Self::default_audio_response_mode(),
309            audio_scenario: None,
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_default_agent_profile() {
320        let profile = AgentProfile::default();
321        assert_eq!(profile.memory_k, 10);
322        assert_eq!(profile.top_p, 0.9);
323
324        // Verify multi-model is enabled by default
325        assert!(profile.fast_reasoning);
326        assert_eq!(profile.fast_model_provider, Some("lmstudio".to_string()));
327        assert_eq!(
328            profile.fast_model_name,
329            Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string())
330        );
331        assert_eq!(profile.fast_model_temperature, 0.3);
332        assert_eq!(profile.escalation_threshold, 0.6);
333
334        // Verify knowledge graph is enabled by default
335        assert!(profile.enable_graph);
336        assert!(profile.graph_memory);
337        assert!(profile.auto_graph);
338        assert!(profile.graph_steering);
339
340        assert!(profile.validate().is_ok());
341    }
342
343    #[test]
344    fn test_validate_invalid_temperature() {
345        let mut profile = AgentProfile::default();
346        profile.temperature = Some(3.0);
347        assert!(profile.validate().is_err());
348    }
349
350    #[test]
351    fn test_validate_invalid_top_p() {
352        let mut profile = AgentProfile::default();
353        profile.top_p = 1.5;
354        assert!(profile.validate().is_err());
355    }
356
357    #[test]
358    fn test_validate_tool_overlap() {
359        let mut profile = AgentProfile::default();
360        profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
361        profile.denied_tools = Some(vec!["tool2".to_string(), "tool3".to_string()]);
362        assert!(profile.validate().is_err());
363    }
364
365    #[test]
366    fn test_is_tool_allowed_no_restrictions() {
367        let profile = AgentProfile::default();
368        assert!(profile.is_tool_allowed("any_tool"));
369        assert!(profile.is_tool_allowed("prompt_user"));
370    }
371
372    #[test]
373    fn test_is_tool_allowed_with_allowlist() {
374        let mut profile = AgentProfile::default();
375        profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
376
377        assert!(profile.is_tool_allowed("tool1"));
378        assert!(profile.is_tool_allowed("tool2"));
379        assert!(!profile.is_tool_allowed("tool3"));
380        // prompt_user should remain available even if not explicitly listed
381        assert!(profile.is_tool_allowed("prompt_user"));
382    }
383
384    #[test]
385    fn test_is_tool_allowed_with_denylist() {
386        let mut profile = AgentProfile::default();
387        profile.denied_tools = Some(vec!["tool1".to_string(), "prompt_user".to_string()]);
388
389        assert!(!profile.is_tool_allowed("tool1"));
390        assert!(profile.is_tool_allowed("tool2"));
391        assert!(!profile.is_tool_allowed("prompt_user"));
392    }
393
394    #[test]
395    fn test_effective_temperature() {
396        let mut profile = AgentProfile::default();
397        assert_eq!(profile.effective_temperature(0.7), 0.7);
398
399        profile.temperature = Some(0.5);
400        assert_eq!(profile.effective_temperature(0.7), 0.5);
401    }
402
403    #[test]
404    fn test_effective_provider() {
405        let mut profile = AgentProfile::default();
406        assert_eq!(profile.effective_provider("mock"), "mock");
407
408        profile.model_provider = Some("openai".to_string());
409        assert_eq!(profile.effective_provider("mock"), "openai");
410    }
411}