1use axum::{
7 extract::{Path, State},
8 response::Json,
9};
10use serde::Deserialize;
11use tracing::info;
12
13use super::state::MultiUserMemoryManager;
14use super::types::MemoryEvent;
15use crate::errors::{AppError, ValidationErrorExt};
16use crate::graph_memory::{EntityNode, EpisodicNode, GraphStats, GraphTraversal, MemoryUniverse};
17use crate::memory::{Experience, MemoryId};
18use crate::validation;
19use std::sync::Arc;
20
21type AppState = Arc<MultiUserMemoryManager>;
22
23pub async fn get_graph_stats(
25 State(state): State<AppState>,
26 Path(user_id): Path<String>,
27) -> Result<Json<GraphStats>, AppError> {
28 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
29
30 let stats = state
31 .get_user_graph_stats(&user_id)
32 .map_err(AppError::Internal)?;
33
34 Ok(Json(stats))
35}
36
37#[derive(Debug, Deserialize)]
39pub struct FindEntityRequest {
40 pub user_id: String,
41 pub entity_name: String,
42}
43
44pub async fn find_entity(
46 State(state): State<AppState>,
47 Json(req): Json<FindEntityRequest>,
48) -> Result<Json<Option<EntityNode>>, AppError> {
49 validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
50
51 let graph = state
52 .get_user_graph(&req.user_id)
53 .map_err(AppError::Internal)?;
54
55 let graph_guard = graph.read();
56 let entity = graph_guard
57 .find_entity_by_name(&req.entity_name)
58 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
59
60 Ok(Json(entity))
61}
62
63#[derive(Debug, Deserialize)]
65pub struct TraverseGraphRequest {
66 pub user_id: String,
67 pub entity_name: String,
68 pub max_depth: Option<usize>,
69}
70
71pub async fn traverse_graph(
73 State(state): State<AppState>,
74 Json(req): Json<TraverseGraphRequest>,
75) -> Result<Json<GraphTraversal>, AppError> {
76 validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
77
78 let graph = state
79 .get_user_graph(&req.user_id)
80 .map_err(AppError::Internal)?;
81
82 let graph_guard = graph.read();
83
84 let entity = graph_guard
85 .find_entity_by_name(&req.entity_name)
86 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?
87 .ok_or_else(|| {
88 AppError::MemoryNotFound(format!("Entity not found: {}", req.entity_name))
89 })?;
90
91 let max_depth = req.max_depth.unwrap_or(2);
92 let traversal = graph_guard
93 .traverse_from_entity(&entity.uuid, max_depth)
94 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
95
96 Ok(Json(traversal))
97}
98
99#[derive(Debug, Deserialize)]
101pub struct GetEpisodeRequest {
102 pub user_id: String,
103 pub episode_uuid: String,
104}
105
106pub async fn get_episode(
108 State(state): State<AppState>,
109 Json(req): Json<GetEpisodeRequest>,
110) -> Result<Json<Option<EpisodicNode>>, AppError> {
111 validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
112
113 let graph = state
114 .get_user_graph(&req.user_id)
115 .map_err(AppError::Internal)?;
116
117 let graph_guard = graph.read();
118
119 let episode_uuid =
120 uuid::Uuid::parse_str(&req.episode_uuid).map_err(|_| AppError::InvalidInput {
121 field: "episode_uuid".to_string(),
122 reason: "Invalid UUID format".to_string(),
123 })?;
124
125 let episode = graph_guard
126 .get_episode(&episode_uuid)
127 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
128
129 Ok(Json(episode))
130}
131
132#[derive(Debug, Deserialize)]
134pub struct GetAllEntitiesRequest {
135 pub user_id: String,
136 pub limit: Option<usize>,
137}
138
139pub async fn get_all_entities(
141 State(state): State<AppState>,
142 Json(req): Json<GetAllEntitiesRequest>,
143) -> Result<Json<serde_json::Value>, AppError> {
144 validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
145
146 let graph = state
147 .get_user_graph(&req.user_id)
148 .map_err(AppError::Internal)?;
149 let graph_guard = graph.read();
150
151 let entities = graph_guard
152 .get_all_entities()
153 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
154
155 let limit = req.limit.unwrap_or(100);
156 let entities: Vec<_> = entities.into_iter().take(limit).collect();
157 let count = entities.len();
158
159 Ok(Json(serde_json::json!({
160 "entities": entities,
161 "count": count
162 })))
163}
164
165pub async fn get_memory_universe(
167 State(state): State<AppState>,
168 Path(user_id): Path<String>,
169) -> Result<Json<MemoryUniverse>, AppError> {
170 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
171
172 let graph = state.get_user_graph(&user_id).map_err(AppError::Internal)?;
173
174 let graph_guard = graph.read();
175 let universe = graph_guard
176 .get_universe()
177 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
178
179 Ok(Json(universe))
180}
181
182pub async fn clear_user_graph(
184 State(state): State<AppState>,
185 Path(user_id): Path<String>,
186) -> Result<Json<serde_json::Value>, AppError> {
187 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
188
189 let graph = state.get_user_graph(&user_id).map_err(AppError::Internal)?;
190 let graph_guard = graph.write();
191
192 let (entities, relationships, episodes) = graph_guard
193 .clear_all()
194 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
195
196 info!(
197 "Cleared graph for user {}: {} entities, {} relationships, {} episodes",
198 user_id, entities, relationships, episodes
199 );
200
201 state.emit_event(MemoryEvent {
202 event_type: "GRAPH_CLEAR".to_string(),
203 timestamp: chrono::Utc::now(),
204 user_id: user_id.clone(),
205 memory_id: Some(format!("{}/{}/{}", entities, relationships, episodes)),
206 content_preview: Some(format!(
207 "Cleared {} entities, {} relationships, {} episodes",
208 entities, relationships, episodes
209 )),
210 memory_type: Some("graph".to_string()),
211 importance: None,
212 count: Some(entities + relationships + episodes),
213 results: None,
214 });
215
216 Ok(Json(serde_json::json!({
217 "cleared": {
218 "entities": entities,
219 "relationships": relationships,
220 "episodes": episodes
221 }
222 })))
223}
224
225pub async fn rebuild_user_graph(
227 State(state): State<AppState>,
228 Path(user_id): Path<String>,
229) -> Result<Json<serde_json::Value>, AppError> {
230 validation::validate_user_id(&user_id).map_validation_err("user_id")?;
231
232 let graph = state.get_user_graph(&user_id).map_err(AppError::Internal)?;
234 {
235 let graph_guard = graph.write();
236 let _ = graph_guard.clear_all();
237 }
238
239 let memory_sys = state
241 .get_user_memory(&user_id)
242 .map_err(AppError::Internal)?;
243 let memories: Vec<(MemoryId, Experience)> = {
244 let memory_guard = memory_sys.read();
245 memory_guard
246 .get_all_memories()
247 .map_err(AppError::Internal)?
248 .into_iter()
249 .map(|m| (m.id.clone(), m.experience.clone()))
250 .collect()
251 };
252
253 let total_memories = memories.len();
254 let mut processed = 0;
255
256 for (memory_id, experience) in memories {
258 if let Err(e) = state.process_experience_into_graph(&user_id, &experience, &memory_id) {
259 tracing::debug!("Failed to process memory {}: {}", memory_id.0, e);
260 } else {
261 processed += 1;
262 }
263 }
264
265 let stats = state
267 .get_user_graph_stats(&user_id)
268 .map_err(AppError::Internal)?;
269 let entities_created = stats.entity_count;
270 let relationships_created = stats.relationship_count;
271
272 info!(
273 "Rebuilt graph for user {}: processed {}/{} memories, created {} entities, {} relationships",
274 user_id, processed, total_memories, entities_created, relationships_created
275 );
276
277 state.emit_event(MemoryEvent {
278 event_type: "GRAPH_REBUILD".to_string(),
279 timestamp: chrono::Utc::now(),
280 user_id: user_id.clone(),
281 memory_id: None,
282 content_preview: Some(format!(
283 "Rebuilt: {} memories -> {} entities, {} relationships",
284 processed, entities_created, relationships_created
285 )),
286 memory_type: Some("graph".to_string()),
287 importance: None,
288 count: Some(entities_created + relationships_created),
289 results: None,
290 });
291
292 Ok(Json(serde_json::json!({
293 "success": true,
294 "processed_memories": processed,
295 "total_memories": total_memories,
296 "entities_created": entities_created,
297 "relationships_created": relationships_created
298 })))
299}
300
301#[derive(Debug, Deserialize)]
303pub struct InvalidateRelationshipRequest {
304 pub user_id: String,
305 pub relationship_uuid: String,
306}
307
308pub async fn invalidate_relationship(
310 State(state): State<AppState>,
311 Json(req): Json<InvalidateRelationshipRequest>,
312) -> Result<Json<serde_json::Value>, AppError> {
313 validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
314
315 let graph = state
316 .get_user_graph(&req.user_id)
317 .map_err(AppError::Internal)?;
318
319 let graph_guard = graph.write();
320
321 let rel_uuid =
322 uuid::Uuid::parse_str(&req.relationship_uuid).map_err(|_| AppError::InvalidInput {
323 field: "relationship_uuid".to_string(),
324 reason: "Invalid UUID format".to_string(),
325 })?;
326
327 graph_guard
328 .invalidate_relationship(&rel_uuid)
329 .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
330
331 state.emit_event(MemoryEvent {
332 event_type: "EDGE_INVALIDATE".to_string(),
333 timestamp: chrono::Utc::now(),
334 user_id: req.user_id.clone(),
335 memory_id: Some(req.relationship_uuid.clone()),
336 content_preview: Some("Relationship invalidated".to_string()),
337 memory_type: Some("graph".to_string()),
338 importance: None,
339 count: None,
340 results: None,
341 });
342
343 Ok(Json(serde_json::json!({
344 "success": true,
345 "message": "Relationship invalidated"
346 })))
347}