Skip to main content

rustant_tools/
knowledge_graph.rs

1//! Knowledge graph tool — local graph of concepts, papers, methods, and relationships.
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::path::PathBuf;
11use std::time::Duration;
12
13use crate::registry::Tool;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16enum NodeType {
17    Paper,
18    Concept,
19    Method,
20    Dataset,
21    Person,
22    Organization,
23}
24
25impl NodeType {
26    fn from_str_loose(s: &str) -> Option<Self> {
27        match s.to_lowercase().as_str() {
28            "paper" => Some(Self::Paper),
29            "concept" => Some(Self::Concept),
30            "method" => Some(Self::Method),
31            "dataset" => Some(Self::Dataset),
32            "person" => Some(Self::Person),
33            "organization" => Some(Self::Organization),
34            _ => None,
35        }
36    }
37
38    fn as_str(&self) -> &str {
39        match self {
40            Self::Paper => "Paper",
41            Self::Concept => "Concept",
42            Self::Method => "Method",
43            Self::Dataset => "Dataset",
44            Self::Person => "Person",
45            Self::Organization => "Organization",
46        }
47    }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51enum RelationshipType {
52    Cites,
53    Implements,
54    Extends,
55    Contradicts,
56    BuildsOn,
57    AuthoredBy,
58    UsesDataset,
59    RelatedTo,
60}
61
62impl RelationshipType {
63    fn from_str_loose(s: &str) -> Option<Self> {
64        match s.to_lowercase().as_str() {
65            "cites" => Some(Self::Cites),
66            "implements" => Some(Self::Implements),
67            "extends" => Some(Self::Extends),
68            "contradicts" => Some(Self::Contradicts),
69            "builds_on" => Some(Self::BuildsOn),
70            "authored_by" => Some(Self::AuthoredBy),
71            "uses_dataset" => Some(Self::UsesDataset),
72            "related_to" => Some(Self::RelatedTo),
73            _ => None,
74        }
75    }
76
77    fn as_str(&self) -> &str {
78        match self {
79            Self::Cites => "Cites",
80            Self::Implements => "Implements",
81            Self::Extends => "Extends",
82            Self::Contradicts => "Contradicts",
83            Self::BuildsOn => "BuildsOn",
84            Self::AuthoredBy => "AuthoredBy",
85            Self::UsesDataset => "UsesDataset",
86            Self::RelatedTo => "RelatedTo",
87        }
88    }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct GraphNode {
93    id: String,
94    name: String,
95    node_type: NodeType,
96    description: String,
97    tags: Vec<String>,
98    metadata: HashMap<String, String>,
99    created_at: DateTime<Utc>,
100    updated_at: DateTime<Utc>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104struct Edge {
105    source_id: String,
106    target_id: String,
107    relationship_type: RelationshipType,
108    strength: f64,
109    notes: String,
110    created_at: DateTime<Utc>,
111}
112
113#[derive(Debug, Default, Serialize, Deserialize)]
114struct KnowledgeGraphState {
115    nodes: Vec<GraphNode>,
116    edges: Vec<Edge>,
117    next_auto_id: usize,
118}
119
120/// Minimal struct for deserializing papers from the ArXiv library.
121#[derive(Debug, Deserialize)]
122struct ArxivLibraryFile {
123    entries: Vec<ArxivLibraryEntry>,
124}
125
126#[derive(Debug, Deserialize)]
127struct ArxivLibraryEntry {
128    paper: ArxivPaperRef,
129    #[serde(default)]
130    tags: Vec<String>,
131}
132
133#[derive(Debug, Deserialize)]
134struct ArxivPaperRef {
135    #[serde(alias = "arxiv_id")]
136    id: String,
137    title: String,
138    authors: Vec<String>,
139    #[serde(alias = "summary")]
140    abstract_text: String,
141}
142
143pub struct KnowledgeGraphTool {
144    workspace: PathBuf,
145}
146
147impl KnowledgeGraphTool {
148    pub fn new(workspace: PathBuf) -> Self {
149        Self { workspace }
150    }
151
152    fn state_path(&self) -> PathBuf {
153        self.workspace
154            .join(".rustant")
155            .join("knowledge")
156            .join("graph.json")
157    }
158
159    fn load_state(&self) -> KnowledgeGraphState {
160        let path = self.state_path();
161        if path.exists() {
162            std::fs::read_to_string(&path)
163                .ok()
164                .and_then(|s| serde_json::from_str(&s).ok())
165                .unwrap_or_default()
166        } else {
167            KnowledgeGraphState::default()
168        }
169    }
170
171    fn save_state(&self, state: &KnowledgeGraphState) -> Result<(), ToolError> {
172        let path = self.state_path();
173        if let Some(parent) = path.parent() {
174            std::fs::create_dir_all(parent).map_err(|e| ToolError::ExecutionFailed {
175                name: "knowledge_graph".to_string(),
176                message: format!("Failed to create state dir: {}", e),
177            })?;
178        }
179        let json = serde_json::to_string_pretty(state).map_err(|e| ToolError::ExecutionFailed {
180            name: "knowledge_graph".to_string(),
181            message: format!("Failed to serialize state: {}", e),
182        })?;
183        let tmp = path.with_extension("json.tmp");
184        std::fs::write(&tmp, &json).map_err(|e| ToolError::ExecutionFailed {
185            name: "knowledge_graph".to_string(),
186            message: format!("Failed to write state: {}", e),
187        })?;
188        std::fs::rename(&tmp, &path).map_err(|e| ToolError::ExecutionFailed {
189            name: "knowledge_graph".to_string(),
190            message: format!("Failed to rename state file: {}", e),
191        })?;
192        Ok(())
193    }
194
195    fn slug(name: &str) -> String {
196        name.to_lowercase().replace(' ', "-")
197    }
198
199    fn arxiv_library_path(&self) -> PathBuf {
200        self.workspace
201            .join(".rustant")
202            .join("arxiv")
203            .join("library.json")
204    }
205}
206
207#[async_trait]
208impl Tool for KnowledgeGraphTool {
209    fn name(&self) -> &str {
210        "knowledge_graph"
211    }
212
213    fn description(&self) -> &str {
214        "Local knowledge graph of concepts, papers, methods, and relationships. Actions: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot."
215    }
216
217    fn parameters_schema(&self) -> Value {
218        json!({
219            "type": "object",
220            "properties": {
221                "action": {
222                    "type": "string",
223                    "enum": [
224                        "add_node", "get_node", "update_node", "remove_node",
225                        "add_edge", "remove_edge", "neighbors", "search",
226                        "list", "path", "stats", "import_arxiv", "export_dot"
227                    ],
228                    "description": "Action to perform"
229                },
230                "id": { "type": "string", "description": "Node ID" },
231                "name": { "type": "string", "description": "Node name" },
232                "node_type": {
233                    "type": "string",
234                    "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
235                    "description": "Type of node"
236                },
237                "description": { "type": "string", "description": "Node description" },
238                "tags": {
239                    "type": "array",
240                    "items": { "type": "string" },
241                    "description": "Tags for the node"
242                },
243                "metadata": {
244                    "type": "object",
245                    "additionalProperties": { "type": "string" },
246                    "description": "Key-value metadata for the node"
247                },
248                "source_id": { "type": "string", "description": "Source node ID for edge" },
249                "target_id": { "type": "string", "description": "Target node ID for edge" },
250                "relationship_type": {
251                    "type": "string",
252                    "enum": ["cites", "implements", "extends", "contradicts", "builds_on", "authored_by", "uses_dataset", "related_to"],
253                    "description": "Type of relationship"
254                },
255                "strength": {
256                    "type": "number",
257                    "description": "Edge strength 0.0-1.0 (default 0.5)"
258                },
259                "notes": { "type": "string", "description": "Edge notes" },
260                "depth": {
261                    "type": "integer",
262                    "description": "Traversal depth for neighbors (1-3, default 1)"
263                },
264                "query": { "type": "string", "description": "Search query" },
265                "tag": { "type": "string", "description": "Filter by tag" },
266                "from_id": { "type": "string", "description": "Path start node ID" },
267                "to_id": { "type": "string", "description": "Path end node ID" },
268                "arxiv_id": { "type": "string", "description": "ArXiv paper ID for import" },
269                "filter_type": {
270                    "type": "string",
271                    "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
272                    "description": "Filter by node type"
273                }
274            },
275            "required": ["action"]
276        })
277    }
278
279    fn risk_level(&self) -> RiskLevel {
280        RiskLevel::Write
281    }
282
283    fn timeout(&self) -> Duration {
284        Duration::from_secs(30)
285    }
286
287    async fn execute(&self, args: Value) -> Result<ToolOutput, ToolError> {
288        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("");
289        let mut state = self.load_state();
290
291        match action {
292            "add_node" => {
293                let name = args
294                    .get("name")
295                    .and_then(|v| v.as_str())
296                    .unwrap_or("")
297                    .trim();
298                if name.is_empty() {
299                    return Ok(ToolOutput::text("Missing required parameter 'name'."));
300                }
301
302                let node_type_str = args.get("node_type").and_then(|v| v.as_str()).unwrap_or("");
303                let node_type = match NodeType::from_str_loose(node_type_str) {
304                    Some(nt) => nt,
305                    None => {
306                        return Ok(ToolOutput::text(format!(
307                            "Invalid node_type '{}'. Use: paper, concept, method, dataset, person, organization.",
308                            node_type_str
309                        )));
310                    }
311                };
312
313                let id = args
314                    .get("id")
315                    .and_then(|v| v.as_str())
316                    .map(|s| s.to_string())
317                    .unwrap_or_else(|| Self::slug(name));
318
319                // Check for duplicate id
320                if state.nodes.iter().any(|n| n.id == id) {
321                    return Ok(ToolOutput::text(format!(
322                        "Node with id '{}' already exists.",
323                        id
324                    )));
325                }
326
327                let description = args
328                    .get("description")
329                    .and_then(|v| v.as_str())
330                    .unwrap_or("")
331                    .to_string();
332
333                let tags: Vec<String> = args
334                    .get("tags")
335                    .and_then(|v| v.as_array())
336                    .map(|arr| {
337                        arr.iter()
338                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
339                            .collect()
340                    })
341                    .unwrap_or_default();
342
343                let metadata: HashMap<String, String> = args
344                    .get("metadata")
345                    .and_then(|v| v.as_object())
346                    .map(|obj| {
347                        obj.iter()
348                            .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
349                            .collect()
350                    })
351                    .unwrap_or_default();
352
353                let now = Utc::now();
354                state.nodes.push(GraphNode {
355                    id: id.clone(),
356                    name: name.to_string(),
357                    node_type,
358                    description,
359                    tags,
360                    metadata,
361                    created_at: now,
362                    updated_at: now,
363                });
364                state.next_auto_id += 1;
365                self.save_state(&state)?;
366
367                Ok(ToolOutput::text(format!(
368                    "Added node '{}' (id: {}).",
369                    name, id
370                )))
371            }
372
373            "get_node" => {
374                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
375                if id.is_empty() {
376                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
377                }
378
379                let node = match state.nodes.iter().find(|n| n.id == id) {
380                    Some(n) => n,
381                    None => {
382                        return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
383                    }
384                };
385
386                let connected_edges: Vec<&Edge> = state
387                    .edges
388                    .iter()
389                    .filter(|e| e.source_id == id || e.target_id == id)
390                    .collect();
391
392                let mut output = format!(
393                    "Node: {} ({})\n  ID: {}\n  Type: {}\n  Description: {}\n  Tags: {}\n  Created: {}\n  Updated: {}",
394                    node.name,
395                    node.node_type.as_str(),
396                    node.id,
397                    node.node_type.as_str(),
398                    if node.description.is_empty() {
399                        "(none)"
400                    } else {
401                        &node.description
402                    },
403                    if node.tags.is_empty() {
404                        "(none)".to_string()
405                    } else {
406                        node.tags.join(", ")
407                    },
408                    node.created_at.format("%Y-%m-%d %H:%M"),
409                    node.updated_at.format("%Y-%m-%d %H:%M"),
410                );
411
412                if !node.metadata.is_empty() {
413                    output.push_str("\n  Metadata:");
414                    for (k, v) in &node.metadata {
415                        output.push_str(&format!("\n    {}: {}", k, v));
416                    }
417                }
418
419                if !connected_edges.is_empty() {
420                    output.push_str(&format!("\n  Edges ({}):", connected_edges.len()));
421                    for edge in &connected_edges {
422                        let direction = if edge.source_id == id {
423                            format!(
424                                "--[{}]--> {}",
425                                edge.relationship_type.as_str(),
426                                edge.target_id
427                            )
428                        } else {
429                            format!(
430                                "<--[{}]-- {}",
431                                edge.relationship_type.as_str(),
432                                edge.source_id
433                            )
434                        };
435                        output.push_str(&format!(
436                            "\n    {} (strength: {:.2})",
437                            direction, edge.strength
438                        ));
439                    }
440                }
441
442                Ok(ToolOutput::text(output))
443            }
444
445            "update_node" => {
446                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
447                if id.is_empty() {
448                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
449                }
450
451                let node = match state.nodes.iter_mut().find(|n| n.id == id) {
452                    Some(n) => n,
453                    None => {
454                        return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
455                    }
456                };
457
458                if let Some(name) = args.get("name").and_then(|v| v.as_str()) {
459                    node.name = name.to_string();
460                }
461                if let Some(desc) = args.get("description").and_then(|v| v.as_str()) {
462                    node.description = desc.to_string();
463                }
464                if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
465                    node.tags = tags
466                        .iter()
467                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
468                        .collect();
469                }
470                node.updated_at = Utc::now();
471                self.save_state(&state)?;
472
473                Ok(ToolOutput::text(format!("Updated node '{}'.", id)))
474            }
475
476            "remove_node" => {
477                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
478                if id.is_empty() {
479                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
480                }
481
482                let before_nodes = state.nodes.len();
483                state.nodes.retain(|n| n.id != id);
484                if state.nodes.len() == before_nodes {
485                    return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
486                }
487
488                // Cascade-delete edges referencing this node
489                let before_edges = state.edges.len();
490                state
491                    .edges
492                    .retain(|e| e.source_id != id && e.target_id != id);
493                let edges_removed = before_edges - state.edges.len();
494
495                self.save_state(&state)?;
496
497                Ok(ToolOutput::text(format!(
498                    "Removed node '{}' and {} connected edge(s).",
499                    id, edges_removed
500                )))
501            }
502
503            "add_edge" => {
504                let source_id = args.get("source_id").and_then(|v| v.as_str()).unwrap_or("");
505                let target_id = args.get("target_id").and_then(|v| v.as_str()).unwrap_or("");
506                let rel_str = args
507                    .get("relationship_type")
508                    .and_then(|v| v.as_str())
509                    .unwrap_or("");
510
511                if source_id.is_empty() || target_id.is_empty() {
512                    return Ok(ToolOutput::text(
513                        "Missing required parameters 'source_id' and 'target_id'.",
514                    ));
515                }
516
517                // Validate nodes exist
518                if !state.nodes.iter().any(|n| n.id == source_id) {
519                    return Ok(ToolOutput::text(format!(
520                        "Source node '{}' not found.",
521                        source_id
522                    )));
523                }
524                if !state.nodes.iter().any(|n| n.id == target_id) {
525                    return Ok(ToolOutput::text(format!(
526                        "Target node '{}' not found.",
527                        target_id
528                    )));
529                }
530
531                let relationship_type = match RelationshipType::from_str_loose(rel_str) {
532                    Some(rt) => rt,
533                    None => {
534                        return Ok(ToolOutput::text(format!(
535                            "Invalid relationship_type '{}'. Use: cites, implements, extends, contradicts, builds_on, authored_by, uses_dataset, related_to.",
536                            rel_str
537                        )));
538                    }
539                };
540
541                let strength = args
542                    .get("strength")
543                    .and_then(|v| v.as_f64())
544                    .unwrap_or(0.5)
545                    .clamp(0.0, 1.0);
546
547                let notes = args
548                    .get("notes")
549                    .and_then(|v| v.as_str())
550                    .unwrap_or("")
551                    .to_string();
552
553                state.edges.push(Edge {
554                    source_id: source_id.to_string(),
555                    target_id: target_id.to_string(),
556                    relationship_type,
557                    strength,
558                    notes,
559                    created_at: Utc::now(),
560                });
561
562                self.save_state(&state)?;
563
564                Ok(ToolOutput::text(format!(
565                    "Added edge {} --[{}]--> {} (strength: {:.2}).",
566                    source_id, rel_str, target_id, strength
567                )))
568            }
569
570            "remove_edge" => {
571                let source_id = args.get("source_id").and_then(|v| v.as_str()).unwrap_or("");
572                let target_id = args.get("target_id").and_then(|v| v.as_str()).unwrap_or("");
573
574                if source_id.is_empty() || target_id.is_empty() {
575                    return Ok(ToolOutput::text(
576                        "Missing required parameters 'source_id' and 'target_id'.",
577                    ));
578                }
579
580                let rel_filter = args
581                    .get("relationship_type")
582                    .and_then(|v| v.as_str())
583                    .and_then(RelationshipType::from_str_loose);
584
585                let before = state.edges.len();
586                state.edges.retain(|e| {
587                    if e.source_id == source_id && e.target_id == target_id {
588                        if let Some(ref rt) = rel_filter {
589                            return &e.relationship_type != rt;
590                        }
591                        return false;
592                    }
593                    true
594                });
595                let removed = before - state.edges.len();
596
597                if removed == 0 {
598                    return Ok(ToolOutput::text("No matching edge(s) found."));
599                }
600
601                self.save_state(&state)?;
602
603                Ok(ToolOutput::text(format!(
604                    "Removed {} edge(s) from '{}' to '{}'.",
605                    removed, source_id, target_id
606                )))
607            }
608
609            "neighbors" => {
610                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
611                if id.is_empty() {
612                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
613                }
614
615                if !state.nodes.iter().any(|n| n.id == id) {
616                    return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
617                }
618
619                let max_depth = args
620                    .get("depth")
621                    .and_then(|v| v.as_u64())
622                    .unwrap_or(1)
623                    .clamp(1, 3) as usize;
624
625                let rel_filter = args
626                    .get("relationship_type")
627                    .and_then(|v| v.as_str())
628                    .and_then(RelationshipType::from_str_loose);
629
630                // BFS traversal
631                let mut visited: HashSet<String> = HashSet::new();
632                visited.insert(id.to_string());
633                let mut queue: VecDeque<(String, usize)> = VecDeque::new();
634                queue.push_back((id.to_string(), 0));
635                let mut found_nodes: Vec<(String, usize)> = Vec::new();
636
637                while let Some((current_id, depth)) = queue.pop_front() {
638                    if depth >= max_depth {
639                        continue;
640                    }
641
642                    for edge in &state.edges {
643                        if let Some(ref rt) = rel_filter
644                            && &edge.relationship_type != rt
645                        {
646                            continue;
647                        }
648
649                        let neighbor_id = if edge.source_id == current_id {
650                            &edge.target_id
651                        } else if edge.target_id == current_id {
652                            &edge.source_id
653                        } else {
654                            continue;
655                        };
656
657                        if !visited.contains(neighbor_id) {
658                            visited.insert(neighbor_id.clone());
659                            found_nodes.push((neighbor_id.clone(), depth + 1));
660                            queue.push_back((neighbor_id.clone(), depth + 1));
661                        }
662                    }
663                }
664
665                if found_nodes.is_empty() {
666                    return Ok(ToolOutput::text(format!(
667                        "No neighbors found for '{}' within depth {}.",
668                        id, max_depth
669                    )));
670                }
671
672                let mut output = format!("Neighbors of '{}' (depth {}):\n", id, max_depth);
673                for (nid, depth) in &found_nodes {
674                    if let Some(node) = state.nodes.iter().find(|n| n.id == *nid) {
675                        output.push_str(&format!(
676                            "  [depth {}] {} — {} ({})\n",
677                            depth,
678                            node.id,
679                            node.name,
680                            node.node_type.as_str()
681                        ));
682                    }
683                }
684
685                Ok(ToolOutput::text(output.trim_end().to_string()))
686            }
687
688            "search" => {
689                let query = args
690                    .get("query")
691                    .and_then(|v| v.as_str())
692                    .unwrap_or("")
693                    .to_lowercase();
694                if query.is_empty() {
695                    return Ok(ToolOutput::text("Missing required parameter 'query'."));
696                }
697
698                let type_filter = args
699                    .get("node_type")
700                    .and_then(|v| v.as_str())
701                    .and_then(NodeType::from_str_loose);
702
703                let matches: Vec<&GraphNode> = state
704                    .nodes
705                    .iter()
706                    .filter(|n| {
707                        if let Some(ref nt) = type_filter
708                            && &n.node_type != nt
709                        {
710                            return false;
711                        }
712                        n.name.to_lowercase().contains(&query)
713                            || n.description.to_lowercase().contains(&query)
714                            || n.tags.iter().any(|t| t.to_lowercase().contains(&query))
715                    })
716                    .collect();
717
718                if matches.is_empty() {
719                    return Ok(ToolOutput::text(format!("No nodes matching '{}'.", query)));
720                }
721
722                let mut output = format!("Found {} node(s):\n", matches.len());
723                for node in &matches {
724                    output.push_str(&format!(
725                        "  {} — {} ({})\n",
726                        node.id,
727                        node.name,
728                        node.node_type.as_str()
729                    ));
730                }
731
732                Ok(ToolOutput::text(output.trim_end().to_string()))
733            }
734
735            "list" => {
736                let type_filter = args
737                    .get("node_type")
738                    .and_then(|v| v.as_str())
739                    .and_then(NodeType::from_str_loose);
740
741                let tag_filter = args.get("tag").and_then(|v| v.as_str());
742
743                let filtered: Vec<&GraphNode> = state
744                    .nodes
745                    .iter()
746                    .filter(|n| {
747                        if let Some(ref nt) = type_filter
748                            && &n.node_type != nt
749                        {
750                            return false;
751                        }
752                        if let Some(tag) = tag_filter
753                            && !n.tags.iter().any(|t| t.eq_ignore_ascii_case(tag))
754                        {
755                            return false;
756                        }
757                        true
758                    })
759                    .collect();
760
761                if filtered.is_empty() {
762                    return Ok(ToolOutput::text("No nodes found."));
763                }
764
765                let mut output = format!("Nodes ({}):\n", filtered.len());
766                for node in &filtered {
767                    output.push_str(&format!(
768                        "  {} — {} ({})\n",
769                        node.id,
770                        node.name,
771                        node.node_type.as_str()
772                    ));
773                }
774
775                Ok(ToolOutput::text(output.trim_end().to_string()))
776            }
777
778            "path" => {
779                let from_id = args.get("from_id").and_then(|v| v.as_str()).unwrap_or("");
780                let to_id = args.get("to_id").and_then(|v| v.as_str()).unwrap_or("");
781
782                if from_id.is_empty() || to_id.is_empty() {
783                    return Ok(ToolOutput::text(
784                        "Missing required parameters 'from_id' and 'to_id'.",
785                    ));
786                }
787
788                if !state.nodes.iter().any(|n| n.id == from_id) {
789                    return Ok(ToolOutput::text(format!("Node '{}' not found.", from_id)));
790                }
791                if !state.nodes.iter().any(|n| n.id == to_id) {
792                    return Ok(ToolOutput::text(format!("Node '{}' not found.", to_id)));
793                }
794
795                // BFS shortest path (bidirectional edges)
796                let mut visited: HashSet<String> = HashSet::new();
797                let mut parent: HashMap<String, String> = HashMap::new();
798                let mut queue: VecDeque<String> = VecDeque::new();
799
800                visited.insert(from_id.to_string());
801                queue.push_back(from_id.to_string());
802
803                let mut found = false;
804                while let Some(current) = queue.pop_front() {
805                    if current == to_id {
806                        found = true;
807                        break;
808                    }
809
810                    for edge in &state.edges {
811                        let neighbor = if edge.source_id == current {
812                            &edge.target_id
813                        } else if edge.target_id == current {
814                            &edge.source_id
815                        } else {
816                            continue;
817                        };
818
819                        if !visited.contains(neighbor) {
820                            visited.insert(neighbor.clone());
821                            parent.insert(neighbor.clone(), current.clone());
822                            queue.push_back(neighbor.clone());
823                        }
824                    }
825                }
826
827                if !found {
828                    return Ok(ToolOutput::text(format!(
829                        "No path found from '{}' to '{}'.",
830                        from_id, to_id
831                    )));
832                }
833
834                // Reconstruct path
835                let mut path_ids: Vec<String> = Vec::new();
836                let mut current = to_id.to_string();
837                while current != from_id {
838                    path_ids.push(current.clone());
839                    current = parent.get(&current).unwrap().clone();
840                }
841                path_ids.push(from_id.to_string());
842                path_ids.reverse();
843
844                let mut output = format!(
845                    "Path from '{}' to '{}' ({} hops):\n",
846                    from_id,
847                    to_id,
848                    path_ids.len() - 1
849                );
850                for (i, pid) in path_ids.iter().enumerate() {
851                    let node_name = state
852                        .nodes
853                        .iter()
854                        .find(|n| n.id == *pid)
855                        .map(|n| n.name.as_str())
856                        .unwrap_or("?");
857                    if i > 0 {
858                        output.push_str("  -> ");
859                    } else {
860                        output.push_str("  ");
861                    }
862                    output.push_str(&format!("{} ({})\n", pid, node_name));
863                }
864
865                Ok(ToolOutput::text(output.trim_end().to_string()))
866            }
867
868            "stats" => {
869                let total_nodes = state.nodes.len();
870                let total_edges = state.edges.len();
871
872                if total_nodes == 0 {
873                    return Ok(ToolOutput::text(
874                        "Knowledge graph is empty. Use add_node to get started.",
875                    ));
876                }
877
878                // Node counts by type
879                let mut type_counts: HashMap<&str, usize> = HashMap::new();
880                for node in &state.nodes {
881                    *type_counts.entry(node.node_type.as_str()).or_insert(0) += 1;
882                }
883
884                // Edge counts by type
885                let mut edge_type_counts: HashMap<&str, usize> = HashMap::new();
886                for edge in &state.edges {
887                    *edge_type_counts
888                        .entry(edge.relationship_type.as_str())
889                        .or_insert(0) += 1;
890                }
891
892                // Top-5 most connected nodes
893                let mut connection_counts: HashMap<&str, usize> = HashMap::new();
894                for edge in &state.edges {
895                    *connection_counts.entry(&edge.source_id).or_insert(0) += 1;
896                    *connection_counts.entry(&edge.target_id).or_insert(0) += 1;
897                }
898                let mut top_connected: Vec<(&&str, &usize)> = connection_counts.iter().collect();
899                top_connected.sort_by(|a, b| b.1.cmp(a.1));
900                top_connected.truncate(5);
901
902                let mut output = format!(
903                    "Knowledge Graph Stats:\n  Nodes: {}\n  Edges: {}\n\n  Nodes by type:\n",
904                    total_nodes, total_edges
905                );
906                let mut sorted_types: Vec<_> = type_counts.iter().collect();
907                sorted_types.sort_by_key(|(k, _)| *k);
908                for (t, c) in &sorted_types {
909                    output.push_str(&format!("    {}: {}\n", t, c));
910                }
911
912                if !edge_type_counts.is_empty() {
913                    output.push_str("\n  Edges by type:\n");
914                    let mut sorted_etypes: Vec<_> = edge_type_counts.iter().collect();
915                    sorted_etypes.sort_by_key(|(k, _)| *k);
916                    for (t, c) in &sorted_etypes {
917                        output.push_str(&format!("    {}: {}\n", t, c));
918                    }
919                }
920
921                if !top_connected.is_empty() {
922                    output.push_str("\n  Most connected nodes:\n");
923                    for (nid, count) in &top_connected {
924                        let node_name = state
925                            .nodes
926                            .iter()
927                            .find(|n| n.id == ***nid)
928                            .map(|n| n.name.as_str())
929                            .unwrap_or("?");
930                        output.push_str(&format!(
931                            "    {} ({}) — {} connections\n",
932                            nid, node_name, count
933                        ));
934                    }
935                }
936
937                Ok(ToolOutput::text(output.trim_end().to_string()))
938            }
939
940            "import_arxiv" => {
941                let arxiv_id = args.get("arxiv_id").and_then(|v| v.as_str()).unwrap_or("");
942                if arxiv_id.is_empty() {
943                    return Ok(ToolOutput::text("Missing required parameter 'arxiv_id'."));
944                }
945
946                let lib_path = self.arxiv_library_path();
947                if !lib_path.exists() {
948                    return Ok(ToolOutput::text(
949                        "ArXiv library not found. Save papers with the arxiv_research tool first.",
950                    ));
951                }
952
953                let lib_json =
954                    std::fs::read_to_string(&lib_path).map_err(|e| ToolError::ExecutionFailed {
955                        name: "knowledge_graph".to_string(),
956                        message: format!("Failed to read arxiv library: {}", e),
957                    })?;
958
959                let library: ArxivLibraryFile =
960                    serde_json::from_str(&lib_json).map_err(|e| ToolError::ExecutionFailed {
961                        name: "knowledge_graph".to_string(),
962                        message: format!("Failed to parse arxiv library: {}", e),
963                    })?;
964
965                let entry = match library.entries.iter().find(|e| e.paper.id == arxiv_id) {
966                    Some(e) => e,
967                    None => {
968                        return Ok(ToolOutput::text(format!(
969                            "Paper '{}' not found in arxiv library.",
970                            arxiv_id
971                        )));
972                    }
973                };
974
975                let now = Utc::now();
976                let paper_id = Self::slug(&entry.paper.title);
977                let mut added_nodes = 0;
978                let mut added_edges = 0;
979
980                // Create Paper node if not exists
981                if !state.nodes.iter().any(|n| n.id == paper_id) {
982                    let mut metadata = HashMap::new();
983                    metadata.insert("arxiv_id".to_string(), entry.paper.id.clone());
984                    state.nodes.push(GraphNode {
985                        id: paper_id.clone(),
986                        name: entry.paper.title.clone(),
987                        node_type: NodeType::Paper,
988                        description: entry.paper.abstract_text.clone(),
989                        tags: entry.tags.clone(),
990                        metadata,
991                        created_at: now,
992                        updated_at: now,
993                    });
994                    added_nodes += 1;
995                }
996
997                // Create Person nodes for each author + AuthoredBy edges
998                for author in &entry.paper.authors {
999                    let author_id = Self::slug(author);
1000                    if !state.nodes.iter().any(|n| n.id == author_id) {
1001                        state.nodes.push(GraphNode {
1002                            id: author_id.clone(),
1003                            name: author.clone(),
1004                            node_type: NodeType::Person,
1005                            description: String::new(),
1006                            tags: Vec::new(),
1007                            metadata: HashMap::new(),
1008                            created_at: now,
1009                            updated_at: now,
1010                        });
1011                        added_nodes += 1;
1012                    }
1013
1014                    // Add AuthoredBy edge (paper -> author)
1015                    let edge_exists = state.edges.iter().any(|e| {
1016                        e.source_id == paper_id
1017                            && e.target_id == author_id
1018                            && e.relationship_type == RelationshipType::AuthoredBy
1019                    });
1020                    if !edge_exists {
1021                        state.edges.push(Edge {
1022                            source_id: paper_id.clone(),
1023                            target_id: author_id,
1024                            relationship_type: RelationshipType::AuthoredBy,
1025                            strength: 1.0,
1026                            notes: String::new(),
1027                            created_at: now,
1028                        });
1029                        added_edges += 1;
1030                    }
1031                }
1032
1033                state.next_auto_id += added_nodes;
1034                self.save_state(&state)?;
1035
1036                Ok(ToolOutput::text(format!(
1037                    "Imported '{}' from arxiv: {} node(s) and {} edge(s) added.",
1038                    entry.paper.title, added_nodes, added_edges
1039                )))
1040            }
1041
1042            "export_dot" => {
1043                let type_filter = args
1044                    .get("filter_type")
1045                    .and_then(|v| v.as_str())
1046                    .and_then(NodeType::from_str_loose);
1047
1048                let filtered_nodes: Vec<&GraphNode> = state
1049                    .nodes
1050                    .iter()
1051                    .filter(|n| {
1052                        type_filter
1053                            .as_ref()
1054                            .map(|nt| &n.node_type == nt)
1055                            .unwrap_or(true)
1056                    })
1057                    .collect();
1058
1059                let node_ids: HashSet<&str> =
1060                    filtered_nodes.iter().map(|n| n.id.as_str()).collect();
1061
1062                let mut dot = String::from("digraph KnowledgeGraph {\n");
1063                dot.push_str("  rankdir=LR;\n");
1064                dot.push_str("  node [shape=box];\n\n");
1065
1066                for node in &filtered_nodes {
1067                    dot.push_str(&format!(
1068                        "  \"{}\" [label=\"{}\\n({})\"];\n",
1069                        node.id,
1070                        node.name.replace('"', "\\\""),
1071                        node.node_type.as_str()
1072                    ));
1073                }
1074
1075                dot.push('\n');
1076
1077                for edge in &state.edges {
1078                    if node_ids.contains(edge.source_id.as_str())
1079                        && node_ids.contains(edge.target_id.as_str())
1080                    {
1081                        dot.push_str(&format!(
1082                            "  \"{}\" -> \"{}\" [label=\"{}\"];\n",
1083                            edge.source_id,
1084                            edge.target_id,
1085                            edge.relationship_type.as_str()
1086                        ));
1087                    }
1088                }
1089
1090                dot.push_str("}\n");
1091
1092                Ok(ToolOutput::text(dot))
1093            }
1094
1095            _ => Ok(ToolOutput::text(format!(
1096                "Unknown action: '{}'. Use: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot.",
1097                action
1098            ))),
1099        }
1100    }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105    use super::*;
1106    use tempfile::TempDir;
1107
1108    fn make_tool() -> (TempDir, KnowledgeGraphTool) {
1109        let dir = TempDir::new().unwrap();
1110        let workspace = dir.path().canonicalize().unwrap();
1111        let tool = KnowledgeGraphTool::new(workspace);
1112        (dir, tool)
1113    }
1114
1115    #[test]
1116    fn test_tool_properties() {
1117        let (_dir, tool) = make_tool();
1118        assert_eq!(tool.name(), "knowledge_graph");
1119        assert_eq!(tool.risk_level(), RiskLevel::Write);
1120        assert_eq!(tool.timeout(), Duration::from_secs(30));
1121        assert!(!tool.description().is_empty());
1122    }
1123
1124    #[test]
1125    fn test_schema_validation() {
1126        let (_dir, tool) = make_tool();
1127        let schema = tool.parameters_schema();
1128        assert!(schema.get("properties").is_some());
1129        let required = schema["required"].as_array().unwrap();
1130        assert!(required.iter().any(|v| v.as_str() == Some("action")));
1131        let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
1132        assert_eq!(actions.len(), 13);
1133    }
1134
1135    #[tokio::test]
1136    async fn test_add_node_basic() {
1137        let (_dir, tool) = make_tool();
1138        let result = tool
1139            .execute(json!({
1140                "action": "add_node",
1141                "name": "Attention Is All You Need",
1142                "node_type": "paper",
1143                "description": "Transformer architecture paper"
1144            }))
1145            .await
1146            .unwrap();
1147        assert!(result.content.contains("Added node"));
1148        assert!(result.content.contains("attention-is-all-you-need"));
1149    }
1150
1151    #[tokio::test]
1152    async fn test_add_node_auto_id_from_slug() {
1153        let (_dir, tool) = make_tool();
1154        let result = tool
1155            .execute(json!({
1156                "action": "add_node",
1157                "name": "Deep Learning Basics",
1158                "node_type": "concept"
1159            }))
1160            .await
1161            .unwrap();
1162        assert!(result.content.contains("deep-learning-basics"));
1163
1164        // Verify the id was generated correctly
1165        let state = tool.load_state();
1166        assert_eq!(state.nodes[0].id, "deep-learning-basics");
1167        assert_eq!(state.nodes[0].name, "Deep Learning Basics");
1168    }
1169
1170    #[tokio::test]
1171    async fn test_get_node_with_edges() {
1172        let (_dir, tool) = make_tool();
1173        // Add two nodes and an edge
1174        tool.execute(json!({
1175            "action": "add_node",
1176            "name": "Paper A",
1177            "node_type": "paper",
1178            "id": "paper-a"
1179        }))
1180        .await
1181        .unwrap();
1182        tool.execute(json!({
1183            "action": "add_node",
1184            "name": "Concept B",
1185            "node_type": "concept",
1186            "id": "concept-b"
1187        }))
1188        .await
1189        .unwrap();
1190        tool.execute(json!({
1191            "action": "add_edge",
1192            "source_id": "paper-a",
1193            "target_id": "concept-b",
1194            "relationship_type": "implements"
1195        }))
1196        .await
1197        .unwrap();
1198
1199        let result = tool
1200            .execute(json!({ "action": "get_node", "id": "paper-a" }))
1201            .await
1202            .unwrap();
1203        assert!(result.content.contains("Paper A"));
1204        assert!(result.content.contains("Implements"));
1205        assert!(result.content.contains("concept-b"));
1206    }
1207
1208    #[tokio::test]
1209    async fn test_update_node() {
1210        let (_dir, tool) = make_tool();
1211        tool.execute(json!({
1212            "action": "add_node",
1213            "name": "Old Name",
1214            "node_type": "concept",
1215            "id": "test-node"
1216        }))
1217        .await
1218        .unwrap();
1219
1220        let result = tool
1221            .execute(json!({
1222                "action": "update_node",
1223                "id": "test-node",
1224                "name": "New Name",
1225                "description": "Updated description",
1226                "tags": ["updated"]
1227            }))
1228            .await
1229            .unwrap();
1230        assert!(result.content.contains("Updated node"));
1231
1232        let state = tool.load_state();
1233        let node = state.nodes.iter().find(|n| n.id == "test-node").unwrap();
1234        assert_eq!(node.name, "New Name");
1235        assert_eq!(node.description, "Updated description");
1236        assert_eq!(node.tags, vec!["updated"]);
1237    }
1238
1239    #[tokio::test]
1240    async fn test_remove_node_cascades_edges() {
1241        let (_dir, tool) = make_tool();
1242        // Add three nodes and edges
1243        tool.execute(
1244            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1245        )
1246        .await
1247        .unwrap();
1248        tool.execute(
1249            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1250        )
1251        .await
1252        .unwrap();
1253        tool.execute(
1254            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1255        )
1256        .await
1257        .unwrap();
1258        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1259            .await.unwrap();
1260        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1261            .await.unwrap();
1262
1263        let state = tool.load_state();
1264        assert_eq!(state.edges.len(), 2);
1265
1266        // Remove node B — should cascade-delete both edges
1267        let result = tool
1268            .execute(json!({ "action": "remove_node", "id": "b" }))
1269            .await
1270            .unwrap();
1271        assert!(result.content.contains("Removed node 'b'"));
1272        assert!(result.content.contains("2 connected edge(s)"));
1273
1274        let state = tool.load_state();
1275        assert_eq!(state.nodes.len(), 2);
1276        assert_eq!(state.edges.len(), 0);
1277    }
1278
1279    #[tokio::test]
1280    async fn test_add_edge_validates_nodes() {
1281        let (_dir, tool) = make_tool();
1282        tool.execute(
1283            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1284        )
1285        .await
1286        .unwrap();
1287
1288        // Target node doesn't exist
1289        let result = tool
1290            .execute(json!({
1291                "action": "add_edge",
1292                "source_id": "a",
1293                "target_id": "nonexistent",
1294                "relationship_type": "related_to"
1295            }))
1296            .await
1297            .unwrap();
1298        assert!(result.content.contains("not found"));
1299
1300        // Source node doesn't exist
1301        let result = tool
1302            .execute(json!({
1303                "action": "add_edge",
1304                "source_id": "nonexistent",
1305                "target_id": "a",
1306                "relationship_type": "related_to"
1307            }))
1308            .await
1309            .unwrap();
1310        assert!(result.content.contains("not found"));
1311    }
1312
1313    #[tokio::test]
1314    async fn test_add_edge_strength_default() {
1315        let (_dir, tool) = make_tool();
1316        tool.execute(
1317            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1318        )
1319        .await
1320        .unwrap();
1321        tool.execute(
1322            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1323        )
1324        .await
1325        .unwrap();
1326
1327        tool.execute(json!({
1328            "action": "add_edge",
1329            "source_id": "a",
1330            "target_id": "b",
1331            "relationship_type": "related_to"
1332        }))
1333        .await
1334        .unwrap();
1335
1336        let state = tool.load_state();
1337        assert_eq!(state.edges.len(), 1);
1338        assert!((state.edges[0].strength - 0.5).abs() < f64::EPSILON);
1339    }
1340
1341    #[tokio::test]
1342    async fn test_remove_edge() {
1343        let (_dir, tool) = make_tool();
1344        tool.execute(
1345            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1346        )
1347        .await
1348        .unwrap();
1349        tool.execute(
1350            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1351        )
1352        .await
1353        .unwrap();
1354        tool.execute(json!({
1355            "action": "add_edge",
1356            "source_id": "a",
1357            "target_id": "b",
1358            "relationship_type": "related_to"
1359        }))
1360        .await
1361        .unwrap();
1362
1363        assert_eq!(tool.load_state().edges.len(), 1);
1364
1365        let result = tool
1366            .execute(json!({
1367                "action": "remove_edge",
1368                "source_id": "a",
1369                "target_id": "b"
1370            }))
1371            .await
1372            .unwrap();
1373        assert!(result.content.contains("Removed"));
1374        assert_eq!(tool.load_state().edges.len(), 0);
1375    }
1376
1377    #[tokio::test]
1378    async fn test_neighbors_depth_1() {
1379        let (_dir, tool) = make_tool();
1380        tool.execute(
1381            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1382        )
1383        .await
1384        .unwrap();
1385        tool.execute(
1386            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1387        )
1388        .await
1389        .unwrap();
1390        tool.execute(
1391            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1392        )
1393        .await
1394        .unwrap();
1395        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1396            .await.unwrap();
1397        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1398            .await.unwrap();
1399
1400        let result = tool
1401            .execute(json!({ "action": "neighbors", "id": "a", "depth": 1 }))
1402            .await
1403            .unwrap();
1404        assert!(result.content.contains("b"));
1405        // At depth 1, should NOT find C
1406        assert!(!result.content.contains(" c ") && !result.content.contains("\n  [depth 1] c"));
1407    }
1408
1409    #[tokio::test]
1410    async fn test_neighbors_depth_2() {
1411        let (_dir, tool) = make_tool();
1412        tool.execute(
1413            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1414        )
1415        .await
1416        .unwrap();
1417        tool.execute(
1418            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1419        )
1420        .await
1421        .unwrap();
1422        tool.execute(
1423            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1424        )
1425        .await
1426        .unwrap();
1427        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1428            .await.unwrap();
1429        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1430            .await.unwrap();
1431
1432        let result = tool
1433            .execute(json!({ "action": "neighbors", "id": "a", "depth": 2 }))
1434            .await
1435            .unwrap();
1436        // At depth 2, should find both B and C
1437        assert!(result.content.contains("b"));
1438        assert!(result.content.contains("c"));
1439    }
1440
1441    #[tokio::test]
1442    async fn test_search_by_name() {
1443        let (_dir, tool) = make_tool();
1444        tool.execute(json!({
1445            "action": "add_node",
1446            "name": "Transformer Architecture",
1447            "node_type": "method",
1448            "id": "transformer"
1449        }))
1450        .await
1451        .unwrap();
1452
1453        let result = tool
1454            .execute(json!({ "action": "search", "query": "transformer" }))
1455            .await
1456            .unwrap();
1457        assert!(result.content.contains("Transformer Architecture"));
1458    }
1459
1460    #[tokio::test]
1461    async fn test_search_by_tag() {
1462        let (_dir, tool) = make_tool();
1463        tool.execute(json!({
1464            "action": "add_node",
1465            "name": "BERT",
1466            "node_type": "method",
1467            "id": "bert",
1468            "tags": ["nlp", "language-model"]
1469        }))
1470        .await
1471        .unwrap();
1472
1473        let result = tool
1474            .execute(json!({ "action": "search", "query": "nlp" }))
1475            .await
1476            .unwrap();
1477        assert!(result.content.contains("BERT"));
1478    }
1479
1480    #[tokio::test]
1481    async fn test_search_filter_type() {
1482        let (_dir, tool) = make_tool();
1483        tool.execute(json!({ "action": "add_node", "name": "ML Paper", "node_type": "paper", "id": "ml-paper" }))
1484            .await.unwrap();
1485        tool.execute(json!({ "action": "add_node", "name": "ML Concept", "node_type": "concept", "id": "ml-concept" }))
1486            .await.unwrap();
1487
1488        let result = tool
1489            .execute(json!({ "action": "search", "query": "ml", "node_type": "paper" }))
1490            .await
1491            .unwrap();
1492        assert!(result.content.contains("ML Paper"));
1493        assert!(!result.content.contains("ML Concept"));
1494    }
1495
1496    #[tokio::test]
1497    async fn test_list_all() {
1498        let (_dir, tool) = make_tool();
1499        tool.execute(
1500            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1501        )
1502        .await
1503        .unwrap();
1504        tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1505            .await
1506            .unwrap();
1507
1508        let result = tool.execute(json!({ "action": "list" })).await.unwrap();
1509        assert!(result.content.contains("Nodes (2)"));
1510        assert!(result.content.contains("a"));
1511        assert!(result.content.contains("b"));
1512    }
1513
1514    #[tokio::test]
1515    async fn test_list_filter_type() {
1516        let (_dir, tool) = make_tool();
1517        tool.execute(
1518            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1519        )
1520        .await
1521        .unwrap();
1522        tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1523            .await
1524            .unwrap();
1525
1526        let result = tool
1527            .execute(json!({ "action": "list", "node_type": "concept" }))
1528            .await
1529            .unwrap();
1530        assert!(result.content.contains("Nodes (1)"));
1531        assert!(result.content.contains("a"));
1532        assert!(!result.content.contains(" b "));
1533    }
1534
1535    #[tokio::test]
1536    async fn test_path_direct() {
1537        let (_dir, tool) = make_tool();
1538        tool.execute(
1539            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1540        )
1541        .await
1542        .unwrap();
1543        tool.execute(
1544            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1545        )
1546        .await
1547        .unwrap();
1548        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1549            .await.unwrap();
1550
1551        let result = tool
1552            .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1553            .await
1554            .unwrap();
1555        assert!(result.content.contains("1 hops"));
1556        assert!(result.content.contains("a"));
1557        assert!(result.content.contains("b"));
1558    }
1559
1560    #[tokio::test]
1561    async fn test_path_indirect() {
1562        let (_dir, tool) = make_tool();
1563        tool.execute(
1564            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1565        )
1566        .await
1567        .unwrap();
1568        tool.execute(
1569            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1570        )
1571        .await
1572        .unwrap();
1573        tool.execute(
1574            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1575        )
1576        .await
1577        .unwrap();
1578        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1579            .await.unwrap();
1580        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1581            .await.unwrap();
1582
1583        let result = tool
1584            .execute(json!({ "action": "path", "from_id": "a", "to_id": "c" }))
1585            .await
1586            .unwrap();
1587        assert!(result.content.contains("2 hops"));
1588        assert!(result.content.contains("a"));
1589        assert!(result.content.contains("b"));
1590        assert!(result.content.contains("c"));
1591    }
1592
1593    #[tokio::test]
1594    async fn test_path_no_path() {
1595        let (_dir, tool) = make_tool();
1596        tool.execute(
1597            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1598        )
1599        .await
1600        .unwrap();
1601        tool.execute(
1602            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1603        )
1604        .await
1605        .unwrap();
1606        // No edges between them
1607
1608        let result = tool
1609            .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1610            .await
1611            .unwrap();
1612        assert!(result.content.contains("No path found"));
1613    }
1614
1615    #[tokio::test]
1616    async fn test_stats() {
1617        let (_dir, tool) = make_tool();
1618        tool.execute(
1619            json!({ "action": "add_node", "name": "Paper 1", "node_type": "paper", "id": "p1" }),
1620        )
1621        .await
1622        .unwrap();
1623        tool.execute(json!({ "action": "add_node", "name": "Concept 1", "node_type": "concept", "id": "c1" }))
1624            .await.unwrap();
1625        tool.execute(
1626            json!({ "action": "add_node", "name": "Author 1", "node_type": "person", "id": "a1" }),
1627        )
1628        .await
1629        .unwrap();
1630        tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "c1", "relationship_type": "implements" }))
1631            .await.unwrap();
1632        tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "a1", "relationship_type": "authored_by" }))
1633            .await.unwrap();
1634
1635        let result = tool.execute(json!({ "action": "stats" })).await.unwrap();
1636        assert!(result.content.contains("Nodes: 3"));
1637        assert!(result.content.contains("Edges: 2"));
1638        assert!(result.content.contains("Paper: 1"));
1639        assert!(result.content.contains("Concept: 1"));
1640        assert!(result.content.contains("Person: 1"));
1641        assert!(result.content.contains("Most connected"));
1642        assert!(result.content.contains("p1"));
1643    }
1644
1645    #[tokio::test]
1646    async fn test_export_dot() {
1647        let (_dir, tool) = make_tool();
1648        tool.execute(
1649            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1650        )
1651        .await
1652        .unwrap();
1653        tool.execute(
1654            json!({ "action": "add_node", "name": "B", "node_type": "method", "id": "b" }),
1655        )
1656        .await
1657        .unwrap();
1658        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "implements" }))
1659            .await.unwrap();
1660
1661        let result = tool
1662            .execute(json!({ "action": "export_dot" }))
1663            .await
1664            .unwrap();
1665        assert!(result.content.contains("digraph KnowledgeGraph"));
1666        assert!(result.content.contains("\"a\""));
1667        assert!(result.content.contains("\"b\""));
1668        assert!(result.content.contains("Implements"));
1669        assert!(result.content.contains("->"));
1670    }
1671
1672    #[tokio::test]
1673    async fn test_state_roundtrip() {
1674        let (_dir, tool) = make_tool();
1675        // Add node, save, reload and verify
1676        tool.execute(json!({
1677            "action": "add_node",
1678            "name": "Test Node",
1679            "node_type": "dataset",
1680            "id": "test-node",
1681            "description": "A test dataset",
1682            "tags": ["test", "data"],
1683            "metadata": { "source": "kaggle", "size": "1GB" }
1684        }))
1685        .await
1686        .unwrap();
1687
1688        // Reload state from disk
1689        let state = tool.load_state();
1690        assert_eq!(state.nodes.len(), 1);
1691        let node = &state.nodes[0];
1692        assert_eq!(node.id, "test-node");
1693        assert_eq!(node.name, "Test Node");
1694        assert_eq!(node.node_type, NodeType::Dataset);
1695        assert_eq!(node.description, "A test dataset");
1696        assert_eq!(node.tags, vec!["test", "data"]);
1697        assert_eq!(node.metadata.get("source").unwrap(), "kaggle");
1698        assert_eq!(node.metadata.get("size").unwrap(), "1GB");
1699    }
1700
1701    #[tokio::test]
1702    async fn test_unknown_action() {
1703        let (_dir, tool) = make_tool();
1704        let result = tool.execute(json!({ "action": "foobar" })).await.unwrap();
1705        assert!(result.content.contains("Unknown action"));
1706        assert!(result.content.contains("foobar"));
1707    }
1708}