mcp_agent_local_chatbot/
mcp_agent_local_chatbot.rs

1// 基于MCP的AI Agent聊天机器人示例
2use std::path::PathBuf;
3use rust_agent::{run_agent, OpenAIChatModel, McpClient, SimpleMcpClient, McpTool, McpAgent, SimpleMemory, BaseMemory, CompositeMemory};
4use std::sync::Arc;
5use std::collections::HashMap;
6use chrono;
7use serde_json::{Value, json};
8use anyhow::Error;
9
10// 初始化日志记录器
11use log::LevelFilter;
12use env_logger;
13use log::{info, error};
14
15#[tokio::main]
16async fn main() {
17    // 初始化日志记录器
18    env_logger::Builder::new()
19        .filter_level(LevelFilter::Info)
20        .init();
21    
22    info!("=== Rust Agent 使用示例 ===");
23    
24    // 获取记忆类型配置
25    let memory_type = std::env::var("MEMORY_TYPE").unwrap_or_else(|_| "composite".to_string());
26    let summary_threshold = std::env::var("SUMMARY_THRESHOLD")
27        .ok()
28        .and_then(|s| s.parse().ok())
29        .unwrap_or(200);
30    let recent_messages_count = std::env::var("RECENT_MESSAGES_COUNT")
31        .ok()
32        .and_then(|s| s.parse().ok())
33        .unwrap_or(10);
34    
35    info!("使用记忆类型: {}", memory_type);
36    info!("摘要阈值: {}", summary_threshold);
37    info!("保留最近消息数: {}", recent_messages_count);
38    
39    // 从环境变量获取API密钥和基本URL
40    let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
41    let base_url = std::env::var("OPENAI_API_URL").ok();
42    let mcp_url = std::env::var("MCP_URL").unwrap_or("http://localhost:8000/mcp".to_string());
43    
44    // 创建OpenAI模型实例 - 支持Openai兼容 API
45    let model = OpenAIChatModel::new(api_key.clone(), base_url)
46        .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
47        .with_temperature(0.7)
48        .with_max_tokens(8*1024);
49    
50    // 初始化MCP客户端
51    // 在初始化 MCP 客户端后,自定义工具和工具处理器
52    let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
53    
54    // 清空默认工具(可选)
55    mcp_client.clear_tools();
56    
57    // 添加自定义工具
58    mcp_client.add_tools(vec![
59        McpTool {
60            name: "get_weather".to_string(),
61            description: format!(
62                "Get weather information for a specified city. For example: 'What's the weather like in Beijing?'.
63                The parameter request body you should extract is: '\"parameters\": {{ \"city\": \"{}\" }}'",
64                "city".to_string()),
65        },
66        McpTool {
67            name: "simple_calculate".to_string(),
68            description: format!(
69                "Execute simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'.
70                The parameter request body you should extract is: '\"parameters\": {{ \"expression\": \"{}\" }}'",
71                "expression".to_string()),
72        },
73    ]);
74    
75    // 注册自定义工具处理器
76    mcp_client.register_tool_handler("get_weather".to_string(), |params: HashMap<String, Value>| async move {
77        let default_city = Value::String("Shanghai".to_string());
78        let city_value = params.get("city").unwrap_or(&default_city);
79        let city = city_value.as_str().unwrap_or("Shanghai");
80        Ok(json!({
81            "city": city,
82            "temperature": "25°C",
83            "weather": "Sunny",
84            "humidity": "40%",
85            "updated_at": chrono::Utc::now().to_rfc3339()
86        }))
87    });
88    
89    mcp_client.register_tool_handler("simple_calculate".to_string(), |params: HashMap<String, Value>| async move {
90        let expression_value = params.get("expression").ok_or_else(|| Error::msg("Missing calculation expression"))?;
91        let expression = expression_value.as_str().ok_or_else(|| Error::msg("Expression format error"))?;
92        
93        // 解析表达式,提取操作数和运算符
94        let result = parse_and_calculate(expression)?;
95        
96        Ok(json!({
97            "expression": expression,
98            "result": result,
99            "calculated_at": chrono::Utc::now().to_rfc3339()
100        }))
101    });
102    
103    // 不连接到 MCP 服务器,仅使用本地工具
104    info!("Using local tools only, not connecting to MCP server...");
105
106    info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
107    info!("Using API URL: {}", model.base_url());
108    info!("----------------------------------------");
109    
110    let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
111    
112    // 根据配置创建不同类型的记忆模块实例
113    let memory: Box<dyn BaseMemory> = match memory_type.as_str() {
114        "simple" => {
115            info!("使用SimpleMemory (仅内存记忆)");
116            Box::new(SimpleMemory::new())
117        },
118        "composite" => {
119            info!("使用CompositeMemory (组合记忆 - 支持中长期记忆和摘要记忆)");
120            // 使用新的简化接口,只需要提供必要的参数
121            // session_id将在内部自动生成
122            let memory = CompositeMemory::with_basic_params(
123                PathBuf::from("./data/memory"),
124                summary_threshold,
125                recent_messages_count,
126            ).await.expect("Failed to create composite memory");
127            
128            Box::new(memory)
129        },
130        _ => {
131            error!("未知的记忆类型: {}, 使用默认的SimpleMemory", memory_type);
132            Box::new(SimpleMemory::new())
133        }
134    };
135
136    // 创建Agent实例,并传递temperature、max_tokens和memory参数
137    let user_system_prompt = "You are an AI assistant that can use tools to answer user questions. Please decide whether to use tools based on the user's needs.".to_string();
138    
139    let mut agent = McpAgent::with_openai_model_and_memory(
140        client_arc.clone(),
141        user_system_prompt,
142        model.clone(),
143        memory
144    );
145    
146    // 尝试从MCP服务器自动获取工具并添加到Agent
147    if let Err(e) = agent.auto_add_tools().await {
148        error!("Failed to auto add tools from MCP server: {}", e);
149    }
150    
151    println!("基于MCP的AI Agent聊天机器人已启动!");
152    println!("记忆类型: {}", memory_type);
153    if memory_type == "composite" {
154        println!("摘要功能: 已启用 (阈值: {} 条消息)", summary_threshold);
155        println!("中长期记忆: 已启用");
156    }
157    println!("输入'退出'结束对话");
158    println!("----------------------------------------");
159    println!("Using tools example:");
160    let tools = client_arc.get_tools().await.unwrap_or_else(|e| {
161        error!("Failed to get tools from MCP server: {}", e);
162        // 返回本地工具列表
163        vec![
164            McpTool {
165                name: "get_weather".to_string(),
166                description: "Get the weather information for a specified city. For example: 'What's the weather like in Beijing?'".to_string(),
167            },
168            McpTool {
169                name: "simple_calculate".to_string(),
170                description: "Perform simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'".to_string(),
171            },
172        ]
173    });
174    
175    // 打印工具列表
176    let mut index = 0;
177    for tool in &tools {
178        index += 1;
179
180        println!("{index}. {}: {}", tool.name, tool.description);
181    }
182    
183    println!("----------------------------------------");
184    // 对话循环
185    loop {
186        let mut user_input = String::new();
187        println!("你: ");
188        std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
189        println!("");
190        let user_input = user_input.trim();
191        
192        if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
193            println!("再见!");
194            break;
195        }
196        
197        // 创建输入上下文
198        let mut inputs = HashMap::new();
199        inputs.insert("input".to_string(), serde_json::Value::String(user_input.to_string()));
200        
201        // 运行Agent
202        match run_agent(&agent, user_input.to_string()).await {
203            Ok(response) => {
204                // 尝试解析response为JSON并提取content字段
205                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
206                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
207                        println!("助手: {}", content);
208                    } else {
209                        println!("助手: {}", response);
210                    }
211                } else {
212                    println!("助手: {}", response);
213                }
214            },
215            Err(e) => {
216                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
217            },
218        }
219        
220        info!("----------------------------------------");
221    }
222    
223    // 打印对话历史
224    info!("对话历史:");
225    if let Some(memory) = agent.get_memory() {
226        match memory.load_memory_variables(&HashMap::new()).await {
227            Ok(memories) => {
228                if let Some(chat_history) = memories.get("chat_history") {
229                    if let serde_json::Value::Array(messages) = chat_history {
230                        info!("总消息数: {}", messages.len());
231                        for (i, message) in messages.iter().enumerate() {
232                            if let serde_json::Value::Object(msg) = message {
233                                let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
234                                let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
235                                // 限制内容长度以便显示
236                                let display_content = if content.len() > 100 {
237                                    format!("{}...", &content[..100])
238                                } else {
239                                    content.to_string()
240                                };
241                                info!("{}. {}: {}", i + 1, role, display_content);
242                            }
243                        }
244                    }
245                }
246                
247                // 如果有摘要,也打印出来
248                if let Some(summary) = memories.get("summary") {
249                    if let serde_json::Value::String(summary_text) = summary {
250                        info!("对话摘要: {}", summary_text);
251                    }
252                }
253            },
254            Err(e) => {
255                info!("Failed to load memory variables: {}", e);
256            }
257        }
258    } else {
259        info!("No memory available");
260    }
261    
262    // 断开MCP连接
263    if let Err(e) = client_arc.disconnect().await {
264        error!("Failed to disconnect MCP client: {}", e);
265    }
266}
267
268// 解析表达式并计算结果
269fn parse_and_calculate(expression: &str) -> Result<f64, Error> {
270    let expression = expression.replace(" ", "");
271    
272    // 尝试匹配不同的运算符
273    for op_char in ["+", "-", "*", "/"].iter() {
274        if let Some(pos) = expression.find(op_char) {
275            // 提取左右操作数
276            let left_str = &expression[0..pos];
277            let right_str = &expression[pos + 1..];
278            
279            // 转换为浮点数
280            let left = left_str.parse::<f64>().map_err(|e| 
281                Error::msg(format!("Left operand format error: {}", e)))?;
282            let right = right_str.parse::<f64>().map_err(|e| 
283                Error::msg(format!("Right operand format error: {}", e)))?;
284            
285            // 执行相应的运算
286            let result = match *op_char {
287                "+" => left + right,
288                "-" => left - right,
289                "*" => left * right,
290                "/" => {
291                    if right == 0.0 {
292                        return Err(Error::msg("除数不能为零"));
293                    }
294                    left / right
295                },
296                _ => unreachable!()
297            };
298            
299            return Ok(result);
300        }
301    }
302    
303    // 如果没有找到运算符,尝试将整个表达式解析为数字
304    if let Ok(number) = expression.parse::<f64>() {
305        return Ok(number);
306    }
307    
308    Err(Error::msg(format!("Failed to parse expression: {}", expression)))
309}