spec_ai_core/tools/builtin/
graph.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use serde_json::{json, Map, Value};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::persistence::Persistence;
8use crate::tools::{Tool, ToolResult};
9use crate::types::{EdgeType, NodeType, TraversalDirection};
10
11pub struct GraphTool {
12    persistence: Arc<Persistence>,
13}
14
15impl GraphTool {
16    pub fn new(persistence: Arc<Persistence>) -> Self {
17        Self { persistence }
18    }
19}
20
21#[async_trait]
22impl Tool for GraphTool {
23    fn name(&self) -> &str {
24        "graph"
25    }
26
27    fn description(&self) -> &str {
28        "Create, query, traverse, and synchronize knowledge graphs. Supports operations: \
29         create_node, create_edge, delete_node, delete_edge, get_node, get_edge, \
30         list_nodes, list_edges, find_path, traverse_neighbors, update_node, \
31         node_degree, list_hubs, enable_sync, disable_sync, sync_status, force_sync, \
32         list_sync_configs"
33    }
34
35    fn parameters(&self) -> Value {
36        json!({
37            "type": "object",
38            "properties": {
39                "operation": {
40                    "type": "string",
41                    "enum": [
42                        "create_node", "create_edge", "delete_node", "delete_edge",
43                        "get_node", "get_edge", "list_nodes", "list_edges",
44                        "find_path", "traverse_neighbors", "update_node",
45                        "node_degree", "list_hubs",
46                        "enable_sync", "disable_sync", "sync_status", "force_sync",
47                        "list_sync_configs"
48                    ],
49                    "description": "The graph operation to perform"
50                },
51                "session_id": {
52                    "type": "string",
53                    "description": "Session ID for graph isolation"
54                },
55                "node_id": {
56                    "type": "integer",
57                    "description": "Node ID (for get_node, delete_node, update_node, traverse_neighbors)"
58                },
59                "edge_id": {
60                    "type": "integer",
61                    "description": "Edge ID (for get_edge, delete_edge)"
62                },
63                "node_type": {
64                    "type": "string",
65                    "enum": ["entity", "concept", "fact", "message", "tool_result", "event"],
66                    "description": "Type of node to create or filter by"
67                },
68                "label": {
69                    "type": "string",
70                    "description": "Semantic label for the node (e.g., 'Person', 'Location', 'Action')"
71                },
72                "properties": {
73                    "type": "object",
74                    "description": "JSON properties for the node or edge"
75                },
76                "source_id": {
77                    "type": "integer",
78                    "description": "Source node ID for edge creation or path finding"
79                },
80                "target_id": {
81                    "type": "integer",
82                    "description": "Target node ID for edge creation or path finding"
83                },
84                "edge_type": {
85                    "type": "string",
86                    "enum": [
87                        "RELATES_TO", "CAUSED_BY", "PART_OF", "MENTIONS",
88                        "FOLLOWS_FROM", "USES", "PRODUCES", "DEPENDS_ON"
89                    ],
90                    "description": "Type of edge relationship"
91                },
92                "custom_edge_type": {
93                    "type": "string",
94                    "description": "Custom edge type if not using predefined types"
95                },
96                "predicate": {
97                    "type": "string",
98                    "description": "RDF-style predicate for the edge"
99                },
100                "weight": {
101                    "type": "number",
102                    "default": 1.0,
103                    "description": "Weight for the edge"
104                },
105                "direction": {
106                    "type": "string",
107                    "enum": ["outgoing", "incoming", "both"],
108                    "default": "outgoing",
109                    "description": "Direction for traversal and degree-based operations"
110                },
111                "depth": {
112                    "type": "integer",
113                    "default": 1,
114                    "minimum": 1,
115                    "maximum": 10,
116                    "description": "Depth for traversal operations"
117                },
118                "max_hops": {
119                    "type": "integer",
120                    "default": 10,
121                    "minimum": 1,
122                    "maximum": 20,
123                    "description": "Maximum hops for path finding"
124                },
125                "limit": {
126                    "type": "integer",
127                    "default": 100,
128                    "minimum": 1,
129                    "maximum": 1000,
130                    "description": "Limit for list operations"
131                },
132                "min_degree": {
133                    "type": "integer",
134                    "default": 1,
135                    "minimum": 0,
136                    "description": "Minimum degree threshold when listing hubs"
137                },
138                "graph_name": {
139                    "type": "string",
140                    "default": "default",
141                    "description": "Graph name for sync operations"
142                },
143                "peer_instance_id": {
144                    "type": "string",
145                    "description": "Peer instance ID to sync with (for force_sync)"
146                },
147                "sync_enabled": {
148                    "type": "boolean",
149                    "description": "Enable or disable sync (for enable_sync/disable_sync)"
150                }
151            },
152            "required": ["operation", "session_id"]
153        })
154    }
155
156    async fn execute(&self, args: Value) -> Result<ToolResult> {
157        let operation = args["operation"]
158            .as_str()
159            .context("operation must be a string")?;
160
161        let session_id = args["session_id"]
162            .as_str()
163            .context("session_id must be a string")?;
164
165        // Clone persistence for use in spawn_blocking
166        let persistence = Arc::clone(&self.persistence);
167
168        match operation {
169            "create_node" => {
170                let node_type = args["node_type"]
171                    .as_str()
172                    .context("node_type is required for create_node")?;
173                let label = args["label"]
174                    .as_str()
175                    .context("label is required for create_node")?;
176                let properties = args["properties"].clone();
177
178                let node_type = NodeType::from_str(node_type);
179                let session_id = session_id.to_string();
180                let label = label.to_string();
181
182                let result = tokio::task::spawn_blocking(move || {
183                    persistence.insert_graph_node(&session_id, node_type, &label, &properties, None)
184                })
185                .await
186                .context("task join error")??;
187
188                Ok(ToolResult::success(
189                    json!({
190                        "node_id": result,
191                        "message": format!("Created node with ID {}", result)
192                    })
193                    .to_string(),
194                ))
195            }
196
197            "create_edge" => {
198                let source_id = args["source_id"]
199                    .as_i64()
200                    .context("source_id is required for create_edge")?;
201                let target_id = args["target_id"]
202                    .as_i64()
203                    .context("target_id is required for create_edge")?;
204
205                let edge_type = if let Some(custom) = args["custom_edge_type"].as_str() {
206                    EdgeType::Custom(custom.to_string())
207                } else if let Some(et) = args["edge_type"].as_str() {
208                    EdgeType::from_str(et)
209                } else {
210                    EdgeType::RelatesTo
211                };
212
213                let predicate = args["predicate"].as_str().map(|s| s.to_string());
214                let properties = if args["properties"].is_null() {
215                    None
216                } else {
217                    Some(args["properties"].clone())
218                };
219                let weight = args["weight"].as_f64().unwrap_or(1.0) as f32;
220                let session_id = session_id.to_string();
221
222                let result = tokio::task::spawn_blocking(move || {
223                    persistence.insert_graph_edge(
224                        &session_id,
225                        source_id,
226                        target_id,
227                        edge_type,
228                        predicate.as_deref(),
229                        properties.as_ref(),
230                        weight,
231                    )
232                })
233                .await
234                .context("task join error")??;
235
236                Ok(ToolResult::success(
237                    json!({
238                        "edge_id": result,
239                        "message": format!("Created edge with ID {}", result)
240                    })
241                    .to_string(),
242                ))
243            }
244
245            "get_node" => {
246                let node_id = args["node_id"]
247                    .as_i64()
248                    .context("node_id is required for get_node")?;
249
250                let result =
251                    tokio::task::spawn_blocking(move || persistence.get_graph_node(node_id))
252                        .await
253                        .context("task join error")??;
254
255                match result {
256                    Some(node) => Ok(ToolResult::success(serde_json::to_string_pretty(&node)?)),
257                    None => Ok(ToolResult::failure(format!("Node {} not found", node_id))),
258                }
259            }
260
261            "get_edge" => {
262                let edge_id = args["edge_id"]
263                    .as_i64()
264                    .context("edge_id is required for get_edge")?;
265
266                let result =
267                    tokio::task::spawn_blocking(move || persistence.get_graph_edge(edge_id))
268                        .await
269                        .context("task join error")??;
270
271                match result {
272                    Some(edge) => Ok(ToolResult::success(serde_json::to_string_pretty(&edge)?)),
273                    None => Ok(ToolResult::failure(format!("Edge {} not found", edge_id))),
274                }
275            }
276
277            "list_nodes" => {
278                let node_type = args["node_type"].as_str().map(NodeType::from_str);
279                let limit = args["limit"].as_i64().or(Some(100));
280                let session_id = session_id.to_string();
281
282                let result = tokio::task::spawn_blocking(move || {
283                    persistence.list_graph_nodes(&session_id, node_type, limit)
284                })
285                .await
286                .context("task join error")??;
287
288                Ok(ToolResult::success(
289                    json!({
290                        "count": result.len(),
291                        "nodes": result
292                    })
293                    .to_string(),
294                ))
295            }
296
297            "list_edges" => {
298                let source_id = args["source_id"].as_i64();
299                let target_id = args["target_id"].as_i64();
300                let session_id = session_id.to_string();
301
302                let result = tokio::task::spawn_blocking(move || {
303                    persistence.list_graph_edges(&session_id, source_id, target_id)
304                })
305                .await
306                .context("task join error")??;
307
308                Ok(ToolResult::success(
309                    json!({
310                        "count": result.len(),
311                        "edges": result
312                    })
313                    .to_string(),
314                ))
315            }
316
317            "delete_node" => {
318                let node_id = args["node_id"]
319                    .as_i64()
320                    .context("node_id is required for delete_node")?;
321
322                tokio::task::spawn_blocking(move || persistence.delete_graph_node(node_id))
323                    .await
324                    .context("task join error")??;
325
326                Ok(ToolResult::success(format!("Deleted node {}", node_id)))
327            }
328
329            "delete_edge" => {
330                let edge_id = args["edge_id"]
331                    .as_i64()
332                    .context("edge_id is required for delete_edge")?;
333
334                tokio::task::spawn_blocking(move || persistence.delete_graph_edge(edge_id))
335                    .await
336                    .context("task join error")??;
337
338                Ok(ToolResult::success(format!("Deleted edge {}", edge_id)))
339            }
340
341            "update_node" => {
342                let node_id = args["node_id"]
343                    .as_i64()
344                    .context("node_id is required for update_node")?;
345                let properties = args["properties"].clone();
346
347                tokio::task::spawn_blocking(move || {
348                    persistence.update_graph_node(node_id, &properties)
349                })
350                .await
351                .context("task join error")??;
352
353                Ok(ToolResult::success(format!("Updated node {}", node_id)))
354            }
355
356            "node_degree" => {
357                let node_id = args["node_id"]
358                    .as_i64()
359                    .context("node_id is required for node_degree")?;
360                let edge_type_filter = args["edge_type"].as_str().map(EdgeType::from_str);
361                let session_id = session_id.to_string();
362
363                let (in_degree, out_degree, by_type) = tokio::task::spawn_blocking(move || {
364                    let edges = persistence.list_graph_edges(&session_id, None, None)?;
365                    let mut in_degree: i64 = 0;
366                    let mut out_degree: i64 = 0;
367                    let mut by_type: HashMap<String, (i64, i64)> = HashMap::new();
368
369                    for edge in edges {
370                        if let Some(ref filter) = edge_type_filter {
371                            if &edge.edge_type != filter {
372                                continue;
373                            }
374                        }
375
376                        let key = edge.edge_type.as_str();
377
378                        if edge.source_id == node_id {
379                            out_degree += 1;
380                            let entry = by_type.entry(key.clone()).or_insert((0, 0));
381                            entry.1 += 1;
382                        }
383                        if edge.target_id == node_id {
384                            in_degree += 1;
385                            let entry = by_type.entry(key.clone()).or_insert((0, 0));
386                            entry.0 += 1;
387                        }
388                    }
389
390                    Ok::<_, anyhow::Error>((in_degree, out_degree, by_type))
391                })
392                .await
393                .context("task join error")??;
394
395                let total_degree = in_degree + out_degree;
396
397                let mut by_type_json = Map::new();
398                for (edge_type, (in_d, out_d)) in by_type {
399                    by_type_json.insert(
400                        edge_type,
401                        json!({
402                            "in_degree": in_d,
403                            "out_degree": out_d,
404                            "total_degree": in_d + out_d
405                        }),
406                    );
407                }
408
409                Ok(ToolResult::success(
410                    json!({
411                        "node_id": node_id,
412                        "in_degree": in_degree,
413                        "out_degree": out_degree,
414                        "total_degree": total_degree,
415                        "by_edge_type": by_type_json
416                    })
417                    .to_string(),
418                ))
419            }
420
421            "find_path" => {
422                let source_id = args["source_id"]
423                    .as_i64()
424                    .context("source_id is required for find_path")?;
425                let target_id = args["target_id"]
426                    .as_i64()
427                    .context("target_id is required for find_path")?;
428                let max_hops = args["max_hops"].as_u64().map(|h| h as usize);
429                let session_id = session_id.to_string();
430
431                let result = tokio::task::spawn_blocking(move || {
432                    persistence.find_shortest_path(&session_id, source_id, target_id, max_hops)
433                })
434                .await
435                .context("task join error")??;
436
437                match result {
438                    Some(path) => Ok(ToolResult::success(
439                        json!({
440                            "found": true,
441                            "length": path.length,
442                            "total_weight": path.weight,
443                            "path": path
444                        })
445                        .to_string(),
446                    )),
447                    None => Ok(ToolResult::success(
448                        json!({
449                            "found": false,
450                            "message": format!("No path found from {} to {}", source_id, target_id)
451                        })
452                        .to_string(),
453                    )),
454                }
455            }
456
457            "traverse_neighbors" => {
458                let node_id = args["node_id"]
459                    .as_i64()
460                    .context("node_id is required for traverse_neighbors")?;
461                let depth = args["depth"].as_u64().unwrap_or(1) as usize;
462                let direction = args["direction"]
463                    .as_str()
464                    .map(|d| match d {
465                        "incoming" => TraversalDirection::Incoming,
466                        "both" => TraversalDirection::Both,
467                        _ => TraversalDirection::Outgoing,
468                    })
469                    .unwrap_or(TraversalDirection::Outgoing);
470                let session_id = session_id.to_string();
471
472                let result = tokio::task::spawn_blocking(move || {
473                    persistence.traverse_neighbors(&session_id, node_id, direction, depth)
474                })
475                .await
476                .context("task join error")??;
477
478                Ok(ToolResult::success(
479                    json!({
480                        "count": result.len(),
481                        "neighbors": result
482                    })
483                    .to_string(),
484                ))
485            }
486
487            "list_hubs" => {
488                let direction = args["direction"]
489                    .as_str()
490                    .map(|d| match d {
491                        "incoming" => TraversalDirection::Incoming,
492                        "both" => TraversalDirection::Both,
493                        _ => TraversalDirection::Outgoing,
494                    })
495                    .unwrap_or(TraversalDirection::Outgoing);
496                let min_degree = args["min_degree"].as_i64().unwrap_or(1).max(0);
497                let limit = args["limit"].as_i64().unwrap_or(10).max(1);
498                let edge_type_filter = args["edge_type"].as_str().map(EdgeType::from_str);
499                let session_id = session_id.to_string();
500
501                let hubs = tokio::task::spawn_blocking(move || {
502                    let edges = persistence.list_graph_edges(&session_id, None, None)?;
503                    let mut degrees: HashMap<i64, (i64, i64)> = HashMap::new();
504
505                    for edge in edges {
506                        if let Some(ref filter) = edge_type_filter {
507                            if &edge.edge_type != filter {
508                                continue;
509                            }
510                        }
511
512                        // out-degree for source
513                        {
514                            let entry = degrees.entry(edge.source_id).or_insert((0, 0));
515                            entry.1 += 1;
516                        }
517                        // in-degree for target
518                        {
519                            let entry = degrees.entry(edge.target_id).or_insert((0, 0));
520                            entry.0 += 1;
521                        }
522                    }
523
524                    // Convert to vector and filter by min_degree and direction
525                    let mut nodes_with_degree: Vec<(i64, i64, i64, i64)> = degrees
526                        .into_iter()
527                        .map(|(node_id, (in_d, out_d))| {
528                            let total = in_d + out_d;
529                            (node_id, in_d, out_d, total)
530                        })
531                        .filter(|(_, in_d, out_d, total)| {
532                            let score = match direction {
533                                TraversalDirection::Incoming => *in_d,
534                                TraversalDirection::Outgoing => *out_d,
535                                TraversalDirection::Both => *total,
536                            };
537                            score >= min_degree
538                        })
539                        .collect();
540
541                    nodes_with_degree.sort_by(|a, b| {
542                        let score_a = match direction {
543                            TraversalDirection::Incoming => a.1,
544                            TraversalDirection::Outgoing => a.2,
545                            TraversalDirection::Both => a.3,
546                        };
547                        let score_b = match direction {
548                            TraversalDirection::Incoming => b.1,
549                            TraversalDirection::Outgoing => b.2,
550                            TraversalDirection::Both => b.3,
551                        };
552                        score_b.cmp(&score_a).then_with(|| a.0.cmp(&b.0))
553                    });
554
555                    nodes_with_degree.truncate(limit as usize);
556
557                    // Fetch node details for the selected hubs
558                    let mut result = Vec::new();
559                    for (node_id, in_d, out_d, total) in nodes_with_degree {
560                        if let Some(node) = persistence.get_graph_node(node_id)? {
561                            result.push((node, in_d, out_d, total));
562                        }
563                    }
564
565                    Ok::<_, anyhow::Error>(result)
566                })
567                .await
568                .context("task join error")??;
569
570                let hubs_json: Vec<Value> = hubs
571                    .into_iter()
572                    .map(|(node, in_d, out_d, total)| {
573                        json!({
574                            "node": node,
575                            "in_degree": in_d,
576                            "out_degree": out_d,
577                            "total_degree": total
578                        })
579                    })
580                    .collect();
581
582                let direction_str = match direction {
583                    TraversalDirection::Incoming => "incoming",
584                    TraversalDirection::Outgoing => "outgoing",
585                    TraversalDirection::Both => "both",
586                };
587
588                Ok(ToolResult::success(
589                    json!({
590                        "direction": direction_str,
591                        "min_degree": min_degree,
592                        "count": hubs_json.len(),
593                        "hubs": hubs_json
594                    })
595                    .to_string(),
596                ))
597            }
598
599            "enable_sync" => {
600                let graph_name = args["graph_name"].as_str().unwrap_or("default");
601                let graph_name = graph_name.to_string();
602                let graph_name_display = graph_name.clone();
603                let session_id = session_id.to_string();
604
605                tokio::task::spawn_blocking(move || {
606                    persistence.graph_set_sync_enabled(&session_id, &graph_name, true)
607                })
608                .await
609                .context("task join error")??;
610
611                Ok(ToolResult::success(
612                    json!({
613                        "message": format!("Sync enabled for graph '{}'", graph_name_display),
614                        "graph_name": graph_name_display,
615                        "sync_enabled": true
616                    })
617                    .to_string(),
618                ))
619            }
620
621            "disable_sync" => {
622                let graph_name = args["graph_name"].as_str().unwrap_or("default");
623                let graph_name = graph_name.to_string();
624                let graph_name_display = graph_name.clone();
625                let session_id = session_id.to_string();
626
627                tokio::task::spawn_blocking(move || {
628                    persistence.graph_set_sync_enabled(&session_id, &graph_name, false)
629                })
630                .await
631                .context("task join error")??;
632
633                Ok(ToolResult::success(
634                    json!({
635                        "message": format!("Sync disabled for graph '{}'", graph_name_display),
636                        "graph_name": graph_name_display,
637                        "sync_enabled": false
638                    })
639                    .to_string(),
640                ))
641            }
642
643            "sync_status" => {
644                let graph_name = args["graph_name"].as_str().unwrap_or("default");
645                let graph_name = graph_name.to_string();
646                let graph_name_display = graph_name.clone();
647                let session_id = session_id.to_string();
648                let instance_id = persistence.instance_id().to_string();
649
650                let result = tokio::task::spawn_blocking(move || {
651                    let sync_enabled =
652                        persistence.graph_get_sync_enabled(&session_id, &graph_name)?;
653                    let vector_clock =
654                        persistence.graph_sync_state_get(&instance_id, &session_id, &graph_name)?;
655
656                    // Count recent changes
657                    let since = chrono::Utc::now()
658                        .checked_sub_signed(chrono::Duration::hours(1))
659                        .unwrap()
660                        .to_rfc3339();
661                    let changes = persistence.graph_changelog_get_since(&session_id, &since)?;
662
663                    Ok::<_, anyhow::Error>((sync_enabled, vector_clock, changes.len()))
664                })
665                .await
666                .context("task join error")??;
667
668                Ok(ToolResult::success(
669                    json!({
670                        "graph_name": graph_name_display,
671                        "sync_enabled": result.0,
672                        "vector_clock": result.1.unwrap_or_else(|| "{}".to_string()),
673                        "pending_changes": result.2,
674                    })
675                    .to_string(),
676                ))
677            }
678
679            #[cfg(feature = "api")]
680            "force_sync" => {
681                let graph_name = args["graph_name"].as_str().unwrap_or("default");
682                let peer_instance_id = args["peer_instance_id"]
683                    .as_str()
684                    .context("peer_instance_id is required for force_sync")?;
685
686                let graph_name = graph_name.to_string();
687                let graph_name_display = graph_name.clone();
688                let session_id = session_id.to_string();
689                let peer_instance_id = peer_instance_id.to_string();
690                let peer_display = peer_instance_id.clone();
691                let instance_id = persistence.instance_id().to_string();
692
693                // Create sync engine and trigger sync
694                let sync_engine = crate::sync::SyncEngine::new((*persistence).clone(), instance_id);
695
696                let result = sync_engine.sync_full(&session_id, &graph_name).await?;
697
698                Ok(ToolResult::success(
699                    json!({
700                        "message": format!("Sync initiated with peer {}", peer_display),
701                        "graph_name": graph_name_display,
702                        "nodes_synced": result.nodes.len(),
703                        "edges_synced": result.edges.len(),
704                    })
705                    .to_string(),
706                ))
707            }
708
709            #[cfg(not(feature = "api"))]
710            "force_sync" => Ok(ToolResult::failure(
711                "force_sync requires the 'api' feature to be enabled".to_string(),
712            )),
713
714            "list_sync_configs" => {
715                let session_id = session_id.to_string();
716
717                let result = tokio::task::spawn_blocking(move || {
718                    let graphs = persistence.graph_list(&session_id)?;
719                    let mut configs = Vec::new();
720                    for graph_name in graphs {
721                        let sync_enabled =
722                            persistence.graph_get_sync_enabled(&session_id, &graph_name)?;
723                        configs.push(json!({
724                            "graph_name": graph_name,
725                            "sync_enabled": sync_enabled,
726                        }));
727                    }
728                    Ok::<_, anyhow::Error>((session_id.clone(), configs))
729                })
730                .await
731                .context("task join error")??;
732
733                Ok(ToolResult::success(
734                    json!({
735                        "session_id": result.0,
736                        "graphs": result.1,
737                    })
738                    .to_string(),
739                ))
740            }
741
742            _ => Ok(ToolResult::failure(format!(
743                "Unknown operation: {}",
744                operation
745            ))),
746        }
747    }
748}