rust_agent/agents/
mcp_agent.rs

1use anyhow::anyhow;
2use std::pin::Pin;
3use std::sync::Arc;
4use log::info;
5
6use crate::{
7    Agent, AgentAction, AgentFinish, AgentOutput, BaseMemory, ModelChatMessage, ChatMessageContent, ChatModel,
8    McpClient, McpToolAdapter, OpenAIChatModel, Runnable, Tool, parse_model_output
9};
10use serde_json::Value;
11
12/// McpAgent is an intelligent agent implementation based on MCP services
13/// It can connect to MCP servers, process user inputs, call tools, and generate responses
14pub struct McpAgent {
15    client: Arc<dyn McpClient>,
16    tools: Vec<Box<dyn Tool + Send + Sync>>,
17    system_prompt: String,
18    openai_model: Option<OpenAIChatModel>,
19    memory: Option<Box<dyn BaseMemory>>,
20}
21
22impl McpAgent {
23    /// Create a new McpAgent instance
24    pub fn new(client: Arc<dyn McpClient>, system_prompt: String) -> Self {
25        Self {
26            client,
27            tools: Vec::new(),
28            system_prompt,
29            openai_model: None, // Default to not setting OpenAI model
30            memory: None, // Default to not setting memory module
31        }
32    }
33    
34    /// Create a new McpAgent instance with specified OpenAIChatModel
35    pub fn with_openai_model(client: Arc<dyn McpClient>, system_prompt: String, openai_model: OpenAIChatModel) -> Self {
36        Self {
37            client,
38            tools: Vec::new(),
39            system_prompt,
40            openai_model: Some(openai_model),
41            memory: None, // Default to not setting memory module
42        }
43    }
44    
45    /// Create a new McpAgent instance with specified memory module
46    pub fn with_memory(client: Arc<dyn McpClient>, system_prompt: String, memory: Box<dyn BaseMemory>) -> Self {
47        Self {
48            client,
49            tools: Vec::new(),
50            system_prompt,
51            openai_model: None,
52            memory: Some(memory),
53        }
54    }
55
56    /// Create a new McpAgent instance with specified OpenAIChatModel and memory module
57    pub fn with_openai_model_and_memory(client: Arc<dyn McpClient>, system_prompt: String, openai_model: OpenAIChatModel, memory: Box<dyn BaseMemory>) -> Self {
58        Self {
59            client,
60            tools: Vec::new(),
61            system_prompt,
62            openai_model: Some(openai_model),
63            memory: Some(memory),
64        }
65    }
66    
67    /// Get a reference to the memory module
68    pub fn get_memory(&self) -> Option<&Box<dyn BaseMemory>> {
69        self.memory.as_ref()
70    }
71
72    /// Add a tool to the Agent
73    pub fn add_tool(&mut self, tool: Box<dyn Tool + Send + Sync>) {
74        self.tools.push(tool);
75    }
76    
77    /// Automatically get tools from MCP client and add them to the Agent
78    /// This method gets all available tools from the MCP client and wraps them as McpToolAdapter before adding to the Agent
79    /// Local tool registration and addition are handled by the caller
80    pub async fn auto_add_tools(&mut self) -> Result<(), anyhow::Error> {
81        use crate::McpToolAdapter;
82        
83        // Get tool list from MCP client
84        let tools = self.client.get_tools().await?;
85
86        // Print information about the obtained tools
87        for tool in &tools {
88            info!("MCP Client Get Tool: {} - {}", tool.name, tool.description);
89        }
90        
91        // Wrap each tool as McpToolAdapter and add to the Agent
92        for tool in tools {
93            let tool_adapter = McpToolAdapter::new(
94                self.client.clone(),
95                tool
96            );
97            self.add_tool(Box::new(tool_adapter));
98        }
99        
100        Ok(())
101    }
102}
103
104impl Agent for McpAgent {
105    fn tools(&self) -> Vec<Box<dyn Tool + Send + Sync>> {
106        // Return a cloned version of the tool list
107        // To solve the problem that Box<dyn Tool> cannot be directly cloned, we create new tool adapter instances
108        let mut cloned_tools: Vec<Box<dyn Tool + Send + Sync>> = Vec::new();
109
110        // Since McpToolAdapter can be recreated through client and McpTool,
111        // we iterate through existing tools and create new adapter instances for each tool
112        for tool in &self.tools {
113            // Check if the tool is of type McpToolAdapter
114            if let Some(mcp_tool_adapter) = tool.as_any().downcast_ref::<McpToolAdapter>() {
115                // Recreate McpToolAdapter instance
116                let cloned_adapter = McpToolAdapter::new(
117                    mcp_tool_adapter.get_client(),
118                    mcp_tool_adapter.get_mcp_tool(),
119                );
120                cloned_tools.push(Box::new(cloned_adapter));
121            } else {
122                // For other types of tools, we skip or need to implement other cloning mechanisms
123                // Here we can add logs or error handling
124                info!(
125                    "Warning: Unable to clone non-McpToolAdapter type tool: {}",
126                    tool.name()
127                );
128            }
129        }
130
131        cloned_tools
132    }
133
134    fn execute(
135        &self,
136        _action: &AgentAction,
137    ) -> std::pin::Pin<
138        Box<dyn std::future::Future<Output = Result<String, anyhow::Error>> + Send + '_>,
139    > {
140        Box::pin(async move {
141            // In practical applications, there should be a mechanism here to find and call tools
142            // Since we cannot clone the tool list, we simplify the implementation here
143            Err(anyhow!("Tool execution functionality is not implemented yet"))
144        })
145    }
146
147    fn clone_agent(&self) -> Box<dyn Agent> {
148        // Create a new McpAgent instance, copy basic fields, but do not copy tools (simplified implementation)
149        let new_agent = McpAgent::new(
150            self.client.clone(),
151            self.system_prompt.clone(),
152        );
153
154        // Note: We do not copy tools here because Box<dyn Tool> cannot be directly cloned
155        Box::new(new_agent)
156    }
157}
158
159impl Clone for McpAgent {
160    fn clone(&self) -> Self {
161        // Create a new McpAgent instance, but do not copy the tool list (simplified implementation)
162        Self {
163            client: Arc::clone(&self.client),
164            tools: Vec::new(), // Do not copy tools because Box<dyn Tool> cannot be directly cloned
165            system_prompt: self.system_prompt.clone(),
166            openai_model: self.openai_model.clone(), // Clone OpenAI model instance
167            memory: self.memory.clone(), // Clone memory module
168        }
169    }
170}
171
172impl Runnable<std::collections::HashMap<String, String>, AgentOutput> for McpAgent {
173    fn invoke(
174        &self,
175        input: std::collections::HashMap<String, String>,
176    ) -> Pin<Box<dyn std::future::Future<Output = Result<AgentOutput, anyhow::Error>> + Send>> {
177        // Capture system prompt in advance
178        let system_prompt = self.system_prompt.clone();
179        let input_text = input
180            .get("input")
181            .cloned()
182            .unwrap_or_default()
183            .to_string()
184            .trim()
185            .to_string();
186
187        // Capture tool descriptions in advance to avoid using self in async move
188        let tool_descriptions: String = if !self.tools.is_empty() {
189            let mut descriptions = String::new();
190            for tool in &self.tools {
191                descriptions.push_str(&format!("- {}: {}\n", tool.name(), tool.description()));
192            }
193            descriptions
194        } else {
195            String::new()
196        };
197
198        // Capture memory module in advance to avoid using self in async move
199        let memory_clone = self.memory.clone();
200
201        // Build enhanced system prompt using ReAct framework format
202        let enhanced_system_prompt = if !tool_descriptions.is_empty() {
203            format!("{}
204You are an AI assistant that follows the ReAct (Reasoning and Acting) framework. 
205You should think step by step and decide whether to use tools based on user needs.
206You should carefully review and when confirming the use of the tool, if there are omissions, errors, or other issues with the parameters, you should reply and remind the user.
207Available tools:\n{}\n\nWhen you need to use a tool, please respond in the following JSON format:
208            \n{{\"call_tool\": {{\"name\": \"Tool Name\", \"parameters\": {{\"parameter_name\": \"parameter_value\"}}}}}}
209        When you don't need to use a tool, please respond in the following JSON format:\n{{\"content\": \"Your answer\"}}
210        Please think carefully about whether the user's request requires a tool to be used, and only use tools when necessary.", 
211            system_prompt, tool_descriptions)
212        } else {
213            system_prompt
214        };
215
216        // Capture OpenAI model instance in advance to avoid using self in async move
217        let openai_model_clone = self.openai_model.clone();
218
219        Box::pin(async move {
220            // Check if input is empty
221            if input_text.is_empty() {
222                let mut return_values = std::collections::HashMap::new();
223                return_values.insert("answer".to_string(), "Please enter valid content".to_string());
224                // Get model name from OpenAI model, use default value if not available
225                let model_name = if let Some(ref openai_model) = openai_model_clone {
226                    openai_model.model_name().map(|s| s.to_string()).unwrap_or("unknown".to_string())
227                } else {
228                    "unknown".to_string()
229                };
230                return_values.insert("model".to_string(), model_name);
231                return Ok(AgentOutput::Finish(AgentFinish { return_values }));
232            }
233
234            // Use the passed OpenAI model instance or create a new instance
235            let model = if let Some(ref openai_model) = openai_model_clone {
236                // Use the passed OpenAI model instance
237                openai_model
238            } else {
239                // If no OpenAI model instance is provided, return an error
240                let mut return_values = std::collections::HashMap::new();
241                return_values.insert("answer".to_string(), "No OpenAI model provided".to_string());
242                return_values.insert("model".to_string(), "unknown".to_string());
243                return Ok(AgentOutput::Finish(AgentFinish { return_values }));
244            };
245
246            // Build message list
247            let mut messages = Vec::new();
248
249            // Get summary content and append to system prompt
250            let enhanced_system_prompt_with_summary = {
251                let mut enhanced_prompt = enhanced_system_prompt;
252                if let Some(memory) = &memory_clone {
253                    // Try to get summary content from memory module
254                    // Here we use downcast_ref to check if it's CompositeMemory type
255                    if let Some(composite_memory) = memory.as_any().downcast_ref::<crate::memory::composite_memory::CompositeMemory>() {
256                        // If it's CompositeMemory, call get_summary method to get summary
257                        match composite_memory.get_summary().await {
258                            Ok(Some(summary)) => {
259                                // Append summary content to system prompt
260                                enhanced_prompt = format!("{}\n\nPrevious conversation summary: {}", enhanced_prompt, summary);
261                                log::info!("Summary appended to system prompt");
262                            },
263                            Ok(None) => {
264                                log::info!("No summary content found");
265                            },
266                            Err(e) => {
267                                log::warn!("Error getting summary: {}", e);
268                            }
269                        }
270                    } else {
271                        // If not CompositeMemory, try to get summary from memory variables
272                        match memory.load_memory_variables(&std::collections::HashMap::new()).await {
273                            Ok(memories) => {
274                                if let Some(summary) = memories.get("summary") {
275                                    if let Some(summary_str) = summary.as_str() {
276                                        // Append summary content to system prompt
277                                        enhanced_prompt = format!("{}\n\nPrevious conversation summary: {}", enhanced_prompt, summary_str);
278                                        log::info!("Summary retrieved from memory variables and appended to system prompt");
279                                    }
280                                }
281                            },
282                            Err(e) => {
283                                log::warn!("Error getting summary from memory variables: {}", e);
284                            }
285                        }
286                    }
287                }
288                enhanced_prompt
289            };
290
291            // Add system message
292            messages.push(ModelChatMessage::System(ChatMessageContent {
293                content: enhanced_system_prompt_with_summary,
294                name: None,
295                additional_kwargs: std::collections::HashMap::new(),
296            }));
297
298            // If there is a memory module, load memory variables and add them to the message list
299            if let Some(memory) = &memory_clone {
300                match memory.load_memory_variables(&std::collections::HashMap::new()).await {
301                    Ok(memories) => {
302                        info!("Loaded memory variables: {:?}", memories);
303                        if let Some(chat_history) = memories.get("chat_history") {
304                            if let serde_json::Value::Array(messages_array) = chat_history {
305                                for message in messages_array {
306                                    if let serde_json::Value::Object(msg_obj) = message {
307                                        let role = msg_obj.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
308                                        let content = msg_obj.get("content").and_then(|v| v.as_str()).unwrap_or("");
309                                        
310                                        // Skip empty content messages
311                                        if content.trim().is_empty() {
312                                            continue;
313                                        }
314                                        
315                                        // Skip assistant messages containing complete history messages
316                                        if role == "assistant" && content.contains("user:") && content.contains("assistant:") {
317                                            continue;
318                                        }
319                                        
320                                        // Add debug log
321                                        // info!("Loaded message: role={}, content={}", role, content);
322                                        
323                                        match role {
324                                            "human" | "user" => {
325                                                // Add debug log
326                                                log::info!("Loaded human message: content={}", content);
327                                                messages.push(ModelChatMessage::Human(ChatMessageContent {
328                                                    content: content.to_string(),
329                                                    name: None,
330                                                    additional_kwargs: std::collections::HashMap::new(),
331                                                }));
332                                            },
333                                            "ai" | "assistant" => {
334                                                // Add debug log
335                                                log::info!("Loaded AI message: content={}", content);
336                                                messages.push(ModelChatMessage::AIMessage(ChatMessageContent {
337                                                    content: content.to_string(),
338                                                    name: None,
339                                                    additional_kwargs: std::collections::HashMap::new(),
340                                                }));
341                                            },
342                                            "tool" => {
343                                                // Handle tool messages
344                                                let content_str = content.to_string();
345                                                // Add debug log
346                                                log::info!("Loaded tool message: content={}", content_str);
347                                                messages.push(ModelChatMessage::ToolMessage(ChatMessageContent {
348                                                    content: content_str,
349                                                    name: None,
350                                                    additional_kwargs: std::collections::HashMap::new(),
351                                                }));
352                                            },
353                                            _ => {
354                                                // Add debug log
355                                                log::info!("Loaded unknown role message: role={}, content={}", role, content);
356                                                // Ignore messages with unknown roles
357                                            }
358                                        }
359                                    }
360                                }
361                            }
362                        }
363                    },
364                    Err(e) => {
365                        // If loading memory fails, log the error but continue execution
366                        log::warn!("Failed to load memory variables: {}", e);
367                    }
368                }
369            }
370
371            // Add current user message
372            messages.push(ModelChatMessage::Human(ChatMessageContent {
373                content: input_text.clone(),
374                name: None,
375                additional_kwargs: std::collections::HashMap::new(),
376            }));
377            // info!("Added current user message: role=user, content={}", input_text);
378            
379            // Add debug log, showing all messages
380            log::info!("Messages to be sent to model:");
381            for (i, msg) in messages.iter().enumerate() {
382                match msg {
383                    ModelChatMessage::System(content) => {
384                        log::info!("  {}. role=system, content={}", i+1, content.content);
385                    },
386                    ModelChatMessage::Human(content) => {
387                        log::info!("  {}. role=user, content={}", i+1, content.content);
388                    },
389                    ModelChatMessage::AIMessage(content) => {
390                        log::info!("  {}. role=assistant, content={}", i+1, content.content);
391                    },
392                    ModelChatMessage::ToolMessage(content) => {
393                        log::info!("  {}. role=tool, content={}", i+1, content.content);
394                    },
395                }
396            }
397
398            // Call the language model
399            let result = model.invoke(messages).await;
400
401            match result {
402                Ok(completion) => {
403                    // Parse model output
404                    let content = match completion.message {
405                        ModelChatMessage::AIMessage(content) => content.content,
406                        _ => { format!("{},{:?}", "Non-AI message received", completion.message) }
407                    };
408
409                    // Get model name from OpenAI model, use default value if not available
410                    let model_name = model.model_name().map(|s| s.to_string()).unwrap_or("unknown".to_string());
411
412                    // If there is a memory module, save the current conversation to memory
413                    if let Some(memory) = &memory_clone {
414                        let mut inputs = std::collections::HashMap::new();
415                        inputs.insert("input".to_string(), serde_json::Value::String(input_text.clone()));
416                        
417                        // Preprocess content, if it's JSON string format, extract the content field
418                        let processed_content = if content.starts_with('"') && content.ends_with('"') {
419                            // Try to parse as JSON string
420                            match serde_json::from_str::<serde_json::Value>(&content) {
421                                Ok(serde_json::Value::String(s)) => s,
422                                _ => content.clone(),
423                            }
424                        } else if content.starts_with('{') && content.ends_with('}') {
425                            // Try to parse as JSON object
426                            match serde_json::from_str::<serde_json::Value>(&content) {
427                                Ok(json_obj) => {
428                                    // If it's a JSON object, try to extract the content field
429                                    if let Some(content_value) = json_obj.get("content") {
430                                        if let Some(content_str) = content_value.as_str() {
431                                            content_str.to_string()
432                                        } else {
433                                            content.clone()
434                                        }
435                                    } else {
436                                        content.clone()
437                                    }
438                                },
439                                _ => content.clone(),
440                            }
441                        } else {
442                            content.clone()
443                        };
444                        
445                        let mut outputs = std::collections::HashMap::new();
446                        outputs.insert("output".to_string(), serde_json::Value::String(processed_content));
447                        
448                        if let Err(e) = memory.save_context(&inputs, &outputs).await {
449                            log::warn!("Failed to save context to memory: {}", e);
450                        }
451                    }
452
453                    // Parse model output, determine if tool call is needed
454                    // Here should correctly parse the JSON format of model output
455                    if let Ok(parsed_output) = parse_model_output(&content) {
456                        match parsed_output {
457                            AgentOutput::Action(action) => {
458                                // Directly return the Action parsed by the model
459                                return Ok(AgentOutput::Action(action));
460                            }
461                            AgentOutput::Finish(_) => {
462                                // Directly return the answer
463                                let mut return_values = std::collections::HashMap::new();
464                                return_values.insert("answer".to_string(), content.clone());
465                                return_values.insert("model".to_string(), model_name);
466                                return Ok(AgentOutput::Finish(AgentFinish { return_values }));
467                            }
468                        }
469                    } else {
470                        // If parsing fails, try to extract tool call information
471                        // Check if tool call keywords are included
472                        if content.contains("call_tool") {
473                            // Try to extract JSON format tool call from content
474                            // Here should more intelligently parse tool calls instead of using default tools
475                            if let Ok(agent_action) = parse_tool_call_from_content(&content) {
476                                Ok(AgentOutput::Action(agent_action))
477                            } else {
478                                // If unable to parse tool call, directly return the answer
479                                let mut return_values = std::collections::HashMap::new();
480                                return_values.insert("answer".to_string(), content.clone());
481                                return_values.insert("model".to_string(), model_name);
482                                Ok(AgentOutput::Finish(AgentFinish { return_values }))
483                            }
484                        } else {
485                            // Directly return the answer
486                            let mut return_values = std::collections::HashMap::new();
487                            return_values.insert("answer".to_string(), content.clone());
488                            return_values.insert("model".to_string(), model_name);
489                            Ok(AgentOutput::Finish(AgentFinish { return_values }))
490                        }
491                    }
492                }
493                Err(e) => {
494                    // Return error message when an error occurs
495                    // Get model name from OpenAI model, use default value if not available
496                    let model_name = if let Some(ref model) = openai_model_clone {
497                        model.model_name().map(|s| s.to_string()).unwrap_or("unknown".to_string())
498                    } else {
499                        "unknown".to_string()
500                    };
501                    
502                    // Even if model call fails, save user message to memory
503                    if let Some(memory) = &memory_clone {
504                        let mut inputs = std::collections::HashMap::new();
505                        inputs.insert("input".to_string(), serde_json::Value::String(input_text.clone()));
506                        
507                        let mut outputs = std::collections::HashMap::new();
508                        outputs.insert("output".to_string(), serde_json::Value::String(format!("Model invocation failed: {}", e)));
509                        
510                        if let Err(e) = memory.save_context(&inputs, &outputs).await {
511                            log::warn!("Failed to save context to memory: {}", e);
512                        }
513                    }
514                    
515                    let mut return_values = std::collections::HashMap::new();
516                    return_values.insert("answer".to_string(), format!("Model invocation failed: {}", e));
517                    return_values.insert("model".to_string(), model_name);
518                    Ok(AgentOutput::Finish(AgentFinish { return_values }))
519                }
520            }
521        })
522    }
523
524    fn clone_to_owned(
525        &self,
526    ) -> Box<dyn Runnable<std::collections::HashMap<String, String>, AgentOutput> + Send + Sync>
527    {
528        Box::new(self.clone())
529    }
530}
531
532/// Extract JSON object string from content
533fn extract_json_object(content: &str) -> Option<String> {
534    // Find the first '{' and the last '}'
535    if let Some(start) = content.find('{') {
536        if let Some(end) = content.rfind('}') {
537            if end > start {
538                // Extract possible JSON object
539                let json_str = &content[start..=end];
540                
541                // Verify if it's a valid JSON object
542                if let Ok(value) = serde_json::from_str::<serde_json::Value>(json_str) {
543                    if value.is_object() {
544                        return Some(json_str.to_string());
545                    }
546                }
547            }
548        }
549    }
550    None
551}
552
553/// Parse tool call from content
554fn parse_tool_call_from_content(content: &str) -> Result<AgentAction, anyhow::Error> {
555    // Try to extract JSON object
556    if let Some(json_str) = extract_json_object(content) {
557        // Parse JSON
558        let value: Value = serde_json::from_str(&json_str)?;
559        
560        // Check if there's a call_tool field
561        if let Some(call_tool) = value.get("call_tool").and_then(|v| v.as_object()) {
562            // Extract tool name
563            let tool_name = call_tool
564                .get("name")
565                .and_then(|v| v.as_str())
566                .ok_or_else(|| anyhow::anyhow!("Missing tool name"))?
567                .to_string();
568            
569            // Extract parameters and convert to string
570            let tool_input = call_tool
571                .get("parameters")
572                .cloned()
573                .unwrap_or(Value::Object(serde_json::Map::new()))
574                .to_string();
575            
576            // Create AgentAction
577            let action = AgentAction {
578                tool: tool_name,
579                tool_input,
580                log: content.to_string(),
581                thought: None,
582            };
583            
584            return Ok(action);
585        }
586    }
587    
588    // If unable to parse, return error
589    Err(anyhow::anyhow!("Failed to parse tool call from content"))
590}