Skip to main content

tensorlogic_quantrs_hooks/
junction_tree.rs

1//! Junction tree algorithm for exact inference in probabilistic graphical models.
2//!
3//! The junction tree algorithm (also known as the clique tree algorithm) is an exact inference
4//! algorithm that works by:
5//! 1. Converting the factor graph into a tree structure via triangulation
6//! 2. Creating cliques (maximal sets of connected variables)
7//! 3. Building a junction tree where nodes are cliques
8//! 4. Passing messages between cliques to compute exact marginals
9//!
10//! # Algorithm Overview
11//!
12//! ```text
13//! Factor Graph → Moralize → Triangulate → Find Cliques → Build Tree → Calibrate
14//!       ↓            ↓           ↓             ↓            ↓           ↓
15//!    Variables   Undirected   Chordal     Maximal      Junction    Marginals
16//!                 Graph       Graph       Cliques       Tree
17//! ```
18//!
19//! # Complexity
20//!
21//! - Time: O(n × d^(w+1)) where w is the treewidth
22//! - Space: O(d^w)
23//! - Exact for any graph structure (unlike loopy BP)
24//!
25//! # References
26//!
27//! - Koller & Friedman, "Probabilistic Graphical Models", Chapter 9
28//! - Lauritzen & Spiegelhalter, "Local Computations with Probabilities on
29//!   Graphical Structures and their Application to Expert Systems" (1988)
30
31use crate::error::{PgmError, Result};
32use crate::factor::Factor;
33use crate::graph::FactorGraph;
34use scirs2_core::ndarray::ArrayD;
35use std::collections::{HashMap, HashSet, VecDeque};
36
37/// A clique in the junction tree (a maximal set of connected variables).
38#[derive(Debug, Clone)]
39pub struct Clique {
40    /// Unique identifier for the clique
41    pub id: usize,
42    /// Variables in this clique
43    pub variables: HashSet<String>,
44    /// Potential function (product of all factors involving these variables)
45    pub potential: Option<Factor>,
46}
47
48impl Clique {
49    /// Create a new clique with the given variables.
50    pub fn new(id: usize, variables: HashSet<String>) -> Self {
51        Self {
52            id,
53            variables,
54            potential: None,
55        }
56    }
57
58    /// Check if this clique contains all variables in the given set.
59    pub fn contains_all(&self, vars: &HashSet<String>) -> bool {
60        vars.is_subset(&self.variables)
61    }
62
63    /// Get the intersection of variables with another clique.
64    pub fn intersection(&self, other: &Clique) -> HashSet<String> {
65        self.variables
66            .intersection(&other.variables)
67            .cloned()
68            .collect()
69    }
70}
71
72/// A separator between two cliques (their shared variables).
73#[derive(Debug, Clone)]
74pub struct Separator {
75    /// Variables in the separator
76    pub variables: HashSet<String>,
77    /// Message potential
78    pub potential: Option<Factor>,
79}
80
81impl Separator {
82    /// Create a new separator from the intersection of two cliques.
83    pub fn from_cliques(c1: &Clique, c2: &Clique) -> Self {
84        Self {
85            variables: c1.intersection(c2),
86            potential: None,
87        }
88    }
89}
90
91/// Edge in the junction tree connecting two cliques.
92#[derive(Debug, Clone)]
93pub struct JunctionTreeEdge {
94    /// ID of the first clique
95    pub clique1: usize,
96    /// ID of the second clique
97    pub clique2: usize,
98    /// Separator (shared variables)
99    pub separator: Separator,
100    /// Message from clique1 to clique2
101    pub message_1_to_2: Option<Factor>,
102    /// Message from clique2 to clique1
103    pub message_2_to_1: Option<Factor>,
104}
105
106impl JunctionTreeEdge {
107    /// Create a new edge between two cliques.
108    pub fn new(clique1: usize, clique2: usize, separator: Separator) -> Self {
109        Self {
110            clique1,
111            clique2,
112            separator,
113            message_1_to_2: None,
114            message_2_to_1: None,
115        }
116    }
117}
118
119/// Junction tree structure for exact inference.
120#[derive(Debug, Clone)]
121pub struct JunctionTree {
122    /// Cliques in the tree
123    pub cliques: Vec<Clique>,
124    /// Edges connecting cliques
125    pub edges: Vec<JunctionTreeEdge>,
126    /// Variable to clique mapping (for query efficiency)
127    pub var_to_cliques: HashMap<String, Vec<usize>>,
128    /// Whether the tree has been calibrated
129    pub calibrated: bool,
130}
131
132impl JunctionTree {
133    /// Create a new junction tree.
134    pub fn new() -> Self {
135        Self {
136            cliques: Vec::new(),
137            edges: Vec::new(),
138            var_to_cliques: HashMap::new(),
139            calibrated: false,
140        }
141    }
142
143    /// Build a junction tree from a factor graph.
144    ///
145    /// This implements the complete junction tree construction algorithm:
146    /// 1. Moralize the graph (if directed)
147    /// 2. Triangulate to create a chordal graph
148    /// 3. Identify maximal cliques
149    /// 4. Build a junction tree satisfying the running intersection property
150    pub fn from_factor_graph(graph: &FactorGraph) -> Result<Self> {
151        // Step 1: Extract interaction graph (moralized graph)
152        let interaction_graph = Self::build_interaction_graph(graph)?;
153
154        // Step 2: Triangulate the graph using min-fill heuristic
155        let triangulated = Self::triangulate(&interaction_graph)?;
156
157        // Step 3: Find maximal cliques
158        let cliques = Self::find_maximal_cliques(&triangulated)?;
159
160        // Step 4: Build junction tree from cliques
161        let mut tree = Self::build_tree_from_cliques(cliques)?;
162
163        // Step 5: Initialize clique potentials
164        tree.initialize_potentials(graph)?;
165
166        Ok(tree)
167    }
168
169    /// Build the interaction graph (moralized graph).
170    ///
171    /// For factor graphs, the interaction graph is an undirected graph where:
172    /// - Nodes are variables
173    /// - Edges connect variables that appear together in some factor
174    fn build_interaction_graph(graph: &FactorGraph) -> Result<HashMap<String, HashSet<String>>> {
175        let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
176
177        // Initialize all variables
178        for var_name in graph.variable_names() {
179            adjacency.insert(var_name.clone(), HashSet::new());
180        }
181
182        // Add edges for each factor
183        for factor in graph.factors() {
184            let vars = &factor.variables;
185            // Connect all pairs of variables in the factor
186            for i in 0..vars.len() {
187                for j in (i + 1)..vars.len() {
188                    let v1 = &vars[i];
189                    let v2 = &vars[j];
190
191                    adjacency.entry(v1.clone()).or_default().insert(v2.clone());
192                    adjacency.entry(v2.clone()).or_default().insert(v1.clone());
193                }
194            }
195        }
196
197        Ok(adjacency)
198    }
199
200    /// Triangulate the graph to make it chordal.
201    ///
202    /// Uses the min-fill heuristic for variable elimination ordering.
203    /// A chordal graph has the property that every cycle of length ≥4 has a chord.
204    fn triangulate(
205        graph: &HashMap<String, HashSet<String>>,
206    ) -> Result<HashMap<String, HashSet<String>>> {
207        let mut triangulated = graph.clone();
208        let mut remaining: HashSet<String> = graph.keys().cloned().collect();
209
210        while !remaining.is_empty() {
211            // Find variable with minimum fill-in edges
212            let var = Self::find_min_fill_variable(&triangulated, &remaining)?;
213
214            // Get neighbors of the variable
215            let neighbors: Vec<String> = triangulated
216                .get(&var)
217                .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
218                .intersection(&remaining)
219                .cloned()
220                .collect();
221
222            // Add fill-in edges (connect all pairs of neighbors)
223            for i in 0..neighbors.len() {
224                for j in (i + 1)..neighbors.len() {
225                    let n1 = &neighbors[i];
226                    let n2 = &neighbors[j];
227
228                    triangulated
229                        .entry(n1.clone())
230                        .or_default()
231                        .insert(n2.clone());
232                    triangulated
233                        .entry(n2.clone())
234                        .or_default()
235                        .insert(n1.clone());
236                }
237            }
238
239            // Remove variable from remaining set
240            remaining.remove(&var);
241        }
242
243        Ok(triangulated)
244    }
245
246    /// Find the variable with minimum fill-in (min-fill heuristic).
247    fn find_min_fill_variable(
248        graph: &HashMap<String, HashSet<String>>,
249        remaining: &HashSet<String>,
250    ) -> Result<String> {
251        let mut min_fill = usize::MAX;
252        let mut best_var = None;
253
254        for var in remaining {
255            let neighbors: Vec<String> = graph
256                .get(var)
257                .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
258                .intersection(remaining)
259                .cloned()
260                .collect();
261
262            // Count how many edges need to be added
263            let mut fill_count = 0;
264            for i in 0..neighbors.len() {
265                for j in (i + 1)..neighbors.len() {
266                    let n1 = &neighbors[i];
267                    let n2 = &neighbors[j];
268                    if !graph
269                        .get(n1)
270                        .expect("n1 neighbor set present in triangulated graph")
271                        .contains(n2)
272                    {
273                        fill_count += 1;
274                    }
275                }
276            }
277
278            if fill_count < min_fill {
279                min_fill = fill_count;
280                best_var = Some(var.clone());
281            }
282        }
283
284        best_var.ok_or_else(|| PgmError::InvalidGraph("No variable found".to_string()))
285    }
286
287    /// Find maximal cliques in a triangulated graph.
288    ///
289    /// Uses a greedy algorithm to identify maximal cliques.
290    fn find_maximal_cliques(
291        graph: &HashMap<String, HashSet<String>>,
292    ) -> Result<Vec<HashSet<String>>> {
293        let mut cliques = Vec::new();
294        let mut visited: HashSet<String> = HashSet::new();
295
296        // Start from each variable and grow cliques
297        for var in graph.keys() {
298            if visited.contains(var) {
299                continue;
300            }
301
302            let mut clique: HashSet<String> = HashSet::new();
303            clique.insert(var.clone());
304
305            // Add neighbors that form a clique
306            for neighbor in graph.get(var).expect("var present in graph adjacency") {
307                // Check if neighbor is connected to all current clique members
308                let is_fully_connected = clique.iter().all(|c| {
309                    c == neighbor
310                        || graph
311                            .get(neighbor)
312                            .expect("neighbor present in graph adjacency")
313                            .contains(c)
314                });
315
316                if is_fully_connected {
317                    clique.insert(neighbor.clone());
318                }
319            }
320
321            // Check if this is a maximal clique
322            let is_maximal = !cliques
323                .iter()
324                .any(|c: &HashSet<String>| c.is_superset(&clique));
325
326            if is_maximal {
327                // Remove non-maximal cliques that are subsets of this one
328                cliques.retain(|c| !clique.is_superset(c));
329                cliques.push(clique.clone());
330            }
331
332            visited.insert(var.clone());
333        }
334
335        // Ensure we have at least one clique
336        if cliques.is_empty() && !graph.is_empty() {
337            // Create a clique with all variables (fallback)
338            let all_vars: HashSet<String> = graph.keys().cloned().collect();
339            cliques.push(all_vars);
340        }
341
342        Ok(cliques)
343    }
344
345    /// Build a junction tree from maximal cliques.
346    ///
347    /// Uses a maximum spanning tree algorithm based on separator size.
348    fn build_tree_from_cliques(clique_sets: Vec<HashSet<String>>) -> Result<Self> {
349        let mut tree = JunctionTree::new();
350
351        // Create clique nodes
352        for (id, vars) in clique_sets.into_iter().enumerate() {
353            let clique = Clique::new(id, vars.clone());
354
355            // Update variable to clique mapping
356            for var in &vars {
357                tree.var_to_cliques.entry(var.clone()).or_default().push(id);
358            }
359
360            tree.cliques.push(clique);
361        }
362
363        // Build maximum spanning tree based on separator size
364        if tree.cliques.len() > 1 {
365            tree.build_maximum_spanning_tree()?;
366        }
367
368        Ok(tree)
369    }
370
371    /// Build a maximum spanning tree connecting cliques.
372    ///
373    /// Uses Prim's algorithm with separator size as edge weight.
374    fn build_maximum_spanning_tree(&mut self) -> Result<()> {
375        let n = self.cliques.len();
376        if n == 0 {
377            return Ok(());
378        }
379
380        let mut in_tree = vec![false; n];
381        let mut edges_to_add: Vec<(usize, usize, usize)> = Vec::new();
382
383        // Start with clique 0
384        in_tree[0] = true;
385        let mut tree_size = 1;
386
387        while tree_size < n {
388            let mut best_edge = None;
389            let mut best_weight = 0;
390
391            // Find best edge to add
392            for i in 0..n {
393                if !in_tree[i] {
394                    continue;
395                }
396
397                for (j, &is_in_tree) in in_tree.iter().enumerate().take(n) {
398                    if is_in_tree {
399                        continue;
400                    }
401
402                    let separator = self.cliques[i].intersection(&self.cliques[j]);
403                    let weight = separator.len();
404
405                    if weight > best_weight {
406                        best_weight = weight;
407                        best_edge = Some((i, j, weight));
408                    }
409                }
410            }
411
412            if let Some((i, j, _)) = best_edge {
413                edges_to_add.push((i, j, best_weight));
414                in_tree[j] = true;
415                tree_size += 1;
416            } else {
417                break;
418            }
419        }
420
421        // Create edges
422        for (c1, c2, _) in edges_to_add {
423            let separator = Separator::from_cliques(&self.cliques[c1], &self.cliques[c2]);
424            let edge = JunctionTreeEdge::new(c1, c2, separator);
425            self.edges.push(edge);
426        }
427
428        Ok(())
429    }
430
431    /// Initialize clique potentials from the factor graph.
432    ///
433    /// Assigns each factor to a clique that contains all its variables.
434    fn initialize_potentials(&mut self, graph: &FactorGraph) -> Result<()> {
435        for factor in graph.factors() {
436            let factor_vars: HashSet<String> = factor.variables.iter().cloned().collect();
437
438            // Find a clique that contains all variables in this factor
439            let clique_idx = self
440                .cliques
441                .iter()
442                .position(|c| c.contains_all(&factor_vars))
443                .ok_or_else(|| {
444                    PgmError::InvalidGraph(format!(
445                        "No clique contains all variables for factor: {:?}",
446                        factor.name
447                    ))
448                })?;
449
450            let clique = &mut self.cliques[clique_idx];
451
452            // Multiply factor into clique potential
453            if let Some(ref mut potential) = clique.potential {
454                *potential = potential.product(factor)?;
455            } else {
456                clique.potential = Some(factor.clone());
457            }
458        }
459
460        // Initialize cliques without factors to uniform potentials
461        for clique in &mut self.cliques {
462            if clique.potential.is_none() {
463                // Create uniform potential
464                clique.potential = Some(Self::create_uniform_potential(&clique.variables, graph)?);
465            }
466        }
467
468        Ok(())
469    }
470
471    /// Create a uniform potential over a set of variables.
472    fn create_uniform_potential(
473        variables: &HashSet<String>,
474        graph: &FactorGraph,
475    ) -> Result<Factor> {
476        let var_vec: Vec<String> = variables.iter().cloned().collect();
477        let mut shape = Vec::new();
478
479        for var in &var_vec {
480            let cardinality = graph
481                .get_variable(var)
482                .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", var)))?
483                .cardinality;
484            shape.push(cardinality);
485        }
486
487        let size: usize = shape.iter().product();
488        let values = vec![1.0; size];
489
490        let array = ArrayD::from_shape_vec(shape, values)
491            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
492
493        Factor::new("uniform".to_string(), var_vec, array)
494    }
495
496    /// Calibrate the junction tree by passing messages.
497    ///
498    /// This implements the message passing schedule for exact inference.
499    pub fn calibrate(&mut self) -> Result<()> {
500        if self.edges.is_empty() {
501            self.calibrated = true;
502            return Ok(());
503        }
504
505        // Collect evidence (inward pass) from leaves to root
506        let root = 0;
507        self.collect_evidence(root, None)?;
508
509        // Distribute evidence (outward pass) from root to leaves
510        self.distribute_evidence(root, None)?;
511
512        self.calibrated = true;
513        Ok(())
514    }
515
516    /// Collect evidence (inward pass).
517    fn collect_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
518        // Find children (neighbors except parent)
519        let children: Vec<usize> = self.get_neighbors(current, parent);
520
521        // Recursively collect from children
522        for child in &children {
523            self.collect_evidence(*child, Some(current))?;
524        }
525
526        // Send message to parent if exists
527        if let Some(parent_idx) = parent {
528            self.send_message(current, parent_idx)?;
529        }
530
531        Ok(())
532    }
533
534    /// Distribute evidence (outward pass).
535    fn distribute_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
536        // Find children (neighbors except parent)
537        let children: Vec<usize> = self.get_neighbors(current, parent);
538
539        // Send messages to all children
540        for child in &children {
541            self.send_message(current, *child)?;
542            self.distribute_evidence(*child, Some(current))?;
543        }
544
545        Ok(())
546    }
547
548    /// Get neighbors of a clique excluding the parent.
549    fn get_neighbors(&self, clique: usize, parent: Option<usize>) -> Vec<usize> {
550        let mut neighbors = Vec::new();
551
552        for edge in &self.edges {
553            if edge.clique1 == clique {
554                if parent != Some(edge.clique2) {
555                    neighbors.push(edge.clique2);
556                }
557            } else if edge.clique2 == clique && parent != Some(edge.clique1) {
558                neighbors.push(edge.clique1);
559            }
560        }
561
562        neighbors
563    }
564
565    /// Send a message from one clique to another.
566    fn send_message(&mut self, from: usize, to: usize) -> Result<()> {
567        // Find the edge
568        let edge_idx = self
569            .edges
570            .iter()
571            .position(|e| {
572                (e.clique1 == from && e.clique2 == to) || (e.clique1 == to && e.clique2 == from)
573            })
574            .ok_or_else(|| PgmError::InvalidGraph("Edge not found".to_string()))?;
575
576        // Get separator variables
577        let separator_vars = self.edges[edge_idx].separator.variables.clone();
578
579        // Get clique potential
580        let clique_potential = self.cliques[from].potential.clone().ok_or_else(|| {
581            PgmError::InvalidGraph("Clique potential not initialized".to_string())
582        })?;
583
584        // Marginalize out variables not in separator
585        let mut message = clique_potential;
586        let all_vars: HashSet<String> = message.variables.iter().cloned().collect();
587        let vars_to_eliminate: Vec<String> =
588            all_vars.difference(&separator_vars).cloned().collect();
589
590        for var in vars_to_eliminate {
591            message = message.marginalize_out(&var)?;
592        }
593
594        // Store message
595        let edge = &mut self.edges[edge_idx];
596        if edge.clique1 == from {
597            edge.message_1_to_2 = Some(message);
598        } else {
599            edge.message_2_to_1 = Some(message);
600        }
601
602        Ok(())
603    }
604
605    /// Query marginal probability for a variable.
606    pub fn query_marginal(&self, variable: &str) -> Result<ArrayD<f64>> {
607        if !self.calibrated {
608            return Err(PgmError::InvalidGraph(
609                "Tree must be calibrated before querying".to_string(),
610            ));
611        }
612
613        // Find a clique containing this variable
614        let clique_indices = self
615            .var_to_cliques
616            .get(variable)
617            .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", variable)))?;
618
619        if clique_indices.is_empty() {
620            return Err(PgmError::InvalidGraph(format!(
621                "No clique contains variable {}",
622                variable
623            )));
624        }
625
626        // Get belief from first clique
627        let clique = &self.cliques[clique_indices[0]];
628        let mut belief = clique.potential.clone().ok_or_else(|| {
629            PgmError::InvalidGraph("Clique potential not initialized".to_string())
630        })?;
631
632        // Marginalize out all variables except the query variable
633        let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
634        let mut target_set = HashSet::new();
635        target_set.insert(variable.to_string());
636        let vars_to_eliminate: Vec<String> = all_vars.difference(&target_set).cloned().collect();
637
638        for var in vars_to_eliminate {
639            belief = belief.marginalize_out(&var)?;
640        }
641
642        // Normalize
643        belief.normalize();
644
645        Ok(belief.values)
646    }
647
648    /// Query joint marginal over multiple variables.
649    pub fn query_joint_marginal(&self, variables: &[String]) -> Result<ArrayD<f64>> {
650        if !self.calibrated {
651            return Err(PgmError::InvalidGraph(
652                "Tree must be calibrated before querying".to_string(),
653            ));
654        }
655
656        let var_set: HashSet<String> = variables.iter().cloned().collect();
657
658        // Find a clique containing all these variables
659        let clique = self
660            .cliques
661            .iter()
662            .find(|c| c.contains_all(&var_set))
663            .ok_or_else(|| {
664                PgmError::InvalidGraph(format!("No clique contains all variables: {:?}", variables))
665            })?;
666
667        let mut belief = clique.potential.clone().ok_or_else(|| {
668            PgmError::InvalidGraph("Clique potential not initialized".to_string())
669        })?;
670
671        // Marginalize out variables not in query
672        let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
673        let vars_to_eliminate: Vec<String> = all_vars.difference(&var_set).cloned().collect();
674
675        for var in vars_to_eliminate {
676            belief = belief.marginalize_out(&var)?;
677        }
678
679        // Normalize
680        belief.normalize();
681
682        Ok(belief.values)
683    }
684
685    /// Get the treewidth of this junction tree.
686    ///
687    /// The treewidth is the size of the largest clique minus 1.
688    pub fn treewidth(&self) -> usize {
689        self.cliques
690            .iter()
691            .map(|c| c.variables.len())
692            .max()
693            .unwrap_or(0)
694            .saturating_sub(1)
695    }
696
697    /// Check if the junction tree satisfies the running intersection property.
698    ///
699    /// For every variable X, the set of cliques containing X forms a connected subtree.
700    pub fn verify_running_intersection_property(&self) -> bool {
701        for var in self.var_to_cliques.keys() {
702            let cliques_with_var = self
703                .var_to_cliques
704                .get(var)
705                .expect("var present in var_to_cliques, iterating over known keys");
706
707            if cliques_with_var.len() <= 1 {
708                continue;
709            }
710
711            // Check if these cliques form a connected component
712            if !self.is_connected_subgraph(cliques_with_var) {
713                return false;
714            }
715        }
716
717        true
718    }
719
720    /// Check if a set of cliques forms a connected subgraph.
721    fn is_connected_subgraph(&self, cliques: &[usize]) -> bool {
722        if cliques.is_empty() {
723            return true;
724        }
725
726        let clique_set: HashSet<usize> = cliques.iter().copied().collect();
727        let mut visited = HashSet::new();
728        let mut queue = VecDeque::new();
729
730        // Start BFS from first clique
731        queue.push_back(cliques[0]);
732        visited.insert(cliques[0]);
733
734        while let Some(current) = queue.pop_front() {
735            for edge in &self.edges {
736                let neighbor = if edge.clique1 == current {
737                    Some(edge.clique2)
738                } else if edge.clique2 == current {
739                    Some(edge.clique1)
740                } else {
741                    None
742                };
743
744                if let Some(n) = neighbor {
745                    if clique_set.contains(&n) && !visited.contains(&n) {
746                        visited.insert(n);
747                        queue.push_back(n);
748                    }
749                }
750            }
751        }
752
753        visited.len() == cliques.len()
754    }
755}
756
757impl Default for JunctionTree {
758    fn default() -> Self {
759        Self::new()
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use crate::graph::FactorGraph;
767    use approx::assert_abs_diff_eq;
768    use scirs2_core::ndarray::Array;
769
770    #[test]
771    fn test_clique_creation() {
772        let mut vars = HashSet::new();
773        vars.insert("x".to_string());
774        vars.insert("y".to_string());
775
776        let clique = Clique::new(0, vars);
777        assert_eq!(clique.id, 0);
778        assert_eq!(clique.variables.len(), 2);
779    }
780
781    #[test]
782    fn test_clique_intersection() {
783        let mut vars1 = HashSet::new();
784        vars1.insert("x".to_string());
785        vars1.insert("y".to_string());
786
787        let mut vars2 = HashSet::new();
788        vars2.insert("y".to_string());
789        vars2.insert("z".to_string());
790
791        let c1 = Clique::new(0, vars1);
792        let c2 = Clique::new(1, vars2);
793
794        let intersection = c1.intersection(&c2);
795        assert_eq!(intersection.len(), 1);
796        assert!(intersection.contains("y"));
797    }
798
799    #[test]
800    fn test_interaction_graph() {
801        let mut graph = FactorGraph::new();
802        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
803        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
804        graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
805
806        // Add factor P(x, y)
807        let pxy = Factor::new(
808            "P(x,y)".to_string(),
809            vec!["x".to_string(), "y".to_string()],
810            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
811                .expect("unwrap")
812                .into_dyn(),
813        )
814        .expect("unwrap");
815        graph.add_factor(pxy).expect("unwrap");
816
817        // Add factor P(y, z)
818        let pyz = Factor::new(
819            "P(y,z)".to_string(),
820            vec!["y".to_string(), "z".to_string()],
821            Array::from_shape_vec(vec![2, 2], vec![0.5, 0.1, 0.2, 0.2])
822                .expect("unwrap")
823                .into_dyn(),
824        )
825        .expect("unwrap");
826        graph.add_factor(pyz).expect("unwrap");
827
828        let interaction_graph = JunctionTree::build_interaction_graph(&graph).expect("unwrap");
829
830        // Check edges
831        assert!(interaction_graph.get("x").expect("unwrap").contains("y"));
832        assert!(interaction_graph.get("y").expect("unwrap").contains("x"));
833        assert!(interaction_graph.get("y").expect("unwrap").contains("z"));
834        assert!(interaction_graph.get("z").expect("unwrap").contains("y"));
835    }
836
837    #[test]
838    fn test_junction_tree_construction() {
839        let mut graph = FactorGraph::new();
840        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
841        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
842
843        let pxy = Factor::new(
844            "P(x,y)".to_string(),
845            vec!["x".to_string(), "y".to_string()],
846            Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
847                .expect("unwrap")
848                .into_dyn(),
849        )
850        .expect("unwrap");
851        graph.add_factor(pxy).expect("unwrap");
852
853        let tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
854
855        assert!(!tree.cliques.is_empty());
856        assert!(tree.verify_running_intersection_property());
857    }
858
859    #[test]
860    fn test_junction_tree_calibration() {
861        let mut graph = FactorGraph::new();
862        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
863        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
864
865        let pxy = Factor::new(
866            "P(x,y)".to_string(),
867            vec!["x".to_string(), "y".to_string()],
868            Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
869                .expect("unwrap")
870                .into_dyn(),
871        )
872        .expect("unwrap");
873        graph.add_factor(pxy).expect("unwrap");
874
875        let mut tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
876        tree.calibrate().expect("unwrap");
877
878        assert!(tree.calibrated);
879    }
880
881    #[test]
882    fn test_marginal_query() {
883        let mut graph = FactorGraph::new();
884        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
885        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
886
887        let pxy = Factor::new(
888            "P(x,y)".to_string(),
889            vec!["x".to_string(), "y".to_string()],
890            Array::from_shape_vec(vec![2, 2], vec![0.1, 0.4, 0.2, 0.3])
891                .expect("unwrap")
892                .into_dyn(),
893        )
894        .expect("unwrap");
895        graph.add_factor(pxy).expect("unwrap");
896
897        let mut tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
898        tree.calibrate().expect("unwrap");
899
900        let marginal_x = tree.query_marginal("x").expect("unwrap");
901
902        // P(x=0) = 0.1 + 0.4 = 0.5
903        // P(x=1) = 0.2 + 0.3 = 0.5
904        assert_abs_diff_eq!(marginal_x[[0]], 0.5, epsilon = 1e-6);
905        assert_abs_diff_eq!(marginal_x[[1]], 0.5, epsilon = 1e-6);
906    }
907
908    #[test]
909    fn test_treewidth() {
910        let mut graph = FactorGraph::new();
911        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
912        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
913        graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
914
915        let pxy = Factor::new(
916            "P(x,y)".to_string(),
917            vec!["x".to_string(), "y".to_string()],
918            Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
919                .expect("unwrap")
920                .into_dyn(),
921        )
922        .expect("unwrap");
923        let pyz = Factor::new(
924            "P(y,z)".to_string(),
925            vec!["y".to_string(), "z".to_string()],
926            Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.6, 0.4])
927                .expect("unwrap")
928                .into_dyn(),
929        )
930        .expect("unwrap");
931
932        graph.add_factor(pxy).expect("unwrap");
933        graph.add_factor(pyz).expect("unwrap");
934
935        let tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
936
937        // Treewidth should be at most 2 (clique size of 3 minus 1)
938        assert!(tree.treewidth() <= 2);
939    }
940}