Skip to main content

shodh_memory/handlers/
graph.rs

1//! Knowledge Graph Handlers
2//!
3//! Handlers for advanced knowledge graph operations including traversal,
4//! entity management, and memory universe visualization.
5
6use axum::{
7    extract::{Path, State},
8    response::Json,
9};
10use serde::Deserialize;
11use tracing::info;
12
13use super::state::MultiUserMemoryManager;
14use super::types::MemoryEvent;
15use crate::errors::{AppError, ValidationErrorExt};
16use crate::graph_memory::{EntityNode, EpisodicNode, GraphStats, GraphTraversal, MemoryUniverse};
17use crate::memory::{Experience, MemoryId};
18use crate::validation;
19use std::sync::Arc;
20
21type AppState = Arc<MultiUserMemoryManager>;
22
23/// GET /api/graph/{user_id}/stats - Get graph statistics for a user
24pub async fn get_graph_stats(
25    State(state): State<AppState>,
26    Path(user_id): Path<String>,
27) -> Result<Json<GraphStats>, AppError> {
28    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
29
30    let stats = state
31        .get_user_graph_stats(&user_id)
32        .map_err(AppError::Internal)?;
33
34    Ok(Json(stats))
35}
36
37/// Request to find an entity
38#[derive(Debug, Deserialize)]
39pub struct FindEntityRequest {
40    pub user_id: String,
41    pub entity_name: String,
42}
43
44/// POST /api/graph/entity/find - Find an entity by name
45pub async fn find_entity(
46    State(state): State<AppState>,
47    Json(req): Json<FindEntityRequest>,
48) -> Result<Json<Option<EntityNode>>, AppError> {
49    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
50
51    let graph = state
52        .get_user_graph(&req.user_id)
53        .map_err(AppError::Internal)?;
54
55    let graph_guard = graph.read();
56    let entity = graph_guard
57        .find_entity_by_name(&req.entity_name)
58        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
59
60    Ok(Json(entity))
61}
62
63/// Request to traverse graph
64#[derive(Debug, Deserialize)]
65pub struct TraverseGraphRequest {
66    pub user_id: String,
67    pub entity_name: String,
68    pub max_depth: Option<usize>,
69}
70
71/// POST /api/graph/traverse - Traverse graph from an entity
72pub async fn traverse_graph(
73    State(state): State<AppState>,
74    Json(req): Json<TraverseGraphRequest>,
75) -> Result<Json<GraphTraversal>, AppError> {
76    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
77
78    let graph = state
79        .get_user_graph(&req.user_id)
80        .map_err(AppError::Internal)?;
81
82    let graph_guard = graph.read();
83
84    let entity = graph_guard
85        .find_entity_by_name(&req.entity_name)
86        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?
87        .ok_or_else(|| {
88            AppError::MemoryNotFound(format!("Entity not found: {}", req.entity_name))
89        })?;
90
91    let max_depth = req.max_depth.unwrap_or(2);
92    let traversal = graph_guard
93        .traverse_from_entity(&entity.uuid, max_depth)
94        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
95
96    Ok(Json(traversal))
97}
98
99/// Request to get an episode
100#[derive(Debug, Deserialize)]
101pub struct GetEpisodeRequest {
102    pub user_id: String,
103    pub episode_uuid: String,
104}
105
106/// POST /api/graph/episode/get - Get an episodic node by UUID
107pub async fn get_episode(
108    State(state): State<AppState>,
109    Json(req): Json<GetEpisodeRequest>,
110) -> Result<Json<Option<EpisodicNode>>, AppError> {
111    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
112
113    let graph = state
114        .get_user_graph(&req.user_id)
115        .map_err(AppError::Internal)?;
116
117    let graph_guard = graph.read();
118
119    let episode_uuid =
120        uuid::Uuid::parse_str(&req.episode_uuid).map_err(|_| AppError::InvalidInput {
121            field: "episode_uuid".to_string(),
122            reason: "Invalid UUID format".to_string(),
123        })?;
124
125    let episode = graph_guard
126        .get_episode(&episode_uuid)
127        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
128
129    Ok(Json(episode))
130}
131
132/// Request to get all entities
133#[derive(Debug, Deserialize)]
134pub struct GetAllEntitiesRequest {
135    pub user_id: String,
136    pub limit: Option<usize>,
137}
138
139/// POST /api/graph/entities/all - Get all entities
140pub async fn get_all_entities(
141    State(state): State<AppState>,
142    Json(req): Json<GetAllEntitiesRequest>,
143) -> Result<Json<serde_json::Value>, AppError> {
144    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
145
146    let graph = state
147        .get_user_graph(&req.user_id)
148        .map_err(AppError::Internal)?;
149    let graph_guard = graph.read();
150
151    let entities = graph_guard
152        .get_all_entities()
153        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
154
155    let limit = req.limit.unwrap_or(100);
156    let entities: Vec<_> = entities.into_iter().take(limit).collect();
157    let count = entities.len();
158
159    Ok(Json(serde_json::json!({
160        "entities": entities,
161        "count": count
162    })))
163}
164
165/// GET /api/graph/{user_id}/universe - Get Memory Universe visualization
166pub async fn get_memory_universe(
167    State(state): State<AppState>,
168    Path(user_id): Path<String>,
169) -> Result<Json<MemoryUniverse>, AppError> {
170    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
171
172    let graph = state.get_user_graph(&user_id).map_err(AppError::Internal)?;
173
174    let graph_guard = graph.read();
175    let universe = graph_guard
176        .get_universe()
177        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
178
179    Ok(Json(universe))
180}
181
182/// DELETE /api/graph/{user_id}/clear - Clear all graph data for a user
183pub async fn clear_user_graph(
184    State(state): State<AppState>,
185    Path(user_id): Path<String>,
186) -> Result<Json<serde_json::Value>, AppError> {
187    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
188
189    let graph = state.get_user_graph(&user_id).map_err(AppError::Internal)?;
190    let graph_guard = graph.write();
191
192    let (entities, relationships, episodes) = graph_guard
193        .clear_all()
194        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
195
196    info!(
197        "Cleared graph for user {}: {} entities, {} relationships, {} episodes",
198        user_id, entities, relationships, episodes
199    );
200
201    state.emit_event(MemoryEvent {
202        event_type: "GRAPH_CLEAR".to_string(),
203        timestamp: chrono::Utc::now(),
204        user_id: user_id.clone(),
205        memory_id: Some(format!("{}/{}/{}", entities, relationships, episodes)),
206        content_preview: Some(format!(
207            "Cleared {} entities, {} relationships, {} episodes",
208            entities, relationships, episodes
209        )),
210        memory_type: Some("graph".to_string()),
211        importance: None,
212        count: Some(entities + relationships + episodes),
213        results: None,
214    });
215
216    Ok(Json(serde_json::json!({
217        "cleared": {
218            "entities": entities,
219            "relationships": relationships,
220            "episodes": episodes
221        }
222    })))
223}
224
225/// POST /api/graph/{user_id}/rebuild - Rebuild graph from all existing memories
226pub async fn rebuild_user_graph(
227    State(state): State<AppState>,
228    Path(user_id): Path<String>,
229) -> Result<Json<serde_json::Value>, AppError> {
230    validation::validate_user_id(&user_id).map_validation_err("user_id")?;
231
232    // First, clear existing graph data
233    let graph = state.get_user_graph(&user_id).map_err(AppError::Internal)?;
234    {
235        let graph_guard = graph.write();
236        let _ = graph_guard.clear_all();
237    }
238
239    // Get all memories for this user
240    let memory_sys = state
241        .get_user_memory(&user_id)
242        .map_err(AppError::Internal)?;
243    let memories: Vec<(MemoryId, Experience)> = {
244        let memory_guard = memory_sys.read();
245        memory_guard
246            .get_all_memories()
247            .map_err(AppError::Internal)?
248            .into_iter()
249            .map(|m| (m.id.clone(), m.experience.clone()))
250            .collect()
251    };
252
253    let total_memories = memories.len();
254    let mut processed = 0;
255
256    // Re-process each memory through entity extraction
257    for (memory_id, experience) in memories {
258        if let Err(e) = state.process_experience_into_graph(&user_id, &experience, &memory_id) {
259            tracing::debug!("Failed to process memory {}: {}", memory_id.0, e);
260        } else {
261            processed += 1;
262        }
263    }
264
265    // Get final stats
266    let stats = state
267        .get_user_graph_stats(&user_id)
268        .map_err(AppError::Internal)?;
269    let entities_created = stats.entity_count;
270    let relationships_created = stats.relationship_count;
271
272    info!(
273        "Rebuilt graph for user {}: processed {}/{} memories, created {} entities, {} relationships",
274        user_id, processed, total_memories, entities_created, relationships_created
275    );
276
277    state.emit_event(MemoryEvent {
278        event_type: "GRAPH_REBUILD".to_string(),
279        timestamp: chrono::Utc::now(),
280        user_id: user_id.clone(),
281        memory_id: None,
282        content_preview: Some(format!(
283            "Rebuilt: {} memories -> {} entities, {} relationships",
284            processed, entities_created, relationships_created
285        )),
286        memory_type: Some("graph".to_string()),
287        importance: None,
288        count: Some(entities_created + relationships_created),
289        results: None,
290    });
291
292    Ok(Json(serde_json::json!({
293        "success": true,
294        "processed_memories": processed,
295        "total_memories": total_memories,
296        "entities_created": entities_created,
297        "relationships_created": relationships_created
298    })))
299}
300
301/// Request to invalidate a relationship
302#[derive(Debug, Deserialize)]
303pub struct InvalidateRelationshipRequest {
304    pub user_id: String,
305    pub relationship_uuid: String,
306}
307
308/// POST /api/graph/relationship/invalidate - Invalidate a relationship edge
309pub async fn invalidate_relationship(
310    State(state): State<AppState>,
311    Json(req): Json<InvalidateRelationshipRequest>,
312) -> Result<Json<serde_json::Value>, AppError> {
313    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
314
315    let graph = state
316        .get_user_graph(&req.user_id)
317        .map_err(AppError::Internal)?;
318
319    let graph_guard = graph.write();
320
321    let rel_uuid =
322        uuid::Uuid::parse_str(&req.relationship_uuid).map_err(|_| AppError::InvalidInput {
323            field: "relationship_uuid".to_string(),
324            reason: "Invalid UUID format".to_string(),
325        })?;
326
327    graph_guard
328        .invalidate_relationship(&rel_uuid)
329        .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
330
331    state.emit_event(MemoryEvent {
332        event_type: "EDGE_INVALIDATE".to_string(),
333        timestamp: chrono::Utc::now(),
334        user_id: req.user_id.clone(),
335        memory_id: Some(req.relationship_uuid.clone()),
336        content_preview: Some("Relationship invalidated".to_string()),
337        memory_type: Some("graph".to_string()),
338        importance: None,
339        count: None,
340        results: None,
341    });
342
343    Ok(Json(serde_json::json!({
344        "success": true,
345        "message": "Relationship invalidated"
346    })))
347}