Skip to main content

rustant_tools/
knowledge_graph.rs

1//! Knowledge graph tool — local graph of concepts, papers, methods, and relationships.
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::path::PathBuf;
11use std::time::Duration;
12
13use crate::registry::Tool;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16enum NodeType {
17    Paper,
18    Concept,
19    Method,
20    Dataset,
21    Person,
22    Organization,
23}
24
25impl NodeType {
26    fn from_str_loose(s: &str) -> Option<Self> {
27        match s.to_lowercase().as_str() {
28            "paper" => Some(Self::Paper),
29            "concept" => Some(Self::Concept),
30            "method" => Some(Self::Method),
31            "dataset" => Some(Self::Dataset),
32            "person" => Some(Self::Person),
33            "organization" => Some(Self::Organization),
34            _ => None,
35        }
36    }
37
38    fn as_str(&self) -> &str {
39        match self {
40            Self::Paper => "Paper",
41            Self::Concept => "Concept",
42            Self::Method => "Method",
43            Self::Dataset => "Dataset",
44            Self::Person => "Person",
45            Self::Organization => "Organization",
46        }
47    }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51enum RelationshipType {
52    Cites,
53    Implements,
54    Extends,
55    Contradicts,
56    BuildsOn,
57    AuthoredBy,
58    UsesDataset,
59    RelatedTo,
60}
61
62impl RelationshipType {
63    fn from_str_loose(s: &str) -> Option<Self> {
64        match s.to_lowercase().as_str() {
65            "cites" => Some(Self::Cites),
66            "implements" => Some(Self::Implements),
67            "extends" => Some(Self::Extends),
68            "contradicts" => Some(Self::Contradicts),
69            "builds_on" => Some(Self::BuildsOn),
70            "authored_by" => Some(Self::AuthoredBy),
71            "uses_dataset" => Some(Self::UsesDataset),
72            "related_to" => Some(Self::RelatedTo),
73            _ => None,
74        }
75    }
76
77    fn as_str(&self) -> &str {
78        match self {
79            Self::Cites => "Cites",
80            Self::Implements => "Implements",
81            Self::Extends => "Extends",
82            Self::Contradicts => "Contradicts",
83            Self::BuildsOn => "BuildsOn",
84            Self::AuthoredBy => "AuthoredBy",
85            Self::UsesDataset => "UsesDataset",
86            Self::RelatedTo => "RelatedTo",
87        }
88    }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct GraphNode {
93    id: String,
94    name: String,
95    node_type: NodeType,
96    description: String,
97    tags: Vec<String>,
98    metadata: HashMap<String, String>,
99    created_at: DateTime<Utc>,
100    updated_at: DateTime<Utc>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104struct Edge {
105    source_id: String,
106    target_id: String,
107    relationship_type: RelationshipType,
108    strength: f64,
109    notes: String,
110    created_at: DateTime<Utc>,
111}
112
113#[derive(Debug, Default, Serialize, Deserialize)]
114struct KnowledgeGraphState {
115    nodes: Vec<GraphNode>,
116    edges: Vec<Edge>,
117    next_auto_id: usize,
118}
119
120/// Minimal struct for deserializing papers from the ArXiv library.
121#[derive(Debug, Deserialize)]
122struct ArxivLibraryFile {
123    entries: Vec<ArxivLibraryEntry>,
124}
125
126#[derive(Debug, Deserialize)]
127struct ArxivLibraryEntry {
128    paper: ArxivPaperRef,
129    #[serde(default)]
130    tags: Vec<String>,
131}
132
133#[derive(Debug, Deserialize)]
134struct ArxivPaperRef {
135    #[serde(alias = "arxiv_id")]
136    id: String,
137    title: String,
138    authors: Vec<String>,
139    #[serde(alias = "summary")]
140    abstract_text: String,
141}
142
143pub struct KnowledgeGraphTool {
144    workspace: PathBuf,
145}
146
147impl KnowledgeGraphTool {
148    pub fn new(workspace: PathBuf) -> Self {
149        Self { workspace }
150    }
151
152    fn state_path(&self) -> PathBuf {
153        self.workspace
154            .join(".rustant")
155            .join("knowledge")
156            .join("graph.json")
157    }
158
159    fn load_state(&self) -> KnowledgeGraphState {
160        let path = self.state_path();
161        if path.exists() {
162            std::fs::read_to_string(&path)
163                .ok()
164                .and_then(|s| serde_json::from_str(&s).ok())
165                .unwrap_or_default()
166        } else {
167            KnowledgeGraphState::default()
168        }
169    }
170
171    fn save_state(&self, state: &KnowledgeGraphState) -> Result<(), ToolError> {
172        let path = self.state_path();
173        if let Some(parent) = path.parent() {
174            std::fs::create_dir_all(parent).map_err(|e| ToolError::ExecutionFailed {
175                name: "knowledge_graph".to_string(),
176                message: format!("Failed to create state dir: {}", e),
177            })?;
178        }
179        let json = serde_json::to_string_pretty(state).map_err(|e| ToolError::ExecutionFailed {
180            name: "knowledge_graph".to_string(),
181            message: format!("Failed to serialize state: {}", e),
182        })?;
183        let tmp = path.with_extension("json.tmp");
184        std::fs::write(&tmp, &json).map_err(|e| ToolError::ExecutionFailed {
185            name: "knowledge_graph".to_string(),
186            message: format!("Failed to write state: {}", e),
187        })?;
188        std::fs::rename(&tmp, &path).map_err(|e| ToolError::ExecutionFailed {
189            name: "knowledge_graph".to_string(),
190            message: format!("Failed to rename state file: {}", e),
191        })?;
192        Ok(())
193    }
194
195    fn slug(name: &str) -> String {
196        name.to_lowercase().replace(' ', "-")
197    }
198
199    fn arxiv_library_path(&self) -> PathBuf {
200        self.workspace
201            .join(".rustant")
202            .join("arxiv")
203            .join("library.json")
204    }
205}
206
207#[async_trait]
208impl Tool for KnowledgeGraphTool {
209    fn name(&self) -> &str {
210        "knowledge_graph"
211    }
212
213    fn description(&self) -> &str {
214        "Local knowledge graph of concepts, papers, methods, and relationships. Actions: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot."
215    }
216
217    fn parameters_schema(&self) -> Value {
218        json!({
219            "type": "object",
220            "properties": {
221                "action": {
222                    "type": "string",
223                    "enum": [
224                        "add_node", "get_node", "update_node", "remove_node",
225                        "add_edge", "remove_edge", "neighbors", "search",
226                        "list", "path", "stats", "import_arxiv", "export_dot"
227                    ],
228                    "description": "Action to perform"
229                },
230                "id": { "type": "string", "description": "Node ID" },
231                "name": { "type": "string", "description": "Node name" },
232                "node_type": {
233                    "type": "string",
234                    "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
235                    "description": "Type of node"
236                },
237                "description": { "type": "string", "description": "Node description" },
238                "tags": {
239                    "type": "array",
240                    "items": { "type": "string" },
241                    "description": "Tags for the node"
242                },
243                "metadata": {
244                    "type": "object",
245                    "additionalProperties": { "type": "string" },
246                    "description": "Key-value metadata for the node"
247                },
248                "source_id": { "type": "string", "description": "Source node ID for edge" },
249                "target_id": { "type": "string", "description": "Target node ID for edge" },
250                "relationship_type": {
251                    "type": "string",
252                    "enum": ["cites", "implements", "extends", "contradicts", "builds_on", "authored_by", "uses_dataset", "related_to"],
253                    "description": "Type of relationship"
254                },
255                "strength": {
256                    "type": "number",
257                    "description": "Edge strength 0.0-1.0 (default 0.5)"
258                },
259                "notes": { "type": "string", "description": "Edge notes" },
260                "depth": {
261                    "type": "integer",
262                    "description": "Traversal depth for neighbors (1-3, default 1)"
263                },
264                "query": { "type": "string", "description": "Search query" },
265                "tag": { "type": "string", "description": "Filter by tag" },
266                "from_id": { "type": "string", "description": "Path start node ID" },
267                "to_id": { "type": "string", "description": "Path end node ID" },
268                "arxiv_id": { "type": "string", "description": "ArXiv paper ID for import" },
269                "filter_type": {
270                    "type": "string",
271                    "enum": ["paper", "concept", "method", "dataset", "person", "organization"],
272                    "description": "Filter by node type"
273                }
274            },
275            "required": ["action"]
276        })
277    }
278
279    fn risk_level(&self) -> RiskLevel {
280        RiskLevel::Write
281    }
282
283    fn timeout(&self) -> Duration {
284        Duration::from_secs(30)
285    }
286
287    async fn execute(&self, args: Value) -> Result<ToolOutput, ToolError> {
288        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("");
289        let mut state = self.load_state();
290
291        match action {
292            "add_node" => {
293                let name = args
294                    .get("name")
295                    .and_then(|v| v.as_str())
296                    .unwrap_or("")
297                    .trim();
298                if name.is_empty() {
299                    return Ok(ToolOutput::text("Missing required parameter 'name'."));
300                }
301
302                let node_type_str = args
303                    .get("node_type")
304                    .and_then(|v| v.as_str())
305                    .unwrap_or("");
306                let node_type = match NodeType::from_str_loose(node_type_str) {
307                    Some(nt) => nt,
308                    None => {
309                        return Ok(ToolOutput::text(format!(
310                            "Invalid node_type '{}'. Use: paper, concept, method, dataset, person, organization.",
311                            node_type_str
312                        )));
313                    }
314                };
315
316                let id = args
317                    .get("id")
318                    .and_then(|v| v.as_str())
319                    .map(|s| s.to_string())
320                    .unwrap_or_else(|| Self::slug(name));
321
322                // Check for duplicate id
323                if state.nodes.iter().any(|n| n.id == id) {
324                    return Ok(ToolOutput::text(format!(
325                        "Node with id '{}' already exists.",
326                        id
327                    )));
328                }
329
330                let description = args
331                    .get("description")
332                    .and_then(|v| v.as_str())
333                    .unwrap_or("")
334                    .to_string();
335
336                let tags: Vec<String> = args
337                    .get("tags")
338                    .and_then(|v| v.as_array())
339                    .map(|arr| {
340                        arr.iter()
341                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
342                            .collect()
343                    })
344                    .unwrap_or_default();
345
346                let metadata: HashMap<String, String> = args
347                    .get("metadata")
348                    .and_then(|v| v.as_object())
349                    .map(|obj| {
350                        obj.iter()
351                            .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
352                            .collect()
353                    })
354                    .unwrap_or_default();
355
356                let now = Utc::now();
357                state.nodes.push(GraphNode {
358                    id: id.clone(),
359                    name: name.to_string(),
360                    node_type,
361                    description,
362                    tags,
363                    metadata,
364                    created_at: now,
365                    updated_at: now,
366                });
367                state.next_auto_id += 1;
368                self.save_state(&state)?;
369
370                Ok(ToolOutput::text(format!(
371                    "Added node '{}' (id: {}).",
372                    name, id
373                )))
374            }
375
376            "get_node" => {
377                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
378                if id.is_empty() {
379                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
380                }
381
382                let node = match state.nodes.iter().find(|n| n.id == id) {
383                    Some(n) => n,
384                    None => {
385                        return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
386                    }
387                };
388
389                let connected_edges: Vec<&Edge> = state
390                    .edges
391                    .iter()
392                    .filter(|e| e.source_id == id || e.target_id == id)
393                    .collect();
394
395                let mut output = format!(
396                    "Node: {} ({})\n  ID: {}\n  Type: {}\n  Description: {}\n  Tags: {}\n  Created: {}\n  Updated: {}",
397                    node.name,
398                    node.node_type.as_str(),
399                    node.id,
400                    node.node_type.as_str(),
401                    if node.description.is_empty() { "(none)" } else { &node.description },
402                    if node.tags.is_empty() { "(none)".to_string() } else { node.tags.join(", ") },
403                    node.created_at.format("%Y-%m-%d %H:%M"),
404                    node.updated_at.format("%Y-%m-%d %H:%M"),
405                );
406
407                if !node.metadata.is_empty() {
408                    output.push_str("\n  Metadata:");
409                    for (k, v) in &node.metadata {
410                        output.push_str(&format!("\n    {}: {}", k, v));
411                    }
412                }
413
414                if !connected_edges.is_empty() {
415                    output.push_str(&format!("\n  Edges ({}):", connected_edges.len()));
416                    for edge in &connected_edges {
417                        let direction = if edge.source_id == id {
418                            format!("--[{}]--> {}", edge.relationship_type.as_str(), edge.target_id)
419                        } else {
420                            format!(
421                                "<--[{}]-- {}",
422                                edge.relationship_type.as_str(),
423                                edge.source_id
424                            )
425                        };
426                        output.push_str(&format!(
427                            "\n    {} (strength: {:.2})",
428                            direction, edge.strength
429                        ));
430                    }
431                }
432
433                Ok(ToolOutput::text(output))
434            }
435
436            "update_node" => {
437                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
438                if id.is_empty() {
439                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
440                }
441
442                let node = match state.nodes.iter_mut().find(|n| n.id == id) {
443                    Some(n) => n,
444                    None => {
445                        return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
446                    }
447                };
448
449                if let Some(name) = args.get("name").and_then(|v| v.as_str()) {
450                    node.name = name.to_string();
451                }
452                if let Some(desc) = args.get("description").and_then(|v| v.as_str()) {
453                    node.description = desc.to_string();
454                }
455                if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
456                    node.tags = tags
457                        .iter()
458                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
459                        .collect();
460                }
461                node.updated_at = Utc::now();
462                self.save_state(&state)?;
463
464                Ok(ToolOutput::text(format!("Updated node '{}'.", id)))
465            }
466
467            "remove_node" => {
468                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
469                if id.is_empty() {
470                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
471                }
472
473                let before_nodes = state.nodes.len();
474                state.nodes.retain(|n| n.id != id);
475                if state.nodes.len() == before_nodes {
476                    return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
477                }
478
479                // Cascade-delete edges referencing this node
480                let before_edges = state.edges.len();
481                state
482                    .edges
483                    .retain(|e| e.source_id != id && e.target_id != id);
484                let edges_removed = before_edges - state.edges.len();
485
486                self.save_state(&state)?;
487
488                Ok(ToolOutput::text(format!(
489                    "Removed node '{}' and {} connected edge(s).",
490                    id, edges_removed
491                )))
492            }
493
494            "add_edge" => {
495                let source_id = args
496                    .get("source_id")
497                    .and_then(|v| v.as_str())
498                    .unwrap_or("");
499                let target_id = args
500                    .get("target_id")
501                    .and_then(|v| v.as_str())
502                    .unwrap_or("");
503                let rel_str = args
504                    .get("relationship_type")
505                    .and_then(|v| v.as_str())
506                    .unwrap_or("");
507
508                if source_id.is_empty() || target_id.is_empty() {
509                    return Ok(ToolOutput::text(
510                        "Missing required parameters 'source_id' and 'target_id'.",
511                    ));
512                }
513
514                // Validate nodes exist
515                if !state.nodes.iter().any(|n| n.id == source_id) {
516                    return Ok(ToolOutput::text(format!(
517                        "Source node '{}' not found.",
518                        source_id
519                    )));
520                }
521                if !state.nodes.iter().any(|n| n.id == target_id) {
522                    return Ok(ToolOutput::text(format!(
523                        "Target node '{}' not found.",
524                        target_id
525                    )));
526                }
527
528                let relationship_type = match RelationshipType::from_str_loose(rel_str) {
529                    Some(rt) => rt,
530                    None => {
531                        return Ok(ToolOutput::text(format!(
532                            "Invalid relationship_type '{}'. Use: cites, implements, extends, contradicts, builds_on, authored_by, uses_dataset, related_to.",
533                            rel_str
534                        )));
535                    }
536                };
537
538                let strength = args
539                    .get("strength")
540                    .and_then(|v| v.as_f64())
541                    .unwrap_or(0.5)
542                    .clamp(0.0, 1.0);
543
544                let notes = args
545                    .get("notes")
546                    .and_then(|v| v.as_str())
547                    .unwrap_or("")
548                    .to_string();
549
550                state.edges.push(Edge {
551                    source_id: source_id.to_string(),
552                    target_id: target_id.to_string(),
553                    relationship_type,
554                    strength,
555                    notes,
556                    created_at: Utc::now(),
557                });
558
559                self.save_state(&state)?;
560
561                Ok(ToolOutput::text(format!(
562                    "Added edge {} --[{}]--> {} (strength: {:.2}).",
563                    source_id, rel_str, target_id, strength
564                )))
565            }
566
567            "remove_edge" => {
568                let source_id = args
569                    .get("source_id")
570                    .and_then(|v| v.as_str())
571                    .unwrap_or("");
572                let target_id = args
573                    .get("target_id")
574                    .and_then(|v| v.as_str())
575                    .unwrap_or("");
576
577                if source_id.is_empty() || target_id.is_empty() {
578                    return Ok(ToolOutput::text(
579                        "Missing required parameters 'source_id' and 'target_id'.",
580                    ));
581                }
582
583                let rel_filter = args
584                    .get("relationship_type")
585                    .and_then(|v| v.as_str())
586                    .and_then(RelationshipType::from_str_loose);
587
588                let before = state.edges.len();
589                state.edges.retain(|e| {
590                    if e.source_id == source_id && e.target_id == target_id {
591                        if let Some(ref rt) = rel_filter {
592                            return &e.relationship_type != rt;
593                        }
594                        return false;
595                    }
596                    true
597                });
598                let removed = before - state.edges.len();
599
600                if removed == 0 {
601                    return Ok(ToolOutput::text("No matching edge(s) found."));
602                }
603
604                self.save_state(&state)?;
605
606                Ok(ToolOutput::text(format!(
607                    "Removed {} edge(s) from '{}' to '{}'.",
608                    removed, source_id, target_id
609                )))
610            }
611
612            "neighbors" => {
613                let id = args.get("id").and_then(|v| v.as_str()).unwrap_or("");
614                if id.is_empty() {
615                    return Ok(ToolOutput::text("Missing required parameter 'id'."));
616                }
617
618                if !state.nodes.iter().any(|n| n.id == id) {
619                    return Ok(ToolOutput::text(format!("Node '{}' not found.", id)));
620                }
621
622                let max_depth = args
623                    .get("depth")
624                    .and_then(|v| v.as_u64())
625                    .unwrap_or(1)
626                    .clamp(1, 3) as usize;
627
628                let rel_filter = args
629                    .get("relationship_type")
630                    .and_then(|v| v.as_str())
631                    .and_then(RelationshipType::from_str_loose);
632
633                // BFS traversal
634                let mut visited: HashSet<String> = HashSet::new();
635                visited.insert(id.to_string());
636                let mut queue: VecDeque<(String, usize)> = VecDeque::new();
637                queue.push_back((id.to_string(), 0));
638                let mut found_nodes: Vec<(String, usize)> = Vec::new();
639
640                while let Some((current_id, depth)) = queue.pop_front() {
641                    if depth >= max_depth {
642                        continue;
643                    }
644
645                    for edge in &state.edges {
646                        if let Some(ref rt) = rel_filter {
647                            if &edge.relationship_type != rt {
648                                continue;
649                            }
650                        }
651
652                        let neighbor_id = if edge.source_id == current_id {
653                            &edge.target_id
654                        } else if edge.target_id == current_id {
655                            &edge.source_id
656                        } else {
657                            continue;
658                        };
659
660                        if !visited.contains(neighbor_id) {
661                            visited.insert(neighbor_id.clone());
662                            found_nodes.push((neighbor_id.clone(), depth + 1));
663                            queue.push_back((neighbor_id.clone(), depth + 1));
664                        }
665                    }
666                }
667
668                if found_nodes.is_empty() {
669                    return Ok(ToolOutput::text(format!(
670                        "No neighbors found for '{}' within depth {}.",
671                        id, max_depth
672                    )));
673                }
674
675                let mut output = format!(
676                    "Neighbors of '{}' (depth {}):\n",
677                    id, max_depth
678                );
679                for (nid, depth) in &found_nodes {
680                    if let Some(node) = state.nodes.iter().find(|n| n.id == *nid) {
681                        output.push_str(&format!(
682                            "  [depth {}] {} — {} ({})\n",
683                            depth,
684                            node.id,
685                            node.name,
686                            node.node_type.as_str()
687                        ));
688                    }
689                }
690
691                Ok(ToolOutput::text(output.trim_end().to_string()))
692            }
693
694            "search" => {
695                let query = args
696                    .get("query")
697                    .and_then(|v| v.as_str())
698                    .unwrap_or("")
699                    .to_lowercase();
700                if query.is_empty() {
701                    return Ok(ToolOutput::text("Missing required parameter 'query'."));
702                }
703
704                let type_filter = args
705                    .get("node_type")
706                    .and_then(|v| v.as_str())
707                    .and_then(NodeType::from_str_loose);
708
709                let matches: Vec<&GraphNode> = state
710                    .nodes
711                    .iter()
712                    .filter(|n| {
713                        if let Some(ref nt) = type_filter {
714                            if &n.node_type != nt {
715                                return false;
716                            }
717                        }
718                        n.name.to_lowercase().contains(&query)
719                            || n.description.to_lowercase().contains(&query)
720                            || n.tags
721                                .iter()
722                                .any(|t| t.to_lowercase().contains(&query))
723                    })
724                    .collect();
725
726                if matches.is_empty() {
727                    return Ok(ToolOutput::text(format!(
728                        "No nodes matching '{}'.",
729                        query
730                    )));
731                }
732
733                let mut output = format!("Found {} node(s):\n", matches.len());
734                for node in &matches {
735                    output.push_str(&format!(
736                        "  {} — {} ({})\n",
737                        node.id,
738                        node.name,
739                        node.node_type.as_str()
740                    ));
741                }
742
743                Ok(ToolOutput::text(output.trim_end().to_string()))
744            }
745
746            "list" => {
747                let type_filter = args
748                    .get("node_type")
749                    .and_then(|v| v.as_str())
750                    .and_then(NodeType::from_str_loose);
751
752                let tag_filter = args.get("tag").and_then(|v| v.as_str());
753
754                let filtered: Vec<&GraphNode> = state
755                    .nodes
756                    .iter()
757                    .filter(|n| {
758                        if let Some(ref nt) = type_filter {
759                            if &n.node_type != nt {
760                                return false;
761                            }
762                        }
763                        if let Some(tag) = tag_filter {
764                            if !n
765                                .tags
766                                .iter()
767                                .any(|t| t.eq_ignore_ascii_case(tag))
768                            {
769                                return false;
770                            }
771                        }
772                        true
773                    })
774                    .collect();
775
776                if filtered.is_empty() {
777                    return Ok(ToolOutput::text("No nodes found."));
778                }
779
780                let mut output = format!("Nodes ({}):\n", filtered.len());
781                for node in &filtered {
782                    output.push_str(&format!(
783                        "  {} — {} ({})\n",
784                        node.id,
785                        node.name,
786                        node.node_type.as_str()
787                    ));
788                }
789
790                Ok(ToolOutput::text(output.trim_end().to_string()))
791            }
792
793            "path" => {
794                let from_id = args
795                    .get("from_id")
796                    .and_then(|v| v.as_str())
797                    .unwrap_or("");
798                let to_id = args.get("to_id").and_then(|v| v.as_str()).unwrap_or("");
799
800                if from_id.is_empty() || to_id.is_empty() {
801                    return Ok(ToolOutput::text(
802                        "Missing required parameters 'from_id' and 'to_id'.",
803                    ));
804                }
805
806                if !state.nodes.iter().any(|n| n.id == from_id) {
807                    return Ok(ToolOutput::text(format!(
808                        "Node '{}' not found.",
809                        from_id
810                    )));
811                }
812                if !state.nodes.iter().any(|n| n.id == to_id) {
813                    return Ok(ToolOutput::text(format!(
814                        "Node '{}' not found.",
815                        to_id
816                    )));
817                }
818
819                // BFS shortest path (bidirectional edges)
820                let mut visited: HashSet<String> = HashSet::new();
821                let mut parent: HashMap<String, String> = HashMap::new();
822                let mut queue: VecDeque<String> = VecDeque::new();
823
824                visited.insert(from_id.to_string());
825                queue.push_back(from_id.to_string());
826
827                let mut found = false;
828                while let Some(current) = queue.pop_front() {
829                    if current == to_id {
830                        found = true;
831                        break;
832                    }
833
834                    for edge in &state.edges {
835                        let neighbor = if edge.source_id == current {
836                            &edge.target_id
837                        } else if edge.target_id == current {
838                            &edge.source_id
839                        } else {
840                            continue;
841                        };
842
843                        if !visited.contains(neighbor) {
844                            visited.insert(neighbor.clone());
845                            parent.insert(neighbor.clone(), current.clone());
846                            queue.push_back(neighbor.clone());
847                        }
848                    }
849                }
850
851                if !found {
852                    return Ok(ToolOutput::text(format!(
853                        "No path found from '{}' to '{}'.",
854                        from_id, to_id
855                    )));
856                }
857
858                // Reconstruct path
859                let mut path_ids: Vec<String> = Vec::new();
860                let mut current = to_id.to_string();
861                while current != from_id {
862                    path_ids.push(current.clone());
863                    current = parent.get(&current).unwrap().clone();
864                }
865                path_ids.push(from_id.to_string());
866                path_ids.reverse();
867
868                let mut output = format!(
869                    "Path from '{}' to '{}' ({} hops):\n",
870                    from_id,
871                    to_id,
872                    path_ids.len() - 1
873                );
874                for (i, pid) in path_ids.iter().enumerate() {
875                    let node_name = state
876                        .nodes
877                        .iter()
878                        .find(|n| n.id == *pid)
879                        .map(|n| n.name.as_str())
880                        .unwrap_or("?");
881                    if i > 0 {
882                        output.push_str("  -> ");
883                    } else {
884                        output.push_str("  ");
885                    }
886                    output.push_str(&format!("{} ({})\n", pid, node_name));
887                }
888
889                Ok(ToolOutput::text(output.trim_end().to_string()))
890            }
891
892            "stats" => {
893                let total_nodes = state.nodes.len();
894                let total_edges = state.edges.len();
895
896                if total_nodes == 0 {
897                    return Ok(ToolOutput::text(
898                        "Knowledge graph is empty. Use add_node to get started.",
899                    ));
900                }
901
902                // Node counts by type
903                let mut type_counts: HashMap<&str, usize> = HashMap::new();
904                for node in &state.nodes {
905                    *type_counts.entry(node.node_type.as_str()).or_insert(0) += 1;
906                }
907
908                // Edge counts by type
909                let mut edge_type_counts: HashMap<&str, usize> = HashMap::new();
910                for edge in &state.edges {
911                    *edge_type_counts
912                        .entry(edge.relationship_type.as_str())
913                        .or_insert(0) += 1;
914                }
915
916                // Top-5 most connected nodes
917                let mut connection_counts: HashMap<&str, usize> = HashMap::new();
918                for edge in &state.edges {
919                    *connection_counts.entry(&edge.source_id).or_insert(0) += 1;
920                    *connection_counts.entry(&edge.target_id).or_insert(0) += 1;
921                }
922                let mut top_connected: Vec<(&&str, &usize)> = connection_counts.iter().collect();
923                top_connected.sort_by(|a, b| b.1.cmp(a.1));
924                top_connected.truncate(5);
925
926                let mut output = format!(
927                    "Knowledge Graph Stats:\n  Nodes: {}\n  Edges: {}\n\n  Nodes by type:\n",
928                    total_nodes, total_edges
929                );
930                let mut sorted_types: Vec<_> = type_counts.iter().collect();
931                sorted_types.sort_by_key(|(k, _)| *k);
932                for (t, c) in &sorted_types {
933                    output.push_str(&format!("    {}: {}\n", t, c));
934                }
935
936                if !edge_type_counts.is_empty() {
937                    output.push_str("\n  Edges by type:\n");
938                    let mut sorted_etypes: Vec<_> = edge_type_counts.iter().collect();
939                    sorted_etypes.sort_by_key(|(k, _)| *k);
940                    for (t, c) in &sorted_etypes {
941                        output.push_str(&format!("    {}: {}\n", t, c));
942                    }
943                }
944
945                if !top_connected.is_empty() {
946                    output.push_str("\n  Most connected nodes:\n");
947                    for (nid, count) in &top_connected {
948                        let node_name = state
949                            .nodes
950                            .iter()
951                            .find(|n| n.id == ***nid)
952                            .map(|n| n.name.as_str())
953                            .unwrap_or("?");
954                        output.push_str(&format!("    {} ({}) — {} connections\n", nid, node_name, count));
955                    }
956                }
957
958                Ok(ToolOutput::text(output.trim_end().to_string()))
959            }
960
961            "import_arxiv" => {
962                let arxiv_id = args
963                    .get("arxiv_id")
964                    .and_then(|v| v.as_str())
965                    .unwrap_or("");
966                if arxiv_id.is_empty() {
967                    return Ok(ToolOutput::text(
968                        "Missing required parameter 'arxiv_id'.",
969                    ));
970                }
971
972                let lib_path = self.arxiv_library_path();
973                if !lib_path.exists() {
974                    return Ok(ToolOutput::text(
975                        "ArXiv library not found. Save papers with the arxiv_research tool first.",
976                    ));
977                }
978
979                let lib_json = std::fs::read_to_string(&lib_path).map_err(|e| {
980                    ToolError::ExecutionFailed {
981                        name: "knowledge_graph".to_string(),
982                        message: format!("Failed to read arxiv library: {}", e),
983                    }
984                })?;
985
986                let library: ArxivLibraryFile =
987                    serde_json::from_str(&lib_json).map_err(|e| ToolError::ExecutionFailed {
988                        name: "knowledge_graph".to_string(),
989                        message: format!("Failed to parse arxiv library: {}", e),
990                    })?;
991
992                let entry = match library
993                    .entries
994                    .iter()
995                    .find(|e| e.paper.id == arxiv_id)
996                {
997                    Some(e) => e,
998                    None => {
999                        return Ok(ToolOutput::text(format!(
1000                            "Paper '{}' not found in arxiv library.",
1001                            arxiv_id
1002                        )));
1003                    }
1004                };
1005
1006                let now = Utc::now();
1007                let paper_id = Self::slug(&entry.paper.title);
1008                let mut added_nodes = 0;
1009                let mut added_edges = 0;
1010
1011                // Create Paper node if not exists
1012                if !state.nodes.iter().any(|n| n.id == paper_id) {
1013                    let mut metadata = HashMap::new();
1014                    metadata.insert("arxiv_id".to_string(), entry.paper.id.clone());
1015                    state.nodes.push(GraphNode {
1016                        id: paper_id.clone(),
1017                        name: entry.paper.title.clone(),
1018                        node_type: NodeType::Paper,
1019                        description: entry.paper.abstract_text.clone(),
1020                        tags: entry.tags.clone(),
1021                        metadata,
1022                        created_at: now,
1023                        updated_at: now,
1024                    });
1025                    added_nodes += 1;
1026                }
1027
1028                // Create Person nodes for each author + AuthoredBy edges
1029                for author in &entry.paper.authors {
1030                    let author_id = Self::slug(author);
1031                    if !state.nodes.iter().any(|n| n.id == author_id) {
1032                        state.nodes.push(GraphNode {
1033                            id: author_id.clone(),
1034                            name: author.clone(),
1035                            node_type: NodeType::Person,
1036                            description: String::new(),
1037                            tags: Vec::new(),
1038                            metadata: HashMap::new(),
1039                            created_at: now,
1040                            updated_at: now,
1041                        });
1042                        added_nodes += 1;
1043                    }
1044
1045                    // Add AuthoredBy edge (paper -> author)
1046                    let edge_exists = state.edges.iter().any(|e| {
1047                        e.source_id == paper_id
1048                            && e.target_id == author_id
1049                            && e.relationship_type == RelationshipType::AuthoredBy
1050                    });
1051                    if !edge_exists {
1052                        state.edges.push(Edge {
1053                            source_id: paper_id.clone(),
1054                            target_id: author_id,
1055                            relationship_type: RelationshipType::AuthoredBy,
1056                            strength: 1.0,
1057                            notes: String::new(),
1058                            created_at: now,
1059                        });
1060                        added_edges += 1;
1061                    }
1062                }
1063
1064                state.next_auto_id += added_nodes;
1065                self.save_state(&state)?;
1066
1067                Ok(ToolOutput::text(format!(
1068                    "Imported '{}' from arxiv: {} node(s) and {} edge(s) added.",
1069                    entry.paper.title, added_nodes, added_edges
1070                )))
1071            }
1072
1073            "export_dot" => {
1074                let type_filter = args
1075                    .get("filter_type")
1076                    .and_then(|v| v.as_str())
1077                    .and_then(NodeType::from_str_loose);
1078
1079                let filtered_nodes: Vec<&GraphNode> = state
1080                    .nodes
1081                    .iter()
1082                    .filter(|n| {
1083                        type_filter
1084                            .as_ref()
1085                            .map(|nt| &n.node_type == nt)
1086                            .unwrap_or(true)
1087                    })
1088                    .collect();
1089
1090                let node_ids: HashSet<&str> =
1091                    filtered_nodes.iter().map(|n| n.id.as_str()).collect();
1092
1093                let mut dot = String::from("digraph KnowledgeGraph {\n");
1094                dot.push_str("  rankdir=LR;\n");
1095                dot.push_str("  node [shape=box];\n\n");
1096
1097                for node in &filtered_nodes {
1098                    dot.push_str(&format!(
1099                        "  \"{}\" [label=\"{}\\n({})\"];\n",
1100                        node.id,
1101                        node.name.replace('"', "\\\""),
1102                        node.node_type.as_str()
1103                    ));
1104                }
1105
1106                dot.push('\n');
1107
1108                for edge in &state.edges {
1109                    if node_ids.contains(edge.source_id.as_str())
1110                        && node_ids.contains(edge.target_id.as_str())
1111                    {
1112                        dot.push_str(&format!(
1113                            "  \"{}\" -> \"{}\" [label=\"{}\"];\n",
1114                            edge.source_id,
1115                            edge.target_id,
1116                            edge.relationship_type.as_str()
1117                        ));
1118                    }
1119                }
1120
1121                dot.push_str("}\n");
1122
1123                Ok(ToolOutput::text(dot))
1124            }
1125
1126            _ => Ok(ToolOutput::text(format!(
1127                "Unknown action: '{}'. Use: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot.",
1128                action
1129            ))),
1130        }
1131    }
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137    use tempfile::TempDir;
1138
1139    fn make_tool() -> (TempDir, KnowledgeGraphTool) {
1140        let dir = TempDir::new().unwrap();
1141        let workspace = dir.path().canonicalize().unwrap();
1142        let tool = KnowledgeGraphTool::new(workspace);
1143        (dir, tool)
1144    }
1145
1146    #[test]
1147    fn test_tool_properties() {
1148        let (_dir, tool) = make_tool();
1149        assert_eq!(tool.name(), "knowledge_graph");
1150        assert_eq!(tool.risk_level(), RiskLevel::Write);
1151        assert_eq!(tool.timeout(), Duration::from_secs(30));
1152        assert!(!tool.description().is_empty());
1153    }
1154
1155    #[test]
1156    fn test_schema_validation() {
1157        let (_dir, tool) = make_tool();
1158        let schema = tool.parameters_schema();
1159        assert!(schema.get("properties").is_some());
1160        let required = schema["required"].as_array().unwrap();
1161        assert!(required.iter().any(|v| v.as_str() == Some("action")));
1162        let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
1163        assert_eq!(actions.len(), 13);
1164    }
1165
1166    #[tokio::test]
1167    async fn test_add_node_basic() {
1168        let (_dir, tool) = make_tool();
1169        let result = tool
1170            .execute(json!({
1171                "action": "add_node",
1172                "name": "Attention Is All You Need",
1173                "node_type": "paper",
1174                "description": "Transformer architecture paper"
1175            }))
1176            .await
1177            .unwrap();
1178        assert!(result.content.contains("Added node"));
1179        assert!(result.content.contains("attention-is-all-you-need"));
1180    }
1181
1182    #[tokio::test]
1183    async fn test_add_node_auto_id_from_slug() {
1184        let (_dir, tool) = make_tool();
1185        let result = tool
1186            .execute(json!({
1187                "action": "add_node",
1188                "name": "Deep Learning Basics",
1189                "node_type": "concept"
1190            }))
1191            .await
1192            .unwrap();
1193        assert!(result.content.contains("deep-learning-basics"));
1194
1195        // Verify the id was generated correctly
1196        let state = tool.load_state();
1197        assert_eq!(state.nodes[0].id, "deep-learning-basics");
1198        assert_eq!(state.nodes[0].name, "Deep Learning Basics");
1199    }
1200
1201    #[tokio::test]
1202    async fn test_get_node_with_edges() {
1203        let (_dir, tool) = make_tool();
1204        // Add two nodes and an edge
1205        tool.execute(json!({
1206            "action": "add_node",
1207            "name": "Paper A",
1208            "node_type": "paper",
1209            "id": "paper-a"
1210        }))
1211        .await
1212        .unwrap();
1213        tool.execute(json!({
1214            "action": "add_node",
1215            "name": "Concept B",
1216            "node_type": "concept",
1217            "id": "concept-b"
1218        }))
1219        .await
1220        .unwrap();
1221        tool.execute(json!({
1222            "action": "add_edge",
1223            "source_id": "paper-a",
1224            "target_id": "concept-b",
1225            "relationship_type": "implements"
1226        }))
1227        .await
1228        .unwrap();
1229
1230        let result = tool
1231            .execute(json!({ "action": "get_node", "id": "paper-a" }))
1232            .await
1233            .unwrap();
1234        assert!(result.content.contains("Paper A"));
1235        assert!(result.content.contains("Implements"));
1236        assert!(result.content.contains("concept-b"));
1237    }
1238
1239    #[tokio::test]
1240    async fn test_update_node() {
1241        let (_dir, tool) = make_tool();
1242        tool.execute(json!({
1243            "action": "add_node",
1244            "name": "Old Name",
1245            "node_type": "concept",
1246            "id": "test-node"
1247        }))
1248        .await
1249        .unwrap();
1250
1251        let result = tool
1252            .execute(json!({
1253                "action": "update_node",
1254                "id": "test-node",
1255                "name": "New Name",
1256                "description": "Updated description",
1257                "tags": ["updated"]
1258            }))
1259            .await
1260            .unwrap();
1261        assert!(result.content.contains("Updated node"));
1262
1263        let state = tool.load_state();
1264        let node = state.nodes.iter().find(|n| n.id == "test-node").unwrap();
1265        assert_eq!(node.name, "New Name");
1266        assert_eq!(node.description, "Updated description");
1267        assert_eq!(node.tags, vec!["updated"]);
1268    }
1269
1270    #[tokio::test]
1271    async fn test_remove_node_cascades_edges() {
1272        let (_dir, tool) = make_tool();
1273        // Add three nodes and edges
1274        tool.execute(
1275            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1276        )
1277        .await
1278        .unwrap();
1279        tool.execute(
1280            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1281        )
1282        .await
1283        .unwrap();
1284        tool.execute(
1285            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1286        )
1287        .await
1288        .unwrap();
1289        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1290            .await.unwrap();
1291        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1292            .await.unwrap();
1293
1294        let state = tool.load_state();
1295        assert_eq!(state.edges.len(), 2);
1296
1297        // Remove node B — should cascade-delete both edges
1298        let result = tool
1299            .execute(json!({ "action": "remove_node", "id": "b" }))
1300            .await
1301            .unwrap();
1302        assert!(result.content.contains("Removed node 'b'"));
1303        assert!(result.content.contains("2 connected edge(s)"));
1304
1305        let state = tool.load_state();
1306        assert_eq!(state.nodes.len(), 2);
1307        assert_eq!(state.edges.len(), 0);
1308    }
1309
1310    #[tokio::test]
1311    async fn test_add_edge_validates_nodes() {
1312        let (_dir, tool) = make_tool();
1313        tool.execute(
1314            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1315        )
1316        .await
1317        .unwrap();
1318
1319        // Target node doesn't exist
1320        let result = tool
1321            .execute(json!({
1322                "action": "add_edge",
1323                "source_id": "a",
1324                "target_id": "nonexistent",
1325                "relationship_type": "related_to"
1326            }))
1327            .await
1328            .unwrap();
1329        assert!(result.content.contains("not found"));
1330
1331        // Source node doesn't exist
1332        let result = tool
1333            .execute(json!({
1334                "action": "add_edge",
1335                "source_id": "nonexistent",
1336                "target_id": "a",
1337                "relationship_type": "related_to"
1338            }))
1339            .await
1340            .unwrap();
1341        assert!(result.content.contains("not found"));
1342    }
1343
1344    #[tokio::test]
1345    async fn test_add_edge_strength_default() {
1346        let (_dir, tool) = make_tool();
1347        tool.execute(
1348            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1349        )
1350        .await
1351        .unwrap();
1352        tool.execute(
1353            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1354        )
1355        .await
1356        .unwrap();
1357
1358        tool.execute(json!({
1359            "action": "add_edge",
1360            "source_id": "a",
1361            "target_id": "b",
1362            "relationship_type": "related_to"
1363        }))
1364        .await
1365        .unwrap();
1366
1367        let state = tool.load_state();
1368        assert_eq!(state.edges.len(), 1);
1369        assert!((state.edges[0].strength - 0.5).abs() < f64::EPSILON);
1370    }
1371
1372    #[tokio::test]
1373    async fn test_remove_edge() {
1374        let (_dir, tool) = make_tool();
1375        tool.execute(
1376            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1377        )
1378        .await
1379        .unwrap();
1380        tool.execute(
1381            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1382        )
1383        .await
1384        .unwrap();
1385        tool.execute(json!({
1386            "action": "add_edge",
1387            "source_id": "a",
1388            "target_id": "b",
1389            "relationship_type": "related_to"
1390        }))
1391        .await
1392        .unwrap();
1393
1394        assert_eq!(tool.load_state().edges.len(), 1);
1395
1396        let result = tool
1397            .execute(json!({
1398                "action": "remove_edge",
1399                "source_id": "a",
1400                "target_id": "b"
1401            }))
1402            .await
1403            .unwrap();
1404        assert!(result.content.contains("Removed"));
1405        assert_eq!(tool.load_state().edges.len(), 0);
1406    }
1407
1408    #[tokio::test]
1409    async fn test_neighbors_depth_1() {
1410        let (_dir, tool) = make_tool();
1411        tool.execute(
1412            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1413        )
1414        .await
1415        .unwrap();
1416        tool.execute(
1417            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1418        )
1419        .await
1420        .unwrap();
1421        tool.execute(
1422            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1423        )
1424        .await
1425        .unwrap();
1426        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1427            .await.unwrap();
1428        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1429            .await.unwrap();
1430
1431        let result = tool
1432            .execute(json!({ "action": "neighbors", "id": "a", "depth": 1 }))
1433            .await
1434            .unwrap();
1435        assert!(result.content.contains("b"));
1436        // At depth 1, should NOT find C
1437        assert!(!result.content.contains(" c ") && !result.content.contains("\n  [depth 1] c"));
1438    }
1439
1440    #[tokio::test]
1441    async fn test_neighbors_depth_2() {
1442        let (_dir, tool) = make_tool();
1443        tool.execute(
1444            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1445        )
1446        .await
1447        .unwrap();
1448        tool.execute(
1449            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1450        )
1451        .await
1452        .unwrap();
1453        tool.execute(
1454            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1455        )
1456        .await
1457        .unwrap();
1458        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1459            .await.unwrap();
1460        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1461            .await.unwrap();
1462
1463        let result = tool
1464            .execute(json!({ "action": "neighbors", "id": "a", "depth": 2 }))
1465            .await
1466            .unwrap();
1467        // At depth 2, should find both B and C
1468        assert!(result.content.contains("b"));
1469        assert!(result.content.contains("c"));
1470    }
1471
1472    #[tokio::test]
1473    async fn test_search_by_name() {
1474        let (_dir, tool) = make_tool();
1475        tool.execute(json!({
1476            "action": "add_node",
1477            "name": "Transformer Architecture",
1478            "node_type": "method",
1479            "id": "transformer"
1480        }))
1481        .await
1482        .unwrap();
1483
1484        let result = tool
1485            .execute(json!({ "action": "search", "query": "transformer" }))
1486            .await
1487            .unwrap();
1488        assert!(result.content.contains("Transformer Architecture"));
1489    }
1490
1491    #[tokio::test]
1492    async fn test_search_by_tag() {
1493        let (_dir, tool) = make_tool();
1494        tool.execute(json!({
1495            "action": "add_node",
1496            "name": "BERT",
1497            "node_type": "method",
1498            "id": "bert",
1499            "tags": ["nlp", "language-model"]
1500        }))
1501        .await
1502        .unwrap();
1503
1504        let result = tool
1505            .execute(json!({ "action": "search", "query": "nlp" }))
1506            .await
1507            .unwrap();
1508        assert!(result.content.contains("BERT"));
1509    }
1510
1511    #[tokio::test]
1512    async fn test_search_filter_type() {
1513        let (_dir, tool) = make_tool();
1514        tool.execute(json!({ "action": "add_node", "name": "ML Paper", "node_type": "paper", "id": "ml-paper" }))
1515            .await.unwrap();
1516        tool.execute(json!({ "action": "add_node", "name": "ML Concept", "node_type": "concept", "id": "ml-concept" }))
1517            .await.unwrap();
1518
1519        let result = tool
1520            .execute(json!({ "action": "search", "query": "ml", "node_type": "paper" }))
1521            .await
1522            .unwrap();
1523        assert!(result.content.contains("ML Paper"));
1524        assert!(!result.content.contains("ML Concept"));
1525    }
1526
1527    #[tokio::test]
1528    async fn test_list_all() {
1529        let (_dir, tool) = make_tool();
1530        tool.execute(
1531            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1532        )
1533        .await
1534        .unwrap();
1535        tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1536            .await
1537            .unwrap();
1538
1539        let result = tool.execute(json!({ "action": "list" })).await.unwrap();
1540        assert!(result.content.contains("Nodes (2)"));
1541        assert!(result.content.contains("a"));
1542        assert!(result.content.contains("b"));
1543    }
1544
1545    #[tokio::test]
1546    async fn test_list_filter_type() {
1547        let (_dir, tool) = make_tool();
1548        tool.execute(
1549            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1550        )
1551        .await
1552        .unwrap();
1553        tool.execute(json!({ "action": "add_node", "name": "B", "node_type": "paper", "id": "b" }))
1554            .await
1555            .unwrap();
1556
1557        let result = tool
1558            .execute(json!({ "action": "list", "node_type": "concept" }))
1559            .await
1560            .unwrap();
1561        assert!(result.content.contains("Nodes (1)"));
1562        assert!(result.content.contains("a"));
1563        assert!(!result.content.contains(" b "));
1564    }
1565
1566    #[tokio::test]
1567    async fn test_path_direct() {
1568        let (_dir, tool) = make_tool();
1569        tool.execute(
1570            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1571        )
1572        .await
1573        .unwrap();
1574        tool.execute(
1575            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1576        )
1577        .await
1578        .unwrap();
1579        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1580            .await.unwrap();
1581
1582        let result = tool
1583            .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1584            .await
1585            .unwrap();
1586        assert!(result.content.contains("1 hops"));
1587        assert!(result.content.contains("a"));
1588        assert!(result.content.contains("b"));
1589    }
1590
1591    #[tokio::test]
1592    async fn test_path_indirect() {
1593        let (_dir, tool) = make_tool();
1594        tool.execute(
1595            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1596        )
1597        .await
1598        .unwrap();
1599        tool.execute(
1600            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1601        )
1602        .await
1603        .unwrap();
1604        tool.execute(
1605            json!({ "action": "add_node", "name": "C", "node_type": "concept", "id": "c" }),
1606        )
1607        .await
1608        .unwrap();
1609        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "related_to" }))
1610            .await.unwrap();
1611        tool.execute(json!({ "action": "add_edge", "source_id": "b", "target_id": "c", "relationship_type": "related_to" }))
1612            .await.unwrap();
1613
1614        let result = tool
1615            .execute(json!({ "action": "path", "from_id": "a", "to_id": "c" }))
1616            .await
1617            .unwrap();
1618        assert!(result.content.contains("2 hops"));
1619        assert!(result.content.contains("a"));
1620        assert!(result.content.contains("b"));
1621        assert!(result.content.contains("c"));
1622    }
1623
1624    #[tokio::test]
1625    async fn test_path_no_path() {
1626        let (_dir, tool) = make_tool();
1627        tool.execute(
1628            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1629        )
1630        .await
1631        .unwrap();
1632        tool.execute(
1633            json!({ "action": "add_node", "name": "B", "node_type": "concept", "id": "b" }),
1634        )
1635        .await
1636        .unwrap();
1637        // No edges between them
1638
1639        let result = tool
1640            .execute(json!({ "action": "path", "from_id": "a", "to_id": "b" }))
1641            .await
1642            .unwrap();
1643        assert!(result.content.contains("No path found"));
1644    }
1645
1646    #[tokio::test]
1647    async fn test_stats() {
1648        let (_dir, tool) = make_tool();
1649        tool.execute(
1650            json!({ "action": "add_node", "name": "Paper 1", "node_type": "paper", "id": "p1" }),
1651        )
1652        .await
1653        .unwrap();
1654        tool.execute(json!({ "action": "add_node", "name": "Concept 1", "node_type": "concept", "id": "c1" }))
1655            .await.unwrap();
1656        tool.execute(
1657            json!({ "action": "add_node", "name": "Author 1", "node_type": "person", "id": "a1" }),
1658        )
1659        .await
1660        .unwrap();
1661        tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "c1", "relationship_type": "implements" }))
1662            .await.unwrap();
1663        tool.execute(json!({ "action": "add_edge", "source_id": "p1", "target_id": "a1", "relationship_type": "authored_by" }))
1664            .await.unwrap();
1665
1666        let result = tool.execute(json!({ "action": "stats" })).await.unwrap();
1667        assert!(result.content.contains("Nodes: 3"));
1668        assert!(result.content.contains("Edges: 2"));
1669        assert!(result.content.contains("Paper: 1"));
1670        assert!(result.content.contains("Concept: 1"));
1671        assert!(result.content.contains("Person: 1"));
1672        assert!(result.content.contains("Most connected"));
1673        assert!(result.content.contains("p1"));
1674    }
1675
1676    #[tokio::test]
1677    async fn test_export_dot() {
1678        let (_dir, tool) = make_tool();
1679        tool.execute(
1680            json!({ "action": "add_node", "name": "A", "node_type": "concept", "id": "a" }),
1681        )
1682        .await
1683        .unwrap();
1684        tool.execute(
1685            json!({ "action": "add_node", "name": "B", "node_type": "method", "id": "b" }),
1686        )
1687        .await
1688        .unwrap();
1689        tool.execute(json!({ "action": "add_edge", "source_id": "a", "target_id": "b", "relationship_type": "implements" }))
1690            .await.unwrap();
1691
1692        let result = tool
1693            .execute(json!({ "action": "export_dot" }))
1694            .await
1695            .unwrap();
1696        assert!(result.content.contains("digraph KnowledgeGraph"));
1697        assert!(result.content.contains("\"a\""));
1698        assert!(result.content.contains("\"b\""));
1699        assert!(result.content.contains("Implements"));
1700        assert!(result.content.contains("->"));
1701    }
1702
1703    #[tokio::test]
1704    async fn test_state_roundtrip() {
1705        let (_dir, tool) = make_tool();
1706        // Add node, save, reload and verify
1707        tool.execute(json!({
1708            "action": "add_node",
1709            "name": "Test Node",
1710            "node_type": "dataset",
1711            "id": "test-node",
1712            "description": "A test dataset",
1713            "tags": ["test", "data"],
1714            "metadata": { "source": "kaggle", "size": "1GB" }
1715        }))
1716        .await
1717        .unwrap();
1718
1719        // Reload state from disk
1720        let state = tool.load_state();
1721        assert_eq!(state.nodes.len(), 1);
1722        let node = &state.nodes[0];
1723        assert_eq!(node.id, "test-node");
1724        assert_eq!(node.name, "Test Node");
1725        assert_eq!(node.node_type, NodeType::Dataset);
1726        assert_eq!(node.description, "A test dataset");
1727        assert_eq!(node.tags, vec!["test", "data"]);
1728        assert_eq!(node.metadata.get("source").unwrap(), "kaggle");
1729        assert_eq!(node.metadata.get("size").unwrap(), "1GB");
1730    }
1731
1732    #[tokio::test]
1733    async fn test_unknown_action() {
1734        let (_dir, tool) = make_tool();
1735        let result = tool.execute(json!({ "action": "foobar" })).await.unwrap();
1736        assert!(result.content.contains("Unknown action"));
1737        assert!(result.content.contains("foobar"));
1738    }
1739}