1use crate::tools::{ToolRegistry, parse_tool_calls, format_tool_output};
7use anyhow::Result;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11
12pub struct AgentConfig {
14 pub max_iterations: usize,
15 pub system_prompt: String,
16}
17
18impl Default for AgentConfig {
19 fn default() -> Self {
20 Self {
21 max_iterations: 40,
22 system_prompt: crate::config::SYSTEM_PROMPT.to_string(),
23 }
24 }
25}
26
27pub struct Agent {
29 config: AgentConfig,
30 tools: ToolRegistry,
31 iteration_count: usize,
32 interrupt_flag: Arc<AtomicBool>,
33}
34
35impl Agent {
36 pub fn new(config: AgentConfig) -> Self {
37 Self {
38 config,
39 tools: ToolRegistry::new(),
40 iteration_count: 0,
41 interrupt_flag: Arc::new(AtomicBool::new(false)),
42 }
43 }
44
45 pub fn get_tool_descriptions(&self) -> String {
47 let mut descriptions = String::from("\n\nYou have access to the following tools:\n\n");
48
49 for tool in self.tools.all_tools() {
50 descriptions.push_str(&format!("**{}**\n", tool.name));
51 descriptions.push_str(&format!("{}\n", tool.description));
52 descriptions.push_str(&format!("Schema: {}\n\n", serde_json::to_string_pretty(&tool.input_schema).unwrap()));
53 }
54
55 descriptions.push_str(r#"To use a tool, wrap your tool call in <tool_use> tags:
56<tool_use>
57{"tool": "bash", "params": {"command": "ls -la"}}
58</tool_use>
59
60Always think step by step and use tools to explore, understand, and solve problems."#);
61
62
63 descriptions
64 }
65
66 pub async fn process_tool_calls(&mut self, response: &str) -> Result<Vec<(String, String)>> {
68 let tool_calls = parse_tool_calls(response);
69 let mut results = Vec::new();
70
71 if self.interrupt_flag.load(Ordering::Relaxed) {
73 println!("\n⏹️ Agent interrupted by user");
74 self.interrupt_flag.store(false, Ordering::Relaxed);
75 return Ok(results);
76 }
77
78 if tool_calls.is_empty() {
79 return Ok(results);
80 }
81
82 self.iteration_count += 1;
83 if self.iteration_count >= self.config.max_iterations {
84 println!("⚠️ Maximum iterations ({}) reached. Stopping tool execution.", self.config.max_iterations);
85 return Ok(results);
86 }
87
88 for tool_call in tool_calls {
89 if self.interrupt_flag.load(Ordering::Relaxed) {
91 println!("\n⏹️ Agent interrupted by user");
92 self.interrupt_flag.store(false, Ordering::Relaxed);
93 break;
94 }
95
96 let colored_tool = match tool_call.tool.as_str() {
97 "read_file" => format!("\x1b[34m{}\x1b[0m", tool_call.tool),
98 "bash" | "exec" => format!("\x1b[38;5;208m{}\x1b[0m", tool_call.tool),
99 _ => tool_call.tool.clone(),
100 };
101 println!("Executing {}: {}",
102 colored_tool,
103 serde_json::to_string(&tool_call.params).unwrap_or_default());
104
105 match self.tools.execute(&tool_call.tool, tool_call.params).await {
106 Ok(output) => {
107 let formatted = format_tool_output(&tool_call.tool, &output);
108 println!("{formatted}");
109 results.push((tool_call.tool, output));
110 }
111 Err(e) => {
112 let error_msg = format!("Error executing {}: {}", tool_call.tool, e);
113 println!("❌ {error_msg}");
114 results.push((tool_call.tool, error_msg));
115 }
116 }
117 }
118
119 Ok(results)
120 }
121
122 pub fn reset(&mut self) {
124 self.iteration_count = 0;
125 }
126
127 pub fn interrupt_flag(&self) -> Arc<AtomicBool> {
129 Arc::clone(&self.interrupt_flag)
130 }
131
132 pub fn get_system_prompt(&self) -> String {
134 format!("{}{}", self.config.system_prompt, self.get_tool_descriptions())
135 }
136}
137
138pub struct AgentSession {
140 agent: Agent,
141 context: HashMap<String, String>,
142}
143
144impl AgentSession {
145 pub fn new(config: AgentConfig) -> Self {
146 Self {
147 agent: Agent::new(config),
148 context: HashMap::new(),
149 }
150 }
151
152 pub fn add_context(&mut self, key: String, value: String) {
154 self.context.insert(key, value);
155 }
156
157 pub fn agent(&mut self) -> &mut Agent {
159 &mut self.agent
160 }
161
162 pub fn context(&self) -> &HashMap<String, String> {
164 &self.context
165 }
166
167 pub fn interrupt_flag(&self) -> Arc<AtomicBool> {
169 self.agent.interrupt_flag()
170 }
171
172 pub fn reset(&mut self) {
174 self.agent.reset();
175 self.context.clear();
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_tool_parsing() {
185 let response = r#"Let me check the files:
187<tool_use>
188{"tool": "bash", "params": {"command": "ls -la"}}
189</tool_use>
190
191And also:
192# exec pwd
193# read_file test.txt"#;
194
195 let tool_calls = parse_tool_calls(response);
196
197 assert_eq!(tool_calls.len(), 3);
199
200 let bash_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "bash").collect();
202 let editor_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "editor").collect();
203
204 assert_eq!(bash_calls.len(), 2);
205 assert_eq!(editor_calls.len(), 1);
206
207 assert!(bash_calls.iter().any(|tc|
209 tc.params.get("command").and_then(|v| v.as_str()) == Some("ls -la")));
210 assert!(bash_calls.iter().any(|tc|
211 tc.params.get("command").and_then(|v| v.as_str()) == Some("pwd")));
212 }
213}