1use crate::dispatch::Dispatcher;
8use crate::event::{AgentEvent, AgentResponse, AgentStep, AgentStopReason};
9use crate::model::{Message, Model, Request};
10use anyhow::Result;
11use async_stream::stream;
12use futures_core::Stream;
13
14pub use builder::AgentBuilder;
15pub use config::AgentConfig;
16
17mod builder;
18pub mod config;
19
20pub struct Agent<M: Model> {
27 pub config: AgentConfig,
29 model: M,
31 pub(crate) history: Vec<Message>,
33}
34
35impl<M: Model> Agent<M> {
36 pub fn push_message(&mut self, message: Message) {
38 self.history.push(message);
39 }
40
41 pub fn messages(&self) -> &[Message] {
43 &self.history
44 }
45
46 pub fn clear_history(&mut self) {
48 self.history.clear();
49 }
50
51 pub async fn step<D: Dispatcher>(&mut self, dispatcher: &D) -> Result<AgentStep> {
57 let model_name = self
58 .config
59 .model
60 .clone()
61 .unwrap_or_else(|| self.model.active_model());
62
63 let mut messages = Vec::with_capacity(1 + self.history.len());
64 if !self.config.system_prompt.is_empty() {
65 messages.push(Message::system(&self.config.system_prompt));
66 }
67 messages.extend(self.history.iter().cloned());
68
69 let tools = dispatcher.tools();
70 let mut request = Request::new(model_name)
71 .with_messages(messages)
72 .with_tool_choice(self.config.tool_choice.clone());
73 if !tools.is_empty() {
74 request = request.with_tools(tools);
75 }
76
77 let response = self.model.send(&request).await?;
78 let tool_calls = response.tool_calls().unwrap_or_default().to_vec();
79
80 if let Some(msg) = response.message() {
82 self.history.push(msg);
83 }
84
85 let mut tool_results = Vec::new();
87 if !tool_calls.is_empty() {
88 let calls: Vec<(&str, &str)> = tool_calls
89 .iter()
90 .map(|tc| (tc.function.name.as_str(), tc.function.arguments.as_str()))
91 .collect();
92
93 let results = dispatcher.dispatch(&calls).await;
94
95 for (tc, result) in tool_calls.iter().zip(results) {
96 let output = match result {
97 Ok(s) => s,
98 Err(e) => format!("error: {e}"),
99 };
100
101 let msg = Message::tool(&output, tc.id.clone());
102 self.history.push(msg.clone());
103 tool_results.push(msg);
104 }
105 }
106
107 Ok(AgentStep {
108 response,
109 tool_calls,
110 tool_results,
111 })
112 }
113
114 fn stop_reason(step: &AgentStep) -> AgentStopReason {
116 if step.response.content().is_some() {
117 AgentStopReason::TextResponse
118 } else {
119 AgentStopReason::NoAction
120 }
121 }
122
123 pub async fn run<D: Dispatcher>(&mut self, dispatcher: &D) -> AgentResponse {
128 let mut steps = Vec::new();
129 let max = self.config.max_iterations;
130
131 for _ in 0..max {
132 match self.step(dispatcher).await {
133 Ok(step) => {
134 let has_tool_calls = !step.tool_calls.is_empty();
135 let text = step.response.content().cloned();
136
137 if !has_tool_calls {
138 let stop_reason = Self::stop_reason(&step);
139 steps.push(step);
140 return AgentResponse {
141 final_response: text,
142 iterations: steps.len(),
143 stop_reason,
144 steps,
145 };
146 }
147
148 steps.push(step);
149 }
150 Err(e) => {
151 return AgentResponse {
152 final_response: None,
153 iterations: steps.len(),
154 stop_reason: AgentStopReason::Error(e.to_string()),
155 steps,
156 };
157 }
158 }
159 }
160
161 let final_response = steps.last().and_then(|s| s.response.content().cloned());
162 AgentResponse {
163 final_response,
164 iterations: steps.len(),
165 stop_reason: AgentStopReason::MaxIterations,
166 steps,
167 }
168 }
169
170 pub fn run_stream<'a, D: Dispatcher + 'a>(
176 &'a mut self,
177 dispatcher: &'a D,
178 ) -> impl Stream<Item = AgentEvent> + 'a {
179 stream! {
180 let mut steps = Vec::new();
181 let max = self.config.max_iterations;
182
183 for _ in 0..max {
184 match self.step(dispatcher).await {
185 Ok(step) => {
186 let has_tool_calls = !step.tool_calls.is_empty();
187 let text = step.response.content().cloned();
188
189 if let Some(ref t) = text {
190 yield AgentEvent::TextDelta(t.clone());
191 }
192
193 if has_tool_calls {
194 yield AgentEvent::ToolCallsStart(step.tool_calls.clone());
195 for (tc, result) in step.tool_calls.iter().zip(&step.tool_results) {
196 yield AgentEvent::ToolResult {
197 call_id: tc.id.clone(),
198 output: result.content.clone(),
199 };
200 }
201 yield AgentEvent::ToolCallsComplete;
202 }
203
204 if !has_tool_calls {
205 let stop_reason = Self::stop_reason(&step);
206 steps.push(step);
207 let response = AgentResponse {
208 final_response: text,
209 iterations: steps.len(),
210 stop_reason,
211 steps,
212 };
213 yield AgentEvent::Done(response);
214 return;
215 }
216
217 steps.push(step);
218 }
219 Err(e) => {
220 let response = AgentResponse {
221 final_response: None,
222 iterations: steps.len(),
223 stop_reason: AgentStopReason::Error(e.to_string()),
224 steps,
225 };
226 yield AgentEvent::Done(response);
227 return;
228 }
229 }
230 }
231
232 let final_response = steps.last().and_then(|s| s.response.content().cloned());
233 let response = AgentResponse {
234 final_response,
235 iterations: steps.len(),
236 stop_reason: AgentStopReason::MaxIterations,
237 steps,
238 };
239 yield AgentEvent::Done(response);
240 }
241 }
242}