pub struct CompositeMemory { /* private fields */ }Expand description
Composite memory implementation
This struct combines multiple memory types, providing a unified interface to manage different types of memory. It can simultaneously manage message history and summary memory, and provide intelligent summary generation functionality.
Implementations§
Source§impl CompositeMemory
impl CompositeMemory
Sourcepub async fn with_basic_params(
data_dir: PathBuf,
summary_threshold: usize,
recent_messages_count: usize,
) -> Result<Self>
pub async fn with_basic_params( data_dir: PathBuf, summary_threshold: usize, recent_messages_count: usize, ) -> Result<Self>
Create a composite memory instance with basic parameters This is the recommended constructor, only requires necessary parameters session_id will be automatically generated internally
Examples found in repository?
examples/mcp_agent_hybrid_chatbot.rs (lines 87-91)
15async fn main() {
16 // 初始化日志记录器
17 env_logger::Builder::new()
18 .filter_level(LevelFilter::Info) // 设置日志级别为Error以便查看错误信息
19 .init();
20
21 info!("=== Rust Agent 混合模式示例 ===");
22
23 // 从环境变量获取API密钥和基本URL
24 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
25 let base_url = std::env::var("OPENAI_API_URL").ok();
26 let mcp_url = std::env::var("MCP_URL").unwrap_or("http://127.0.0.1:6000".to_string()); // 默认MCP服务器地址
27
28 // 创建OpenAI模型实例 - 支持Openai兼容 API
29 let model = OpenAIChatModel::new(api_key.clone(), base_url)
30 .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
31 .with_temperature(0.7)
32 .with_max_tokens(8*1024);
33
34 // 初始化MCP客户端
35 let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
36
37 // 清空默认工具(可选)
38 mcp_client.clear_tools();
39
40 // 添加本地自定义工具定义
41 mcp_client.add_tools(vec![
42 McpTool {
43 name: "get_local_time".to_string(),
44 description: "Get the current local time and date. For example: 'What time is it?'".to_string(),
45 },
46 ]);
47
48 // 注册本地工具处理器
49 mcp_client.register_tool_handler("get_local_time".to_string(), |_params: HashMap<String, Value>| async move {
50 let now = chrono::Local::now();
51 Ok(json!({
52 "current_time": now.format("%Y-%m-%d %H:%M:%S").to_string(),
53 "timezone": "Local"
54 }))
55 });
56
57 // 不注册calculate_expression的本地处理器,让其使用服务端工具
58 info!("Using local 'get_local_time' tool and server-side calculation tools...");
59
60 // 尝试连接到 MCP 服务器
61 match mcp_client.connect(&mcp_url).await {
62 Ok(_) => {
63 info!("Successfully connected to MCP server at {}", mcp_url);
64
65 // 设置连接状态为已连接
66 mcp_client.set_server_connected(true);
67
68 // // 获取服务器工具
69 // match mcp_client.get_tools().await {
70 // Ok(server_tools) => {
71 // info!("Retrieved {} tools from MCP server", server_tools.len());
72 // for tool in &server_tools {
73 // info!("Server tool: {} - {}", tool.name, tool.description);
74 // }
75 // },
76 // Err(e) => {
77 // error!("Failed to get tools from MCP server: {}", e);
78 // }
79 // }
80 },
81 Err(e) => {
82 error!("Failed to connect to MCP server: {}", e);
83 }
84 }
85
86 // 创建记忆模块实例 - 使用CompositeMemory替代SimpleMemory
87 let memory = CompositeMemory::with_basic_params(
88 "data".into(), // 组合记忆数据存储目录
89 200, // 摘要阈值(token数量)
90 10 // 保留的最近消息数量
91 ).await.expect("Failed to create CompositeMemory");
92
93 // 创建Agent实例
94 let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
95 let mut agent = McpAgent::with_openai_model_and_memory(
96 client_arc.clone(),
97 "You are an AI assistant that can use both local tools and remote MCP server tools. Please decide whether to use tools based on the user's needs.".to_string(),
98 model.clone(),
99 Box::new(memory.clone())
100 );
101
102 // 自动从MCP客户端获取工具并添加到Agent
103 // 这会同时添加MCP服务器工具和本地工具
104 if let Err(e) = agent.auto_add_tools().await {
105 error!("Warning: Failed to auto add tools to McpAgent: {}", e);
106 }
107
108 info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
109 info!("----------------------------------------");
110
111 println!("基于MCP的混合模式AI Agent聊天机器人已启动!");
112 println!("输入'退出'结束对话");
113 println!("----------------------------------------");
114
115 // 显示可用工具
116 println!("Available tools:");
117 let tools = agent.tools();
118 if tools.is_empty() {
119 println!("No tools available");
120 } else {
121 for (index, tool) in tools.iter().enumerate() {
122 println!("{}. {}: {}", index + 1, tool.name(), tool.description());
123 }
124 }
125 println!("----------------------------------------");
126
127 // 示例对话
128 println!("示例对话:");
129 let examples = vec![
130 "What time is it?",
131 "What is 15.5 plus 24.3?",
132 ];
133
134 for example in examples {
135 println!("你: {}", example);
136
137 // 创建输入上下文
138 let mut inputs = HashMap::new();
139 inputs.insert("input".to_string(), serde_json::Value::String(example.to_string()));
140
141 // 运行Agent
142 match run_agent(&agent, example.to_string()).await {
143 Ok(response) => {
144 // 尝试解析response为JSON并提取content字段
145 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
146 if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
147 println!("助手: {}", content);
148 } else {
149 println!("助手: {}", response);
150 }
151 } else {
152 println!("助手: {}", response);
153 }
154 },
155 Err(e) => {
156 println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
157 },
158 }
159
160 info!("----------------------------------------");
161 }
162
163 // 交互式对话循环
164 println!("现在开始交互式对话(输入'退出'结束对话):");
165 loop {
166 let mut user_input = String::new();
167 println!("你: ");
168 std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
169 let user_input = user_input.trim();
170
171 if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
172 println!("再见!");
173 break;
174 }
175
176 if user_input.is_empty() {
177 continue;
178 }
179
180 // 运行Agent
181 match run_agent(&agent, user_input.to_string()).await {
182 Ok(response) => {
183 // 尝试解析response为JSON并提取content字段
184 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
185 if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
186 println!("助手: {}", content);
187 } else {
188 println!("助手: {}", response);
189 }
190 } else {
191 println!("助手: {}", response);
192 }
193 },
194 Err(e) => {
195 println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
196 },
197 }
198
199 info!("----------------------------------------");
200 }
201
202 // 打印对话历史
203 info!("对话历史:");
204 match memory.load_memory_variables(&HashMap::new()).await {
205 Ok(memories) => {
206 if let Some(chat_history) = memories.get("chat_history") {
207 if let serde_json::Value::Array(messages) = chat_history {
208 for (i, message) in messages.iter().enumerate() {
209 if let serde_json::Value::Object(msg) = message {
210 let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
211 let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
212 info!("{}. {}: {}", i + 1, role, content);
213 }
214 }
215 }
216 }
217 },
218 Err(e) => {
219 info!("Failed to load memory variables: {}", e);
220 }
221 }
222
223 // 断开MCP连接
224 if let Err(e) = client_arc.disconnect().await {
225 error!("Failed to disconnect MCP client: {}", e);
226 }
227}More examples
examples/mcp_agent_local_chatbot.rs (lines 122-126)
16async fn main() {
17 // 初始化日志记录器
18 env_logger::Builder::new()
19 .filter_level(LevelFilter::Info)
20 .init();
21
22 info!("=== Rust Agent 使用示例 ===");
23
24 // 获取记忆类型配置
25 let memory_type = std::env::var("MEMORY_TYPE").unwrap_or_else(|_| "composite".to_string());
26 let summary_threshold = std::env::var("SUMMARY_THRESHOLD")
27 .ok()
28 .and_then(|s| s.parse().ok())
29 .unwrap_or(200);
30 let recent_messages_count = std::env::var("RECENT_MESSAGES_COUNT")
31 .ok()
32 .and_then(|s| s.parse().ok())
33 .unwrap_or(10);
34
35 info!("使用记忆类型: {}", memory_type);
36 info!("摘要阈值: {}", summary_threshold);
37 info!("保留最近消息数: {}", recent_messages_count);
38
39 // 从环境变量获取API密钥和基本URL
40 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
41 let base_url = std::env::var("OPENAI_API_URL").ok();
42 let mcp_url = std::env::var("MCP_URL").unwrap_or("http://localhost:8000/mcp".to_string());
43
44 // 创建OpenAI模型实例 - 支持Openai兼容 API
45 let model = OpenAIChatModel::new(api_key.clone(), base_url)
46 .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
47 .with_temperature(0.7)
48 .with_max_tokens(8*1024);
49
50 // 初始化MCP客户端
51 // 在初始化 MCP 客户端后,自定义工具和工具处理器
52 let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
53
54 // 清空默认工具(可选)
55 mcp_client.clear_tools();
56
57 // 添加自定义工具
58 mcp_client.add_tools(vec![
59 McpTool {
60 name: "get_weather".to_string(),
61 description: format!(
62 "Get weather information for a specified city. For example: 'What's the weather like in Beijing?'.
63 The parameter request body you should extract is: '\"parameters\": {{ \"city\": \"{}\" }}'",
64 "city".to_string()),
65 },
66 McpTool {
67 name: "simple_calculate".to_string(),
68 description: format!(
69 "Execute simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'.
70 The parameter request body you should extract is: '\"parameters\": {{ \"expression\": \"{}\" }}'",
71 "expression".to_string()),
72 },
73 ]);
74
75 // 注册自定义工具处理器
76 mcp_client.register_tool_handler("get_weather".to_string(), |params: HashMap<String, Value>| async move {
77 let default_city = Value::String("Shanghai".to_string());
78 let city_value = params.get("city").unwrap_or(&default_city);
79 let city = city_value.as_str().unwrap_or("Shanghai");
80 Ok(json!({
81 "city": city,
82 "temperature": "25°C",
83 "weather": "Sunny",
84 "humidity": "40%",
85 "updated_at": chrono::Utc::now().to_rfc3339()
86 }))
87 });
88
89 mcp_client.register_tool_handler("simple_calculate".to_string(), |params: HashMap<String, Value>| async move {
90 let expression_value = params.get("expression").ok_or_else(|| Error::msg("Missing calculation expression"))?;
91 let expression = expression_value.as_str().ok_or_else(|| Error::msg("Expression format error"))?;
92
93 // 解析表达式,提取操作数和运算符
94 let result = parse_and_calculate(expression)?;
95
96 Ok(json!({
97 "expression": expression,
98 "result": result,
99 "calculated_at": chrono::Utc::now().to_rfc3339()
100 }))
101 });
102
103 // 不连接到 MCP 服务器,仅使用本地工具
104 info!("Using local tools only, not connecting to MCP server...");
105
106 info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
107 info!("Using API URL: {}", model.base_url());
108 info!("----------------------------------------");
109
110 let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
111
112 // 根据配置创建不同类型的记忆模块实例
113 let memory: Box<dyn BaseMemory> = match memory_type.as_str() {
114 "simple" => {
115 info!("使用SimpleMemory (仅内存记忆)");
116 Box::new(SimpleMemory::new())
117 },
118 "composite" => {
119 info!("使用CompositeMemory (组合记忆 - 支持中长期记忆和摘要记忆)");
120 // 使用新的简化接口,只需要提供必要的参数
121 // session_id将在内部自动生成
122 let memory = CompositeMemory::with_basic_params(
123 PathBuf::from("./data/memory"),
124 summary_threshold,
125 recent_messages_count,
126 ).await.expect("Failed to create composite memory");
127
128 Box::new(memory)
129 },
130 _ => {
131 error!("未知的记忆类型: {}, 使用默认的SimpleMemory", memory_type);
132 Box::new(SimpleMemory::new())
133 }
134 };
135
136 // 创建Agent实例,并传递temperature、max_tokens和memory参数
137 let user_system_prompt = "You are an AI assistant that can use tools to answer user questions. Please decide whether to use tools based on the user's needs.".to_string();
138
139 let mut agent = McpAgent::with_openai_model_and_memory(
140 client_arc.clone(),
141 user_system_prompt,
142 model.clone(),
143 memory
144 );
145
146 // 尝试从MCP服务器自动获取工具并添加到Agent
147 if let Err(e) = agent.auto_add_tools().await {
148 error!("Failed to auto add tools from MCP server: {}", e);
149 }
150
151 println!("基于MCP的AI Agent聊天机器人已启动!");
152 println!("记忆类型: {}", memory_type);
153 if memory_type == "composite" {
154 println!("摘要功能: 已启用 (阈值: {} 条消息)", summary_threshold);
155 println!("中长期记忆: 已启用");
156 }
157 println!("输入'退出'结束对话");
158 println!("----------------------------------------");
159 println!("Using tools example:");
160 let tools = client_arc.get_tools().await.unwrap_or_else(|e| {
161 error!("Failed to get tools from MCP server: {}", e);
162 // 返回本地工具列表
163 vec![
164 McpTool {
165 name: "get_weather".to_string(),
166 description: "Get the weather information for a specified city. For example: 'What's the weather like in Beijing?'".to_string(),
167 },
168 McpTool {
169 name: "simple_calculate".to_string(),
170 description: "Perform simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'".to_string(),
171 },
172 ]
173 });
174
175 // 打印工具列表
176 let mut index = 0;
177 for tool in &tools {
178 index += 1;
179
180 println!("{index}. {}: {}", tool.name, tool.description);
181 }
182
183 println!("----------------------------------------");
184 // 对话循环
185 loop {
186 let mut user_input = String::new();
187 println!("你: ");
188 std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
189 println!("");
190 let user_input = user_input.trim();
191
192 if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
193 println!("再见!");
194 break;
195 }
196
197 // 创建输入上下文
198 let mut inputs = HashMap::new();
199 inputs.insert("input".to_string(), serde_json::Value::String(user_input.to_string()));
200
201 // 运行Agent
202 match run_agent(&agent, user_input.to_string()).await {
203 Ok(response) => {
204 // 尝试解析response为JSON并提取content字段
205 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
206 if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
207 println!("助手: {}", content);
208 } else {
209 println!("助手: {}", response);
210 }
211 } else {
212 println!("助手: {}", response);
213 }
214 },
215 Err(e) => {
216 println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
217 },
218 }
219
220 info!("----------------------------------------");
221 }
222
223 // 打印对话历史
224 info!("对话历史:");
225 if let Some(memory) = agent.get_memory() {
226 match memory.load_memory_variables(&HashMap::new()).await {
227 Ok(memories) => {
228 if let Some(chat_history) = memories.get("chat_history") {
229 if let serde_json::Value::Array(messages) = chat_history {
230 info!("总消息数: {}", messages.len());
231 for (i, message) in messages.iter().enumerate() {
232 if let serde_json::Value::Object(msg) = message {
233 let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
234 let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
235 // 限制内容长度以便显示
236 let display_content = if content.len() > 100 {
237 format!("{}...", &content[..100])
238 } else {
239 content.to_string()
240 };
241 info!("{}. {}: {}", i + 1, role, display_content);
242 }
243 }
244 }
245 }
246
247 // 如果有摘要,也打印出来
248 if let Some(summary) = memories.get("summary") {
249 if let serde_json::Value::String(summary_text) = summary {
250 info!("对话摘要: {}", summary_text);
251 }
252 }
253 },
254 Err(e) => {
255 info!("Failed to load memory variables: {}", e);
256 }
257 }
258 } else {
259 info!("No memory available");
260 }
261
262 // 断开MCP连接
263 if let Err(e) = client_arc.disconnect().await {
264 error!("Failed to disconnect MCP client: {}", e);
265 }
266}Sourcepub async fn with_config(config: CompositeMemoryConfig) -> Result<Self>
pub async fn with_config(config: CompositeMemoryConfig) -> Result<Self>
Create a composite memory instance with configuration
Sourcepub async fn with_session_id(session_id: String) -> Result<Self>
pub async fn with_session_id(session_id: String) -> Result<Self>
Create a composite memory instance with session ID
Sourcepub async fn add_message(&self, message: ChatMessage) -> Result<()>
pub async fn add_message(&self, message: ChatMessage) -> Result<()>
Add message to memory
Sourcepub async fn get_message_count(&self) -> Result<usize>
pub async fn get_message_count(&self) -> Result<usize>
Get message count
Sourcepub async fn get_recent_messages(
&self,
count: usize,
) -> Result<Vec<ChatMessage>>
pub async fn get_recent_messages( &self, count: usize, ) -> Result<Vec<ChatMessage>>
Get the most recent N messages
Sourcepub async fn cleanup_old_messages(&self) -> Result<()>
pub async fn cleanup_old_messages(&self) -> Result<()>
Clean up old messages
Sourcepub async fn get_memory_stats(&self) -> Result<Value>
pub async fn get_memory_stats(&self) -> Result<Value>
Get memory statistics
Sourcepub async fn get_summary(&self) -> Result<Option<String>>
pub async fn get_summary(&self) -> Result<Option<String>>
Get summary content
Trait Implementations§
Source§impl BaseMemory for CompositeMemory
impl BaseMemory for CompositeMemory
fn memory_variables(&self) -> Vec<String>
fn load_memory_variables<'a>( &'a self, inputs: &'a HashMap<String, Value>, ) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>>> + Send + 'a>>
fn save_context<'a>( &'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>, ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
fn clone_box(&self) -> Box<dyn BaseMemory>
fn get_session_id(&self) -> Option<&str>
fn set_session_id(&mut self, session_id: String)
fn get_token_count(&self) -> Result<usize, Error>
fn as_any(&self) -> &dyn Any
Source§impl Clone for CompositeMemory
impl Clone for CompositeMemory
Source§fn clone(&self) -> CompositeMemory
fn clone(&self) -> CompositeMemory
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreAuto Trait Implementations§
impl Freeze for CompositeMemory
impl !RefUnwindSafe for CompositeMemory
impl Send for CompositeMemory
impl Sync for CompositeMemory
impl Unpin for CompositeMemory
impl !UnwindSafe for CompositeMemory
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more