Skip to main content

shodh_memory/memory/
visualization.rs

1//! Memory System Visualization using Petgraph
2//! Creates a real-time graph of memory connections like a neural network
3
4#![allow(dead_code)]
5
6use crate::memory::{ExperienceType, Memory, MemoryId};
7use petgraph::dot::{Config, Dot};
8use petgraph::graph::{DiGraph, NodeIndex};
9use serde::Serialize;
10use std::collections::HashMap;
11use std::fmt;
12use tracing::{debug, info, trace};
13
14/// Node type in the memory graph
15#[derive(Debug, Clone)]
16pub enum MemoryNode {
17    WorkingMemory {
18        id: MemoryId,
19        importance: f32,
20    },
21    SessionMemory {
22        id: MemoryId,
23        importance: f32,
24    },
25    LongTermMemory {
26        id: MemoryId,
27        importance: f32,
28        compressed: bool,
29    },
30    Experience {
31        exp_type: ExperienceType,
32        content: String,
33    },
34    Context {
35        context_id: String,
36        decay: f32,
37    },
38}
39
40impl fmt::Display for MemoryNode {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        match self {
43            MemoryNode::WorkingMemory { id: _, importance } => {
44                write!(f, "WM\\n{importance:.2}")
45            }
46            MemoryNode::SessionMemory { id: _, importance } => {
47                write!(f, "SM\\n{importance:.2}")
48            }
49            MemoryNode::LongTermMemory {
50                id: _,
51                importance,
52                compressed,
53            } => {
54                write!(
55                    f,
56                    "LTM\\n{:.2}{}",
57                    importance,
58                    if *compressed { "🗜️" } else { "" }
59                )
60            }
61            MemoryNode::Experience { exp_type, .. } => {
62                write!(f, "{exp_type:?}")
63            }
64            MemoryNode::Context {
65                context_id: _,
66                decay,
67            } => {
68                write!(f, "CTX\\n{decay:.2}")
69            }
70        }
71    }
72}
73
74/// Edge type in the memory graph
75#[derive(Debug, Clone)]
76pub enum MemoryEdge {
77    Promotion,               // Working -> Session -> LongTerm
78    SemanticSimilarity(f32), // Similarity score
79    TemporalSuccession,      // A happened after B
80    CausalLink,              // A caused B
81    ContextRelation,         // Related through context
82}
83
84impl fmt::Display for MemoryEdge {
85    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86        match self {
87            MemoryEdge::Promotion => write!(f, "→"),
88            MemoryEdge::SemanticSimilarity(score) => write!(f, "~{score:.2}"),
89            MemoryEdge::TemporalSuccession => write!(f, "⏭"),
90            MemoryEdge::CausalLink => write!(f, "⚡"),
91            MemoryEdge::ContextRelation => write!(f, "⊕"),
92        }
93    }
94}
95
96/// Memory visualization graph
97pub struct MemoryGraph {
98    graph: DiGraph<MemoryNode, MemoryEdge>,
99    node_map: HashMap<String, NodeIndex>,
100}
101
102impl Default for MemoryGraph {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl MemoryGraph {
109    pub fn new() -> Self {
110        Self {
111            graph: DiGraph::new(),
112            node_map: HashMap::new(),
113        }
114    }
115
116    /// Add a memory to the graph
117    pub fn add_memory(&mut self, memory: &Memory, tier: &str) -> NodeIndex {
118        let key = format!("{}_{}", tier, memory.id.0);
119
120        if let Some(&idx) = self.node_map.get(&key) {
121            return idx;
122        }
123
124        let node = match tier {
125            "working" => MemoryNode::WorkingMemory {
126                id: memory.id.clone(),
127                importance: memory.importance(),
128            },
129            "session" => MemoryNode::SessionMemory {
130                id: memory.id.clone(),
131                importance: memory.importance(),
132            },
133            "longterm" => MemoryNode::LongTermMemory {
134                id: memory.id.clone(),
135                importance: memory.importance(),
136                compressed: memory.compressed,
137            },
138            _ => {
139                tracing::error!(
140                    "Invalid tier '{}' passed to add_memory for memory {}, defaulting to WorkingMemory",
141                    tier,
142                    memory.id.0
143                );
144                MemoryNode::WorkingMemory {
145                    id: memory.id.clone(),
146                    importance: memory.importance(),
147                }
148            }
149        };
150
151        let idx = self.graph.add_node(node);
152        self.node_map.insert(key, idx);
153        idx
154    }
155
156    /// Add an experience node
157    pub fn add_experience(&mut self, exp_type: ExperienceType, content: &str) -> NodeIndex {
158        let node = MemoryNode::Experience {
159            exp_type,
160            content: content.chars().take(50).collect(),
161        };
162        self.graph.add_node(node)
163    }
164
165    /// Add a context node
166    pub fn add_context(&mut self, context_id: &str, decay: f32) -> NodeIndex {
167        let node = MemoryNode::Context {
168            context_id: context_id.to_string(),
169            decay,
170        };
171        self.graph.add_node(node)
172    }
173
174    /// Add an edge between nodes
175    pub fn add_edge(&mut self, from: NodeIndex, to: NodeIndex, edge_type: MemoryEdge) {
176        self.graph.add_edge(from, to, edge_type);
177    }
178
179    /// Visualize memory promotion (working -> session -> longterm)
180    pub fn log_promotion(&mut self, from_tier: &str, to_tier: &str, memory_id: &MemoryId) {
181        let from_key = format!("{}_{}", from_tier, memory_id.0);
182        let to_key = format!("{}_{}", to_tier, memory_id.0);
183
184        if let (Some(&from_idx), Some(&to_idx)) =
185            (self.node_map.get(&from_key), self.node_map.get(&to_key))
186        {
187            self.add_edge(from_idx, to_idx, MemoryEdge::Promotion);
188            debug!(
189                from = from_tier.to_uppercase().as_str(),
190                to = to_tier.to_uppercase().as_str(),
191                "Graph tier promotion"
192            );
193        }
194    }
195
196    /// Export graph as DOT format for Graphviz
197    pub fn to_dot(&self) -> String {
198        format!(
199            "{:?}",
200            Dot::with_config(&self.graph, &[Config::EdgeNoLabel])
201        )
202    }
203
204    /// Get statistics about the graph
205    pub fn stats(&self) -> GraphStats {
206        GraphStats {
207            total_nodes: self.graph.node_count(),
208            total_edges: self.graph.edge_count(),
209            working_memory_count: self.count_tier("working"),
210            session_memory_count: self.count_tier("session"),
211            longterm_memory_count: self.count_tier("longterm"),
212        }
213    }
214
215    fn count_tier(&self, tier: &str) -> usize {
216        self.node_map
217            .keys()
218            .filter(|k| k.starts_with(&format!("{tier}_")))
219            .count()
220    }
221
222    /// Print ASCII visualization of current memory state
223    pub fn print_ascii_visualization(&self) {
224        let stats = self.stats();
225
226        info!(
227            working = stats.working_memory_count,
228            session = stats.session_memory_count,
229            longterm = stats.longterm_memory_count,
230            nodes = stats.total_nodes,
231            edges = stats.total_edges,
232            "Memory system visualization: working={}, session={}, longterm={}, nodes={}, edges={}",
233            stats.working_memory_count,
234            stats.session_memory_count,
235            stats.longterm_memory_count,
236            stats.total_nodes,
237            stats.total_edges,
238        );
239    }
240}
241
242/// Graph statistics
243#[derive(Debug, Clone, Serialize)]
244pub struct GraphStats {
245    pub total_nodes: usize,
246    pub total_edges: usize,
247    pub working_memory_count: usize,
248    pub session_memory_count: usize,
249    pub longterm_memory_count: usize,
250}
251
252/// Logger for memory operations
253pub struct MemoryLogger {
254    pub graph: MemoryGraph,
255    enabled: bool,
256}
257
258impl MemoryLogger {
259    pub fn new(enabled: bool) -> Self {
260        Self {
261            graph: MemoryGraph::new(),
262            enabled,
263        }
264    }
265
266    /// Log memory creation
267    pub fn log_created(&mut self, memory: &Memory, tier: &str) {
268        if !self.enabled {
269            return;
270        }
271
272        debug!(
273            tier = tier.to_uppercase().as_str(),
274            importance = memory.importance(),
275            experience_type = ?memory.experience.experience_type,
276            "Memory created"
277        );
278
279        self.graph.add_memory(memory, tier);
280    }
281
282    /// Log memory access
283    pub fn log_accessed(&self, memory_id: &MemoryId, tier: &str) {
284        if !self.enabled {
285            return;
286        }
287
288        trace!(
289            tier = tier.to_uppercase().as_str(),
290            memory_id = %memory_id.0,
291            "Memory accessed"
292        );
293    }
294
295    /// Log memory promotion
296    pub fn log_promoted(&mut self, memory_id: &MemoryId, from: &str, to: &str, count: usize) {
297        if !self.enabled {
298            return;
299        }
300
301        debug!(
302            from = from.to_uppercase().as_str(),
303            to = to.to_uppercase().as_str(),
304            count,
305            "Memory tier promotion"
306        );
307
308        self.graph.log_promotion(from, to, memory_id);
309    }
310
311    /// Log compression
312    pub fn log_compressed(
313        &self,
314        _memory_id: &MemoryId,
315        original_size: usize,
316        compressed_size: usize,
317    ) {
318        if !self.enabled {
319            return;
320        }
321
322        let ratio = (compressed_size as f32 / original_size as f32 * 100.0) as usize;
323        debug!(original_size, compressed_size, ratio, "Memory compressed");
324    }
325
326    /// Log retrieval
327    pub fn log_retrieved(&self, query: &str, result_count: usize, sources: &[&str]) {
328        if !self.enabled {
329            return;
330        }
331
332        debug!(
333            query = %query.chars().take(50).collect::<String>(),
334            result_count,
335            sources = %sources.join(", "),
336            "Memory retrieved"
337        );
338    }
339
340    /// Show visualization
341    pub fn show_visualization(&self) {
342        if !self.enabled {
343            return;
344        }
345
346        self.graph.print_ascii_visualization();
347    }
348
349    /// Export graph
350    pub fn export_dot(&self, path: &std::path::Path) -> anyhow::Result<()> {
351        if !self.enabled {
352            return Ok(());
353        }
354
355        let dot = self.graph.to_dot();
356        std::fs::write(path, dot)?;
357        info!(path = %path.display(), "Graph exported");
358        Ok(())
359    }
360
361    /// Get graph statistics
362    pub fn get_stats(&self) -> GraphStats {
363        self.graph.stats()
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::memory::{Experience, ExperienceType, Memory, MemoryId};
371
372    fn create_test_memory() -> Memory {
373        use uuid::Uuid;
374
375        let experience = Experience {
376            experience_type: ExperienceType::Conversation,
377            content: "test content".to_string(),
378            ..Default::default()
379        };
380
381        Memory::new(
382            MemoryId(Uuid::new_v4()),
383            experience,
384            0.5,  // importance
385            None, // agent_id
386            None, // run_id
387            None, // actor_id
388            None, // created_at
389        )
390    }
391
392    #[test]
393    fn test_add_memory_with_valid_tiers() {
394        let mut graph = MemoryGraph::new();
395        let memory = create_test_memory();
396
397        // Test all valid tier names
398        let idx1 = graph.add_memory(&memory, "working");
399        let idx2 = graph.add_memory(&memory, "session");
400        let idx3 = graph.add_memory(&memory, "longterm");
401
402        // Verify nodes were created
403        assert_eq!(graph.graph.node_count(), 3);
404        assert!(idx1 != idx2 && idx2 != idx3 && idx1 != idx3);
405    }
406
407    #[test]
408    fn test_add_memory_with_invalid_tier_does_not_panic() {
409        let mut graph = MemoryGraph::new();
410        let memory = create_test_memory();
411
412        // This should NOT panic - it should log error and default to WorkingMemory
413        let idx = graph.add_memory(&memory, "invalid_tier_name");
414
415        // Verify node was created despite invalid tier
416        assert_eq!(graph.graph.node_count(), 1);
417
418        // Verify the node exists
419        assert!(graph.graph.node_weight(idx).is_some());
420
421        // Verify it was added as WorkingMemory (default fallback)
422        let node = graph.graph.node_weight(idx).unwrap();
423        match node {
424            MemoryNode::WorkingMemory { .. } => {
425                // Success - defaulted to WorkingMemory
426            }
427            _ => panic!("Expected WorkingMemory for invalid tier, got {node:?}"),
428        }
429    }
430
431    #[test]
432    fn test_add_memory_with_various_invalid_tiers() {
433        let mut graph = MemoryGraph::new();
434        let memory = create_test_memory();
435
436        // Test various invalid tier names - none should panic
437        let invalid_tiers = vec![
438            "",
439            "Working",
440            "WORKING",
441            "long-term",
442            "unknown",
443            "123",
444            "session_memory",
445        ];
446
447        for tier in invalid_tiers {
448            let _ = graph.add_memory(&memory, tier);
449        }
450
451        // All should have been created as WorkingMemory
452        assert_eq!(graph.graph.node_count(), 7);
453    }
454}