Skip to main content

tensorlogic_quantrs_hooks/
junction_tree.rs

1//! Junction tree algorithm for exact inference in probabilistic graphical models.
2//!
3//! The junction tree algorithm (also known as the clique tree algorithm) is an exact inference
4//! algorithm that works by:
5//! 1. Converting the factor graph into a tree structure via triangulation
6//! 2. Creating cliques (maximal sets of connected variables)
7//! 3. Building a junction tree where nodes are cliques
8//! 4. Passing messages between cliques to compute exact marginals
9//!
10//! # Algorithm Overview
11//!
12//! ```text
13//! Factor Graph → Moralize → Triangulate → Find Cliques → Build Tree → Calibrate
14//!       ↓            ↓           ↓             ↓            ↓           ↓
15//!    Variables   Undirected   Chordal     Maximal      Junction    Marginals
16//!                 Graph       Graph       Cliques       Tree
17//! ```
18//!
19//! # Complexity
20//!
21//! - Time: O(n × d^(w+1)) where w is the treewidth
22//! - Space: O(d^w)
23//! - Exact for any graph structure (unlike loopy BP)
24//!
25//! # References
26//!
27//! - Koller & Friedman, "Probabilistic Graphical Models", Chapter 9
28//! - Lauritzen & Spiegelhalter, "Local Computations with Probabilities on
29//!   Graphical Structures and their Application to Expert Systems" (1988)
30
31use crate::error::{PgmError, Result};
32use crate::factor::Factor;
33use crate::graph::FactorGraph;
34use scirs2_core::ndarray::ArrayD;
35use std::collections::{HashMap, HashSet, VecDeque};
36
37/// A clique in the junction tree (a maximal set of connected variables).
38#[derive(Debug, Clone)]
39pub struct Clique {
40    /// Unique identifier for the clique
41    pub id: usize,
42    /// Variables in this clique
43    pub variables: HashSet<String>,
44    /// Potential function (product of all factors involving these variables)
45    pub potential: Option<Factor>,
46}
47
48impl Clique {
49    /// Create a new clique with the given variables.
50    pub fn new(id: usize, variables: HashSet<String>) -> Self {
51        Self {
52            id,
53            variables,
54            potential: None,
55        }
56    }
57
58    /// Check if this clique contains all variables in the given set.
59    pub fn contains_all(&self, vars: &HashSet<String>) -> bool {
60        vars.is_subset(&self.variables)
61    }
62
63    /// Get the intersection of variables with another clique.
64    pub fn intersection(&self, other: &Clique) -> HashSet<String> {
65        self.variables
66            .intersection(&other.variables)
67            .cloned()
68            .collect()
69    }
70}
71
72/// A separator between two cliques (their shared variables).
73#[derive(Debug, Clone)]
74pub struct Separator {
75    /// Variables in the separator
76    pub variables: HashSet<String>,
77    /// Message potential
78    pub potential: Option<Factor>,
79}
80
81impl Separator {
82    /// Create a new separator from the intersection of two cliques.
83    pub fn from_cliques(c1: &Clique, c2: &Clique) -> Self {
84        Self {
85            variables: c1.intersection(c2),
86            potential: None,
87        }
88    }
89}
90
91/// Edge in the junction tree connecting two cliques.
92#[derive(Debug, Clone)]
93pub struct JunctionTreeEdge {
94    /// ID of the first clique
95    pub clique1: usize,
96    /// ID of the second clique
97    pub clique2: usize,
98    /// Separator (shared variables)
99    pub separator: Separator,
100    /// Message from clique1 to clique2
101    pub message_1_to_2: Option<Factor>,
102    /// Message from clique2 to clique1
103    pub message_2_to_1: Option<Factor>,
104}
105
106impl JunctionTreeEdge {
107    /// Create a new edge between two cliques.
108    pub fn new(clique1: usize, clique2: usize, separator: Separator) -> Self {
109        Self {
110            clique1,
111            clique2,
112            separator,
113            message_1_to_2: None,
114            message_2_to_1: None,
115        }
116    }
117}
118
119/// Junction tree structure for exact inference.
120#[derive(Debug, Clone)]
121pub struct JunctionTree {
122    /// Cliques in the tree
123    pub cliques: Vec<Clique>,
124    /// Edges connecting cliques
125    pub edges: Vec<JunctionTreeEdge>,
126    /// Variable to clique mapping (for query efficiency)
127    pub var_to_cliques: HashMap<String, Vec<usize>>,
128    /// Whether the tree has been calibrated
129    pub calibrated: bool,
130}
131
132impl JunctionTree {
133    /// Create a new junction tree.
134    pub fn new() -> Self {
135        Self {
136            cliques: Vec::new(),
137            edges: Vec::new(),
138            var_to_cliques: HashMap::new(),
139            calibrated: false,
140        }
141    }
142
143    /// Build a junction tree from a factor graph.
144    ///
145    /// This implements the complete junction tree construction algorithm:
146    /// 1. Moralize the graph (if directed)
147    /// 2. Triangulate to create a chordal graph
148    /// 3. Identify maximal cliques
149    /// 4. Build a junction tree satisfying the running intersection property
150    pub fn from_factor_graph(graph: &FactorGraph) -> Result<Self> {
151        // Step 1: Extract interaction graph (moralized graph)
152        let interaction_graph = Self::build_interaction_graph(graph)?;
153
154        // Step 2: Triangulate the graph using min-fill heuristic
155        let triangulated = Self::triangulate(&interaction_graph)?;
156
157        // Step 3: Find maximal cliques
158        let cliques = Self::find_maximal_cliques(&triangulated)?;
159
160        // Step 4: Build junction tree from cliques
161        let mut tree = Self::build_tree_from_cliques(cliques)?;
162
163        // Step 5: Initialize clique potentials
164        tree.initialize_potentials(graph)?;
165
166        Ok(tree)
167    }
168
169    /// Build the interaction graph (moralized graph).
170    ///
171    /// For factor graphs, the interaction graph is an undirected graph where:
172    /// - Nodes are variables
173    /// - Edges connect variables that appear together in some factor
174    fn build_interaction_graph(graph: &FactorGraph) -> Result<HashMap<String, HashSet<String>>> {
175        let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
176
177        // Initialize all variables
178        for var_name in graph.variable_names() {
179            adjacency.insert(var_name.clone(), HashSet::new());
180        }
181
182        // Add edges for each factor
183        for factor in graph.factors() {
184            let vars = &factor.variables;
185            // Connect all pairs of variables in the factor
186            for i in 0..vars.len() {
187                for j in (i + 1)..vars.len() {
188                    let v1 = &vars[i];
189                    let v2 = &vars[j];
190
191                    adjacency.entry(v1.clone()).or_default().insert(v2.clone());
192                    adjacency.entry(v2.clone()).or_default().insert(v1.clone());
193                }
194            }
195        }
196
197        Ok(adjacency)
198    }
199
200    /// Triangulate the graph to make it chordal.
201    ///
202    /// Uses the min-fill heuristic for variable elimination ordering.
203    /// A chordal graph has the property that every cycle of length ≥4 has a chord.
204    fn triangulate(
205        graph: &HashMap<String, HashSet<String>>,
206    ) -> Result<HashMap<String, HashSet<String>>> {
207        let mut triangulated = graph.clone();
208        let mut remaining: HashSet<String> = graph.keys().cloned().collect();
209
210        while !remaining.is_empty() {
211            // Find variable with minimum fill-in edges
212            let var = Self::find_min_fill_variable(&triangulated, &remaining)?;
213
214            // Get neighbors of the variable
215            let neighbors: Vec<String> = triangulated
216                .get(&var)
217                .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
218                .intersection(&remaining)
219                .cloned()
220                .collect();
221
222            // Add fill-in edges (connect all pairs of neighbors)
223            for i in 0..neighbors.len() {
224                for j in (i + 1)..neighbors.len() {
225                    let n1 = &neighbors[i];
226                    let n2 = &neighbors[j];
227
228                    triangulated
229                        .entry(n1.clone())
230                        .or_default()
231                        .insert(n2.clone());
232                    triangulated
233                        .entry(n2.clone())
234                        .or_default()
235                        .insert(n1.clone());
236                }
237            }
238
239            // Remove variable from remaining set
240            remaining.remove(&var);
241        }
242
243        Ok(triangulated)
244    }
245
246    /// Find the variable with minimum fill-in (min-fill heuristic).
247    fn find_min_fill_variable(
248        graph: &HashMap<String, HashSet<String>>,
249        remaining: &HashSet<String>,
250    ) -> Result<String> {
251        let mut min_fill = usize::MAX;
252        let mut best_var = None;
253
254        for var in remaining {
255            let neighbors: Vec<String> = graph
256                .get(var)
257                .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
258                .intersection(remaining)
259                .cloned()
260                .collect();
261
262            // Count how many edges need to be added
263            let mut fill_count = 0;
264            for i in 0..neighbors.len() {
265                for j in (i + 1)..neighbors.len() {
266                    let n1 = &neighbors[i];
267                    let n2 = &neighbors[j];
268                    if !graph.get(n1).unwrap().contains(n2) {
269                        fill_count += 1;
270                    }
271                }
272            }
273
274            if fill_count < min_fill {
275                min_fill = fill_count;
276                best_var = Some(var.clone());
277            }
278        }
279
280        best_var.ok_or_else(|| PgmError::InvalidGraph("No variable found".to_string()))
281    }
282
283    /// Find maximal cliques in a triangulated graph.
284    ///
285    /// Uses a greedy algorithm to identify maximal cliques.
286    fn find_maximal_cliques(
287        graph: &HashMap<String, HashSet<String>>,
288    ) -> Result<Vec<HashSet<String>>> {
289        let mut cliques = Vec::new();
290        let mut visited: HashSet<String> = HashSet::new();
291
292        // Start from each variable and grow cliques
293        for var in graph.keys() {
294            if visited.contains(var) {
295                continue;
296            }
297
298            let mut clique: HashSet<String> = HashSet::new();
299            clique.insert(var.clone());
300
301            // Add neighbors that form a clique
302            for neighbor in graph.get(var).unwrap() {
303                // Check if neighbor is connected to all current clique members
304                let is_fully_connected = clique
305                    .iter()
306                    .all(|c| c == neighbor || graph.get(neighbor).unwrap().contains(c));
307
308                if is_fully_connected {
309                    clique.insert(neighbor.clone());
310                }
311            }
312
313            // Check if this is a maximal clique
314            let is_maximal = !cliques
315                .iter()
316                .any(|c: &HashSet<String>| c.is_superset(&clique));
317
318            if is_maximal {
319                // Remove non-maximal cliques that are subsets of this one
320                cliques.retain(|c| !clique.is_superset(c));
321                cliques.push(clique.clone());
322            }
323
324            visited.insert(var.clone());
325        }
326
327        // Ensure we have at least one clique
328        if cliques.is_empty() && !graph.is_empty() {
329            // Create a clique with all variables (fallback)
330            let all_vars: HashSet<String> = graph.keys().cloned().collect();
331            cliques.push(all_vars);
332        }
333
334        Ok(cliques)
335    }
336
337    /// Build a junction tree from maximal cliques.
338    ///
339    /// Uses a maximum spanning tree algorithm based on separator size.
340    fn build_tree_from_cliques(clique_sets: Vec<HashSet<String>>) -> Result<Self> {
341        let mut tree = JunctionTree::new();
342
343        // Create clique nodes
344        for (id, vars) in clique_sets.into_iter().enumerate() {
345            let clique = Clique::new(id, vars.clone());
346
347            // Update variable to clique mapping
348            for var in &vars {
349                tree.var_to_cliques.entry(var.clone()).or_default().push(id);
350            }
351
352            tree.cliques.push(clique);
353        }
354
355        // Build maximum spanning tree based on separator size
356        if tree.cliques.len() > 1 {
357            tree.build_maximum_spanning_tree()?;
358        }
359
360        Ok(tree)
361    }
362
363    /// Build a maximum spanning tree connecting cliques.
364    ///
365    /// Uses Prim's algorithm with separator size as edge weight.
366    fn build_maximum_spanning_tree(&mut self) -> Result<()> {
367        let n = self.cliques.len();
368        if n == 0 {
369            return Ok(());
370        }
371
372        let mut in_tree = vec![false; n];
373        let mut edges_to_add: Vec<(usize, usize, usize)> = Vec::new();
374
375        // Start with clique 0
376        in_tree[0] = true;
377        let mut tree_size = 1;
378
379        while tree_size < n {
380            let mut best_edge = None;
381            let mut best_weight = 0;
382
383            // Find best edge to add
384            for i in 0..n {
385                if !in_tree[i] {
386                    continue;
387                }
388
389                for (j, &is_in_tree) in in_tree.iter().enumerate().take(n) {
390                    if is_in_tree {
391                        continue;
392                    }
393
394                    let separator = self.cliques[i].intersection(&self.cliques[j]);
395                    let weight = separator.len();
396
397                    if weight > best_weight {
398                        best_weight = weight;
399                        best_edge = Some((i, j, weight));
400                    }
401                }
402            }
403
404            if let Some((i, j, _)) = best_edge {
405                edges_to_add.push((i, j, best_weight));
406                in_tree[j] = true;
407                tree_size += 1;
408            } else {
409                break;
410            }
411        }
412
413        // Create edges
414        for (c1, c2, _) in edges_to_add {
415            let separator = Separator::from_cliques(&self.cliques[c1], &self.cliques[c2]);
416            let edge = JunctionTreeEdge::new(c1, c2, separator);
417            self.edges.push(edge);
418        }
419
420        Ok(())
421    }
422
423    /// Initialize clique potentials from the factor graph.
424    ///
425    /// Assigns each factor to a clique that contains all its variables.
426    fn initialize_potentials(&mut self, graph: &FactorGraph) -> Result<()> {
427        for factor in graph.factors() {
428            let factor_vars: HashSet<String> = factor.variables.iter().cloned().collect();
429
430            // Find a clique that contains all variables in this factor
431            let clique_idx = self
432                .cliques
433                .iter()
434                .position(|c| c.contains_all(&factor_vars))
435                .ok_or_else(|| {
436                    PgmError::InvalidGraph(format!(
437                        "No clique contains all variables for factor: {:?}",
438                        factor.name
439                    ))
440                })?;
441
442            let clique = &mut self.cliques[clique_idx];
443
444            // Multiply factor into clique potential
445            if let Some(ref mut potential) = clique.potential {
446                *potential = potential.product(factor)?;
447            } else {
448                clique.potential = Some(factor.clone());
449            }
450        }
451
452        // Initialize cliques without factors to uniform potentials
453        for clique in &mut self.cliques {
454            if clique.potential.is_none() {
455                // Create uniform potential
456                clique.potential = Some(Self::create_uniform_potential(&clique.variables, graph)?);
457            }
458        }
459
460        Ok(())
461    }
462
463    /// Create a uniform potential over a set of variables.
464    fn create_uniform_potential(
465        variables: &HashSet<String>,
466        graph: &FactorGraph,
467    ) -> Result<Factor> {
468        let var_vec: Vec<String> = variables.iter().cloned().collect();
469        let mut shape = Vec::new();
470
471        for var in &var_vec {
472            let cardinality = graph
473                .get_variable(var)
474                .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", var)))?
475                .cardinality;
476            shape.push(cardinality);
477        }
478
479        let size: usize = shape.iter().product();
480        let values = vec![1.0; size];
481
482        let array = ArrayD::from_shape_vec(shape, values)
483            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
484
485        Factor::new("uniform".to_string(), var_vec, array)
486    }
487
488    /// Calibrate the junction tree by passing messages.
489    ///
490    /// This implements the message passing schedule for exact inference.
491    pub fn calibrate(&mut self) -> Result<()> {
492        if self.edges.is_empty() {
493            self.calibrated = true;
494            return Ok(());
495        }
496
497        // Collect evidence (inward pass) from leaves to root
498        let root = 0;
499        self.collect_evidence(root, None)?;
500
501        // Distribute evidence (outward pass) from root to leaves
502        self.distribute_evidence(root, None)?;
503
504        self.calibrated = true;
505        Ok(())
506    }
507
508    /// Collect evidence (inward pass).
509    fn collect_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
510        // Find children (neighbors except parent)
511        let children: Vec<usize> = self.get_neighbors(current, parent);
512
513        // Recursively collect from children
514        for child in &children {
515            self.collect_evidence(*child, Some(current))?;
516        }
517
518        // Send message to parent if exists
519        if let Some(parent_idx) = parent {
520            self.send_message(current, parent_idx)?;
521        }
522
523        Ok(())
524    }
525
526    /// Distribute evidence (outward pass).
527    fn distribute_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
528        // Find children (neighbors except parent)
529        let children: Vec<usize> = self.get_neighbors(current, parent);
530
531        // Send messages to all children
532        for child in &children {
533            self.send_message(current, *child)?;
534            self.distribute_evidence(*child, Some(current))?;
535        }
536
537        Ok(())
538    }
539
540    /// Get neighbors of a clique excluding the parent.
541    fn get_neighbors(&self, clique: usize, parent: Option<usize>) -> Vec<usize> {
542        let mut neighbors = Vec::new();
543
544        for edge in &self.edges {
545            if edge.clique1 == clique {
546                if parent.is_none() || parent.unwrap() != edge.clique2 {
547                    neighbors.push(edge.clique2);
548                }
549            } else if edge.clique2 == clique
550                && (parent.is_none() || parent.unwrap() != edge.clique1)
551            {
552                neighbors.push(edge.clique1);
553            }
554        }
555
556        neighbors
557    }
558
559    /// Send a message from one clique to another.
560    fn send_message(&mut self, from: usize, to: usize) -> Result<()> {
561        // Find the edge
562        let edge_idx = self
563            .edges
564            .iter()
565            .position(|e| {
566                (e.clique1 == from && e.clique2 == to) || (e.clique1 == to && e.clique2 == from)
567            })
568            .ok_or_else(|| PgmError::InvalidGraph("Edge not found".to_string()))?;
569
570        // Get separator variables
571        let separator_vars = self.edges[edge_idx].separator.variables.clone();
572
573        // Get clique potential
574        let clique_potential = self.cliques[from].potential.clone().ok_or_else(|| {
575            PgmError::InvalidGraph("Clique potential not initialized".to_string())
576        })?;
577
578        // Marginalize out variables not in separator
579        let mut message = clique_potential;
580        let all_vars: HashSet<String> = message.variables.iter().cloned().collect();
581        let vars_to_eliminate: Vec<String> =
582            all_vars.difference(&separator_vars).cloned().collect();
583
584        for var in vars_to_eliminate {
585            message = message.marginalize_out(&var)?;
586        }
587
588        // Store message
589        let edge = &mut self.edges[edge_idx];
590        if edge.clique1 == from {
591            edge.message_1_to_2 = Some(message);
592        } else {
593            edge.message_2_to_1 = Some(message);
594        }
595
596        Ok(())
597    }
598
599    /// Query marginal probability for a variable.
600    pub fn query_marginal(&self, variable: &str) -> Result<ArrayD<f64>> {
601        if !self.calibrated {
602            return Err(PgmError::InvalidGraph(
603                "Tree must be calibrated before querying".to_string(),
604            ));
605        }
606
607        // Find a clique containing this variable
608        let clique_indices = self
609            .var_to_cliques
610            .get(variable)
611            .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", variable)))?;
612
613        if clique_indices.is_empty() {
614            return Err(PgmError::InvalidGraph(format!(
615                "No clique contains variable {}",
616                variable
617            )));
618        }
619
620        // Get belief from first clique
621        let clique = &self.cliques[clique_indices[0]];
622        let mut belief = clique.potential.clone().ok_or_else(|| {
623            PgmError::InvalidGraph("Clique potential not initialized".to_string())
624        })?;
625
626        // Marginalize out all variables except the query variable
627        let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
628        let mut target_set = HashSet::new();
629        target_set.insert(variable.to_string());
630        let vars_to_eliminate: Vec<String> = all_vars.difference(&target_set).cloned().collect();
631
632        for var in vars_to_eliminate {
633            belief = belief.marginalize_out(&var)?;
634        }
635
636        // Normalize
637        belief.normalize();
638
639        Ok(belief.values)
640    }
641
642    /// Query joint marginal over multiple variables.
643    pub fn query_joint_marginal(&self, variables: &[String]) -> Result<ArrayD<f64>> {
644        if !self.calibrated {
645            return Err(PgmError::InvalidGraph(
646                "Tree must be calibrated before querying".to_string(),
647            ));
648        }
649
650        let var_set: HashSet<String> = variables.iter().cloned().collect();
651
652        // Find a clique containing all these variables
653        let clique = self
654            .cliques
655            .iter()
656            .find(|c| c.contains_all(&var_set))
657            .ok_or_else(|| {
658                PgmError::InvalidGraph(format!("No clique contains all variables: {:?}", variables))
659            })?;
660
661        let mut belief = clique.potential.clone().ok_or_else(|| {
662            PgmError::InvalidGraph("Clique potential not initialized".to_string())
663        })?;
664
665        // Marginalize out variables not in query
666        let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
667        let vars_to_eliminate: Vec<String> = all_vars.difference(&var_set).cloned().collect();
668
669        for var in vars_to_eliminate {
670            belief = belief.marginalize_out(&var)?;
671        }
672
673        // Normalize
674        belief.normalize();
675
676        Ok(belief.values)
677    }
678
679    /// Get the treewidth of this junction tree.
680    ///
681    /// The treewidth is the size of the largest clique minus 1.
682    pub fn treewidth(&self) -> usize {
683        self.cliques
684            .iter()
685            .map(|c| c.variables.len())
686            .max()
687            .unwrap_or(0)
688            .saturating_sub(1)
689    }
690
691    /// Check if the junction tree satisfies the running intersection property.
692    ///
693    /// For every variable X, the set of cliques containing X forms a connected subtree.
694    pub fn verify_running_intersection_property(&self) -> bool {
695        for var in self.var_to_cliques.keys() {
696            let cliques_with_var = self.var_to_cliques.get(var).unwrap();
697
698            if cliques_with_var.len() <= 1 {
699                continue;
700            }
701
702            // Check if these cliques form a connected component
703            if !self.is_connected_subgraph(cliques_with_var) {
704                return false;
705            }
706        }
707
708        true
709    }
710
711    /// Check if a set of cliques forms a connected subgraph.
712    fn is_connected_subgraph(&self, cliques: &[usize]) -> bool {
713        if cliques.is_empty() {
714            return true;
715        }
716
717        let clique_set: HashSet<usize> = cliques.iter().copied().collect();
718        let mut visited = HashSet::new();
719        let mut queue = VecDeque::new();
720
721        // Start BFS from first clique
722        queue.push_back(cliques[0]);
723        visited.insert(cliques[0]);
724
725        while let Some(current) = queue.pop_front() {
726            for edge in &self.edges {
727                let neighbor = if edge.clique1 == current {
728                    Some(edge.clique2)
729                } else if edge.clique2 == current {
730                    Some(edge.clique1)
731                } else {
732                    None
733                };
734
735                if let Some(n) = neighbor {
736                    if clique_set.contains(&n) && !visited.contains(&n) {
737                        visited.insert(n);
738                        queue.push_back(n);
739                    }
740                }
741            }
742        }
743
744        visited.len() == cliques.len()
745    }
746}
747
748impl Default for JunctionTree {
749    fn default() -> Self {
750        Self::new()
751    }
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757    use crate::graph::FactorGraph;
758    use approx::assert_abs_diff_eq;
759    use scirs2_core::ndarray::Array;
760
761    #[test]
762    fn test_clique_creation() {
763        let mut vars = HashSet::new();
764        vars.insert("x".to_string());
765        vars.insert("y".to_string());
766
767        let clique = Clique::new(0, vars);
768        assert_eq!(clique.id, 0);
769        assert_eq!(clique.variables.len(), 2);
770    }
771
772    #[test]
773    fn test_clique_intersection() {
774        let mut vars1 = HashSet::new();
775        vars1.insert("x".to_string());
776        vars1.insert("y".to_string());
777
778        let mut vars2 = HashSet::new();
779        vars2.insert("y".to_string());
780        vars2.insert("z".to_string());
781
782        let c1 = Clique::new(0, vars1);
783        let c2 = Clique::new(1, vars2);
784
785        let intersection = c1.intersection(&c2);
786        assert_eq!(intersection.len(), 1);
787        assert!(intersection.contains("y"));
788    }
789
790    #[test]
791    fn test_interaction_graph() {
792        let mut graph = FactorGraph::new();
793        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
794        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
795        graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
796
797        // Add factor P(x, y)
798        let pxy = Factor::new(
799            "P(x,y)".to_string(),
800            vec!["x".to_string(), "y".to_string()],
801            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
802                .unwrap()
803                .into_dyn(),
804        )
805        .unwrap();
806        graph.add_factor(pxy).unwrap();
807
808        // Add factor P(y, z)
809        let pyz = Factor::new(
810            "P(y,z)".to_string(),
811            vec!["y".to_string(), "z".to_string()],
812            Array::from_shape_vec(vec![2, 2], vec![0.5, 0.1, 0.2, 0.2])
813                .unwrap()
814                .into_dyn(),
815        )
816        .unwrap();
817        graph.add_factor(pyz).unwrap();
818
819        let interaction_graph = JunctionTree::build_interaction_graph(&graph).unwrap();
820
821        // Check edges
822        assert!(interaction_graph.get("x").unwrap().contains("y"));
823        assert!(interaction_graph.get("y").unwrap().contains("x"));
824        assert!(interaction_graph.get("y").unwrap().contains("z"));
825        assert!(interaction_graph.get("z").unwrap().contains("y"));
826    }
827
828    #[test]
829    fn test_junction_tree_construction() {
830        let mut graph = FactorGraph::new();
831        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
832        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
833
834        let pxy = Factor::new(
835            "P(x,y)".to_string(),
836            vec!["x".to_string(), "y".to_string()],
837            Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
838                .unwrap()
839                .into_dyn(),
840        )
841        .unwrap();
842        graph.add_factor(pxy).unwrap();
843
844        let tree = JunctionTree::from_factor_graph(&graph).unwrap();
845
846        assert!(!tree.cliques.is_empty());
847        assert!(tree.verify_running_intersection_property());
848    }
849
850    #[test]
851    fn test_junction_tree_calibration() {
852        let mut graph = FactorGraph::new();
853        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
854        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
855
856        let pxy = Factor::new(
857            "P(x,y)".to_string(),
858            vec!["x".to_string(), "y".to_string()],
859            Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
860                .unwrap()
861                .into_dyn(),
862        )
863        .unwrap();
864        graph.add_factor(pxy).unwrap();
865
866        let mut tree = JunctionTree::from_factor_graph(&graph).unwrap();
867        tree.calibrate().unwrap();
868
869        assert!(tree.calibrated);
870    }
871
872    #[test]
873    fn test_marginal_query() {
874        let mut graph = FactorGraph::new();
875        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
876        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
877
878        let pxy = Factor::new(
879            "P(x,y)".to_string(),
880            vec!["x".to_string(), "y".to_string()],
881            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.4, 0.2, 0.3])
882                .unwrap()
883                .into_dyn(),
884        )
885        .unwrap();
886        graph.add_factor(pxy).unwrap();
887
888        let mut tree = JunctionTree::from_factor_graph(&graph).unwrap();
889        tree.calibrate().unwrap();
890
891        let marginal_x = tree.query_marginal("x").unwrap();
892
893        // P(x=0) = 0.1 + 0.4 = 0.5
894        // P(x=1) = 0.2 + 0.3 = 0.5
895        assert_abs_diff_eq!(marginal_x[[0]], 0.5, epsilon = 1e-6);
896        assert_abs_diff_eq!(marginal_x[[1]], 0.5, epsilon = 1e-6);
897    }
898
899    #[test]
900    fn test_treewidth() {
901        let mut graph = FactorGraph::new();
902        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
903        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
904        graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
905
906        let pxy = Factor::new(
907            "P(x,y)".to_string(),
908            vec!["x".to_string(), "y".to_string()],
909            Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
910                .unwrap()
911                .into_dyn(),
912        )
913        .unwrap();
914        let pyz = Factor::new(
915            "P(y,z)".to_string(),
916            vec!["y".to_string(), "z".to_string()],
917            Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.6, 0.4])
918                .unwrap()
919                .into_dyn(),
920        )
921        .unwrap();
922
923        graph.add_factor(pxy).unwrap();
924        graph.add_factor(pyz).unwrap();
925
926        let tree = JunctionTree::from_factor_graph(&graph).unwrap();
927
928        // Treewidth should be at most 2 (clique size of 3 minus 1)
929        assert!(tree.treewidth() <= 2);
930    }
931}