Skip to main content

tensorlogic_quantrs_hooks/
inference.rs

1//! High-level inference operations.
2
3use scirs2_core::ndarray::ArrayD;
4use std::collections::HashMap;
5
6use crate::error::Result;
7use crate::graph::FactorGraph;
8use crate::message_passing::MessagePassingAlgorithm;
9
10/// Query for marginal probability P(X).
11#[derive(Clone, Debug)]
12pub struct MarginalizationQuery {
13    /// Variable to query
14    pub variable: String,
15}
16
17/// Query for conditional probability P(X | Y = y).
18#[derive(Clone, Debug)]
19pub struct ConditionalQuery {
20    /// Query variables
21    pub query_vars: Vec<String>,
22    /// Evidence: variable -> value
23    pub evidence: HashMap<String, usize>,
24}
25
26/// Inference engine for PGM queries.
27pub struct InferenceEngine {
28    /// Factor graph
29    graph: FactorGraph,
30    /// Message passing algorithm
31    algorithm: Box<dyn MessagePassingAlgorithm>,
32}
33
34impl InferenceEngine {
35    /// Create a new inference engine.
36    pub fn new(graph: FactorGraph, algorithm: Box<dyn MessagePassingAlgorithm>) -> Self {
37        Self { graph, algorithm }
38    }
39
40    /// Compute marginal probability for a variable.
41    pub fn marginalize(&self, query: &MarginalizationQuery) -> Result<ArrayD<f64>> {
42        let marginals = self.algorithm.run(&self.graph)?;
43        marginals
44            .get(&query.variable)
45            .cloned()
46            .ok_or_else(|| crate::error::PgmError::VariableNotFound(query.variable.clone()))
47    }
48
49    /// Compute conditional probability.
50    pub fn conditional(&self, query: &ConditionalQuery) -> Result<HashMap<String, ArrayD<f64>>> {
51        // Run inference with evidence
52        let marginals = self.algorithm.run(&self.graph)?;
53
54        // Filter to query variables
55        let mut result = HashMap::new();
56        for var in &query.query_vars {
57            if let Some(marginal) = marginals.get(var) {
58                result.insert(var.clone(), marginal.clone());
59            }
60        }
61
62        Ok(result)
63    }
64
65    /// Compute joint probability over all variables.
66    ///
67    /// This computes P(X₁, X₂, ..., Xₙ) = ∏ᵢ φᵢ(Xᵢ)
68    pub fn joint(&self) -> Result<ArrayD<f64>> {
69        use crate::factor::Factor;
70
71        // Collect all variables
72        let all_vars: Vec<String> = self.graph.variable_names().cloned().collect();
73
74        if all_vars.is_empty() {
75            return Err(crate::error::PgmError::InvalidGraph(
76                "No variables in graph".to_string(),
77            ));
78        }
79
80        // Start with first factor or uniform distribution
81        let mut joint_factor: Option<Factor> = None;
82
83        for factor_id in self.graph.factor_ids() {
84            if let Some(factor) = self.graph.get_factor(factor_id) {
85                joint_factor = if let Some(existing) = joint_factor {
86                    Some(existing.product(factor)?)
87                } else {
88                    Some(factor.clone())
89                };
90            }
91        }
92
93        // If no factors, return uniform distribution
94        if let Some(mut joint) = joint_factor {
95            joint.normalize();
96            Ok(joint.values)
97        } else {
98            // No factors - return uniform
99            let shape: Vec<usize> = all_vars
100                .iter()
101                .filter_map(|v| self.graph.get_variable(v))
102                .map(|n| n.cardinality)
103                .collect();
104            let size: usize = shape.iter().product();
105            Ok(ArrayD::from_elem(shape, 1.0 / size as f64))
106        }
107    }
108
109    /// Get the factor graph.
110    pub fn graph(&self) -> &FactorGraph {
111        &self.graph
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::message_passing::SumProductAlgorithm;
119
120    #[test]
121    fn test_inference_engine() {
122        let mut graph = FactorGraph::new();
123        graph.add_variable("x".to_string(), "D1".to_string());
124
125        let algorithm = Box::new(SumProductAlgorithm::default());
126        let engine = InferenceEngine::new(graph, algorithm);
127
128        let query = MarginalizationQuery {
129            variable: "x".to_string(),
130        };
131
132        let result = engine.marginalize(&query);
133        assert!(result.is_ok());
134    }
135
136    #[test]
137    fn test_conditional_query() {
138        let mut graph = FactorGraph::new();
139        graph.add_variable("x".to_string(), "D1".to_string());
140        graph.add_variable("y".to_string(), "D2".to_string());
141
142        let algorithm = Box::new(SumProductAlgorithm::default());
143        let engine = InferenceEngine::new(graph, algorithm);
144
145        let query = ConditionalQuery {
146            query_vars: vec!["x".to_string()],
147            evidence: HashMap::new(),
148        };
149
150        let result = engine.conditional(&query);
151        assert!(result.is_ok());
152    }
153
154    #[test]
155    fn test_joint_probability() {
156        let mut graph = FactorGraph::new();
157        graph.add_variable("var_0".to_string(), "D1".to_string());
158
159        let algorithm = Box::new(SumProductAlgorithm::default());
160        let engine = InferenceEngine::new(graph, algorithm);
161
162        let joint = engine.joint();
163        assert!(joint.is_ok());
164
165        // Should be normalized
166        let sum: f64 = joint.unwrap().iter().sum();
167        assert!((sum - 1.0).abs() < 1e-6);
168    }
169
170    #[test]
171    fn test_joint_with_multiple_variables() {
172        let mut graph = FactorGraph::new();
173        graph.add_variable("var_0".to_string(), "D1".to_string());
174        graph.add_variable("var_1".to_string(), "D2".to_string());
175
176        let algorithm = Box::new(SumProductAlgorithm::default());
177        let engine = InferenceEngine::new(graph, algorithm);
178
179        let joint = engine.joint();
180        assert!(joint.is_ok());
181
182        // Should be normalized
183        let sum: f64 = joint.unwrap().iter().sum();
184        assert!((sum - 1.0).abs() < 1e-6);
185    }
186}