Skip to main content

tensorlogic_quantrs_hooks/
factor_graph_viz.rs

1//! Factor graph visualization and structural analysis.
2//!
3//! Renders factor graphs as ASCII art and DOT (Graphviz) format,
4//! and computes structural statistics (degree distributions, tree
5//! detection, treewidth bounds).
6
7use std::fmt::Write;
8
9use serde::{Deserialize, Serialize};
10
11use crate::graph::FactorGraph;
12
13// ---------------------------------------------------------------------------
14// Lightweight visualization model
15// ---------------------------------------------------------------------------
16
17/// A lightweight factor graph representation for visualization.
18///
19/// This is intentionally decoupled from [`FactorGraph`] so that callers
20/// can build ad-hoc models (e.g. from external data) and still use the
21/// rendering / statistics helpers.
22#[derive(Debug, Clone, Default)]
23pub struct FactorGraphModel {
24    /// Variable nodes.
25    pub variables: Vec<VizVariableNode>,
26    /// Factor nodes.
27    pub factors: Vec<VizFactorNode>,
28}
29
30/// A variable node inside a [`FactorGraphModel`].
31#[derive(Debug, Clone)]
32pub struct VizVariableNode {
33    /// Human-readable name.
34    pub name: String,
35    /// Number of values the variable can take.
36    pub domain_size: usize,
37}
38
39/// A factor node inside a [`FactorGraphModel`].
40#[derive(Debug, Clone)]
41pub struct VizFactorNode {
42    /// Human-readable name.
43    pub name: String,
44    /// Indices into [`FactorGraphModel::variables`] that this factor touches.
45    pub variable_indices: Vec<usize>,
46}
47
48impl FactorGraphModel {
49    /// Create a new, empty model.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Build a [`FactorGraphModel`] from an existing [`FactorGraph`].
55    ///
56    /// Variable ordering is arbitrary (HashMap iteration order).
57    pub fn from_factor_graph(fg: &FactorGraph) -> Self {
58        // Collect variables into a stable ordering.
59        let mut var_names: Vec<String> = fg.variable_names().cloned().collect();
60        var_names.sort();
61
62        let mut name_to_idx: std::collections::HashMap<String, usize> =
63            std::collections::HashMap::new();
64
65        let mut model = Self::new();
66        for name in &var_names {
67            let card = fg.get_variable(name).map(|v| v.cardinality).unwrap_or(2);
68            let idx = model.add_variable(name.clone(), card);
69            name_to_idx.insert(name.clone(), idx);
70        }
71
72        for factor in fg.factors() {
73            let indices: Vec<usize> = factor
74                .variables
75                .iter()
76                .filter_map(|v| name_to_idx.get(v).copied())
77                .collect();
78            model.add_factor(factor.name.clone(), indices);
79        }
80
81        model
82    }
83
84    /// Add a variable node. Returns the index of the newly added variable.
85    pub fn add_variable(&mut self, name: impl Into<String>, domain_size: usize) -> usize {
86        let idx = self.variables.len();
87        self.variables.push(VizVariableNode {
88            name: name.into(),
89            domain_size,
90        });
91        idx
92    }
93
94    /// Add a factor node connecting the given variable indices.
95    pub fn add_factor(&mut self, name: impl Into<String>, variable_indices: Vec<usize>) {
96        self.factors.push(VizFactorNode {
97            name: name.into(),
98            variable_indices,
99        });
100    }
101
102    /// Number of variable nodes.
103    pub fn variable_count(&self) -> usize {
104        self.variables.len()
105    }
106
107    /// Number of factor nodes.
108    pub fn factor_count(&self) -> usize {
109        self.factors.len()
110    }
111
112    /// Total number of edges (factor-variable connections).
113    pub fn edge_count(&self) -> usize {
114        self.factors.iter().map(|f| f.variable_indices.len()).sum()
115    }
116}
117
118// ---------------------------------------------------------------------------
119// Statistics
120// ---------------------------------------------------------------------------
121
122/// Structural statistics for a factor graph.
123#[derive(Debug, Clone, Default, Serialize, Deserialize)]
124pub struct FactorGraphStats {
125    /// Number of variable nodes.
126    pub variable_count: usize,
127    /// Number of factor nodes.
128    pub factor_count: usize,
129    /// Total edges (variable-factor connections).
130    pub edge_count: usize,
131    /// Maximum number of variables any single factor connects.
132    pub max_factor_arity: usize,
133    /// Average factor arity.
134    pub avg_factor_arity: f64,
135    /// Maximum degree of any variable (number of factors it participates in).
136    pub max_variable_degree: usize,
137    /// Average variable degree.
138    pub avg_variable_degree: f64,
139    /// Whether the factor graph forms a tree (no loops).
140    pub is_tree: bool,
141    /// Upper bound on treewidth (max factor arity - 1).
142    pub treewidth_upper_bound: usize,
143}
144
145impl FactorGraphStats {
146    /// Compute statistics from a [`FactorGraphModel`].
147    pub fn compute(model: &FactorGraphModel) -> Self {
148        let variable_count = model.variable_count();
149        let factor_count = model.factor_count();
150        let edge_count = model.edge_count();
151
152        let max_factor_arity = model
153            .factors
154            .iter()
155            .map(|f| f.variable_indices.len())
156            .max()
157            .unwrap_or(0);
158
159        let avg_factor_arity = if factor_count > 0 {
160            edge_count as f64 / factor_count as f64
161        } else {
162            0.0
163        };
164
165        // Variable degree = number of factors connected to it.
166        let mut var_degrees = vec![0usize; variable_count];
167        for factor in &model.factors {
168            for &vi in &factor.variable_indices {
169                if vi < variable_count {
170                    var_degrees[vi] += 1;
171                }
172            }
173        }
174
175        let max_variable_degree = var_degrees.iter().copied().max().unwrap_or(0);
176        let avg_variable_degree = if variable_count > 0 {
177            var_degrees.iter().sum::<usize>() as f64 / variable_count as f64
178        } else {
179            0.0
180        };
181
182        // Tree check: a bipartite factor graph is a tree when
183        // |edges| == |variable nodes| + |factor nodes| - 1 and the graph is
184        // connected. We use the simpler edge-count heuristic here (sufficient
185        // for the upper-bound use-case).
186        let total_nodes = variable_count + factor_count;
187        let is_tree = total_nodes > 0 && edge_count + 1 == total_nodes;
188
189        let treewidth_upper_bound = if max_factor_arity > 0 {
190            max_factor_arity - 1
191        } else {
192            0
193        };
194
195        Self {
196            variable_count,
197            factor_count,
198            edge_count,
199            max_factor_arity,
200            avg_factor_arity,
201            max_variable_degree,
202            avg_variable_degree,
203            is_tree,
204            treewidth_upper_bound,
205        }
206    }
207
208    /// One-line summary string.
209    pub fn summary(&self) -> String {
210        format!(
211            "{} vars, {} factors, {} edges, treewidth\u{2264}{}{}",
212            self.variable_count,
213            self.factor_count,
214            self.edge_count,
215            self.treewidth_upper_bound,
216            if self.is_tree { " (tree)" } else { "" }
217        )
218    }
219}
220
221// ---------------------------------------------------------------------------
222// Rendering helpers
223// ---------------------------------------------------------------------------
224
225/// Render a [`FactorGraphModel`] as human-readable ASCII text.
226pub fn render_ascii(model: &FactorGraphModel) -> String {
227    let mut out = String::new();
228
229    let _ = writeln!(out, "Factor Graph:");
230
231    // Variables line
232    let var_descs: Vec<String> = model
233        .variables
234        .iter()
235        .map(|v| format!("{}({})", v.name, v.domain_size))
236        .collect();
237    let _ = writeln!(
238        out,
239        "  Variables ({}): {}",
240        model.variable_count(),
241        var_descs.join(", ")
242    );
243
244    // Factors line
245    let fac_descs: Vec<String> = model
246        .factors
247        .iter()
248        .map(|f| format!("{}({})", f.name, f.variable_indices.len()))
249        .collect();
250    let _ = writeln!(
251        out,
252        "  Factors ({}):  {}",
253        model.factor_count(),
254        fac_descs.join(", ")
255    );
256
257    // Connections
258    let _ = writeln!(out, "  Connections:");
259    for factor in &model.factors {
260        let var_names: Vec<&str> = factor
261            .variable_indices
262            .iter()
263            .filter_map(|&i| model.variables.get(i).map(|v| v.name.as_str()))
264            .collect();
265        let _ = writeln!(
266            out,
267            "    {} \u{2500}\u{2500} {}",
268            factor.name,
269            var_names.join(", ")
270        );
271    }
272
273    out
274}
275
276/// Render a [`FactorGraphModel`] as a DOT (Graphviz) graph.
277///
278/// The output uses the undirected `graph` keyword because factor graphs
279/// are inherently undirected.
280pub fn render_dot(model: &FactorGraphModel) -> String {
281    let mut dot = String::new();
282
283    let _ = writeln!(dot, "graph FactorGraph {{");
284    let _ = writeln!(dot, "  rankdir=LR;");
285
286    // Variable nodes as circles.
287    for (i, var) in model.variables.iter().enumerate() {
288        let _ = writeln!(dot, "  v{} [label=\"{}\", shape=circle];", i, var.name);
289    }
290
291    // Factor nodes as filled squares, with edges.
292    for (i, factor) in model.factors.iter().enumerate() {
293        let _ = writeln!(
294            dot,
295            "  f{} [label=\"{}\", shape=square, style=filled, fillcolor=lightgray];",
296            i, factor.name
297        );
298        for &vi in &factor.variable_indices {
299            let _ = writeln!(dot, "  f{} -- v{};", i, vi);
300        }
301    }
302
303    let _ = writeln!(dot, "}}");
304    dot
305}
306
307// ---------------------------------------------------------------------------
308// Tests
309// ---------------------------------------------------------------------------
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    // -- helpers --
316
317    /// A-f1-B-f2-C chain (3 vars, 2 binary factors, 4 edges).
318    fn chain_model() -> FactorGraphModel {
319        let mut m = FactorGraphModel::new();
320        let a = m.add_variable("A", 2);
321        let b = m.add_variable("B", 2);
322        let c = m.add_variable("C", 2);
323        m.add_factor("f1", vec![a, b]);
324        m.add_factor("f2", vec![b, c]);
325        m
326    }
327
328    /// A loopy model: triangle A-B-C with one ternary factor.
329    fn loopy_model() -> FactorGraphModel {
330        let mut m = FactorGraphModel::new();
331        let a = m.add_variable("A", 2);
332        let b = m.add_variable("B", 2);
333        let c = m.add_variable("C", 2);
334        m.add_factor("f1", vec![a, b]);
335        m.add_factor("f2", vec![b, c]);
336        m.add_factor("f3", vec![a, c]);
337        m
338    }
339
340    // -- FactorGraphModel basics --
341
342    #[test]
343    fn test_model_new_empty() {
344        let m = FactorGraphModel::new();
345        assert_eq!(m.variable_count(), 0);
346        assert_eq!(m.factor_count(), 0);
347        assert_eq!(m.edge_count(), 0);
348    }
349
350    #[test]
351    fn test_model_add_variable() {
352        let mut m = FactorGraphModel::new();
353        let idx = m.add_variable("X", 4);
354        assert_eq!(idx, 0);
355        assert_eq!(m.variable_count(), 1);
356        assert_eq!(m.variables[0].domain_size, 4);
357    }
358
359    #[test]
360    fn test_model_add_factor() {
361        let mut m = FactorGraphModel::new();
362        let a = m.add_variable("A", 2);
363        m.add_factor("f1", vec![a]);
364        assert_eq!(m.factor_count(), 1);
365    }
366
367    #[test]
368    fn test_model_counts() {
369        let m = chain_model();
370        assert_eq!(m.variable_count(), 3);
371        assert_eq!(m.factor_count(), 2);
372        assert_eq!(m.edge_count(), 4);
373    }
374
375    // -- FactorGraphStats --
376
377    #[test]
378    fn test_stats_empty() {
379        let m = FactorGraphModel::new();
380        let s = FactorGraphStats::compute(&m);
381        assert_eq!(s.variable_count, 0);
382        assert_eq!(s.factor_count, 0);
383        assert_eq!(s.edge_count, 0);
384        assert_eq!(s.max_factor_arity, 0);
385        assert!((s.avg_factor_arity - 0.0).abs() < f64::EPSILON);
386    }
387
388    #[test]
389    fn test_stats_simple_chain() {
390        let s = FactorGraphStats::compute(&chain_model());
391        assert_eq!(s.variable_count, 3);
392        assert_eq!(s.factor_count, 2);
393        assert_eq!(s.edge_count, 4);
394    }
395
396    #[test]
397    fn test_stats_max_factor_arity() {
398        let mut m = FactorGraphModel::new();
399        let a = m.add_variable("A", 2);
400        let b = m.add_variable("B", 2);
401        let c = m.add_variable("C", 2);
402        m.add_factor("big", vec![a, b, c]);
403        let s = FactorGraphStats::compute(&m);
404        assert_eq!(s.max_factor_arity, 3);
405    }
406
407    #[test]
408    fn test_stats_avg_factor_arity() {
409        // chain: 4 edges / 2 factors = 2.0
410        let s = FactorGraphStats::compute(&chain_model());
411        assert!((s.avg_factor_arity - 2.0).abs() < f64::EPSILON);
412    }
413
414    #[test]
415    fn test_stats_variable_degree() {
416        // In chain, B appears in f1 and f2 => degree 2.
417        let s = FactorGraphStats::compute(&chain_model());
418        assert_eq!(s.max_variable_degree, 2);
419    }
420
421    #[test]
422    fn test_stats_is_tree_true() {
423        // chain: 5 nodes (3 var + 2 factor), 4 edges => tree.
424        let s = FactorGraphStats::compute(&chain_model());
425        assert!(s.is_tree);
426    }
427
428    #[test]
429    fn test_stats_is_tree_false() {
430        // loopy: 6 nodes, 6 edges => not a tree.
431        let s = FactorGraphStats::compute(&loopy_model());
432        assert!(!s.is_tree);
433    }
434
435    #[test]
436    fn test_stats_treewidth() {
437        let s = FactorGraphStats::compute(&chain_model());
438        // max arity = 2, so upper bound = 1
439        assert_eq!(s.treewidth_upper_bound, 1);
440    }
441
442    #[test]
443    fn test_stats_summary() {
444        let s = FactorGraphStats::compute(&chain_model());
445        let summary = s.summary();
446        assert!(summary.contains("vars"));
447        assert!(summary.contains("factors"));
448    }
449
450    // -- Rendering --
451
452    #[test]
453    fn test_render_ascii_header() {
454        let out = render_ascii(&chain_model());
455        assert!(out.contains("Factor Graph:"));
456    }
457
458    #[test]
459    fn test_render_ascii_variables() {
460        let out = render_ascii(&chain_model());
461        assert!(out.contains("A(2)"));
462        assert!(out.contains("B(2)"));
463        assert!(out.contains("C(2)"));
464    }
465
466    #[test]
467    fn test_render_ascii_connections() {
468        let out = render_ascii(&chain_model());
469        // f1 connects A, B
470        assert!(out.contains("f1"));
471        assert!(out.contains("A"));
472        assert!(out.contains("B"));
473    }
474
475    #[test]
476    fn test_render_dot_undirected() {
477        let dot = render_dot(&chain_model());
478        // Must use undirected "graph", not "digraph".
479        assert!(dot.starts_with("graph "));
480        assert!(!dot.contains("digraph"));
481    }
482
483    #[test]
484    fn test_render_dot_nodes() {
485        let dot = render_dot(&chain_model());
486        // Variable nodes
487        assert!(dot.contains("v0"));
488        assert!(dot.contains("shape=circle"));
489        // Factor nodes
490        assert!(dot.contains("f0"));
491        assert!(dot.contains("shape=square"));
492    }
493}