1use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::path::PathBuf;
11use std::time::Duration;
12
13use crate::registry::Tool;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16enum NodeType {
17 Paper,
18 Concept,
19 Method,
20 Dataset,
21 Person,
22 Organization,
23}
24
25impl NodeType {
26 fn from_str_loose(s: &str) -> Option<Self> {
27 match s.to_lowercase().as_str() {
28 "paper" => Some(Self::Paper),
29 "concept" => Some(Self::Concept),
30 "method" => Some(Self::Method),
31 "dataset" => Some(Self::Dataset),
32 "person" => Some(Self::Person),
33 "organization" => Some(Self::Organization),
34 _ => None,
35 }
36 }
37
38 fn as_str(&self) -> &str {
39 match self {
40 Self::Paper => "Paper",
41 Self::Concept => "Concept",
42 Self::Method => "Method",
43 Self::Dataset => "Dataset",
44 Self::Person => "Person",
45 Self::Organization => "Organization",
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51enum RelationshipType {
52 Cites,
53 Implements,
54 Extends,
55 Contradicts,
56 BuildsOn,
57 AuthoredBy,
58 UsesDataset,
59 RelatedTo,
60}
61
62impl RelationshipType {
63 fn from_str_loose(s: &str) -> Option<Self> {
64 match s.to_lowercase().as_str() {
65 "cites" => Some(Self::Cites),
66 "implements" => Some(Self::Implements),
67 "extends" => Some(Self::Extends),
68 "contradicts" => Some(Self::Contradicts),
69 "builds_on" => Some(Self::BuildsOn),
70 "authored_by" => Some(Self::AuthoredBy),
71 "uses_dataset" => Some(Self::UsesDataset),
72 "related_to" => Some(Self::RelatedTo),
73 _ => None,
74 }
75 }
76
77 fn as_str(&self) -> &str {
78 match self {
79 Self::Cites => "Cites",
80 Self::Implements => "Implements",
81 Self::Extends => "Extends",
82 Self::Contradicts => "Contradicts",
83 Self::BuildsOn => "BuildsOn",
84 Self::AuthoredBy => "AuthoredBy",
85 Self::UsesDataset => "UsesDataset",
86 Self::RelatedTo => "RelatedTo",
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct GraphNode {
93 id: String,
94 name: String,
95 node_type: NodeType,
96 description: String,
97 tags: Vec<String>,
98 metadata: HashMap<String, String>,
99 created_at: DateTime<Utc>,
100 updated_at: DateTime<Utc>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104struct Edge {
105 source_id: String,
106 target_id: String,
107 relationship_type: RelationshipType,
108 strength: f64,
109 notes: String,
110 created_at: DateTime<Utc>,
111}
112
113#[derive(Debug, Default, Serialize, Deserialize)]
114struct KnowledgeGraphState {
115 nodes: Vec<GraphNode>,
116 edges: Vec<Edge>,
117 next_auto_id: usize,
118}
119
120#[derive(Debug, Deserialize)]
122struct ArxivLibraryFile {
123 entries: Vec<ArxivLibraryEntry>,
124}
125
126#[derive(Debug, Deserialize)]
127struct ArxivLibraryEntry {
128 paper: ArxivPaperRef,
129 #[serde(default)]
130 tags: Vec<String>,
131}
132
133#[derive(Debug, Deserialize)]
134struct ArxivPaperRef {
135 #[serde(alias = "arxiv_id")]
136 id: String,
137 title: String,
138 authors: Vec<String>,
139 #[serde(alias = "summary")]
140 abstract_text: String,
141}
142
143pub struct KnowledgeGraphTool {
144 workspace: PathBuf,
145}
146
147impl KnowledgeGraphTool {
148 pub fn new(workspace: PathBuf) -> Self {
149 Self { workspace }
150 }
151
152 fn state_path(&self) -> PathBuf {
153 self.workspace
154 .join(".rustant")
155 .join("knowledge")
156 .join("graph.json")
157 }
158
159 fn load_state(&self) -> KnowledgeGraphState {
160 let path = self.state_path();
161 if path.exists() {
162 std::fs::read_to_string(&path)
163 .ok()
164 .and_then(|s| serde_json::from_str(&s).ok())
165 .unwrap_or_default()
166 } else {
167 KnowledgeGraphState::default()
168 }
169 }
170
171 fn save_state(&self, state: &KnowledgeGraphState) -> Result<(), ToolError> {
172 let path = self.state_path();
173 if let Some(parent) = path.parent() {
174 std::fs::create_dir_all(parent).map_err(|e| ToolError::ExecutionFailed {
175 name: "knowledge_graph".to_string(),
176 message: format!("Failed to create state dir: {}", e),
177 })?;
178 }
179 let json = serde_json::to_string_pretty(state).map_err(|e| ToolError::ExecutionFailed {
180 name: "knowledge_graph".to_string(),
181 message: format!("Failed to serialize state: {}", e),
182 })?;
183 let tmp = path.with_extension("json.tmp");
184 std::fs::write(&tmp, &json).map_err(|e| ToolError::ExecutionFailed {
185 name: "knowledge_graph".to_string(),
186 message: format!("Failed to write state: {}", e),
187 })?;
188 std::fs::rename(&tmp, &path).map_err(|e| ToolError::ExecutionFailed {
189 name: "knowledge_graph".to_string(),
190 message: format!("Failed to rename state file: {}", e),
191 })?;
192 Ok(())
193 }
194
195 fn slug(name: &str) -> String {
196 name.to_lowercase().replace(' ', "-")
197 }
198
199 fn arxiv_library_path(&self) -> PathBuf {
200 self.workspace
201 .join(".rustant")
202 .join("arxiv")
203 .join("library.json")
204 }
205}
206
207#[async_trait]
208impl Tool for KnowledgeGraphTool {
209 fn name(&self) -> &str {
210 "knowledge_graph"
211 }
212
213 fn description(&self) -> &str {
214 "Local knowledge graph of concepts, papers, methods, and relationships. Actions: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot."
215 }
216
217 fn parameters_schema(&self) -> Value {
218 json!({
219 "type": "object",
220 "properties": {
221 "action": {
222 "type": "string",
223 "enum": [
224 "add_node", "get_node", "update_node", "remove_node",
225 "add_edge", "remove_edge", "neighbors", "search",
226 "list", "path", "stats", "import_arxiv", "export_dot"
227 ],
228 "description": "Action to perform"
229 },
230 "id": { "type": "string", "description": "Node ID" },
231 "name": { "type": "string", "description": "Node name" },
232 "node_type": {
233 "type": "string",
234 "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
235 "description": "Type of node"
236 },
237 "description": { "type": "string", "description": "Node description" },
238 "tags": {
239 "type": "array",
240 "items": { "type": "string" },
241 "description": "Tags for the node"
242 },
243 "metadata": {
244 "type": "object",
245 "additionalProperties": { "type": "string" },
246 "description": "Key-value metadata for the node"
247 },
248 "source_id": { "type": "string", "description": "Source node ID for edge" },
249 "target_id": { "type": "string", "description": "Target node ID for edge" },
250 "relationship_type": {
251 "type": "string",
252 "enum": ["cites", "implements", "extends", "contradicts", "builds_on", "authored_by", "uses_dataset", "related_to"],
253 "description": "Type of relationship"
254 },
255 "strength": {
256 "type": "number",
257 "description": "Edge strength 0.0-1.0 (default 0.5)"
258 },
259 "notes": { "type": "string", "description": "Edge notes" },
260 "depth": {
261 "type": "integer",
262 "description": "Traversal depth for neighbors (1-3, default 1)"
263 },
264 "query": { "type": "string", "description": "Search query" },
265 "tag": { "type": "string", "description": "Filter by tag" },
266 "from_id": { "type": "string", "description": "Path start node ID" },
267 "to_id": { "type": "string", "description": "Path end node ID" },
268 "arxiv_id": { "type": "string", "description": "ArXiv paper ID for import" },
269 "filter_type": {
270 "type": "string",
271 "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
272 "description": "Filter by node type"
273 }
274 },
275 "required": ["action"]
276 })
277 }
278
279 fn risk_level(&self) -> RiskLevel {
280 RiskLevel::Write
281 }
282
283 fn timeout(&self) -> Duration {
284 Duration::from_secs(30)
285 }
286
287 async fn execute(&self, args: Value) -> Result<ToolOutput, ToolError> {
288 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("");
289 let mut state = self.load_state();
290
291 match action {
292 "add_node" => {
293 let name = args
294 .get("name")
295 .and_then(|v| v.as_str())
296 .unwrap_or("")
297 .trim();
298 if name.is_empty() {
299 return Ok(ToolOutput::text("Missing required parameter 'name'."));
300 }
301
302 let node_type_str = args.get("node_type").and_then(|v| v.as_str()).unwrap_or("");
303 let node_type = match NodeType::from_str_loose(node_type_str) {
304 Some(nt) => nt,
305 None => {
306 return Ok(ToolOutput::text(format!(
307 "Invalid node_type '{}'. Use: paper, concept, method, dataset, person, organization.",
308 node_type_str
309 )));
310 }
311 };
312
313 let id = args
314 .get("id")
315 .and_then(|v| v.as_str())
316 .map(|s| s.to_string())
317 .unwrap_or_else(|| Self::slug(name));
318
319 if state.nodes.iter().any(|n| n.id == id) {
321 return Ok(ToolOutput::text(format!(
322 "Node with id '{}' already exists.",
323 id
324 )));
325 }
326
327 let description = args
328 .get("description")
329 .and_then(|v| v.as_str())
330 .unwrap_or("")
331 .to_string();
332
333 let tags: Vec<String> = args
334 .get("tags")
335 .and_then(|v| v.as_array())
336 .map(|arr| {
337 arr.iter()
338 .filter_map(|v| v.as_str().map(|s| s.to_string()))
339 .collect()
340 })
341 .unwrap_or_default();
342
343 let metadata: HashMap<String, String> = args
344 .get("metadata")
345 .and_then(|v| v.as_object())
346 .map(|obj| {
347 obj.iter()
348 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
349 .collect()
350 })
351 .unwrap_or_default();
352
353 let now = Utc::now();
354 state.nodes.push(GraphNode {
355 id: id.clone(),
356 name: name.to_string(),
357 node_type,
358 description,
359 tags,
360 metadata,
361 created_at: now,
362 updated_at: now,
363 });
364 state.next_auto_id += 1;
365 self.save_state(&state)?;
366
367 Ok(ToolOutput::text(format!(
368 "Added node '{}' (id: {}).",
369 name, id
370 )))
371 }
372
373 "get_node" => {
374 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
375 if id.is_empty() {
376 return Ok(ToolOutput::text("Missing required parameter 'id'."));
377 }
378
379 let node = match state.nodes.iter().find(|n| n.id == id) {
380 Some(n) => n,
381 None => {
382 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
383 }
384 };
385
386 let connected_edges: Vec<&Edge> = state
387 .edges
388 .iter()
389 .filter(|e| e.source_id == id || e.target_id == id)
390 .collect();
391
392 let mut output = format!(
393 "Node: {} ({})\n ID: {}\n Type: {}\n Description: {}\n Tags: {}\n Created: {}\n Updated: {}",
394 node.name,
395 node.node_type.as_str(),
396 node.id,
397 node.node_type.as_str(),
398 if node.description.is_empty() {
399 "(none)"
400 } else {
401 &node.description
402 },
403 if node.tags.is_empty() {
404 "(none)".to_string()
405 } else {
406 node.tags.join(", ")
407 },
408 node.created_at.format("%Y-%m-%d %H:%M"),
409 node.updated_at.format("%Y-%m-%d %H:%M"),
410 );
411
412 if !node.metadata.is_empty() {
413 output.push_str("\n Metadata:");
414 for (k, v) in &node.metadata {
415 output.push_str(&format!("\n {}: {}", k, v));
416 }
417 }
418
419 if !connected_edges.is_empty() {
420 output.push_str(&format!("\n Edges ({}):", connected_edges.len()));
421 for edge in &connected_edges {
422 let direction = if edge.source_id == id {
423 format!(
424 "--[{}]--> {}",
425 edge.relationship_type.as_str(),
426 edge.target_id
427 )
428 } else {
429 format!(
430 "<--[{}]-- {}",
431 edge.relationship_type.as_str(),
432 edge.source_id
433 )
434 };
435 output.push_str(&format!(
436 "\n {} (strength: {:.2})",
437 direction, edge.strength
438 ));
439 }
440 }
441
442 Ok(ToolOutput::text(output))
443 }
444
445 "update_node" => {
446 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
447 if id.is_empty() {
448 return Ok(ToolOutput::text("Missing required parameter 'id'."));
449 }
450
451 let node = match state.nodes.iter_mut().find(|n| n.id == id) {
452 Some(n) => n,
453 None => {
454 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
455 }
456 };
457
458 if let Some(name) = args.get("name").and_then(|v| v.as_str()) {
459 node.name = name.to_string();
460 }
461 if let Some(desc) = args.get("description").and_then(|v| v.as_str()) {
462 node.description = desc.to_string();
463 }
464 if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
465 node.tags = tags
466 .iter()
467 .filter_map(|v| v.as_str().map(|s| s.to_string()))
468 .collect();
469 }
470 node.updated_at = Utc::now();
471 self.save_state(&state)?;
472
473 Ok(ToolOutput::text(format!("Updated node '{}'.", id)))
474 }
475
476 "remove_node" => {
477 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
478 if id.is_empty() {
479 return Ok(ToolOutput::text("Missing required parameter 'id'."));
480 }
481
482 let before_nodes = state.nodes.len();
483 state.nodes.retain(|n| n.id != id);
484 if state.nodes.len() == before_nodes {
485 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
486 }
487
488 let before_edges = state.edges.len();
490 state
491 .edges
492 .retain(|e| e.source_id != id && e.target_id != id);
493 let edges_removed = before_edges - state.edges.len();
494
495 self.save_state(&state)?;
496
497 Ok(ToolOutput::text(format!(
498 "Removed node '{}' and {} connected edge(s).",
499 id, edges_removed
500 )))
501 }
502
503 "add_edge" => {
504 let source_id = args.get("source_id").and_then(|v| v.as_str()).unwrap_or("");
505 let target_id = args.get("target_id").and_then(|v| v.as_str()).unwrap_or("");
506 let rel_str = args
507 .get("relationship_type")
508 .and_then(|v| v.as_str())
509 .unwrap_or("");
510
511 if source_id.is_empty() || target_id.is_empty() {
512 return Ok(ToolOutput::text(
513 "Missing required parameters 'source_id' and 'target_id'.",
514 ));
515 }
516
517 if !state.nodes.iter().any(|n| n.id == source_id) {
519 return Ok(ToolOutput::text(format!(
520 "Source node '{}' not found.",
521 source_id
522 )));
523 }
524 if !state.nodes.iter().any(|n| n.id == target_id) {
525 return Ok(ToolOutput::text(format!(
526 "Target node '{}' not found.",
527 target_id
528 )));
529 }
530
531 let relationship_type = match RelationshipType::from_str_loose(rel_str) {
532 Some(rt) => rt,
533 None => {
534 return Ok(ToolOutput::text(format!(
535 "Invalid relationship_type '{}'. Use: cites, implements, extends, contradicts, builds_on, authored_by, uses_dataset, related_to.",
536 rel_str
537 )));
538 }
539 };
540
541 let strength = args
542 .get("strength")
543 .and_then(|v| v.as_f64())
544 .unwrap_or(0.5)
545 .clamp(0.0, 1.0);
546
547 let notes = args
548 .get("notes")
549 .and_then(|v| v.as_str())
550 .unwrap_or("")
551 .to_string();
552
553 state.edges.push(Edge {
554 source_id: source_id.to_string(),
555 target_id: target_id.to_string(),
556 relationship_type,
557 strength,
558 notes,
559 created_at: Utc::now(),
560 });
561
562 self.save_state(&state)?;
563
564 Ok(ToolOutput::text(format!(
565 "Added edge {} --[{}]--> {} (strength: {:.2}).",
566 source_id, rel_str, target_id, strength
567 )))
568 }
569
570 "remove_edge" => {
571 let source_id = args.get("source_id").and_then(|v| v.as_str()).unwrap_or("");
572 let target_id = args.get("target_id").and_then(|v| v.as_str()).unwrap_or("");
573
574 if source_id.is_empty() || target_id.is_empty() {
575 return Ok(ToolOutput::text(
576 "Missing required parameters 'source_id' and 'target_id'.",
577 ));
578 }
579
580 let rel_filter = args
581 .get("relationship_type")
582 .and_then(|v| v.as_str())
583 .and_then(RelationshipType::from_str_loose);
584
585 let before = state.edges.len();
586 state.edges.retain(|e| {
587 if e.source_id == source_id && e.target_id == target_id {
588 if let Some(ref rt) = rel_filter {
589 return &e.relationship_type != rt;
590 }
591 return false;
592 }
593 true
594 });
595 let removed = before - state.edges.len();
596
597 if removed == 0 {
598 return Ok(ToolOutput::text("No matching edge(s) found."));
599 }
600
601 self.save_state(&state)?;
602
603 Ok(ToolOutput::text(format!(
604 "Removed {} edge(s) from '{}' to '{}'.",
605 removed, source_id, target_id
606 )))
607 }
608
609 "neighbors" => {
610 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
611 if id.is_empty() {
612 return Ok(ToolOutput::text("Missing required parameter 'id'."));
613 }
614
615 if !state.nodes.iter().any(|n| n.id == id) {
616 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
617 }
618
619 let max_depth = args
620 .get("depth")
621 .and_then(|v| v.as_u64())
622 .unwrap_or(1)
623 .clamp(1, 3) as usize;
624
625 let rel_filter = args
626 .get("relationship_type")
627 .and_then(|v| v.as_str())
628 .and_then(RelationshipType::from_str_loose);
629
630 let mut visited: HashSet<String> = HashSet::new();
632 visited.insert(id.to_string());
633 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
634 queue.push_back((id.to_string(), 0));
635 let mut found_nodes: Vec<(String, usize)> = Vec::new();
636
637 while let Some((current_id, depth)) = queue.pop_front() {
638 if depth >= max_depth {
639 continue;
640 }
641
642 for edge in &state.edges {
643 if let Some(ref rt) = rel_filter
644 && &edge.relationship_type != rt
645 {
646 continue;
647 }
648
649 let neighbor_id = if edge.source_id == current_id {
650 &edge.target_id
651 } else if edge.target_id == current_id {
652 &edge.source_id
653 } else {
654 continue;
655 };
656
657 if !visited.contains(neighbor_id) {
658 visited.insert(neighbor_id.clone());
659 found_nodes.push((neighbor_id.clone(), depth + 1));
660 queue.push_back((neighbor_id.clone(), depth + 1));
661 }
662 }
663 }
664
665 if found_nodes.is_empty() {
666 return Ok(ToolOutput::text(format!(
667 "No neighbors found for '{}' within depth {}.",
668 id, max_depth
669 )));
670 }
671
672 let mut output = format!("Neighbors of '{}' (depth {}):\n", id, max_depth);
673 for (nid, depth) in &found_nodes {
674 if let Some(node) = state.nodes.iter().find(|n| n.id == *nid) {
675 output.push_str(&format!(
676 " [depth {}] {} — {} ({})\n",
677 depth,
678 node.id,
679 node.name,
680 node.node_type.as_str()
681 ));
682 }
683 }
684
685 Ok(ToolOutput::text(output.trim_end().to_string()))
686 }
687
688 "search" => {
689 let query = args
690 .get("query")
691 .and_then(|v| v.as_str())
692 .unwrap_or("")
693 .to_lowercase();
694 if query.is_empty() {
695 return Ok(ToolOutput::text("Missing required parameter 'query'."));
696 }
697
698 let type_filter = args
699 .get("node_type")
700 .and_then(|v| v.as_str())
701 .and_then(NodeType::from_str_loose);
702
703 let matches: Vec<&GraphNode> = state
704 .nodes
705 .iter()
706 .filter(|n| {
707 if let Some(ref nt) = type_filter
708 && &n.node_type != nt
709 {
710 return false;
711 }
712 n.name.to_lowercase().contains(&query)
713 || n.description.to_lowercase().contains(&query)
714 || n.tags.iter().any(|t| t.to_lowercase().contains(&query))
715 })
716 .collect();
717
718 if matches.is_empty() {
719 return Ok(ToolOutput::text(format!("No nodes matching '{}'.", query)));
720 }
721
722 let mut output = format!("Found {} node(s):\n", matches.len());
723 for node in &matches {
724 output.push_str(&format!(
725 " {} — {} ({})\n",
726 node.id,
727 node.name,
728 node.node_type.as_str()
729 ));
730 }
731
732 Ok(ToolOutput::text(output.trim_end().to_string()))
733 }
734
735 "list" => {
736 let type_filter = args
737 .get("node_type")
738 .and_then(|v| v.as_str())
739 .and_then(NodeType::from_str_loose);
740
741 let tag_filter = args.get("tag").and_then(|v| v.as_str());
742
743 let filtered: Vec<&GraphNode> = state
744 .nodes
745 .iter()
746 .filter(|n| {
747 if let Some(ref nt) = type_filter
748 && &n.node_type != nt
749 {
750 return false;
751 }
752 if let Some(tag) = tag_filter
753 && !n.tags.iter().any(|t| t.eq_ignore_ascii_case(tag))
754 {
755 return false;
756 }
757 true
758 })
759 .collect();
760
761 if filtered.is_empty() {
762 return Ok(ToolOutput::text("No nodes found."));
763 }
764
765 let mut output = format!("Nodes ({}):\n", filtered.len());
766 for node in &filtered {
767 output.push_str(&format!(
768 " {} — {} ({})\n",
769 node.id,
770 node.name,
771 node.node_type.as_str()
772 ));
773 }
774
775 Ok(ToolOutput::text(output.trim_end().to_string()))
776 }
777
778 "path" => {
779 let from_id = args.get("from_id").and_then(|v| v.as_str()).unwrap_or("");
780 let to_id = args.get("to_id").and_then(|v| v.as_str()).unwrap_or("");
781
782 if from_id.is_empty() || to_id.is_empty() {
783 return Ok(ToolOutput::text(
784 "Missing required parameters 'from_id' and 'to_id'.",
785 ));
786 }
787
788 if !state.nodes.iter().any(|n| n.id == from_id) {
789 return Ok(ToolOutput::text(format!("Node '{}' not found.", from_id)));
790 }
791 if !state.nodes.iter().any(|n| n.id == to_id) {
792 return Ok(ToolOutput::text(format!("Node '{}' not found.", to_id)));
793 }
794
795 let mut visited: HashSet<String> = HashSet::new();
797 let mut parent: HashMap<String, String> = HashMap::new();
798 let mut queue: VecDeque<String> = VecDeque::new();
799
800 visited.insert(from_id.to_string());
801 queue.push_back(from_id.to_string());
802
803 let mut found = false;
804 while let Some(current) = queue.pop_front() {
805 if current == to_id {
806 found = true;
807 break;
808 }
809
810 for edge in &state.edges {
811 let neighbor = if edge.source_id == current {
812 &edge.target_id
813 } else if edge.target_id == current {
814 &edge.source_id
815 } else {
816 continue;
817 };
818
819 if !visited.contains(neighbor) {
820 visited.insert(neighbor.clone());
821 parent.insert(neighbor.clone(), current.clone());
822 queue.push_back(neighbor.clone());
823 }
824 }
825 }
826
827 if !found {
828 return Ok(ToolOutput::text(format!(
829 "No path found from '{}' to '{}'.",
830 from_id, to_id
831 )));
832 }
833
834 let mut path_ids: Vec<String> = Vec::new();
836 let mut current = to_id.to_string();
837 while current != from_id {
838 path_ids.push(current.clone());
839 current = parent.get(¤t).unwrap().clone();
840 }
841 path_ids.push(from_id.to_string());
842 path_ids.reverse();
843
844 let mut output = format!(
845 "Path from '{}' to '{}' ({} hops):\n",
846 from_id,
847 to_id,
848 path_ids.len() - 1
849 );
850 for (i, pid) in path_ids.iter().enumerate() {
851 let node_name = state
852 .nodes
853 .iter()
854 .find(|n| n.id == *pid)
855 .map(|n| n.name.as_str())
856 .unwrap_or("?");
857 if i > 0 {
858 output.push_str(" -> ");
859 } else {
860 output.push_str(" ");
861 }
862 output.push_str(&format!("{} ({})\n", pid, node_name));
863 }
864
865 Ok(ToolOutput::text(output.trim_end().to_string()))
866 }
867
868 "stats" => {
869 let total_nodes = state.nodes.len();
870 let total_edges = state.edges.len();
871
872 if total_nodes == 0 {
873 return Ok(ToolOutput::text(
874 "Knowledge graph is empty. Use add_node to get started.",
875 ));
876 }
877
878 let mut type_counts: HashMap<&str, usize> = HashMap::new();
880 for node in &state.nodes {
881 *type_counts.entry(node.node_type.as_str()).or_insert(0) += 1;
882 }
883
884 let mut edge_type_counts: HashMap<&str, usize> = HashMap::new();
886 for edge in &state.edges {
887 *edge_type_counts
888 .entry(edge.relationship_type.as_str())
889 .or_insert(0) += 1;
890 }
891
892 let mut connection_counts: HashMap<&str, usize> = HashMap::new();
894 for edge in &state.edges {
895 *connection_counts.entry(&edge.source_id).or_insert(0) += 1;
896 *connection_counts.entry(&edge.target_id).or_insert(0) += 1;
897 }
898 let mut top_connected: Vec<(&&str, &usize)> = connection_counts.iter().collect();
899 top_connected.sort_by(|a, b| b.1.cmp(a.1));
900 top_connected.truncate(5);
901
902 let mut output = format!(
903 "Knowledge Graph Stats:\n Nodes: {}\n Edges: {}\n\n Nodes by type:\n",
904 total_nodes, total_edges
905 );
906 let mut sorted_types: Vec<_> = type_counts.iter().collect();
907 sorted_types.sort_by_key(|(k, _)| *k);
908 for (t, c) in &sorted_types {
909 output.push_str(&format!(" {}: {}\n", t, c));
910 }
911
912 if !edge_type_counts.is_empty() {
913 output.push_str("\n Edges by type:\n");
914 let mut sorted_etypes: Vec<_> = edge_type_counts.iter().collect();
915 sorted_etypes.sort_by_key(|(k, _)| *k);
916 for (t, c) in &sorted_etypes {
917 output.push_str(&format!(" {}: {}\n", t, c));
918 }
919 }
920
921 if !top_connected.is_empty() {
922 output.push_str("\n Most connected nodes:\n");
923 for (nid, count) in &top_connected {
924 let node_name = state
925 .nodes
926 .iter()
927 .find(|n| n.id == ***nid)
928 .map(|n| n.name.as_str())
929 .unwrap_or("?");
930 output.push_str(&format!(
931 " {} ({}) — {} connections\n",
932 nid, node_name, count
933 ));
934 }
935 }
936
937 Ok(ToolOutput::text(output.trim_end().to_string()))
938 }
939
940 "import_arxiv" => {
941 let arxiv_id = args.get("arxiv_id").and_then(|v| v.as_str()).unwrap_or("");
942 if arxiv_id.is_empty() {
943 return Ok(ToolOutput::text("Missing required parameter 'arxiv_id'."));
944 }
945
946 let lib_path = self.arxiv_library_path();
947 if !lib_path.exists() {
948 return Ok(ToolOutput::text(
949 "ArXiv library not found. Save papers with the arxiv_research tool first.",
950 ));
951 }
952
953 let lib_json =
954 std::fs::read_to_string(&lib_path).map_err(|e| ToolError::ExecutionFailed {
955 name: "knowledge_graph".to_string(),
956 message: format!("Failed to read arxiv library: {}", e),
957 })?;
958
959 let library: ArxivLibraryFile =
960 serde_json::from_str(&lib_json).map_err(|e| ToolError::ExecutionFailed {
961 name: "knowledge_graph".to_string(),
962 message: format!("Failed to parse arxiv library: {}", e),
963 })?;
964
965 let entry = match library.entries.iter().find(|e| e.paper.id == arxiv_id) {
966 Some(e) => e,
967 None => {
968 return Ok(ToolOutput::text(format!(
969 "Paper '{}' not found in arxiv library.",
970 arxiv_id
971 )));
972 }
973 };
974
975 let now = Utc::now();
976 let paper_id = Self::slug(&entry.paper.title);
977 let mut added_nodes = 0;
978 let mut added_edges = 0;
979
980 if !state.nodes.iter().any(|n| n.id == paper_id) {
982 let mut metadata = HashMap::new();
983 metadata.insert("arxiv_id".to_string(), entry.paper.id.clone());
984 state.nodes.push(GraphNode {
985 id: paper_id.clone(),
986 name: entry.paper.title.clone(),
987 node_type: NodeType::Paper,
988 description: entry.paper.abstract_text.clone(),
989 tags: entry.tags.clone(),
990 metadata,
991 created_at: now,
992 updated_at: now,
993 });
994 added_nodes += 1;
995 }
996
997 for author in &entry.paper.authors {
999 let author_id = Self::slug(author);
1000 if !state.nodes.iter().any(|n| n.id == author_id) {
1001 state.nodes.push(GraphNode {
1002 id: author_id.clone(),
1003 name: author.clone(),
1004 node_type: NodeType::Person,
1005 description: String::new(),
1006 tags: Vec::new(),
1007 metadata: HashMap::new(),
1008 created_at: now,
1009 updated_at: now,
1010 });
1011 added_nodes += 1;
1012 }
1013
1014 let edge_exists = state.edges.iter().any(|e| {
1016 e.source_id == paper_id
1017 && e.target_id == author_id
1018 && e.relationship_type == RelationshipType::AuthoredBy
1019 });
1020 if !edge_exists {
1021 state.edges.push(Edge {
1022 source_id: paper_id.clone(),
1023 target_id: author_id,
1024 relationship_type: RelationshipType::AuthoredBy,
1025 strength: 1.0,
1026 notes: String::new(),
1027 created_at: now,
1028 });
1029 added_edges += 1;
1030 }
1031 }
1032
1033 state.next_auto_id += added_nodes;
1034 self.save_state(&state)?;
1035
1036 Ok(ToolOutput::text(format!(
1037 "Imported '{}' from arxiv: {} node(s) and {} edge(s) added.",
1038 entry.paper.title, added_nodes, added_edges
1039 )))
1040 }
1041
1042 "export_dot" => {
1043 let type_filter = args
1044 .get("filter_type")
1045 .and_then(|v| v.as_str())
1046 .and_then(NodeType::from_str_loose);
1047
1048 let filtered_nodes: Vec<&GraphNode> = state
1049 .nodes
1050 .iter()
1051 .filter(|n| {
1052 type_filter
1053 .as_ref()
1054 .map(|nt| &n.node_type == nt)
1055 .unwrap_or(true)
1056 })
1057 .collect();
1058
1059 let node_ids: HashSet<&str> =
1060 filtered_nodes.iter().map(|n| n.id.as_str()).collect();
1061
1062 let mut dot = String::from("digraph KnowledgeGraph {\n");
1063 dot.push_str(" rankdir=LR;\n");
1064 dot.push_str(" node [shape=box];\n\n");
1065
1066 for node in &filtered_nodes {
1067 dot.push_str(&format!(
1068 " \"{}\" [label=\"{}\\n({})\"];\n",
1069 node.id,
1070 node.name.replace('"', "\\\""),
1071 node.node_type.as_str()
1072 ));
1073 }
1074
1075 dot.push('\n');
1076
1077 for edge in &state.edges {
1078 if node_ids.contains(edge.source_id.as_str())
1079 && node_ids.contains(edge.target_id.as_str())
1080 {
1081 dot.push_str(&format!(
1082 " \"{}\" -> \"{}\" [label=\"{}\"];\n",
1083 edge.source_id,
1084 edge.target_id,
1085 edge.relationship_type.as_str()
1086 ));
1087 }
1088 }
1089
1090 dot.push_str("}\n");
1091
1092 Ok(ToolOutput::text(dot))
1093 }
1094
1095 _ => Ok(ToolOutput::text(format!(
1096 "Unknown action: '{}'. Use: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot.",
1097 action
1098 ))),
1099 }
1100 }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use tempfile::TempDir;
1107
1108 fn make_tool() -> (TempDir, KnowledgeGraphTool) {
1109 let dir = TempDir::new().unwrap();
1110 let workspace = dir.path().canonicalize().unwrap();
1111 let tool = KnowledgeGraphTool::new(workspace);
1112 (dir, tool)
1113 }
1114
1115 #[test]
1116 fn test_tool_properties() {
1117 let (_dir, tool) = make_tool();
1118 assert_eq!(tool.name(), "knowledge_graph");
1119 assert_eq!(tool.risk_level(), RiskLevel::Write);
1120 assert_eq!(tool.timeout(), Duration::from_secs(30));
1121 assert!(!tool.description().is_empty());
1122 }
1123
1124 #[test]
1125 fn test_schema_validation() {
1126 let (_dir, tool) = make_tool();
1127 let schema = tool.parameters_schema();
1128 assert!(schema.get("properties").is_some());
1129 let required = schema["required"].as_array().unwrap();
1130 assert!(required.iter().any(|v| v.as_str() == Some("action")));
1131 let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
1132 assert_eq!(actions.len(), 13);
1133 }
1134
1135 #[tokio::test]
1136 async fn test_add_node_basic() {
1137 let (_dir, tool) = make_tool();
1138 let result = tool
1139 .execute(json!({
1140 "action": "add_node",
1141 "name": "Attention Is All You Need",
1142 "node_type": "paper",
1143 "description": "Transformer architecture paper"
1144 }))
1145 .await
1146 .unwrap();
1147 assert!(result.content.contains("Added node"));
1148 assert!(result.content.contains("attention-is-all-you-need"));
1149 }
1150
1151 #[tokio::test]
1152 async fn test_add_node_auto_id_from_slug() {
1153 let (_dir, tool) = make_tool();
1154 let result = tool
1155 .execute(json!({
1156 "action": "add_node",
1157 "name": "Deep Learning Basics",
1158 "node_type": "concept"
1159 }))
1160 .await
1161 .unwrap();
1162 assert!(result.content.contains("deep-learning-basics"));
1163
1164 let state = tool.load_state();
1166 assert_eq!(state.nodes[0].id, "deep-learning-basics");
1167 assert_eq!(state.nodes[0].name, "Deep Learning Basics");
1168 }
1169
1170 #[tokio::test]
1171 async fn test_get_node_with_edges() {
1172 let (_dir, tool) = make_tool();
1173 tool.execute(json!({
1175 "action": "add_node",
1176 "name": "Paper A",
1177 "node_type": "paper",
1178 "id": "paper-a"
1179 }))
1180 .await
1181 .unwrap();
1182 tool.execute(json!({
1183 "action": "add_node",
1184 "name": "Concept B",
1185 "node_type": "concept",
1186 "id": "concept-b"
1187 }))
1188 .await
1189 .unwrap();
1190 tool.execute(json!({
1191 "action": "add_edge",
1192 "source_id": "paper-a",
1193 "target_id": "concept-b",
1194 "relationship_type": "implements"
1195 }))
1196 .await
1197 .unwrap();
1198
1199 let result = tool
1200 .execute(json!({ "action": "get_node", "id": "paper-a" }))
1201 .await
1202 .unwrap();
1203 assert!(result.content.contains("Paper A"));
1204 assert!(result.content.contains("Implements"));
1205 assert!(result.content.contains("concept-b"));
1206 }
1207
1208 #[tokio::test]
1209 async fn test_update_node() {
1210 let (_dir, tool) = make_tool();
1211 tool.execute(json!({
1212 "action": "add_node",
1213 "name": "Old Name",
1214 "node_type": "concept",
1215 "id": "test-node"
1216 }))
1217 .await
1218 .unwrap();
1219
1220 let result = tool
1221 .execute(json!({
1222 "action": "update_node",
1223 "id": "test-node",
1224 "name": "New Name",
1225 "description": "Updated description",
1226 "tags": ["updated"]
1227 }))
1228 .await
1229 .unwrap();
1230 assert!(result.content.contains("Updated node"));
1231
1232 let state = tool.load_state();
1233 let node = state.nodes.iter().find(|n| n.id == "test-node").unwrap();
1234 assert_eq!(node.name, "New Name");
1235 assert_eq!(node.description, "Updated description");
1236 assert_eq!(node.tags, vec!["updated"]);
1237 }
1238
1239 #[tokio::test]
1240 async fn test_remove_node_cascades_edges() {
1241 let (_dir, tool) = make_tool();
1242 tool.execute(
1244 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1245 )
1246 .await
1247 .unwrap();
1248 tool.execute(
1249 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1250 )
1251 .await
1252 .unwrap();
1253 tool.execute(
1254 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1255 )
1256 .await
1257 .unwrap();
1258 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1259 .await.unwrap();
1260 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1261 .await.unwrap();
1262
1263 let state = tool.load_state();
1264 assert_eq!(state.edges.len(), 2);
1265
1266 let result = tool
1268 .execute(json!({ "action": "remove_node", "id": "b" }))
1269 .await
1270 .unwrap();
1271 assert!(result.content.contains("Removed node 'b'"));
1272 assert!(result.content.contains("2 connected edge(s)"));
1273
1274 let state = tool.load_state();
1275 assert_eq!(state.nodes.len(), 2);
1276 assert_eq!(state.edges.len(), 0);
1277 }
1278
1279 #[tokio::test]
1280 async fn test_add_edge_validates_nodes() {
1281 let (_dir, tool) = make_tool();
1282 tool.execute(
1283 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1284 )
1285 .await
1286 .unwrap();
1287
1288 let result = tool
1290 .execute(json!({
1291 "action": "add_edge",
1292 "source_id": "a",
1293 "target_id": "nonexistent",
1294 "relationship_type": "related_to"
1295 }))
1296 .await
1297 .unwrap();
1298 assert!(result.content.contains("not found"));
1299
1300 let result = tool
1302 .execute(json!({
1303 "action": "add_edge",
1304 "source_id": "nonexistent",
1305 "target_id": "a",
1306 "relationship_type": "related_to"
1307 }))
1308 .await
1309 .unwrap();
1310 assert!(result.content.contains("not found"));
1311 }
1312
1313 #[tokio::test]
1314 async fn test_add_edge_strength_default() {
1315 let (_dir, tool) = make_tool();
1316 tool.execute(
1317 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1318 )
1319 .await
1320 .unwrap();
1321 tool.execute(
1322 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1323 )
1324 .await
1325 .unwrap();
1326
1327 tool.execute(json!({
1328 "action": "add_edge",
1329 "source_id": "a",
1330 "target_id": "b",
1331 "relationship_type": "related_to"
1332 }))
1333 .await
1334 .unwrap();
1335
1336 let state = tool.load_state();
1337 assert_eq!(state.edges.len(), 1);
1338 assert!((state.edges[0].strength - 0.5).abs() < f64::EPSILON);
1339 }
1340
1341 #[tokio::test]
1342 async fn test_remove_edge() {
1343 let (_dir, tool) = make_tool();
1344 tool.execute(
1345 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1346 )
1347 .await
1348 .unwrap();
1349 tool.execute(
1350 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1351 )
1352 .await
1353 .unwrap();
1354 tool.execute(json!({
1355 "action": "add_edge",
1356 "source_id": "a",
1357 "target_id": "b",
1358 "relationship_type": "related_to"
1359 }))
1360 .await
1361 .unwrap();
1362
1363 assert_eq!(tool.load_state().edges.len(), 1);
1364
1365 let result = tool
1366 .execute(json!({
1367 "action": "remove_edge",
1368 "source_id": "a",
1369 "target_id": "b"
1370 }))
1371 .await
1372 .unwrap();
1373 assert!(result.content.contains("Removed"));
1374 assert_eq!(tool.load_state().edges.len(), 0);
1375 }
1376
1377 #[tokio::test]
1378 async fn test_neighbors_depth_1() {
1379 let (_dir, tool) = make_tool();
1380 tool.execute(
1381 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1382 )
1383 .await
1384 .unwrap();
1385 tool.execute(
1386 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1387 )
1388 .await
1389 .unwrap();
1390 tool.execute(
1391 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1392 )
1393 .await
1394 .unwrap();
1395 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1396 .await.unwrap();
1397 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1398 .await.unwrap();
1399
1400 let result = tool
1401 .execute(json!({ "action": "neighbors", "id": "a", "depth": 1 }))
1402 .await
1403 .unwrap();
1404 assert!(result.content.contains("b"));
1405 assert!(!result.content.contains(" c ") && !result.content.contains("\n [depth 1] c"));
1407 }
1408
1409 #[tokio::test]
1410 async fn test_neighbors_depth_2() {
1411 let (_dir, tool) = make_tool();
1412 tool.execute(
1413 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1414 )
1415 .await
1416 .unwrap();
1417 tool.execute(
1418 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1419 )
1420 .await
1421 .unwrap();
1422 tool.execute(
1423 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1424 )
1425 .await
1426 .unwrap();
1427 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1428 .await.unwrap();
1429 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1430 .await.unwrap();
1431
1432 let result = tool
1433 .execute(json!({ "action": "neighbors", "id": "a", "depth": 2 }))
1434 .await
1435 .unwrap();
1436 assert!(result.content.contains("b"));
1438 assert!(result.content.contains("c"));
1439 }
1440
1441 #[tokio::test]
1442 async fn test_search_by_name() {
1443 let (_dir, tool) = make_tool();
1444 tool.execute(json!({
1445 "action": "add_node",
1446 "name": "Transformer Architecture",
1447 "node_type": "method",
1448 "id": "transformer"
1449 }))
1450 .await
1451 .unwrap();
1452
1453 let result = tool
1454 .execute(json!({ "action": "search", "query": "transformer" }))
1455 .await
1456 .unwrap();
1457 assert!(result.content.contains("Transformer Architecture"));
1458 }
1459
1460 #[tokio::test]
1461 async fn test_search_by_tag() {
1462 let (_dir, tool) = make_tool();
1463 tool.execute(json!({
1464 "action": "add_node",
1465 "name": "BERT",
1466 "node_type": "method",
1467 "id": "bert",
1468 "tags": ["nlp", "language-model"]
1469 }))
1470 .await
1471 .unwrap();
1472
1473 let result = tool
1474 .execute(json!({ "action": "search", "query": "nlp" }))
1475 .await
1476 .unwrap();
1477 assert!(result.content.contains("BERT"));
1478 }
1479
1480 #[tokio::test]
1481 async fn test_search_filter_type() {
1482 let (_dir, tool) = make_tool();
1483 tool.execute(json!({ "action": "add_node", "name": "ML Paper", "node_type": "paper", "id": "ml-paper" }))
1484 .await.unwrap();
1485 tool.execute(json!({ "action": "add_node", "name": "ML Concept", "node_type": "concept", "id": "ml-concept" }))
1486 .await.unwrap();
1487
1488 let result = tool
1489 .execute(json!({ "action": "search", "query": "ml", "node_type": "paper" }))
1490 .await
1491 .unwrap();
1492 assert!(result.content.contains("ML Paper"));
1493 assert!(!result.content.contains("ML Concept"));
1494 }
1495
1496 #[tokio::test]
1497 async fn test_list_all() {
1498 let (_dir, tool) = make_tool();
1499 tool.execute(
1500 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1501 )
1502 .await
1503 .unwrap();
1504 tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1505 .await
1506 .unwrap();
1507
1508 let result = tool.execute(json!({ "action": "list" })).await.unwrap();
1509 assert!(result.content.contains("Nodes (2)"));
1510 assert!(result.content.contains("a"));
1511 assert!(result.content.contains("b"));
1512 }
1513
1514 #[tokio::test]
1515 async fn test_list_filter_type() {
1516 let (_dir, tool) = make_tool();
1517 tool.execute(
1518 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1519 )
1520 .await
1521 .unwrap();
1522 tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1523 .await
1524 .unwrap();
1525
1526 let result = tool
1527 .execute(json!({ "action": "list", "node_type": "concept" }))
1528 .await
1529 .unwrap();
1530 assert!(result.content.contains("Nodes (1)"));
1531 assert!(result.content.contains("a"));
1532 assert!(!result.content.contains(" b "));
1533 }
1534
1535 #[tokio::test]
1536 async fn test_path_direct() {
1537 let (_dir, tool) = make_tool();
1538 tool.execute(
1539 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1540 )
1541 .await
1542 .unwrap();
1543 tool.execute(
1544 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1545 )
1546 .await
1547 .unwrap();
1548 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1549 .await.unwrap();
1550
1551 let result = tool
1552 .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1553 .await
1554 .unwrap();
1555 assert!(result.content.contains("1 hops"));
1556 assert!(result.content.contains("a"));
1557 assert!(result.content.contains("b"));
1558 }
1559
1560 #[tokio::test]
1561 async fn test_path_indirect() {
1562 let (_dir, tool) = make_tool();
1563 tool.execute(
1564 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1565 )
1566 .await
1567 .unwrap();
1568 tool.execute(
1569 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1570 )
1571 .await
1572 .unwrap();
1573 tool.execute(
1574 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1575 )
1576 .await
1577 .unwrap();
1578 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1579 .await.unwrap();
1580 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1581 .await.unwrap();
1582
1583 let result = tool
1584 .execute(json!({ "action": "path", "from_id": "a", "to_id": "c" }))
1585 .await
1586 .unwrap();
1587 assert!(result.content.contains("2 hops"));
1588 assert!(result.content.contains("a"));
1589 assert!(result.content.contains("b"));
1590 assert!(result.content.contains("c"));
1591 }
1592
1593 #[tokio::test]
1594 async fn test_path_no_path() {
1595 let (_dir, tool) = make_tool();
1596 tool.execute(
1597 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1598 )
1599 .await
1600 .unwrap();
1601 tool.execute(
1602 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1603 )
1604 .await
1605 .unwrap();
1606 let result = tool
1609 .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1610 .await
1611 .unwrap();
1612 assert!(result.content.contains("No path found"));
1613 }
1614
1615 #[tokio::test]
1616 async fn test_stats() {
1617 let (_dir, tool) = make_tool();
1618 tool.execute(
1619 json!({ "action": "add_node", "name": "Paper 1", "node_type": "paper", "id": "p1" }),
1620 )
1621 .await
1622 .unwrap();
1623 tool.execute(json!({ "action": "add_node", "name": "Concept 1", "node_type": "concept", "id": "c1" }))
1624 .await.unwrap();
1625 tool.execute(
1626 json!({ "action": "add_node", "name": "Author 1", "node_type": "person", "id": "a1" }),
1627 )
1628 .await
1629 .unwrap();
1630 tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "c1", "relationship_type": "implements" }))
1631 .await.unwrap();
1632 tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "a1", "relationship_type": "authored_by" }))
1633 .await.unwrap();
1634
1635 let result = tool.execute(json!({ "action": "stats" })).await.unwrap();
1636 assert!(result.content.contains("Nodes: 3"));
1637 assert!(result.content.contains("Edges: 2"));
1638 assert!(result.content.contains("Paper: 1"));
1639 assert!(result.content.contains("Concept: 1"));
1640 assert!(result.content.contains("Person: 1"));
1641 assert!(result.content.contains("Most connected"));
1642 assert!(result.content.contains("p1"));
1643 }
1644
1645 #[tokio::test]
1646 async fn test_export_dot() {
1647 let (_dir, tool) = make_tool();
1648 tool.execute(
1649 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1650 )
1651 .await
1652 .unwrap();
1653 tool.execute(
1654 json!({ "action": "add_node", "name": "B", "node_type": "method", "id": "b" }),
1655 )
1656 .await
1657 .unwrap();
1658 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "implements" }))
1659 .await.unwrap();
1660
1661 let result = tool
1662 .execute(json!({ "action": "export_dot" }))
1663 .await
1664 .unwrap();
1665 assert!(result.content.contains("digraph KnowledgeGraph"));
1666 assert!(result.content.contains("\"a\""));
1667 assert!(result.content.contains("\"b\""));
1668 assert!(result.content.contains("Implements"));
1669 assert!(result.content.contains("->"));
1670 }
1671
1672 #[tokio::test]
1673 async fn test_state_roundtrip() {
1674 let (_dir, tool) = make_tool();
1675 tool.execute(json!({
1677 "action": "add_node",
1678 "name": "Test Node",
1679 "node_type": "dataset",
1680 "id": "test-node",
1681 "description": "A test dataset",
1682 "tags": ["test", "data"],
1683 "metadata": { "source": "kaggle", "size": "1GB" }
1684 }))
1685 .await
1686 .unwrap();
1687
1688 let state = tool.load_state();
1690 assert_eq!(state.nodes.len(), 1);
1691 let node = &state.nodes[0];
1692 assert_eq!(node.id, "test-node");
1693 assert_eq!(node.name, "Test Node");
1694 assert_eq!(node.node_type, NodeType::Dataset);
1695 assert_eq!(node.description, "A test dataset");
1696 assert_eq!(node.tags, vec!["test", "data"]);
1697 assert_eq!(node.metadata.get("source").unwrap(), "kaggle");
1698 assert_eq!(node.metadata.get("size").unwrap(), "1GB");
1699 }
1700
1701 #[tokio::test]
1702 async fn test_unknown_action() {
1703 let (_dir, tool) = make_tool();
1704 let result = tool.execute(json!({ "action": "foobar" })).await.unwrap();
1705 assert!(result.content.contains("Unknown action"));
1706 assert!(result.content.contains("foobar"));
1707 }
1708}