CompositeMemory

Struct CompositeMemory 

Source
pub struct CompositeMemory { /* private fields */ }
Expand description

Composite memory implementation

This struct combines multiple memory types, providing a unified interface to manage different types of memory. It can simultaneously manage message history and summary memory, and provide intelligent summary generation functionality.

Implementations§

Source§

impl CompositeMemory

Source

pub async fn new() -> Result<Self>

Create a new composite memory instance

Source

pub async fn with_basic_params( data_dir: PathBuf, summary_threshold: usize, recent_messages_count: usize, ) -> Result<Self>

Create a composite memory instance with basic parameters This is the recommended constructor, only requires necessary parameters session_id will be automatically generated internally

Examples found in repository?
examples/mcp_agent_hybrid_chatbot.rs (lines 87-91)
15async fn main() {
16    // 初始化日志记录器
17    env_logger::Builder::new()
18        .filter_level(LevelFilter::Info)  // 设置日志级别为Error以便查看错误信息
19        .init();
20    
21    info!("=== Rust Agent 混合模式示例 ===");
22    
23    // 从环境变量获取API密钥和基本URL
24    let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
25    let base_url = std::env::var("OPENAI_API_URL").ok();
26    let mcp_url = std::env::var("MCP_URL").unwrap_or("http://127.0.0.1:6000".to_string());  // 默认MCP服务器地址
27    
28    // 创建OpenAI模型实例 - 支持Openai兼容 API
29    let model = OpenAIChatModel::new(api_key.clone(), base_url)
30        .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
31        .with_temperature(0.7)
32        .with_max_tokens(8*1024);
33    
34    // 初始化MCP客户端
35    let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
36    
37    // 清空默认工具(可选)
38    mcp_client.clear_tools();
39    
40    // 添加本地自定义工具定义
41    mcp_client.add_tools(vec![
42        McpTool {
43            name: "get_local_time".to_string(),
44            description: "Get the current local time and date. For example: 'What time is it?'".to_string(),
45        },
46    ]);
47    
48    // 注册本地工具处理器
49    mcp_client.register_tool_handler("get_local_time".to_string(), |_params: HashMap<String, Value>| async move {
50        let now = chrono::Local::now();
51        Ok(json!({
52            "current_time": now.format("%Y-%m-%d %H:%M:%S").to_string(),
53            "timezone": "Local"
54        }))
55    });
56    
57    // 不注册calculate_expression的本地处理器,让其使用服务端工具
58    info!("Using local 'get_local_time' tool and server-side calculation tools...");
59    
60    // 尝试连接到 MCP 服务器
61    match mcp_client.connect(&mcp_url).await {
62        Ok(_) => {
63            info!("Successfully connected to MCP server at {}", mcp_url);
64            
65            // 设置连接状态为已连接
66            mcp_client.set_server_connected(true);
67            
68            // // 获取服务器工具
69            // match mcp_client.get_tools().await {
70            //     Ok(server_tools) => {
71            //         info!("Retrieved {} tools from MCP server", server_tools.len());
72            //         for tool in &server_tools {
73            //             info!("Server tool: {} - {}", tool.name, tool.description);
74            //         }
75            //     },
76            //     Err(e) => {
77            //         error!("Failed to get tools from MCP server: {}", e);
78            //     }
79            // }
80        },
81        Err(e) => {
82            error!("Failed to connect to MCP server: {}", e);
83        }
84    }
85
86    // 创建记忆模块实例 - 使用CompositeMemory替代SimpleMemory
87    let memory = CompositeMemory::with_basic_params(
88        "data".into(),           // 组合记忆数据存储目录
89        200,           // 摘要阈值(token数量)
90        10          // 保留的最近消息数量
91    ).await.expect("Failed to create CompositeMemory");
92
93    // 创建Agent实例
94    let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
95    let mut agent = McpAgent::with_openai_model_and_memory(
96        client_arc.clone(),
97        "You are an AI assistant that can use both local tools and remote MCP server tools. Please decide whether to use tools based on the user's needs.".to_string(),
98        model.clone(),
99        Box::new(memory.clone())
100    );
101    
102    // 自动从MCP客户端获取工具并添加到Agent
103    // 这会同时添加MCP服务器工具和本地工具
104    if let Err(e) = agent.auto_add_tools().await {
105        error!("Warning: Failed to auto add tools to McpAgent: {}", e);
106    }
107
108    info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
109    info!("----------------------------------------");
110    
111    println!("基于MCP的混合模式AI Agent聊天机器人已启动!");
112    println!("输入'退出'结束对话");
113    println!("----------------------------------------");
114    
115    // 显示可用工具
116    println!("Available tools:");
117    let tools = agent.tools();
118    if tools.is_empty() {
119        println!("No tools available");
120    } else {
121        for (index, tool) in tools.iter().enumerate() {
122            println!("{}. {}: {}", index + 1, tool.name(), tool.description());
123        }
124    }
125    println!("----------------------------------------");
126    
127    // 示例对话
128    println!("示例对话:");
129    let examples = vec![
130        "What time is it?",
131        "What is 15.5 plus 24.3?",
132    ];
133    
134    for example in examples {
135        println!("你: {}", example);
136        
137        // 创建输入上下文
138        let mut inputs = HashMap::new();
139        inputs.insert("input".to_string(), serde_json::Value::String(example.to_string()));
140        
141        // 运行Agent
142        match run_agent(&agent, example.to_string()).await {
143            Ok(response) => {
144                // 尝试解析response为JSON并提取content字段
145                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
146                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
147                        println!("助手: {}", content);
148                    } else {
149                        println!("助手: {}", response);
150                    }
151                } else {
152                    println!("助手: {}", response);
153                }
154            },
155            Err(e) => {
156                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
157            },
158        }
159        
160        info!("----------------------------------------");
161    }
162    
163    // 交互式对话循环
164    println!("现在开始交互式对话(输入'退出'结束对话):");
165    loop {
166        let mut user_input = String::new();
167        println!("你: ");
168        std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
169        let user_input = user_input.trim();
170        
171        if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
172            println!("再见!");
173            break;
174        }
175        
176        if user_input.is_empty() {
177            continue;
178        }
179        
180        // 运行Agent
181        match run_agent(&agent, user_input.to_string()).await {
182            Ok(response) => {
183                // 尝试解析response为JSON并提取content字段
184                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
185                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
186                        println!("助手: {}", content);
187                    } else {
188                        println!("助手: {}", response);
189                    }
190                } else {
191                    println!("助手: {}", response);
192                }
193            },
194            Err(e) => {
195                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
196            },
197        }
198        
199        info!("----------------------------------------");
200    }
201    
202    // 打印对话历史
203    info!("对话历史:");
204    match memory.load_memory_variables(&HashMap::new()).await {
205        Ok(memories) => {
206            if let Some(chat_history) = memories.get("chat_history") {
207                if let serde_json::Value::Array(messages) = chat_history {
208                    for (i, message) in messages.iter().enumerate() {
209                        if let serde_json::Value::Object(msg) = message {
210                            let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
211                            let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
212                            info!("{}. {}: {}", i + 1, role, content);
213                        }
214                    }
215                }
216            }
217        },
218        Err(e) => {
219            info!("Failed to load memory variables: {}", e);
220        }
221    }
222    
223    // 断开MCP连接
224    if let Err(e) = client_arc.disconnect().await {
225        error!("Failed to disconnect MCP client: {}", e);
226    }
227}
More examples
Hide additional examples
examples/mcp_agent_local_chatbot.rs (lines 122-126)
16async fn main() {
17    // 初始化日志记录器
18    env_logger::Builder::new()
19        .filter_level(LevelFilter::Info)
20        .init();
21    
22    info!("=== Rust Agent 使用示例 ===");
23    
24    // 获取记忆类型配置
25    let memory_type = std::env::var("MEMORY_TYPE").unwrap_or_else(|_| "composite".to_string());
26    let summary_threshold = std::env::var("SUMMARY_THRESHOLD")
27        .ok()
28        .and_then(|s| s.parse().ok())
29        .unwrap_or(200);
30    let recent_messages_count = std::env::var("RECENT_MESSAGES_COUNT")
31        .ok()
32        .and_then(|s| s.parse().ok())
33        .unwrap_or(10);
34    
35    info!("使用记忆类型: {}", memory_type);
36    info!("摘要阈值: {}", summary_threshold);
37    info!("保留最近消息数: {}", recent_messages_count);
38    
39    // 从环境变量获取API密钥和基本URL
40    let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
41    let base_url = std::env::var("OPENAI_API_URL").ok();
42    let mcp_url = std::env::var("MCP_URL").unwrap_or("http://localhost:8000/mcp".to_string());
43    
44    // 创建OpenAI模型实例 - 支持Openai兼容 API
45    let model = OpenAIChatModel::new(api_key.clone(), base_url)
46        .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
47        .with_temperature(0.7)
48        .with_max_tokens(8*1024);
49    
50    // 初始化MCP客户端
51    // 在初始化 MCP 客户端后,自定义工具和工具处理器
52    let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
53    
54    // 清空默认工具(可选)
55    mcp_client.clear_tools();
56    
57    // 添加自定义工具
58    mcp_client.add_tools(vec![
59        McpTool {
60            name: "get_weather".to_string(),
61            description: format!(
62                "Get weather information for a specified city. For example: 'What's the weather like in Beijing?'.
63                The parameter request body you should extract is: '\"parameters\": {{ \"city\": \"{}\" }}'",
64                "city".to_string()),
65        },
66        McpTool {
67            name: "simple_calculate".to_string(),
68            description: format!(
69                "Execute simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'.
70                The parameter request body you should extract is: '\"parameters\": {{ \"expression\": \"{}\" }}'",
71                "expression".to_string()),
72        },
73    ]);
74    
75    // 注册自定义工具处理器
76    mcp_client.register_tool_handler("get_weather".to_string(), |params: HashMap<String, Value>| async move {
77        let default_city = Value::String("Shanghai".to_string());
78        let city_value = params.get("city").unwrap_or(&default_city);
79        let city = city_value.as_str().unwrap_or("Shanghai");
80        Ok(json!({
81            "city": city,
82            "temperature": "25°C",
83            "weather": "Sunny",
84            "humidity": "40%",
85            "updated_at": chrono::Utc::now().to_rfc3339()
86        }))
87    });
88    
89    mcp_client.register_tool_handler("simple_calculate".to_string(), |params: HashMap<String, Value>| async move {
90        let expression_value = params.get("expression").ok_or_else(|| Error::msg("Missing calculation expression"))?;
91        let expression = expression_value.as_str().ok_or_else(|| Error::msg("Expression format error"))?;
92        
93        // 解析表达式,提取操作数和运算符
94        let result = parse_and_calculate(expression)?;
95        
96        Ok(json!({
97            "expression": expression,
98            "result": result,
99            "calculated_at": chrono::Utc::now().to_rfc3339()
100        }))
101    });
102    
103    // 不连接到 MCP 服务器,仅使用本地工具
104    info!("Using local tools only, not connecting to MCP server...");
105
106    info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
107    info!("Using API URL: {}", model.base_url());
108    info!("----------------------------------------");
109    
110    let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
111    
112    // 根据配置创建不同类型的记忆模块实例
113    let memory: Box<dyn BaseMemory> = match memory_type.as_str() {
114        "simple" => {
115            info!("使用SimpleMemory (仅内存记忆)");
116            Box::new(SimpleMemory::new())
117        },
118        "composite" => {
119            info!("使用CompositeMemory (组合记忆 - 支持中长期记忆和摘要记忆)");
120            // 使用新的简化接口,只需要提供必要的参数
121            // session_id将在内部自动生成
122            let memory = CompositeMemory::with_basic_params(
123                PathBuf::from("./data/memory"),
124                summary_threshold,
125                recent_messages_count,
126            ).await.expect("Failed to create composite memory");
127            
128            Box::new(memory)
129        },
130        _ => {
131            error!("未知的记忆类型: {}, 使用默认的SimpleMemory", memory_type);
132            Box::new(SimpleMemory::new())
133        }
134    };
135
136    // 创建Agent实例,并传递temperature、max_tokens和memory参数
137    let user_system_prompt = "You are an AI assistant that can use tools to answer user questions. Please decide whether to use tools based on the user's needs.".to_string();
138    
139    let mut agent = McpAgent::with_openai_model_and_memory(
140        client_arc.clone(),
141        user_system_prompt,
142        model.clone(),
143        memory
144    );
145    
146    // 尝试从MCP服务器自动获取工具并添加到Agent
147    if let Err(e) = agent.auto_add_tools().await {
148        error!("Failed to auto add tools from MCP server: {}", e);
149    }
150    
151    println!("基于MCP的AI Agent聊天机器人已启动!");
152    println!("记忆类型: {}", memory_type);
153    if memory_type == "composite" {
154        println!("摘要功能: 已启用 (阈值: {} 条消息)", summary_threshold);
155        println!("中长期记忆: 已启用");
156    }
157    println!("输入'退出'结束对话");
158    println!("----------------------------------------");
159    println!("Using tools example:");
160    let tools = client_arc.get_tools().await.unwrap_or_else(|e| {
161        error!("Failed to get tools from MCP server: {}", e);
162        // 返回本地工具列表
163        vec![
164            McpTool {
165                name: "get_weather".to_string(),
166                description: "Get the weather information for a specified city. For example: 'What's the weather like in Beijing?'".to_string(),
167            },
168            McpTool {
169                name: "simple_calculate".to_string(),
170                description: "Perform simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'".to_string(),
171            },
172        ]
173    });
174    
175    // 打印工具列表
176    let mut index = 0;
177    for tool in &tools {
178        index += 1;
179
180        println!("{index}. {}: {}", tool.name, tool.description);
181    }
182    
183    println!("----------------------------------------");
184    // 对话循环
185    loop {
186        let mut user_input = String::new();
187        println!("你: ");
188        std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
189        println!("");
190        let user_input = user_input.trim();
191        
192        if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
193            println!("再见!");
194            break;
195        }
196        
197        // 创建输入上下文
198        let mut inputs = HashMap::new();
199        inputs.insert("input".to_string(), serde_json::Value::String(user_input.to_string()));
200        
201        // 运行Agent
202        match run_agent(&agent, user_input.to_string()).await {
203            Ok(response) => {
204                // 尝试解析response为JSON并提取content字段
205                if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
206                    if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
207                        println!("助手: {}", content);
208                    } else {
209                        println!("助手: {}", response);
210                    }
211                } else {
212                    println!("助手: {}", response);
213                }
214            },
215            Err(e) => {
216                println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
217            },
218        }
219        
220        info!("----------------------------------------");
221    }
222    
223    // 打印对话历史
224    info!("对话历史:");
225    if let Some(memory) = agent.get_memory() {
226        match memory.load_memory_variables(&HashMap::new()).await {
227            Ok(memories) => {
228                if let Some(chat_history) = memories.get("chat_history") {
229                    if let serde_json::Value::Array(messages) = chat_history {
230                        info!("总消息数: {}", messages.len());
231                        for (i, message) in messages.iter().enumerate() {
232                            if let serde_json::Value::Object(msg) = message {
233                                let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
234                                let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
235                                // 限制内容长度以便显示
236                                let display_content = if content.len() > 100 {
237                                    format!("{}...", &content[..100])
238                                } else {
239                                    content.to_string()
240                                };
241                                info!("{}. {}: {}", i + 1, role, display_content);
242                            }
243                        }
244                    }
245                }
246                
247                // 如果有摘要,也打印出来
248                if let Some(summary) = memories.get("summary") {
249                    if let serde_json::Value::String(summary_text) = summary {
250                        info!("对话摘要: {}", summary_text);
251                    }
252                }
253            },
254            Err(e) => {
255                info!("Failed to load memory variables: {}", e);
256            }
257        }
258    } else {
259        info!("No memory available");
260    }
261    
262    // 断开MCP连接
263    if let Err(e) = client_arc.disconnect().await {
264        error!("Failed to disconnect MCP client: {}", e);
265    }
266}
Source

pub async fn with_config(config: CompositeMemoryConfig) -> Result<Self>

Create a composite memory instance with configuration

Source

pub async fn with_session_id(session_id: String) -> Result<Self>

Create a composite memory instance with session ID

Source

pub async fn add_message(&self, message: ChatMessage) -> Result<()>

Add message to memory

Source

pub async fn get_message_count(&self) -> Result<usize>

Get message count

Source

pub async fn get_recent_messages( &self, count: usize, ) -> Result<Vec<ChatMessage>>

Get the most recent N messages

Source

pub async fn cleanup_old_messages(&self) -> Result<()>

Clean up old messages

Source

pub async fn get_memory_stats(&self) -> Result<Value>

Get memory statistics

Source

pub async fn get_summary(&self) -> Result<Option<String>>

Get summary content

Source§

impl CompositeMemory

Source

pub fn as_any(&self) -> &dyn Any

Get Any reference for type conversion

Trait Implementations§

Source§

impl BaseMemory for CompositeMemory

Source§

fn memory_variables(&self) -> Vec<String>

Source§

fn load_memory_variables<'a>( &'a self, inputs: &'a HashMap<String, Value>, ) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>>> + Send + 'a>>

Source§

fn save_context<'a>( &'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>, ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>

Source§

fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>

Source§

fn clone_box(&self) -> Box<dyn BaseMemory>

Source§

fn get_session_id(&self) -> Option<&str>

Source§

fn set_session_id(&mut self, session_id: String)

Source§

fn get_token_count(&self) -> Result<usize, Error>

Source§

fn as_any(&self) -> &dyn Any

Source§

impl Clone for CompositeMemory

Source§

fn clone(&self) -> CompositeMemory

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for CompositeMemory

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> FromRef<T> for T
where T: Clone,

Source§

fn from_ref(input: &T) -> T

Converts to this type from a reference to the input type.
Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,

Source§

impl<A, B, T> HttpServerConnExec<A, B> for T
where B: Body,