1use std::path::PathBuf;
3use rust_agent::{run_agent, OpenAIChatModel, McpClient, SimpleMcpClient, McpTool, McpAgent, SimpleMemory, BaseMemory, CompositeMemory};
4use std::sync::Arc;
5use std::collections::HashMap;
6use chrono;
7use serde_json::{Value, json};
8use anyhow::Error;
9
10use log::LevelFilter;
12use env_logger;
13use log::{info, error};
14
15#[tokio::main]
16async fn main() {
17 env_logger::Builder::new()
19 .filter_level(LevelFilter::Info)
20 .init();
21
22 info!("=== Rust Agent 使用示例 ===");
23
24 let memory_type = std::env::var("MEMORY_TYPE").unwrap_or_else(|_| "composite".to_string());
26 let summary_threshold = std::env::var("SUMMARY_THRESHOLD")
27 .ok()
28 .and_then(|s| s.parse().ok())
29 .unwrap_or(200);
30 let recent_messages_count = std::env::var("RECENT_MESSAGES_COUNT")
31 .ok()
32 .and_then(|s| s.parse().ok())
33 .unwrap_or(10);
34
35 info!("使用记忆类型: {}", memory_type);
36 info!("摘要阈值: {}", summary_threshold);
37 info!("保留最近消息数: {}", recent_messages_count);
38
39 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
41 let base_url = std::env::var("OPENAI_API_URL").ok();
42 let mcp_url = std::env::var("MCP_URL").unwrap_or("http://localhost:8000/mcp".to_string());
43
44 let model = OpenAIChatModel::new(api_key.clone(), base_url)
46 .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
47 .with_temperature(0.7)
48 .with_max_tokens(8*1024);
49
50 let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
53
54 mcp_client.clear_tools();
56
57 mcp_client.add_tools(vec![
59 McpTool {
60 name: "get_weather".to_string(),
61 description: format!(
62 "Get weather information for a specified city. For example: 'What's the weather like in Beijing?'.
63 The parameter request body you should extract is: '\"parameters\": {{ \"city\": \"{}\" }}'",
64 "city".to_string()),
65 },
66 McpTool {
67 name: "simple_calculate".to_string(),
68 description: format!(
69 "Execute simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'.
70 The parameter request body you should extract is: '\"parameters\": {{ \"expression\": \"{}\" }}'",
71 "expression".to_string()),
72 },
73 ]);
74
75 mcp_client.register_tool_handler("get_weather".to_string(), |params: HashMap<String, Value>| async move {
77 let default_city = Value::String("Shanghai".to_string());
78 let city_value = params.get("city").unwrap_or(&default_city);
79 let city = city_value.as_str().unwrap_or("Shanghai");
80 Ok(json!({
81 "city": city,
82 "temperature": "25°C",
83 "weather": "Sunny",
84 "humidity": "40%",
85 "updated_at": chrono::Utc::now().to_rfc3339()
86 }))
87 });
88
89 mcp_client.register_tool_handler("simple_calculate".to_string(), |params: HashMap<String, Value>| async move {
90 let expression_value = params.get("expression").ok_or_else(|| Error::msg("Missing calculation expression"))?;
91 let expression = expression_value.as_str().ok_or_else(|| Error::msg("Expression format error"))?;
92
93 let result = parse_and_calculate(expression)?;
95
96 Ok(json!({
97 "expression": expression,
98 "result": result,
99 "calculated_at": chrono::Utc::now().to_rfc3339()
100 }))
101 });
102
103 info!("Using local tools only, not connecting to MCP server...");
105
106 info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
107 info!("Using API URL: {}", model.base_url());
108 info!("----------------------------------------");
109
110 let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
111
112 let memory: Box<dyn BaseMemory> = match memory_type.as_str() {
114 "simple" => {
115 info!("使用SimpleMemory (仅内存记忆)");
116 Box::new(SimpleMemory::new())
117 },
118 "composite" => {
119 info!("使用CompositeMemory (组合记忆 - 支持中长期记忆和摘要记忆)");
120 let memory = CompositeMemory::with_basic_params(
123 PathBuf::from("./data/memory"),
124 summary_threshold,
125 recent_messages_count,
126 ).await.expect("Failed to create composite memory");
127
128 Box::new(memory)
129 },
130 _ => {
131 error!("未知的记忆类型: {}, 使用默认的SimpleMemory", memory_type);
132 Box::new(SimpleMemory::new())
133 }
134 };
135
136 let user_system_prompt = "You are an AI assistant that can use tools to answer user questions. Please decide whether to use tools based on the user's needs.".to_string();
138
139 let mut agent = McpAgent::with_openai_model_and_memory(
140 client_arc.clone(),
141 user_system_prompt,
142 model.clone(),
143 memory
144 );
145
146 if let Err(e) = agent.auto_add_tools().await {
148 error!("Failed to auto add tools from MCP server: {}", e);
149 }
150
151 println!("基于MCP的AI Agent聊天机器人已启动!");
152 println!("记忆类型: {}", memory_type);
153 if memory_type == "composite" {
154 println!("摘要功能: 已启用 (阈值: {} 条消息)", summary_threshold);
155 println!("中长期记忆: 已启用");
156 }
157 println!("输入'退出'结束对话");
158 println!("----------------------------------------");
159 println!("Using tools example:");
160 let tools = client_arc.get_tools().await.unwrap_or_else(|e| {
161 error!("Failed to get tools from MCP server: {}", e);
162 vec![
164 McpTool {
165 name: "get_weather".to_string(),
166 description: "Get the weather information for a specified city. For example: 'What's the weather like in Beijing?'".to_string(),
167 },
168 McpTool {
169 name: "simple_calculate".to_string(),
170 description: "Perform simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'".to_string(),
171 },
172 ]
173 });
174
175 let mut index = 0;
177 for tool in &tools {
178 index += 1;
179
180 println!("{index}. {}: {}", tool.name, tool.description);
181 }
182
183 println!("----------------------------------------");
184 loop {
186 let mut user_input = String::new();
187 println!("你: ");
188 std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
189 println!("");
190 let user_input = user_input.trim();
191
192 if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
193 println!("再见!");
194 break;
195 }
196
197 let mut inputs = HashMap::new();
199 inputs.insert("input".to_string(), serde_json::Value::String(user_input.to_string()));
200
201 match run_agent(&agent, user_input.to_string()).await {
203 Ok(response) => {
204 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
206 if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
207 println!("助手: {}", content);
208 } else {
209 println!("助手: {}", response);
210 }
211 } else {
212 println!("助手: {}", response);
213 }
214 },
215 Err(e) => {
216 println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
217 },
218 }
219
220 info!("----------------------------------------");
221 }
222
223 info!("对话历史:");
225 if let Some(memory) = agent.get_memory() {
226 match memory.load_memory_variables(&HashMap::new()).await {
227 Ok(memories) => {
228 if let Some(chat_history) = memories.get("chat_history") {
229 if let serde_json::Value::Array(messages) = chat_history {
230 info!("总消息数: {}", messages.len());
231 for (i, message) in messages.iter().enumerate() {
232 if let serde_json::Value::Object(msg) = message {
233 let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
234 let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
235 let display_content = if content.len() > 100 {
237 format!("{}...", &content[..100])
238 } else {
239 content.to_string()
240 };
241 info!("{}. {}: {}", i + 1, role, display_content);
242 }
243 }
244 }
245 }
246
247 if let Some(summary) = memories.get("summary") {
249 if let serde_json::Value::String(summary_text) = summary {
250 info!("对话摘要: {}", summary_text);
251 }
252 }
253 },
254 Err(e) => {
255 info!("Failed to load memory variables: {}", e);
256 }
257 }
258 } else {
259 info!("No memory available");
260 }
261
262 if let Err(e) = client_arc.disconnect().await {
264 error!("Failed to disconnect MCP client: {}", e);
265 }
266}
267
268fn parse_and_calculate(expression: &str) -> Result<f64, Error> {
270 let expression = expression.replace(" ", "");
271
272 for op_char in ["+", "-", "*", "/"].iter() {
274 if let Some(pos) = expression.find(op_char) {
275 let left_str = &expression[0..pos];
277 let right_str = &expression[pos + 1..];
278
279 let left = left_str.parse::<f64>().map_err(|e|
281 Error::msg(format!("Left operand format error: {}", e)))?;
282 let right = right_str.parse::<f64>().map_err(|e|
283 Error::msg(format!("Right operand format error: {}", e)))?;
284
285 let result = match *op_char {
287 "+" => left + right,
288 "-" => left - right,
289 "*" => left * right,
290 "/" => {
291 if right == 0.0 {
292 return Err(Error::msg("除数不能为零"));
293 }
294 left / right
295 },
296 _ => unreachable!()
297 };
298
299 return Ok(result);
300 }
301 }
302
303 if let Ok(number) = expression.parse::<f64>() {
305 return Ok(number);
306 }
307
308 Err(Error::msg(format!("Failed to parse expression: {}", expression)))
309}