Skip to main content

saorsa_agent/
agent.rs

1//! Core agent loop for interacting with LLM providers.
2//!
3//! The agent loop sends messages to the LLM, processes streaming responses,
4//! executes tool calls, and continues until the model stops or a turn limit is reached.
5
6use tracing::{debug, error};
7
8use saorsa_ai::{
9    CompletionRequest, ContentBlock, ContentDelta, Message, StopReason, StreamEvent,
10    StreamingProvider,
11};
12
13use crate::config::AgentConfig;
14use crate::error::{Result, SaorsaAgentError};
15use crate::event::{AgentEvent, EventSender, TurnEndReason};
16use crate::tool::ToolRegistry;
17
18/// The core agent loop.
19pub struct AgentLoop {
20    /// Provider for LLM completions.
21    provider: Box<dyn StreamingProvider>,
22    /// Configuration.
23    config: AgentConfig,
24    /// Tool registry.
25    tools: ToolRegistry,
26    /// Event sender for UI integration.
27    event_tx: EventSender,
28    /// Conversation history.
29    messages: Vec<Message>,
30}
31
32impl AgentLoop {
33    /// Create a new agent loop.
34    pub fn new(
35        provider: Box<dyn StreamingProvider>,
36        config: AgentConfig,
37        tools: ToolRegistry,
38        event_tx: EventSender,
39    ) -> Self {
40        Self {
41            provider,
42            config,
43            tools,
44            event_tx,
45            messages: Vec::new(),
46        }
47    }
48
49    /// Add a user message and run the agent loop until completion.
50    ///
51    /// Returns the final assistant text response, or an error.
52    pub async fn run(&mut self, user_message: &str) -> Result<String> {
53        self.messages.push(Message::user(user_message));
54
55        let mut turn = 0u32;
56        let mut final_text = String::new();
57
58        loop {
59            turn += 1;
60
61            if turn > self.config.max_turns {
62                debug!(turn, max = self.config.max_turns, "Max turns reached");
63                let _ = self
64                    .event_tx
65                    .send(AgentEvent::TurnEnd {
66                        turn,
67                        reason: TurnEndReason::MaxTurns,
68                    })
69                    .await;
70                break;
71            }
72
73            let _ = self.event_tx.send(AgentEvent::TurnStart { turn }).await;
74
75            let request = CompletionRequest::new(
76                &self.config.model,
77                self.messages.clone(),
78                self.config.max_tokens,
79            )
80            .system(&self.config.system_prompt)
81            .tools(self.tools.definitions());
82
83            // Stream the response.
84            let mut rx = self.provider.stream(request).await?;
85
86            let mut text_content = String::new();
87            let mut tool_calls: Vec<ToolCallInfo> = Vec::new();
88            let mut stop_reason = None;
89
90            while let Some(event) = rx.recv().await {
91                match event {
92                    Ok(StreamEvent::ContentBlockStart {
93                        content_block: ContentBlock::ToolUse { id, name, .. },
94                        ..
95                    }) => {
96                        tool_calls.push(ToolCallInfo {
97                            id,
98                            name,
99                            input_json: String::new(),
100                        });
101                    }
102                    Ok(StreamEvent::ContentBlockDelta {
103                        delta: ContentDelta::TextDelta { text },
104                        ..
105                    }) => {
106                        text_content.push_str(&text);
107                        let _ = self.event_tx.send(AgentEvent::TextDelta { text }).await;
108                    }
109                    Ok(StreamEvent::ContentBlockDelta {
110                        delta: ContentDelta::InputJsonDelta { partial_json },
111                        ..
112                    }) => {
113                        if let Some(tc) = tool_calls.last_mut() {
114                            tc.input_json.push_str(&partial_json);
115                        }
116                    }
117                    Ok(StreamEvent::MessageDelta {
118                        stop_reason: sr, ..
119                    }) => {
120                        stop_reason = sr;
121                    }
122                    Ok(StreamEvent::Error { message }) => {
123                        error!(message = %message, "Stream error");
124                        let _ = self
125                            .event_tx
126                            .send(AgentEvent::Error {
127                                message: message.clone(),
128                            })
129                            .await;
130                        return Err(SaorsaAgentError::Internal(message));
131                    }
132                    _ => {}
133                }
134            }
135
136            // Emit text complete event if we got text.
137            if !text_content.is_empty() {
138                final_text.clone_from(&text_content);
139                let _ = self
140                    .event_tx
141                    .send(AgentEvent::TextComplete {
142                        text: text_content.clone(),
143                    })
144                    .await;
145            }
146
147            // Build the assistant message for history.
148            let mut assistant_content: Vec<ContentBlock> = Vec::new();
149            if !text_content.is_empty() {
150                assistant_content.push(ContentBlock::Text { text: text_content });
151            }
152
153            // Emit tool call events.
154            for tc in &tool_calls {
155                let input: serde_json::Value = serde_json::from_str(&tc.input_json)
156                    .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
157
158                let _ = self
159                    .event_tx
160                    .send(AgentEvent::ToolCall {
161                        id: tc.id.clone(),
162                        name: tc.name.clone(),
163                        input: input.clone(),
164                    })
165                    .await;
166
167                assistant_content.push(ContentBlock::ToolUse {
168                    id: tc.id.clone(),
169                    name: tc.name.clone(),
170                    input,
171                });
172            }
173
174            self.messages.push(Message {
175                role: saorsa_ai::Role::Assistant,
176                content: assistant_content,
177            });
178
179            // Handle tool calls.
180            match stop_reason {
181                Some(StopReason::ToolUse) if !tool_calls.is_empty() => {
182                    let tool_results = self.execute_tool_calls(&tool_calls).await;
183
184                    for result in &tool_results {
185                        self.messages
186                            .push(Message::tool_result(&result.id, &result.output));
187                    }
188
189                    let _ = self
190                        .event_tx
191                        .send(AgentEvent::TurnEnd {
192                            turn,
193                            reason: TurnEndReason::ToolUse,
194                        })
195                        .await;
196
197                    // Continue the loop for the next turn.
198                }
199                Some(StopReason::MaxTokens) => {
200                    let _ = self
201                        .event_tx
202                        .send(AgentEvent::TurnEnd {
203                            turn,
204                            reason: TurnEndReason::MaxTokens,
205                        })
206                        .await;
207                    break;
208                }
209                _ => {
210                    // EndTurn, StopSequence, or None — we're done.
211                    let _ = self
212                        .event_tx
213                        .send(AgentEvent::TurnEnd {
214                            turn,
215                            reason: TurnEndReason::EndTurn,
216                        })
217                        .await;
218                    break;
219                }
220            }
221        }
222
223        Ok(final_text)
224    }
225
226    /// Execute a list of tool calls and return results.
227    async fn execute_tool_calls(&self, tool_calls: &[ToolCallInfo]) -> Vec<ToolResultInfo> {
228        let mut results = Vec::new();
229
230        for tc in tool_calls {
231            let input: serde_json::Value = serde_json::from_str(&tc.input_json)
232                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
233
234            let (output, success) = match self.tools.get(&tc.name) {
235                Some(tool) => match tool.execute(input).await {
236                    Ok(result) => (result, true),
237                    Err(e) => (format!("Error: {e}"), false),
238                },
239                None => (format!("Unknown tool: {}", tc.name), false),
240            };
241
242            let _ = self
243                .event_tx
244                .send(AgentEvent::ToolResult {
245                    id: tc.id.clone(),
246                    name: tc.name.clone(),
247                    output: output.clone(),
248                    success,
249                })
250                .await;
251
252            results.push(ToolResultInfo {
253                id: tc.id.clone(),
254                output,
255            });
256        }
257
258        results
259    }
260
261    /// Get the current conversation messages.
262    pub fn messages(&self) -> &[Message] {
263        &self.messages
264    }
265}
266
267/// Internal tracking for a tool call being assembled from stream events.
268#[derive(Debug)]
269struct ToolCallInfo {
270    /// Tool use ID.
271    id: String,
272    /// Tool name.
273    name: String,
274    /// Accumulated input JSON string.
275    input_json: String,
276}
277
278/// Internal tracking for a tool result.
279#[derive(Debug)]
280struct ToolResultInfo {
281    /// Tool use ID.
282    id: String,
283    /// Tool output.
284    output: String,
285}
286
287/// Create a default tool registry with all built-in tools.
288///
289/// This includes:
290/// - BashTool: Execute shell commands
291/// - ReadTool: Read file contents with optional line ranges
292/// - WriteTool: Write files with diff display
293/// - EditTool: Surgical file editing with ambiguity detection
294/// - GrepTool: Search file contents with regex
295/// - FindTool: Find files by name pattern
296/// - LsTool: List directory contents with metadata
297pub fn default_tools(working_dir: impl Into<std::path::PathBuf>) -> ToolRegistry {
298    use crate::tools::{BashTool, EditTool, FindTool, GrepTool, LsTool, ReadTool, WriteTool};
299    use std::path::PathBuf;
300
301    let wd: PathBuf = working_dir.into();
302    let mut registry = ToolRegistry::new();
303
304    registry.register(Box::new(BashTool::new(wd.clone())));
305    registry.register(Box::new(ReadTool::new(wd.clone())));
306    registry.register(Box::new(WriteTool::new(wd.clone())));
307    registry.register(Box::new(EditTool::new(wd.clone())));
308    registry.register(Box::new(GrepTool::new(wd.clone())));
309    registry.register(Box::new(FindTool::new(wd.clone())));
310    registry.register(Box::new(LsTool::new(wd)));
311
312    registry
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::event::event_channel;
319
320    /// Mock provider that returns a fixed response.
321    struct MockProvider {
322        events: Vec<StreamEvent>,
323    }
324
325    #[async_trait::async_trait]
326    impl saorsa_ai::Provider for MockProvider {
327        async fn complete(
328            &self,
329            _request: CompletionRequest,
330        ) -> saorsa_ai::Result<saorsa_ai::CompletionResponse> {
331            Err(saorsa_ai::SaorsaAiError::Internal("not implemented".into()))
332        }
333    }
334
335    #[async_trait::async_trait]
336    impl StreamingProvider for MockProvider {
337        async fn stream(
338            &self,
339            _request: CompletionRequest,
340        ) -> saorsa_ai::Result<tokio::sync::mpsc::Receiver<saorsa_ai::Result<StreamEvent>>>
341        {
342            let (tx, rx) = tokio::sync::mpsc::channel(64);
343            let events = self.events.clone();
344            tokio::spawn(async move {
345                for event in events {
346                    if tx.send(Ok(event)).await.is_err() {
347                        break;
348                    }
349                }
350            });
351            Ok(rx)
352        }
353    }
354
355    fn mock_text_provider(text: &str) -> Box<dyn StreamingProvider> {
356        Box::new(MockProvider {
357            events: vec![
358                StreamEvent::MessageStart {
359                    id: "msg_1".into(),
360                    model: "test".into(),
361                    usage: saorsa_ai::Usage::default(),
362                },
363                StreamEvent::ContentBlockStart {
364                    index: 0,
365                    content_block: ContentBlock::Text {
366                        text: String::new(),
367                    },
368                },
369                StreamEvent::ContentBlockDelta {
370                    index: 0,
371                    delta: ContentDelta::TextDelta {
372                        text: text.to_string(),
373                    },
374                },
375                StreamEvent::ContentBlockStop { index: 0 },
376                StreamEvent::MessageDelta {
377                    stop_reason: Some(StopReason::EndTurn),
378                    usage: saorsa_ai::Usage::default(),
379                },
380                StreamEvent::MessageStop,
381            ],
382        })
383    }
384
385    #[tokio::test]
386    async fn agent_simple_text_response() {
387        let provider = mock_text_provider("Hello, world!");
388        let config = AgentConfig::default();
389        let tools = ToolRegistry::new();
390        let (tx, mut rx) = event_channel(64);
391
392        let mut agent = AgentLoop::new(provider, config, tools, tx);
393
394        let handle = tokio::spawn(async move { agent.run("Hi").await });
395
396        // Collect events.
397        let mut events = Vec::new();
398        while let Some(event) = rx.recv().await {
399            events.push(event);
400        }
401
402        let result = handle.await;
403        assert!(result.is_ok());
404        if let Ok(Ok(text)) = result {
405            assert_eq!(text, "Hello, world!");
406        }
407
408        // Should have: TurnStart, TextDelta, TextComplete, TurnEnd.
409        assert!(
410            events
411                .iter()
412                .any(|e| matches!(e, AgentEvent::TurnStart { turn: 1 }))
413        );
414        assert!(
415            events
416                .iter()
417                .any(|e| matches!(e, AgentEvent::TextDelta { .. }))
418        );
419        assert!(
420            events
421                .iter()
422                .any(|e| matches!(e, AgentEvent::TextComplete { .. }))
423        );
424        assert!(events.iter().any(|e| matches!(
425            e,
426            AgentEvent::TurnEnd {
427                reason: TurnEndReason::EndTurn,
428                ..
429            }
430        )));
431    }
432
433    #[tokio::test]
434    async fn agent_max_turns_limit() {
435        let provider = mock_text_provider("response");
436        let config = AgentConfig::default().max_turns(0);
437        let tools = ToolRegistry::new();
438        let (tx, _rx) = event_channel(64);
439
440        let mut agent = AgentLoop::new(provider, config, tools, tx);
441        let result = agent.run("Hi").await;
442        assert!(result.is_ok());
443        // With max_turns=0, it should break immediately.
444    }
445
446    #[tokio::test]
447    async fn agent_tracks_messages() {
448        let provider = mock_text_provider("response");
449        let config = AgentConfig::default();
450        let tools = ToolRegistry::new();
451        let (tx, _rx) = event_channel(64);
452
453        let mut agent = AgentLoop::new(provider, config, tools, tx);
454        let _ = agent.run("Hello").await;
455
456        let msgs = agent.messages();
457        // Should have user message + assistant message.
458        assert_eq!(msgs.len(), 2);
459    }
460
461    #[test]
462    fn default_tools_registers_all() {
463        let cwd = std::env::current_dir();
464        assert!(cwd.is_ok());
465        let Ok(dir) = cwd else { unreachable!() };
466        let registry = super::default_tools(dir);
467
468        // Verify all 7 tools are registered
469        assert_eq!(registry.len(), 7);
470
471        let names = registry.names();
472        assert!(names.contains(&"bash"));
473        assert!(names.contains(&"read"));
474        assert!(names.contains(&"write"));
475        assert!(names.contains(&"edit"));
476        assert!(names.contains(&"grep"));
477        assert!(names.contains(&"find"));
478        assert!(names.contains(&"ls"));
479    }
480}