unia/agent.rs
1//! Agent struct for automatic tool execution with LLM providers.
2
3use crate::client::{Client, ClientError};
4use crate::model::{FinishReason, Message, Part, Response, Usage};
5use serde_json::json;
6use std::collections::HashMap;
7use tracing::{debug, info, warn};
8
9use crate::mcp::MCPServer;
10
11/// Agent that automatically executes tools in a loop.
12///
13/// Unlike the raw `Client`, an `Agent` handles tool execution automatically:
14/// 1. Sends request with tool definitions from the MCP server (if configured)
15/// 2. Receives response with potential function calls
16/// 3. Executes tools automatically
17/// 4. Adds results back to conversation
18/// 5. Loops until no more function calls
19///
20/// # Example
21/// ```ignore
22/// use unia::agent::Agent;
23/// use unia::providers::{Gemini, Provider};
24/// use unia::model::{Message, Part};
25///
26/// let client = Gemini::create("api_key".to_string(), "gemini-3.0-pro".to_string());
27/// let agent = Agent::new(client)
28/// .with_server(weather_server);
29///
30/// let messages = vec![
31/// Message::User(vec![
32/// Part::Text { content: "What's the weather?".into(), finished: true }
33/// ])
34/// ];
35///
36/// let response = agent.chat(messages).await?;
37/// ```
38pub struct Agent<C: Client> {
39 client: C,
40 max_iterations: usize,
41 server: Option<Box<dyn MCPServer>>,
42}
43
44impl<C: Client> Agent<C> {
45 /// Create a new agent.
46 ///
47 /// # Arguments
48 /// - `client`: The initialized client instance
49 ///
50 /// Tools are fetched from the configured MCP server if available.
51 pub fn new(client: C) -> Self {
52 Self {
53 client,
54 max_iterations: 10,
55 server: None,
56 }
57 }
58
59 /// Set the MCP server for the agent.
60 pub fn with_server<S: MCPServer + 'static>(mut self, server: S) -> Self {
61 self.server = Some(Box::new(server));
62 self
63 }
64
65 /// Set the maximum number of iterations for the agentic loop.
66 pub fn with_max_iterations(mut self, max: usize) -> Self {
67 self.max_iterations = max;
68 self
69 }
70
71 /// Send a chat request with automatic tool execution.
72 ///
73 /// This method automatically handles the tool execution loop:
74 /// - Sends request to LLM with tools from the MCP server (if configured)
75 /// - Executes any tool calls
76 /// - Continues until no more tool calls or max iterations reached
77 ///
78 /// # Arguments
79 /// - `messages`: Conversation messages
80 ///
81 /// # Returns
82 /// The response containing all new messages generated during the execution (including tool calls and results)
83 pub async fn chat(&self, mut messages: Vec<Message>) -> Result<Response, ClientError> {
84 debug!(
85 "Starting agent chat loop with {} initial messages",
86 messages.len()
87 );
88
89 let mut current_response = Response {
90 data: Vec::new(),
91 usage: Usage::default(),
92 finish: FinishReason::Unfinished,
93 };
94
95 let (tools, tool_map) = if let Some(server) = &self.server {
96 match server.list_tools().await {
97 Ok(tools) => {
98 let map: HashMap<String, Option<String>> = tools
99 .iter()
100 .map(|t| (t.value.name.to_string(), t.server_id.clone()))
101 .collect();
102 (tools.into_iter().map(|t| t.value).collect(), map)
103 }
104 Err(e) => {
105 return Err(ClientError::ProviderError(format!(
106 "Failed to list tools from MCP server: {}",
107 e
108 )));
109 }
110 }
111 } else {
112 (Vec::new(), HashMap::new())
113 };
114
115 for iteration in 0..self.max_iterations {
116 debug!("Agent iteration {}/{}", iteration + 1, self.max_iterations);
117
118 let response = self.client.request(messages.clone(), tools.clone()).await?;
119 current_response.usage += response.usage;
120 current_response.finish = response.finish.clone();
121
122 let mut tool_calls_executed = false;
123
124 for msg in response.data {
125 messages.push(msg.clone());
126 current_response.data.push(msg.clone());
127
128 for part in msg.parts() {
129 if let Part::FunctionCall {
130 id,
131 name,
132 arguments,
133 ..
134 } = part
135 {
136 tool_calls_executed = true;
137 info!("Tool call requested: {}", name);
138 debug!("Tool arguments: {}", arguments);
139
140 let server = self.server.as_ref().ok_or_else(|| {
141 ClientError::Config("No MCP server configured".to_string())
142 })?;
143 let server_id = tool_map.get(name).cloned().flatten();
144 let result = server
145 .call_tool(name.clone(), arguments.clone(), server_id)
146 .await;
147
148 let response_part = match result {
149 Ok(mut part) => {
150 info!("Tool {} executed successfully", name);
151 debug!("Tool result: {:?}", part);
152 if let Part::FunctionResponse {
153 id: ref mut pid, ..
154 } = part
155 {
156 *pid = id.clone();
157 }
158 part
159 }
160 Err(e) => {
161 warn!("Tool {} execution failed: {}", name, e);
162 Part::FunctionResponse {
163 id: id.clone(),
164 name: name.clone(),
165 response: json!({ "error": format!("Error: {}", e) }),
166 parts: vec![],
167 finished: true,
168 }
169 }
170 };
171
172 let response_msg = Message::User(vec![response_part]);
173 messages.push(response_msg.clone());
174 current_response.data.push(response_msg);
175 }
176 }
177 }
178
179 if !tool_calls_executed {
180 debug!("No more function calls, agent loop complete");
181 return Ok(current_response);
182 }
183 }
184
185 warn!(
186 "Max iterations ({}) reached in agent loop",
187 self.max_iterations
188 );
189 Err(ClientError::Config(
190 "Max iterations reached in agent loop".to_string(),
191 ))
192 }
193
194 /// Send a streaming chat request with automatic tool execution.
195 ///
196 /// This method automatically handles the tool execution loop with streaming:
197 /// - Sends streaming request to LLM with tools from the MCP server (if configured)
198 /// - Executes any tool calls
199 /// - Continues until no more tool calls or max iterations reached
200 ///
201 /// # Arguments
202 /// - `messages`: Conversation messages
203 ///
204 /// # Returns
205 /// A stream of chunks for the final response after all tool executions complete
206 pub fn chat_stream<'a>(
207 &'a self,
208 mut messages: Vec<Message>,
209 ) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<Response, ClientError>> + Send + 'a>>
210 where
211 C: crate::client::StreamingClient,
212 {
213 Box::pin(async_stream::try_stream! {
214 debug!("Starting agent streaming chat loop");
215 use futures::StreamExt;
216
217 let mut current_response = Response {
218 data: Vec::new(),
219 usage: Usage::default(),
220 finish: FinishReason::Unfinished,
221 };
222
223 let (tools, tool_map) = if let Some(server) = &self.server {
224 match server.list_tools().await {
225 Ok(tools) => {
226 let map: HashMap<String, Option<String>> = tools
227 .iter()
228 .map(|t| (t.value.name.to_string(), t.server_id.clone()))
229 .collect();
230 (tools.into_iter().map(|t| t.value).collect(), map)
231 }
232 Err(e) => {
233 warn!("Failed to list tools from MCP server: {}", e);
234 (Vec::new(), HashMap::new())
235 }
236 }
237 } else {
238 (Vec::new(), HashMap::new())
239 };
240
241 for iteration in 0..self.max_iterations {
242 debug!(
243 "Agent streaming iteration {}/{}",
244 iteration + 1,
245 self.max_iterations
246 );
247
248 let mut stream = self.client.request_stream(messages.clone(), tools.clone()).await?;
249
250 // Snapshot of state before this turn
251 let base_data_len = current_response.data.len();
252 let base_usage = current_response.usage.clone();
253
254 while let Some(response_result) = stream.next().await {
255 let response = response_result?;
256
257 // Update current_response
258 // Truncate to base length to remove previous partials of this turn
259 current_response.data.truncate(base_data_len);
260 current_response.data.extend(response.data.clone());
261
262 current_response.usage = base_usage.clone();
263 current_response.usage += response.usage;
264 current_response.finish = response.finish;
265
266 yield current_response.clone();
267 }
268
269 // After stream, current_response contains the full assistant message for this turn.
270 // Update messages history
271 if current_response.data.len() > base_data_len {
272 // The new messages added in this turn
273 for i in base_data_len..current_response.data.len() {
274 messages.push(current_response.data[i].clone());
275 }
276 }
277
278 // Check for tool calls
279 let mut tool_calls_executed = false;
280 let mut tool_responses = Vec::new();
281
282 // We only check the LAST message for tool calls, which should be the assistant's message
283 if let Some(msg) = current_response.data.last() {
284 for part in msg.parts() {
285 if let Part::FunctionCall { id, name, arguments, finished, .. } = part {
286 if *finished {
287 tool_calls_executed = true;
288 info!("Executing tool: {}", name);
289
290 let server = self.server.as_ref().ok_or_else(|| ClientError::Config("No MCP server configured".to_string()))?;
291 let server_id = tool_map.get(name).cloned().flatten();
292 let result = server
293 .call_tool(name.clone(), arguments.clone(), server_id)
294 .await;
295
296 let response_part = match result {
297 Ok(mut part) => {
298 if let Part::FunctionResponse { id: ref mut pid, .. } = part {
299 *pid = id.clone();
300 }
301 part
302 }
303 Err(e) => {
304 Part::FunctionResponse {
305 id: id.clone(),
306 name: name.clone(),
307 response: json!({ "error": format!("Error: {}", e) }),
308 parts: vec![],
309 finished: true,
310 }
311 },
312 };
313 tool_responses.push(response_part);
314 }
315 }
316 }
317 }
318
319 if tool_calls_executed {
320 let tool_msg = Message::User(tool_responses);
321 messages.push(tool_msg.clone());
322 current_response.data.push(tool_msg);
323
324 yield current_response.clone();
325 } else {
326 // No tool calls, we are done
327 return;
328 }
329 }
330
331 warn!(
332 "Max iterations ({}) reached in streaming agent loop",
333 self.max_iterations
334 );
335 Err(ClientError::Config(
336 "Max iterations reached in agent loop".to_string(),
337 ))?;
338 })
339 }
340}