Skip to main content

sparrow/tools/
knowledge_graph.rs

1use async_trait::async_trait;
2use serde_json::json;
3use std::sync::Arc;
4
5use super::{Tool, ToolCtx, ToolResult};
6use crate::event::{Block, RiskLevel};
7use crate::memory::{GraphDirection, GraphEdge, GraphNode, KnowledgeGraph, Memory};
8
9pub struct KnowledgeGraphTool {
10    memory: Arc<dyn Memory>,
11}
12
13impl KnowledgeGraphTool {
14    pub fn new(memory: Arc<dyn Memory>) -> Self {
15        Self { memory }
16    }
17}
18
19#[async_trait]
20impl Tool for KnowledgeGraphTool {
21    fn name(&self) -> &str {
22        "knowledge_graph"
23    }
24
25    fn description(&self) -> &str {
26        "Manage Sparrow's persistent knowledge graph: upsert nodes/edges, search, inspect neighbors, export, delete, and optionally sync to local Neo4j."
27    }
28
29    fn schema(&self) -> serde_json::Value {
30        json!({
31            "type": "object",
32            "properties": {
33                "action": {
34                    "type": "string",
35                    "enum": ["upsert_node", "upsert_edge", "get", "neighbors", "search", "export", "delete_node", "delete_edge", "sync_neo4j"]
36                },
37                "id": { "type": "string", "description": "Node id, edge id, or lookup id" },
38                "label": { "type": "string", "description": "Human-readable node label" },
39                "kind": { "type": "string", "description": "Node kind, e.g. user, project, file, decision, feature" },
40                "from_id": { "type": "string", "description": "Source node id for edges" },
41                "to_id": { "type": "string", "description": "Target node id for edges" },
42                "relation": { "type": "string", "description": "Edge relation, e.g. works_on, depends_on, decided" },
43                "weight": { "type": "number", "description": "Edge weight; defaults to 1.0" },
44                "properties": { "type": "object", "description": "JSON metadata for a node or edge" },
45                "query": { "type": "string", "description": "Search query" },
46                "direction": { "type": "string", "enum": ["incoming", "outgoing", "both"] },
47                "limit": { "type": "integer", "description": "Maximum rows to return" }
48            },
49            "required": ["action"]
50        })
51    }
52
53    fn risk(&self) -> RiskLevel {
54        RiskLevel::Mutating
55    }
56
57    async fn call(&self, args: serde_json::Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
58        let action = args["action"].as_str().unwrap_or("search");
59        let limit = args["limit"].as_u64().unwrap_or(20).clamp(1, 100) as usize;
60        match action {
61            "upsert_node" => {
62                let node = node_from_args(&args)?;
63                self.memory.upsert_graph_node(node.clone())?;
64                Ok(ToolResult::text(format!(
65                    "graph node stored: {} [{}] {}",
66                    node.id, node.kind, node.label
67                )))
68            }
69            "upsert_edge" => {
70                let edge = edge_from_args(&args)?;
71                self.memory.upsert_graph_edge(edge.clone())?;
72                Ok(ToolResult::text(format!(
73                    "graph edge stored: {} {} -{}-> {}",
74                    edge.id, edge.from_id, edge.relation, edge.to_id
75                )))
76            }
77            "get" => {
78                let id = required_str(&args, "id")?;
79                match self.memory.graph_node(&id) {
80                    Some(node) => Ok(ToolResult::ok(vec![Block::Json(json!(node))])),
81                    None => Ok(ToolResult::error(format!("graph node not found: {}", id))),
82                }
83            }
84            "neighbors" => {
85                let id = required_str(&args, "id")?;
86                let direction = GraphDirection::parse(args["direction"].as_str().unwrap_or("both"));
87                let rows = self.memory.graph_neighbors(&id, direction, limit);
88                if rows.is_empty() {
89                    return Ok(ToolResult::text(format!("no graph neighbors for {}", id)));
90                }
91                Ok(ToolResult::ok(vec![Block::Json(json!(rows))]))
92            }
93            "search" => {
94                let query = args["query"].as_str().unwrap_or("");
95                if query.trim().is_empty() {
96                    return Ok(ToolResult::error("knowledge_graph search requires query"));
97                }
98                let nodes = self.memory.search_graph(query, limit);
99                Ok(ToolResult::ok(vec![Block::Json(json!(nodes))]))
100            }
101            "export" => {
102                let graph = self.memory.graph_export();
103                Ok(ToolResult::ok(vec![Block::Json(json!(graph))]))
104            }
105            "delete_node" => {
106                let id = required_str(&args, "id")?;
107                self.memory.delete_graph_node(&id)?;
108                Ok(ToolResult::text(format!("graph node deleted: {}", id)))
109            }
110            "delete_edge" => {
111                let id = required_str(&args, "id")?;
112                self.memory.delete_graph_edge(&id)?;
113                Ok(ToolResult::text(format!("graph edge deleted: {}", id)))
114            }
115            "sync_neo4j" => {
116                let graph = self.memory.graph_export();
117                let count = sync_graph_to_neo4j(&graph).await?;
118                Ok(ToolResult::text(format!(
119                    "synced graph to Neo4j: {} nodes, {} edges, {} statements",
120                    graph.nodes.len(),
121                    graph.edges.len(),
122                    count
123                )))
124            }
125            _ => Ok(ToolResult::error("unknown knowledge_graph action")),
126        }
127    }
128}
129
130pub async fn sync_graph_to_neo4j(graph: &KnowledgeGraph) -> anyhow::Result<usize> {
131    let url = std::env::var("NEO4J_URL")
132        .unwrap_or_else(|_| "http://127.0.0.1:7474/db/neo4j/tx/commit".into());
133    let user = std::env::var("NEO4J_USER")
134        .or_else(|_| std::env::var("NEO4J_USERNAME"))
135        .map_err(|_| anyhow::anyhow!("Neo4j sync requires NEO4J_USER and NEO4J_PASSWORD"))?;
136    let password = std::env::var("NEO4J_PASSWORD")
137        .map_err(|_| anyhow::anyhow!("Neo4j sync requires NEO4J_USER and NEO4J_PASSWORD"))?;
138    let statements = neo4j_statements(graph);
139    if statements.is_empty() {
140        return Ok(0);
141    }
142    let resp = reqwest::Client::new()
143        .post(url)
144        .basic_auth(user, Some(password))
145        .json(&json!({ "statements": statements }))
146        .send()
147        .await?;
148    let status = resp.status();
149    let body: serde_json::Value = resp.json().await.unwrap_or_else(|_| json!({}));
150    if !status.is_success() || body["errors"].as_array().is_some_and(|e| !e.is_empty()) {
151        anyhow::bail!("Neo4j sync failed: HTTP {} {}", status, body);
152    }
153    Ok(statements.len())
154}
155
156fn neo4j_statements(graph: &KnowledgeGraph) -> Vec<serde_json::Value> {
157    let mut statements = Vec::new();
158    for node in &graph.nodes {
159        statements.push(json!({
160            "statement": "MERGE (n:SparrowNode {id: $id}) SET n.label = $label, n.kind = $kind, n.properties = $properties, n.updated_at = $updated_at",
161            "parameters": {
162                "id": node.id,
163                "label": node.label,
164                "kind": node.kind,
165                "properties": node.properties.to_string(),
166                "updated_at": node.updated_at,
167            }
168        }));
169    }
170    for edge in &graph.edges {
171        statements.push(json!({
172            "statement": "MATCH (a:SparrowNode {id: $from_id}), (b:SparrowNode {id: $to_id}) MERGE (a)-[r:SPARROW_REL {id: $id}]->(b) SET r.relation = $relation, r.weight = $weight, r.properties = $properties, r.updated_at = $updated_at",
173            "parameters": {
174                "id": edge.id,
175                "from_id": edge.from_id,
176                "to_id": edge.to_id,
177                "relation": edge.relation,
178                "weight": edge.weight,
179                "properties": edge.properties.to_string(),
180                "updated_at": edge.updated_at,
181            }
182        }));
183    }
184    statements
185}
186
187fn node_from_args(args: &serde_json::Value) -> anyhow::Result<GraphNode> {
188    let id = required_str(args, "id")?;
189    let label = required_str(args, "label")?;
190    let kind = args["kind"].as_str().unwrap_or("entity").trim().to_string();
191    let now = chrono::Utc::now().to_rfc3339();
192    Ok(GraphNode {
193        id,
194        label,
195        kind,
196        properties: args.get("properties").cloned().unwrap_or_else(|| json!({})),
197        created_at: now.clone(),
198        updated_at: now,
199    })
200}
201
202fn edge_from_args(args: &serde_json::Value) -> anyhow::Result<GraphEdge> {
203    let id = args["id"]
204        .as_str()
205        .map(|s| s.trim().to_string())
206        .filter(|s| !s.is_empty())
207        .unwrap_or_else(|| {
208            format!(
209                "{}:{}:{}",
210                args["from_id"].as_str().unwrap_or(""),
211                args["relation"].as_str().unwrap_or("related_to"),
212                args["to_id"].as_str().unwrap_or("")
213            )
214        });
215    let now = chrono::Utc::now().to_rfc3339();
216    Ok(GraphEdge {
217        id,
218        from_id: required_str(args, "from_id")?,
219        to_id: required_str(args, "to_id")?,
220        relation: args["relation"]
221            .as_str()
222            .unwrap_or("related_to")
223            .trim()
224            .to_string(),
225        weight: args["weight"].as_f64().unwrap_or(1.0),
226        properties: args.get("properties").cloned().unwrap_or_else(|| json!({})),
227        created_at: now.clone(),
228        updated_at: now,
229    })
230}
231
232fn required_str(args: &serde_json::Value, key: &str) -> anyhow::Result<String> {
233    args[key]
234        .as_str()
235        .map(str::trim)
236        .filter(|value| !value.is_empty())
237        .map(str::to_string)
238        .ok_or_else(|| anyhow::anyhow!("knowledge_graph requires '{}'", key))
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn neo4j_payload_uses_parameterized_statements() {
247        let graph = KnowledgeGraph {
248            nodes: vec![GraphNode {
249                id: "user:abdou".into(),
250                label: "Abdou".into(),
251                kind: "user".into(),
252                properties: json!({"prefers": "local-first"}),
253                created_at: "2026-06-04T00:00:00Z".into(),
254                updated_at: "2026-06-04T00:00:00Z".into(),
255            }],
256            edges: vec![GraphEdge {
257                id: "user:abdou:works_on:project:sparrow".into(),
258                from_id: "user:abdou".into(),
259                to_id: "project:sparrow".into(),
260                relation: "works_on".into(),
261                weight: 1.0,
262                properties: json!({}),
263                created_at: "2026-06-04T00:00:00Z".into(),
264                updated_at: "2026-06-04T00:00:00Z".into(),
265            }],
266        };
267        let statements = neo4j_statements(&graph);
268        assert_eq!(statements.len(), 2);
269        assert!(
270            statements[0]["statement"]
271                .as_str()
272                .unwrap()
273                .contains("MERGE")
274        );
275        assert_eq!(statements[0]["parameters"]["id"], "user:abdou");
276    }
277}