mcp_agent_client_chatbot/
mcp_agent_client_chatbot.rs

1// MCP 客户端调用服务端工具的端到端示例
2// 此示例演示了如何使用 MCP 客户端连接到 MCP 服务器并调用其注册的工具
3
4use rust_agent::{run_agent, OpenAIChatModel, McpClient, SimpleMcpClient, McpTool, McpAgent, SimpleMemory, BaseMemory};
5use std::sync::Arc;
6use std::collections::HashMap;
7use serde_json::{Value, json};
8
9// 初始化日志记录器
10use log::LevelFilter;
11use env_logger;
12use log::{info, error};
13
14#[tokio::main]
15async fn main() {
16    // 初始化日志记录器
17    env_logger::Builder::new()
18        .filter_level(LevelFilter::Info)
19        .init();
20    
21    info!("=== Rust Agent MCP 客户端调用服务端工具示例 ===");
22    
23    // 从环境变量获取 MCP 服务器 URL
24    let mcp_url = std::env::var("MCP_URL").unwrap_or("http://127.0.0.1:6000".to_string());
25    
26    // 创建 OpenAI 模型实例(可选,用于智能决策是否调用工具)
27    // 如果没有设置 API 密钥,则使用一个占位符
28    let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "sk-00000".to_string());
29    let base_url = std::env::var("OPENAI_API_URL").unwrap_or("https://api.deepseek.com/v1".to_string());
30    let model = OpenAIChatModel::new(api_key.clone(), Some(base_url))
31        .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "deepseek-chat".to_string()))
32        .with_temperature(0.7)
33        .with_max_tokens(8*1024);
34    
35    // 初始化 MCP 客户端
36    let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
37    
38    // 清空默认工具(可选)
39    mcp_client.clear_tools();
40    
41    // 不添加任何本地工具,完全依赖服务端工具
42    info!("Not adding any local tools, will use server-side tools only...");
43    
44    // 注意:客户端不需要实现工具处理器,因为工具实际在服务端执行
45    // 客户端只需要知道工具的名称和描述即可
46    
47    // 连接到 MCP 服务器
48    info!("正在连接到 MCP 服务器: {}", mcp_url);
49    if let Err(e) = mcp_client.connect(&mcp_url).await {
50        error!("连接到 MCP 服务器失败: {}", e);
51        return;
52    } else {
53        mcp_client.set_server_connected(true);
54    }
55    info!("成功连接到 MCP 服务器");
56    
57    // let model_name = model.model_name().map_or("未指定模型".to_string(), |v| v.to_string());
58    
59    info!("----------------------------------------");
60    
61    let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
62    
63    // 创建记忆模块实例
64    let memory = SimpleMemory::new();
65
66    // 创建 Agent 实例
67    let mut agent = McpAgent::with_openai_model_and_memory(
68        client_arc.clone(),
69        "You are an AI assistant that can use tools to answer user questions. Please decide whether to use tools based on the user's needs.".to_string(),
70        model.clone(),
71        Box::new(memory.clone())
72    );
73    
74    // 自动从 MCP 客户端获取工具并添加到 Agent
75    if let Err(e) = agent.auto_add_tools().await {
76        error!("自动添加工具到 Agent 失败: {}", e);
77        return;
78    }
79    
80    println!("MCP 客户端 Agent 已启动!");
81    println!("输入'退出'结束对话");
82    println!("----------------------------------------");
83    
84    // 获取并显示可用工具
85    let tools = client_arc.get_tools().await.unwrap_or_else(|e| {
86        error!("获取工具列表失败: {}", e);
87        vec![]
88    });
89    
90    println!("可用工具:");
91    for (index, tool) in tools.iter().enumerate() {
92        println!("{}. {}: {}", index + 1, tool.name, tool.description);
93    }
94    
95    println!("----------------------------------------");
96    
97    // 演示直接调用工具
98    println!("演示直接调用工具:");
99    
100    // 调用天气工具
101    println!("\n1. 调用天气工具获取北京天气: What's the weather like in Beijing?");
102    let mut weather_params = HashMap::new();
103    weather_params.insert("city".to_string(), Value::String("Beijing".to_string()));
104    
105    match client_arc.call_tool("get_weather", weather_params).await {
106        Ok(result) => {
107            println!("天气查询结果: {}", serde_json::to_string_pretty(&result).unwrap_or_else(|_| "无法格式化结果".to_string()));
108        },
109        Err(e) => {
110            println!("调用天气工具失败: {}", e);
111        }
112    }
113    
114    // 调用计算工具
115    println!("\n2. 调用计算工具计算: 'What is 15.5 plus 24.3?'");
116    let calc_params = json!({
117        "expression": "15.5 + 24.3"
118    });
119    let calc_params_map: HashMap<String, Value> = serde_json::from_value(calc_params).unwrap();
120    
121    match client_arc.call_tool("simple_calculate", calc_params_map).await {
122        Ok(result) => {
123            println!("计算结果: {}", serde_json::to_string_pretty(&result).unwrap_or_else(|_| "无法格式化结果".to_string()));
124        },
125        Err(e) => {
126            println!("调用计算工具失败: {}", e);
127        }
128    }
129    
130    println!("----------------------------------------");
131    
132    // 交互式对话循环
133    println!("现在进入交互模式,您可以询问天气或数学计算问题:");
134    loop {
135        println!("\n您: ");
136        let mut user_input = String::new();
137        std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
138        let user_input = user_input.trim();
139        
140        if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
141            println!("再见!");
142            break;
143        }
144        
145        if user_input.is_empty() {
146            continue;
147        }
148        
149        // 运行 Agent
150        match run_agent(&agent, user_input.to_string()).await {
151            Ok(response) => {
152                // 尝试解析 response 为 JSON 并提取 content 字段
153                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
154                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
155                        println!("助手: {}", content);
156                    } else {
157                        println!("助手: {}", response);
158                    }
159                } else {
160                    println!("助手: {}", response);
161                }
162            },
163            Err(e) => {
164                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
165            },
166        }
167    }
168    
169    // 打印对话历史
170    info!("\n对话历史:");
171    match memory.load_memory_variables(&HashMap::new()).await {
172        Ok(memories) => {
173            if let Some(chat_history) = memories.get("chat_history") {
174                if let serde_json::Value::Array(messages) = chat_history {
175                    for (i, message) in messages.iter().enumerate() {
176                        if let serde_json::Value::Object(msg) = message {
177                            let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
178                            let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
179                            info!("{}. {}: {}", i + 1, role, content);
180                        }
181                    }
182                }
183            }
184        },
185        Err(e) => {
186            error!("加载记忆变量失败: {}", e);
187        }
188    }
189    
190    // 断开 MCP 连接
191    if let Err(e) = client_arc.disconnect().await {
192        error!("断开 MCP 客户端连接失败: {}", e);
193    }
194    
195    info!("\nMCP 客户端示例结束");
196}