Skip to main content

walrus_core/agent/
mod.rs

1//! Immutable agent definition and execution methods.
2//!
3//! [`Agent`] owns its configuration, model, tool schemas, and an optional
4//! [`ToolSender`] for dispatching tool calls to the runtime. Conversation
5//! history is passed in externally — the agent itself is stateless.
6//! It drives LLM execution through [`Agent::step`], [`Agent::run`], and
7//! [`Agent::run_stream`]. `run_stream()` is the canonical step loop —
8//! `run()` collects its events and returns the final response.
9
10use crate::model::{
11    Choice, CompletionMeta, Delta, Message, MessageBuilder, Model, Request, Response, Role, Tool,
12    Usage,
13};
14use anyhow::Result;
15use async_stream::stream;
16pub use builder::AgentBuilder;
17pub use config::AgentConfig;
18use event::{AgentEvent, AgentResponse, AgentStep, AgentStopReason};
19use futures_core::Stream;
20use futures_util::StreamExt;
21use std::sync::Arc;
22use tokio::sync::{mpsc, oneshot};
23pub use tool::{AsTool, ToolDescription, ToolRequest, ToolSender};
24
25mod builder;
26mod compact;
27pub mod config;
28pub mod event;
29pub mod tool;
30
31/// Extract sender from the last user message in history.
32fn last_sender(history: &[Message]) -> compact_str::CompactString {
33    history
34        .iter()
35        .rev()
36        .find(|m| m.role == Role::User)
37        .map(|m| m.sender.clone())
38        .unwrap_or_default()
39}
40
41/// Callback interface for compaction hooks.
42///
43/// Allows the agent to call back to the runtime's hook during auto-compaction
44/// without requiring Agent to be generic over the Hook type.
45pub trait CompactHook: Send + Sync {
46    /// Enrich the compaction prompt before sending to the LLM.
47    fn on_compact(&self, agent: &str, prompt: &mut String);
48}
49
50/// An immutable agent definition.
51///
52/// Generic over `M: Model` — stores the model provider alongside config,
53/// tool schemas, and an optional sender for tool dispatch. Conversation
54/// history is owned externally and passed into execution methods.
55/// Callers drive execution via `step()` (single LLM round), `run()` (loop to
56/// completion), or `run_stream()` (yields events as a stream).
57pub struct Agent<M: Model> {
58    /// Agent configuration (name, prompt, model, limits, tool_choice).
59    pub config: AgentConfig,
60    /// The model provider for LLM calls.
61    model: M,
62    /// Tool schemas advertised to the LLM. Set once at build time.
63    tools: Vec<Tool>,
64    /// Sender for dispatching tool calls to the runtime. None = no tools.
65    tool_tx: Option<ToolSender>,
66    /// Compact hook for auto-compaction enrichment.
67    compact_hook: Option<Arc<dyn CompactHook>>,
68}
69
70impl<M: Model> Agent<M> {
71    /// Perform a single LLM round: send request, dispatch tools, return step.
72    ///
73    /// Composes a [`Request`] from config state (system prompt + history +
74    /// tool schemas), calls the stored model, dispatches any tool calls via
75    /// the [`ToolSender`] channel, and appends results to history.
76    pub async fn step(&self, history: &mut Vec<Message>) -> Result<AgentStep> {
77        let model_name = self
78            .config
79            .model
80            .clone()
81            .unwrap_or_else(|| self.model.active_model());
82
83        let mut messages = Vec::with_capacity(1 + history.len());
84        if !self.config.system_prompt.is_empty() {
85            messages.push(Message::system(&self.config.system_prompt));
86        }
87        messages.extend(history.iter().cloned());
88
89        let mut request = Request::new(model_name)
90            .with_messages(messages)
91            .with_tool_choice(self.config.tool_choice.clone())
92            .with_think(self.config.thinking);
93        if !self.tools.is_empty() {
94            request = request.with_tools(self.tools.clone());
95        }
96
97        let response = self.model.send(&request).await?;
98        let tool_calls = response.tool_calls().unwrap_or_default().to_vec();
99
100        if let Some(msg) = response.message() {
101            history.push(msg);
102        }
103
104        let mut tool_results = Vec::new();
105        if !tool_calls.is_empty() {
106            let sender = last_sender(history);
107            for tc in &tool_calls {
108                let result = self
109                    .dispatch_tool(&tc.function.name, &tc.function.arguments, &sender)
110                    .await;
111                let msg = Message::tool(&result, tc.id.clone());
112                history.push(msg.clone());
113                tool_results.push(msg);
114            }
115        }
116
117        Ok(AgentStep {
118            response,
119            tool_calls,
120            tool_results,
121        })
122    }
123
124    /// Dispatch a single tool call via the tool sender channel.
125    ///
126    /// Returns the result string. If no sender is configured, returns an error
127    /// message without panicking.
128    async fn dispatch_tool(&self, name: &str, args: &str, sender: &str) -> String {
129        let Some(tx) = &self.tool_tx else {
130            return format!("tool '{name}' called but no tool sender configured");
131        };
132        let (reply_tx, reply_rx) = oneshot::channel();
133        let req = ToolRequest {
134            name: name.to_owned(),
135            args: args.to_owned(),
136            agent: self.config.name.to_string(),
137            reply: reply_tx,
138            task_id: None,
139            sender: sender.into(),
140        };
141        if tx.send(req).is_err() {
142            return format!("tool channel closed while calling '{name}'");
143        }
144        reply_rx
145            .await
146            .unwrap_or_else(|_| format!("tool '{name}' dropped reply"))
147    }
148
149    /// Determine the stop reason for a step with no tool calls.
150    fn stop_reason(step: &AgentStep) -> AgentStopReason {
151        if step.response.content().is_some() {
152            AgentStopReason::TextResponse
153        } else {
154            AgentStopReason::NoAction
155        }
156    }
157
158    /// Run the agent loop to completion, returning the final response.
159    ///
160    /// Wraps [`Agent::run_stream`] — collects all events, sends each through
161    /// `events`, and extracts the `Done` response.
162    pub async fn run(
163        &self,
164        history: &mut Vec<Message>,
165        events: mpsc::UnboundedSender<AgentEvent>,
166    ) -> AgentResponse {
167        let mut stream = std::pin::pin!(self.run_stream(history));
168        let mut response = None;
169        while let Some(event) = stream.next().await {
170            if let AgentEvent::Done(ref resp) = event {
171                response = Some(resp.clone());
172            }
173            let _ = events.send(event);
174        }
175
176        response.unwrap_or_else(|| AgentResponse {
177            final_response: None,
178            iterations: 0,
179            stop_reason: AgentStopReason::Error("stream ended without Done".into()),
180            steps: vec![],
181        })
182    }
183
184    /// Run the agent loop as a stream of [`AgentEvent`]s.
185    ///
186    /// Uses the model's streaming API so text deltas are yielded token-by-token.
187    /// Tool call responses are dispatched after the stream completes (arguments
188    /// arrive incrementally and must be fully accumulated first).
189    pub fn run_stream<'a>(
190        &'a self,
191        history: &'a mut Vec<Message>,
192    ) -> impl Stream<Item = AgentEvent> + 'a {
193        stream! {
194            let mut steps = Vec::new();
195            let max = self.config.max_iterations;
196
197            for _ in 0..max {
198                // Build the request (same logic as step()).
199                let model_name = self
200                    .config
201                    .model
202                    .clone()
203                    .unwrap_or_else(|| self.model.active_model());
204
205                let mut messages = Vec::with_capacity(1 + history.len());
206                if !self.config.system_prompt.is_empty() {
207                    messages.push(Message::system(&self.config.system_prompt));
208                }
209                messages.extend(history.iter().cloned());
210
211                let mut request = Request::new(model_name)
212                    .with_messages(messages)
213                    .with_tool_choice(self.config.tool_choice.clone())
214                    .with_think(self.config.thinking);
215                if !self.tools.is_empty() {
216                    request = request.with_tools(self.tools.clone());
217                }
218
219                // Stream from the model, yielding text deltas as they arrive.
220                let mut builder = MessageBuilder::new(Role::Assistant);
221                let mut finish_reason = None;
222                let mut last_meta = CompletionMeta::default();
223                let mut last_usage = None;
224                let mut stream_error = None;
225
226                {
227                    let mut chunk_stream = std::pin::pin!(self.model.stream(request));
228                    while let Some(result) = chunk_stream.next().await {
229                        match result {
230                            Ok(chunk) => {
231                                if let Some(text) = chunk.content() {
232                                    yield AgentEvent::TextDelta(text.to_owned());
233                                }
234                                if let Some(reason) = chunk.reasoning_content() {
235                                    yield AgentEvent::ThinkingDelta(reason.to_owned());
236                                }
237                                if let Some(r) = chunk.reason() {
238                                    finish_reason = Some(*r);
239                                }
240                                last_meta = chunk.meta.clone();
241                                if chunk.usage.is_some() {
242                                    last_usage = chunk.usage.clone();
243                                }
244                                builder.accept(&chunk);
245                            }
246                            Err(e) => {
247                                stream_error = Some(e.to_string());
248                                break;
249                            }
250                        }
251                    }
252                }
253                if let Some(e) = stream_error {
254                    yield AgentEvent::Done(AgentResponse {
255                        final_response: None,
256                        iterations: steps.len(),
257                        stop_reason: AgentStopReason::Error(e),
258                        steps,
259                    });
260                    return;
261                }
262
263                // Build the accumulated message and response.
264                let msg = builder.build();
265                let tool_calls = msg.tool_calls.to_vec();
266                let content = if msg.content.is_empty() {
267                    None
268                } else {
269                    Some(msg.content.clone())
270                };
271
272                let response = Response {
273                    meta: last_meta,
274                    choices: vec![Choice {
275                        index: 0,
276                        delta: Delta {
277                            role: Some(Role::Assistant),
278                            content: content.clone(),
279                            reasoning_content: if msg.reasoning_content.is_empty() {
280                                None
281                            } else {
282                                Some(msg.reasoning_content.clone())
283                            },
284                            tool_calls: if tool_calls.is_empty() {
285                                None
286                            } else {
287                                Some(tool_calls.clone())
288                            },
289                        },
290                        finish_reason,
291                        logprobs: None,
292                    }],
293                    usage: last_usage.unwrap_or(Usage {
294                        prompt_tokens: 0,
295                        completion_tokens: 0,
296                        total_tokens: 0,
297                        prompt_cache_hit_tokens: None,
298                        prompt_cache_miss_tokens: None,
299                        completion_tokens_details: None,
300                    }),
301                };
302
303                history.push(msg);
304                let has_tool_calls = !tool_calls.is_empty();
305
306                // Dispatch tool calls if any.
307                let mut tool_results = Vec::new();
308                if has_tool_calls {
309                    let sender = last_sender(history);
310                    yield AgentEvent::ToolCallsStart(tool_calls.clone());
311                    for tc in &tool_calls {
312                        let result = self
313                            .dispatch_tool(&tc.function.name, &tc.function.arguments, &sender)
314                            .await;
315                        let msg = Message::tool(&result, tc.id.clone());
316                        history.push(msg.clone());
317                        tool_results.push(msg);
318                        yield AgentEvent::ToolResult {
319                            call_id: tc.id.clone(),
320                            output: result,
321                        };
322                    }
323                    yield AgentEvent::ToolCallsComplete;
324                }
325
326                // Auto-compaction: check token estimate after each step.
327                if let Some(threshold) = self.config.compact_threshold
328                    && Self::estimate_tokens(history) > threshold
329                {
330                    if let Some(summary) = self.compact(history).await {
331                        *history = vec![Message::user(&summary)];
332                        yield AgentEvent::TextDelta(
333                            "\n[context compacted]\n".to_owned(),
334                        );
335                    }
336                    continue;
337                }
338
339                let step = AgentStep {
340                    response,
341                    tool_calls,
342                    tool_results,
343                };
344
345                if !has_tool_calls {
346                    let stop_reason = Self::stop_reason(&step);
347                    steps.push(step);
348                    yield AgentEvent::Done(AgentResponse {
349                        final_response: content,
350                        iterations: steps.len(),
351                        stop_reason,
352                        steps,
353                    });
354                    return;
355                }
356
357                steps.push(step);
358            }
359
360            let final_response = steps.last().and_then(|s| s.response.content().cloned());
361            yield AgentEvent::Done(AgentResponse {
362                final_response,
363                iterations: steps.len(),
364                stop_reason: AgentStopReason::MaxIterations,
365                steps,
366            });
367        }
368    }
369}