1use axum::{
7 extract::{Path, Query, State},
8 response::{Html, Json},
9};
10use serde::{Deserialize, Serialize};
11
12use super::state::MultiUserMemoryManager;
13use crate::errors::{AppError, ValidationErrorExt};
14use crate::memory::GraphStats as VisualizationStats;
15use crate::validation;
16use std::sync::Arc;
17
18type AppState = Arc<MultiUserMemoryManager>;
19
20#[derive(Debug, Serialize)]
22pub struct BrainStateResponse {
23 pub working_memory: Vec<MemoryNeuron>,
24 pub session_memory: Vec<MemoryNeuron>,
25 pub longterm_memory: Vec<MemoryNeuron>,
26 pub stats: BrainStats,
27}
28
29#[derive(Debug, Serialize)]
31pub struct MemoryNeuron {
32 pub id: String,
33 pub content_preview: String,
34 pub activation: f32,
35 pub importance: f32,
36 pub tier: String,
37 pub access_count: u32,
38 pub created_at: String,
39}
40
41#[derive(Debug, Serialize)]
43pub struct BrainStats {
44 pub total_memories: usize,
45 pub working_count: usize,
46 pub session_count: usize,
47 pub longterm_count: usize,
48 pub avg_activation: f32,
49 pub avg_importance: f32,
50}
51
52pub async fn get_brain_state(
54 State(state): State<AppState>,
55 Path(user_id): Path<String>,
56) -> Result<Json<BrainStateResponse>, AppError> {
57 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
58
59 let memory = state
60 .get_user_memory(&user_id)
61 .map_err(AppError::Internal)?;
62
63 let memory_guard = memory.read();
64
65 let mut working_memory = Vec::new();
66 let mut session_memory = Vec::new();
67 let mut longterm_memory = Vec::new();
68 let mut total_activation = 0.0f32;
69 let mut total_importance = 0.0f32;
70
71 for mem in memory_guard.get_working_memories() {
73 let neuron = MemoryNeuron {
74 id: mem.id.0.to_string(),
75 content_preview: mem.experience.content.chars().take(100).collect(),
76 activation: mem.activation(),
77 importance: mem.importance(),
78 tier: "working".to_string(),
79 access_count: mem.metadata_snapshot().access_count,
80 created_at: mem.created_at.to_rfc3339(),
81 };
82 total_activation += neuron.activation;
83 total_importance += neuron.importance;
84 working_memory.push(neuron);
85 }
86
87 for mem in memory_guard.get_session_memories() {
89 let neuron = MemoryNeuron {
90 id: mem.id.0.to_string(),
91 content_preview: mem.experience.content.chars().take(100).collect(),
92 activation: mem.activation(),
93 importance: mem.importance(),
94 tier: "session".to_string(),
95 access_count: mem.metadata_snapshot().access_count,
96 created_at: mem.created_at.to_rfc3339(),
97 };
98 total_activation += neuron.activation;
99 total_importance += neuron.importance;
100 session_memory.push(neuron);
101 }
102
103 let longterm_sample = memory_guard.get_longterm_memories(50).unwrap_or_default();
105 for mem in longterm_sample {
106 let neuron = MemoryNeuron {
107 id: mem.id.0.to_string(),
108 content_preview: mem.experience.content.chars().take(100).collect(),
109 activation: mem.activation(),
110 importance: mem.importance(),
111 tier: "longterm".to_string(),
112 access_count: mem.metadata_snapshot().access_count,
113 created_at: mem.created_at.to_rfc3339(),
114 };
115 total_activation += neuron.activation;
116 total_importance += neuron.importance;
117 longterm_memory.push(neuron);
118 }
119
120 let total_count = working_memory.len() + session_memory.len() + longterm_memory.len();
121 let stats = BrainStats {
122 total_memories: total_count,
123 working_count: working_memory.len(),
124 session_count: session_memory.len(),
125 longterm_count: longterm_memory.len(),
126 avg_activation: if total_count > 0 {
127 total_activation / total_count as f32
128 } else {
129 0.0
130 },
131 avg_importance: if total_count > 0 {
132 total_importance / total_count as f32
133 } else {
134 0.0
135 },
136 };
137
138 Ok(Json(BrainStateResponse {
139 working_memory,
140 session_memory,
141 longterm_memory,
142 stats,
143 }))
144}
145
146pub async fn get_visualization_stats(
148 State(state): State<AppState>,
149 Path(user_id): Path<String>,
150) -> Result<Json<VisualizationStats>, AppError> {
151 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
152
153 let memory = state
154 .get_user_memory(&user_id)
155 .map_err(AppError::Internal)?;
156
157 let memory_guard = memory.read();
158 let stats = memory_guard.get_visualization_stats();
159
160 Ok(Json(stats))
161}
162
163pub async fn get_visualization_dot(
165 State(state): State<AppState>,
166 Path(user_id): Path<String>,
167) -> Result<String, AppError> {
168 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
169
170 let memory = state
171 .get_user_memory(&user_id)
172 .map_err(AppError::Internal)?;
173
174 let memory_guard = memory.read();
175 let dot = memory_guard.export_visualization_dot();
176
177 Ok(dot)
178}
179
180#[derive(Debug, Deserialize)]
182pub struct BuildVisualizationRequest {
183 pub user_id: String,
184}
185
186pub async fn build_visualization(
188 State(state): State<AppState>,
189 Json(req): Json<BuildVisualizationRequest>,
190) -> Result<Json<VisualizationStats>, AppError> {
191 validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
192
193 let memory = state
194 .get_user_memory(&req.user_id)
195 .map_err(AppError::Internal)?;
196
197 let memory_guard = memory.read();
198 let stats = memory_guard
199 .build_visualization_graph()
200 .map_err(AppError::Internal)?;
201
202 Ok(Json(stats))
203}
204
205#[derive(Debug, Deserialize)]
207pub struct GraphViewParams {
208 pub user_id: Option<String>,
209}
210
211#[derive(Debug, Serialize)]
213pub struct GraphNode {
214 pub id: String,
215 pub label: String,
216 pub node_type: String, pub tier: String, pub strength: f32,
219 pub size: f32,
220}
221
222#[derive(Debug, Serialize)]
224pub struct GraphEdge {
225 pub source: String,
226 pub target: String,
227 pub edge_type: String,
228 pub tier: String, pub strength: f32,
230}
231
232#[derive(Debug, Serialize)]
234pub struct GraphDataResponse {
235 pub nodes: Vec<GraphNode>,
236 pub edges: Vec<GraphEdge>,
237 pub stats: GraphDataStats,
238}
239
240#[derive(Debug, Serialize)]
242pub struct GraphDataStats {
243 pub total_nodes: usize,
244 pub total_edges: usize,
245 pub l1_edges: usize,
246 pub l2_edges: usize,
247 pub l3_edges: usize,
248}
249
250pub async fn graph_view(Query(params): Query<GraphViewParams>) -> Html<String> {
252 let user_id = params.user_id.unwrap_or_else(|| "default".to_string());
253 Html(generate_graph_html(&user_id))
254}
255
256pub async fn get_graph_data(
258 State(state): State<AppState>,
259 Path(user_id): Path<String>,
260) -> Result<Json<GraphDataResponse>, AppError> {
261 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
262
263 let memory = state
264 .get_user_memory(&user_id)
265 .map_err(AppError::Internal)?;
266
267 let memory_guard = memory.read();
268 let graph = memory_guard
269 .graph_memory()
270 .ok_or_else(|| AppError::Internal(anyhow::anyhow!("Graph memory not initialized")))?;
271 let graph_guard = graph.read();
272
273 let mut nodes = Vec::new();
274 let mut edges = Vec::new();
275 let mut l1_count = 0;
276 let mut l2_count = 0;
277 let mut l3_count = 0;
278
279 if let Ok(entities) = graph_guard.get_all_entities() {
281 for entity in entities.iter().take(200) {
282 let tier_label = entity
283 .labels
284 .first()
285 .map(|l| l.as_str().to_string())
286 .unwrap_or_else(|| "entity".to_string());
287 nodes.push(GraphNode {
288 id: entity.uuid.to_string(),
289 label: entity.name.clone(),
290 node_type: "entity".to_string(),
291 tier: tier_label,
292 strength: 1.0,
293 size: 10.0,
294 });
295 }
296 }
297
298 if let Ok(relationships) = graph_guard.get_all_relationships() {
300 use crate::graph_memory::EdgeTier;
301
302 let l1_edges: Vec<_> = relationships
304 .iter()
305 .filter(|r| matches!(r.tier, EdgeTier::L1Working))
306 .take(200)
307 .collect();
308 let l2_edges: Vec<_> = relationships
309 .iter()
310 .filter(|r| matches!(r.tier, EdgeTier::L2Episodic))
311 .take(200)
312 .collect();
313 let l3_edges: Vec<_> = relationships
314 .iter()
315 .filter(|r| matches!(r.tier, EdgeTier::L3Semantic))
316 .take(200)
317 .collect();
318
319 for rel in l1_edges
321 .iter()
322 .chain(l2_edges.iter())
323 .chain(l3_edges.iter())
324 {
325 let tier_str = match rel.tier {
326 EdgeTier::L1Working => {
327 l1_count += 1;
328 "L1"
329 }
330 EdgeTier::L2Episodic => {
331 l2_count += 1;
332 "L2"
333 }
334 EdgeTier::L3Semantic => {
335 l3_count += 1;
336 "L3"
337 }
338 };
339
340 edges.push(GraphEdge {
341 source: rel.from_entity.to_string(),
342 target: rel.to_entity.to_string(),
343 edge_type: rel.relation_type.as_str().to_string(),
344 tier: tier_str.to_string(),
345 strength: rel.effective_strength(),
346 });
347 }
348 }
349
350 let memories = memory_guard.get_longterm_memories(100).unwrap_or_default();
352 let entity_ids: std::collections::HashSet<String> =
353 nodes.iter().map(|n| n.id.clone()).collect();
354
355 for mem in memories {
356 let mem_id = mem.id.0.to_string();
357 nodes.push(GraphNode {
358 id: mem_id.clone(),
359 label: mem.experience.content.chars().take(30).collect::<String>() + "...",
360 node_type: "memory".to_string(),
361 tier: "longterm".to_string(),
362 strength: mem.importance(),
363 size: 6.0 + mem.importance() * 8.0,
364 });
365
366 for entity_id in mem.entity_ids() {
368 let entity_id_str = entity_id.to_string();
369 if entity_ids.contains(&entity_id_str) {
370 edges.push(GraphEdge {
371 source: mem_id.clone(),
372 target: entity_id_str,
373 edge_type: "mentions".to_string(),
374 tier: "L2".to_string(),
375 strength: 0.5,
376 });
377 }
378 }
379 }
380
381 Ok(Json(GraphDataResponse {
382 stats: GraphDataStats {
383 total_nodes: nodes.len(),
384 total_edges: edges.len(),
385 l1_edges: l1_count,
386 l2_edges: l2_count,
387 l3_edges: l3_count,
388 },
389 nodes,
390 edges,
391 }))
392}
393
394fn generate_graph_html(user_id: &str) -> String {
396 let html = include_str!("graph_view.html");
397 let escaped = user_id
399 .replace('&', "&")
400 .replace('<', "<")
401 .replace('>', ">")
402 .replace('"', """)
403 .replace('\'', "'");
404 html.replace("{{USER_ID}}", &escaped)
405}