1use tracing::{debug, error, warn};
7
8use saorsa_ai::{
9 CompletionRequest, ContentBlock, ContentDelta, Message, StopReason, StreamEvent,
10 StreamingProvider,
11};
12
13use crate::config::AgentConfig;
14use crate::error::{Result, SaorsaAgentError};
15use crate::event::{AgentEvent, EventSender, TurnEndReason};
16use crate::tool::ToolRegistry;
17
18pub struct AgentLoop {
20 provider: Box<dyn StreamingProvider>,
22 config: AgentConfig,
24 tools: ToolRegistry,
26 event_tx: EventSender,
28 messages: Vec<Message>,
30}
31
32impl AgentLoop {
33 pub fn new(
35 provider: Box<dyn StreamingProvider>,
36 config: AgentConfig,
37 tools: ToolRegistry,
38 event_tx: EventSender,
39 ) -> Self {
40 Self {
41 provider,
42 config,
43 tools,
44 event_tx,
45 messages: Vec::new(),
46 }
47 }
48
49 pub async fn run(&mut self, user_message: &str) -> Result<String> {
53 self.messages.push(Message::user(user_message));
54
55 let mut turn = 0u32;
56 let mut final_text = String::new();
57
58 loop {
59 turn += 1;
60
61 if turn > self.config.max_turns {
62 debug!(turn, max = self.config.max_turns, "Max turns reached");
63 let _ = self
64 .event_tx
65 .send(AgentEvent::TurnEnd {
66 turn,
67 reason: TurnEndReason::MaxTurns,
68 })
69 .await;
70 break;
71 }
72
73 let _ = self.event_tx.send(AgentEvent::TurnStart { turn }).await;
74
75 let request = CompletionRequest::new(
76 &self.config.model,
77 self.messages.clone(),
78 self.config.max_tokens,
79 )
80 .system(&self.config.system_prompt)
81 .tools(self.tools.definitions());
82
83 let mut rx = self.provider.stream(request).await?;
85
86 let mut text_content = String::new();
87 let mut tool_calls: Vec<ToolCallInfo> = Vec::new();
88 let mut stop_reason = None;
89
90 while let Some(event) = rx.recv().await {
91 match event {
92 Ok(StreamEvent::ContentBlockStart {
93 content_block: ContentBlock::ToolUse { id, name, .. },
94 ..
95 }) => {
96 tool_calls.push(ToolCallInfo {
97 id,
98 name,
99 input_json: String::new(),
100 });
101 }
102 Ok(StreamEvent::ContentBlockDelta {
103 delta: ContentDelta::TextDelta { text },
104 ..
105 }) => {
106 text_content.push_str(&text);
107 let _ = self.event_tx.send(AgentEvent::TextDelta { text }).await;
108 }
109 Ok(StreamEvent::ContentBlockDelta {
110 delta: ContentDelta::InputJsonDelta { partial_json },
111 ..
112 }) => {
113 if let Some(tc) = tool_calls.last_mut() {
114 tc.input_json.push_str(&partial_json);
115 }
116 }
117 Ok(StreamEvent::ContentBlockDelta {
118 delta: ContentDelta::ThinkingDelta { text },
119 ..
120 }) => {
121 let _ = self.event_tx.send(AgentEvent::ThinkingDelta { text }).await;
122 }
123 Ok(StreamEvent::MessageDelta {
124 stop_reason: sr, ..
125 }) => {
126 stop_reason = sr;
127 }
128 Ok(StreamEvent::Error { message }) => {
129 error!(message = %message, "Stream error");
130 let _ = self
131 .event_tx
132 .send(AgentEvent::Error {
133 message: message.clone(),
134 })
135 .await;
136 return Err(SaorsaAgentError::Internal(message));
137 }
138 _ => {}
139 }
140 }
141
142 if !text_content.is_empty() {
144 final_text.clone_from(&text_content);
145 let _ = self
146 .event_tx
147 .send(AgentEvent::TextComplete {
148 text: text_content.clone(),
149 })
150 .await;
151 }
152
153 let mut assistant_content: Vec<ContentBlock> = Vec::new();
155 if !text_content.is_empty() {
156 assistant_content.push(ContentBlock::Text { text: text_content });
157 }
158
159 let mut parsed_inputs = Vec::with_capacity(tool_calls.len());
161 for tc in &tool_calls {
162 let input: serde_json::Value =
163 serde_json::from_str(&tc.input_json).unwrap_or_else(|e| {
164 warn!(
165 tool = %tc.name,
166 error = %e,
167 "Malformed tool call JSON, using empty object"
168 );
169 serde_json::Value::Object(serde_json::Map::new())
170 });
171
172 let _ = self
173 .event_tx
174 .send(AgentEvent::ToolCall {
175 id: tc.id.clone(),
176 name: tc.name.clone(),
177 input: input.clone(),
178 })
179 .await;
180
181 assistant_content.push(ContentBlock::ToolUse {
182 id: tc.id.clone(),
183 name: tc.name.clone(),
184 input: input.clone(),
185 });
186
187 parsed_inputs.push(input);
188 }
189
190 self.messages.push(Message {
191 role: saorsa_ai::Role::Assistant,
192 content: assistant_content,
193 });
194
195 match stop_reason {
197 Some(StopReason::ToolUse) if !tool_calls.is_empty() => {
198 let tool_results = self.execute_tool_calls(&tool_calls, &parsed_inputs).await;
199
200 for result in &tool_results {
201 self.messages
202 .push(Message::tool_result(&result.id, &result.output));
203 }
204
205 let _ = self
206 .event_tx
207 .send(AgentEvent::TurnEnd {
208 turn,
209 reason: TurnEndReason::ToolUse,
210 })
211 .await;
212
213 }
215 Some(StopReason::MaxTokens) => {
216 let _ = self
217 .event_tx
218 .send(AgentEvent::TurnEnd {
219 turn,
220 reason: TurnEndReason::MaxTokens,
221 })
222 .await;
223 break;
224 }
225 _ => {
226 let _ = self
228 .event_tx
229 .send(AgentEvent::TurnEnd {
230 turn,
231 reason: TurnEndReason::EndTurn,
232 })
233 .await;
234 break;
235 }
236 }
237 }
238
239 Ok(final_text)
240 }
241
242 async fn execute_tool_calls(
244 &self,
245 tool_calls: &[ToolCallInfo],
246 inputs: &[serde_json::Value],
247 ) -> Vec<ToolResultInfo> {
248 let mut results = Vec::new();
249
250 for (tc, input) in tool_calls.iter().zip(inputs.iter()) {
251 let (output, success) = match self.tools.get(&tc.name) {
252 Some(tool) => match tool.execute(input.clone()).await {
253 Ok(result) => (result, true),
254 Err(e) => (format!("Error: {e}"), false),
255 },
256 None => (format!("Unknown tool: {}", tc.name), false),
257 };
258
259 let _ = self
260 .event_tx
261 .send(AgentEvent::ToolResult {
262 id: tc.id.clone(),
263 name: tc.name.clone(),
264 output: output.clone(),
265 success,
266 })
267 .await;
268
269 results.push(ToolResultInfo {
270 id: tc.id.clone(),
271 output,
272 });
273 }
274
275 results
276 }
277
278 pub fn messages(&self) -> &[Message] {
280 &self.messages
281 }
282}
283
284#[derive(Debug)]
286struct ToolCallInfo {
287 id: String,
289 name: String,
291 input_json: String,
293}
294
295#[derive(Debug)]
297struct ToolResultInfo {
298 id: String,
300 output: String,
302}
303
304pub fn default_tools(working_dir: impl Into<std::path::PathBuf>) -> ToolRegistry {
316 use crate::tools::{
317 BashTool, EditTool, FindTool, GrepTool, LsTool, ReadTool, WebSearchTool, WriteTool,
318 };
319 use std::path::PathBuf;
320
321 let wd: PathBuf = working_dir.into();
322 let mut registry = ToolRegistry::new();
323
324 registry.register(Box::new(BashTool::new(wd.clone())));
325 registry.register(Box::new(ReadTool::new(wd.clone())));
326 registry.register(Box::new(WriteTool::new(wd.clone())));
327 registry.register(Box::new(EditTool::new(wd.clone())));
328 registry.register(Box::new(GrepTool::new(wd.clone())));
329 registry.register(Box::new(FindTool::new(wd.clone())));
330 registry.register(Box::new(LsTool::new(wd)));
331 registry.register(Box::new(WebSearchTool::new()));
332
333 registry
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::event::event_channel;
340
341 struct MockProvider {
343 events: Vec<StreamEvent>,
344 }
345
346 #[async_trait::async_trait]
347 impl saorsa_ai::Provider for MockProvider {
348 async fn complete(
349 &self,
350 _request: CompletionRequest,
351 ) -> saorsa_ai::Result<saorsa_ai::CompletionResponse> {
352 Err(saorsa_ai::SaorsaAiError::Internal("not implemented".into()))
353 }
354 }
355
356 #[async_trait::async_trait]
357 impl StreamingProvider for MockProvider {
358 async fn stream(
359 &self,
360 _request: CompletionRequest,
361 ) -> saorsa_ai::Result<tokio::sync::mpsc::Receiver<saorsa_ai::Result<StreamEvent>>>
362 {
363 let (tx, rx) = tokio::sync::mpsc::channel(64);
364 let events = self.events.clone();
365 tokio::spawn(async move {
366 for event in events {
367 if tx.send(Ok(event)).await.is_err() {
368 break;
369 }
370 }
371 });
372 Ok(rx)
373 }
374 }
375
376 fn mock_text_provider(text: &str) -> Box<dyn StreamingProvider> {
377 Box::new(MockProvider {
378 events: vec![
379 StreamEvent::MessageStart {
380 id: "msg_1".into(),
381 model: "test".into(),
382 usage: saorsa_ai::Usage::default(),
383 },
384 StreamEvent::ContentBlockStart {
385 index: 0,
386 content_block: ContentBlock::Text {
387 text: String::new(),
388 },
389 },
390 StreamEvent::ContentBlockDelta {
391 index: 0,
392 delta: ContentDelta::TextDelta {
393 text: text.to_string(),
394 },
395 },
396 StreamEvent::ContentBlockStop { index: 0 },
397 StreamEvent::MessageDelta {
398 stop_reason: Some(StopReason::EndTurn),
399 usage: saorsa_ai::Usage::default(),
400 },
401 StreamEvent::MessageStop,
402 ],
403 })
404 }
405
406 #[tokio::test]
407 async fn agent_simple_text_response() {
408 let provider = mock_text_provider("Hello, world!");
409 let config = AgentConfig::default();
410 let tools = ToolRegistry::new();
411 let (tx, mut rx) = event_channel(64);
412
413 let mut agent = AgentLoop::new(provider, config, tools, tx);
414
415 let handle = tokio::spawn(async move { agent.run("Hi").await });
416
417 let mut events = Vec::new();
419 while let Some(event) = rx.recv().await {
420 events.push(event);
421 }
422
423 let result = handle.await;
424 assert!(result.is_ok());
425 if let Ok(Ok(text)) = result {
426 assert_eq!(text, "Hello, world!");
427 }
428
429 assert!(
431 events
432 .iter()
433 .any(|e| matches!(e, AgentEvent::TurnStart { turn: 1 }))
434 );
435 assert!(
436 events
437 .iter()
438 .any(|e| matches!(e, AgentEvent::TextDelta { .. }))
439 );
440 assert!(
441 events
442 .iter()
443 .any(|e| matches!(e, AgentEvent::TextComplete { .. }))
444 );
445 assert!(events.iter().any(|e| matches!(
446 e,
447 AgentEvent::TurnEnd {
448 reason: TurnEndReason::EndTurn,
449 ..
450 }
451 )));
452 }
453
454 #[tokio::test]
455 async fn agent_max_turns_limit() {
456 let provider = mock_text_provider("response");
457 let config = AgentConfig::default().max_turns(0);
458 let tools = ToolRegistry::new();
459 let (tx, _rx) = event_channel(64);
460
461 let mut agent = AgentLoop::new(provider, config, tools, tx);
462 let result = agent.run("Hi").await;
463 assert!(result.is_ok());
464 }
466
467 #[tokio::test]
468 async fn agent_tracks_messages() {
469 let provider = mock_text_provider("response");
470 let config = AgentConfig::default();
471 let tools = ToolRegistry::new();
472 let (tx, _rx) = event_channel(64);
473
474 let mut agent = AgentLoop::new(provider, config, tools, tx);
475 let _ = agent.run("Hello").await;
476
477 let msgs = agent.messages();
478 assert_eq!(msgs.len(), 2);
480 }
481
482 #[test]
483 fn default_tools_registers_all() {
484 let cwd = std::env::current_dir();
485 assert!(cwd.is_ok());
486 let Ok(dir) = cwd else { unreachable!() };
487 let registry = super::default_tools(dir);
488
489 assert_eq!(registry.len(), 8);
491
492 let names = registry.names();
493 assert!(names.contains(&"bash"));
494 assert!(names.contains(&"read"));
495 assert!(names.contains(&"write"));
496 assert!(names.contains(&"edit"));
497 assert!(names.contains(&"grep"));
498 assert!(names.contains(&"find"));
499 assert!(names.contains(&"ls"));
500 assert!(names.contains(&"web_search"));
501 }
502}