tensorlogic_quantrs_hooks/
inference.rs1use scirs2_core::ndarray::ArrayD;
4use std::collections::HashMap;
5
6use crate::error::Result;
7use crate::graph::FactorGraph;
8use crate::message_passing::MessagePassingAlgorithm;
9
10#[derive(Clone, Debug)]
12pub struct MarginalizationQuery {
13 pub variable: String,
15}
16
17#[derive(Clone, Debug)]
19pub struct ConditionalQuery {
20 pub query_vars: Vec<String>,
22 pub evidence: HashMap<String, usize>,
24}
25
26pub struct InferenceEngine {
28 graph: FactorGraph,
30 algorithm: Box<dyn MessagePassingAlgorithm>,
32}
33
34impl InferenceEngine {
35 pub fn new(graph: FactorGraph, algorithm: Box<dyn MessagePassingAlgorithm>) -> Self {
37 Self { graph, algorithm }
38 }
39
40 pub fn marginalize(&self, query: &MarginalizationQuery) -> Result<ArrayD<f64>> {
42 let marginals = self.algorithm.run(&self.graph)?;
43 marginals
44 .get(&query.variable)
45 .cloned()
46 .ok_or_else(|| crate::error::PgmError::VariableNotFound(query.variable.clone()))
47 }
48
49 pub fn conditional(&self, query: &ConditionalQuery) -> Result<HashMap<String, ArrayD<f64>>> {
51 let marginals = self.algorithm.run(&self.graph)?;
53
54 let mut result = HashMap::new();
56 for var in &query.query_vars {
57 if let Some(marginal) = marginals.get(var) {
58 result.insert(var.clone(), marginal.clone());
59 }
60 }
61
62 Ok(result)
63 }
64
65 pub fn joint(&self) -> Result<ArrayD<f64>> {
69 use crate::factor::Factor;
70
71 let all_vars: Vec<String> = self.graph.variable_names().cloned().collect();
73
74 if all_vars.is_empty() {
75 return Err(crate::error::PgmError::InvalidGraph(
76 "No variables in graph".to_string(),
77 ));
78 }
79
80 let mut joint_factor: Option<Factor> = None;
82
83 for factor_id in self.graph.factor_ids() {
84 if let Some(factor) = self.graph.get_factor(factor_id) {
85 joint_factor = if let Some(existing) = joint_factor {
86 Some(existing.product(factor)?)
87 } else {
88 Some(factor.clone())
89 };
90 }
91 }
92
93 if let Some(mut joint) = joint_factor {
95 joint.normalize();
96 Ok(joint.values)
97 } else {
98 let shape: Vec<usize> = all_vars
100 .iter()
101 .filter_map(|v| self.graph.get_variable(v))
102 .map(|n| n.cardinality)
103 .collect();
104 let size: usize = shape.iter().product();
105 Ok(ArrayD::from_elem(shape, 1.0 / size as f64))
106 }
107 }
108
109 pub fn graph(&self) -> &FactorGraph {
111 &self.graph
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::message_passing::SumProductAlgorithm;
119
120 #[test]
121 fn test_inference_engine() {
122 let mut graph = FactorGraph::new();
123 graph.add_variable("x".to_string(), "D1".to_string());
124
125 let algorithm = Box::new(SumProductAlgorithm::default());
126 let engine = InferenceEngine::new(graph, algorithm);
127
128 let query = MarginalizationQuery {
129 variable: "x".to_string(),
130 };
131
132 let result = engine.marginalize(&query);
133 assert!(result.is_ok());
134 }
135
136 #[test]
137 fn test_conditional_query() {
138 let mut graph = FactorGraph::new();
139 graph.add_variable("x".to_string(), "D1".to_string());
140 graph.add_variable("y".to_string(), "D2".to_string());
141
142 let algorithm = Box::new(SumProductAlgorithm::default());
143 let engine = InferenceEngine::new(graph, algorithm);
144
145 let query = ConditionalQuery {
146 query_vars: vec!["x".to_string()],
147 evidence: HashMap::new(),
148 };
149
150 let result = engine.conditional(&query);
151 assert!(result.is_ok());
152 }
153
154 #[test]
155 fn test_joint_probability() {
156 let mut graph = FactorGraph::new();
157 graph.add_variable("var_0".to_string(), "D1".to_string());
158
159 let algorithm = Box::new(SumProductAlgorithm::default());
160 let engine = InferenceEngine::new(graph, algorithm);
161
162 let joint = engine.joint();
163 assert!(joint.is_ok());
164
165 let sum: f64 = joint.unwrap().iter().sum();
167 assert!((sum - 1.0).abs() < 1e-6);
168 }
169
170 #[test]
171 fn test_joint_with_multiple_variables() {
172 let mut graph = FactorGraph::new();
173 graph.add_variable("var_0".to_string(), "D1".to_string());
174 graph.add_variable("var_1".to_string(), "D2".to_string());
175
176 let algorithm = Box::new(SumProductAlgorithm::default());
177 let engine = InferenceEngine::new(graph, algorithm);
178
179 let joint = engine.joint();
180 assert!(joint.is_ok());
181
182 let sum: f64 = joint.unwrap().iter().sum();
184 assert!((sum - 1.0).abs() < 1e-6);
185 }
186}