1use crate::tools::{ToolRegistry, parse_tool_calls, format_tool_output};
7use anyhow::Result;
8use std::collections::HashMap;
9
10pub struct AgentConfig {
12 pub max_iterations: usize,
13 pub system_prompt: String,
14}
15
16impl Default for AgentConfig {
17 fn default() -> Self {
18 Self {
19 max_iterations: 40,
20 system_prompt: crate::config::SYSTEM_PROMPT.to_string(),
21 }
22 }
23}
24
25pub struct Agent {
27 config: AgentConfig,
28 tools: ToolRegistry,
29 iteration_count: usize,
30}
31
32impl Agent {
33 pub fn new(config: AgentConfig) -> Self {
34 Self {
35 config,
36 tools: ToolRegistry::new(),
37 iteration_count: 0,
38 }
39 }
40
41 pub fn get_tool_descriptions(&self) -> String {
43 let mut descriptions = String::from("\n\nYou have access to the following tools:\n\n");
44
45 for tool in self.tools.all_tools() {
46 descriptions.push_str(&format!("**{}**\n", tool.name));
47 descriptions.push_str(&format!("{}\n", tool.description));
48 descriptions.push_str(&format!("Schema: {}\n\n", serde_json::to_string_pretty(&tool.input_schema).unwrap()));
49 }
50
51 descriptions.push_str(r#"To use a tool, wrap your tool call in <tool_use> tags:
52<tool_use>
53{"tool": "bash", "params": {"command": "ls -la"}}
54</tool_use>
55
56Always think step by step and use tools to explore, understand, and solve problems."#);
57
58
59 descriptions
60 }
61
62 pub async fn process_tool_calls(&mut self, response: &str) -> Result<Vec<(String, String)>> {
64 let tool_calls = parse_tool_calls(response);
65 let mut results = Vec::new();
66
67 if tool_calls.is_empty() {
68 return Ok(results);
69 }
70
71 self.iteration_count += 1;
72 if self.iteration_count >= self.config.max_iterations {
73 println!("⚠️ Maximum iterations ({}) reached. Stopping tool execution.", self.config.max_iterations);
74 return Ok(results);
75 }
76
77 for tool_call in tool_calls {
78 let colored_tool = match tool_call.tool.as_str() {
79 "read_file" => format!("\x1b[34m{}\x1b[0m", tool_call.tool),
80 "bash" | "exec" => format!("\x1b[38;5;208m{}\x1b[0m", tool_call.tool),
81 _ => tool_call.tool.clone(),
82 };
83 println!("Executing {}: {}",
84 colored_tool,
85 serde_json::to_string(&tool_call.params).unwrap_or_default());
86
87 match self.tools.execute(&tool_call.tool, tool_call.params).await {
88 Ok(output) => {
89 let formatted = format_tool_output(&tool_call.tool, &output);
90 println!("{formatted}");
91 results.push((tool_call.tool, output));
92 }
93 Err(e) => {
94 let error_msg = format!("Error executing {}: {}", tool_call.tool, e);
95 println!("❌ {error_msg}");
96 results.push((tool_call.tool, error_msg));
97 }
98 }
99 }
100
101 Ok(results)
102 }
103
104 pub fn reset(&mut self) {
106 self.iteration_count = 0;
107 }
108
109 pub fn get_system_prompt(&self) -> String {
111 format!("{}{}", self.config.system_prompt, self.get_tool_descriptions())
112 }
113}
114
115pub struct AgentSession {
117 agent: Agent,
118 context: HashMap<String, String>,
119}
120
121impl AgentSession {
122 pub fn new(config: AgentConfig) -> Self {
123 Self {
124 agent: Agent::new(config),
125 context: HashMap::new(),
126 }
127 }
128
129 pub fn add_context(&mut self, key: String, value: String) {
131 self.context.insert(key, value);
132 }
133
134 pub fn agent(&mut self) -> &mut Agent {
136 &mut self.agent
137 }
138
139 pub fn context(&self) -> &HashMap<String, String> {
141 &self.context
142 }
143
144 pub fn reset(&mut self) {
146 self.agent.reset();
147 self.context.clear();
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_tool_parsing() {
157 let response = r#"Let me check the files:
159<tool_use>
160{"tool": "bash", "params": {"command": "ls -la"}}
161</tool_use>
162
163And also:
164# exec pwd
165# read_file test.txt"#;
166
167 let tool_calls = parse_tool_calls(response);
168
169 assert_eq!(tool_calls.len(), 3);
171
172 let bash_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "bash").collect();
174 let editor_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "editor").collect();
175
176 assert_eq!(bash_calls.len(), 2);
177 assert_eq!(editor_calls.len(), 1);
178
179 assert!(bash_calls.iter().any(|tc|
181 tc.params.get("command").and_then(|v| v.as_str()) == Some("ls -la")));
182 assert!(bash_calls.iter().any(|tc|
183 tc.params.get("command").and_then(|v| v.as_str()) == Some("pwd")));
184 }
185}