1use chrono::Utc;
8use rusqlite::Connection;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12use crate::error::Result;
13
14const TYPE_COLORS: &[&str] = &[
16 "blue", "red", "green", "orange", "purple", "brown", "cyan", "magenta",
17];
18
19#[derive(Debug, Clone)]
21pub struct DotConfig {
22 pub rankdir: String,
24 pub node_shape: String,
26 pub color_by_type: bool,
28 pub max_nodes: Option<usize>,
30}
31
32impl Default for DotConfig {
33 fn default() -> Self {
34 Self {
35 rankdir: "LR".to_string(),
36 node_shape: "ellipse".to_string(),
37 color_by_type: true,
38 max_nodes: None,
39 }
40 }
41}
42
43pub fn export_dot(conn: &Connection, config: &DotConfig) -> Result<String> {
58 let nodes = query_nodes(conn)?;
59 let links = query_links(conn)?;
60
61 let nodes = if let Some(max) = config.max_nodes {
62 nodes.into_iter().take(max).collect::<Vec<_>>()
63 } else {
64 nodes
65 };
66
67 let node_ids: std::collections::HashSet<i64> = nodes.iter().map(|n| n.id).collect();
69
70 let mut type_color_map: HashMap<String, &str> = HashMap::new();
72 if config.color_by_type {
73 let mut color_idx = 0;
74 for node in &nodes {
75 type_color_map
76 .entry(node.node_type.clone())
77 .or_insert_with(|| {
78 let color = TYPE_COLORS[color_idx % TYPE_COLORS.len()];
79 color_idx += 1;
80 color
81 });
82 }
83 }
84
85 let mut dot = String::new();
86 dot.push_str("digraph knowledge_graph {\n");
87 dot.push_str(&format!(" rankdir={};\n", config.rankdir));
88 dot.push_str(&format!(" node [shape={}];\n", config.node_shape));
89 dot.push('\n');
90
91 for node in &nodes {
93 let label = escape_dot_label(&node.name);
94 if config.color_by_type {
95 let color = type_color_map
96 .get(&node.node_type)
97 .copied()
98 .unwrap_or("black");
99 dot.push_str(&format!(
100 " {} [label=\"{}\" color={}];\n",
101 node.id, label, color
102 ));
103 } else {
104 dot.push_str(&format!(" {} [label=\"{}\"];\n", node.id, label));
105 }
106 }
107
108 dot.push('\n');
109
110 for link in &links {
112 if node_ids.contains(&link.source) && node_ids.contains(&link.target) {
113 let rel_label = escape_dot_label(&link.link_type);
114 dot.push_str(&format!(
115 " {} -> {} [label=\"{}\" weight={:.4}];\n",
116 link.source, link.target, rel_label, link.weight
117 ));
118 }
119 }
120
121 dot.push_str("}\n");
122 Ok(dot)
123}
124
125fn escape_dot_label(s: &str) -> String {
127 s.replace('\\', "\\\\")
128 .replace('"', "\\\"")
129 .replace('\n', "\\n")
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct D3Node {
135 pub id: i64,
136 pub name: String,
137 #[serde(rename = "type")]
138 pub node_type: String,
139 pub properties: HashMap<String, serde_json::Value>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct D3Link {
145 pub source: i64,
146 pub target: i64,
147 #[serde(rename = "type")]
148 pub link_type: String,
149 pub weight: f64,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct D3ExportMetadata {
155 pub node_count: usize,
156 pub edge_count: usize,
157 pub exported_at: String,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct D3ExportGraph {
163 pub nodes: Vec<D3Node>,
164 pub links: Vec<D3Link>,
165 pub metadata: D3ExportMetadata,
166}
167
168pub fn export_d3_json(conn: &Connection) -> Result<D3ExportGraph> {
173 let nodes = query_nodes(conn)?;
174 let links = query_links(conn)?;
175
176 let metadata = D3ExportMetadata {
177 node_count: nodes.len(),
178 edge_count: links.len(),
179 exported_at: Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(),
180 };
181
182 Ok(D3ExportGraph {
183 nodes,
184 links,
185 metadata,
186 })
187}
188
189fn query_nodes(conn: &Connection) -> Result<Vec<D3Node>> {
190 let mut stmt =
191 conn.prepare("SELECT id, entity_type, name, properties FROM kg_entities ORDER BY id")?;
192
193 let rows = stmt.query_map([], |row| {
194 let id: i64 = row.get(0)?;
195 let node_type: String = row.get(1)?;
196 let name: String = row.get(2)?;
197 let properties_json: String = row.get(3)?;
198 Ok((id, node_type, name, properties_json))
199 })?;
200
201 let mut nodes = Vec::new();
202 for row in rows {
203 let (id, node_type, name, properties_json) = row?;
204 let properties: HashMap<String, serde_json::Value> =
205 serde_json::from_str(&properties_json).unwrap_or_default();
206 nodes.push(D3Node {
207 id,
208 name,
209 node_type,
210 properties,
211 });
212 }
213
214 Ok(nodes)
215}
216
217fn query_links(conn: &Connection) -> Result<Vec<D3Link>> {
218 let mut stmt = conn
219 .prepare("SELECT source_id, target_id, rel_type, weight FROM kg_relations ORDER BY id")?;
220
221 let rows = stmt.query_map([], |row| {
222 let source: i64 = row.get(0)?;
223 let target: i64 = row.get(1)?;
224 let link_type: String = row.get(2)?;
225 let weight: f64 = row.get(3)?;
226 Ok((source, target, link_type, weight))
227 })?;
228
229 let mut links = Vec::new();
230 for row in rows {
231 let (source, target, link_type, weight) = row?;
232 links.push(D3Link {
233 source,
234 target,
235 link_type,
236 weight,
237 });
238 }
239
240 Ok(links)
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::{Entity, KnowledgeGraph, Relation};
247
248 #[test]
251 fn test_export_dot_empty_graph() {
252 let kg = KnowledgeGraph::open_in_memory().unwrap();
253 let config = DotConfig::default();
254 let dot = export_dot(kg.connection(), &config).unwrap();
255
256 assert!(dot.contains("digraph knowledge_graph {"));
257 assert!(dot.contains("rankdir=LR;"));
258 assert!(dot.contains("node [shape=ellipse];"));
259 assert!(dot.ends_with("}\n"));
260 }
261
262 #[test]
263 fn test_export_dot_nodes_and_edges() {
264 let kg = KnowledgeGraph::open_in_memory().unwrap();
265 let id1 = kg
266 .insert_entity(&Entity::new("concept", "Deep Learning"))
267 .unwrap();
268 let id2 = kg
269 .insert_entity(&Entity::new("concept", "Neural Networks"))
270 .unwrap();
271 kg.insert_relation(&Relation::new(id1, id2, "related_to", 0.8).unwrap())
272 .unwrap();
273
274 let config = DotConfig::default();
275 let dot = export_dot(kg.connection(), &config).unwrap();
276
277 assert!(dot.contains("Deep Learning"));
278 assert!(dot.contains("Neural Networks"));
279 assert!(dot.contains("related_to"));
280 assert!(dot.contains(&format!("{} -> {}", id1, id2)));
281 assert!(dot.contains("weight=0.8000"));
282 }
283
284 #[test]
285 fn test_export_dot_rankdir_tb() {
286 let kg = KnowledgeGraph::open_in_memory().unwrap();
287 kg.insert_entity(&Entity::new("concept", "AI")).unwrap();
288
289 let config = DotConfig {
290 rankdir: "TB".to_string(),
291 ..Default::default()
292 };
293 let dot = export_dot(kg.connection(), &config).unwrap();
294 assert!(dot.contains("rankdir=TB;"));
295 }
296
297 #[test]
298 fn test_export_dot_custom_node_shape() {
299 let kg = KnowledgeGraph::open_in_memory().unwrap();
300 kg.insert_entity(&Entity::new("concept", "AI")).unwrap();
301
302 let config = DotConfig {
303 node_shape: "box".to_string(),
304 ..Default::default()
305 };
306 let dot = export_dot(kg.connection(), &config).unwrap();
307 assert!(dot.contains("node [shape=box];"));
308 }
309
310 #[test]
311 fn test_export_dot_color_by_type() {
312 let kg = KnowledgeGraph::open_in_memory().unwrap();
313 kg.insert_entity(&Entity::new("paper", "Paper A")).unwrap();
314 kg.insert_entity(&Entity::new("author", "Alice")).unwrap();
315
316 let config = DotConfig {
317 color_by_type: true,
318 ..Default::default()
319 };
320 let dot = export_dot(kg.connection(), &config).unwrap();
321 assert!(dot.contains("color="));
323 }
324
325 #[test]
326 fn test_export_dot_no_color() {
327 let kg = KnowledgeGraph::open_in_memory().unwrap();
328 kg.insert_entity(&Entity::new("concept", "AI")).unwrap();
329
330 let config = DotConfig {
331 color_by_type: false,
332 ..Default::default()
333 };
334 let dot = export_dot(kg.connection(), &config).unwrap();
335 assert!(!dot.contains("color="));
337 }
338
339 #[test]
340 fn test_export_dot_max_nodes() {
341 let kg = KnowledgeGraph::open_in_memory().unwrap();
342 for i in 0..5 {
343 kg.insert_entity(&Entity::new("concept", &format!("Concept {i}")))
344 .unwrap();
345 }
346
347 let config = DotConfig {
348 max_nodes: Some(3),
349 ..Default::default()
350 };
351 let dot = export_dot(kg.connection(), &config).unwrap();
352
353 assert!(dot.contains("Concept 0"));
355 assert!(dot.contains("Concept 1"));
356 assert!(dot.contains("Concept 2"));
357 assert!(!dot.contains("Concept 3"));
358 assert!(!dot.contains("Concept 4"));
359 }
360
361 #[test]
362 fn test_export_dot_edges_filtered_by_max_nodes() {
363 let kg = KnowledgeGraph::open_in_memory().unwrap();
364 let id1 = kg.insert_entity(&Entity::new("concept", "A")).unwrap();
365 let id2 = kg.insert_entity(&Entity::new("concept", "B")).unwrap();
366 let id3 = kg.insert_entity(&Entity::new("concept", "C")).unwrap();
367 kg.insert_relation(&Relation::new(id1, id2, "link", 1.0).unwrap())
368 .unwrap();
369 kg.insert_relation(&Relation::new(id2, id3, "link", 1.0).unwrap())
370 .unwrap();
371
372 let config = DotConfig {
374 max_nodes: Some(2),
375 color_by_type: false,
376 ..Default::default()
377 };
378 let dot = export_dot(kg.connection(), &config).unwrap();
379
380 assert!(dot.contains(&format!("{} -> {}", id1, id2)));
381 assert!(!dot.contains(&format!("{} -> {}", id2, id3)));
382 }
383
384 #[test]
385 fn test_escape_dot_label() {
386 assert_eq!(escape_dot_label("hello"), "hello");
387 assert_eq!(escape_dot_label("say \"hi\""), "say \\\"hi\\\"");
388 assert_eq!(escape_dot_label("line\nnew"), "line\\nnew");
389 assert_eq!(escape_dot_label("back\\slash"), "back\\\\slash");
390 }
391
392 fn setup() -> KnowledgeGraph {
393 KnowledgeGraph::open_in_memory().unwrap()
394 }
395
396 #[test]
397 fn test_export_empty_graph() {
398 let kg = setup();
399 let result = export_d3_json(kg.connection()).unwrap();
400
401 assert_eq!(result.nodes.len(), 0);
402 assert_eq!(result.links.len(), 0);
403 assert_eq!(result.metadata.node_count, 0);
404 assert_eq!(result.metadata.edge_count, 0);
405 assert!(!result.metadata.exported_at.is_empty());
406 }
407
408 #[test]
409 fn test_export_nodes_only() {
410 let kg = setup();
411
412 let mut paper = Entity::new("paper", "Deep Learning");
413 paper.set_property("year", serde_json::json!(2024));
414 kg.insert_entity(&paper).unwrap();
415
416 let result = export_d3_json(kg.connection()).unwrap();
417
418 assert_eq!(result.nodes.len(), 1);
419 assert_eq!(result.links.len(), 0);
420 assert_eq!(result.metadata.node_count, 1);
421 assert_eq!(result.metadata.edge_count, 0);
422
423 let node = &result.nodes[0];
424 assert_eq!(node.name, "Deep Learning");
425 assert_eq!(node.node_type, "paper");
426 assert_eq!(node.properties["year"], serde_json::json!(2024));
427 }
428
429 #[test]
430 fn test_export_nodes_and_links() {
431 let kg = setup();
432
433 let id1 = kg.insert_entity(&Entity::new("paper", "Paper A")).unwrap();
434 let id2 = kg.insert_entity(&Entity::new("paper", "Paper B")).unwrap();
435 kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
436 .unwrap();
437
438 let result = export_d3_json(kg.connection()).unwrap();
439
440 assert_eq!(result.nodes.len(), 2);
441 assert_eq!(result.links.len(), 1);
442 assert_eq!(result.metadata.node_count, 2);
443 assert_eq!(result.metadata.edge_count, 1);
444
445 let link = &result.links[0];
446 assert_eq!(link.source, id1);
447 assert_eq!(link.target, id2);
448 assert_eq!(link.link_type, "cites");
449 assert!((link.weight - 0.8).abs() < 1e-9);
450 }
451
452 #[test]
453 fn test_export_json_serialization() {
454 let kg = setup();
455
456 let id1 = kg
457 .insert_entity(&Entity::new("concept", "Neural Networks"))
458 .unwrap();
459 let id2 = kg
460 .insert_entity(&Entity::new("concept", "Deep Learning"))
461 .unwrap();
462 kg.insert_relation(&Relation::new(id1, id2, "related_to", 0.9).unwrap())
463 .unwrap();
464
465 let graph = export_d3_json(kg.connection()).unwrap();
466 let json_str = serde_json::to_string_pretty(&graph).unwrap();
467 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
468
469 assert!(parsed["nodes"].is_array());
470 assert!(parsed["links"].is_array());
471 assert!(parsed["metadata"].is_object());
472 assert_eq!(parsed["metadata"]["node_count"], 2);
473 assert_eq!(parsed["metadata"]["edge_count"], 1);
474 assert!(parsed["metadata"]["exported_at"].is_string());
475
476 let nodes = parsed["nodes"].as_array().unwrap();
477 assert_eq!(nodes[0]["name"], "Neural Networks");
478 assert_eq!(nodes[0]["type"], "concept");
479
480 let links = parsed["links"].as_array().unwrap();
481 assert_eq!(links[0]["type"], "related_to");
482 assert_eq!(links[0]["weight"], 0.9);
483 }
484
485 #[test]
486 fn test_export_multiple_relations() {
487 let kg = setup();
488
489 let id1 = kg.insert_entity(&Entity::new("author", "Alice")).unwrap();
490 let id2 = kg.insert_entity(&Entity::new("paper", "Paper X")).unwrap();
491 let id3 = kg.insert_entity(&Entity::new("topic", "ML")).unwrap();
492
493 kg.insert_relation(&Relation::new(id1, id2, "wrote", 1.0).unwrap())
494 .unwrap();
495 kg.insert_relation(&Relation::new(id2, id3, "covers", 0.7).unwrap())
496 .unwrap();
497
498 let result = export_d3_json(kg.connection()).unwrap();
499
500 assert_eq!(result.nodes.len(), 3);
501 assert_eq!(result.links.len(), 2);
502 assert_eq!(result.metadata.node_count, 3);
503 assert_eq!(result.metadata.edge_count, 2);
504 }
505}