volition_core/
agent.rs

1// volition-agent-core/src/agent.rs
2
3use crate::UserInteraction;
4use crate::config::AgentConfig;
5use crate::errors::AgentError;
6use crate::mcp::McpConnection;
7use crate::models::chat::{ApiResponse, ChatMessage};
8use crate::models::tools::{
9    ToolDefinition, ToolParameter, ToolParameterType, ToolParametersDefinition,
10};
11use crate::providers::{Provider, ProviderRegistry};
12use crate::strategies::{NextStep, Strategy};
13use anyhow::{Context, Result, anyhow};
14use rmcp::model::Tool as McpTool;
15use serde_json::{Map, Value};
16use std::collections::HashMap;
17use std::path::Path;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20use tracing::{debug, info, trace, warn};
21
22use crate::AgentState;
23
24pub struct Agent<UI: UserInteraction> {
25    provider_registry: ProviderRegistry,
26    mcp_connections: HashMap<String, Arc<Mutex<McpConnection>>>,
27    #[allow(dead_code)] // Field currently unused
28    http_client: reqwest::Client,
29    #[allow(dead_code)] // Field currently unused
30    ui_handler: Arc<UI>,
31    strategy: Box<dyn Strategy<UI> + Send + Sync>,
32    state: AgentState,
33    current_provider_id: String,
34}
35
36fn mcp_schema_to_tool_params(schema_val: Option<&Map<String, Value>>) -> ToolParametersDefinition {
37    let default_params = ToolParametersDefinition {
38        param_type: "object".to_string(),
39        properties: HashMap::new(),
40        required: Vec::new(),
41    };
42
43    let schema = match schema_val {
44        Some(s) => s,
45        None => return default_params,
46    };
47
48    let props_val = schema.get("properties").and_then(Value::as_object);
49    let required_val = schema.get("required").and_then(Value::as_array);
50    let mut properties = HashMap::new();
51
52    if let Some(props_map) = props_val {
53        for (key, val) in props_map {
54            let prop_obj = match val.as_object() {
55                Some(obj) => obj,
56                None => continue,
57            };
58
59            let param_type_str = prop_obj
60                .get("type")
61                .and_then(Value::as_str)
62                .unwrap_or("string");
63            let description = prop_obj
64                .get("description")
65                .and_then(Value::as_str)
66                .unwrap_or("")
67                .to_string();
68
69            let param_type = match param_type_str {
70                "string" => ToolParameterType::String,
71                "integer" => ToolParameterType::Integer,
72                "number" => ToolParameterType::Number,
73                "boolean" => ToolParameterType::Boolean,
74                "array" => ToolParameterType::Array,
75                "object" => ToolParameterType::Object,
76                _ => ToolParameterType::String,
77            };
78
79            let items = if param_type == ToolParameterType::Array {
80                prop_obj.get("items")
81                    .and_then(Value::as_object)
82                    .map(|items_obj| {
83                        let item_type_str = items_obj
84                            .get("type")
85                            .and_then(Value::as_str)
86                            .unwrap_or("string");
87                        let item_desc = items_obj
88                            .get("description")
89                            .and_then(Value::as_str)
90                            .unwrap_or("Array item")
91                            .to_string();
92                        let item_type = match item_type_str {
93                            "string" => ToolParameterType::String,
94                            "integer" => ToolParameterType::Integer,
95                            "number" => ToolParameterType::Number,
96                            "boolean" => ToolParameterType::Boolean,
97                            "array" => ToolParameterType::Array,
98                            "object" => ToolParameterType::Object,
99                            _ => ToolParameterType::String,
100                        };
101                        Box::new(ToolParameter {
102                            param_type: item_type,
103                            description: item_desc,
104                            enum_values: None,
105                            items: None, // Nested items not supported for now
106                        })
107                    })
108                    .or_else(|| Some(Box::new(ToolParameter {
109                        param_type: ToolParameterType::String,
110                        description: "Array item".to_string(),
111                        enum_values: None,
112                        items: None,
113                    })))
114            } else {
115                None
116            };
117
118            properties.insert(
119                key.clone(),
120                ToolParameter {
121                    param_type,
122                    description,
123                    enum_values: None,
124                    items,
125                },
126            );
127        }
128    }
129
130    let required = required_val
131        .map(|arr| {
132            arr.iter()
133                .filter_map(|v| v.as_str().map(String::from))
134                .collect()
135        })
136        .unwrap_or_default();
137
138    ToolParametersDefinition {
139        param_type: "object".to_string(),
140        properties,
141        required,
142    }
143}
144
145
146// --- DummyClientService struct ---
147struct DummyClientService;
148impl rmcp::service::Service<rmcp::service::RoleClient> for DummyClientService {
149    #[allow(refining_impl_trait)] // Allow Pin<Box<dyn Future>> where trait uses impl Future
150    fn handle_request(
151        &self,
152        _request: rmcp::model::ServerRequest,
153        _context: rmcp::service::RequestContext<rmcp::service::RoleClient>,
154    ) -> std::pin::Pin<
155        Box<
156            dyn std::future::Future<Output = Result<rmcp::model::ClientResult, rmcp::Error>> + Send,
157        >,
158    > {
159        Box::pin(async {
160            Err(rmcp::Error::method_not_found::<rmcp::model::InitializeResultMethod>())
161        })
162    }
163    #[allow(refining_impl_trait)] // Allow Pin<Box<dyn Future>> where trait uses impl Future
164    fn handle_notification(
165        &self,
166        _notification: rmcp::model::ServerNotification,
167    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), rmcp::Error>> + Send>> {
168        Box::pin(async { Ok(()) })
169    }
170    fn get_peer(&self) -> Option<rmcp::service::Peer<rmcp::service::RoleClient>> {
171        None
172    }
173    fn set_peer(&mut self, _peer: rmcp::service::Peer<rmcp::service::RoleClient>) {}
174    fn get_info(&self) -> rmcp::model::ClientInfo {
175        rmcp::model::ClientInfo::default()
176    }
177}
178
179
180impl<UI: UserInteraction + 'static> Agent<UI> {
181    #[allow(clippy::too_many_arguments)]
182    pub fn new(
183        config: AgentConfig,
184        ui_handler: Arc<UI>,
185        strategy: Box<dyn Strategy<UI> + Send + Sync>,
186        history: Option<Vec<ChatMessage>>,
187        current_user_input: String,
188        provider_registry_override: Option<ProviderRegistry>,
189        mcp_connections_override: Option<HashMap<String, Arc<Mutex<McpConnection>>>>,
190    ) -> Result<Self> {
191        let http_client = reqwest::Client::builder()
192            .build()
193            .context("Failed to build HTTP client for Agent")?;
194
195        let provider_registry = match provider_registry_override {
196            Some(registry) => registry,
197            None => {
198                let mut registry = ProviderRegistry::new(config.default_provider.clone());
199                for (id, provider_conf) in config.providers {
200                    let api_key = if !provider_conf.api_key_env_var.is_empty() {
201                        match std::env::var(&provider_conf.api_key_env_var) {
202                            Ok(key) => key,
203                            Err(e) => {
204                                warn!(provider_id = %id, env_var = %provider_conf.api_key_env_var, error = %e, "API key environment variable not set or invalid");
205                                String::new()
206                            }
207                        }
208                    } else {
209                        String::new()
210                    };
211                    let model_config = provider_conf.model_config;
212                    let provider: Box<dyn Provider> = match provider_conf.provider_type.as_str() {
213                        "gemini" => Box::new(crate::providers::gemini::GeminiProvider::new(
214                            model_config,
215                            http_client.clone(),
216                            api_key,
217                        )),
218                        "ollama" => Box::new(crate::providers::ollama::OllamaProvider::new(
219                            model_config,
220                            http_client.clone(),
221                            api_key, // Note: OllamaProvider ignores the key in its `new` fn
222                        )),
223                        "openai" => Box::new(crate::providers::openai::OpenAIProvider::new(
224                            model_config,
225                            http_client.clone(),
226                            api_key,
227                        )),
228                        _ => {
229                            return Err(anyhow!(
230                                "Unsupported provider type: '{}' specified for provider ID '{}'. Supported types: gemini, ollama, openai.",
231                                provider_conf.provider_type,
232                                id // Added provider ID to error message for clarity
233                            ));
234                        }
235                    };
236                    registry.register(id.clone(), provider); // Register the created provider instance
237                }
238                registry
239            }
240        };
241
242        let mcp_connections = match mcp_connections_override {
243            Some(connections) => connections,
244            None => {
245                let mut connections = HashMap::new();
246                for (id, server_conf) in config.mcp_servers {
247                    let connection = McpConnection::new(server_conf.command, server_conf.args);
248                    connections.insert(id, Arc::new(Mutex::new(connection)));
249                }
250                connections
251            }
252        };
253
254        let initial_state = AgentState::new_turn(history, current_user_input);
255        let default_provider_id = provider_registry.default_provider_id().to_string();
256
257        info!(
258            strategy = strategy.name(),
259            default_provider = %default_provider_id,
260            "Initializing MCP Agent with strategy."
261        );
262
263        Ok(Self {
264            provider_registry,
265            mcp_connections,
266            http_client,
267            ui_handler,
268            strategy,
269            state: initial_state,
270            current_provider_id: default_provider_id,
271        })
272    }
273
274    // --- ensure_mcp_connection, switch_provider, get_completion, call_mcp_tool, get_mcp_resource, list_mcp_tools remain unchanged ---
275    async fn ensure_mcp_connection(&self, server_id: &str) -> Result<()> {
276        let conn_mutex = self
277            .mcp_connections
278            .get(server_id)
279            .ok_or_else(|| anyhow!("MCP server config not found: {}", server_id))?;
280        let conn_guard = conn_mutex.lock().await;
281        let ct = tokio_util::sync::CancellationToken::new();
282        conn_guard
283            .establish_connection_external(DummyClientService, ct)
284            .await
285    }
286
287    pub fn switch_provider(&mut self, provider_id: &str) -> Result<()> {
288        self.provider_registry.get(provider_id)?;
289        if self.current_provider_id != provider_id {
290            debug!(old_provider = %self.current_provider_id, new_provider = %provider_id, "Switching provider");
291            self.current_provider_id = provider_id.to_string();
292        }
293        Ok(())
294    }
295
296    pub async fn get_completion(
297        &self,
298        messages: Vec<ChatMessage>,
299        tools: Option<&[ToolDefinition]>,
300    ) -> Result<ApiResponse> {
301        let provider = self.provider_registry.get(&self.current_provider_id)?;
302        debug!(provider = %self.current_provider_id, num_messages = messages.len(), "Getting completion from provider");
303        provider.get_completion(messages, tools).await
304    }
305
306    pub async fn call_mcp_tool(
307        &self,
308        server_id: &str,
309        tool_name: &str,
310        args: Value,
311    ) -> Result<Value> {
312        self.ensure_mcp_connection(server_id).await?;
313        let conn_mutex = self.mcp_connections.get(server_id).unwrap();
314        let conn = conn_mutex.lock().await;
315        conn.call_tool(tool_name, args).await
316    }
317
318    pub async fn get_mcp_resource(&self, server_id: &str, uri: &str) -> Result<Value> {
319        self.ensure_mcp_connection(server_id).await?;
320        let conn_mutex = self.mcp_connections.get(server_id).unwrap();
321        let conn = conn_mutex.lock().await;
322        debug!(server = %server_id, uri = %uri, "Getting MCP resource");
323        conn.get_resource(uri).await
324    }
325
326     pub async fn list_mcp_tools(&self) -> Result<Vec<McpTool>> {
327        let mut all_tools = Vec::new();
328        for (id, conn_mutex) in &self.mcp_connections {
329            match self.ensure_mcp_connection(id).await {
330                Ok(_) => {
331                    let conn = conn_mutex.lock().await;
332                    match conn.list_tools().await {
333                        Ok(tools) => all_tools.extend(tools),
334                        Err(e) => {
335                            warn!(server_id = %id, error = ?e, "Failed to list tools from MCP server (post-connection)")
336                        }
337                    }
338                }
339                Err(e) => {
340                    warn!(server_id = %id, error = ?e, "Failed to ensure MCP connection for listing tools");
341                }
342            }
343        }
344        Ok(all_tools)
345    }
346
347
348    pub async fn run(&mut self, _working_dir: &Path) -> Result<(String, AgentState), AgentError> {
349        info!(strategy = self.strategy.name(), "Starting MCP agent run.");
350
351        let mut next_step = self.strategy.initialize_interaction(&mut self.state)?;
352
353        loop {
354            trace!(?next_step, "Processing next step.");
355            match next_step {
356                NextStep::CallApi(state_from_strategy) => {
357                    // --- This block remains unchanged ---
358                    self.state = state_from_strategy;
359                    let mcp_tools = self
360                        .list_mcp_tools()
361                        .await
362                        .map_err(|e| AgentError::Mcp(e.context("Failed to list MCP tools")))?;
363
364                    let tool_definitions: Vec<ToolDefinition> = mcp_tools
365                        .iter()
366                        .map(|mcp_tool| {
367                            let schema_map = mcp_tool.input_schema.as_ref();
368                            ToolDefinition {
369                                name: mcp_tool.name.to_string(),
370                                description: mcp_tool.description.to_string(),
371                                parameters: mcp_schema_to_tool_params(Some(schema_map)),
372                            }
373                         })
374                        .collect();
375
376                    debug!(
377                        provider = %self.current_provider_id,
378                        num_messages = self.state.messages.len(),
379                        num_tools = tool_definitions.len(),
380                        "Sending request to AI provider."
381                    );
382
383                    let api_response = self
384                        .get_completion(
385                            self.state.messages.clone(),
386                            if tool_definitions.is_empty() { None } else { Some(&tool_definitions) },
387                        )
388                        .await
389                        .map_err(|e| AgentError::Api(e.context("API call failed during agent run")))?;
390
391                    debug!("Received response from AI.");
392                    trace!(response = %serde_json::to_string_pretty(&api_response).unwrap_or_default(), "Full API Response");
393
394                    next_step = self
395                        .strategy
396                        .process_api_response(&mut self.state, api_response)?;
397                }
398                NextStep::CallTools(state_from_strategy) => {
399                    self.state = state_from_strategy;
400                    let tool_calls_to_execute = self.state.pending_tool_calls.clone();
401
402                    if tool_calls_to_execute.is_empty() {
403                        warn!("Strategy requested tool calls, but none were pending.");
404                        return Err(AgentError::Strategy(
405                            "Strategy requested tool calls, but none were pending in state".to_string(),
406                        ));
407                    }
408
409                    // Print assistant message before tool execution
410                    if let Some(last_message) = self.state.messages.last() {
411                        if last_message.role == "assistant" {
412                            if let Some(content) = &last_message.content {
413                                if !content.trim().is_empty() {
414                                     println!("\nAssistant: {}", content);
415                                }
416                            }
417                        }
418                    }
419
420                    info!(
421                        count = tool_calls_to_execute.len(),
422                        "Executing {} requested tool call(s) via MCP.",
423                        tool_calls_to_execute.len()
424                    );
425
426                    let mut tool_results = Vec::new();
427                    for tool_call in &tool_calls_to_execute {
428                        let tool_name = &tool_call.function.name;
429                        let args: Value = serde_json::from_str(&tool_call.function.arguments)
430                            .map_err(|e| {
431                                warn!(tool_call_id = %tool_call.id, tool_name=%tool_name, args_str=%tool_call.function.arguments, error=%e, "Failed to parse tool arguments JSON string. Using null.");
432                                e
433                            })
434                            .unwrap_or(Value::Null);
435
436                        let server_id = match tool_name.as_str() {
437                            "read_file" | "write_file" => "filesystem",
438                            "shell" => "shell",
439                            "git_diff" | "git_status" | "git_commit" => "git",
440                            "search_text" => "search",
441                            _ => {
442                                warn!(tool_name = %tool_name, "Cannot map tool to MCP server, skipping.");
443                                tool_results.push(crate::ToolResult {
444                                    tool_call_id: tool_call.id.clone(),
445                                    output: format!("Error: Unknown tool name '{}'", tool_name),
446                                    status: crate::ToolExecutionStatus::Failure,
447                                });
448                                continue;
449                            }
450                        };
451
452                        println!(
453                            "\n\x1b[33m▶\x1b[0m Running: {}({})",
454                            tool_name,
455                            &tool_call.function.arguments
456                        );
457
458                        match self.call_mcp_tool(server_id, tool_name, args).await {
459                            Ok(output_value) => {
460                                let output_str = match output_value {
461                                    Value::String(s) => s,
462                                    Value::Object(map) if map.contains_key("content") => {
463                                        serde_json::to_string(&map).unwrap_or_else(|_| "<invalid JSON object>".to_string())
464                                    },
465                                    Value::Object(map) if map.contains_key("text") => map
466                                        .get("text")
467                                        .and_then(Value::as_str)
468                                        .unwrap_or("")
469                                        .to_string(),
470                                    Value::Array(arr) if arr.is_empty() => {
471                                        if tool_name == "write_file" {
472                                             "<write successful>".to_string() // Specific message for successful write
473                                        } else {
474                                             "<empty array result>".to_string() // Generic for other tools
475                                        }
476                                    }
477                                    Value::Array(arr) => serde_json::to_string_pretty(&arr)
478                                        .unwrap_or_else(|_| "<invalid JSON array>".to_string()),
479                                    Value::Object(map) => serde_json::to_string_pretty(&map)
480                                        .unwrap_or_else(|_| "<invalid JSON object>".to_string()),
481                                    Value::Null => "<no output>".to_string(),
482                                    other => other.to_string(),
483                                };
484                                tool_results.push(crate::ToolResult {
485                                    tool_call_id: tool_call.id.clone(),
486                                    output: output_str,
487                                    status: crate::ToolExecutionStatus::Success,
488                                });
489                            }
490                            Err(e) => {
491                                tool_results.push(crate::ToolResult {
492                                    tool_call_id: tool_call.id.clone(),
493                                    output: format!(
494                                        "Error executing MCP tool '{}' on server '{}': {}",
495                                        tool_name, server_id, e
496                                    ),
497                                    status: crate::ToolExecutionStatus::Failure,
498                                });
499                            }
500                        }
501                    } // End of for tool_call loop
502
503                    // Log summary
504                    let results_map: HashMap<_, _> = tool_results
505                        .iter()
506                        .map(|r| (r.tool_call_id.as_str(), r))
507                        .collect();
508
509                    for tool_call in &tool_calls_to_execute {
510                        if let Some(result) = results_map.get(tool_call.id.as_str()) {
511                            let status_icon = match result.status {
512                                crate::ToolExecutionStatus::Success => "\n\x1b[32m✓\x1b[0m",
513                                crate::ToolExecutionStatus::Failure => "\n\x1b[31m✗\x1b[0m",
514                            };
515                            const MAX_SUMMARY_LEN: usize = 70;
516                            let output_preview = result.output.chars().take(MAX_SUMMARY_LEN).collect::<String>();
517                            let ellipsis = if result.output.len() > MAX_SUMMARY_LEN { "..." } else { "" };
518
519                             println!(
520                                "{} {}({}) -> {:?} \"{}{}\"",
521                                status_icon,
522                                tool_call.function.name,
523                                tool_call.function.arguments,
524                                result.status,
525                                output_preview.replace('\n', " "),
526                                ellipsis
527                            );
528                        } else {
529                            warn!(tool_call_id = %tool_call.id, "Result mismatch during summary generation.");
530                        }
531                    }
532
533                    debug!(
534                        count = tool_results.len(),
535                        "Passing {} tool result(s) back to strategy.",
536                        tool_results.len()
537                    );
538
539                    next_step = self
540                        .strategy
541                        .process_tool_results(&mut self.state, tool_results)?;
542                }
543                NextStep::DelegateTask(delegation_input) => {
544                     // --- This block remains unchanged ---
545                    warn!(task = ?delegation_input.task_description, "Delegation requested, but not yet implemented.");
546                    let delegation_result = crate::DelegationResult {
547                        result: "Delegation is not implemented.".to_string(),
548                    };
549                    next_step = self
550                        .strategy
551                        .process_delegation_result(&mut self.state, delegation_result)?;
552                }
553                NextStep::Completed(final_message) => {
554                    // --- This block remains unchanged ---
555                    info!("Strategy indicated completion.");
556                    trace!(message = %final_message, "Final message from strategy.");
557                    return Ok((final_message, self.state.clone()));
558                }
559            } // End match next_step
560        } // End loop
561    } // End run
562} // End impl Agent