Skip to main content

sqlite_knowledge_graph/export/
mod.rs

1//! Graph visualization export module.
2//!
3//! Supports exporting knowledge graphs to various formats for visualization:
4//! - D3.js JSON format (nodes + links + metadata)
5//! - DOT (Graphviz) format for graph visualization
6
7use chrono::Utc;
8use rusqlite::Connection;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12use crate::error::Result;
13
14// Predefined colors for entity types (cycles if more than 8 types)
15const TYPE_COLORS: &[&str] = &[
16    "blue", "red", "green", "orange", "purple", "brown", "cyan", "magenta",
17];
18
19/// Configuration for DOT format export.
20#[derive(Debug, Clone)]
21pub struct DotConfig {
22    /// Graph layout direction: LR, TB, RL, BT (default: "LR")
23    pub rankdir: String,
24    /// Node shape (default: "ellipse")
25    pub node_shape: String,
26    /// Color nodes by entity type (default: true)
27    pub color_by_type: bool,
28    /// Maximum number of nodes to export (default: None = all)
29    pub max_nodes: Option<usize>,
30}
31
32impl Default for DotConfig {
33    fn default() -> Self {
34        Self {
35            rankdir: "LR".to_string(),
36            node_shape: "ellipse".to_string(),
37            color_by_type: true,
38            max_nodes: None,
39        }
40    }
41}
42
43/// Export the knowledge graph in DOT (Graphviz) format.
44///
45/// Generates a DOT format string suitable for rendering with Graphviz tools
46/// such as `dot`, `neato`, `fdp`, etc.
47///
48/// # Example output
49/// ```text
50/// digraph knowledge_graph {
51///     rankdir=LR;
52///     node [shape=ellipse];
53///     1 [label="Deep Learning" color=blue];
54///     1 -> 2 [label="related_to" weight=0.8];
55/// }
56/// ```
57pub fn export_dot(conn: &Connection, config: &DotConfig) -> Result<String> {
58    let nodes = query_nodes(conn)?;
59    let links = query_links(conn)?;
60
61    let nodes = if let Some(max) = config.max_nodes {
62        nodes.into_iter().take(max).collect::<Vec<_>>()
63    } else {
64        nodes
65    };
66
67    // Build set of included node ids for filtering links
68    let node_ids: std::collections::HashSet<i64> = nodes.iter().map(|n| n.id).collect();
69
70    // Build type -> color mapping
71    let mut type_color_map: HashMap<String, &str> = HashMap::new();
72    if config.color_by_type {
73        let mut color_idx = 0;
74        for node in &nodes {
75            type_color_map
76                .entry(node.node_type.clone())
77                .or_insert_with(|| {
78                    let color = TYPE_COLORS[color_idx % TYPE_COLORS.len()];
79                    color_idx += 1;
80                    color
81                });
82        }
83    }
84
85    let mut dot = String::new();
86    dot.push_str("digraph knowledge_graph {\n");
87    dot.push_str(&format!("    rankdir={};\n", config.rankdir));
88    dot.push_str(&format!("    node [shape={}];\n", config.node_shape));
89    dot.push('\n');
90
91    // Emit nodes
92    for node in &nodes {
93        let label = escape_dot_label(&node.name);
94        if config.color_by_type {
95            let color = type_color_map
96                .get(&node.node_type)
97                .copied()
98                .unwrap_or("black");
99            dot.push_str(&format!(
100                "    {} [label=\"{}\" color={}];\n",
101                node.id, label, color
102            ));
103        } else {
104            dot.push_str(&format!("    {} [label=\"{}\"];\n", node.id, label));
105        }
106    }
107
108    dot.push('\n');
109
110    // Emit edges (only for included nodes)
111    for link in &links {
112        if node_ids.contains(&link.source) && node_ids.contains(&link.target) {
113            let rel_label = escape_dot_label(&link.link_type);
114            dot.push_str(&format!(
115                "    {} -> {} [label=\"{}\" weight={:.4}];\n",
116                link.source, link.target, rel_label, link.weight
117            ));
118        }
119    }
120
121    dot.push_str("}\n");
122    Ok(dot)
123}
124
125/// Escape special characters in DOT label strings.
126fn escape_dot_label(s: &str) -> String {
127    s.replace('\\', "\\\\")
128        .replace('"', "\\\"")
129        .replace('\n', "\\n")
130}
131
132/// A node in the D3.js export format.
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct D3Node {
135    pub id: i64,
136    pub name: String,
137    #[serde(rename = "type")]
138    pub node_type: String,
139    pub properties: HashMap<String, serde_json::Value>,
140}
141
142/// A link (edge) in the D3.js export format.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct D3Link {
145    pub source: i64,
146    pub target: i64,
147    #[serde(rename = "type")]
148    pub link_type: String,
149    pub weight: f64,
150}
151
152/// Metadata for the exported graph.
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct D3ExportMetadata {
155    pub node_count: usize,
156    pub edge_count: usize,
157    pub exported_at: String,
158}
159
160/// D3.js force-directed graph export format.
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct D3ExportGraph {
163    pub nodes: Vec<D3Node>,
164    pub links: Vec<D3Link>,
165    pub metadata: D3ExportMetadata,
166}
167
168/// Export the knowledge graph in D3.js JSON format.
169///
170/// Queries all entities and relations from the database and formats them
171/// as a D3.js force-directed graph with nodes, links, and metadata.
172pub fn export_d3_json(conn: &Connection) -> Result<D3ExportGraph> {
173    let nodes = query_nodes(conn)?;
174    let links = query_links(conn)?;
175
176    let metadata = D3ExportMetadata {
177        node_count: nodes.len(),
178        edge_count: links.len(),
179        exported_at: Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(),
180    };
181
182    Ok(D3ExportGraph {
183        nodes,
184        links,
185        metadata,
186    })
187}
188
189fn query_nodes(conn: &Connection) -> Result<Vec<D3Node>> {
190    let mut stmt =
191        conn.prepare("SELECT id, entity_type, name, properties FROM kg_entities ORDER BY id")?;
192
193    let rows = stmt.query_map([], |row| {
194        let id: i64 = row.get(0)?;
195        let node_type: String = row.get(1)?;
196        let name: String = row.get(2)?;
197        let properties_json: String = row.get(3)?;
198        Ok((id, node_type, name, properties_json))
199    })?;
200
201    let mut nodes = Vec::new();
202    for row in rows {
203        let (id, node_type, name, properties_json) = row?;
204        let properties: HashMap<String, serde_json::Value> =
205            serde_json::from_str(&properties_json).unwrap_or_default();
206        nodes.push(D3Node {
207            id,
208            name,
209            node_type,
210            properties,
211        });
212    }
213
214    Ok(nodes)
215}
216
217fn query_links(conn: &Connection) -> Result<Vec<D3Link>> {
218    let mut stmt = conn
219        .prepare("SELECT source_id, target_id, rel_type, weight FROM kg_relations ORDER BY id")?;
220
221    let rows = stmt.query_map([], |row| {
222        let source: i64 = row.get(0)?;
223        let target: i64 = row.get(1)?;
224        let link_type: String = row.get(2)?;
225        let weight: f64 = row.get(3)?;
226        Ok((source, target, link_type, weight))
227    })?;
228
229    let mut links = Vec::new();
230    for row in rows {
231        let (source, target, link_type, weight) = row?;
232        links.push(D3Link {
233            source,
234            target,
235            link_type,
236            weight,
237        });
238    }
239
240    Ok(links)
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::{Entity, KnowledgeGraph, Relation};
247
248    // ===== DOT export tests =====
249
250    #[test]
251    fn test_export_dot_empty_graph() {
252        let kg = KnowledgeGraph::open_in_memory().unwrap();
253        let config = DotConfig::default();
254        let dot = export_dot(kg.connection(), &config).unwrap();
255
256        assert!(dot.contains("digraph knowledge_graph {"));
257        assert!(dot.contains("rankdir=LR;"));
258        assert!(dot.contains("node [shape=ellipse];"));
259        assert!(dot.ends_with("}\n"));
260    }
261
262    #[test]
263    fn test_export_dot_nodes_and_edges() {
264        let kg = KnowledgeGraph::open_in_memory().unwrap();
265        let id1 = kg
266            .insert_entity(&Entity::new("concept", "Deep Learning"))
267            .unwrap();
268        let id2 = kg
269            .insert_entity(&Entity::new("concept", "Neural Networks"))
270            .unwrap();
271        kg.insert_relation(&Relation::new(id1, id2, "related_to", 0.8).unwrap())
272            .unwrap();
273
274        let config = DotConfig::default();
275        let dot = export_dot(kg.connection(), &config).unwrap();
276
277        assert!(dot.contains("Deep Learning"));
278        assert!(dot.contains("Neural Networks"));
279        assert!(dot.contains("related_to"));
280        assert!(dot.contains(&format!("{} -> {}", id1, id2)));
281        assert!(dot.contains("weight=0.8000"));
282    }
283
284    #[test]
285    fn test_export_dot_rankdir_tb() {
286        let kg = KnowledgeGraph::open_in_memory().unwrap();
287        kg.insert_entity(&Entity::new("concept", "AI")).unwrap();
288
289        let config = DotConfig {
290            rankdir: "TB".to_string(),
291            ..Default::default()
292        };
293        let dot = export_dot(kg.connection(), &config).unwrap();
294        assert!(dot.contains("rankdir=TB;"));
295    }
296
297    #[test]
298    fn test_export_dot_custom_node_shape() {
299        let kg = KnowledgeGraph::open_in_memory().unwrap();
300        kg.insert_entity(&Entity::new("concept", "AI")).unwrap();
301
302        let config = DotConfig {
303            node_shape: "box".to_string(),
304            ..Default::default()
305        };
306        let dot = export_dot(kg.connection(), &config).unwrap();
307        assert!(dot.contains("node [shape=box];"));
308    }
309
310    #[test]
311    fn test_export_dot_color_by_type() {
312        let kg = KnowledgeGraph::open_in_memory().unwrap();
313        kg.insert_entity(&Entity::new("paper", "Paper A")).unwrap();
314        kg.insert_entity(&Entity::new("author", "Alice")).unwrap();
315
316        let config = DotConfig {
317            color_by_type: true,
318            ..Default::default()
319        };
320        let dot = export_dot(kg.connection(), &config).unwrap();
321        // Two different types should each get a color attribute
322        assert!(dot.contains("color="));
323    }
324
325    #[test]
326    fn test_export_dot_no_color() {
327        let kg = KnowledgeGraph::open_in_memory().unwrap();
328        kg.insert_entity(&Entity::new("concept", "AI")).unwrap();
329
330        let config = DotConfig {
331            color_by_type: false,
332            ..Default::default()
333        };
334        let dot = export_dot(kg.connection(), &config).unwrap();
335        // No color attribute when disabled
336        assert!(!dot.contains("color="));
337    }
338
339    #[test]
340    fn test_export_dot_max_nodes() {
341        let kg = KnowledgeGraph::open_in_memory().unwrap();
342        for i in 0..5 {
343            kg.insert_entity(&Entity::new("concept", &format!("Concept {i}")))
344                .unwrap();
345        }
346
347        let config = DotConfig {
348            max_nodes: Some(3),
349            ..Default::default()
350        };
351        let dot = export_dot(kg.connection(), &config).unwrap();
352
353        // Only the first 3 nodes should appear
354        assert!(dot.contains("Concept 0"));
355        assert!(dot.contains("Concept 1"));
356        assert!(dot.contains("Concept 2"));
357        assert!(!dot.contains("Concept 3"));
358        assert!(!dot.contains("Concept 4"));
359    }
360
361    #[test]
362    fn test_export_dot_edges_filtered_by_max_nodes() {
363        let kg = KnowledgeGraph::open_in_memory().unwrap();
364        let id1 = kg.insert_entity(&Entity::new("concept", "A")).unwrap();
365        let id2 = kg.insert_entity(&Entity::new("concept", "B")).unwrap();
366        let id3 = kg.insert_entity(&Entity::new("concept", "C")).unwrap();
367        kg.insert_relation(&Relation::new(id1, id2, "link", 1.0).unwrap())
368            .unwrap();
369        kg.insert_relation(&Relation::new(id2, id3, "link", 1.0).unwrap())
370            .unwrap();
371
372        // Only include first 2 nodes; edge id2->id3 should be omitted
373        let config = DotConfig {
374            max_nodes: Some(2),
375            color_by_type: false,
376            ..Default::default()
377        };
378        let dot = export_dot(kg.connection(), &config).unwrap();
379
380        assert!(dot.contains(&format!("{} -> {}", id1, id2)));
381        assert!(!dot.contains(&format!("{} -> {}", id2, id3)));
382    }
383
384    #[test]
385    fn test_escape_dot_label() {
386        assert_eq!(escape_dot_label("hello"), "hello");
387        assert_eq!(escape_dot_label("say \"hi\""), "say \\\"hi\\\"");
388        assert_eq!(escape_dot_label("line\nnew"), "line\\nnew");
389        assert_eq!(escape_dot_label("back\\slash"), "back\\\\slash");
390    }
391
392    fn setup() -> KnowledgeGraph {
393        KnowledgeGraph::open_in_memory().unwrap()
394    }
395
396    #[test]
397    fn test_export_empty_graph() {
398        let kg = setup();
399        let result = export_d3_json(kg.connection()).unwrap();
400
401        assert_eq!(result.nodes.len(), 0);
402        assert_eq!(result.links.len(), 0);
403        assert_eq!(result.metadata.node_count, 0);
404        assert_eq!(result.metadata.edge_count, 0);
405        assert!(!result.metadata.exported_at.is_empty());
406    }
407
408    #[test]
409    fn test_export_nodes_only() {
410        let kg = setup();
411
412        let mut paper = Entity::new("paper", "Deep Learning");
413        paper.set_property("year", serde_json::json!(2024));
414        kg.insert_entity(&paper).unwrap();
415
416        let result = export_d3_json(kg.connection()).unwrap();
417
418        assert_eq!(result.nodes.len(), 1);
419        assert_eq!(result.links.len(), 0);
420        assert_eq!(result.metadata.node_count, 1);
421        assert_eq!(result.metadata.edge_count, 0);
422
423        let node = &result.nodes[0];
424        assert_eq!(node.name, "Deep Learning");
425        assert_eq!(node.node_type, "paper");
426        assert_eq!(node.properties["year"], serde_json::json!(2024));
427    }
428
429    #[test]
430    fn test_export_nodes_and_links() {
431        let kg = setup();
432
433        let id1 = kg.insert_entity(&Entity::new("paper", "Paper A")).unwrap();
434        let id2 = kg.insert_entity(&Entity::new("paper", "Paper B")).unwrap();
435        kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
436            .unwrap();
437
438        let result = export_d3_json(kg.connection()).unwrap();
439
440        assert_eq!(result.nodes.len(), 2);
441        assert_eq!(result.links.len(), 1);
442        assert_eq!(result.metadata.node_count, 2);
443        assert_eq!(result.metadata.edge_count, 1);
444
445        let link = &result.links[0];
446        assert_eq!(link.source, id1);
447        assert_eq!(link.target, id2);
448        assert_eq!(link.link_type, "cites");
449        assert!((link.weight - 0.8).abs() < 1e-9);
450    }
451
452    #[test]
453    fn test_export_json_serialization() {
454        let kg = setup();
455
456        let id1 = kg
457            .insert_entity(&Entity::new("concept", "Neural Networks"))
458            .unwrap();
459        let id2 = kg
460            .insert_entity(&Entity::new("concept", "Deep Learning"))
461            .unwrap();
462        kg.insert_relation(&Relation::new(id1, id2, "related_to", 0.9).unwrap())
463            .unwrap();
464
465        let graph = export_d3_json(kg.connection()).unwrap();
466        let json_str = serde_json::to_string_pretty(&graph).unwrap();
467        let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
468
469        assert!(parsed["nodes"].is_array());
470        assert!(parsed["links"].is_array());
471        assert!(parsed["metadata"].is_object());
472        assert_eq!(parsed["metadata"]["node_count"], 2);
473        assert_eq!(parsed["metadata"]["edge_count"], 1);
474        assert!(parsed["metadata"]["exported_at"].is_string());
475
476        let nodes = parsed["nodes"].as_array().unwrap();
477        assert_eq!(nodes[0]["name"], "Neural Networks");
478        assert_eq!(nodes[0]["type"], "concept");
479
480        let links = parsed["links"].as_array().unwrap();
481        assert_eq!(links[0]["type"], "related_to");
482        assert_eq!(links[0]["weight"], 0.9);
483    }
484
485    #[test]
486    fn test_export_multiple_relations() {
487        let kg = setup();
488
489        let id1 = kg.insert_entity(&Entity::new("author", "Alice")).unwrap();
490        let id2 = kg.insert_entity(&Entity::new("paper", "Paper X")).unwrap();
491        let id3 = kg.insert_entity(&Entity::new("topic", "ML")).unwrap();
492
493        kg.insert_relation(&Relation::new(id1, id2, "wrote", 1.0).unwrap())
494            .unwrap();
495        kg.insert_relation(&Relation::new(id2, id3, "covers", 0.7).unwrap())
496            .unwrap();
497
498        let result = export_d3_json(kg.connection()).unwrap();
499
500        assert_eq!(result.nodes.len(), 3);
501        assert_eq!(result.links.len(), 2);
502        assert_eq!(result.metadata.node_count, 3);
503        assert_eq!(result.metadata.edge_count, 2);
504    }
505}