1use tracing::{debug, error};
7
8use saorsa_ai::{
9 CompletionRequest, ContentBlock, ContentDelta, Message, StopReason, StreamEvent,
10 StreamingProvider,
11};
12
13use crate::config::AgentConfig;
14use crate::error::{Result, SaorsaAgentError};
15use crate::event::{AgentEvent, EventSender, TurnEndReason};
16use crate::tool::ToolRegistry;
17
18pub struct AgentLoop {
20 provider: Box<dyn StreamingProvider>,
22 config: AgentConfig,
24 tools: ToolRegistry,
26 event_tx: EventSender,
28 messages: Vec<Message>,
30}
31
32impl AgentLoop {
33 pub fn new(
35 provider: Box<dyn StreamingProvider>,
36 config: AgentConfig,
37 tools: ToolRegistry,
38 event_tx: EventSender,
39 ) -> Self {
40 Self {
41 provider,
42 config,
43 tools,
44 event_tx,
45 messages: Vec::new(),
46 }
47 }
48
49 pub async fn run(&mut self, user_message: &str) -> Result<String> {
53 self.messages.push(Message::user(user_message));
54
55 let mut turn = 0u32;
56 let mut final_text = String::new();
57
58 loop {
59 turn += 1;
60
61 if turn > self.config.max_turns {
62 debug!(turn, max = self.config.max_turns, "Max turns reached");
63 let _ = self
64 .event_tx
65 .send(AgentEvent::TurnEnd {
66 turn,
67 reason: TurnEndReason::MaxTurns,
68 })
69 .await;
70 break;
71 }
72
73 let _ = self.event_tx.send(AgentEvent::TurnStart { turn }).await;
74
75 let request = CompletionRequest::new(
76 &self.config.model,
77 self.messages.clone(),
78 self.config.max_tokens,
79 )
80 .system(&self.config.system_prompt)
81 .tools(self.tools.definitions());
82
83 let mut rx = self.provider.stream(request).await?;
85
86 let mut text_content = String::new();
87 let mut tool_calls: Vec<ToolCallInfo> = Vec::new();
88 let mut stop_reason = None;
89
90 while let Some(event) = rx.recv().await {
91 match event {
92 Ok(StreamEvent::ContentBlockStart {
93 content_block: ContentBlock::ToolUse { id, name, .. },
94 ..
95 }) => {
96 tool_calls.push(ToolCallInfo {
97 id,
98 name,
99 input_json: String::new(),
100 });
101 }
102 Ok(StreamEvent::ContentBlockDelta {
103 delta: ContentDelta::TextDelta { text },
104 ..
105 }) => {
106 text_content.push_str(&text);
107 let _ = self.event_tx.send(AgentEvent::TextDelta { text }).await;
108 }
109 Ok(StreamEvent::ContentBlockDelta {
110 delta: ContentDelta::InputJsonDelta { partial_json },
111 ..
112 }) => {
113 if let Some(tc) = tool_calls.last_mut() {
114 tc.input_json.push_str(&partial_json);
115 }
116 }
117 Ok(StreamEvent::MessageDelta {
118 stop_reason: sr, ..
119 }) => {
120 stop_reason = sr;
121 }
122 Ok(StreamEvent::Error { message }) => {
123 error!(message = %message, "Stream error");
124 let _ = self
125 .event_tx
126 .send(AgentEvent::Error {
127 message: message.clone(),
128 })
129 .await;
130 return Err(SaorsaAgentError::Internal(message));
131 }
132 _ => {}
133 }
134 }
135
136 if !text_content.is_empty() {
138 final_text.clone_from(&text_content);
139 let _ = self
140 .event_tx
141 .send(AgentEvent::TextComplete {
142 text: text_content.clone(),
143 })
144 .await;
145 }
146
147 let mut assistant_content: Vec<ContentBlock> = Vec::new();
149 if !text_content.is_empty() {
150 assistant_content.push(ContentBlock::Text { text: text_content });
151 }
152
153 for tc in &tool_calls {
155 let input: serde_json::Value = serde_json::from_str(&tc.input_json)
156 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
157
158 let _ = self
159 .event_tx
160 .send(AgentEvent::ToolCall {
161 id: tc.id.clone(),
162 name: tc.name.clone(),
163 input: input.clone(),
164 })
165 .await;
166
167 assistant_content.push(ContentBlock::ToolUse {
168 id: tc.id.clone(),
169 name: tc.name.clone(),
170 input,
171 });
172 }
173
174 self.messages.push(Message {
175 role: saorsa_ai::Role::Assistant,
176 content: assistant_content,
177 });
178
179 match stop_reason {
181 Some(StopReason::ToolUse) if !tool_calls.is_empty() => {
182 let tool_results = self.execute_tool_calls(&tool_calls).await;
183
184 for result in &tool_results {
185 self.messages
186 .push(Message::tool_result(&result.id, &result.output));
187 }
188
189 let _ = self
190 .event_tx
191 .send(AgentEvent::TurnEnd {
192 turn,
193 reason: TurnEndReason::ToolUse,
194 })
195 .await;
196
197 }
199 Some(StopReason::MaxTokens) => {
200 let _ = self
201 .event_tx
202 .send(AgentEvent::TurnEnd {
203 turn,
204 reason: TurnEndReason::MaxTokens,
205 })
206 .await;
207 break;
208 }
209 _ => {
210 let _ = self
212 .event_tx
213 .send(AgentEvent::TurnEnd {
214 turn,
215 reason: TurnEndReason::EndTurn,
216 })
217 .await;
218 break;
219 }
220 }
221 }
222
223 Ok(final_text)
224 }
225
226 async fn execute_tool_calls(&self, tool_calls: &[ToolCallInfo]) -> Vec<ToolResultInfo> {
228 let mut results = Vec::new();
229
230 for tc in tool_calls {
231 let input: serde_json::Value = serde_json::from_str(&tc.input_json)
232 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
233
234 let (output, success) = match self.tools.get(&tc.name) {
235 Some(tool) => match tool.execute(input).await {
236 Ok(result) => (result, true),
237 Err(e) => (format!("Error: {e}"), false),
238 },
239 None => (format!("Unknown tool: {}", tc.name), false),
240 };
241
242 let _ = self
243 .event_tx
244 .send(AgentEvent::ToolResult {
245 id: tc.id.clone(),
246 name: tc.name.clone(),
247 output: output.clone(),
248 success,
249 })
250 .await;
251
252 results.push(ToolResultInfo {
253 id: tc.id.clone(),
254 output,
255 });
256 }
257
258 results
259 }
260
261 pub fn messages(&self) -> &[Message] {
263 &self.messages
264 }
265}
266
267#[derive(Debug)]
269struct ToolCallInfo {
270 id: String,
272 name: String,
274 input_json: String,
276}
277
278#[derive(Debug)]
280struct ToolResultInfo {
281 id: String,
283 output: String,
285}
286
287pub fn default_tools(working_dir: impl Into<std::path::PathBuf>) -> ToolRegistry {
298 use crate::tools::{BashTool, EditTool, FindTool, GrepTool, LsTool, ReadTool, WriteTool};
299 use std::path::PathBuf;
300
301 let wd: PathBuf = working_dir.into();
302 let mut registry = ToolRegistry::new();
303
304 registry.register(Box::new(BashTool::new(wd.clone())));
305 registry.register(Box::new(ReadTool::new(wd.clone())));
306 registry.register(Box::new(WriteTool::new(wd.clone())));
307 registry.register(Box::new(EditTool::new(wd.clone())));
308 registry.register(Box::new(GrepTool::new(wd.clone())));
309 registry.register(Box::new(FindTool::new(wd.clone())));
310 registry.register(Box::new(LsTool::new(wd)));
311
312 registry
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::event::event_channel;
319
320 struct MockProvider {
322 events: Vec<StreamEvent>,
323 }
324
325 #[async_trait::async_trait]
326 impl saorsa_ai::Provider for MockProvider {
327 async fn complete(
328 &self,
329 _request: CompletionRequest,
330 ) -> saorsa_ai::Result<saorsa_ai::CompletionResponse> {
331 Err(saorsa_ai::SaorsaAiError::Internal("not implemented".into()))
332 }
333 }
334
335 #[async_trait::async_trait]
336 impl StreamingProvider for MockProvider {
337 async fn stream(
338 &self,
339 _request: CompletionRequest,
340 ) -> saorsa_ai::Result<tokio::sync::mpsc::Receiver<saorsa_ai::Result<StreamEvent>>>
341 {
342 let (tx, rx) = tokio::sync::mpsc::channel(64);
343 let events = self.events.clone();
344 tokio::spawn(async move {
345 for event in events {
346 if tx.send(Ok(event)).await.is_err() {
347 break;
348 }
349 }
350 });
351 Ok(rx)
352 }
353 }
354
355 fn mock_text_provider(text: &str) -> Box<dyn StreamingProvider> {
356 Box::new(MockProvider {
357 events: vec![
358 StreamEvent::MessageStart {
359 id: "msg_1".into(),
360 model: "test".into(),
361 usage: saorsa_ai::Usage::default(),
362 },
363 StreamEvent::ContentBlockStart {
364 index: 0,
365 content_block: ContentBlock::Text {
366 text: String::new(),
367 },
368 },
369 StreamEvent::ContentBlockDelta {
370 index: 0,
371 delta: ContentDelta::TextDelta {
372 text: text.to_string(),
373 },
374 },
375 StreamEvent::ContentBlockStop { index: 0 },
376 StreamEvent::MessageDelta {
377 stop_reason: Some(StopReason::EndTurn),
378 usage: saorsa_ai::Usage::default(),
379 },
380 StreamEvent::MessageStop,
381 ],
382 })
383 }
384
385 #[tokio::test]
386 async fn agent_simple_text_response() {
387 let provider = mock_text_provider("Hello, world!");
388 let config = AgentConfig::default();
389 let tools = ToolRegistry::new();
390 let (tx, mut rx) = event_channel(64);
391
392 let mut agent = AgentLoop::new(provider, config, tools, tx);
393
394 let handle = tokio::spawn(async move { agent.run("Hi").await });
395
396 let mut events = Vec::new();
398 while let Some(event) = rx.recv().await {
399 events.push(event);
400 }
401
402 let result = handle.await;
403 assert!(result.is_ok());
404 if let Ok(Ok(text)) = result {
405 assert_eq!(text, "Hello, world!");
406 }
407
408 assert!(
410 events
411 .iter()
412 .any(|e| matches!(e, AgentEvent::TurnStart { turn: 1 }))
413 );
414 assert!(
415 events
416 .iter()
417 .any(|e| matches!(e, AgentEvent::TextDelta { .. }))
418 );
419 assert!(
420 events
421 .iter()
422 .any(|e| matches!(e, AgentEvent::TextComplete { .. }))
423 );
424 assert!(events.iter().any(|e| matches!(
425 e,
426 AgentEvent::TurnEnd {
427 reason: TurnEndReason::EndTurn,
428 ..
429 }
430 )));
431 }
432
433 #[tokio::test]
434 async fn agent_max_turns_limit() {
435 let provider = mock_text_provider("response");
436 let config = AgentConfig::default().max_turns(0);
437 let tools = ToolRegistry::new();
438 let (tx, _rx) = event_channel(64);
439
440 let mut agent = AgentLoop::new(provider, config, tools, tx);
441 let result = agent.run("Hi").await;
442 assert!(result.is_ok());
443 }
445
446 #[tokio::test]
447 async fn agent_tracks_messages() {
448 let provider = mock_text_provider("response");
449 let config = AgentConfig::default();
450 let tools = ToolRegistry::new();
451 let (tx, _rx) = event_channel(64);
452
453 let mut agent = AgentLoop::new(provider, config, tools, tx);
454 let _ = agent.run("Hello").await;
455
456 let msgs = agent.messages();
457 assert_eq!(msgs.len(), 2);
459 }
460
461 #[test]
462 fn default_tools_registers_all() {
463 let cwd = std::env::current_dir();
464 assert!(cwd.is_ok());
465 let Ok(dir) = cwd else { unreachable!() };
466 let registry = super::default_tools(dir);
467
468 assert_eq!(registry.len(), 7);
470
471 let names = registry.names();
472 assert!(names.contains(&"bash"));
473 assert!(names.contains(&"read"));
474 assert!(names.contains(&"write"));
475 assert!(names.contains(&"edit"));
476 assert!(names.contains(&"grep"));
477 assert!(names.contains(&"find"));
478 assert!(names.contains(&"ls"));
479 }
480}