toast_api/
agent.rs

1//! Agent implementation for toast
2//! 
3//! This module provides an agent that can use tools to accomplish tasks,
4//! similar to the DGM paper's approach but adapted for toast's architecture.
5
6use crate::tools::{ToolRegistry, parse_tool_calls, format_tool_output};
7use anyhow::Result;
8use std::collections::HashMap;
9
10/// Agent configuration
11pub struct AgentConfig {
12    pub max_iterations: usize,
13    pub system_prompt: String,
14}
15
16impl Default for AgentConfig {
17    fn default() -> Self {
18        Self {
19            max_iterations: 40,
20            system_prompt: crate::config::SYSTEM_PROMPT.to_string(),
21        }
22    }
23}
24
25/// Agent that can use tools to accomplish tasks
26pub struct Agent {
27    config: AgentConfig,
28    tools: ToolRegistry,
29    iteration_count: usize,
30}
31
32impl Agent {
33    pub fn new(config: AgentConfig) -> Self {
34        Self {
35            config,
36            tools: ToolRegistry::new(),
37            iteration_count: 0,
38        }
39    }
40
41    /// Get tool descriptions for the system prompt
42    pub fn get_tool_descriptions(&self) -> String {
43        let mut descriptions = String::from("\n\nYou have access to the following tools:\n\n");
44        
45        for tool in self.tools.all_tools() {
46            descriptions.push_str(&format!("**{}**\n", tool.name));
47            descriptions.push_str(&format!("{}\n", tool.description));
48            descriptions.push_str(&format!("Schema: {}\n\n", serde_json::to_string_pretty(&tool.input_schema).unwrap()));
49        }
50
51        descriptions.push_str(r#"To use a tool, wrap your tool call in <tool_use> tags:
52<tool_use>
53{"tool": "bash", "params": {"command": "ls -la"}}
54</tool_use>
55
56Always think step by step and use tools to explore, understand, and solve problems."#);
57
58
59        descriptions
60    }
61
62    /// Process a response and execute any tool calls
63    pub async fn process_tool_calls(&mut self, response: &str) -> Result<Vec<(String, String)>> {
64        let tool_calls = parse_tool_calls(response);
65        let mut results = Vec::new();
66
67        if tool_calls.is_empty() {
68            return Ok(results);
69        }
70
71        self.iteration_count += 1;
72        if self.iteration_count >= self.config.max_iterations {
73            println!("⚠️  Maximum iterations ({}) reached. Stopping tool execution.", self.config.max_iterations);
74            return Ok(results);
75        }
76
77        for tool_call in tool_calls {
78            let colored_tool = match tool_call.tool.as_str() {
79                "read_file" => format!("\x1b[34m{}\x1b[0m", tool_call.tool),
80                "bash" | "exec" => format!("\x1b[38;5;208m{}\x1b[0m", tool_call.tool),
81                _ => tool_call.tool.clone(),
82            };
83            println!("Executing {}: {}", 
84                     colored_tool, 
85                     serde_json::to_string(&tool_call.params).unwrap_or_default());
86
87            match self.tools.execute(&tool_call.tool, tool_call.params).await {
88                Ok(output) => {
89                    let formatted = format_tool_output(&tool_call.tool, &output);
90                    println!("{formatted}");
91                    results.push((tool_call.tool, output));
92                }
93                Err(e) => {
94                    let error_msg = format!("Error executing {}: {}", tool_call.tool, e);
95                    println!("❌ {error_msg}");
96                    results.push((tool_call.tool, error_msg));
97                }
98            }
99        }
100
101        Ok(results)
102    }
103
104    /// Reset iteration count for a new task
105    pub fn reset(&mut self) {
106        self.iteration_count = 0;
107    }
108
109    /// Get enhanced system prompt with tool descriptions
110    pub fn get_system_prompt(&self) -> String {
111        format!("{}{}", self.config.system_prompt, self.get_tool_descriptions())
112    }
113}
114
115/// Agent session that maintains context across multiple interactions
116pub struct AgentSession {
117    agent: Agent,
118    context: HashMap<String, String>,
119}
120
121impl AgentSession {
122    pub fn new(config: AgentConfig) -> Self {
123        Self {
124            agent: Agent::new(config),
125            context: HashMap::new(),
126        }
127    }
128
129    /// Add context information
130    pub fn add_context(&mut self, key: String, value: String) {
131        self.context.insert(key, value);
132    }
133
134    /// Get the agent
135    pub fn agent(&mut self) -> &mut Agent {
136        &mut self.agent
137    }
138
139    /// Get context
140    pub fn context(&self) -> &HashMap<String, String> {
141        &self.context
142    }
143
144    /// Reset for a new task
145    pub fn reset(&mut self) {
146        self.agent.reset();
147        self.context.clear();
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_tool_parsing() {
157        // Test with realistic formatting (no leading whitespace on command lines)
158        let response = r#"Let me check the files:
159<tool_use>
160{"tool": "bash", "params": {"command": "ls -la"}}
161</tool_use>
162
163And also:
164# exec pwd
165# read_file test.txt"#;
166
167        let tool_calls = parse_tool_calls(response);
168        
169        // Should find 3 tool calls
170        assert_eq!(tool_calls.len(), 3);
171        
172        // Check that we have the right tools
173        let bash_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "bash").collect();
174        let editor_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "editor").collect();
175        
176        assert_eq!(bash_calls.len(), 2);
177        assert_eq!(editor_calls.len(), 1);
178        
179        // Check the specific commands
180        assert!(bash_calls.iter().any(|tc| 
181            tc.params.get("command").and_then(|v| v.as_str()) == Some("ls -la")));
182        assert!(bash_calls.iter().any(|tc| 
183            tc.params.get("command").and_then(|v| v.as_str()) == Some("pwd")));
184    }
185}