Skip to main content

saorsa_agent/
agent.rs

1//! Core agent loop for interacting with LLM providers.
2//!
3//! The agent loop sends messages to the LLM, processes streaming responses,
4//! executes tool calls, and continues until the model stops or a turn limit is reached.
5
6use tracing::{debug, error, warn};
7
8use saorsa_ai::{
9    CompletionRequest, ContentBlock, ContentDelta, Message, StopReason, StreamEvent,
10    StreamingProvider,
11};
12
13use crate::config::AgentConfig;
14use crate::error::{Result, SaorsaAgentError};
15use crate::event::{AgentEvent, EventSender, TurnEndReason};
16use crate::tool::ToolRegistry;
17
18/// The core agent loop.
19pub struct AgentLoop {
20    /// Provider for LLM completions.
21    provider: Box<dyn StreamingProvider>,
22    /// Configuration.
23    config: AgentConfig,
24    /// Tool registry.
25    tools: ToolRegistry,
26    /// Event sender for UI integration.
27    event_tx: EventSender,
28    /// Conversation history.
29    messages: Vec<Message>,
30}
31
32impl AgentLoop {
33    /// Create a new agent loop.
34    pub fn new(
35        provider: Box<dyn StreamingProvider>,
36        config: AgentConfig,
37        tools: ToolRegistry,
38        event_tx: EventSender,
39    ) -> Self {
40        Self {
41            provider,
42            config,
43            tools,
44            event_tx,
45            messages: Vec::new(),
46        }
47    }
48
49    /// Add a user message and run the agent loop until completion.
50    ///
51    /// Returns the final assistant text response, or an error.
52    pub async fn run(&mut self, user_message: &str) -> Result<String> {
53        self.messages.push(Message::user(user_message));
54
55        let mut turn = 0u32;
56        let mut final_text = String::new();
57
58        loop {
59            turn += 1;
60
61            if turn > self.config.max_turns {
62                debug!(turn, max = self.config.max_turns, "Max turns reached");
63                let _ = self
64                    .event_tx
65                    .send(AgentEvent::TurnEnd {
66                        turn,
67                        reason: TurnEndReason::MaxTurns,
68                    })
69                    .await;
70                break;
71            }
72
73            let _ = self.event_tx.send(AgentEvent::TurnStart { turn }).await;
74
75            let request = CompletionRequest::new(
76                &self.config.model,
77                self.messages.clone(),
78                self.config.max_tokens,
79            )
80            .system(&self.config.system_prompt)
81            .tools(self.tools.definitions());
82
83            // Stream the response.
84            let mut rx = self.provider.stream(request).await?;
85
86            let mut text_content = String::new();
87            let mut tool_calls: Vec<ToolCallInfo> = Vec::new();
88            let mut stop_reason = None;
89
90            while let Some(event) = rx.recv().await {
91                match event {
92                    Ok(StreamEvent::ContentBlockStart {
93                        content_block: ContentBlock::ToolUse { id, name, .. },
94                        ..
95                    }) => {
96                        tool_calls.push(ToolCallInfo {
97                            id,
98                            name,
99                            input_json: String::new(),
100                        });
101                    }
102                    Ok(StreamEvent::ContentBlockDelta {
103                        delta: ContentDelta::TextDelta { text },
104                        ..
105                    }) => {
106                        text_content.push_str(&text);
107                        let _ = self.event_tx.send(AgentEvent::TextDelta { text }).await;
108                    }
109                    Ok(StreamEvent::ContentBlockDelta {
110                        delta: ContentDelta::InputJsonDelta { partial_json },
111                        ..
112                    }) => {
113                        if let Some(tc) = tool_calls.last_mut() {
114                            tc.input_json.push_str(&partial_json);
115                        }
116                    }
117                    Ok(StreamEvent::ContentBlockDelta {
118                        delta: ContentDelta::ThinkingDelta { text },
119                        ..
120                    }) => {
121                        let _ = self.event_tx.send(AgentEvent::ThinkingDelta { text }).await;
122                    }
123                    Ok(StreamEvent::MessageDelta {
124                        stop_reason: sr, ..
125                    }) => {
126                        stop_reason = sr;
127                    }
128                    Ok(StreamEvent::Error { message }) => {
129                        error!(message = %message, "Stream error");
130                        let _ = self
131                            .event_tx
132                            .send(AgentEvent::Error {
133                                message: message.clone(),
134                            })
135                            .await;
136                        return Err(SaorsaAgentError::Internal(message));
137                    }
138                    _ => {}
139                }
140            }
141
142            // Emit text complete event if we got text.
143            if !text_content.is_empty() {
144                final_text.clone_from(&text_content);
145                let _ = self
146                    .event_tx
147                    .send(AgentEvent::TextComplete {
148                        text: text_content.clone(),
149                    })
150                    .await;
151            }
152
153            // Build the assistant message for history.
154            let mut assistant_content: Vec<ContentBlock> = Vec::new();
155            if !text_content.is_empty() {
156                assistant_content.push(ContentBlock::Text { text: text_content });
157            }
158
159            // Parse tool call inputs once and emit events.
160            let mut parsed_inputs = Vec::with_capacity(tool_calls.len());
161            for tc in &tool_calls {
162                let input: serde_json::Value =
163                    serde_json::from_str(&tc.input_json).unwrap_or_else(|e| {
164                        warn!(
165                            tool = %tc.name,
166                            error = %e,
167                            "Malformed tool call JSON, using empty object"
168                        );
169                        serde_json::Value::Object(serde_json::Map::new())
170                    });
171
172                let _ = self
173                    .event_tx
174                    .send(AgentEvent::ToolCall {
175                        id: tc.id.clone(),
176                        name: tc.name.clone(),
177                        input: input.clone(),
178                    })
179                    .await;
180
181                assistant_content.push(ContentBlock::ToolUse {
182                    id: tc.id.clone(),
183                    name: tc.name.clone(),
184                    input: input.clone(),
185                });
186
187                parsed_inputs.push(input);
188            }
189
190            self.messages.push(Message {
191                role: saorsa_ai::Role::Assistant,
192                content: assistant_content,
193            });
194
195            // Handle tool calls.
196            match stop_reason {
197                Some(StopReason::ToolUse) if !tool_calls.is_empty() => {
198                    let tool_results = self.execute_tool_calls(&tool_calls, &parsed_inputs).await;
199
200                    for result in &tool_results {
201                        self.messages
202                            .push(Message::tool_result(&result.id, &result.output));
203                    }
204
205                    let _ = self
206                        .event_tx
207                        .send(AgentEvent::TurnEnd {
208                            turn,
209                            reason: TurnEndReason::ToolUse,
210                        })
211                        .await;
212
213                    // Continue the loop for the next turn.
214                }
215                Some(StopReason::MaxTokens) => {
216                    let _ = self
217                        .event_tx
218                        .send(AgentEvent::TurnEnd {
219                            turn,
220                            reason: TurnEndReason::MaxTokens,
221                        })
222                        .await;
223                    break;
224                }
225                _ => {
226                    // EndTurn, StopSequence, or None — we're done.
227                    let _ = self
228                        .event_tx
229                        .send(AgentEvent::TurnEnd {
230                            turn,
231                            reason: TurnEndReason::EndTurn,
232                        })
233                        .await;
234                    break;
235                }
236            }
237        }
238
239        Ok(final_text)
240    }
241
242    /// Execute a list of tool calls with pre-parsed inputs and return results.
243    async fn execute_tool_calls(
244        &self,
245        tool_calls: &[ToolCallInfo],
246        inputs: &[serde_json::Value],
247    ) -> Vec<ToolResultInfo> {
248        let mut results = Vec::new();
249
250        for (tc, input) in tool_calls.iter().zip(inputs.iter()) {
251            let (output, success) = match self.tools.get(&tc.name) {
252                Some(tool) => match tool.execute(input.clone()).await {
253                    Ok(result) => (result, true),
254                    Err(e) => (format!("Error: {e}"), false),
255                },
256                None => (format!("Unknown tool: {}", tc.name), false),
257            };
258
259            let _ = self
260                .event_tx
261                .send(AgentEvent::ToolResult {
262                    id: tc.id.clone(),
263                    name: tc.name.clone(),
264                    output: output.clone(),
265                    success,
266                })
267                .await;
268
269            results.push(ToolResultInfo {
270                id: tc.id.clone(),
271                output,
272            });
273        }
274
275        results
276    }
277
278    /// Get the current conversation messages.
279    pub fn messages(&self) -> &[Message] {
280        &self.messages
281    }
282}
283
284/// Internal tracking for a tool call being assembled from stream events.
285#[derive(Debug)]
286struct ToolCallInfo {
287    /// Tool use ID.
288    id: String,
289    /// Tool name.
290    name: String,
291    /// Accumulated input JSON string.
292    input_json: String,
293}
294
295/// Internal tracking for a tool result.
296#[derive(Debug)]
297struct ToolResultInfo {
298    /// Tool use ID.
299    id: String,
300    /// Tool output.
301    output: String,
302}
303
304/// Create a default tool registry with all built-in tools.
305///
306/// This includes:
307/// - BashTool: Execute shell commands
308/// - ReadTool: Read file contents with optional line ranges
309/// - WriteTool: Write files with diff display
310/// - EditTool: Surgical file editing with ambiguity detection
311/// - GrepTool: Search file contents with regex
312/// - FindTool: Find files by name pattern
313/// - LsTool: List directory contents with metadata
314/// - WebSearchTool: Search the web via DuckDuckGo (no API key required)
315pub fn default_tools(working_dir: impl Into<std::path::PathBuf>) -> ToolRegistry {
316    use crate::tools::{
317        BashTool, EditTool, FindTool, GrepTool, LsTool, ReadTool, WebSearchTool, WriteTool,
318    };
319    use std::path::PathBuf;
320
321    let wd: PathBuf = working_dir.into();
322    let mut registry = ToolRegistry::new();
323
324    registry.register(Box::new(BashTool::new(wd.clone())));
325    registry.register(Box::new(ReadTool::new(wd.clone())));
326    registry.register(Box::new(WriteTool::new(wd.clone())));
327    registry.register(Box::new(EditTool::new(wd.clone())));
328    registry.register(Box::new(GrepTool::new(wd.clone())));
329    registry.register(Box::new(FindTool::new(wd.clone())));
330    registry.register(Box::new(LsTool::new(wd)));
331    registry.register(Box::new(WebSearchTool::new()));
332
333    registry
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::event::event_channel;
340
341    /// Mock provider that returns a fixed response.
342    struct MockProvider {
343        events: Vec<StreamEvent>,
344    }
345
346    #[async_trait::async_trait]
347    impl saorsa_ai::Provider for MockProvider {
348        async fn complete(
349            &self,
350            _request: CompletionRequest,
351        ) -> saorsa_ai::Result<saorsa_ai::CompletionResponse> {
352            Err(saorsa_ai::SaorsaAiError::Internal("not implemented".into()))
353        }
354    }
355
356    #[async_trait::async_trait]
357    impl StreamingProvider for MockProvider {
358        async fn stream(
359            &self,
360            _request: CompletionRequest,
361        ) -> saorsa_ai::Result<tokio::sync::mpsc::Receiver<saorsa_ai::Result<StreamEvent>>>
362        {
363            let (tx, rx) = tokio::sync::mpsc::channel(64);
364            let events = self.events.clone();
365            tokio::spawn(async move {
366                for event in events {
367                    if tx.send(Ok(event)).await.is_err() {
368                        break;
369                    }
370                }
371            });
372            Ok(rx)
373        }
374    }
375
376    fn mock_text_provider(text: &str) -> Box<dyn StreamingProvider> {
377        Box::new(MockProvider {
378            events: vec![
379                StreamEvent::MessageStart {
380                    id: "msg_1".into(),
381                    model: "test".into(),
382                    usage: saorsa_ai::Usage::default(),
383                },
384                StreamEvent::ContentBlockStart {
385                    index: 0,
386                    content_block: ContentBlock::Text {
387                        text: String::new(),
388                    },
389                },
390                StreamEvent::ContentBlockDelta {
391                    index: 0,
392                    delta: ContentDelta::TextDelta {
393                        text: text.to_string(),
394                    },
395                },
396                StreamEvent::ContentBlockStop { index: 0 },
397                StreamEvent::MessageDelta {
398                    stop_reason: Some(StopReason::EndTurn),
399                    usage: saorsa_ai::Usage::default(),
400                },
401                StreamEvent::MessageStop,
402            ],
403        })
404    }
405
406    #[tokio::test]
407    async fn agent_simple_text_response() {
408        let provider = mock_text_provider("Hello, world!");
409        let config = AgentConfig::default();
410        let tools = ToolRegistry::new();
411        let (tx, mut rx) = event_channel(64);
412
413        let mut agent = AgentLoop::new(provider, config, tools, tx);
414
415        let handle = tokio::spawn(async move { agent.run("Hi").await });
416
417        // Collect events.
418        let mut events = Vec::new();
419        while let Some(event) = rx.recv().await {
420            events.push(event);
421        }
422
423        let result = handle.await;
424        assert!(result.is_ok());
425        if let Ok(Ok(text)) = result {
426            assert_eq!(text, "Hello, world!");
427        }
428
429        // Should have: TurnStart, TextDelta, TextComplete, TurnEnd.
430        assert!(
431            events
432                .iter()
433                .any(|e| matches!(e, AgentEvent::TurnStart { turn: 1 }))
434        );
435        assert!(
436            events
437                .iter()
438                .any(|e| matches!(e, AgentEvent::TextDelta { .. }))
439        );
440        assert!(
441            events
442                .iter()
443                .any(|e| matches!(e, AgentEvent::TextComplete { .. }))
444        );
445        assert!(events.iter().any(|e| matches!(
446            e,
447            AgentEvent::TurnEnd {
448                reason: TurnEndReason::EndTurn,
449                ..
450            }
451        )));
452    }
453
454    #[tokio::test]
455    async fn agent_max_turns_limit() {
456        let provider = mock_text_provider("response");
457        let config = AgentConfig::default().max_turns(0);
458        let tools = ToolRegistry::new();
459        let (tx, _rx) = event_channel(64);
460
461        let mut agent = AgentLoop::new(provider, config, tools, tx);
462        let result = agent.run("Hi").await;
463        assert!(result.is_ok());
464        // With max_turns=0, it should break immediately.
465    }
466
467    #[tokio::test]
468    async fn agent_tracks_messages() {
469        let provider = mock_text_provider("response");
470        let config = AgentConfig::default();
471        let tools = ToolRegistry::new();
472        let (tx, _rx) = event_channel(64);
473
474        let mut agent = AgentLoop::new(provider, config, tools, tx);
475        let _ = agent.run("Hello").await;
476
477        let msgs = agent.messages();
478        // Should have user message + assistant message.
479        assert_eq!(msgs.len(), 2);
480    }
481
482    #[test]
483    fn default_tools_registers_all() {
484        let cwd = std::env::current_dir();
485        assert!(cwd.is_ok());
486        let Ok(dir) = cwd else { unreachable!() };
487        let registry = super::default_tools(dir);
488
489        // Verify all 8 tools are registered
490        assert_eq!(registry.len(), 8);
491
492        let names = registry.names();
493        assert!(names.contains(&"bash"));
494        assert!(names.contains(&"read"));
495        assert!(names.contains(&"write"));
496        assert!(names.contains(&"edit"));
497        assert!(names.contains(&"grep"));
498        assert!(names.contains(&"find"));
499        assert!(names.contains(&"ls"));
500        assert!(names.contains(&"web_search"));
501    }
502}