1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use serde_json::{json, Map, Value};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::persistence::Persistence;
8use crate::tools::{Tool, ToolResult};
9use crate::types::{EdgeType, NodeType, TraversalDirection};
10
11pub struct GraphTool {
12 persistence: Arc<Persistence>,
13}
14
15impl GraphTool {
16 pub fn new(persistence: Arc<Persistence>) -> Self {
17 Self { persistence }
18 }
19}
20
21#[async_trait]
22impl Tool for GraphTool {
23 fn name(&self) -> &str {
24 "graph"
25 }
26
27 fn description(&self) -> &str {
28 "Create, query, traverse, and synchronize knowledge graphs. Supports operations: \
29 create_node, create_edge, delete_node, delete_edge, get_node, get_edge, \
30 list_nodes, list_edges, find_path, traverse_neighbors, update_node, \
31 node_degree, list_hubs, enable_sync, disable_sync, sync_status, force_sync, \
32 list_sync_configs"
33 }
34
35 fn parameters(&self) -> Value {
36 json!({
37 "type": "object",
38 "properties": {
39 "operation": {
40 "type": "string",
41 "enum": [
42 "create_node", "create_edge", "delete_node", "delete_edge",
43 "get_node", "get_edge", "list_nodes", "list_edges",
44 "find_path", "traverse_neighbors", "update_node",
45 "node_degree", "list_hubs",
46 "enable_sync", "disable_sync", "sync_status", "force_sync",
47 "list_sync_configs"
48 ],
49 "description": "The graph operation to perform"
50 },
51 "session_id": {
52 "type": "string",
53 "description": "Session ID for graph isolation"
54 },
55 "node_id": {
56 "type": "integer",
57 "description": "Node ID (for get_node, delete_node, update_node, traverse_neighbors)"
58 },
59 "edge_id": {
60 "type": "integer",
61 "description": "Edge ID (for get_edge, delete_edge)"
62 },
63 "node_type": {
64 "type": "string",
65 "enum": ["entity", "concept", "fact", "message", "tool_result", "event"],
66 "description": "Type of node to create or filter by"
67 },
68 "label": {
69 "type": "string",
70 "description": "Semantic label for the node (e.g., 'Person', 'Location', 'Action')"
71 },
72 "properties": {
73 "type": "object",
74 "description": "JSON properties for the node or edge"
75 },
76 "source_id": {
77 "type": "integer",
78 "description": "Source node ID for edge creation or path finding"
79 },
80 "target_id": {
81 "type": "integer",
82 "description": "Target node ID for edge creation or path finding"
83 },
84 "edge_type": {
85 "type": "string",
86 "enum": [
87 "RELATES_TO", "CAUSED_BY", "PART_OF", "MENTIONS",
88 "FOLLOWS_FROM", "USES", "PRODUCES", "DEPENDS_ON"
89 ],
90 "description": "Type of edge relationship"
91 },
92 "custom_edge_type": {
93 "type": "string",
94 "description": "Custom edge type if not using predefined types"
95 },
96 "predicate": {
97 "type": "string",
98 "description": "RDF-style predicate for the edge"
99 },
100 "weight": {
101 "type": "number",
102 "default": 1.0,
103 "description": "Weight for the edge"
104 },
105 "direction": {
106 "type": "string",
107 "enum": ["outgoing", "incoming", "both"],
108 "default": "outgoing",
109 "description": "Direction for traversal and degree-based operations"
110 },
111 "depth": {
112 "type": "integer",
113 "default": 1,
114 "minimum": 1,
115 "maximum": 10,
116 "description": "Depth for traversal operations"
117 },
118 "max_hops": {
119 "type": "integer",
120 "default": 10,
121 "minimum": 1,
122 "maximum": 20,
123 "description": "Maximum hops for path finding"
124 },
125 "limit": {
126 "type": "integer",
127 "default": 100,
128 "minimum": 1,
129 "maximum": 1000,
130 "description": "Limit for list operations"
131 },
132 "min_degree": {
133 "type": "integer",
134 "default": 1,
135 "minimum": 0,
136 "description": "Minimum degree threshold when listing hubs"
137 },
138 "graph_name": {
139 "type": "string",
140 "default": "default",
141 "description": "Graph name for sync operations"
142 },
143 "peer_instance_id": {
144 "type": "string",
145 "description": "Peer instance ID to sync with (for force_sync)"
146 },
147 "sync_enabled": {
148 "type": "boolean",
149 "description": "Enable or disable sync (for enable_sync/disable_sync)"
150 }
151 },
152 "required": ["operation", "session_id"]
153 })
154 }
155
156 async fn execute(&self, args: Value) -> Result<ToolResult> {
157 let operation = args["operation"]
158 .as_str()
159 .context("operation must be a string")?;
160
161 let session_id = args["session_id"]
162 .as_str()
163 .context("session_id must be a string")?;
164
165 let persistence = Arc::clone(&self.persistence);
167
168 match operation {
169 "create_node" => {
170 let node_type = args["node_type"]
171 .as_str()
172 .context("node_type is required for create_node")?;
173 let label = args["label"]
174 .as_str()
175 .context("label is required for create_node")?;
176 let properties = args["properties"].clone();
177
178 let node_type = NodeType::from_str(node_type);
179 let session_id = session_id.to_string();
180 let label = label.to_string();
181
182 let result = tokio::task::spawn_blocking(move || {
183 persistence.insert_graph_node(&session_id, node_type, &label, &properties, None)
184 })
185 .await
186 .context("task join error")??;
187
188 Ok(ToolResult::success(
189 json!({
190 "node_id": result,
191 "message": format!("Created node with ID {}", result)
192 })
193 .to_string(),
194 ))
195 }
196
197 "create_edge" => {
198 let source_id = args["source_id"]
199 .as_i64()
200 .context("source_id is required for create_edge")?;
201 let target_id = args["target_id"]
202 .as_i64()
203 .context("target_id is required for create_edge")?;
204
205 let edge_type = if let Some(custom) = args["custom_edge_type"].as_str() {
206 EdgeType::Custom(custom.to_string())
207 } else if let Some(et) = args["edge_type"].as_str() {
208 EdgeType::from_str(et)
209 } else {
210 EdgeType::RelatesTo
211 };
212
213 let predicate = args["predicate"].as_str().map(|s| s.to_string());
214 let properties = if args["properties"].is_null() {
215 None
216 } else {
217 Some(args["properties"].clone())
218 };
219 let weight = args["weight"].as_f64().unwrap_or(1.0) as f32;
220 let session_id = session_id.to_string();
221
222 let result = tokio::task::spawn_blocking(move || {
223 persistence.insert_graph_edge(
224 &session_id,
225 source_id,
226 target_id,
227 edge_type,
228 predicate.as_deref(),
229 properties.as_ref(),
230 weight,
231 )
232 })
233 .await
234 .context("task join error")??;
235
236 Ok(ToolResult::success(
237 json!({
238 "edge_id": result,
239 "message": format!("Created edge with ID {}", result)
240 })
241 .to_string(),
242 ))
243 }
244
245 "get_node" => {
246 let node_id = args["node_id"]
247 .as_i64()
248 .context("node_id is required for get_node")?;
249
250 let result =
251 tokio::task::spawn_blocking(move || persistence.get_graph_node(node_id))
252 .await
253 .context("task join error")??;
254
255 match result {
256 Some(node) => Ok(ToolResult::success(serde_json::to_string_pretty(&node)?)),
257 None => Ok(ToolResult::failure(format!("Node {} not found", node_id))),
258 }
259 }
260
261 "get_edge" => {
262 let edge_id = args["edge_id"]
263 .as_i64()
264 .context("edge_id is required for get_edge")?;
265
266 let result =
267 tokio::task::spawn_blocking(move || persistence.get_graph_edge(edge_id))
268 .await
269 .context("task join error")??;
270
271 match result {
272 Some(edge) => Ok(ToolResult::success(serde_json::to_string_pretty(&edge)?)),
273 None => Ok(ToolResult::failure(format!("Edge {} not found", edge_id))),
274 }
275 }
276
277 "list_nodes" => {
278 let node_type = args["node_type"].as_str().map(NodeType::from_str);
279 let limit = args["limit"].as_i64().or(Some(100));
280 let session_id = session_id.to_string();
281
282 let result = tokio::task::spawn_blocking(move || {
283 persistence.list_graph_nodes(&session_id, node_type, limit)
284 })
285 .await
286 .context("task join error")??;
287
288 Ok(ToolResult::success(
289 json!({
290 "count": result.len(),
291 "nodes": result
292 })
293 .to_string(),
294 ))
295 }
296
297 "list_edges" => {
298 let source_id = args["source_id"].as_i64();
299 let target_id = args["target_id"].as_i64();
300 let session_id = session_id.to_string();
301
302 let result = tokio::task::spawn_blocking(move || {
303 persistence.list_graph_edges(&session_id, source_id, target_id)
304 })
305 .await
306 .context("task join error")??;
307
308 Ok(ToolResult::success(
309 json!({
310 "count": result.len(),
311 "edges": result
312 })
313 .to_string(),
314 ))
315 }
316
317 "delete_node" => {
318 let node_id = args["node_id"]
319 .as_i64()
320 .context("node_id is required for delete_node")?;
321
322 tokio::task::spawn_blocking(move || persistence.delete_graph_node(node_id))
323 .await
324 .context("task join error")??;
325
326 Ok(ToolResult::success(format!("Deleted node {}", node_id)))
327 }
328
329 "delete_edge" => {
330 let edge_id = args["edge_id"]
331 .as_i64()
332 .context("edge_id is required for delete_edge")?;
333
334 tokio::task::spawn_blocking(move || persistence.delete_graph_edge(edge_id))
335 .await
336 .context("task join error")??;
337
338 Ok(ToolResult::success(format!("Deleted edge {}", edge_id)))
339 }
340
341 "update_node" => {
342 let node_id = args["node_id"]
343 .as_i64()
344 .context("node_id is required for update_node")?;
345 let properties = args["properties"].clone();
346
347 tokio::task::spawn_blocking(move || {
348 persistence.update_graph_node(node_id, &properties)
349 })
350 .await
351 .context("task join error")??;
352
353 Ok(ToolResult::success(format!("Updated node {}", node_id)))
354 }
355
356 "node_degree" => {
357 let node_id = args["node_id"]
358 .as_i64()
359 .context("node_id is required for node_degree")?;
360 let edge_type_filter = args["edge_type"].as_str().map(EdgeType::from_str);
361 let session_id = session_id.to_string();
362
363 let (in_degree, out_degree, by_type) = tokio::task::spawn_blocking(move || {
364 let edges = persistence.list_graph_edges(&session_id, None, None)?;
365 let mut in_degree: i64 = 0;
366 let mut out_degree: i64 = 0;
367 let mut by_type: HashMap<String, (i64, i64)> = HashMap::new();
368
369 for edge in edges {
370 if let Some(ref filter) = edge_type_filter {
371 if &edge.edge_type != filter {
372 continue;
373 }
374 }
375
376 let key = edge.edge_type.as_str();
377
378 if edge.source_id == node_id {
379 out_degree += 1;
380 let entry = by_type.entry(key.clone()).or_insert((0, 0));
381 entry.1 += 1;
382 }
383 if edge.target_id == node_id {
384 in_degree += 1;
385 let entry = by_type.entry(key.clone()).or_insert((0, 0));
386 entry.0 += 1;
387 }
388 }
389
390 Ok::<_, anyhow::Error>((in_degree, out_degree, by_type))
391 })
392 .await
393 .context("task join error")??;
394
395 let total_degree = in_degree + out_degree;
396
397 let mut by_type_json = Map::new();
398 for (edge_type, (in_d, out_d)) in by_type {
399 by_type_json.insert(
400 edge_type,
401 json!({
402 "in_degree": in_d,
403 "out_degree": out_d,
404 "total_degree": in_d + out_d
405 }),
406 );
407 }
408
409 Ok(ToolResult::success(
410 json!({
411 "node_id": node_id,
412 "in_degree": in_degree,
413 "out_degree": out_degree,
414 "total_degree": total_degree,
415 "by_edge_type": by_type_json
416 })
417 .to_string(),
418 ))
419 }
420
421 "find_path" => {
422 let source_id = args["source_id"]
423 .as_i64()
424 .context("source_id is required for find_path")?;
425 let target_id = args["target_id"]
426 .as_i64()
427 .context("target_id is required for find_path")?;
428 let max_hops = args["max_hops"].as_u64().map(|h| h as usize);
429 let session_id = session_id.to_string();
430
431 let result = tokio::task::spawn_blocking(move || {
432 persistence.find_shortest_path(&session_id, source_id, target_id, max_hops)
433 })
434 .await
435 .context("task join error")??;
436
437 match result {
438 Some(path) => Ok(ToolResult::success(
439 json!({
440 "found": true,
441 "length": path.length,
442 "total_weight": path.weight,
443 "path": path
444 })
445 .to_string(),
446 )),
447 None => Ok(ToolResult::success(
448 json!({
449 "found": false,
450 "message": format!("No path found from {} to {}", source_id, target_id)
451 })
452 .to_string(),
453 )),
454 }
455 }
456
457 "traverse_neighbors" => {
458 let node_id = args["node_id"]
459 .as_i64()
460 .context("node_id is required for traverse_neighbors")?;
461 let depth = args["depth"].as_u64().unwrap_or(1) as usize;
462 let direction = args["direction"]
463 .as_str()
464 .map(|d| match d {
465 "incoming" => TraversalDirection::Incoming,
466 "both" => TraversalDirection::Both,
467 _ => TraversalDirection::Outgoing,
468 })
469 .unwrap_or(TraversalDirection::Outgoing);
470 let session_id = session_id.to_string();
471
472 let result = tokio::task::spawn_blocking(move || {
473 persistence.traverse_neighbors(&session_id, node_id, direction, depth)
474 })
475 .await
476 .context("task join error")??;
477
478 Ok(ToolResult::success(
479 json!({
480 "count": result.len(),
481 "neighbors": result
482 })
483 .to_string(),
484 ))
485 }
486
487 "list_hubs" => {
488 let direction = args["direction"]
489 .as_str()
490 .map(|d| match d {
491 "incoming" => TraversalDirection::Incoming,
492 "both" => TraversalDirection::Both,
493 _ => TraversalDirection::Outgoing,
494 })
495 .unwrap_or(TraversalDirection::Outgoing);
496 let min_degree = args["min_degree"].as_i64().unwrap_or(1).max(0);
497 let limit = args["limit"].as_i64().unwrap_or(10).max(1);
498 let edge_type_filter = args["edge_type"].as_str().map(EdgeType::from_str);
499 let session_id = session_id.to_string();
500
501 let hubs = tokio::task::spawn_blocking(move || {
502 let edges = persistence.list_graph_edges(&session_id, None, None)?;
503 let mut degrees: HashMap<i64, (i64, i64)> = HashMap::new();
504
505 for edge in edges {
506 if let Some(ref filter) = edge_type_filter {
507 if &edge.edge_type != filter {
508 continue;
509 }
510 }
511
512 {
514 let entry = degrees.entry(edge.source_id).or_insert((0, 0));
515 entry.1 += 1;
516 }
517 {
519 let entry = degrees.entry(edge.target_id).or_insert((0, 0));
520 entry.0 += 1;
521 }
522 }
523
524 let mut nodes_with_degree: Vec<(i64, i64, i64, i64)> = degrees
526 .into_iter()
527 .map(|(node_id, (in_d, out_d))| {
528 let total = in_d + out_d;
529 (node_id, in_d, out_d, total)
530 })
531 .filter(|(_, in_d, out_d, total)| {
532 let score = match direction {
533 TraversalDirection::Incoming => *in_d,
534 TraversalDirection::Outgoing => *out_d,
535 TraversalDirection::Both => *total,
536 };
537 score >= min_degree
538 })
539 .collect();
540
541 nodes_with_degree.sort_by(|a, b| {
542 let score_a = match direction {
543 TraversalDirection::Incoming => a.1,
544 TraversalDirection::Outgoing => a.2,
545 TraversalDirection::Both => a.3,
546 };
547 let score_b = match direction {
548 TraversalDirection::Incoming => b.1,
549 TraversalDirection::Outgoing => b.2,
550 TraversalDirection::Both => b.3,
551 };
552 score_b.cmp(&score_a).then_with(|| a.0.cmp(&b.0))
553 });
554
555 nodes_with_degree.truncate(limit as usize);
556
557 let mut result = Vec::new();
559 for (node_id, in_d, out_d, total) in nodes_with_degree {
560 if let Some(node) = persistence.get_graph_node(node_id)? {
561 result.push((node, in_d, out_d, total));
562 }
563 }
564
565 Ok::<_, anyhow::Error>(result)
566 })
567 .await
568 .context("task join error")??;
569
570 let hubs_json: Vec<Value> = hubs
571 .into_iter()
572 .map(|(node, in_d, out_d, total)| {
573 json!({
574 "node": node,
575 "in_degree": in_d,
576 "out_degree": out_d,
577 "total_degree": total
578 })
579 })
580 .collect();
581
582 let direction_str = match direction {
583 TraversalDirection::Incoming => "incoming",
584 TraversalDirection::Outgoing => "outgoing",
585 TraversalDirection::Both => "both",
586 };
587
588 Ok(ToolResult::success(
589 json!({
590 "direction": direction_str,
591 "min_degree": min_degree,
592 "count": hubs_json.len(),
593 "hubs": hubs_json
594 })
595 .to_string(),
596 ))
597 }
598
599 "enable_sync" => {
600 let graph_name = args["graph_name"].as_str().unwrap_or("default");
601 let graph_name = graph_name.to_string();
602 let graph_name_display = graph_name.clone();
603 let session_id = session_id.to_string();
604
605 tokio::task::spawn_blocking(move || {
606 persistence.graph_set_sync_enabled(&session_id, &graph_name, true)
607 })
608 .await
609 .context("task join error")??;
610
611 Ok(ToolResult::success(
612 json!({
613 "message": format!("Sync enabled for graph '{}'", graph_name_display),
614 "graph_name": graph_name_display,
615 "sync_enabled": true
616 })
617 .to_string(),
618 ))
619 }
620
621 "disable_sync" => {
622 let graph_name = args["graph_name"].as_str().unwrap_or("default");
623 let graph_name = graph_name.to_string();
624 let graph_name_display = graph_name.clone();
625 let session_id = session_id.to_string();
626
627 tokio::task::spawn_blocking(move || {
628 persistence.graph_set_sync_enabled(&session_id, &graph_name, false)
629 })
630 .await
631 .context("task join error")??;
632
633 Ok(ToolResult::success(
634 json!({
635 "message": format!("Sync disabled for graph '{}'", graph_name_display),
636 "graph_name": graph_name_display,
637 "sync_enabled": false
638 })
639 .to_string(),
640 ))
641 }
642
643 "sync_status" => {
644 let graph_name = args["graph_name"].as_str().unwrap_or("default");
645 let graph_name = graph_name.to_string();
646 let graph_name_display = graph_name.clone();
647 let session_id = session_id.to_string();
648 let instance_id = persistence.instance_id().to_string();
649
650 let result = tokio::task::spawn_blocking(move || {
651 let sync_enabled =
652 persistence.graph_get_sync_enabled(&session_id, &graph_name)?;
653 let vector_clock =
654 persistence.graph_sync_state_get(&instance_id, &session_id, &graph_name)?;
655
656 let since = chrono::Utc::now()
658 .checked_sub_signed(chrono::Duration::hours(1))
659 .unwrap()
660 .to_rfc3339();
661 let changes = persistence.graph_changelog_get_since(&session_id, &since)?;
662
663 Ok::<_, anyhow::Error>((sync_enabled, vector_clock, changes.len()))
664 })
665 .await
666 .context("task join error")??;
667
668 Ok(ToolResult::success(
669 json!({
670 "graph_name": graph_name_display,
671 "sync_enabled": result.0,
672 "vector_clock": result.1.unwrap_or_else(|| "{}".to_string()),
673 "pending_changes": result.2,
674 })
675 .to_string(),
676 ))
677 }
678
679 #[cfg(feature = "api")]
680 "force_sync" => {
681 let graph_name = args["graph_name"].as_str().unwrap_or("default");
682 let peer_instance_id = args["peer_instance_id"]
683 .as_str()
684 .context("peer_instance_id is required for force_sync")?;
685
686 let graph_name = graph_name.to_string();
687 let graph_name_display = graph_name.clone();
688 let session_id = session_id.to_string();
689 let peer_instance_id = peer_instance_id.to_string();
690 let peer_display = peer_instance_id.clone();
691 let instance_id = persistence.instance_id().to_string();
692
693 let sync_engine = crate::sync::SyncEngine::new((*persistence).clone(), instance_id);
695
696 let result = sync_engine.sync_full(&session_id, &graph_name).await?;
697
698 Ok(ToolResult::success(
699 json!({
700 "message": format!("Sync initiated with peer {}", peer_display),
701 "graph_name": graph_name_display,
702 "nodes_synced": result.nodes.len(),
703 "edges_synced": result.edges.len(),
704 })
705 .to_string(),
706 ))
707 }
708
709 #[cfg(not(feature = "api"))]
710 "force_sync" => Ok(ToolResult::failure(
711 "force_sync requires the 'api' feature to be enabled".to_string(),
712 )),
713
714 "list_sync_configs" => {
715 let session_id = session_id.to_string();
716
717 let result = tokio::task::spawn_blocking(move || {
718 let graphs = persistence.graph_list(&session_id)?;
719 let mut configs = Vec::new();
720 for graph_name in graphs {
721 let sync_enabled =
722 persistence.graph_get_sync_enabled(&session_id, &graph_name)?;
723 configs.push(json!({
724 "graph_name": graph_name,
725 "sync_enabled": sync_enabled,
726 }));
727 }
728 Ok::<_, anyhow::Error>((session_id.clone(), configs))
729 })
730 .await
731 .context("task join error")??;
732
733 Ok(ToolResult::success(
734 json!({
735 "session_id": result.0,
736 "graphs": result.1,
737 })
738 .to_string(),
739 ))
740 }
741
742 _ => Ok(ToolResult::failure(format!(
743 "Unknown operation: {}",
744 operation
745 ))),
746 }
747 }
748}