1use crate::{
9 dispatch::Dispatcher,
10 event::{AgentEvent, AgentResponse, AgentStep, AgentStopReason},
11 model::{Message, Model, Request},
12};
13use anyhow::Result;
14use async_stream::stream;
15use futures_core::Stream;
16use tokio::sync::mpsc;
17
18pub use builder::AgentBuilder;
19pub use config::AgentConfig;
20pub use parser::parse_agent_md;
21
22mod builder;
23pub mod config;
24mod parser;
25
26pub struct Agent<M: Model> {
33 pub config: AgentConfig,
35 model: M,
37 pub(crate) history: Vec<Message>,
39}
40
41impl<M: Model> Agent<M> {
42 pub fn push_message(&mut self, message: Message) {
44 self.history.push(message);
45 }
46
47 pub fn messages(&self) -> &[Message] {
49 &self.history
50 }
51
52 pub fn clear_history(&mut self) {
54 self.history.clear();
55 }
56
57 pub async fn step<D: Dispatcher>(&mut self, dispatcher: &D) -> Result<AgentStep> {
63 let model_name = self
64 .config
65 .model
66 .clone()
67 .unwrap_or_else(|| self.model.active_model());
68
69 let mut messages = Vec::with_capacity(1 + self.history.len());
70 if !self.config.system_prompt.is_empty() {
71 messages.push(Message::system(&self.config.system_prompt));
72 }
73 messages.extend(self.history.iter().cloned());
74
75 let tools = dispatcher.tools();
76 let mut request = Request::new(model_name)
77 .with_messages(messages)
78 .with_tool_choice(self.config.tool_choice.clone());
79 if !tools.is_empty() {
80 request = request.with_tools(tools);
81 }
82
83 let response = self.model.send(&request).await?;
84 let tool_calls = response.tool_calls().unwrap_or_default().to_vec();
85
86 if let Some(msg) = response.message() {
88 self.history.push(msg);
89 }
90
91 let mut tool_results = Vec::new();
93 if !tool_calls.is_empty() {
94 let calls: Vec<(&str, &str)> = tool_calls
95 .iter()
96 .map(|tc| (tc.function.name.as_str(), tc.function.arguments.as_str()))
97 .collect();
98
99 let results = dispatcher.dispatch(&calls).await;
100
101 for (tc, result) in tool_calls.iter().zip(results) {
102 let output = match result {
103 Ok(s) => s,
104 Err(e) => format!("error: {e}"),
105 };
106
107 let msg = Message::tool(&output, tc.id.clone());
108 self.history.push(msg.clone());
109 tool_results.push(msg);
110 }
111 }
112
113 Ok(AgentStep {
114 response,
115 tool_calls,
116 tool_results,
117 })
118 }
119
120 fn stop_reason(step: &AgentStep) -> AgentStopReason {
122 if step.response.content().is_some() {
123 AgentStopReason::TextResponse
124 } else {
125 AgentStopReason::NoAction
126 }
127 }
128
129 pub async fn run<D: Dispatcher>(
136 &mut self,
137 dispatcher: &D,
138 events: mpsc::UnboundedSender<AgentEvent>,
139 ) -> AgentResponse {
140 use futures_util::StreamExt;
141
142 let mut stream = std::pin::pin!(self.run_stream(dispatcher));
143 let mut response = None;
144 while let Some(event) = stream.next().await {
145 if let AgentEvent::Done(ref resp) = event {
146 response = Some(resp.clone());
147 }
148 let _ = events.send(event);
149 }
150
151 response.unwrap_or_else(|| AgentResponse {
152 final_response: None,
153 iterations: 0,
154 stop_reason: AgentStopReason::Error("stream ended without Done".into()),
155 steps: vec![],
156 })
157 }
158
159 pub fn run_stream<'a, D: Dispatcher + 'a>(
165 &'a mut self,
166 dispatcher: &'a D,
167 ) -> impl Stream<Item = AgentEvent> + 'a {
168 stream! {
169 let mut steps = Vec::new();
170 let max = self.config.max_iterations;
171
172 for _ in 0..max {
173 match self.step(dispatcher).await {
174 Ok(step) => {
175 let has_tool_calls = !step.tool_calls.is_empty();
176 let text = step.response.content().cloned();
177
178 if let Some(ref t) = text {
179 yield AgentEvent::TextDelta(t.clone());
180 }
181
182 if has_tool_calls {
183 yield AgentEvent::ToolCallsStart(step.tool_calls.clone());
184 for (tc, result) in step.tool_calls.iter().zip(&step.tool_results) {
185 yield AgentEvent::ToolResult {
186 call_id: tc.id.clone(),
187 output: result.content.clone(),
188 };
189 }
190 yield AgentEvent::ToolCallsComplete;
191 }
192
193 if !has_tool_calls {
194 let stop_reason = Self::stop_reason(&step);
195 steps.push(step);
196 yield AgentEvent::Done(AgentResponse {
197 final_response: text,
198 iterations: steps.len(),
199 stop_reason,
200 steps,
201 });
202 return;
203 }
204
205 steps.push(step);
206 }
207 Err(e) => {
208 yield AgentEvent::Done(AgentResponse {
209 final_response: None,
210 iterations: steps.len(),
211 stop_reason: AgentStopReason::Error(e.to_string()),
212 steps,
213 });
214 return;
215 }
216 }
217 }
218
219 let final_response = steps.last().and_then(|s| s.response.content().cloned());
220 yield AgentEvent::Done(AgentResponse {
221 final_response,
222 iterations: steps.len(),
223 stop_reason: AgentStopReason::MaxIterations,
224 steps,
225 });
226 }
227 }
228}