1#![allow(dead_code)]
5
6use crate::memory::{ExperienceType, Memory, MemoryId};
7use petgraph::dot::{Config, Dot};
8use petgraph::graph::{DiGraph, NodeIndex};
9use serde::Serialize;
10use std::collections::HashMap;
11use std::fmt;
12use tracing::{debug, info, trace};
13
14#[derive(Debug, Clone)]
16pub enum MemoryNode {
17 WorkingMemory {
18 id: MemoryId,
19 importance: f32,
20 },
21 SessionMemory {
22 id: MemoryId,
23 importance: f32,
24 },
25 LongTermMemory {
26 id: MemoryId,
27 importance: f32,
28 compressed: bool,
29 },
30 Experience {
31 exp_type: ExperienceType,
32 content: String,
33 },
34 Context {
35 context_id: String,
36 decay: f32,
37 },
38}
39
40impl fmt::Display for MemoryNode {
41 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 match self {
43 MemoryNode::WorkingMemory { id: _, importance } => {
44 write!(f, "WM\\n{importance:.2}")
45 }
46 MemoryNode::SessionMemory { id: _, importance } => {
47 write!(f, "SM\\n{importance:.2}")
48 }
49 MemoryNode::LongTermMemory {
50 id: _,
51 importance,
52 compressed,
53 } => {
54 write!(
55 f,
56 "LTM\\n{:.2}{}",
57 importance,
58 if *compressed { "🗜️" } else { "" }
59 )
60 }
61 MemoryNode::Experience { exp_type, .. } => {
62 write!(f, "{exp_type:?}")
63 }
64 MemoryNode::Context {
65 context_id: _,
66 decay,
67 } => {
68 write!(f, "CTX\\n{decay:.2}")
69 }
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub enum MemoryEdge {
77 Promotion, SemanticSimilarity(f32), TemporalSuccession, CausalLink, ContextRelation, }
83
84impl fmt::Display for MemoryEdge {
85 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86 match self {
87 MemoryEdge::Promotion => write!(f, "→"),
88 MemoryEdge::SemanticSimilarity(score) => write!(f, "~{score:.2}"),
89 MemoryEdge::TemporalSuccession => write!(f, "⏭"),
90 MemoryEdge::CausalLink => write!(f, "⚡"),
91 MemoryEdge::ContextRelation => write!(f, "⊕"),
92 }
93 }
94}
95
96pub struct MemoryGraph {
98 graph: DiGraph<MemoryNode, MemoryEdge>,
99 node_map: HashMap<String, NodeIndex>,
100}
101
102impl Default for MemoryGraph {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl MemoryGraph {
109 pub fn new() -> Self {
110 Self {
111 graph: DiGraph::new(),
112 node_map: HashMap::new(),
113 }
114 }
115
116 pub fn add_memory(&mut self, memory: &Memory, tier: &str) -> NodeIndex {
118 let key = format!("{}_{}", tier, memory.id.0);
119
120 if let Some(&idx) = self.node_map.get(&key) {
121 return idx;
122 }
123
124 let node = match tier {
125 "working" => MemoryNode::WorkingMemory {
126 id: memory.id.clone(),
127 importance: memory.importance(),
128 },
129 "session" => MemoryNode::SessionMemory {
130 id: memory.id.clone(),
131 importance: memory.importance(),
132 },
133 "longterm" => MemoryNode::LongTermMemory {
134 id: memory.id.clone(),
135 importance: memory.importance(),
136 compressed: memory.compressed,
137 },
138 _ => {
139 tracing::error!(
140 "Invalid tier '{}' passed to add_memory for memory {}, defaulting to WorkingMemory",
141 tier,
142 memory.id.0
143 );
144 MemoryNode::WorkingMemory {
145 id: memory.id.clone(),
146 importance: memory.importance(),
147 }
148 }
149 };
150
151 let idx = self.graph.add_node(node);
152 self.node_map.insert(key, idx);
153 idx
154 }
155
156 pub fn add_experience(&mut self, exp_type: ExperienceType, content: &str) -> NodeIndex {
158 let node = MemoryNode::Experience {
159 exp_type,
160 content: content.chars().take(50).collect(),
161 };
162 self.graph.add_node(node)
163 }
164
165 pub fn add_context(&mut self, context_id: &str, decay: f32) -> NodeIndex {
167 let node = MemoryNode::Context {
168 context_id: context_id.to_string(),
169 decay,
170 };
171 self.graph.add_node(node)
172 }
173
174 pub fn add_edge(&mut self, from: NodeIndex, to: NodeIndex, edge_type: MemoryEdge) {
176 self.graph.add_edge(from, to, edge_type);
177 }
178
179 pub fn log_promotion(&mut self, from_tier: &str, to_tier: &str, memory_id: &MemoryId) {
181 let from_key = format!("{}_{}", from_tier, memory_id.0);
182 let to_key = format!("{}_{}", to_tier, memory_id.0);
183
184 if let (Some(&from_idx), Some(&to_idx)) =
185 (self.node_map.get(&from_key), self.node_map.get(&to_key))
186 {
187 self.add_edge(from_idx, to_idx, MemoryEdge::Promotion);
188 debug!(
189 from = from_tier.to_uppercase().as_str(),
190 to = to_tier.to_uppercase().as_str(),
191 "Graph tier promotion"
192 );
193 }
194 }
195
196 pub fn to_dot(&self) -> String {
198 format!(
199 "{:?}",
200 Dot::with_config(&self.graph, &[Config::EdgeNoLabel])
201 )
202 }
203
204 pub fn stats(&self) -> GraphStats {
206 GraphStats {
207 total_nodes: self.graph.node_count(),
208 total_edges: self.graph.edge_count(),
209 working_memory_count: self.count_tier("working"),
210 session_memory_count: self.count_tier("session"),
211 longterm_memory_count: self.count_tier("longterm"),
212 }
213 }
214
215 fn count_tier(&self, tier: &str) -> usize {
216 self.node_map
217 .keys()
218 .filter(|k| k.starts_with(&format!("{tier}_")))
219 .count()
220 }
221
222 pub fn print_ascii_visualization(&self) {
224 let stats = self.stats();
225
226 info!(
227 working = stats.working_memory_count,
228 session = stats.session_memory_count,
229 longterm = stats.longterm_memory_count,
230 nodes = stats.total_nodes,
231 edges = stats.total_edges,
232 "Memory system visualization: working={}, session={}, longterm={}, nodes={}, edges={}",
233 stats.working_memory_count,
234 stats.session_memory_count,
235 stats.longterm_memory_count,
236 stats.total_nodes,
237 stats.total_edges,
238 );
239 }
240}
241
242#[derive(Debug, Clone, Serialize)]
244pub struct GraphStats {
245 pub total_nodes: usize,
246 pub total_edges: usize,
247 pub working_memory_count: usize,
248 pub session_memory_count: usize,
249 pub longterm_memory_count: usize,
250}
251
252pub struct MemoryLogger {
254 pub graph: MemoryGraph,
255 enabled: bool,
256}
257
258impl MemoryLogger {
259 pub fn new(enabled: bool) -> Self {
260 Self {
261 graph: MemoryGraph::new(),
262 enabled,
263 }
264 }
265
266 pub fn log_created(&mut self, memory: &Memory, tier: &str) {
268 if !self.enabled {
269 return;
270 }
271
272 debug!(
273 tier = tier.to_uppercase().as_str(),
274 importance = memory.importance(),
275 experience_type = ?memory.experience.experience_type,
276 "Memory created"
277 );
278
279 self.graph.add_memory(memory, tier);
280 }
281
282 pub fn log_accessed(&self, memory_id: &MemoryId, tier: &str) {
284 if !self.enabled {
285 return;
286 }
287
288 trace!(
289 tier = tier.to_uppercase().as_str(),
290 memory_id = %memory_id.0,
291 "Memory accessed"
292 );
293 }
294
295 pub fn log_promoted(&mut self, memory_id: &MemoryId, from: &str, to: &str, count: usize) {
297 if !self.enabled {
298 return;
299 }
300
301 debug!(
302 from = from.to_uppercase().as_str(),
303 to = to.to_uppercase().as_str(),
304 count,
305 "Memory tier promotion"
306 );
307
308 self.graph.log_promotion(from, to, memory_id);
309 }
310
311 pub fn log_compressed(
313 &self,
314 _memory_id: &MemoryId,
315 original_size: usize,
316 compressed_size: usize,
317 ) {
318 if !self.enabled {
319 return;
320 }
321
322 let ratio = (compressed_size as f32 / original_size as f32 * 100.0) as usize;
323 debug!(original_size, compressed_size, ratio, "Memory compressed");
324 }
325
326 pub fn log_retrieved(&self, query: &str, result_count: usize, sources: &[&str]) {
328 if !self.enabled {
329 return;
330 }
331
332 debug!(
333 query = %query.chars().take(50).collect::<String>(),
334 result_count,
335 sources = %sources.join(", "),
336 "Memory retrieved"
337 );
338 }
339
340 pub fn show_visualization(&self) {
342 if !self.enabled {
343 return;
344 }
345
346 self.graph.print_ascii_visualization();
347 }
348
349 pub fn export_dot(&self, path: &std::path::Path) -> anyhow::Result<()> {
351 if !self.enabled {
352 return Ok(());
353 }
354
355 let dot = self.graph.to_dot();
356 std::fs::write(path, dot)?;
357 info!(path = %path.display(), "Graph exported");
358 Ok(())
359 }
360
361 pub fn get_stats(&self) -> GraphStats {
363 self.graph.stats()
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::memory::{Experience, ExperienceType, Memory, MemoryId};
371
372 fn create_test_memory() -> Memory {
373 use uuid::Uuid;
374
375 let experience = Experience {
376 experience_type: ExperienceType::Conversation,
377 content: "test content".to_string(),
378 ..Default::default()
379 };
380
381 Memory::new(
382 MemoryId(Uuid::new_v4()),
383 experience,
384 0.5, None, None, None, None, )
390 }
391
392 #[test]
393 fn test_add_memory_with_valid_tiers() {
394 let mut graph = MemoryGraph::new();
395 let memory = create_test_memory();
396
397 let idx1 = graph.add_memory(&memory, "working");
399 let idx2 = graph.add_memory(&memory, "session");
400 let idx3 = graph.add_memory(&memory, "longterm");
401
402 assert_eq!(graph.graph.node_count(), 3);
404 assert!(idx1 != idx2 && idx2 != idx3 && idx1 != idx3);
405 }
406
407 #[test]
408 fn test_add_memory_with_invalid_tier_does_not_panic() {
409 let mut graph = MemoryGraph::new();
410 let memory = create_test_memory();
411
412 let idx = graph.add_memory(&memory, "invalid_tier_name");
414
415 assert_eq!(graph.graph.node_count(), 1);
417
418 assert!(graph.graph.node_weight(idx).is_some());
420
421 let node = graph.graph.node_weight(idx).unwrap();
423 match node {
424 MemoryNode::WorkingMemory { .. } => {
425 }
427 _ => panic!("Expected WorkingMemory for invalid tier, got {node:?}"),
428 }
429 }
430
431 #[test]
432 fn test_add_memory_with_various_invalid_tiers() {
433 let mut graph = MemoryGraph::new();
434 let memory = create_test_memory();
435
436 let invalid_tiers = vec![
438 "",
439 "Working",
440 "WORKING",
441 "long-term",
442 "unknown",
443 "123",
444 "session_memory",
445 ];
446
447 for tier in invalid_tiers {
448 let _ = graph.add_memory(&memory, tier);
449 }
450
451 assert_eq!(graph.graph.node_count(), 7);
453 }
454}