unia/
agent.rs

1//! Agent struct for automatic tool execution with LLM providers.
2
3use crate::client::{Client, ClientError};
4use crate::model::{FinishReason, Message, Part, Response, Usage};
5use serde_json::json;
6use std::collections::HashMap;
7use tracing::{debug, info, warn};
8
9use crate::mcp::MCPServer;
10
11/// Agent that automatically executes tools in a loop.
12///
13/// Unlike the raw `Client`, an `Agent` handles tool execution automatically:
14/// 1. Sends request with tool definitions from the MCP server (if configured)
15/// 2. Receives response with potential function calls
16/// 3. Executes tools automatically
17/// 4. Adds results back to conversation
18/// 5. Loops until no more function calls
19///
20/// # Example
21/// ```ignore
22/// use unia::agent::Agent;
23/// use unia::providers::{Gemini, Provider};
24/// use unia::model::{Message, Part};
25///
26/// let client = Gemini::create("api_key".to_string(), "gemini-3.0-pro".to_string());
27/// let agent = Agent::new(client)
28///     .with_server(weather_server);
29///
30/// let messages = vec![
31///     Message::User(vec![
32///         Part::Text { content: "What's the weather?".into(), finished: true }
33///     ])
34/// ];
35///
36/// let response = agent.chat(messages).await?;
37/// ```
38pub struct Agent<C: Client> {
39    client: C,
40    max_iterations: usize,
41    server: Option<Box<dyn MCPServer>>,
42}
43
44impl<C: Client> Agent<C> {
45    /// Create a new agent.
46    ///
47    /// # Arguments
48    /// - `client`: The initialized client instance
49    ///
50    /// Tools are fetched from the configured MCP server if available.
51    pub fn new(client: C) -> Self {
52        Self {
53            client,
54            max_iterations: 10,
55            server: None,
56        }
57    }
58
59    /// Set the MCP server for the agent.
60    pub fn with_server<S: MCPServer + 'static>(mut self, server: S) -> Self {
61        self.server = Some(Box::new(server));
62        self
63    }
64
65    /// Set the maximum number of iterations for the agentic loop.
66    pub fn with_max_iterations(mut self, max: usize) -> Self {
67        self.max_iterations = max;
68        self
69    }
70
71    /// Send a chat request with automatic tool execution.
72    ///
73    /// This method automatically handles the tool execution loop:
74    /// - Sends request to LLM with tools from the MCP server (if configured)
75    /// - Executes any tool calls
76    /// - Continues until no more tool calls or max iterations reached
77    ///
78    /// # Arguments
79    /// - `messages`: Conversation messages
80    ///
81    /// # Returns
82    /// The response containing all new messages generated during the execution (including tool calls and results)
83    pub async fn chat(&self, mut messages: Vec<Message>) -> Result<Response, ClientError> {
84        debug!(
85            "Starting agent chat loop with {} initial messages",
86            messages.len()
87        );
88
89        let mut current_response = Response {
90            data: Vec::new(),
91            usage: Usage::default(),
92            finish: FinishReason::Unfinished,
93        };
94
95        let (tools, tool_map) = if let Some(server) = &self.server {
96            match server.list_tools().await {
97                Ok(tools) => {
98                    let map: HashMap<String, Option<String>> = tools
99                        .iter()
100                        .map(|t| (t.value.name.to_string(), t.server_id.clone()))
101                        .collect();
102                    (tools.into_iter().map(|t| t.value).collect(), map)
103                }
104                Err(e) => {
105                    return Err(ClientError::ProviderError(format!(
106                        "Failed to list tools from MCP server: {}",
107                        e
108                    )));
109                }
110            }
111        } else {
112            (Vec::new(), HashMap::new())
113        };
114
115        for iteration in 0..self.max_iterations {
116            debug!("Agent iteration {}/{}", iteration + 1, self.max_iterations);
117
118            let response = self.client.request(messages.clone(), tools.clone()).await?;
119            current_response.usage += response.usage;
120            current_response.finish = response.finish.clone();
121
122            let mut tool_calls_executed = false;
123
124            for msg in response.data {
125                messages.push(msg.clone());
126                current_response.data.push(msg.clone());
127
128                for part in msg.parts() {
129                    if let Part::FunctionCall {
130                        id,
131                        name,
132                        arguments,
133                        ..
134                    } = part
135                    {
136                        tool_calls_executed = true;
137                        info!("Tool call requested: {}", name);
138                        debug!("Tool arguments: {}", arguments);
139
140                        let server = self.server.as_ref().ok_or_else(|| {
141                            ClientError::Config("No MCP server configured".to_string())
142                        })?;
143                        let server_id = tool_map.get(name).cloned().flatten();
144                        let result = server
145                            .call_tool(name.clone(), arguments.clone(), server_id)
146                            .await;
147
148                        let response_part = match result {
149                            Ok(mut part) => {
150                                info!("Tool {} executed successfully", name);
151                                debug!("Tool result: {:?}", part);
152                                if let Part::FunctionResponse {
153                                    id: ref mut pid, ..
154                                } = part
155                                {
156                                    *pid = id.clone();
157                                }
158                                part
159                            }
160                            Err(e) => {
161                                warn!("Tool {} execution failed: {}", name, e);
162                                Part::FunctionResponse {
163                                    id: id.clone(),
164                                    name: name.clone(),
165                                    response: json!({ "error": format!("Error: {}", e) }),
166                                    parts: vec![],
167                                    finished: true,
168                                }
169                            }
170                        };
171
172                        let response_msg = Message::User(vec![response_part]);
173                        messages.push(response_msg.clone());
174                        current_response.data.push(response_msg);
175                    }
176                }
177            }
178
179            if !tool_calls_executed {
180                debug!("No more function calls, agent loop complete");
181                return Ok(current_response);
182            }
183        }
184
185        warn!(
186            "Max iterations ({}) reached in agent loop",
187            self.max_iterations
188        );
189        Err(ClientError::Config(
190            "Max iterations reached in agent loop".to_string(),
191        ))
192    }
193
194    /// Send a streaming chat request with automatic tool execution.
195    ///
196    /// This method automatically handles the tool execution loop with streaming:
197    /// - Sends streaming request to LLM with tools from the MCP server (if configured)
198    /// - Executes any tool calls
199    /// - Continues until no more tool calls or max iterations reached
200    ///
201    /// # Arguments
202    /// - `messages`: Conversation messages
203    ///
204    /// # Returns
205    /// A stream of chunks for the final response after all tool executions complete
206    pub fn chat_stream<'a>(
207        &'a self,
208        mut messages: Vec<Message>,
209    ) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<Response, ClientError>> + Send + 'a>>
210    where
211        C: crate::client::StreamingClient,
212    {
213        Box::pin(async_stream::try_stream! {
214            debug!("Starting agent streaming chat loop");
215            use futures::StreamExt;
216
217            let mut current_response = Response {
218                data: Vec::new(),
219                usage: Usage::default(),
220                finish: FinishReason::Unfinished,
221            };
222
223            let (tools, tool_map) = if let Some(server) = &self.server {
224                match server.list_tools().await {
225                    Ok(tools) => {
226                        let map: HashMap<String, Option<String>> = tools
227                            .iter()
228                            .map(|t| (t.value.name.to_string(), t.server_id.clone()))
229                            .collect();
230                        (tools.into_iter().map(|t| t.value).collect(), map)
231                    }
232                    Err(e) => {
233                        warn!("Failed to list tools from MCP server: {}", e);
234                        (Vec::new(), HashMap::new())
235                    }
236                }
237            } else {
238                (Vec::new(), HashMap::new())
239            };
240
241            for iteration in 0..self.max_iterations {
242                debug!(
243                    "Agent streaming iteration {}/{}",
244                    iteration + 1,
245                    self.max_iterations
246                );
247
248                let mut stream = self.client.request_stream(messages.clone(), tools.clone()).await?;
249
250                // Snapshot of state before this turn
251                let base_data_len = current_response.data.len();
252                let base_usage = current_response.usage.clone();
253
254                while let Some(response_result) = stream.next().await {
255                    let response = response_result?;
256
257                    // Update current_response
258                    // Truncate to base length to remove previous partials of this turn
259                    current_response.data.truncate(base_data_len);
260                    current_response.data.extend(response.data.clone());
261
262                    current_response.usage = base_usage.clone();
263                    current_response.usage += response.usage;
264                    current_response.finish = response.finish;
265
266                    yield current_response.clone();
267                }
268
269                // After stream, current_response contains the full assistant message for this turn.
270                // Update messages history
271                if current_response.data.len() > base_data_len {
272                     // The new messages added in this turn
273                     for i in base_data_len..current_response.data.len() {
274                         messages.push(current_response.data[i].clone());
275                     }
276                }
277
278                // Check for tool calls
279                let mut tool_calls_executed = false;
280                let mut tool_responses = Vec::new();
281
282                // We only check the LAST message for tool calls, which should be the assistant's message
283                if let Some(msg) = current_response.data.last() {
284                    for part in msg.parts() {
285                        if let Part::FunctionCall { id, name, arguments, finished, .. } = part {
286                            if *finished {
287                                tool_calls_executed = true;
288                                info!("Executing tool: {}", name);
289
290                                let server = self.server.as_ref().ok_or_else(|| ClientError::Config("No MCP server configured".to_string()))?;
291                                let server_id = tool_map.get(name).cloned().flatten();
292                                let result = server
293                                    .call_tool(name.clone(), arguments.clone(), server_id)
294                                    .await;
295
296                                let response_part = match result {
297                                    Ok(mut part) => {
298                                        if let Part::FunctionResponse { id: ref mut pid, .. } = part {
299                                            *pid = id.clone();
300                                        }
301                                        part
302                                    }
303                                    Err(e) => {
304                                        Part::FunctionResponse {
305                                            id: id.clone(),
306                                            name: name.clone(),
307                                            response: json!({ "error": format!("Error: {}", e) }),
308                                            parts: vec![],
309                                            finished: true,
310                                        }
311                                    },
312                                };
313                                tool_responses.push(response_part);
314                            }
315                        }
316                    }
317                }
318
319                if tool_calls_executed {
320                    let tool_msg = Message::User(tool_responses);
321                    messages.push(tool_msg.clone());
322                    current_response.data.push(tool_msg);
323
324                    yield current_response.clone();
325                } else {
326                    // No tool calls, we are done
327                    return;
328                }
329            }
330
331            warn!(
332                "Max iterations ({}) reached in streaming agent loop",
333                self.max_iterations
334            );
335            Err(ClientError::Config(
336                "Max iterations reached in agent loop".to_string(),
337            ))?;
338        })
339    }
340}