Skip to main content

spec_ai/spec_ai_api/api/
graph_handlers.rs

1/// Graph API handlers for direct knowledge graph access
2///
3/// These endpoints expose the knowledge graph as a generic key-value store
4/// with nodes and edges. Clients interpret the data in domain-specific ways.
5use crate::spec_ai_api::api::handlers::AppState;
6use crate::spec_ai_api::api::models::ErrorResponse;
7use axum::{
8    extract::{Json, Path, Query, State},
9    http::StatusCode,
10    response::{
11        IntoResponse, Response,
12        sse::{Event, Sse},
13    },
14};
15use futures::stream::Stream;
16use serde::{Deserialize, Serialize};
17use serde_json::Value as JsonValue;
18use crate::spec_ai_core::bootstrap_self::plugin::BootstrapPlugin;
19use crate::spec_ai_core::bootstrap_self::plugin::{BootstrapMode, PluginContext};
20use crate::spec_ai_core::bootstrap_self::plugins::universal_code::UniversalCodePlugin;
21use crate::spec_ai_knowledge_graph::{EdgeType, NodeType};
22use std::convert::Infallible;
23use std::time::Duration;
24
25// ============================================================================
26// Request/Response Types
27// ============================================================================
28
29/// Query parameters for listing nodes
30#[derive(Debug, Deserialize)]
31pub struct ListNodesQuery {
32    /// Session ID to scope the query
33    pub session_id: String,
34    /// Optional node type filter
35    pub node_type: Option<String>,
36    /// Maximum number of nodes to return
37    pub limit: Option<usize>,
38}
39
40/// Query parameters for listing edges
41#[derive(Debug, Deserialize)]
42pub struct ListEdgesQuery {
43    /// Session ID to scope the query
44    pub session_id: String,
45    /// Optional source node ID filter
46    pub source_id: Option<i64>,
47    /// Optional target node ID filter
48    pub target_id: Option<i64>,
49}
50
51/// Request to create a new node
52#[derive(Debug, Deserialize)]
53pub struct CreateNodeRequest {
54    /// Session ID for the node
55    pub session_id: String,
56    /// Node type (entity, concept, fact, message, tool_result, event, goal)
57    pub node_type: String,
58    /// Human-readable label
59    pub label: String,
60    /// Arbitrary properties as JSON
61    #[serde(default)]
62    pub properties: JsonValue,
63}
64
65/// Request to update a node's properties
66#[derive(Debug, Deserialize)]
67pub struct UpdateNodeRequest {
68    /// New properties (replaces existing)
69    pub properties: JsonValue,
70}
71
72/// Request to create a new edge
73#[derive(Debug, Deserialize)]
74pub struct CreateEdgeRequest {
75    /// Session ID for the edge
76    pub session_id: String,
77    /// Source node ID
78    pub source_id: i64,
79    /// Target node ID
80    pub target_id: i64,
81    /// Edge type
82    pub edge_type: String,
83    /// Optional predicate/relationship name
84    pub predicate: Option<String>,
85    /// Optional properties
86    pub properties: Option<JsonValue>,
87    /// Edge weight (0.0 to 1.0)
88    #[serde(default = "default_weight")]
89    pub weight: f32,
90}
91
92fn default_weight() -> f32 {
93    1.0
94}
95
96/// Response containing a single node
97#[derive(Debug, Serialize)]
98pub struct NodeResponse {
99    pub id: i64,
100    pub session_id: String,
101    pub node_type: String,
102    pub label: String,
103    pub properties: JsonValue,
104    pub created_at: String,
105    pub updated_at: String,
106}
107
108/// Response containing multiple nodes
109#[derive(Debug, Serialize)]
110pub struct NodesListResponse {
111    pub nodes: Vec<NodeResponse>,
112    pub count: usize,
113}
114
115/// Response containing a single edge
116#[derive(Debug, Serialize)]
117pub struct EdgeResponse {
118    pub id: i64,
119    pub session_id: String,
120    pub source_id: i64,
121    pub target_id: i64,
122    pub edge_type: String,
123    pub predicate: Option<String>,
124    pub properties: Option<JsonValue>,
125    pub weight: f32,
126    pub created_at: String,
127}
128
129/// Response containing multiple edges
130#[derive(Debug, Serialize)]
131pub struct EdgesListResponse {
132    pub edges: Vec<EdgeResponse>,
133    pub count: usize,
134}
135
136/// Query parameters for changelog stream
137#[derive(Debug, Deserialize)]
138pub struct ChangelogStreamQuery {
139    /// Session ID to watch
140    pub session_id: String,
141    /// Optional: only return changes after this timestamp (ISO 8601)
142    pub since: Option<String>,
143}
144
145/// A changelog event sent via SSE
146#[derive(Debug, Serialize)]
147pub struct ChangelogEvent {
148    pub entity_type: String,
149    pub entity_id: i64,
150    pub operation: String,
151    pub timestamp: String,
152    pub data: Option<JsonValue>,
153}
154
155// ============================================================================
156// Node Handlers
157// ============================================================================
158
159/// List nodes with optional filtering
160pub async fn list_nodes(
161    State(state): State<AppState>,
162    Query(query): Query<ListNodesQuery>,
163) -> Response {
164    let node_type = query.node_type.map(|s| NodeType::from_str(&s));
165    let limit = query.limit.map(|l| l as i64);
166
167    match state
168        .persistence
169        .list_graph_nodes(&query.session_id, node_type, limit)
170    {
171        Ok(nodes) => {
172            let response_nodes: Vec<NodeResponse> = nodes
173                .into_iter()
174                .map(|n| NodeResponse {
175                    id: n.id,
176                    session_id: n.session_id,
177                    node_type: n.node_type.as_str().to_string(),
178                    label: n.label,
179                    properties: n.properties,
180                    created_at: n.created_at.to_rfc3339(),
181                    updated_at: n.updated_at.to_rfc3339(),
182                })
183                .collect();
184
185            let count = response_nodes.len();
186            Json(NodesListResponse {
187                nodes: response_nodes,
188                count,
189            })
190            .into_response()
191        }
192        Err(e) => (
193            StatusCode::INTERNAL_SERVER_ERROR,
194            Json(ErrorResponse::new("database_error", e.to_string())),
195        )
196            .into_response(),
197    }
198}
199
200/// Get a single node by ID
201pub async fn get_node(State(state): State<AppState>, Path(node_id): Path<i64>) -> Response {
202    match state.persistence.get_graph_node(node_id) {
203        Ok(Some(node)) => Json(NodeResponse {
204            id: node.id,
205            session_id: node.session_id,
206            node_type: node.node_type.as_str().to_string(),
207            label: node.label,
208            properties: node.properties,
209            created_at: node.created_at.to_rfc3339(),
210            updated_at: node.updated_at.to_rfc3339(),
211        })
212        .into_response(),
213        Ok(None) => (
214            StatusCode::NOT_FOUND,
215            Json(ErrorResponse::new("not_found", "Node not found")),
216        )
217            .into_response(),
218        Err(e) => (
219            StatusCode::INTERNAL_SERVER_ERROR,
220            Json(ErrorResponse::new("database_error", e.to_string())),
221        )
222            .into_response(),
223    }
224}
225
226/// Create a new node
227pub async fn create_node(
228    State(state): State<AppState>,
229    Json(request): Json<CreateNodeRequest>,
230) -> Response {
231    let node_type = NodeType::from_str(&request.node_type);
232
233    match state.persistence.insert_graph_node(
234        &request.session_id,
235        node_type,
236        &request.label,
237        &request.properties,
238        None,
239    ) {
240        Ok(node_id) => {
241            // Fetch the created node to return it
242            match state.persistence.get_graph_node(node_id) {
243                Ok(Some(node)) => (
244                    StatusCode::CREATED,
245                    Json(NodeResponse {
246                        id: node.id,
247                        session_id: node.session_id,
248                        node_type: node.node_type.as_str().to_string(),
249                        label: node.label,
250                        properties: node.properties,
251                        created_at: node.created_at.to_rfc3339(),
252                        updated_at: node.updated_at.to_rfc3339(),
253                    }),
254                )
255                    .into_response(),
256                _ => (
257                    StatusCode::CREATED,
258                    Json(serde_json::json!({ "id": node_id })),
259                )
260                    .into_response(),
261            }
262        }
263        Err(e) => (
264            StatusCode::INTERNAL_SERVER_ERROR,
265            Json(ErrorResponse::new("database_error", e.to_string())),
266        )
267            .into_response(),
268    }
269}
270
271/// Update a node's properties
272pub async fn update_node(
273    State(state): State<AppState>,
274    Path(node_id): Path<i64>,
275    Json(request): Json<UpdateNodeRequest>,
276) -> Response {
277    // First check if node exists
278    match state.persistence.get_graph_node(node_id) {
279        Ok(None) => {
280            return (
281                StatusCode::NOT_FOUND,
282                Json(ErrorResponse::new("not_found", "Node not found")),
283            )
284                .into_response();
285        }
286        Err(e) => {
287            return (
288                StatusCode::INTERNAL_SERVER_ERROR,
289                Json(ErrorResponse::new("database_error", e.to_string())),
290            )
291                .into_response();
292        }
293        Ok(Some(_)) => {}
294    }
295
296    match state
297        .persistence
298        .update_graph_node(node_id, &request.properties)
299    {
300        Ok(()) => {
301            // Fetch updated node
302            match state.persistence.get_graph_node(node_id) {
303                Ok(Some(node)) => Json(NodeResponse {
304                    id: node.id,
305                    session_id: node.session_id,
306                    node_type: node.node_type.as_str().to_string(),
307                    label: node.label,
308                    properties: node.properties,
309                    created_at: node.created_at.to_rfc3339(),
310                    updated_at: node.updated_at.to_rfc3339(),
311                })
312                .into_response(),
313                _ => StatusCode::NO_CONTENT.into_response(),
314            }
315        }
316        Err(e) => (
317            StatusCode::INTERNAL_SERVER_ERROR,
318            Json(ErrorResponse::new("database_error", e.to_string())),
319        )
320            .into_response(),
321    }
322}
323
324/// Delete a node
325pub async fn delete_node(State(state): State<AppState>, Path(node_id): Path<i64>) -> Response {
326    match state.persistence.delete_graph_node(node_id) {
327        Ok(()) => StatusCode::NO_CONTENT.into_response(),
328        Err(e) => (
329            StatusCode::INTERNAL_SERVER_ERROR,
330            Json(ErrorResponse::new("database_error", e.to_string())),
331        )
332            .into_response(),
333    }
334}
335
336// ============================================================================
337// Edge Handlers
338// ============================================================================
339
340/// List edges with optional filtering
341pub async fn list_edges(
342    State(state): State<AppState>,
343    Query(query): Query<ListEdgesQuery>,
344) -> Response {
345    match state
346        .persistence
347        .list_graph_edges(&query.session_id, query.source_id, query.target_id)
348    {
349        Ok(edges) => {
350            let response_edges: Vec<EdgeResponse> = edges
351                .into_iter()
352                .map(|e| EdgeResponse {
353                    id: e.id,
354                    session_id: e.session_id,
355                    source_id: e.source_id,
356                    target_id: e.target_id,
357                    edge_type: e.edge_type.as_str(),
358                    predicate: e.predicate,
359                    properties: e.properties,
360                    weight: e.weight,
361                    created_at: e.created_at.to_rfc3339(),
362                })
363                .collect();
364
365            let count = response_edges.len();
366            Json(EdgesListResponse {
367                edges: response_edges,
368                count,
369            })
370            .into_response()
371        }
372        Err(e) => (
373            StatusCode::INTERNAL_SERVER_ERROR,
374            Json(ErrorResponse::new("database_error", e.to_string())),
375        )
376            .into_response(),
377    }
378}
379
380/// Get a single edge by ID
381pub async fn get_edge(State(state): State<AppState>, Path(edge_id): Path<i64>) -> Response {
382    match state.persistence.get_graph_edge(edge_id) {
383        Ok(Some(edge)) => Json(EdgeResponse {
384            id: edge.id,
385            session_id: edge.session_id,
386            source_id: edge.source_id,
387            target_id: edge.target_id,
388            edge_type: edge.edge_type.as_str(),
389            predicate: edge.predicate,
390            properties: edge.properties,
391            weight: edge.weight,
392            created_at: edge.created_at.to_rfc3339(),
393        })
394        .into_response(),
395        Ok(None) => (
396            StatusCode::NOT_FOUND,
397            Json(ErrorResponse::new("not_found", "Edge not found")),
398        )
399            .into_response(),
400        Err(e) => (
401            StatusCode::INTERNAL_SERVER_ERROR,
402            Json(ErrorResponse::new("database_error", e.to_string())),
403        )
404            .into_response(),
405    }
406}
407
408/// Create a new edge
409pub async fn create_edge(
410    State(state): State<AppState>,
411    Json(request): Json<CreateEdgeRequest>,
412) -> Response {
413    let edge_type = EdgeType::from_str(&request.edge_type);
414
415    match state.persistence.insert_graph_edge(
416        &request.session_id,
417        request.source_id,
418        request.target_id,
419        edge_type,
420        request.predicate.as_deref(),
421        request.properties.as_ref(),
422        request.weight,
423    ) {
424        Ok(edge_id) => {
425            // Fetch the created edge to return it
426            match state.persistence.get_graph_edge(edge_id) {
427                Ok(Some(edge)) => (
428                    StatusCode::CREATED,
429                    Json(EdgeResponse {
430                        id: edge.id,
431                        session_id: edge.session_id,
432                        source_id: edge.source_id,
433                        target_id: edge.target_id,
434                        edge_type: edge.edge_type.as_str(),
435                        predicate: edge.predicate,
436                        properties: edge.properties,
437                        weight: edge.weight,
438                        created_at: edge.created_at.to_rfc3339(),
439                    }),
440                )
441                    .into_response(),
442                _ => (
443                    StatusCode::CREATED,
444                    Json(serde_json::json!({ "id": edge_id })),
445                )
446                    .into_response(),
447            }
448        }
449        Err(e) => (
450            StatusCode::INTERNAL_SERVER_ERROR,
451            Json(ErrorResponse::new("database_error", e.to_string())),
452        )
453            .into_response(),
454    }
455}
456
457/// Delete an edge
458pub async fn delete_edge(State(state): State<AppState>, Path(edge_id): Path<i64>) -> Response {
459    match state.persistence.delete_graph_edge(edge_id) {
460        Ok(()) => StatusCode::NO_CONTENT.into_response(),
461        Err(e) => (
462            StatusCode::INTERNAL_SERVER_ERROR,
463            Json(ErrorResponse::new("database_error", e.to_string())),
464        )
465            .into_response(),
466    }
467}
468
469// ============================================================================
470// Changelog Stream (SSE)
471// ============================================================================
472
473/// Stream changelog events via Server-Sent Events
474pub async fn stream_changelog(
475    State(state): State<AppState>,
476    Query(query): Query<ChangelogStreamQuery>,
477) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
478    let session_id = query.session_id;
479    let since = query
480        .since
481        .unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
482
483    let stream = async_stream::stream! {
484        let mut last_timestamp = since;
485        let mut interval = tokio::time::interval(Duration::from_millis(500));
486
487        loop {
488            interval.tick().await;
489
490            // Poll for new changelog entries
491            match state.persistence.graph_changelog_get_since(&session_id, &last_timestamp) {
492                Ok(entries) => {
493                    for entry in entries {
494                        let timestamp_str = entry.created_at.to_rfc3339();
495                        let event = ChangelogEvent {
496                            entity_type: entry.entity_type.clone(),
497                            entity_id: entry.entity_id,
498                            operation: entry.operation.clone(),
499                            timestamp: timestamp_str.clone(),
500                            data: entry.data.and_then(|s| serde_json::from_str(&s).ok()),
501                        };
502
503                        // Update last timestamp for next poll
504                        last_timestamp = timestamp_str;
505
506                        if let Ok(json) = serde_json::to_string(&event) {
507                            yield Ok(Event::default().data(json));
508                        }
509                    }
510                }
511                Err(e) => {
512                    tracing::warn!("Changelog poll error: {}", e);
513                }
514            }
515        }
516    };
517
518    Sse::new(stream).keep_alive(
519        axum::response::sse::KeepAlive::new()
520            .interval(Duration::from_secs(15))
521            .text("ping"),
522    )
523}
524
525// ============================================================================
526// Bootstrap Handler
527// ============================================================================
528
529/// Request to bootstrap a knowledge graph from a directory
530#[derive(Debug, Deserialize)]
531pub struct BootstrapRequest {
532    /// Session ID for the graph (optional, defaults to "visionos-dashboard")
533    pub session_id: Option<String>,
534}
535
536/// Response from bootstrap operation
537#[derive(Debug, Serialize)]
538pub struct BootstrapResponse {
539    pub session_id: String,
540    pub nodes_created: usize,
541    pub edges_created: usize,
542    pub root_node_id: Option<i64>,
543}
544
545/// Bootstrap a knowledge graph from the server's current working directory
546pub async fn bootstrap_graph(
547    State(state): State<AppState>,
548    Json(request): Json<BootstrapRequest>,
549) -> Response {
550    let session_id = request
551        .session_id
552        .unwrap_or_else(|| "visionos-dashboard".to_string());
553
554    // Get current working directory
555    let cwd = match std::env::current_dir() {
556        Ok(path) => path,
557        Err(e) => {
558            return (
559                StatusCode::INTERNAL_SERVER_ERROR,
560                Json(ErrorResponse::new(
561                    "cwd_error",
562                    format!("Failed to get current directory: {}", e),
563                )),
564            )
565                .into_response();
566        }
567    };
568
569    tracing::info!("Bootstrapping knowledge graph from: {:?}", cwd);
570
571    // Create plugin context
572    let context = PluginContext {
573        persistence: &state.persistence,
574        session_id: &session_id,
575        repo_root: &cwd,
576        mode: BootstrapMode::Fresh,
577    };
578
579    // Run the universal code plugin
580    let plugin = UniversalCodePlugin;
581
582    if !plugin.should_activate(&cwd) {
583        return (
584            StatusCode::BAD_REQUEST,
585            Json(ErrorResponse::new(
586                "not_a_repository",
587                "Current directory does not appear to be a code repository",
588            )),
589        )
590            .into_response();
591    }
592
593    match plugin.run(context) {
594        Ok(outcome) => {
595            tracing::info!(
596                "Bootstrap complete: {} nodes, {} edges created",
597                outcome.nodes_created,
598                outcome.edges_created
599            );
600            (
601                StatusCode::CREATED,
602                Json(BootstrapResponse {
603                    session_id,
604                    nodes_created: outcome.nodes_created,
605                    edges_created: outcome.edges_created,
606                    root_node_id: outcome.root_node_id,
607                }),
608            )
609                .into_response()
610        }
611        Err(e) => {
612            tracing::error!("Bootstrap failed: {}", e);
613            (
614                StatusCode::INTERNAL_SERVER_ERROR,
615                Json(ErrorResponse::new("bootstrap_error", e.to_string())),
616            )
617                .into_response()
618        }
619    }
620}