tensorlogic_quantrs_hooks/
factor_graph_viz.rs1use std::fmt::Write;
8
9use serde::{Deserialize, Serialize};
10
11use crate::graph::FactorGraph;
12
13#[derive(Debug, Clone, Default)]
23pub struct FactorGraphModel {
24 pub variables: Vec<VizVariableNode>,
26 pub factors: Vec<VizFactorNode>,
28}
29
30#[derive(Debug, Clone)]
32pub struct VizVariableNode {
33 pub name: String,
35 pub domain_size: usize,
37}
38
39#[derive(Debug, Clone)]
41pub struct VizFactorNode {
42 pub name: String,
44 pub variable_indices: Vec<usize>,
46}
47
48impl FactorGraphModel {
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn from_factor_graph(fg: &FactorGraph) -> Self {
58 let mut var_names: Vec<String> = fg.variable_names().cloned().collect();
60 var_names.sort();
61
62 let mut name_to_idx: std::collections::HashMap<String, usize> =
63 std::collections::HashMap::new();
64
65 let mut model = Self::new();
66 for name in &var_names {
67 let card = fg.get_variable(name).map(|v| v.cardinality).unwrap_or(2);
68 let idx = model.add_variable(name.clone(), card);
69 name_to_idx.insert(name.clone(), idx);
70 }
71
72 for factor in fg.factors() {
73 let indices: Vec<usize> = factor
74 .variables
75 .iter()
76 .filter_map(|v| name_to_idx.get(v).copied())
77 .collect();
78 model.add_factor(factor.name.clone(), indices);
79 }
80
81 model
82 }
83
84 pub fn add_variable(&mut self, name: impl Into<String>, domain_size: usize) -> usize {
86 let idx = self.variables.len();
87 self.variables.push(VizVariableNode {
88 name: name.into(),
89 domain_size,
90 });
91 idx
92 }
93
94 pub fn add_factor(&mut self, name: impl Into<String>, variable_indices: Vec<usize>) {
96 self.factors.push(VizFactorNode {
97 name: name.into(),
98 variable_indices,
99 });
100 }
101
102 pub fn variable_count(&self) -> usize {
104 self.variables.len()
105 }
106
107 pub fn factor_count(&self) -> usize {
109 self.factors.len()
110 }
111
112 pub fn edge_count(&self) -> usize {
114 self.factors.iter().map(|f| f.variable_indices.len()).sum()
115 }
116}
117
118#[derive(Debug, Clone, Default, Serialize, Deserialize)]
124pub struct FactorGraphStats {
125 pub variable_count: usize,
127 pub factor_count: usize,
129 pub edge_count: usize,
131 pub max_factor_arity: usize,
133 pub avg_factor_arity: f64,
135 pub max_variable_degree: usize,
137 pub avg_variable_degree: f64,
139 pub is_tree: bool,
141 pub treewidth_upper_bound: usize,
143}
144
145impl FactorGraphStats {
146 pub fn compute(model: &FactorGraphModel) -> Self {
148 let variable_count = model.variable_count();
149 let factor_count = model.factor_count();
150 let edge_count = model.edge_count();
151
152 let max_factor_arity = model
153 .factors
154 .iter()
155 .map(|f| f.variable_indices.len())
156 .max()
157 .unwrap_or(0);
158
159 let avg_factor_arity = if factor_count > 0 {
160 edge_count as f64 / factor_count as f64
161 } else {
162 0.0
163 };
164
165 let mut var_degrees = vec![0usize; variable_count];
167 for factor in &model.factors {
168 for &vi in &factor.variable_indices {
169 if vi < variable_count {
170 var_degrees[vi] += 1;
171 }
172 }
173 }
174
175 let max_variable_degree = var_degrees.iter().copied().max().unwrap_or(0);
176 let avg_variable_degree = if variable_count > 0 {
177 var_degrees.iter().sum::<usize>() as f64 / variable_count as f64
178 } else {
179 0.0
180 };
181
182 let total_nodes = variable_count + factor_count;
187 let is_tree = total_nodes > 0 && edge_count + 1 == total_nodes;
188
189 let treewidth_upper_bound = if max_factor_arity > 0 {
190 max_factor_arity - 1
191 } else {
192 0
193 };
194
195 Self {
196 variable_count,
197 factor_count,
198 edge_count,
199 max_factor_arity,
200 avg_factor_arity,
201 max_variable_degree,
202 avg_variable_degree,
203 is_tree,
204 treewidth_upper_bound,
205 }
206 }
207
208 pub fn summary(&self) -> String {
210 format!(
211 "{} vars, {} factors, {} edges, treewidth\u{2264}{}{}",
212 self.variable_count,
213 self.factor_count,
214 self.edge_count,
215 self.treewidth_upper_bound,
216 if self.is_tree { " (tree)" } else { "" }
217 )
218 }
219}
220
221pub fn render_ascii(model: &FactorGraphModel) -> String {
227 let mut out = String::new();
228
229 let _ = writeln!(out, "Factor Graph:");
230
231 let var_descs: Vec<String> = model
233 .variables
234 .iter()
235 .map(|v| format!("{}({})", v.name, v.domain_size))
236 .collect();
237 let _ = writeln!(
238 out,
239 " Variables ({}): {}",
240 model.variable_count(),
241 var_descs.join(", ")
242 );
243
244 let fac_descs: Vec<String> = model
246 .factors
247 .iter()
248 .map(|f| format!("{}({})", f.name, f.variable_indices.len()))
249 .collect();
250 let _ = writeln!(
251 out,
252 " Factors ({}): {}",
253 model.factor_count(),
254 fac_descs.join(", ")
255 );
256
257 let _ = writeln!(out, " Connections:");
259 for factor in &model.factors {
260 let var_names: Vec<&str> = factor
261 .variable_indices
262 .iter()
263 .filter_map(|&i| model.variables.get(i).map(|v| v.name.as_str()))
264 .collect();
265 let _ = writeln!(
266 out,
267 " {} \u{2500}\u{2500} {}",
268 factor.name,
269 var_names.join(", ")
270 );
271 }
272
273 out
274}
275
276pub fn render_dot(model: &FactorGraphModel) -> String {
281 let mut dot = String::new();
282
283 let _ = writeln!(dot, "graph FactorGraph {{");
284 let _ = writeln!(dot, " rankdir=LR;");
285
286 for (i, var) in model.variables.iter().enumerate() {
288 let _ = writeln!(dot, " v{} [label=\"{}\", shape=circle];", i, var.name);
289 }
290
291 for (i, factor) in model.factors.iter().enumerate() {
293 let _ = writeln!(
294 dot,
295 " f{} [label=\"{}\", shape=square, style=filled, fillcolor=lightgray];",
296 i, factor.name
297 );
298 for &vi in &factor.variable_indices {
299 let _ = writeln!(dot, " f{} -- v{};", i, vi);
300 }
301 }
302
303 let _ = writeln!(dot, "}}");
304 dot
305}
306
307#[cfg(test)]
312mod tests {
313 use super::*;
314
315 fn chain_model() -> FactorGraphModel {
319 let mut m = FactorGraphModel::new();
320 let a = m.add_variable("A", 2);
321 let b = m.add_variable("B", 2);
322 let c = m.add_variable("C", 2);
323 m.add_factor("f1", vec![a, b]);
324 m.add_factor("f2", vec![b, c]);
325 m
326 }
327
328 fn loopy_model() -> FactorGraphModel {
330 let mut m = FactorGraphModel::new();
331 let a = m.add_variable("A", 2);
332 let b = m.add_variable("B", 2);
333 let c = m.add_variable("C", 2);
334 m.add_factor("f1", vec![a, b]);
335 m.add_factor("f2", vec![b, c]);
336 m.add_factor("f3", vec![a, c]);
337 m
338 }
339
340 #[test]
343 fn test_model_new_empty() {
344 let m = FactorGraphModel::new();
345 assert_eq!(m.variable_count(), 0);
346 assert_eq!(m.factor_count(), 0);
347 assert_eq!(m.edge_count(), 0);
348 }
349
350 #[test]
351 fn test_model_add_variable() {
352 let mut m = FactorGraphModel::new();
353 let idx = m.add_variable("X", 4);
354 assert_eq!(idx, 0);
355 assert_eq!(m.variable_count(), 1);
356 assert_eq!(m.variables[0].domain_size, 4);
357 }
358
359 #[test]
360 fn test_model_add_factor() {
361 let mut m = FactorGraphModel::new();
362 let a = m.add_variable("A", 2);
363 m.add_factor("f1", vec![a]);
364 assert_eq!(m.factor_count(), 1);
365 }
366
367 #[test]
368 fn test_model_counts() {
369 let m = chain_model();
370 assert_eq!(m.variable_count(), 3);
371 assert_eq!(m.factor_count(), 2);
372 assert_eq!(m.edge_count(), 4);
373 }
374
375 #[test]
378 fn test_stats_empty() {
379 let m = FactorGraphModel::new();
380 let s = FactorGraphStats::compute(&m);
381 assert_eq!(s.variable_count, 0);
382 assert_eq!(s.factor_count, 0);
383 assert_eq!(s.edge_count, 0);
384 assert_eq!(s.max_factor_arity, 0);
385 assert!((s.avg_factor_arity - 0.0).abs() < f64::EPSILON);
386 }
387
388 #[test]
389 fn test_stats_simple_chain() {
390 let s = FactorGraphStats::compute(&chain_model());
391 assert_eq!(s.variable_count, 3);
392 assert_eq!(s.factor_count, 2);
393 assert_eq!(s.edge_count, 4);
394 }
395
396 #[test]
397 fn test_stats_max_factor_arity() {
398 let mut m = FactorGraphModel::new();
399 let a = m.add_variable("A", 2);
400 let b = m.add_variable("B", 2);
401 let c = m.add_variable("C", 2);
402 m.add_factor("big", vec![a, b, c]);
403 let s = FactorGraphStats::compute(&m);
404 assert_eq!(s.max_factor_arity, 3);
405 }
406
407 #[test]
408 fn test_stats_avg_factor_arity() {
409 let s = FactorGraphStats::compute(&chain_model());
411 assert!((s.avg_factor_arity - 2.0).abs() < f64::EPSILON);
412 }
413
414 #[test]
415 fn test_stats_variable_degree() {
416 let s = FactorGraphStats::compute(&chain_model());
418 assert_eq!(s.max_variable_degree, 2);
419 }
420
421 #[test]
422 fn test_stats_is_tree_true() {
423 let s = FactorGraphStats::compute(&chain_model());
425 assert!(s.is_tree);
426 }
427
428 #[test]
429 fn test_stats_is_tree_false() {
430 let s = FactorGraphStats::compute(&loopy_model());
432 assert!(!s.is_tree);
433 }
434
435 #[test]
436 fn test_stats_treewidth() {
437 let s = FactorGraphStats::compute(&chain_model());
438 assert_eq!(s.treewidth_upper_bound, 1);
440 }
441
442 #[test]
443 fn test_stats_summary() {
444 let s = FactorGraphStats::compute(&chain_model());
445 let summary = s.summary();
446 assert!(summary.contains("vars"));
447 assert!(summary.contains("factors"));
448 }
449
450 #[test]
453 fn test_render_ascii_header() {
454 let out = render_ascii(&chain_model());
455 assert!(out.contains("Factor Graph:"));
456 }
457
458 #[test]
459 fn test_render_ascii_variables() {
460 let out = render_ascii(&chain_model());
461 assert!(out.contains("A(2)"));
462 assert!(out.contains("B(2)"));
463 assert!(out.contains("C(2)"));
464 }
465
466 #[test]
467 fn test_render_ascii_connections() {
468 let out = render_ascii(&chain_model());
469 assert!(out.contains("f1"));
471 assert!(out.contains("A"));
472 assert!(out.contains("B"));
473 }
474
475 #[test]
476 fn test_render_dot_undirected() {
477 let dot = render_dot(&chain_model());
478 assert!(dot.starts_with("graph "));
480 assert!(!dot.contains("digraph"));
481 }
482
483 #[test]
484 fn test_render_dot_nodes() {
485 let dot = render_dot(&chain_model());
486 assert!(dot.contains("v0"));
488 assert!(dot.contains("shape=circle"));
489 assert!(dot.contains("f0"));
491 assert!(dot.contains("shape=square"));
492 }
493}