Skip to main content

shodh_memory/handlers/
visualization.rs

1//! Visualization Handlers
2//!
3//! Handlers for brain state visualization and memory graph visualization.
4//! Includes live browser-based graph visualization with SSE updates.
5
6use axum::{
7    extract::{Path, Query, State},
8    response::{Html, Json},
9};
10use serde::{Deserialize, Serialize};
11
12use super::state::MultiUserMemoryManager;
13use crate::errors::{AppError, ValidationErrorExt};
14use crate::memory::GraphStats as VisualizationStats;
15use crate::validation;
16use std::sync::Arc;
17
18type AppState = Arc<MultiUserMemoryManager>;
19
20/// Brain state response with memories organized by tier
21#[derive(Debug, Serialize)]
22pub struct BrainStateResponse {
23    pub working_memory: Vec<MemoryNeuron>,
24    pub session_memory: Vec<MemoryNeuron>,
25    pub longterm_memory: Vec<MemoryNeuron>,
26    pub stats: BrainStats,
27}
28
29/// Individual memory neuron for visualization
30#[derive(Debug, Serialize)]
31pub struct MemoryNeuron {
32    pub id: String,
33    pub content_preview: String,
34    pub activation: f32,
35    pub importance: f32,
36    pub tier: String,
37    pub access_count: u32,
38    pub created_at: String,
39}
40
41/// Brain statistics
42#[derive(Debug, Serialize)]
43pub struct BrainStats {
44    pub total_memories: usize,
45    pub working_count: usize,
46    pub session_count: usize,
47    pub longterm_count: usize,
48    pub avg_activation: f32,
49    pub avg_importance: f32,
50}
51
52/// GET /api/brain/{user_id} - Get brain state visualization
53pub async fn get_brain_state(
54    State(state): State<AppState>,
55    Path(user_id): Path<String>,
56) -> Result<Json<BrainStateResponse>, AppError> {
57    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
58
59    let memory = state
60        .get_user_memory(&user_id)
61        .map_err(AppError::Internal)?;
62
63    let memory_guard = memory.read();
64
65    let mut working_memory = Vec::new();
66    let mut session_memory = Vec::new();
67    let mut longterm_memory = Vec::new();
68    let mut total_activation = 0.0f32;
69    let mut total_importance = 0.0f32;
70
71    // Get working memory
72    for mem in memory_guard.get_working_memories() {
73        let neuron = MemoryNeuron {
74            id: mem.id.0.to_string(),
75            content_preview: mem.experience.content.chars().take(100).collect(),
76            activation: mem.activation(),
77            importance: mem.importance(),
78            tier: "working".to_string(),
79            access_count: mem.metadata_snapshot().access_count,
80            created_at: mem.created_at.to_rfc3339(),
81        };
82        total_activation += neuron.activation;
83        total_importance += neuron.importance;
84        working_memory.push(neuron);
85    }
86
87    // Get session memory
88    for mem in memory_guard.get_session_memories() {
89        let neuron = MemoryNeuron {
90            id: mem.id.0.to_string(),
91            content_preview: mem.experience.content.chars().take(100).collect(),
92            activation: mem.activation(),
93            importance: mem.importance(),
94            tier: "session".to_string(),
95            access_count: mem.metadata_snapshot().access_count,
96            created_at: mem.created_at.to_rfc3339(),
97        };
98        total_activation += neuron.activation;
99        total_importance += neuron.importance;
100        session_memory.push(neuron);
101    }
102
103    // Get longterm memory sample
104    let longterm_sample = memory_guard.get_longterm_memories(50).unwrap_or_default();
105    for mem in longterm_sample {
106        let neuron = MemoryNeuron {
107            id: mem.id.0.to_string(),
108            content_preview: mem.experience.content.chars().take(100).collect(),
109            activation: mem.activation(),
110            importance: mem.importance(),
111            tier: "longterm".to_string(),
112            access_count: mem.metadata_snapshot().access_count,
113            created_at: mem.created_at.to_rfc3339(),
114        };
115        total_activation += neuron.activation;
116        total_importance += neuron.importance;
117        longterm_memory.push(neuron);
118    }
119
120    let total_count = working_memory.len() + session_memory.len() + longterm_memory.len();
121    let stats = BrainStats {
122        total_memories: total_count,
123        working_count: working_memory.len(),
124        session_count: session_memory.len(),
125        longterm_count: longterm_memory.len(),
126        avg_activation: if total_count > 0 {
127            total_activation / total_count as f32
128        } else {
129            0.0
130        },
131        avg_importance: if total_count > 0 {
132            total_importance / total_count as f32
133        } else {
134            0.0
135        },
136    };
137
138    Ok(Json(BrainStateResponse {
139        working_memory,
140        session_memory,
141        longterm_memory,
142        stats,
143    }))
144}
145
146/// GET /api/visualization/{user_id}/stats - Get visualization statistics
147pub async fn get_visualization_stats(
148    State(state): State<AppState>,
149    Path(user_id): Path<String>,
150) -> Result<Json<VisualizationStats>, AppError> {
151    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
152
153    let memory = state
154        .get_user_memory(&user_id)
155        .map_err(AppError::Internal)?;
156
157    let memory_guard = memory.read();
158    let stats = memory_guard.get_visualization_stats();
159
160    Ok(Json(stats))
161}
162
163/// GET /api/visualization/{user_id}/dot - Export graph as DOT format
164pub async fn get_visualization_dot(
165    State(state): State<AppState>,
166    Path(user_id): Path<String>,
167) -> Result<String, AppError> {
168    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
169
170    let memory = state
171        .get_user_memory(&user_id)
172        .map_err(AppError::Internal)?;
173
174    let memory_guard = memory.read();
175    let dot = memory_guard.export_visualization_dot();
176
177    Ok(dot)
178}
179
180/// Request to build visualization
181#[derive(Debug, Deserialize)]
182pub struct BuildVisualizationRequest {
183    pub user_id: String,
184}
185
186/// POST /api/visualization/build - Build visualization graph
187pub async fn build_visualization(
188    State(state): State<AppState>,
189    Json(req): Json<BuildVisualizationRequest>,
190) -> Result<Json<VisualizationStats>, AppError> {
191    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
192
193    let memory = state
194        .get_user_memory(&req.user_id)
195        .map_err(AppError::Internal)?;
196
197    let memory_guard = memory.read();
198    let stats = memory_guard
199        .build_visualization_graph()
200        .map_err(AppError::Internal)?;
201
202    Ok(Json(stats))
203}
204
205/// Query parameters for graph view
206#[derive(Debug, Deserialize)]
207pub struct GraphViewParams {
208    pub user_id: Option<String>,
209}
210
211/// Graph node for d3.js visualization
212#[derive(Debug, Serialize)]
213pub struct GraphNode {
214    pub id: String,
215    pub label: String,
216    pub node_type: String, // "memory", "entity"
217    pub tier: String,      // "L1", "L2", "L3" or memory tier
218    pub strength: f32,
219    pub size: f32,
220}
221
222/// Graph edge for d3.js visualization
223#[derive(Debug, Serialize)]
224pub struct GraphEdge {
225    pub source: String,
226    pub target: String,
227    pub edge_type: String,
228    pub tier: String, // "L1", "L2", "L3"
229    pub strength: f32,
230}
231
232/// Graph data response for d3.js
233#[derive(Debug, Serialize)]
234pub struct GraphDataResponse {
235    pub nodes: Vec<GraphNode>,
236    pub edges: Vec<GraphEdge>,
237    pub stats: GraphDataStats,
238}
239
240/// Graph statistics
241#[derive(Debug, Serialize)]
242pub struct GraphDataStats {
243    pub total_nodes: usize,
244    pub total_edges: usize,
245    pub l1_edges: usize,
246    pub l2_edges: usize,
247    pub l3_edges: usize,
248}
249
250/// GET /graph/view - Serve interactive graph visualization HTML
251pub async fn graph_view(Query(params): Query<GraphViewParams>) -> Html<String> {
252    let user_id = params.user_id.unwrap_or_else(|| "default".to_string());
253    Html(generate_graph_html(&user_id))
254}
255
256/// GET /api/graph/data/{user_id} - Get graph data as JSON for d3.js
257pub async fn get_graph_data(
258    State(state): State<AppState>,
259    Path(user_id): Path<String>,
260) -> Result<Json<GraphDataResponse>, AppError> {
261    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
262
263    let memory = state
264        .get_user_memory(&user_id)
265        .map_err(AppError::Internal)?;
266
267    let memory_guard = memory.read();
268    let graph = memory_guard
269        .graph_memory()
270        .ok_or_else(|| AppError::Internal(anyhow::anyhow!("Graph memory not initialized")))?;
271    let graph_guard = graph.read();
272
273    let mut nodes = Vec::new();
274    let mut edges = Vec::new();
275    let mut l1_count = 0;
276    let mut l2_count = 0;
277    let mut l3_count = 0;
278
279    // Get entities as nodes
280    if let Ok(entities) = graph_guard.get_all_entities() {
281        for entity in entities.iter().take(200) {
282            let tier_label = entity
283                .labels
284                .first()
285                .map(|l| l.as_str().to_string())
286                .unwrap_or_else(|| "entity".to_string());
287            nodes.push(GraphNode {
288                id: entity.uuid.to_string(),
289                label: entity.name.clone(),
290                node_type: "entity".to_string(),
291                tier: tier_label,
292                strength: 1.0,
293                size: 10.0,
294            });
295        }
296    }
297
298    // Get relationships as edges - sample from each tier for visibility
299    if let Ok(relationships) = graph_guard.get_all_relationships() {
300        use crate::graph_memory::EdgeTier;
301
302        // Separate by tier for proportional sampling
303        let l1_edges: Vec<_> = relationships
304            .iter()
305            .filter(|r| matches!(r.tier, EdgeTier::L1Working))
306            .take(200)
307            .collect();
308        let l2_edges: Vec<_> = relationships
309            .iter()
310            .filter(|r| matches!(r.tier, EdgeTier::L2Episodic))
311            .take(200)
312            .collect();
313        let l3_edges: Vec<_> = relationships
314            .iter()
315            .filter(|r| matches!(r.tier, EdgeTier::L3Semantic))
316            .take(200)
317            .collect();
318
319        // Add edges from each tier
320        for rel in l1_edges
321            .iter()
322            .chain(l2_edges.iter())
323            .chain(l3_edges.iter())
324        {
325            let tier_str = match rel.tier {
326                EdgeTier::L1Working => {
327                    l1_count += 1;
328                    "L1"
329                }
330                EdgeTier::L2Episodic => {
331                    l2_count += 1;
332                    "L2"
333                }
334                EdgeTier::L3Semantic => {
335                    l3_count += 1;
336                    "L3"
337                }
338            };
339
340            edges.push(GraphEdge {
341                source: rel.from_entity.to_string(),
342                target: rel.to_entity.to_string(),
343                edge_type: rel.relation_type.as_str().to_string(),
344                tier: tier_str.to_string(),
345                strength: rel.effective_strength(),
346            });
347        }
348    }
349
350    // Add memory nodes and connect them to their entities
351    let memories = memory_guard.get_longterm_memories(100).unwrap_or_default();
352    let entity_ids: std::collections::HashSet<String> =
353        nodes.iter().map(|n| n.id.clone()).collect();
354
355    for mem in memories {
356        let mem_id = mem.id.0.to_string();
357        nodes.push(GraphNode {
358            id: mem_id.clone(),
359            label: mem.experience.content.chars().take(30).collect::<String>() + "...",
360            node_type: "memory".to_string(),
361            tier: "longterm".to_string(),
362            strength: mem.importance(),
363            size: 6.0 + mem.importance() * 8.0,
364        });
365
366        // Connect memory to its entities
367        for entity_id in mem.entity_ids() {
368            let entity_id_str = entity_id.to_string();
369            if entity_ids.contains(&entity_id_str) {
370                edges.push(GraphEdge {
371                    source: mem_id.clone(),
372                    target: entity_id_str,
373                    edge_type: "mentions".to_string(),
374                    tier: "L2".to_string(),
375                    strength: 0.5,
376                });
377            }
378        }
379    }
380
381    Ok(Json(GraphDataResponse {
382        stats: GraphDataStats {
383            total_nodes: nodes.len(),
384            total_edges: edges.len(),
385            l1_edges: l1_count,
386            l2_edges: l2_count,
387            l3_edges: l3_count,
388        },
389        nodes,
390        edges,
391    }))
392}
393
394/// Generate the HTML page for graph visualization (includes 2D/3D toggle)
395fn generate_graph_html(user_id: &str) -> String {
396    let html = include_str!("graph_view.html");
397    // HTML-escape user_id to prevent reflected XSS via query parameter injection
398    let escaped = user_id
399        .replace('&', "&amp;")
400        .replace('<', "&lt;")
401        .replace('>', "&gt;")
402        .replace('"', "&quot;")
403        .replace('\'', "&#x27;");
404    html.replace("{{USER_ID}}", &escaped)
405}