1use crate::error::{PgmError, Result};
32use crate::factor::Factor;
33use crate::graph::FactorGraph;
34use scirs2_core::ndarray::ArrayD;
35use std::collections::{HashMap, HashSet, VecDeque};
36
37#[derive(Debug, Clone)]
39pub struct Clique {
40 pub id: usize,
42 pub variables: HashSet<String>,
44 pub potential: Option<Factor>,
46}
47
48impl Clique {
49 pub fn new(id: usize, variables: HashSet<String>) -> Self {
51 Self {
52 id,
53 variables,
54 potential: None,
55 }
56 }
57
58 pub fn contains_all(&self, vars: &HashSet<String>) -> bool {
60 vars.is_subset(&self.variables)
61 }
62
63 pub fn intersection(&self, other: &Clique) -> HashSet<String> {
65 self.variables
66 .intersection(&other.variables)
67 .cloned()
68 .collect()
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct Separator {
75 pub variables: HashSet<String>,
77 pub potential: Option<Factor>,
79}
80
81impl Separator {
82 pub fn from_cliques(c1: &Clique, c2: &Clique) -> Self {
84 Self {
85 variables: c1.intersection(c2),
86 potential: None,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct JunctionTreeEdge {
94 pub clique1: usize,
96 pub clique2: usize,
98 pub separator: Separator,
100 pub message_1_to_2: Option<Factor>,
102 pub message_2_to_1: Option<Factor>,
104}
105
106impl JunctionTreeEdge {
107 pub fn new(clique1: usize, clique2: usize, separator: Separator) -> Self {
109 Self {
110 clique1,
111 clique2,
112 separator,
113 message_1_to_2: None,
114 message_2_to_1: None,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct JunctionTree {
122 pub cliques: Vec<Clique>,
124 pub edges: Vec<JunctionTreeEdge>,
126 pub var_to_cliques: HashMap<String, Vec<usize>>,
128 pub calibrated: bool,
130}
131
132impl JunctionTree {
133 pub fn new() -> Self {
135 Self {
136 cliques: Vec::new(),
137 edges: Vec::new(),
138 var_to_cliques: HashMap::new(),
139 calibrated: false,
140 }
141 }
142
143 pub fn from_factor_graph(graph: &FactorGraph) -> Result<Self> {
151 let interaction_graph = Self::build_interaction_graph(graph)?;
153
154 let triangulated = Self::triangulate(&interaction_graph)?;
156
157 let cliques = Self::find_maximal_cliques(&triangulated)?;
159
160 let mut tree = Self::build_tree_from_cliques(cliques)?;
162
163 tree.initialize_potentials(graph)?;
165
166 Ok(tree)
167 }
168
169 fn build_interaction_graph(graph: &FactorGraph) -> Result<HashMap<String, HashSet<String>>> {
175 let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
176
177 for var_name in graph.variable_names() {
179 adjacency.insert(var_name.clone(), HashSet::new());
180 }
181
182 for factor in graph.factors() {
184 let vars = &factor.variables;
185 for i in 0..vars.len() {
187 for j in (i + 1)..vars.len() {
188 let v1 = &vars[i];
189 let v2 = &vars[j];
190
191 adjacency.entry(v1.clone()).or_default().insert(v2.clone());
192 adjacency.entry(v2.clone()).or_default().insert(v1.clone());
193 }
194 }
195 }
196
197 Ok(adjacency)
198 }
199
200 fn triangulate(
205 graph: &HashMap<String, HashSet<String>>,
206 ) -> Result<HashMap<String, HashSet<String>>> {
207 let mut triangulated = graph.clone();
208 let mut remaining: HashSet<String> = graph.keys().cloned().collect();
209
210 while !remaining.is_empty() {
211 let var = Self::find_min_fill_variable(&triangulated, &remaining)?;
213
214 let neighbors: Vec<String> = triangulated
216 .get(&var)
217 .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
218 .intersection(&remaining)
219 .cloned()
220 .collect();
221
222 for i in 0..neighbors.len() {
224 for j in (i + 1)..neighbors.len() {
225 let n1 = &neighbors[i];
226 let n2 = &neighbors[j];
227
228 triangulated
229 .entry(n1.clone())
230 .or_default()
231 .insert(n2.clone());
232 triangulated
233 .entry(n2.clone())
234 .or_default()
235 .insert(n1.clone());
236 }
237 }
238
239 remaining.remove(&var);
241 }
242
243 Ok(triangulated)
244 }
245
246 fn find_min_fill_variable(
248 graph: &HashMap<String, HashSet<String>>,
249 remaining: &HashSet<String>,
250 ) -> Result<String> {
251 let mut min_fill = usize::MAX;
252 let mut best_var = None;
253
254 for var in remaining {
255 let neighbors: Vec<String> = graph
256 .get(var)
257 .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
258 .intersection(remaining)
259 .cloned()
260 .collect();
261
262 let mut fill_count = 0;
264 for i in 0..neighbors.len() {
265 for j in (i + 1)..neighbors.len() {
266 let n1 = &neighbors[i];
267 let n2 = &neighbors[j];
268 if !graph.get(n1).unwrap().contains(n2) {
269 fill_count += 1;
270 }
271 }
272 }
273
274 if fill_count < min_fill {
275 min_fill = fill_count;
276 best_var = Some(var.clone());
277 }
278 }
279
280 best_var.ok_or_else(|| PgmError::InvalidGraph("No variable found".to_string()))
281 }
282
283 fn find_maximal_cliques(
287 graph: &HashMap<String, HashSet<String>>,
288 ) -> Result<Vec<HashSet<String>>> {
289 let mut cliques = Vec::new();
290 let mut visited: HashSet<String> = HashSet::new();
291
292 for var in graph.keys() {
294 if visited.contains(var) {
295 continue;
296 }
297
298 let mut clique: HashSet<String> = HashSet::new();
299 clique.insert(var.clone());
300
301 for neighbor in graph.get(var).unwrap() {
303 let is_fully_connected = clique
305 .iter()
306 .all(|c| c == neighbor || graph.get(neighbor).unwrap().contains(c));
307
308 if is_fully_connected {
309 clique.insert(neighbor.clone());
310 }
311 }
312
313 let is_maximal = !cliques
315 .iter()
316 .any(|c: &HashSet<String>| c.is_superset(&clique));
317
318 if is_maximal {
319 cliques.retain(|c| !clique.is_superset(c));
321 cliques.push(clique.clone());
322 }
323
324 visited.insert(var.clone());
325 }
326
327 if cliques.is_empty() && !graph.is_empty() {
329 let all_vars: HashSet<String> = graph.keys().cloned().collect();
331 cliques.push(all_vars);
332 }
333
334 Ok(cliques)
335 }
336
337 fn build_tree_from_cliques(clique_sets: Vec<HashSet<String>>) -> Result<Self> {
341 let mut tree = JunctionTree::new();
342
343 for (id, vars) in clique_sets.into_iter().enumerate() {
345 let clique = Clique::new(id, vars.clone());
346
347 for var in &vars {
349 tree.var_to_cliques.entry(var.clone()).or_default().push(id);
350 }
351
352 tree.cliques.push(clique);
353 }
354
355 if tree.cliques.len() > 1 {
357 tree.build_maximum_spanning_tree()?;
358 }
359
360 Ok(tree)
361 }
362
363 fn build_maximum_spanning_tree(&mut self) -> Result<()> {
367 let n = self.cliques.len();
368 if n == 0 {
369 return Ok(());
370 }
371
372 let mut in_tree = vec![false; n];
373 let mut edges_to_add: Vec<(usize, usize, usize)> = Vec::new();
374
375 in_tree[0] = true;
377 let mut tree_size = 1;
378
379 while tree_size < n {
380 let mut best_edge = None;
381 let mut best_weight = 0;
382
383 for i in 0..n {
385 if !in_tree[i] {
386 continue;
387 }
388
389 for (j, &is_in_tree) in in_tree.iter().enumerate().take(n) {
390 if is_in_tree {
391 continue;
392 }
393
394 let separator = self.cliques[i].intersection(&self.cliques[j]);
395 let weight = separator.len();
396
397 if weight > best_weight {
398 best_weight = weight;
399 best_edge = Some((i, j, weight));
400 }
401 }
402 }
403
404 if let Some((i, j, _)) = best_edge {
405 edges_to_add.push((i, j, best_weight));
406 in_tree[j] = true;
407 tree_size += 1;
408 } else {
409 break;
410 }
411 }
412
413 for (c1, c2, _) in edges_to_add {
415 let separator = Separator::from_cliques(&self.cliques[c1], &self.cliques[c2]);
416 let edge = JunctionTreeEdge::new(c1, c2, separator);
417 self.edges.push(edge);
418 }
419
420 Ok(())
421 }
422
423 fn initialize_potentials(&mut self, graph: &FactorGraph) -> Result<()> {
427 for factor in graph.factors() {
428 let factor_vars: HashSet<String> = factor.variables.iter().cloned().collect();
429
430 let clique_idx = self
432 .cliques
433 .iter()
434 .position(|c| c.contains_all(&factor_vars))
435 .ok_or_else(|| {
436 PgmError::InvalidGraph(format!(
437 "No clique contains all variables for factor: {:?}",
438 factor.name
439 ))
440 })?;
441
442 let clique = &mut self.cliques[clique_idx];
443
444 if let Some(ref mut potential) = clique.potential {
446 *potential = potential.product(factor)?;
447 } else {
448 clique.potential = Some(factor.clone());
449 }
450 }
451
452 for clique in &mut self.cliques {
454 if clique.potential.is_none() {
455 clique.potential = Some(Self::create_uniform_potential(&clique.variables, graph)?);
457 }
458 }
459
460 Ok(())
461 }
462
463 fn create_uniform_potential(
465 variables: &HashSet<String>,
466 graph: &FactorGraph,
467 ) -> Result<Factor> {
468 let var_vec: Vec<String> = variables.iter().cloned().collect();
469 let mut shape = Vec::new();
470
471 for var in &var_vec {
472 let cardinality = graph
473 .get_variable(var)
474 .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", var)))?
475 .cardinality;
476 shape.push(cardinality);
477 }
478
479 let size: usize = shape.iter().product();
480 let values = vec![1.0; size];
481
482 let array = ArrayD::from_shape_vec(shape, values)
483 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
484
485 Factor::new("uniform".to_string(), var_vec, array)
486 }
487
488 pub fn calibrate(&mut self) -> Result<()> {
492 if self.edges.is_empty() {
493 self.calibrated = true;
494 return Ok(());
495 }
496
497 let root = 0;
499 self.collect_evidence(root, None)?;
500
501 self.distribute_evidence(root, None)?;
503
504 self.calibrated = true;
505 Ok(())
506 }
507
508 fn collect_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
510 let children: Vec<usize> = self.get_neighbors(current, parent);
512
513 for child in &children {
515 self.collect_evidence(*child, Some(current))?;
516 }
517
518 if let Some(parent_idx) = parent {
520 self.send_message(current, parent_idx)?;
521 }
522
523 Ok(())
524 }
525
526 fn distribute_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
528 let children: Vec<usize> = self.get_neighbors(current, parent);
530
531 for child in &children {
533 self.send_message(current, *child)?;
534 self.distribute_evidence(*child, Some(current))?;
535 }
536
537 Ok(())
538 }
539
540 fn get_neighbors(&self, clique: usize, parent: Option<usize>) -> Vec<usize> {
542 let mut neighbors = Vec::new();
543
544 for edge in &self.edges {
545 if edge.clique1 == clique {
546 if parent.is_none() || parent.unwrap() != edge.clique2 {
547 neighbors.push(edge.clique2);
548 }
549 } else if edge.clique2 == clique
550 && (parent.is_none() || parent.unwrap() != edge.clique1)
551 {
552 neighbors.push(edge.clique1);
553 }
554 }
555
556 neighbors
557 }
558
559 fn send_message(&mut self, from: usize, to: usize) -> Result<()> {
561 let edge_idx = self
563 .edges
564 .iter()
565 .position(|e| {
566 (e.clique1 == from && e.clique2 == to) || (e.clique1 == to && e.clique2 == from)
567 })
568 .ok_or_else(|| PgmError::InvalidGraph("Edge not found".to_string()))?;
569
570 let separator_vars = self.edges[edge_idx].separator.variables.clone();
572
573 let clique_potential = self.cliques[from].potential.clone().ok_or_else(|| {
575 PgmError::InvalidGraph("Clique potential not initialized".to_string())
576 })?;
577
578 let mut message = clique_potential;
580 let all_vars: HashSet<String> = message.variables.iter().cloned().collect();
581 let vars_to_eliminate: Vec<String> =
582 all_vars.difference(&separator_vars).cloned().collect();
583
584 for var in vars_to_eliminate {
585 message = message.marginalize_out(&var)?;
586 }
587
588 let edge = &mut self.edges[edge_idx];
590 if edge.clique1 == from {
591 edge.message_1_to_2 = Some(message);
592 } else {
593 edge.message_2_to_1 = Some(message);
594 }
595
596 Ok(())
597 }
598
599 pub fn query_marginal(&self, variable: &str) -> Result<ArrayD<f64>> {
601 if !self.calibrated {
602 return Err(PgmError::InvalidGraph(
603 "Tree must be calibrated before querying".to_string(),
604 ));
605 }
606
607 let clique_indices = self
609 .var_to_cliques
610 .get(variable)
611 .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", variable)))?;
612
613 if clique_indices.is_empty() {
614 return Err(PgmError::InvalidGraph(format!(
615 "No clique contains variable {}",
616 variable
617 )));
618 }
619
620 let clique = &self.cliques[clique_indices[0]];
622 let mut belief = clique.potential.clone().ok_or_else(|| {
623 PgmError::InvalidGraph("Clique potential not initialized".to_string())
624 })?;
625
626 let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
628 let mut target_set = HashSet::new();
629 target_set.insert(variable.to_string());
630 let vars_to_eliminate: Vec<String> = all_vars.difference(&target_set).cloned().collect();
631
632 for var in vars_to_eliminate {
633 belief = belief.marginalize_out(&var)?;
634 }
635
636 belief.normalize();
638
639 Ok(belief.values)
640 }
641
642 pub fn query_joint_marginal(&self, variables: &[String]) -> Result<ArrayD<f64>> {
644 if !self.calibrated {
645 return Err(PgmError::InvalidGraph(
646 "Tree must be calibrated before querying".to_string(),
647 ));
648 }
649
650 let var_set: HashSet<String> = variables.iter().cloned().collect();
651
652 let clique = self
654 .cliques
655 .iter()
656 .find(|c| c.contains_all(&var_set))
657 .ok_or_else(|| {
658 PgmError::InvalidGraph(format!("No clique contains all variables: {:?}", variables))
659 })?;
660
661 let mut belief = clique.potential.clone().ok_or_else(|| {
662 PgmError::InvalidGraph("Clique potential not initialized".to_string())
663 })?;
664
665 let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
667 let vars_to_eliminate: Vec<String> = all_vars.difference(&var_set).cloned().collect();
668
669 for var in vars_to_eliminate {
670 belief = belief.marginalize_out(&var)?;
671 }
672
673 belief.normalize();
675
676 Ok(belief.values)
677 }
678
679 pub fn treewidth(&self) -> usize {
683 self.cliques
684 .iter()
685 .map(|c| c.variables.len())
686 .max()
687 .unwrap_or(0)
688 .saturating_sub(1)
689 }
690
691 pub fn verify_running_intersection_property(&self) -> bool {
695 for var in self.var_to_cliques.keys() {
696 let cliques_with_var = self.var_to_cliques.get(var).unwrap();
697
698 if cliques_with_var.len() <= 1 {
699 continue;
700 }
701
702 if !self.is_connected_subgraph(cliques_with_var) {
704 return false;
705 }
706 }
707
708 true
709 }
710
711 fn is_connected_subgraph(&self, cliques: &[usize]) -> bool {
713 if cliques.is_empty() {
714 return true;
715 }
716
717 let clique_set: HashSet<usize> = cliques.iter().copied().collect();
718 let mut visited = HashSet::new();
719 let mut queue = VecDeque::new();
720
721 queue.push_back(cliques[0]);
723 visited.insert(cliques[0]);
724
725 while let Some(current) = queue.pop_front() {
726 for edge in &self.edges {
727 let neighbor = if edge.clique1 == current {
728 Some(edge.clique2)
729 } else if edge.clique2 == current {
730 Some(edge.clique1)
731 } else {
732 None
733 };
734
735 if let Some(n) = neighbor {
736 if clique_set.contains(&n) && !visited.contains(&n) {
737 visited.insert(n);
738 queue.push_back(n);
739 }
740 }
741 }
742 }
743
744 visited.len() == cliques.len()
745 }
746}
747
748impl Default for JunctionTree {
749 fn default() -> Self {
750 Self::new()
751 }
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757 use crate::graph::FactorGraph;
758 use approx::assert_abs_diff_eq;
759 use scirs2_core::ndarray::Array;
760
761 #[test]
762 fn test_clique_creation() {
763 let mut vars = HashSet::new();
764 vars.insert("x".to_string());
765 vars.insert("y".to_string());
766
767 let clique = Clique::new(0, vars);
768 assert_eq!(clique.id, 0);
769 assert_eq!(clique.variables.len(), 2);
770 }
771
772 #[test]
773 fn test_clique_intersection() {
774 let mut vars1 = HashSet::new();
775 vars1.insert("x".to_string());
776 vars1.insert("y".to_string());
777
778 let mut vars2 = HashSet::new();
779 vars2.insert("y".to_string());
780 vars2.insert("z".to_string());
781
782 let c1 = Clique::new(0, vars1);
783 let c2 = Clique::new(1, vars2);
784
785 let intersection = c1.intersection(&c2);
786 assert_eq!(intersection.len(), 1);
787 assert!(intersection.contains("y"));
788 }
789
790 #[test]
791 fn test_interaction_graph() {
792 let mut graph = FactorGraph::new();
793 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
794 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
795 graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
796
797 let pxy = Factor::new(
799 "P(x,y)".to_string(),
800 vec!["x".to_string(), "y".to_string()],
801 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
802 .unwrap()
803 .into_dyn(),
804 )
805 .unwrap();
806 graph.add_factor(pxy).unwrap();
807
808 let pyz = Factor::new(
810 "P(y,z)".to_string(),
811 vec!["y".to_string(), "z".to_string()],
812 Array::from_shape_vec(vec![2, 2], vec![0.5, 0.1, 0.2, 0.2])
813 .unwrap()
814 .into_dyn(),
815 )
816 .unwrap();
817 graph.add_factor(pyz).unwrap();
818
819 let interaction_graph = JunctionTree::build_interaction_graph(&graph).unwrap();
820
821 assert!(interaction_graph.get("x").unwrap().contains("y"));
823 assert!(interaction_graph.get("y").unwrap().contains("x"));
824 assert!(interaction_graph.get("y").unwrap().contains("z"));
825 assert!(interaction_graph.get("z").unwrap().contains("y"));
826 }
827
828 #[test]
829 fn test_junction_tree_construction() {
830 let mut graph = FactorGraph::new();
831 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
832 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
833
834 let pxy = Factor::new(
835 "P(x,y)".to_string(),
836 vec!["x".to_string(), "y".to_string()],
837 Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
838 .unwrap()
839 .into_dyn(),
840 )
841 .unwrap();
842 graph.add_factor(pxy).unwrap();
843
844 let tree = JunctionTree::from_factor_graph(&graph).unwrap();
845
846 assert!(!tree.cliques.is_empty());
847 assert!(tree.verify_running_intersection_property());
848 }
849
850 #[test]
851 fn test_junction_tree_calibration() {
852 let mut graph = FactorGraph::new();
853 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
854 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
855
856 let pxy = Factor::new(
857 "P(x,y)".to_string(),
858 vec!["x".to_string(), "y".to_string()],
859 Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
860 .unwrap()
861 .into_dyn(),
862 )
863 .unwrap();
864 graph.add_factor(pxy).unwrap();
865
866 let mut tree = JunctionTree::from_factor_graph(&graph).unwrap();
867 tree.calibrate().unwrap();
868
869 assert!(tree.calibrated);
870 }
871
872 #[test]
873 fn test_marginal_query() {
874 let mut graph = FactorGraph::new();
875 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
876 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
877
878 let pxy = Factor::new(
879 "P(x,y)".to_string(),
880 vec!["x".to_string(), "y".to_string()],
881 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.4, 0.2, 0.3])
882 .unwrap()
883 .into_dyn(),
884 )
885 .unwrap();
886 graph.add_factor(pxy).unwrap();
887
888 let mut tree = JunctionTree::from_factor_graph(&graph).unwrap();
889 tree.calibrate().unwrap();
890
891 let marginal_x = tree.query_marginal("x").unwrap();
892
893 assert_abs_diff_eq!(marginal_x[[0]], 0.5, epsilon = 1e-6);
896 assert_abs_diff_eq!(marginal_x[[1]], 0.5, epsilon = 1e-6);
897 }
898
899 #[test]
900 fn test_treewidth() {
901 let mut graph = FactorGraph::new();
902 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
903 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
904 graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
905
906 let pxy = Factor::new(
907 "P(x,y)".to_string(),
908 vec!["x".to_string(), "y".to_string()],
909 Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
910 .unwrap()
911 .into_dyn(),
912 )
913 .unwrap();
914 let pyz = Factor::new(
915 "P(y,z)".to_string(),
916 vec!["y".to_string(), "z".to_string()],
917 Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.6, 0.4])
918 .unwrap()
919 .into_dyn(),
920 )
921 .unwrap();
922
923 graph.add_factor(pxy).unwrap();
924 graph.add_factor(pyz).unwrap();
925
926 let tree = JunctionTree::from_factor_graph(&graph).unwrap();
927
928 assert!(tree.treewidth() <= 2);
930 }
931}