1use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::path::PathBuf;
11use std::time::Duration;
12
13use crate::registry::Tool;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16enum NodeType {
17 Paper,
18 Concept,
19 Method,
20 Dataset,
21 Person,
22 Organization,
23}
24
25impl NodeType {
26 fn from_str_loose(s: &str) -> Option<Self> {
27 match s.to_lowercase().as_str() {
28 "paper" => Some(Self::Paper),
29 "concept" => Some(Self::Concept),
30 "method" => Some(Self::Method),
31 "dataset" => Some(Self::Dataset),
32 "person" => Some(Self::Person),
33 "organization" => Some(Self::Organization),
34 _ => None,
35 }
36 }
37
38 fn as_str(&self) -> &str {
39 match self {
40 Self::Paper => "Paper",
41 Self::Concept => "Concept",
42 Self::Method => "Method",
43 Self::Dataset => "Dataset",
44 Self::Person => "Person",
45 Self::Organization => "Organization",
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51enum RelationshipType {
52 Cites,
53 Implements,
54 Extends,
55 Contradicts,
56 BuildsOn,
57 AuthoredBy,
58 UsesDataset,
59 RelatedTo,
60}
61
62impl RelationshipType {
63 fn from_str_loose(s: &str) -> Option<Self> {
64 match s.to_lowercase().as_str() {
65 "cites" => Some(Self::Cites),
66 "implements" => Some(Self::Implements),
67 "extends" => Some(Self::Extends),
68 "contradicts" => Some(Self::Contradicts),
69 "builds_on" => Some(Self::BuildsOn),
70 "authored_by" => Some(Self::AuthoredBy),
71 "uses_dataset" => Some(Self::UsesDataset),
72 "related_to" => Some(Self::RelatedTo),
73 _ => None,
74 }
75 }
76
77 fn as_str(&self) -> &str {
78 match self {
79 Self::Cites => "Cites",
80 Self::Implements => "Implements",
81 Self::Extends => "Extends",
82 Self::Contradicts => "Contradicts",
83 Self::BuildsOn => "BuildsOn",
84 Self::AuthoredBy => "AuthoredBy",
85 Self::UsesDataset => "UsesDataset",
86 Self::RelatedTo => "RelatedTo",
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct GraphNode {
93 id: String,
94 name: String,
95 node_type: NodeType,
96 description: String,
97 tags: Vec<String>,
98 metadata: HashMap<String, String>,
99 created_at: DateTime<Utc>,
100 updated_at: DateTime<Utc>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104struct Edge {
105 source_id: String,
106 target_id: String,
107 relationship_type: RelationshipType,
108 strength: f64,
109 notes: String,
110 created_at: DateTime<Utc>,
111}
112
113#[derive(Debug, Default, Serialize, Deserialize)]
114struct KnowledgeGraphState {
115 nodes: Vec<GraphNode>,
116 edges: Vec<Edge>,
117 next_auto_id: usize,
118}
119
120#[derive(Debug, Deserialize)]
122struct ArxivLibraryFile {
123 entries: Vec<ArxivLibraryEntry>,
124}
125
126#[derive(Debug, Deserialize)]
127struct ArxivLibraryEntry {
128 paper: ArxivPaperRef,
129 #[serde(default)]
130 tags: Vec<String>,
131}
132
133#[derive(Debug, Deserialize)]
134struct ArxivPaperRef {
135 #[serde(alias = "arxiv_id")]
136 id: String,
137 title: String,
138 authors: Vec<String>,
139 #[serde(alias = "summary")]
140 abstract_text: String,
141}
142
143pub struct KnowledgeGraphTool {
144 workspace: PathBuf,
145}
146
147impl KnowledgeGraphTool {
148 pub fn new(workspace: PathBuf) -> Self {
149 Self { workspace }
150 }
151
152 fn state_path(&self) -> PathBuf {
153 self.workspace
154 .join(".rustant")
155 .join("knowledge")
156 .join("graph.json")
157 }
158
159 fn load_state(&self) -> KnowledgeGraphState {
160 let path = self.state_path();
161 if path.exists() {
162 std::fs::read_to_string(&path)
163 .ok()
164 .and_then(|s| serde_json::from_str(&s).ok())
165 .unwrap_or_default()
166 } else {
167 KnowledgeGraphState::default()
168 }
169 }
170
171 fn save_state(&self, state: &KnowledgeGraphState) -> Result<(), ToolError> {
172 let path = self.state_path();
173 if let Some(parent) = path.parent() {
174 std::fs::create_dir_all(parent).map_err(|e| ToolError::ExecutionFailed {
175 name: "knowledge_graph".to_string(),
176 message: format!("Failed to create state dir: {}", e),
177 })?;
178 }
179 let json = serde_json::to_string_pretty(state).map_err(|e| ToolError::ExecutionFailed {
180 name: "knowledge_graph".to_string(),
181 message: format!("Failed to serialize state: {}", e),
182 })?;
183 let tmp = path.with_extension("json.tmp");
184 std::fs::write(&tmp, &json).map_err(|e| ToolError::ExecutionFailed {
185 name: "knowledge_graph".to_string(),
186 message: format!("Failed to write state: {}", e),
187 })?;
188 std::fs::rename(&tmp, &path).map_err(|e| ToolError::ExecutionFailed {
189 name: "knowledge_graph".to_string(),
190 message: format!("Failed to rename state file: {}", e),
191 })?;
192 Ok(())
193 }
194
195 fn slug(name: &str) -> String {
196 name.to_lowercase().replace(' ', "-")
197 }
198
199 fn arxiv_library_path(&self) -> PathBuf {
200 self.workspace
201 .join(".rustant")
202 .join("arxiv")
203 .join("library.json")
204 }
205}
206
207#[async_trait]
208impl Tool for KnowledgeGraphTool {
209 fn name(&self) -> &str {
210 "knowledge_graph"
211 }
212
213 fn description(&self) -> &str {
214 "Local knowledge graph of concepts, papers, methods, and relationships. Actions: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot."
215 }
216
217 fn parameters_schema(&self) -> Value {
218 json!({
219 "type": "object",
220 "properties": {
221 "action": {
222 "type": "string",
223 "enum": [
224 "add_node", "get_node", "update_node", "remove_node",
225 "add_edge", "remove_edge", "neighbors", "search",
226 "list", "path", "stats", "import_arxiv", "export_dot"
227 ],
228 "description": "Action to perform"
229 },
230 "id": { "type": "string", "description": "Node ID" },
231 "name": { "type": "string", "description": "Node name" },
232 "node_type": {
233 "type": "string",
234 "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
235 "description": "Type of node"
236 },
237 "description": { "type": "string", "description": "Node description" },
238 "tags": {
239 "type": "array",
240 "items": { "type": "string" },
241 "description": "Tags for the node"
242 },
243 "metadata": {
244 "type": "object",
245 "additionalProperties": { "type": "string" },
246 "description": "Key-value metadata for the node"
247 },
248 "source_id": { "type": "string", "description": "Source node ID for edge" },
249 "target_id": { "type": "string", "description": "Target node ID for edge" },
250 "relationship_type": {
251 "type": "string",
252 "enum": ["cites", "implements", "extends", "contradicts", "builds_on", "authored_by", "uses_dataset", "related_to"],
253 "description": "Type of relationship"
254 },
255 "strength": {
256 "type": "number",
257 "description": "Edge strength 0.0-1.0 (default 0.5)"
258 },
259 "notes": { "type": "string", "description": "Edge notes" },
260 "depth": {
261 "type": "integer",
262 "description": "Traversal depth for neighbors (1-3, default 1)"
263 },
264 "query": { "type": "string", "description": "Search query" },
265 "tag": { "type": "string", "description": "Filter by tag" },
266 "from_id": { "type": "string", "description": "Path start node ID" },
267 "to_id": { "type": "string", "description": "Path end node ID" },
268 "arxiv_id": { "type": "string", "description": "ArXiv paper ID for import" },
269 "filter_type": {
270 "type": "string",
271 "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
272 "description": "Filter by node type"
273 }
274 },
275 "required": ["action"]
276 })
277 }
278
279 fn risk_level(&self) -> RiskLevel {
280 RiskLevel::Write
281 }
282
283 fn timeout(&self) -> Duration {
284 Duration::from_secs(30)
285 }
286
287 async fn execute(&self, args: Value) -> Result<ToolOutput, ToolError> {
288 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("");
289 let mut state = self.load_state();
290
291 match action {
292 "add_node" => {
293 let name = args
294 .get("name")
295 .and_then(|v| v.as_str())
296 .unwrap_or("")
297 .trim();
298 if name.is_empty() {
299 return Ok(ToolOutput::text("Missing required parameter 'name'."));
300 }
301
302 let node_type_str = args
303 .get("node_type")
304 .and_then(|v| v.as_str())
305 .unwrap_or("");
306 let node_type = match NodeType::from_str_loose(node_type_str) {
307 Some(nt) => nt,
308 None => {
309 return Ok(ToolOutput::text(format!(
310 "Invalid node_type '{}'. Use: paper, concept, method, dataset, person, organization.",
311 node_type_str
312 )));
313 }
314 };
315
316 let id = args
317 .get("id")
318 .and_then(|v| v.as_str())
319 .map(|s| s.to_string())
320 .unwrap_or_else(|| Self::slug(name));
321
322 if state.nodes.iter().any(|n| n.id == id) {
324 return Ok(ToolOutput::text(format!(
325 "Node with id '{}' already exists.",
326 id
327 )));
328 }
329
330 let description = args
331 .get("description")
332 .and_then(|v| v.as_str())
333 .unwrap_or("")
334 .to_string();
335
336 let tags: Vec<String> = args
337 .get("tags")
338 .and_then(|v| v.as_array())
339 .map(|arr| {
340 arr.iter()
341 .filter_map(|v| v.as_str().map(|s| s.to_string()))
342 .collect()
343 })
344 .unwrap_or_default();
345
346 let metadata: HashMap<String, String> = args
347 .get("metadata")
348 .and_then(|v| v.as_object())
349 .map(|obj| {
350 obj.iter()
351 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
352 .collect()
353 })
354 .unwrap_or_default();
355
356 let now = Utc::now();
357 state.nodes.push(GraphNode {
358 id: id.clone(),
359 name: name.to_string(),
360 node_type,
361 description,
362 tags,
363 metadata,
364 created_at: now,
365 updated_at: now,
366 });
367 state.next_auto_id += 1;
368 self.save_state(&state)?;
369
370 Ok(ToolOutput::text(format!(
371 "Added node '{}' (id: {}).",
372 name, id
373 )))
374 }
375
376 "get_node" => {
377 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
378 if id.is_empty() {
379 return Ok(ToolOutput::text("Missing required parameter 'id'."));
380 }
381
382 let node = match state.nodes.iter().find(|n| n.id == id) {
383 Some(n) => n,
384 None => {
385 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
386 }
387 };
388
389 let connected_edges: Vec<&Edge> = state
390 .edges
391 .iter()
392 .filter(|e| e.source_id == id || e.target_id == id)
393 .collect();
394
395 let mut output = format!(
396 "Node: {} ({})\n ID: {}\n Type: {}\n Description: {}\n Tags: {}\n Created: {}\n Updated: {}",
397 node.name,
398 node.node_type.as_str(),
399 node.id,
400 node.node_type.as_str(),
401 if node.description.is_empty() { "(none)" } else { &node.description },
402 if node.tags.is_empty() { "(none)".to_string() } else { node.tags.join(", ") },
403 node.created_at.format("%Y-%m-%d %H:%M"),
404 node.updated_at.format("%Y-%m-%d %H:%M"),
405 );
406
407 if !node.metadata.is_empty() {
408 output.push_str("\n Metadata:");
409 for (k, v) in &node.metadata {
410 output.push_str(&format!("\n {}: {}", k, v));
411 }
412 }
413
414 if !connected_edges.is_empty() {
415 output.push_str(&format!("\n Edges ({}):", connected_edges.len()));
416 for edge in &connected_edges {
417 let direction = if edge.source_id == id {
418 format!("--[{}]--> {}", edge.relationship_type.as_str(), edge.target_id)
419 } else {
420 format!(
421 "<--[{}]-- {}",
422 edge.relationship_type.as_str(),
423 edge.source_id
424 )
425 };
426 output.push_str(&format!(
427 "\n {} (strength: {:.2})",
428 direction, edge.strength
429 ));
430 }
431 }
432
433 Ok(ToolOutput::text(output))
434 }
435
436 "update_node" => {
437 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
438 if id.is_empty() {
439 return Ok(ToolOutput::text("Missing required parameter 'id'."));
440 }
441
442 let node = match state.nodes.iter_mut().find(|n| n.id == id) {
443 Some(n) => n,
444 None => {
445 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
446 }
447 };
448
449 if let Some(name) = args.get("name").and_then(|v| v.as_str()) {
450 node.name = name.to_string();
451 }
452 if let Some(desc) = args.get("description").and_then(|v| v.as_str()) {
453 node.description = desc.to_string();
454 }
455 if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
456 node.tags = tags
457 .iter()
458 .filter_map(|v| v.as_str().map(|s| s.to_string()))
459 .collect();
460 }
461 node.updated_at = Utc::now();
462 self.save_state(&state)?;
463
464 Ok(ToolOutput::text(format!("Updated node '{}'.", id)))
465 }
466
467 "remove_node" => {
468 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
469 if id.is_empty() {
470 return Ok(ToolOutput::text("Missing required parameter 'id'."));
471 }
472
473 let before_nodes = state.nodes.len();
474 state.nodes.retain(|n| n.id != id);
475 if state.nodes.len() == before_nodes {
476 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
477 }
478
479 let before_edges = state.edges.len();
481 state
482 .edges
483 .retain(|e| e.source_id != id && e.target_id != id);
484 let edges_removed = before_edges - state.edges.len();
485
486 self.save_state(&state)?;
487
488 Ok(ToolOutput::text(format!(
489 "Removed node '{}' and {} connected edge(s).",
490 id, edges_removed
491 )))
492 }
493
494 "add_edge" => {
495 let source_id = args
496 .get("source_id")
497 .and_then(|v| v.as_str())
498 .unwrap_or("");
499 let target_id = args
500 .get("target_id")
501 .and_then(|v| v.as_str())
502 .unwrap_or("");
503 let rel_str = args
504 .get("relationship_type")
505 .and_then(|v| v.as_str())
506 .unwrap_or("");
507
508 if source_id.is_empty() || target_id.is_empty() {
509 return Ok(ToolOutput::text(
510 "Missing required parameters 'source_id' and 'target_id'.",
511 ));
512 }
513
514 if !state.nodes.iter().any(|n| n.id == source_id) {
516 return Ok(ToolOutput::text(format!(
517 "Source node '{}' not found.",
518 source_id
519 )));
520 }
521 if !state.nodes.iter().any(|n| n.id == target_id) {
522 return Ok(ToolOutput::text(format!(
523 "Target node '{}' not found.",
524 target_id
525 )));
526 }
527
528 let relationship_type = match RelationshipType::from_str_loose(rel_str) {
529 Some(rt) => rt,
530 None => {
531 return Ok(ToolOutput::text(format!(
532 "Invalid relationship_type '{}'. Use: cites, implements, extends, contradicts, builds_on, authored_by, uses_dataset, related_to.",
533 rel_str
534 )));
535 }
536 };
537
538 let strength = args
539 .get("strength")
540 .and_then(|v| v.as_f64())
541 .unwrap_or(0.5)
542 .clamp(0.0, 1.0);
543
544 let notes = args
545 .get("notes")
546 .and_then(|v| v.as_str())
547 .unwrap_or("")
548 .to_string();
549
550 state.edges.push(Edge {
551 source_id: source_id.to_string(),
552 target_id: target_id.to_string(),
553 relationship_type,
554 strength,
555 notes,
556 created_at: Utc::now(),
557 });
558
559 self.save_state(&state)?;
560
561 Ok(ToolOutput::text(format!(
562 "Added edge {} --[{}]--> {} (strength: {:.2}).",
563 source_id, rel_str, target_id, strength
564 )))
565 }
566
567 "remove_edge" => {
568 let source_id = args
569 .get("source_id")
570 .and_then(|v| v.as_str())
571 .unwrap_or("");
572 let target_id = args
573 .get("target_id")
574 .and_then(|v| v.as_str())
575 .unwrap_or("");
576
577 if source_id.is_empty() || target_id.is_empty() {
578 return Ok(ToolOutput::text(
579 "Missing required parameters 'source_id' and 'target_id'.",
580 ));
581 }
582
583 let rel_filter = args
584 .get("relationship_type")
585 .and_then(|v| v.as_str())
586 .and_then(RelationshipType::from_str_loose);
587
588 let before = state.edges.len();
589 state.edges.retain(|e| {
590 if e.source_id == source_id && e.target_id == target_id {
591 if let Some(ref rt) = rel_filter {
592 return &e.relationship_type != rt;
593 }
594 return false;
595 }
596 true
597 });
598 let removed = before - state.edges.len();
599
600 if removed == 0 {
601 return Ok(ToolOutput::text("No matching edge(s) found."));
602 }
603
604 self.save_state(&state)?;
605
606 Ok(ToolOutput::text(format!(
607 "Removed {} edge(s) from '{}' to '{}'.",
608 removed, source_id, target_id
609 )))
610 }
611
612 "neighbors" => {
613 let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
614 if id.is_empty() {
615 return Ok(ToolOutput::text("Missing required parameter 'id'."));
616 }
617
618 if !state.nodes.iter().any(|n| n.id == id) {
619 return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
620 }
621
622 let max_depth = args
623 .get("depth")
624 .and_then(|v| v.as_u64())
625 .unwrap_or(1)
626 .clamp(1, 3) as usize;
627
628 let rel_filter = args
629 .get("relationship_type")
630 .and_then(|v| v.as_str())
631 .and_then(RelationshipType::from_str_loose);
632
633 let mut visited: HashSet<String> = HashSet::new();
635 visited.insert(id.to_string());
636 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
637 queue.push_back((id.to_string(), 0));
638 let mut found_nodes: Vec<(String, usize)> = Vec::new();
639
640 while let Some((current_id, depth)) = queue.pop_front() {
641 if depth >= max_depth {
642 continue;
643 }
644
645 for edge in &state.edges {
646 if let Some(ref rt) = rel_filter {
647 if &edge.relationship_type != rt {
648 continue;
649 }
650 }
651
652 let neighbor_id = if edge.source_id == current_id {
653 &edge.target_id
654 } else if edge.target_id == current_id {
655 &edge.source_id
656 } else {
657 continue;
658 };
659
660 if !visited.contains(neighbor_id) {
661 visited.insert(neighbor_id.clone());
662 found_nodes.push((neighbor_id.clone(), depth + 1));
663 queue.push_back((neighbor_id.clone(), depth + 1));
664 }
665 }
666 }
667
668 if found_nodes.is_empty() {
669 return Ok(ToolOutput::text(format!(
670 "No neighbors found for '{}' within depth {}.",
671 id, max_depth
672 )));
673 }
674
675 let mut output = format!(
676 "Neighbors of '{}' (depth {}):\n",
677 id, max_depth
678 );
679 for (nid, depth) in &found_nodes {
680 if let Some(node) = state.nodes.iter().find(|n| n.id == *nid) {
681 output.push_str(&format!(
682 " [depth {}] {} — {} ({})\n",
683 depth,
684 node.id,
685 node.name,
686 node.node_type.as_str()
687 ));
688 }
689 }
690
691 Ok(ToolOutput::text(output.trim_end().to_string()))
692 }
693
694 "search" => {
695 let query = args
696 .get("query")
697 .and_then(|v| v.as_str())
698 .unwrap_or("")
699 .to_lowercase();
700 if query.is_empty() {
701 return Ok(ToolOutput::text("Missing required parameter 'query'."));
702 }
703
704 let type_filter = args
705 .get("node_type")
706 .and_then(|v| v.as_str())
707 .and_then(NodeType::from_str_loose);
708
709 let matches: Vec<&GraphNode> = state
710 .nodes
711 .iter()
712 .filter(|n| {
713 if let Some(ref nt) = type_filter {
714 if &n.node_type != nt {
715 return false;
716 }
717 }
718 n.name.to_lowercase().contains(&query)
719 || n.description.to_lowercase().contains(&query)
720 || n.tags
721 .iter()
722 .any(|t| t.to_lowercase().contains(&query))
723 })
724 .collect();
725
726 if matches.is_empty() {
727 return Ok(ToolOutput::text(format!(
728 "No nodes matching '{}'.",
729 query
730 )));
731 }
732
733 let mut output = format!("Found {} node(s):\n", matches.len());
734 for node in &matches {
735 output.push_str(&format!(
736 " {} — {} ({})\n",
737 node.id,
738 node.name,
739 node.node_type.as_str()
740 ));
741 }
742
743 Ok(ToolOutput::text(output.trim_end().to_string()))
744 }
745
746 "list" => {
747 let type_filter = args
748 .get("node_type")
749 .and_then(|v| v.as_str())
750 .and_then(NodeType::from_str_loose);
751
752 let tag_filter = args.get("tag").and_then(|v| v.as_str());
753
754 let filtered: Vec<&GraphNode> = state
755 .nodes
756 .iter()
757 .filter(|n| {
758 if let Some(ref nt) = type_filter {
759 if &n.node_type != nt {
760 return false;
761 }
762 }
763 if let Some(tag) = tag_filter {
764 if !n
765 .tags
766 .iter()
767 .any(|t| t.eq_ignore_ascii_case(tag))
768 {
769 return false;
770 }
771 }
772 true
773 })
774 .collect();
775
776 if filtered.is_empty() {
777 return Ok(ToolOutput::text("No nodes found."));
778 }
779
780 let mut output = format!("Nodes ({}):\n", filtered.len());
781 for node in &filtered {
782 output.push_str(&format!(
783 " {} — {} ({})\n",
784 node.id,
785 node.name,
786 node.node_type.as_str()
787 ));
788 }
789
790 Ok(ToolOutput::text(output.trim_end().to_string()))
791 }
792
793 "path" => {
794 let from_id = args
795 .get("from_id")
796 .and_then(|v| v.as_str())
797 .unwrap_or("");
798 let to_id = args.get("to_id").and_then(|v| v.as_str()).unwrap_or("");
799
800 if from_id.is_empty() || to_id.is_empty() {
801 return Ok(ToolOutput::text(
802 "Missing required parameters 'from_id' and 'to_id'.",
803 ));
804 }
805
806 if !state.nodes.iter().any(|n| n.id == from_id) {
807 return Ok(ToolOutput::text(format!(
808 "Node '{}' not found.",
809 from_id
810 )));
811 }
812 if !state.nodes.iter().any(|n| n.id == to_id) {
813 return Ok(ToolOutput::text(format!(
814 "Node '{}' not found.",
815 to_id
816 )));
817 }
818
819 let mut visited: HashSet<String> = HashSet::new();
821 let mut parent: HashMap<String, String> = HashMap::new();
822 let mut queue: VecDeque<String> = VecDeque::new();
823
824 visited.insert(from_id.to_string());
825 queue.push_back(from_id.to_string());
826
827 let mut found = false;
828 while let Some(current) = queue.pop_front() {
829 if current == to_id {
830 found = true;
831 break;
832 }
833
834 for edge in &state.edges {
835 let neighbor = if edge.source_id == current {
836 &edge.target_id
837 } else if edge.target_id == current {
838 &edge.source_id
839 } else {
840 continue;
841 };
842
843 if !visited.contains(neighbor) {
844 visited.insert(neighbor.clone());
845 parent.insert(neighbor.clone(), current.clone());
846 queue.push_back(neighbor.clone());
847 }
848 }
849 }
850
851 if !found {
852 return Ok(ToolOutput::text(format!(
853 "No path found from '{}' to '{}'.",
854 from_id, to_id
855 )));
856 }
857
858 let mut path_ids: Vec<String> = Vec::new();
860 let mut current = to_id.to_string();
861 while current != from_id {
862 path_ids.push(current.clone());
863 current = parent.get(¤t).unwrap().clone();
864 }
865 path_ids.push(from_id.to_string());
866 path_ids.reverse();
867
868 let mut output = format!(
869 "Path from '{}' to '{}' ({} hops):\n",
870 from_id,
871 to_id,
872 path_ids.len() - 1
873 );
874 for (i, pid) in path_ids.iter().enumerate() {
875 let node_name = state
876 .nodes
877 .iter()
878 .find(|n| n.id == *pid)
879 .map(|n| n.name.as_str())
880 .unwrap_or("?");
881 if i > 0 {
882 output.push_str(" -> ");
883 } else {
884 output.push_str(" ");
885 }
886 output.push_str(&format!("{} ({})\n", pid, node_name));
887 }
888
889 Ok(ToolOutput::text(output.trim_end().to_string()))
890 }
891
892 "stats" => {
893 let total_nodes = state.nodes.len();
894 let total_edges = state.edges.len();
895
896 if total_nodes == 0 {
897 return Ok(ToolOutput::text(
898 "Knowledge graph is empty. Use add_node to get started.",
899 ));
900 }
901
902 let mut type_counts: HashMap<&str, usize> = HashMap::new();
904 for node in &state.nodes {
905 *type_counts.entry(node.node_type.as_str()).or_insert(0) += 1;
906 }
907
908 let mut edge_type_counts: HashMap<&str, usize> = HashMap::new();
910 for edge in &state.edges {
911 *edge_type_counts
912 .entry(edge.relationship_type.as_str())
913 .or_insert(0) += 1;
914 }
915
916 let mut connection_counts: HashMap<&str, usize> = HashMap::new();
918 for edge in &state.edges {
919 *connection_counts.entry(&edge.source_id).or_insert(0) += 1;
920 *connection_counts.entry(&edge.target_id).or_insert(0) += 1;
921 }
922 let mut top_connected: Vec<(&&str, &usize)> = connection_counts.iter().collect();
923 top_connected.sort_by(|a, b| b.1.cmp(a.1));
924 top_connected.truncate(5);
925
926 let mut output = format!(
927 "Knowledge Graph Stats:\n Nodes: {}\n Edges: {}\n\n Nodes by type:\n",
928 total_nodes, total_edges
929 );
930 let mut sorted_types: Vec<_> = type_counts.iter().collect();
931 sorted_types.sort_by_key(|(k, _)| *k);
932 for (t, c) in &sorted_types {
933 output.push_str(&format!(" {}: {}\n", t, c));
934 }
935
936 if !edge_type_counts.is_empty() {
937 output.push_str("\n Edges by type:\n");
938 let mut sorted_etypes: Vec<_> = edge_type_counts.iter().collect();
939 sorted_etypes.sort_by_key(|(k, _)| *k);
940 for (t, c) in &sorted_etypes {
941 output.push_str(&format!(" {}: {}\n", t, c));
942 }
943 }
944
945 if !top_connected.is_empty() {
946 output.push_str("\n Most connected nodes:\n");
947 for (nid, count) in &top_connected {
948 let node_name = state
949 .nodes
950 .iter()
951 .find(|n| n.id == ***nid)
952 .map(|n| n.name.as_str())
953 .unwrap_or("?");
954 output.push_str(&format!(" {} ({}) — {} connections\n", nid, node_name, count));
955 }
956 }
957
958 Ok(ToolOutput::text(output.trim_end().to_string()))
959 }
960
961 "import_arxiv" => {
962 let arxiv_id = args
963 .get("arxiv_id")
964 .and_then(|v| v.as_str())
965 .unwrap_or("");
966 if arxiv_id.is_empty() {
967 return Ok(ToolOutput::text(
968 "Missing required parameter 'arxiv_id'.",
969 ));
970 }
971
972 let lib_path = self.arxiv_library_path();
973 if !lib_path.exists() {
974 return Ok(ToolOutput::text(
975 "ArXiv library not found. Save papers with the arxiv_research tool first.",
976 ));
977 }
978
979 let lib_json = std::fs::read_to_string(&lib_path).map_err(|e| {
980 ToolError::ExecutionFailed {
981 name: "knowledge_graph".to_string(),
982 message: format!("Failed to read arxiv library: {}", e),
983 }
984 })?;
985
986 let library: ArxivLibraryFile =
987 serde_json::from_str(&lib_json).map_err(|e| ToolError::ExecutionFailed {
988 name: "knowledge_graph".to_string(),
989 message: format!("Failed to parse arxiv library: {}", e),
990 })?;
991
992 let entry = match library
993 .entries
994 .iter()
995 .find(|e| e.paper.id == arxiv_id)
996 {
997 Some(e) => e,
998 None => {
999 return Ok(ToolOutput::text(format!(
1000 "Paper '{}' not found in arxiv library.",
1001 arxiv_id
1002 )));
1003 }
1004 };
1005
1006 let now = Utc::now();
1007 let paper_id = Self::slug(&entry.paper.title);
1008 let mut added_nodes = 0;
1009 let mut added_edges = 0;
1010
1011 if !state.nodes.iter().any(|n| n.id == paper_id) {
1013 let mut metadata = HashMap::new();
1014 metadata.insert("arxiv_id".to_string(), entry.paper.id.clone());
1015 state.nodes.push(GraphNode {
1016 id: paper_id.clone(),
1017 name: entry.paper.title.clone(),
1018 node_type: NodeType::Paper,
1019 description: entry.paper.abstract_text.clone(),
1020 tags: entry.tags.clone(),
1021 metadata,
1022 created_at: now,
1023 updated_at: now,
1024 });
1025 added_nodes += 1;
1026 }
1027
1028 for author in &entry.paper.authors {
1030 let author_id = Self::slug(author);
1031 if !state.nodes.iter().any(|n| n.id == author_id) {
1032 state.nodes.push(GraphNode {
1033 id: author_id.clone(),
1034 name: author.clone(),
1035 node_type: NodeType::Person,
1036 description: String::new(),
1037 tags: Vec::new(),
1038 metadata: HashMap::new(),
1039 created_at: now,
1040 updated_at: now,
1041 });
1042 added_nodes += 1;
1043 }
1044
1045 let edge_exists = state.edges.iter().any(|e| {
1047 e.source_id == paper_id
1048 && e.target_id == author_id
1049 && e.relationship_type == RelationshipType::AuthoredBy
1050 });
1051 if !edge_exists {
1052 state.edges.push(Edge {
1053 source_id: paper_id.clone(),
1054 target_id: author_id,
1055 relationship_type: RelationshipType::AuthoredBy,
1056 strength: 1.0,
1057 notes: String::new(),
1058 created_at: now,
1059 });
1060 added_edges += 1;
1061 }
1062 }
1063
1064 state.next_auto_id += added_nodes;
1065 self.save_state(&state)?;
1066
1067 Ok(ToolOutput::text(format!(
1068 "Imported '{}' from arxiv: {} node(s) and {} edge(s) added.",
1069 entry.paper.title, added_nodes, added_edges
1070 )))
1071 }
1072
1073 "export_dot" => {
1074 let type_filter = args
1075 .get("filter_type")
1076 .and_then(|v| v.as_str())
1077 .and_then(NodeType::from_str_loose);
1078
1079 let filtered_nodes: Vec<&GraphNode> = state
1080 .nodes
1081 .iter()
1082 .filter(|n| {
1083 type_filter
1084 .as_ref()
1085 .map(|nt| &n.node_type == nt)
1086 .unwrap_or(true)
1087 })
1088 .collect();
1089
1090 let node_ids: HashSet<&str> =
1091 filtered_nodes.iter().map(|n| n.id.as_str()).collect();
1092
1093 let mut dot = String::from("digraph KnowledgeGraph {\n");
1094 dot.push_str(" rankdir=LR;\n");
1095 dot.push_str(" node [shape=box];\n\n");
1096
1097 for node in &filtered_nodes {
1098 dot.push_str(&format!(
1099 " \"{}\" [label=\"{}\\n({})\"];\n",
1100 node.id,
1101 node.name.replace('"', "\\\""),
1102 node.node_type.as_str()
1103 ));
1104 }
1105
1106 dot.push('\n');
1107
1108 for edge in &state.edges {
1109 if node_ids.contains(edge.source_id.as_str())
1110 && node_ids.contains(edge.target_id.as_str())
1111 {
1112 dot.push_str(&format!(
1113 " \"{}\" -> \"{}\" [label=\"{}\"];\n",
1114 edge.source_id,
1115 edge.target_id,
1116 edge.relationship_type.as_str()
1117 ));
1118 }
1119 }
1120
1121 dot.push_str("}\n");
1122
1123 Ok(ToolOutput::text(dot))
1124 }
1125
1126 _ => Ok(ToolOutput::text(format!(
1127 "Unknown action: '{}'. Use: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot.",
1128 action
1129 ))),
1130 }
1131 }
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136 use super::*;
1137 use tempfile::TempDir;
1138
1139 fn make_tool() -> (TempDir, KnowledgeGraphTool) {
1140 let dir = TempDir::new().unwrap();
1141 let workspace = dir.path().canonicalize().unwrap();
1142 let tool = KnowledgeGraphTool::new(workspace);
1143 (dir, tool)
1144 }
1145
1146 #[test]
1147 fn test_tool_properties() {
1148 let (_dir, tool) = make_tool();
1149 assert_eq!(tool.name(), "knowledge_graph");
1150 assert_eq!(tool.risk_level(), RiskLevel::Write);
1151 assert_eq!(tool.timeout(), Duration::from_secs(30));
1152 assert!(!tool.description().is_empty());
1153 }
1154
1155 #[test]
1156 fn test_schema_validation() {
1157 let (_dir, tool) = make_tool();
1158 let schema = tool.parameters_schema();
1159 assert!(schema.get("properties").is_some());
1160 let required = schema["required"].as_array().unwrap();
1161 assert!(required.iter().any(|v| v.as_str() == Some("action")));
1162 let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
1163 assert_eq!(actions.len(), 13);
1164 }
1165
1166 #[tokio::test]
1167 async fn test_add_node_basic() {
1168 let (_dir, tool) = make_tool();
1169 let result = tool
1170 .execute(json!({
1171 "action": "add_node",
1172 "name": "Attention Is All You Need",
1173 "node_type": "paper",
1174 "description": "Transformer architecture paper"
1175 }))
1176 .await
1177 .unwrap();
1178 assert!(result.content.contains("Added node"));
1179 assert!(result.content.contains("attention-is-all-you-need"));
1180 }
1181
1182 #[tokio::test]
1183 async fn test_add_node_auto_id_from_slug() {
1184 let (_dir, tool) = make_tool();
1185 let result = tool
1186 .execute(json!({
1187 "action": "add_node",
1188 "name": "Deep Learning Basics",
1189 "node_type": "concept"
1190 }))
1191 .await
1192 .unwrap();
1193 assert!(result.content.contains("deep-learning-basics"));
1194
1195 let state = tool.load_state();
1197 assert_eq!(state.nodes[0].id, "deep-learning-basics");
1198 assert_eq!(state.nodes[0].name, "Deep Learning Basics");
1199 }
1200
1201 #[tokio::test]
1202 async fn test_get_node_with_edges() {
1203 let (_dir, tool) = make_tool();
1204 tool.execute(json!({
1206 "action": "add_node",
1207 "name": "Paper A",
1208 "node_type": "paper",
1209 "id": "paper-a"
1210 }))
1211 .await
1212 .unwrap();
1213 tool.execute(json!({
1214 "action": "add_node",
1215 "name": "Concept B",
1216 "node_type": "concept",
1217 "id": "concept-b"
1218 }))
1219 .await
1220 .unwrap();
1221 tool.execute(json!({
1222 "action": "add_edge",
1223 "source_id": "paper-a",
1224 "target_id": "concept-b",
1225 "relationship_type": "implements"
1226 }))
1227 .await
1228 .unwrap();
1229
1230 let result = tool
1231 .execute(json!({ "action": "get_node", "id": "paper-a" }))
1232 .await
1233 .unwrap();
1234 assert!(result.content.contains("Paper A"));
1235 assert!(result.content.contains("Implements"));
1236 assert!(result.content.contains("concept-b"));
1237 }
1238
1239 #[tokio::test]
1240 async fn test_update_node() {
1241 let (_dir, tool) = make_tool();
1242 tool.execute(json!({
1243 "action": "add_node",
1244 "name": "Old Name",
1245 "node_type": "concept",
1246 "id": "test-node"
1247 }))
1248 .await
1249 .unwrap();
1250
1251 let result = tool
1252 .execute(json!({
1253 "action": "update_node",
1254 "id": "test-node",
1255 "name": "New Name",
1256 "description": "Updated description",
1257 "tags": ["updated"]
1258 }))
1259 .await
1260 .unwrap();
1261 assert!(result.content.contains("Updated node"));
1262
1263 let state = tool.load_state();
1264 let node = state.nodes.iter().find(|n| n.id == "test-node").unwrap();
1265 assert_eq!(node.name, "New Name");
1266 assert_eq!(node.description, "Updated description");
1267 assert_eq!(node.tags, vec!["updated"]);
1268 }
1269
1270 #[tokio::test]
1271 async fn test_remove_node_cascades_edges() {
1272 let (_dir, tool) = make_tool();
1273 tool.execute(
1275 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1276 )
1277 .await
1278 .unwrap();
1279 tool.execute(
1280 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1281 )
1282 .await
1283 .unwrap();
1284 tool.execute(
1285 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1286 )
1287 .await
1288 .unwrap();
1289 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1290 .await.unwrap();
1291 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1292 .await.unwrap();
1293
1294 let state = tool.load_state();
1295 assert_eq!(state.edges.len(), 2);
1296
1297 let result = tool
1299 .execute(json!({ "action": "remove_node", "id": "b" }))
1300 .await
1301 .unwrap();
1302 assert!(result.content.contains("Removed node 'b'"));
1303 assert!(result.content.contains("2 connected edge(s)"));
1304
1305 let state = tool.load_state();
1306 assert_eq!(state.nodes.len(), 2);
1307 assert_eq!(state.edges.len(), 0);
1308 }
1309
1310 #[tokio::test]
1311 async fn test_add_edge_validates_nodes() {
1312 let (_dir, tool) = make_tool();
1313 tool.execute(
1314 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1315 )
1316 .await
1317 .unwrap();
1318
1319 let result = tool
1321 .execute(json!({
1322 "action": "add_edge",
1323 "source_id": "a",
1324 "target_id": "nonexistent",
1325 "relationship_type": "related_to"
1326 }))
1327 .await
1328 .unwrap();
1329 assert!(result.content.contains("not found"));
1330
1331 let result = tool
1333 .execute(json!({
1334 "action": "add_edge",
1335 "source_id": "nonexistent",
1336 "target_id": "a",
1337 "relationship_type": "related_to"
1338 }))
1339 .await
1340 .unwrap();
1341 assert!(result.content.contains("not found"));
1342 }
1343
1344 #[tokio::test]
1345 async fn test_add_edge_strength_default() {
1346 let (_dir, tool) = make_tool();
1347 tool.execute(
1348 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1349 )
1350 .await
1351 .unwrap();
1352 tool.execute(
1353 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1354 )
1355 .await
1356 .unwrap();
1357
1358 tool.execute(json!({
1359 "action": "add_edge",
1360 "source_id": "a",
1361 "target_id": "b",
1362 "relationship_type": "related_to"
1363 }))
1364 .await
1365 .unwrap();
1366
1367 let state = tool.load_state();
1368 assert_eq!(state.edges.len(), 1);
1369 assert!((state.edges[0].strength - 0.5).abs() < f64::EPSILON);
1370 }
1371
1372 #[tokio::test]
1373 async fn test_remove_edge() {
1374 let (_dir, tool) = make_tool();
1375 tool.execute(
1376 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1377 )
1378 .await
1379 .unwrap();
1380 tool.execute(
1381 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1382 )
1383 .await
1384 .unwrap();
1385 tool.execute(json!({
1386 "action": "add_edge",
1387 "source_id": "a",
1388 "target_id": "b",
1389 "relationship_type": "related_to"
1390 }))
1391 .await
1392 .unwrap();
1393
1394 assert_eq!(tool.load_state().edges.len(), 1);
1395
1396 let result = tool
1397 .execute(json!({
1398 "action": "remove_edge",
1399 "source_id": "a",
1400 "target_id": "b"
1401 }))
1402 .await
1403 .unwrap();
1404 assert!(result.content.contains("Removed"));
1405 assert_eq!(tool.load_state().edges.len(), 0);
1406 }
1407
1408 #[tokio::test]
1409 async fn test_neighbors_depth_1() {
1410 let (_dir, tool) = make_tool();
1411 tool.execute(
1412 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1413 )
1414 .await
1415 .unwrap();
1416 tool.execute(
1417 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1418 )
1419 .await
1420 .unwrap();
1421 tool.execute(
1422 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1423 )
1424 .await
1425 .unwrap();
1426 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1427 .await.unwrap();
1428 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1429 .await.unwrap();
1430
1431 let result = tool
1432 .execute(json!({ "action": "neighbors", "id": "a", "depth": 1 }))
1433 .await
1434 .unwrap();
1435 assert!(result.content.contains("b"));
1436 assert!(!result.content.contains(" c ") && !result.content.contains("\n [depth 1] c"));
1438 }
1439
1440 #[tokio::test]
1441 async fn test_neighbors_depth_2() {
1442 let (_dir, tool) = make_tool();
1443 tool.execute(
1444 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1445 )
1446 .await
1447 .unwrap();
1448 tool.execute(
1449 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1450 )
1451 .await
1452 .unwrap();
1453 tool.execute(
1454 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1455 )
1456 .await
1457 .unwrap();
1458 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1459 .await.unwrap();
1460 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1461 .await.unwrap();
1462
1463 let result = tool
1464 .execute(json!({ "action": "neighbors", "id": "a", "depth": 2 }))
1465 .await
1466 .unwrap();
1467 assert!(result.content.contains("b"));
1469 assert!(result.content.contains("c"));
1470 }
1471
1472 #[tokio::test]
1473 async fn test_search_by_name() {
1474 let (_dir, tool) = make_tool();
1475 tool.execute(json!({
1476 "action": "add_node",
1477 "name": "Transformer Architecture",
1478 "node_type": "method",
1479 "id": "transformer"
1480 }))
1481 .await
1482 .unwrap();
1483
1484 let result = tool
1485 .execute(json!({ "action": "search", "query": "transformer" }))
1486 .await
1487 .unwrap();
1488 assert!(result.content.contains("Transformer Architecture"));
1489 }
1490
1491 #[tokio::test]
1492 async fn test_search_by_tag() {
1493 let (_dir, tool) = make_tool();
1494 tool.execute(json!({
1495 "action": "add_node",
1496 "name": "BERT",
1497 "node_type": "method",
1498 "id": "bert",
1499 "tags": ["nlp", "language-model"]
1500 }))
1501 .await
1502 .unwrap();
1503
1504 let result = tool
1505 .execute(json!({ "action": "search", "query": "nlp" }))
1506 .await
1507 .unwrap();
1508 assert!(result.content.contains("BERT"));
1509 }
1510
1511 #[tokio::test]
1512 async fn test_search_filter_type() {
1513 let (_dir, tool) = make_tool();
1514 tool.execute(json!({ "action": "add_node", "name": "ML Paper", "node_type": "paper", "id": "ml-paper" }))
1515 .await.unwrap();
1516 tool.execute(json!({ "action": "add_node", "name": "ML Concept", "node_type": "concept", "id": "ml-concept" }))
1517 .await.unwrap();
1518
1519 let result = tool
1520 .execute(json!({ "action": "search", "query": "ml", "node_type": "paper" }))
1521 .await
1522 .unwrap();
1523 assert!(result.content.contains("ML Paper"));
1524 assert!(!result.content.contains("ML Concept"));
1525 }
1526
1527 #[tokio::test]
1528 async fn test_list_all() {
1529 let (_dir, tool) = make_tool();
1530 tool.execute(
1531 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1532 )
1533 .await
1534 .unwrap();
1535 tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1536 .await
1537 .unwrap();
1538
1539 let result = tool.execute(json!({ "action": "list" })).await.unwrap();
1540 assert!(result.content.contains("Nodes (2)"));
1541 assert!(result.content.contains("a"));
1542 assert!(result.content.contains("b"));
1543 }
1544
1545 #[tokio::test]
1546 async fn test_list_filter_type() {
1547 let (_dir, tool) = make_tool();
1548 tool.execute(
1549 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1550 )
1551 .await
1552 .unwrap();
1553 tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1554 .await
1555 .unwrap();
1556
1557 let result = tool
1558 .execute(json!({ "action": "list", "node_type": "concept" }))
1559 .await
1560 .unwrap();
1561 assert!(result.content.contains("Nodes (1)"));
1562 assert!(result.content.contains("a"));
1563 assert!(!result.content.contains(" b "));
1564 }
1565
1566 #[tokio::test]
1567 async fn test_path_direct() {
1568 let (_dir, tool) = make_tool();
1569 tool.execute(
1570 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1571 )
1572 .await
1573 .unwrap();
1574 tool.execute(
1575 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1576 )
1577 .await
1578 .unwrap();
1579 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1580 .await.unwrap();
1581
1582 let result = tool
1583 .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1584 .await
1585 .unwrap();
1586 assert!(result.content.contains("1 hops"));
1587 assert!(result.content.contains("a"));
1588 assert!(result.content.contains("b"));
1589 }
1590
1591 #[tokio::test]
1592 async fn test_path_indirect() {
1593 let (_dir, tool) = make_tool();
1594 tool.execute(
1595 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1596 )
1597 .await
1598 .unwrap();
1599 tool.execute(
1600 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1601 )
1602 .await
1603 .unwrap();
1604 tool.execute(
1605 json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1606 )
1607 .await
1608 .unwrap();
1609 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1610 .await.unwrap();
1611 tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1612 .await.unwrap();
1613
1614 let result = tool
1615 .execute(json!({ "action": "path", "from_id": "a", "to_id": "c" }))
1616 .await
1617 .unwrap();
1618 assert!(result.content.contains("2 hops"));
1619 assert!(result.content.contains("a"));
1620 assert!(result.content.contains("b"));
1621 assert!(result.content.contains("c"));
1622 }
1623
1624 #[tokio::test]
1625 async fn test_path_no_path() {
1626 let (_dir, tool) = make_tool();
1627 tool.execute(
1628 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1629 )
1630 .await
1631 .unwrap();
1632 tool.execute(
1633 json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1634 )
1635 .await
1636 .unwrap();
1637 let result = tool
1640 .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1641 .await
1642 .unwrap();
1643 assert!(result.content.contains("No path found"));
1644 }
1645
1646 #[tokio::test]
1647 async fn test_stats() {
1648 let (_dir, tool) = make_tool();
1649 tool.execute(
1650 json!({ "action": "add_node", "name": "Paper 1", "node_type": "paper", "id": "p1" }),
1651 )
1652 .await
1653 .unwrap();
1654 tool.execute(json!({ "action": "add_node", "name": "Concept 1", "node_type": "concept", "id": "c1" }))
1655 .await.unwrap();
1656 tool.execute(
1657 json!({ "action": "add_node", "name": "Author 1", "node_type": "person", "id": "a1" }),
1658 )
1659 .await
1660 .unwrap();
1661 tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "c1", "relationship_type": "implements" }))
1662 .await.unwrap();
1663 tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "a1", "relationship_type": "authored_by" }))
1664 .await.unwrap();
1665
1666 let result = tool.execute(json!({ "action": "stats" })).await.unwrap();
1667 assert!(result.content.contains("Nodes: 3"));
1668 assert!(result.content.contains("Edges: 2"));
1669 assert!(result.content.contains("Paper: 1"));
1670 assert!(result.content.contains("Concept: 1"));
1671 assert!(result.content.contains("Person: 1"));
1672 assert!(result.content.contains("Most connected"));
1673 assert!(result.content.contains("p1"));
1674 }
1675
1676 #[tokio::test]
1677 async fn test_export_dot() {
1678 let (_dir, tool) = make_tool();
1679 tool.execute(
1680 json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1681 )
1682 .await
1683 .unwrap();
1684 tool.execute(
1685 json!({ "action": "add_node", "name": "B", "node_type": "method", "id": "b" }),
1686 )
1687 .await
1688 .unwrap();
1689 tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "implements" }))
1690 .await.unwrap();
1691
1692 let result = tool
1693 .execute(json!({ "action": "export_dot" }))
1694 .await
1695 .unwrap();
1696 assert!(result.content.contains("digraph KnowledgeGraph"));
1697 assert!(result.content.contains("\"a\""));
1698 assert!(result.content.contains("\"b\""));
1699 assert!(result.content.contains("Implements"));
1700 assert!(result.content.contains("->"));
1701 }
1702
1703 #[tokio::test]
1704 async fn test_state_roundtrip() {
1705 let (_dir, tool) = make_tool();
1706 tool.execute(json!({
1708 "action": "add_node",
1709 "name": "Test Node",
1710 "node_type": "dataset",
1711 "id": "test-node",
1712 "description": "A test dataset",
1713 "tags": ["test", "data"],
1714 "metadata": { "source": "kaggle", "size": "1GB" }
1715 }))
1716 .await
1717 .unwrap();
1718
1719 let state = tool.load_state();
1721 assert_eq!(state.nodes.len(), 1);
1722 let node = &state.nodes[0];
1723 assert_eq!(node.id, "test-node");
1724 assert_eq!(node.name, "Test Node");
1725 assert_eq!(node.node_type, NodeType::Dataset);
1726 assert_eq!(node.description, "A test dataset");
1727 assert_eq!(node.tags, vec!["test", "data"]);
1728 assert_eq!(node.metadata.get("source").unwrap(), "kaggle");
1729 assert_eq!(node.metadata.get("size").unwrap(), "1GB");
1730 }
1731
1732 #[tokio::test]
1733 async fn test_unknown_action() {
1734 let (_dir, tool) = make_tool();
1735 let result = tool.execute(json!({ "action": "foobar" })).await.unwrap();
1736 assert!(result.content.contains("Unknown action"));
1737 assert!(result.content.contains("foobar"));
1738 }
1739}