Skip to main content

walrus_core/agent/
mod.rs

1//! Stateful agent execution unit.
2//!
3//! [`Agent`] owns its configuration, model, and message history. It drives
4//! LLM execution through [`Agent::step`], [`Agent::run`], and
5//! [`Agent::run_stream`]. `run_stream()` is the canonical step loop —
6//! `run()` collects its events and returns the final response.
7
8use crate::model::{Message, Model, Request};
9use anyhow::Result;
10use async_stream::stream;
11use event::{AgentEvent, AgentResponse, AgentStep, AgentStopReason};
12use futures_core::Stream;
13use tokio::sync::mpsc;
14use tool::Dispatcher;
15
16pub use builder::AgentBuilder;
17pub use config::AgentConfig;
18pub use parser::parse_agent_md;
19
20mod builder;
21pub mod config;
22pub mod event;
23mod parser;
24pub mod tool;
25
26/// A stateful agent execution unit.
27///
28/// Generic over `M: Model` — stores the model provider alongside config
29/// and conversation history. Callers drive execution via `step()` (single
30/// LLM round), `run()` (loop to completion), or `run_stream()` (yields
31/// events as a stream).
32pub struct Agent<M: Model> {
33    /// Agent configuration (name, prompt, model, limits, tool_choice).
34    pub config: AgentConfig,
35    /// The model provider for LLM calls.
36    model: M,
37    /// Conversation history (user/assistant/tool messages).
38    pub(crate) history: Vec<Message>,
39}
40
41impl<M: Model> Agent<M> {
42    /// Push a message into the conversation history.
43    pub fn push_message(&mut self, message: Message) {
44        self.history.push(message);
45    }
46
47    /// Return a reference to the conversation history.
48    pub fn messages(&self) -> &[Message] {
49        &self.history
50    }
51
52    /// Clear the conversation history, keeping configuration intact.
53    pub fn clear_history(&mut self) {
54        self.history.clear();
55    }
56
57    /// Perform a single LLM round: send request, dispatch tools, return step.
58    ///
59    /// Composes a [`Request`] from config state (system prompt + history +
60    /// dispatcher tools), calls the stored model, dispatches any tool calls
61    /// via `dispatcher.dispatch()`, and appends results to history.
62    pub async fn step<D: Dispatcher>(&mut self, dispatcher: &D) -> Result<AgentStep> {
63        let model_name = self
64            .config
65            .model
66            .clone()
67            .unwrap_or_else(|| self.model.active_model());
68
69        let mut messages = Vec::with_capacity(1 + self.history.len());
70        if !self.config.system_prompt.is_empty() {
71            messages.push(Message::system(&self.config.system_prompt));
72        }
73        messages.extend(self.history.iter().cloned());
74
75        let tools = dispatcher.tools();
76        let mut request = Request::new(model_name)
77            .with_messages(messages)
78            .with_tool_choice(self.config.tool_choice.clone());
79        if !tools.is_empty() {
80            request = request.with_tools(tools);
81        }
82
83        let response = self.model.send(&request).await?;
84        let tool_calls = response.tool_calls().unwrap_or_default().to_vec();
85
86        // Append the assistant message to history.
87        if let Some(msg) = response.message() {
88            self.history.push(msg);
89        }
90
91        // Dispatch tool calls if any.
92        let mut tool_results = Vec::new();
93        if !tool_calls.is_empty() {
94            let calls: Vec<(&str, &str)> = tool_calls
95                .iter()
96                .map(|tc| (tc.function.name.as_str(), tc.function.arguments.as_str()))
97                .collect();
98
99            let results = dispatcher.dispatch(&calls).await;
100
101            for (tc, result) in tool_calls.iter().zip(results) {
102                let output = match result {
103                    Ok(s) => s,
104                    Err(e) => format!("error: {e}"),
105                };
106
107                let msg = Message::tool(&output, tc.id.clone());
108                self.history.push(msg.clone());
109                tool_results.push(msg);
110            }
111        }
112
113        Ok(AgentStep {
114            response,
115            tool_calls,
116            tool_results,
117        })
118    }
119
120    /// Determine the stop reason for a step with no tool calls.
121    fn stop_reason(step: &AgentStep) -> AgentStopReason {
122        if step.response.content().is_some() {
123            AgentStopReason::TextResponse
124        } else {
125            AgentStopReason::NoAction
126        }
127    }
128
129    /// Run the agent loop to completion, returning the final response.
130    ///
131    /// Wraps [`Agent::run_stream`] — collects all events, sends each through
132    /// `events`, and extracts the `Done` response. The event sender allows
133    /// callers (like Runtime) to observe execution without reimplementing
134    /// the step loop.
135    pub async fn run<D: Dispatcher>(
136        &mut self,
137        dispatcher: &D,
138        events: mpsc::UnboundedSender<AgentEvent>,
139    ) -> AgentResponse {
140        use futures_util::StreamExt;
141
142        let mut stream = std::pin::pin!(self.run_stream(dispatcher));
143        let mut response = None;
144        while let Some(event) = stream.next().await {
145            if let AgentEvent::Done(ref resp) = event {
146                response = Some(resp.clone());
147            }
148            let _ = events.send(event);
149        }
150
151        response.unwrap_or_else(|| AgentResponse {
152            final_response: None,
153            iterations: 0,
154            stop_reason: AgentStopReason::Error("stream ended without Done".into()),
155            steps: vec![],
156        })
157    }
158
159    /// Run the agent loop as a stream of [`AgentEvent`]s.
160    ///
161    /// The canonical step loop. Calls [`Agent::step`] up to `max_iterations`
162    /// times, yielding events as they are produced. Always finishes with a
163    /// `Done` event containing the [`AgentResponse`].
164    pub fn run_stream<'a, D: Dispatcher + 'a>(
165        &'a mut self,
166        dispatcher: &'a D,
167    ) -> impl Stream<Item = AgentEvent> + 'a {
168        stream! {
169            let mut steps = Vec::new();
170            let max = self.config.max_iterations;
171
172            for _ in 0..max {
173                match self.step(dispatcher).await {
174                    Ok(step) => {
175                        let has_tool_calls = !step.tool_calls.is_empty();
176                        let text = step.response.content().cloned();
177
178                        if let Some(ref t) = text {
179                            yield AgentEvent::TextDelta(t.clone());
180                        }
181
182                        if has_tool_calls {
183                            yield AgentEvent::ToolCallsStart(step.tool_calls.clone());
184                            for (tc, result) in step.tool_calls.iter().zip(&step.tool_results) {
185                                yield AgentEvent::ToolResult {
186                                    call_id: tc.id.clone(),
187                                    output: result.content.clone(),
188                                };
189                            }
190                            yield AgentEvent::ToolCallsComplete;
191                        }
192
193                        if !has_tool_calls {
194                            let stop_reason = Self::stop_reason(&step);
195                            steps.push(step);
196                            yield AgentEvent::Done(AgentResponse {
197                                final_response: text,
198                                iterations: steps.len(),
199                                stop_reason,
200                                steps,
201                            });
202                            return;
203                        }
204
205                        steps.push(step);
206                    }
207                    Err(e) => {
208                        yield AgentEvent::Done(AgentResponse {
209                            final_response: None,
210                            iterations: steps.len(),
211                            stop_reason: AgentStopReason::Error(e.to_string()),
212                            steps,
213                        });
214                        return;
215                    }
216                }
217            }
218
219            let final_response = steps.last().and_then(|s| s.response.content().cloned());
220            yield AgentEvent::Done(AgentResponse {
221                final_response,
222                iterations: steps.len(),
223                stop_reason: AgentStopReason::MaxIterations,
224                steps,
225            });
226        }
227    }
228}