1use crate::UserInteraction;
4use crate::config::AgentConfig;
5use crate::errors::AgentError;
6use crate::mcp::McpConnection;
7use crate::models::chat::{ApiResponse, ChatMessage};
8use crate::models::tools::{
9 ToolDefinition, ToolParameter, ToolParameterType, ToolParametersDefinition,
10};
11use crate::providers::{Provider, ProviderRegistry};
12use crate::strategies::{NextStep, Strategy};
13use anyhow::{Context, Result, anyhow};
14use rmcp::model::Tool as McpTool;
15use serde_json::{Map, Value};
16use std::collections::HashMap;
17use std::path::Path;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20use tracing::{debug, info, trace, warn};
21
22use crate::AgentState;
23
24pub struct Agent<UI: UserInteraction> {
25 provider_registry: ProviderRegistry,
26 mcp_connections: HashMap<String, Arc<Mutex<McpConnection>>>,
27 #[allow(dead_code)] http_client: reqwest::Client,
29 #[allow(dead_code)] ui_handler: Arc<UI>,
31 strategy: Box<dyn Strategy<UI> + Send + Sync>,
32 state: AgentState,
33 current_provider_id: String,
34}
35
36fn mcp_schema_to_tool_params(schema_val: Option<&Map<String, Value>>) -> ToolParametersDefinition {
37 let default_params = ToolParametersDefinition {
38 param_type: "object".to_string(),
39 properties: HashMap::new(),
40 required: Vec::new(),
41 };
42
43 let schema = match schema_val {
44 Some(s) => s,
45 None => return default_params,
46 };
47
48 let props_val = schema.get("properties").and_then(Value::as_object);
49 let required_val = schema.get("required").and_then(Value::as_array);
50 let mut properties = HashMap::new();
51
52 if let Some(props_map) = props_val {
53 for (key, val) in props_map {
54 let prop_obj = match val.as_object() {
55 Some(obj) => obj,
56 None => continue,
57 };
58
59 let param_type_str = prop_obj
60 .get("type")
61 .and_then(Value::as_str)
62 .unwrap_or("string");
63 let description = prop_obj
64 .get("description")
65 .and_then(Value::as_str)
66 .unwrap_or("")
67 .to_string();
68
69 let param_type = match param_type_str {
70 "string" => ToolParameterType::String,
71 "integer" => ToolParameterType::Integer,
72 "number" => ToolParameterType::Number,
73 "boolean" => ToolParameterType::Boolean,
74 "array" => ToolParameterType::Array,
75 "object" => ToolParameterType::Object,
76 _ => ToolParameterType::String,
77 };
78
79 let items = if param_type == ToolParameterType::Array {
80 prop_obj.get("items")
81 .and_then(Value::as_object)
82 .map(|items_obj| {
83 let item_type_str = items_obj
84 .get("type")
85 .and_then(Value::as_str)
86 .unwrap_or("string");
87 let item_desc = items_obj
88 .get("description")
89 .and_then(Value::as_str)
90 .unwrap_or("Array item")
91 .to_string();
92 let item_type = match item_type_str {
93 "string" => ToolParameterType::String,
94 "integer" => ToolParameterType::Integer,
95 "number" => ToolParameterType::Number,
96 "boolean" => ToolParameterType::Boolean,
97 "array" => ToolParameterType::Array,
98 "object" => ToolParameterType::Object,
99 _ => ToolParameterType::String,
100 };
101 Box::new(ToolParameter {
102 param_type: item_type,
103 description: item_desc,
104 enum_values: None,
105 items: None, })
107 })
108 .or_else(|| Some(Box::new(ToolParameter {
109 param_type: ToolParameterType::String,
110 description: "Array item".to_string(),
111 enum_values: None,
112 items: None,
113 })))
114 } else {
115 None
116 };
117
118 properties.insert(
119 key.clone(),
120 ToolParameter {
121 param_type,
122 description,
123 enum_values: None,
124 items,
125 },
126 );
127 }
128 }
129
130 let required = required_val
131 .map(|arr| {
132 arr.iter()
133 .filter_map(|v| v.as_str().map(String::from))
134 .collect()
135 })
136 .unwrap_or_default();
137
138 ToolParametersDefinition {
139 param_type: "object".to_string(),
140 properties,
141 required,
142 }
143}
144
145
146struct DummyClientService;
148impl rmcp::service::Service<rmcp::service::RoleClient> for DummyClientService {
149 #[allow(refining_impl_trait)] fn handle_request(
151 &self,
152 _request: rmcp::model::ServerRequest,
153 _context: rmcp::service::RequestContext<rmcp::service::RoleClient>,
154 ) -> std::pin::Pin<
155 Box<
156 dyn std::future::Future<Output = Result<rmcp::model::ClientResult, rmcp::Error>> + Send,
157 >,
158 > {
159 Box::pin(async {
160 Err(rmcp::Error::method_not_found::<rmcp::model::InitializeResultMethod>())
161 })
162 }
163 #[allow(refining_impl_trait)] fn handle_notification(
165 &self,
166 _notification: rmcp::model::ServerNotification,
167 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), rmcp::Error>> + Send>> {
168 Box::pin(async { Ok(()) })
169 }
170 fn get_peer(&self) -> Option<rmcp::service::Peer<rmcp::service::RoleClient>> {
171 None
172 }
173 fn set_peer(&mut self, _peer: rmcp::service::Peer<rmcp::service::RoleClient>) {}
174 fn get_info(&self) -> rmcp::model::ClientInfo {
175 rmcp::model::ClientInfo::default()
176 }
177}
178
179
180impl<UI: UserInteraction + 'static> Agent<UI> {
181 #[allow(clippy::too_many_arguments)]
182 pub fn new(
183 config: AgentConfig,
184 ui_handler: Arc<UI>,
185 strategy: Box<dyn Strategy<UI> + Send + Sync>,
186 history: Option<Vec<ChatMessage>>,
187 current_user_input: String,
188 provider_registry_override: Option<ProviderRegistry>,
189 mcp_connections_override: Option<HashMap<String, Arc<Mutex<McpConnection>>>>,
190 ) -> Result<Self> {
191 let http_client = reqwest::Client::builder()
192 .build()
193 .context("Failed to build HTTP client for Agent")?;
194
195 let provider_registry = match provider_registry_override {
196 Some(registry) => registry,
197 None => {
198 let mut registry = ProviderRegistry::new(config.default_provider.clone());
199 for (id, provider_conf) in config.providers {
200 let api_key = if !provider_conf.api_key_env_var.is_empty() {
201 match std::env::var(&provider_conf.api_key_env_var) {
202 Ok(key) => key,
203 Err(e) => {
204 warn!(provider_id = %id, env_var = %provider_conf.api_key_env_var, error = %e, "API key environment variable not set or invalid");
205 String::new()
206 }
207 }
208 } else {
209 String::new()
210 };
211 let model_config = provider_conf.model_config;
212 let provider: Box<dyn Provider> = match provider_conf.provider_type.as_str() {
213 "gemini" => Box::new(crate::providers::gemini::GeminiProvider::new(
214 model_config,
215 http_client.clone(),
216 api_key,
217 )),
218 "ollama" => Box::new(crate::providers::ollama::OllamaProvider::new(
219 model_config,
220 http_client.clone(),
221 api_key, )),
223 "openai" => Box::new(crate::providers::openai::OpenAIProvider::new(
224 model_config,
225 http_client.clone(),
226 api_key,
227 )),
228 _ => {
229 return Err(anyhow!(
230 "Unsupported provider type: '{}' specified for provider ID '{}'. Supported types: gemini, ollama, openai.",
231 provider_conf.provider_type,
232 id ));
234 }
235 };
236 registry.register(id.clone(), provider); }
238 registry
239 }
240 };
241
242 let mcp_connections = match mcp_connections_override {
243 Some(connections) => connections,
244 None => {
245 let mut connections = HashMap::new();
246 for (id, server_conf) in config.mcp_servers {
247 let connection = McpConnection::new(server_conf.command, server_conf.args);
248 connections.insert(id, Arc::new(Mutex::new(connection)));
249 }
250 connections
251 }
252 };
253
254 let initial_state = AgentState::new_turn(history, current_user_input);
255 let default_provider_id = provider_registry.default_provider_id().to_string();
256
257 info!(
258 strategy = strategy.name(),
259 default_provider = %default_provider_id,
260 "Initializing MCP Agent with strategy."
261 );
262
263 Ok(Self {
264 provider_registry,
265 mcp_connections,
266 http_client,
267 ui_handler,
268 strategy,
269 state: initial_state,
270 current_provider_id: default_provider_id,
271 })
272 }
273
274 async fn ensure_mcp_connection(&self, server_id: &str) -> Result<()> {
276 let conn_mutex = self
277 .mcp_connections
278 .get(server_id)
279 .ok_or_else(|| anyhow!("MCP server config not found: {}", server_id))?;
280 let conn_guard = conn_mutex.lock().await;
281 let ct = tokio_util::sync::CancellationToken::new();
282 conn_guard
283 .establish_connection_external(DummyClientService, ct)
284 .await
285 }
286
287 pub fn switch_provider(&mut self, provider_id: &str) -> Result<()> {
288 self.provider_registry.get(provider_id)?;
289 if self.current_provider_id != provider_id {
290 debug!(old_provider = %self.current_provider_id, new_provider = %provider_id, "Switching provider");
291 self.current_provider_id = provider_id.to_string();
292 }
293 Ok(())
294 }
295
296 pub async fn get_completion(
297 &self,
298 messages: Vec<ChatMessage>,
299 tools: Option<&[ToolDefinition]>,
300 ) -> Result<ApiResponse> {
301 let provider = self.provider_registry.get(&self.current_provider_id)?;
302 debug!(provider = %self.current_provider_id, num_messages = messages.len(), "Getting completion from provider");
303 provider.get_completion(messages, tools).await
304 }
305
306 pub async fn call_mcp_tool(
307 &self,
308 server_id: &str,
309 tool_name: &str,
310 args: Value,
311 ) -> Result<Value> {
312 self.ensure_mcp_connection(server_id).await?;
313 let conn_mutex = self.mcp_connections.get(server_id).unwrap();
314 let conn = conn_mutex.lock().await;
315 conn.call_tool(tool_name, args).await
316 }
317
318 pub async fn get_mcp_resource(&self, server_id: &str, uri: &str) -> Result<Value> {
319 self.ensure_mcp_connection(server_id).await?;
320 let conn_mutex = self.mcp_connections.get(server_id).unwrap();
321 let conn = conn_mutex.lock().await;
322 debug!(server = %server_id, uri = %uri, "Getting MCP resource");
323 conn.get_resource(uri).await
324 }
325
326 pub async fn list_mcp_tools(&self) -> Result<Vec<McpTool>> {
327 let mut all_tools = Vec::new();
328 for (id, conn_mutex) in &self.mcp_connections {
329 match self.ensure_mcp_connection(id).await {
330 Ok(_) => {
331 let conn = conn_mutex.lock().await;
332 match conn.list_tools().await {
333 Ok(tools) => all_tools.extend(tools),
334 Err(e) => {
335 warn!(server_id = %id, error = ?e, "Failed to list tools from MCP server (post-connection)")
336 }
337 }
338 }
339 Err(e) => {
340 warn!(server_id = %id, error = ?e, "Failed to ensure MCP connection for listing tools");
341 }
342 }
343 }
344 Ok(all_tools)
345 }
346
347
348 pub async fn run(&mut self, _working_dir: &Path) -> Result<(String, AgentState), AgentError> {
349 info!(strategy = self.strategy.name(), "Starting MCP agent run.");
350
351 let mut next_step = self.strategy.initialize_interaction(&mut self.state)?;
352
353 loop {
354 trace!(?next_step, "Processing next step.");
355 match next_step {
356 NextStep::CallApi(state_from_strategy) => {
357 self.state = state_from_strategy;
359 let mcp_tools = self
360 .list_mcp_tools()
361 .await
362 .map_err(|e| AgentError::Mcp(e.context("Failed to list MCP tools")))?;
363
364 let tool_definitions: Vec<ToolDefinition> = mcp_tools
365 .iter()
366 .map(|mcp_tool| {
367 let schema_map = mcp_tool.input_schema.as_ref();
368 ToolDefinition {
369 name: mcp_tool.name.to_string(),
370 description: mcp_tool.description.to_string(),
371 parameters: mcp_schema_to_tool_params(Some(schema_map)),
372 }
373 })
374 .collect();
375
376 debug!(
377 provider = %self.current_provider_id,
378 num_messages = self.state.messages.len(),
379 num_tools = tool_definitions.len(),
380 "Sending request to AI provider."
381 );
382
383 let api_response = self
384 .get_completion(
385 self.state.messages.clone(),
386 if tool_definitions.is_empty() { None } else { Some(&tool_definitions) },
387 )
388 .await
389 .map_err(|e| AgentError::Api(e.context("API call failed during agent run")))?;
390
391 debug!("Received response from AI.");
392 trace!(response = %serde_json::to_string_pretty(&api_response).unwrap_or_default(), "Full API Response");
393
394 next_step = self
395 .strategy
396 .process_api_response(&mut self.state, api_response)?;
397 }
398 NextStep::CallTools(state_from_strategy) => {
399 self.state = state_from_strategy;
400 let tool_calls_to_execute = self.state.pending_tool_calls.clone();
401
402 if tool_calls_to_execute.is_empty() {
403 warn!("Strategy requested tool calls, but none were pending.");
404 return Err(AgentError::Strategy(
405 "Strategy requested tool calls, but none were pending in state".to_string(),
406 ));
407 }
408
409 if let Some(last_message) = self.state.messages.last() {
411 if last_message.role == "assistant" {
412 if let Some(content) = &last_message.content {
413 if !content.trim().is_empty() {
414 println!("\nAssistant: {}", content);
415 }
416 }
417 }
418 }
419
420 info!(
421 count = tool_calls_to_execute.len(),
422 "Executing {} requested tool call(s) via MCP.",
423 tool_calls_to_execute.len()
424 );
425
426 let mut tool_results = Vec::new();
427 for tool_call in &tool_calls_to_execute {
428 let tool_name = &tool_call.function.name;
429 let args: Value = serde_json::from_str(&tool_call.function.arguments)
430 .map_err(|e| {
431 warn!(tool_call_id = %tool_call.id, tool_name=%tool_name, args_str=%tool_call.function.arguments, error=%e, "Failed to parse tool arguments JSON string. Using null.");
432 e
433 })
434 .unwrap_or(Value::Null);
435
436 let server_id = match tool_name.as_str() {
437 "read_file" | "write_file" => "filesystem",
438 "shell" => "shell",
439 "git_diff" | "git_status" | "git_commit" => "git",
440 "search_text" => "search",
441 _ => {
442 warn!(tool_name = %tool_name, "Cannot map tool to MCP server, skipping.");
443 tool_results.push(crate::ToolResult {
444 tool_call_id: tool_call.id.clone(),
445 output: format!("Error: Unknown tool name '{}'", tool_name),
446 status: crate::ToolExecutionStatus::Failure,
447 });
448 continue;
449 }
450 };
451
452 println!(
453 "\n\x1b[33m▶\x1b[0m Running: {}({})",
454 tool_name,
455 &tool_call.function.arguments
456 );
457
458 match self.call_mcp_tool(server_id, tool_name, args).await {
459 Ok(output_value) => {
460 let output_str = match output_value {
461 Value::String(s) => s,
462 Value::Object(map) if map.contains_key("content") => {
463 serde_json::to_string(&map).unwrap_or_else(|_| "<invalid JSON object>".to_string())
464 },
465 Value::Object(map) if map.contains_key("text") => map
466 .get("text")
467 .and_then(Value::as_str)
468 .unwrap_or("")
469 .to_string(),
470 Value::Array(arr) if arr.is_empty() => {
471 if tool_name == "write_file" {
472 "<write successful>".to_string() } else {
474 "<empty array result>".to_string() }
476 }
477 Value::Array(arr) => serde_json::to_string_pretty(&arr)
478 .unwrap_or_else(|_| "<invalid JSON array>".to_string()),
479 Value::Object(map) => serde_json::to_string_pretty(&map)
480 .unwrap_or_else(|_| "<invalid JSON object>".to_string()),
481 Value::Null => "<no output>".to_string(),
482 other => other.to_string(),
483 };
484 tool_results.push(crate::ToolResult {
485 tool_call_id: tool_call.id.clone(),
486 output: output_str,
487 status: crate::ToolExecutionStatus::Success,
488 });
489 }
490 Err(e) => {
491 tool_results.push(crate::ToolResult {
492 tool_call_id: tool_call.id.clone(),
493 output: format!(
494 "Error executing MCP tool '{}' on server '{}': {}",
495 tool_name, server_id, e
496 ),
497 status: crate::ToolExecutionStatus::Failure,
498 });
499 }
500 }
501 } let results_map: HashMap<_, _> = tool_results
505 .iter()
506 .map(|r| (r.tool_call_id.as_str(), r))
507 .collect();
508
509 for tool_call in &tool_calls_to_execute {
510 if let Some(result) = results_map.get(tool_call.id.as_str()) {
511 let status_icon = match result.status {
512 crate::ToolExecutionStatus::Success => "\n\x1b[32m✓\x1b[0m",
513 crate::ToolExecutionStatus::Failure => "\n\x1b[31m✗\x1b[0m",
514 };
515 const MAX_SUMMARY_LEN: usize = 70;
516 let output_preview = result.output.chars().take(MAX_SUMMARY_LEN).collect::<String>();
517 let ellipsis = if result.output.len() > MAX_SUMMARY_LEN { "..." } else { "" };
518
519 println!(
520 "{} {}({}) -> {:?} \"{}{}\"",
521 status_icon,
522 tool_call.function.name,
523 tool_call.function.arguments,
524 result.status,
525 output_preview.replace('\n', " "),
526 ellipsis
527 );
528 } else {
529 warn!(tool_call_id = %tool_call.id, "Result mismatch during summary generation.");
530 }
531 }
532
533 debug!(
534 count = tool_results.len(),
535 "Passing {} tool result(s) back to strategy.",
536 tool_results.len()
537 );
538
539 next_step = self
540 .strategy
541 .process_tool_results(&mut self.state, tool_results)?;
542 }
543 NextStep::DelegateTask(delegation_input) => {
544 warn!(task = ?delegation_input.task_description, "Delegation requested, but not yet implemented.");
546 let delegation_result = crate::DelegationResult {
547 result: "Delegation is not implemented.".to_string(),
548 };
549 next_step = self
550 .strategy
551 .process_delegation_result(&mut self.state, delegation_result)?;
552 }
553 NextStep::Completed(final_message) => {
554 info!("Strategy indicated completion.");
556 trace!(message = %final_message, "Final message from strategy.");
557 return Ok((final_message, self.state.clone()));
558 }
559 } } } }