rust_agent/memory/
base.rs1use anyhow::Error;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use serde_json::Value;
7use std::pin::Pin;
8use std::future::Future;
9use log::info;
10
11pub type MemoryVariables = HashMap<String, Value>;
13
14pub trait BaseMemory: Send + Sync {
16 fn memory_variables(&self) -> Vec<String>;
18
19 fn load_memory_variables<'a>(&'a self, inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>>;
21
22 fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
24
25 fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
27
28 fn clone_box(&self) -> Box<dyn BaseMemory>;
30
31 fn get_session_id(&self) -> Option<&str>;
33
34 fn set_session_id(&mut self, session_id: String);
36
37 fn get_token_count(&self) -> Result<usize, Error>;
39
40 fn as_any(&self) -> &dyn std::any::Any;
42}
43
44#[derive(Debug)]
46pub struct SimpleMemory {
47 memories: Arc<RwLock<HashMap<String, Value>>>,
48 memory_key: String,
49 session_id: Option<String>,
50}
51
52impl Clone for SimpleMemory {
53 fn clone(&self) -> Self {
54 Self {
55 memories: Arc::clone(&self.memories),
56 memory_key: self.memory_key.clone(),
57 session_id: self.session_id.clone(),
58 }
59 }
60}
61
62impl SimpleMemory {
63 pub fn new() -> Self {
64 Self {
65 memories: Arc::new(RwLock::new(HashMap::new())),
66 memory_key: "chat_history".to_string(),
67 session_id: None,
68 }
69 }
70
71 pub fn with_memory_key(memory_key: String) -> Self {
72 Self {
73 memories: Arc::new(RwLock::new(HashMap::new())),
74 memory_key,
75 session_id: None,
76 }
77 }
78
79 pub fn with_memories(memories: HashMap<String, Value>) -> Self {
80 Self {
81 memories: Arc::new(RwLock::new(memories)),
82 memory_key: "chat_history".to_string(),
83 session_id: None,
84 }
85 }
86
87 pub async fn add_message(&self, message: Value) -> Result<(), Error> {
88 let mut memories = self.memories.write().await;
89 let chat_history = memories.entry(self.memory_key.clone()).or_insert_with(|| Value::Array(vec![]));
90
91 if let Value::Array(ref mut arr) = chat_history {
92 arr.push(message);
93 } else {
94 *chat_history = Value::Array(vec![message]);
95 }
96
97 Ok(())
98 }
99
100 pub fn get_memory_key(&self) -> String {
101 self.memory_key.clone()
102 }
103}
104
105impl Default for SimpleMemory {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl BaseMemory for SimpleMemory {
112 fn memory_variables(&self) -> Vec<String> {
113 vec![self.memory_key.clone()]
114 }
115
116 fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
117 let memories = Arc::clone(&self.memories);
118 Box::pin(async move {
119 let memories = memories.read().await;
120 Ok(memories.clone())
121 })
122 }
123
124 fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
125 let memories = Arc::clone(&self.memories);
126 let input_clone = inputs.clone();
127 let output_clone = outputs.clone();
128 let memory_key = self.memory_key.clone();
129
130 Box::pin(async move {
131 let mut memories = memories.write().await;
132
133 let chat_history = memories.entry(memory_key.clone()).or_insert_with(|| Value::Array(vec![]));
135
136 if !chat_history.is_array() {
138 *chat_history = Value::Array(vec![]);
139 }
140
141 if let Some(input_value) = input_clone.get("input") {
143 let user_message = serde_json::json!({
144 "role": "human",
145 "content": input_value
146 });
147
148 if let Value::Array(ref mut arr) = chat_history {
149 info!("Adding to chat history: {:?}", user_message);
150 arr.push(user_message);
151 }
152 }
153
154 if let Some(output_value) = output_clone.get("output") {
156 let ai_message = serde_json::json!({
157 "role": "ai",
158 "content": output_value
159 });
160
161 if let Value::Array(ref mut arr) = chat_history {
162 arr.push(ai_message);
163 }
164 }
165
166 Ok(())
167 })
168 }
169
170 fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
171 let memories = Arc::clone(&self.memories);
172 Box::pin(async move {
173 let mut memories = memories.write().await;
174 memories.clear();
175 Ok(())
176 })
177 }
178
179 fn clone_box(&self) -> Box<dyn BaseMemory> {
180 Box::new(self.clone())
181 }
182
183 fn get_session_id(&self) -> Option<&str> {
184 self.session_id.as_deref()
185 }
186
187 fn set_session_id(&mut self, session_id: String) {
188 self.session_id = Some(session_id);
189 }
190
191 fn get_token_count(&self) -> Result<usize, Error> {
192 let count = self.memory_key.len() + self.session_id.as_ref().map(|s| s.len()).unwrap_or(0);
194 Ok(count)
195 }
196
197 fn as_any(&self) -> &dyn std::any::Any {
198 self
199 }
200}
201
202impl Clone for Box<dyn BaseMemory> {
204 fn clone(&self) -> Self {
205 self.as_ref().clone_box()
206 }
207}