mcp_agent_hybrid_chatbot/
mcp_agent_hybrid_chatbot.rs

1// 混合模式MCP Agent示例 - 同时支持本地工具和MCP服务器工具
2use rust_agent::{run_agent, OpenAIChatModel, McpClient, SimpleMcpClient, McpTool, McpAgent, CompositeMemory, BaseMemory, Agent};
3use std::sync::Arc;
4use std::collections::HashMap;
5use chrono;
6use serde_json::{Value, json};
7
8// 初始化日志记录器
9use log::LevelFilter;
10use env_logger;
11use tokio;
12use log::{info, error};
13
14#[tokio::main]
15async fn main() {
16    // 初始化日志记录器
17    env_logger::Builder::new()
18        .filter_level(LevelFilter::Info)  // 设置日志级别为Error以便查看错误信息
19        .init();
20    
21    info!("=== Rust Agent 混合模式示例 ===");
22    
23    // 从环境变量获取API密钥和基本URL
24    let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
25    let base_url = std::env::var("OPENAI_API_URL").ok();
26    let mcp_url = std::env::var("MCP_URL").unwrap_or("http://127.0.0.1:6000".to_string());  // 默认MCP服务器地址
27    
28    // 创建OpenAI模型实例 - 支持Openai兼容 API
29    let model = OpenAIChatModel::new(api_key.clone(), base_url)
30        .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
31        .with_temperature(0.7)
32        .with_max_tokens(8*1024);
33    
34    // 初始化MCP客户端
35    let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
36    
37    // 清空默认工具(可选)
38    mcp_client.clear_tools();
39    
40    // 添加本地自定义工具定义
41    mcp_client.add_tools(vec![
42        McpTool {
43            name: "get_local_time".to_string(),
44            description: "Get the current local time and date. For example: 'What time is it?'".to_string(),
45        },
46    ]);
47    
48    // 注册本地工具处理器
49    mcp_client.register_tool_handler("get_local_time".to_string(), |_params: HashMap<String, Value>| async move {
50        let now = chrono::Local::now();
51        Ok(json!({
52            "current_time": now.format("%Y-%m-%d %H:%M:%S").to_string(),
53            "timezone": "Local"
54        }))
55    });
56    
57    // 不注册calculate_expression的本地处理器,让其使用服务端工具
58    info!("Using local 'get_local_time' tool and server-side calculation tools...");
59    
60    // 尝试连接到 MCP 服务器
61    match mcp_client.connect(&mcp_url).await {
62        Ok(_) => {
63            info!("Successfully connected to MCP server at {}", mcp_url);
64            
65            // 设置连接状态为已连接
66            mcp_client.set_server_connected(true);
67            
68            // // 获取服务器工具
69            // match mcp_client.get_tools().await {
70            //     Ok(server_tools) => {
71            //         info!("Retrieved {} tools from MCP server", server_tools.len());
72            //         for tool in &server_tools {
73            //             info!("Server tool: {} - {}", tool.name, tool.description);
74            //         }
75            //     },
76            //     Err(e) => {
77            //         error!("Failed to get tools from MCP server: {}", e);
78            //     }
79            // }
80        },
81        Err(e) => {
82            error!("Failed to connect to MCP server: {}", e);
83        }
84    }
85
86    // 创建记忆模块实例 - 使用CompositeMemory替代SimpleMemory
87    let memory = CompositeMemory::with_basic_params(
88        "data".into(),           // 组合记忆数据存储目录
89        200,           // 摘要阈值(token数量)
90        10          // 保留的最近消息数量
91    ).await.expect("Failed to create CompositeMemory");
92
93    // 创建Agent实例
94    let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
95    let mut agent = McpAgent::with_openai_model_and_memory(
96        client_arc.clone(),
97        "You are an AI assistant that can use both local tools and remote MCP server tools. Please decide whether to use tools based on the user's needs.".to_string(),
98        model.clone(),
99        Box::new(memory.clone())
100    );
101    
102    // 自动从MCP客户端获取工具并添加到Agent
103    // 这会同时添加MCP服务器工具和本地工具
104    if let Err(e) = agent.auto_add_tools().await {
105        error!("Warning: Failed to auto add tools to McpAgent: {}", e);
106    }
107
108    info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
109    info!("----------------------------------------");
110    
111    println!("基于MCP的混合模式AI Agent聊天机器人已启动!");
112    println!("输入'退出'结束对话");
113    println!("----------------------------------------");
114    
115    // 显示可用工具
116    println!("Available tools:");
117    let tools = agent.tools();
118    if tools.is_empty() {
119        println!("No tools available");
120    } else {
121        for (index, tool) in tools.iter().enumerate() {
122            println!("{}. {}: {}", index + 1, tool.name(), tool.description());
123        }
124    }
125    println!("----------------------------------------");
126    
127    // 示例对话
128    println!("示例对话:");
129    let examples = vec![
130        "What time is it?",
131        "What is 15.5 plus 24.3?",
132    ];
133    
134    for example in examples {
135        println!("你: {}", example);
136        
137        // 创建输入上下文
138        let mut inputs = HashMap::new();
139        inputs.insert("input".to_string(), serde_json::Value::String(example.to_string()));
140        
141        // 运行Agent
142        match run_agent(&agent, example.to_string()).await {
143            Ok(response) => {
144                // 尝试解析response为JSON并提取content字段
145                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
146                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
147                        println!("助手: {}", content);
148                    } else {
149                        println!("助手: {}", response);
150                    }
151                } else {
152                    println!("助手: {}", response);
153                }
154            },
155            Err(e) => {
156                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
157            },
158        }
159        
160        info!("----------------------------------------");
161    }
162    
163    // 交互式对话循环
164    println!("现在开始交互式对话(输入'退出'结束对话):");
165    loop {
166        let mut user_input = String::new();
167        println!("你: ");
168        std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
169        let user_input = user_input.trim();
170        
171        if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
172            println!("再见!");
173            break;
174        }
175        
176        if user_input.is_empty() {
177            continue;
178        }
179        
180        // 运行Agent
181        match run_agent(&agent, user_input.to_string()).await {
182            Ok(response) => {
183                // 尝试解析response为JSON并提取content字段
184                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
185                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
186                        println!("助手: {}", content);
187                    } else {
188                        println!("助手: {}", response);
189                    }
190                } else {
191                    println!("助手: {}", response);
192                }
193            },
194            Err(e) => {
195                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
196            },
197        }
198        
199        info!("----------------------------------------");
200    }
201    
202    // 打印对话历史
203    info!("对话历史:");
204    match memory.load_memory_variables(&HashMap::new()).await {
205        Ok(memories) => {
206            if let Some(chat_history) = memories.get("chat_history") {
207                if let serde_json::Value::Array(messages) = chat_history {
208                    for (i, message) in messages.iter().enumerate() {
209                        if let serde_json::Value::Object(msg) = message {
210                            let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
211                            let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
212                            info!("{}. {}: {}", i + 1, role, content);
213                        }
214                    }
215                }
216            }
217        },
218        Err(e) => {
219            info!("Failed to load memory variables: {}", e);
220        }
221    }
222    
223    // 断开MCP连接
224    if let Err(e) = client_arc.disconnect().await {
225        error!("Failed to disconnect MCP client: {}", e);
226    }
227}