1use crate::error::{PgmError, Result};
32use crate::factor::Factor;
33use crate::graph::FactorGraph;
34use scirs2_core::ndarray::ArrayD;
35use std::collections::{HashMap, HashSet, VecDeque};
36
37#[derive(Debug, Clone)]
39pub struct Clique {
40 pub id: usize,
42 pub variables: HashSet<String>,
44 pub potential: Option<Factor>,
46}
47
48impl Clique {
49 pub fn new(id: usize, variables: HashSet<String>) -> Self {
51 Self {
52 id,
53 variables,
54 potential: None,
55 }
56 }
57
58 pub fn contains_all(&self, vars: &HashSet<String>) -> bool {
60 vars.is_subset(&self.variables)
61 }
62
63 pub fn intersection(&self, other: &Clique) -> HashSet<String> {
65 self.variables
66 .intersection(&other.variables)
67 .cloned()
68 .collect()
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct Separator {
75 pub variables: HashSet<String>,
77 pub potential: Option<Factor>,
79}
80
81impl Separator {
82 pub fn from_cliques(c1: &Clique, c2: &Clique) -> Self {
84 Self {
85 variables: c1.intersection(c2),
86 potential: None,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct JunctionTreeEdge {
94 pub clique1: usize,
96 pub clique2: usize,
98 pub separator: Separator,
100 pub message_1_to_2: Option<Factor>,
102 pub message_2_to_1: Option<Factor>,
104}
105
106impl JunctionTreeEdge {
107 pub fn new(clique1: usize, clique2: usize, separator: Separator) -> Self {
109 Self {
110 clique1,
111 clique2,
112 separator,
113 message_1_to_2: None,
114 message_2_to_1: None,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct JunctionTree {
122 pub cliques: Vec<Clique>,
124 pub edges: Vec<JunctionTreeEdge>,
126 pub var_to_cliques: HashMap<String, Vec<usize>>,
128 pub calibrated: bool,
130}
131
132impl JunctionTree {
133 pub fn new() -> Self {
135 Self {
136 cliques: Vec::new(),
137 edges: Vec::new(),
138 var_to_cliques: HashMap::new(),
139 calibrated: false,
140 }
141 }
142
143 pub fn from_factor_graph(graph: &FactorGraph) -> Result<Self> {
151 let interaction_graph = Self::build_interaction_graph(graph)?;
153
154 let triangulated = Self::triangulate(&interaction_graph)?;
156
157 let cliques = Self::find_maximal_cliques(&triangulated)?;
159
160 let mut tree = Self::build_tree_from_cliques(cliques)?;
162
163 tree.initialize_potentials(graph)?;
165
166 Ok(tree)
167 }
168
169 fn build_interaction_graph(graph: &FactorGraph) -> Result<HashMap<String, HashSet<String>>> {
175 let mut adjacency: HashMap<String, HashSet<String>> = HashMap::new();
176
177 for var_name in graph.variable_names() {
179 adjacency.insert(var_name.clone(), HashSet::new());
180 }
181
182 for factor in graph.factors() {
184 let vars = &factor.variables;
185 for i in 0..vars.len() {
187 for j in (i + 1)..vars.len() {
188 let v1 = &vars[i];
189 let v2 = &vars[j];
190
191 adjacency.entry(v1.clone()).or_default().insert(v2.clone());
192 adjacency.entry(v2.clone()).or_default().insert(v1.clone());
193 }
194 }
195 }
196
197 Ok(adjacency)
198 }
199
200 fn triangulate(
205 graph: &HashMap<String, HashSet<String>>,
206 ) -> Result<HashMap<String, HashSet<String>>> {
207 let mut triangulated = graph.clone();
208 let mut remaining: HashSet<String> = graph.keys().cloned().collect();
209
210 while !remaining.is_empty() {
211 let var = Self::find_min_fill_variable(&triangulated, &remaining)?;
213
214 let neighbors: Vec<String> = triangulated
216 .get(&var)
217 .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
218 .intersection(&remaining)
219 .cloned()
220 .collect();
221
222 for i in 0..neighbors.len() {
224 for j in (i + 1)..neighbors.len() {
225 let n1 = &neighbors[i];
226 let n2 = &neighbors[j];
227
228 triangulated
229 .entry(n1.clone())
230 .or_default()
231 .insert(n2.clone());
232 triangulated
233 .entry(n2.clone())
234 .or_default()
235 .insert(n1.clone());
236 }
237 }
238
239 remaining.remove(&var);
241 }
242
243 Ok(triangulated)
244 }
245
246 fn find_min_fill_variable(
248 graph: &HashMap<String, HashSet<String>>,
249 remaining: &HashSet<String>,
250 ) -> Result<String> {
251 let mut min_fill = usize::MAX;
252 let mut best_var = None;
253
254 for var in remaining {
255 let neighbors: Vec<String> = graph
256 .get(var)
257 .ok_or_else(|| PgmError::InvalidGraph("Variable not found".to_string()))?
258 .intersection(remaining)
259 .cloned()
260 .collect();
261
262 let mut fill_count = 0;
264 for i in 0..neighbors.len() {
265 for j in (i + 1)..neighbors.len() {
266 let n1 = &neighbors[i];
267 let n2 = &neighbors[j];
268 if !graph
269 .get(n1)
270 .expect("n1 neighbor set present in triangulated graph")
271 .contains(n2)
272 {
273 fill_count += 1;
274 }
275 }
276 }
277
278 if fill_count < min_fill {
279 min_fill = fill_count;
280 best_var = Some(var.clone());
281 }
282 }
283
284 best_var.ok_or_else(|| PgmError::InvalidGraph("No variable found".to_string()))
285 }
286
287 fn find_maximal_cliques(
291 graph: &HashMap<String, HashSet<String>>,
292 ) -> Result<Vec<HashSet<String>>> {
293 let mut cliques = Vec::new();
294 let mut visited: HashSet<String> = HashSet::new();
295
296 for var in graph.keys() {
298 if visited.contains(var) {
299 continue;
300 }
301
302 let mut clique: HashSet<String> = HashSet::new();
303 clique.insert(var.clone());
304
305 for neighbor in graph.get(var).expect("var present in graph adjacency") {
307 let is_fully_connected = clique.iter().all(|c| {
309 c == neighbor
310 || graph
311 .get(neighbor)
312 .expect("neighbor present in graph adjacency")
313 .contains(c)
314 });
315
316 if is_fully_connected {
317 clique.insert(neighbor.clone());
318 }
319 }
320
321 let is_maximal = !cliques
323 .iter()
324 .any(|c: &HashSet<String>| c.is_superset(&clique));
325
326 if is_maximal {
327 cliques.retain(|c| !clique.is_superset(c));
329 cliques.push(clique.clone());
330 }
331
332 visited.insert(var.clone());
333 }
334
335 if cliques.is_empty() && !graph.is_empty() {
337 let all_vars: HashSet<String> = graph.keys().cloned().collect();
339 cliques.push(all_vars);
340 }
341
342 Ok(cliques)
343 }
344
345 fn build_tree_from_cliques(clique_sets: Vec<HashSet<String>>) -> Result<Self> {
349 let mut tree = JunctionTree::new();
350
351 for (id, vars) in clique_sets.into_iter().enumerate() {
353 let clique = Clique::new(id, vars.clone());
354
355 for var in &vars {
357 tree.var_to_cliques.entry(var.clone()).or_default().push(id);
358 }
359
360 tree.cliques.push(clique);
361 }
362
363 if tree.cliques.len() > 1 {
365 tree.build_maximum_spanning_tree()?;
366 }
367
368 Ok(tree)
369 }
370
371 fn build_maximum_spanning_tree(&mut self) -> Result<()> {
375 let n = self.cliques.len();
376 if n == 0 {
377 return Ok(());
378 }
379
380 let mut in_tree = vec![false; n];
381 let mut edges_to_add: Vec<(usize, usize, usize)> = Vec::new();
382
383 in_tree[0] = true;
385 let mut tree_size = 1;
386
387 while tree_size < n {
388 let mut best_edge = None;
389 let mut best_weight = 0;
390
391 for i in 0..n {
393 if !in_tree[i] {
394 continue;
395 }
396
397 for (j, &is_in_tree) in in_tree.iter().enumerate().take(n) {
398 if is_in_tree {
399 continue;
400 }
401
402 let separator = self.cliques[i].intersection(&self.cliques[j]);
403 let weight = separator.len();
404
405 if weight > best_weight {
406 best_weight = weight;
407 best_edge = Some((i, j, weight));
408 }
409 }
410 }
411
412 if let Some((i, j, _)) = best_edge {
413 edges_to_add.push((i, j, best_weight));
414 in_tree[j] = true;
415 tree_size += 1;
416 } else {
417 break;
418 }
419 }
420
421 for (c1, c2, _) in edges_to_add {
423 let separator = Separator::from_cliques(&self.cliques[c1], &self.cliques[c2]);
424 let edge = JunctionTreeEdge::new(c1, c2, separator);
425 self.edges.push(edge);
426 }
427
428 Ok(())
429 }
430
431 fn initialize_potentials(&mut self, graph: &FactorGraph) -> Result<()> {
435 for factor in graph.factors() {
436 let factor_vars: HashSet<String> = factor.variables.iter().cloned().collect();
437
438 let clique_idx = self
440 .cliques
441 .iter()
442 .position(|c| c.contains_all(&factor_vars))
443 .ok_or_else(|| {
444 PgmError::InvalidGraph(format!(
445 "No clique contains all variables for factor: {:?}",
446 factor.name
447 ))
448 })?;
449
450 let clique = &mut self.cliques[clique_idx];
451
452 if let Some(ref mut potential) = clique.potential {
454 *potential = potential.product(factor)?;
455 } else {
456 clique.potential = Some(factor.clone());
457 }
458 }
459
460 for clique in &mut self.cliques {
462 if clique.potential.is_none() {
463 clique.potential = Some(Self::create_uniform_potential(&clique.variables, graph)?);
465 }
466 }
467
468 Ok(())
469 }
470
471 fn create_uniform_potential(
473 variables: &HashSet<String>,
474 graph: &FactorGraph,
475 ) -> Result<Factor> {
476 let var_vec: Vec<String> = variables.iter().cloned().collect();
477 let mut shape = Vec::new();
478
479 for var in &var_vec {
480 let cardinality = graph
481 .get_variable(var)
482 .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", var)))?
483 .cardinality;
484 shape.push(cardinality);
485 }
486
487 let size: usize = shape.iter().product();
488 let values = vec![1.0; size];
489
490 let array = ArrayD::from_shape_vec(shape, values)
491 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
492
493 Factor::new("uniform".to_string(), var_vec, array)
494 }
495
496 pub fn calibrate(&mut self) -> Result<()> {
500 if self.edges.is_empty() {
501 self.calibrated = true;
502 return Ok(());
503 }
504
505 let root = 0;
507 self.collect_evidence(root, None)?;
508
509 self.distribute_evidence(root, None)?;
511
512 self.calibrated = true;
513 Ok(())
514 }
515
516 fn collect_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
518 let children: Vec<usize> = self.get_neighbors(current, parent);
520
521 for child in &children {
523 self.collect_evidence(*child, Some(current))?;
524 }
525
526 if let Some(parent_idx) = parent {
528 self.send_message(current, parent_idx)?;
529 }
530
531 Ok(())
532 }
533
534 fn distribute_evidence(&mut self, current: usize, parent: Option<usize>) -> Result<()> {
536 let children: Vec<usize> = self.get_neighbors(current, parent);
538
539 for child in &children {
541 self.send_message(current, *child)?;
542 self.distribute_evidence(*child, Some(current))?;
543 }
544
545 Ok(())
546 }
547
548 fn get_neighbors(&self, clique: usize, parent: Option<usize>) -> Vec<usize> {
550 let mut neighbors = Vec::new();
551
552 for edge in &self.edges {
553 if edge.clique1 == clique {
554 if parent != Some(edge.clique2) {
555 neighbors.push(edge.clique2);
556 }
557 } else if edge.clique2 == clique && parent != Some(edge.clique1) {
558 neighbors.push(edge.clique1);
559 }
560 }
561
562 neighbors
563 }
564
565 fn send_message(&mut self, from: usize, to: usize) -> Result<()> {
567 let edge_idx = self
569 .edges
570 .iter()
571 .position(|e| {
572 (e.clique1 == from && e.clique2 == to) || (e.clique1 == to && e.clique2 == from)
573 })
574 .ok_or_else(|| PgmError::InvalidGraph("Edge not found".to_string()))?;
575
576 let separator_vars = self.edges[edge_idx].separator.variables.clone();
578
579 let clique_potential = self.cliques[from].potential.clone().ok_or_else(|| {
581 PgmError::InvalidGraph("Clique potential not initialized".to_string())
582 })?;
583
584 let mut message = clique_potential;
586 let all_vars: HashSet<String> = message.variables.iter().cloned().collect();
587 let vars_to_eliminate: Vec<String> =
588 all_vars.difference(&separator_vars).cloned().collect();
589
590 for var in vars_to_eliminate {
591 message = message.marginalize_out(&var)?;
592 }
593
594 let edge = &mut self.edges[edge_idx];
596 if edge.clique1 == from {
597 edge.message_1_to_2 = Some(message);
598 } else {
599 edge.message_2_to_1 = Some(message);
600 }
601
602 Ok(())
603 }
604
605 pub fn query_marginal(&self, variable: &str) -> Result<ArrayD<f64>> {
607 if !self.calibrated {
608 return Err(PgmError::InvalidGraph(
609 "Tree must be calibrated before querying".to_string(),
610 ));
611 }
612
613 let clique_indices = self
615 .var_to_cliques
616 .get(variable)
617 .ok_or_else(|| PgmError::InvalidGraph(format!("Variable {} not found", variable)))?;
618
619 if clique_indices.is_empty() {
620 return Err(PgmError::InvalidGraph(format!(
621 "No clique contains variable {}",
622 variable
623 )));
624 }
625
626 let clique = &self.cliques[clique_indices[0]];
628 let mut belief = clique.potential.clone().ok_or_else(|| {
629 PgmError::InvalidGraph("Clique potential not initialized".to_string())
630 })?;
631
632 let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
634 let mut target_set = HashSet::new();
635 target_set.insert(variable.to_string());
636 let vars_to_eliminate: Vec<String> = all_vars.difference(&target_set).cloned().collect();
637
638 for var in vars_to_eliminate {
639 belief = belief.marginalize_out(&var)?;
640 }
641
642 belief.normalize();
644
645 Ok(belief.values)
646 }
647
648 pub fn query_joint_marginal(&self, variables: &[String]) -> Result<ArrayD<f64>> {
650 if !self.calibrated {
651 return Err(PgmError::InvalidGraph(
652 "Tree must be calibrated before querying".to_string(),
653 ));
654 }
655
656 let var_set: HashSet<String> = variables.iter().cloned().collect();
657
658 let clique = self
660 .cliques
661 .iter()
662 .find(|c| c.contains_all(&var_set))
663 .ok_or_else(|| {
664 PgmError::InvalidGraph(format!("No clique contains all variables: {:?}", variables))
665 })?;
666
667 let mut belief = clique.potential.clone().ok_or_else(|| {
668 PgmError::InvalidGraph("Clique potential not initialized".to_string())
669 })?;
670
671 let all_vars: HashSet<String> = belief.variables.iter().cloned().collect();
673 let vars_to_eliminate: Vec<String> = all_vars.difference(&var_set).cloned().collect();
674
675 for var in vars_to_eliminate {
676 belief = belief.marginalize_out(&var)?;
677 }
678
679 belief.normalize();
681
682 Ok(belief.values)
683 }
684
685 pub fn treewidth(&self) -> usize {
689 self.cliques
690 .iter()
691 .map(|c| c.variables.len())
692 .max()
693 .unwrap_or(0)
694 .saturating_sub(1)
695 }
696
697 pub fn verify_running_intersection_property(&self) -> bool {
701 for var in self.var_to_cliques.keys() {
702 let cliques_with_var = self
703 .var_to_cliques
704 .get(var)
705 .expect("var present in var_to_cliques, iterating over known keys");
706
707 if cliques_with_var.len() <= 1 {
708 continue;
709 }
710
711 if !self.is_connected_subgraph(cliques_with_var) {
713 return false;
714 }
715 }
716
717 true
718 }
719
720 fn is_connected_subgraph(&self, cliques: &[usize]) -> bool {
722 if cliques.is_empty() {
723 return true;
724 }
725
726 let clique_set: HashSet<usize> = cliques.iter().copied().collect();
727 let mut visited = HashSet::new();
728 let mut queue = VecDeque::new();
729
730 queue.push_back(cliques[0]);
732 visited.insert(cliques[0]);
733
734 while let Some(current) = queue.pop_front() {
735 for edge in &self.edges {
736 let neighbor = if edge.clique1 == current {
737 Some(edge.clique2)
738 } else if edge.clique2 == current {
739 Some(edge.clique1)
740 } else {
741 None
742 };
743
744 if let Some(n) = neighbor {
745 if clique_set.contains(&n) && !visited.contains(&n) {
746 visited.insert(n);
747 queue.push_back(n);
748 }
749 }
750 }
751 }
752
753 visited.len() == cliques.len()
754 }
755}
756
757impl Default for JunctionTree {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766 use crate::graph::FactorGraph;
767 use approx::assert_abs_diff_eq;
768 use scirs2_core::ndarray::Array;
769
770 #[test]
771 fn test_clique_creation() {
772 let mut vars = HashSet::new();
773 vars.insert("x".to_string());
774 vars.insert("y".to_string());
775
776 let clique = Clique::new(0, vars);
777 assert_eq!(clique.id, 0);
778 assert_eq!(clique.variables.len(), 2);
779 }
780
781 #[test]
782 fn test_clique_intersection() {
783 let mut vars1 = HashSet::new();
784 vars1.insert("x".to_string());
785 vars1.insert("y".to_string());
786
787 let mut vars2 = HashSet::new();
788 vars2.insert("y".to_string());
789 vars2.insert("z".to_string());
790
791 let c1 = Clique::new(0, vars1);
792 let c2 = Clique::new(1, vars2);
793
794 let intersection = c1.intersection(&c2);
795 assert_eq!(intersection.len(), 1);
796 assert!(intersection.contains("y"));
797 }
798
799 #[test]
800 fn test_interaction_graph() {
801 let mut graph = FactorGraph::new();
802 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
803 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
804 graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
805
806 let pxy = Factor::new(
808 "P(x,y)".to_string(),
809 vec!["x".to_string(), "y".to_string()],
810 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
811 .expect("unwrap")
812 .into_dyn(),
813 )
814 .expect("unwrap");
815 graph.add_factor(pxy).expect("unwrap");
816
817 let pyz = Factor::new(
819 "P(y,z)".to_string(),
820 vec!["y".to_string(), "z".to_string()],
821 Array::from_shape_vec(vec![2, 2], vec![0.5, 0.1, 0.2, 0.2])
822 .expect("unwrap")
823 .into_dyn(),
824 )
825 .expect("unwrap");
826 graph.add_factor(pyz).expect("unwrap");
827
828 let interaction_graph = JunctionTree::build_interaction_graph(&graph).expect("unwrap");
829
830 assert!(interaction_graph.get("x").expect("unwrap").contains("y"));
832 assert!(interaction_graph.get("y").expect("unwrap").contains("x"));
833 assert!(interaction_graph.get("y").expect("unwrap").contains("z"));
834 assert!(interaction_graph.get("z").expect("unwrap").contains("y"));
835 }
836
837 #[test]
838 fn test_junction_tree_construction() {
839 let mut graph = FactorGraph::new();
840 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
841 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
842
843 let pxy = Factor::new(
844 "P(x,y)".to_string(),
845 vec!["x".to_string(), "y".to_string()],
846 Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
847 .expect("unwrap")
848 .into_dyn(),
849 )
850 .expect("unwrap");
851 graph.add_factor(pxy).expect("unwrap");
852
853 let tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
854
855 assert!(!tree.cliques.is_empty());
856 assert!(tree.verify_running_intersection_property());
857 }
858
859 #[test]
860 fn test_junction_tree_calibration() {
861 let mut graph = FactorGraph::new();
862 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
863 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
864
865 let pxy = Factor::new(
866 "P(x,y)".to_string(),
867 vec!["x".to_string(), "y".to_string()],
868 Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
869 .expect("unwrap")
870 .into_dyn(),
871 )
872 .expect("unwrap");
873 graph.add_factor(pxy).expect("unwrap");
874
875 let mut tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
876 tree.calibrate().expect("unwrap");
877
878 assert!(tree.calibrated);
879 }
880
881 #[test]
882 fn test_marginal_query() {
883 let mut graph = FactorGraph::new();
884 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
885 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
886
887 let pxy = Factor::new(
888 "P(x,y)".to_string(),
889 vec!["x".to_string(), "y".to_string()],
890 Array::from_shape_vec(vec![2, 2], vec![0.1, 0.4, 0.2, 0.3])
891 .expect("unwrap")
892 .into_dyn(),
893 )
894 .expect("unwrap");
895 graph.add_factor(pxy).expect("unwrap");
896
897 let mut tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
898 tree.calibrate().expect("unwrap");
899
900 let marginal_x = tree.query_marginal("x").expect("unwrap");
901
902 assert_abs_diff_eq!(marginal_x[[0]], 0.5, epsilon = 1e-6);
905 assert_abs_diff_eq!(marginal_x[[1]], 0.5, epsilon = 1e-6);
906 }
907
908 #[test]
909 fn test_treewidth() {
910 let mut graph = FactorGraph::new();
911 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
912 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
913 graph.add_variable_with_card("z".to_string(), "Binary".to_string(), 2);
914
915 let pxy = Factor::new(
916 "P(x,y)".to_string(),
917 vec!["x".to_string(), "y".to_string()],
918 Array::from_shape_vec(vec![2, 2], vec![0.3, 0.7, 0.4, 0.6])
919 .expect("unwrap")
920 .into_dyn(),
921 )
922 .expect("unwrap");
923 let pyz = Factor::new(
924 "P(y,z)".to_string(),
925 vec!["y".to_string(), "z".to_string()],
926 Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.6, 0.4])
927 .expect("unwrap")
928 .into_dyn(),
929 )
930 .expect("unwrap");
931
932 graph.add_factor(pxy).expect("unwrap");
933 graph.add_factor(pyz).expect("unwrap");
934
935 let tree = JunctionTree::from_factor_graph(&graph).expect("unwrap");
936
937 assert!(tree.treewidth() <= 2);
939 }
940}