tensorlogic_quantrs_hooks/
lib.rs

1//! TL <-> QuantrS2 hooks (PGM/message passing as reductions).
2//!
3//! This crate provides integration between TensorLogic and probabilistic graphical models (PGMs).
4//! It maps belief propagation and other message passing algorithms onto einsum reduction patterns.
5//!
6//! # Core Concepts
7//!
8//! - **Factor Graphs**: Convert TLExpr predicates into factors
9//! - **Message Passing**: Sum-product and max-product algorithms as tensor operations
10//! - **Inference**: Marginalization and conditional queries via reductions
11//!
12//! # Architecture
13//!
14//! ```text
15//! TLExpr → FactorGraph → MessagePassing → Marginals
16//!    ↓         ↓              ↓              ↓
17//! Predicates Factors    Einsum Ops    Probabilities
18//! ```
19
20mod error;
21mod expectation_propagation;
22mod factor;
23mod graph;
24mod inference;
25mod junction_tree;
26mod linear_chain_crf;
27mod message_passing;
28mod models;
29pub mod parameter_learning;
30pub mod quantrs_hooks;
31mod sampling;
32mod variable_elimination;
33mod variational;
34
35pub use error::{PgmError, Result};
36pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
37pub use factor::{Factor, FactorOp};
38pub use graph::{FactorGraph, FactorNode, VariableNode};
39pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
40pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
41pub use linear_chain_crf::{
42    EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
43};
44pub use message_passing::{
45    ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
46};
47pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
48pub use parameter_learning::{
49    BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
50};
51pub use quantrs_hooks::{
52    DistributionExport, DistributionMetadata, ModelExport, ModelStatistics, QuantRSAssignment,
53    QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport, QuantRSParameterLearning,
54    QuantRSSamplingHook,
55};
56pub use sampling::{Assignment, GibbsSampler};
57pub use variable_elimination::VariableElimination;
58pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
59
60use scirs2_core::ndarray::ArrayD;
61use std::collections::HashMap;
62use tensorlogic_ir::TLExpr;
63
64/// Convert a TensorLogic expression to a factor graph.
65///
66/// This function analyzes the logical structure and creates a factor graph
67/// where predicates become factors and quantified variables become nodes.
68pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
69    let mut graph = FactorGraph::new();
70
71    // Recursively extract factors from expression
72    extract_factors(expr, &mut graph)?;
73
74    Ok(graph)
75}
76
77/// Extract factors from a TLExpr and add them to the factor graph.
78fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
79    match expr {
80        TLExpr::Pred { name, args } => {
81            // Create a factor from predicate
82            let var_names: Vec<String> = args
83                .iter()
84                .filter_map(|term| match term {
85                    tensorlogic_ir::Term::Var(v) => Some(v.clone()),
86                    _ => None,
87                })
88                .collect();
89
90            // Add variables if they don't exist
91            for var_name in &var_names {
92                if graph.get_variable(var_name).is_none() {
93                    graph.add_variable(var_name.clone(), "default".to_string());
94                }
95            }
96
97            if !var_names.is_empty() {
98                graph.add_factor_from_predicate(name, &var_names)?;
99            }
100        }
101        TLExpr::And(left, right) => {
102            // Conjunction creates multiple factors
103            extract_factors(left, graph)?;
104            extract_factors(right, graph)?;
105        }
106        TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
107            // Quantified variables become nodes in the factor graph
108            graph.add_variable(var.clone(), domain.clone());
109            extract_factors(body, graph)?;
110        }
111        TLExpr::Imply(premise, conclusion) => {
112            // Implication can be represented as factors
113            extract_factors(premise, graph)?;
114            extract_factors(conclusion, graph)?;
115        }
116        TLExpr::Not(inner) => {
117            // Negation affects factor values
118            extract_factors(inner, graph)?;
119        }
120        _ => {
121            // Other expressions may not directly map to factors
122        }
123    }
124
125    Ok(())
126}
127
128/// Perform message passing inference on a factor graph.
129///
130/// This function runs belief propagation to compute marginal probabilities.
131pub fn message_passing_reduce(
132    graph: &FactorGraph,
133    algorithm: &dyn MessagePassingAlgorithm,
134) -> Result<HashMap<String, ArrayD<f64>>> {
135    algorithm.run(graph)
136}
137
138/// Compute marginal probability for a variable.
139///
140/// This maps to a reduction operation over all other variables.
141pub fn marginalize(
142    joint_distribution: &ArrayD<f64>,
143    variable_idx: usize,
144    axes_to_sum: &[usize],
145) -> Result<ArrayD<f64>> {
146    use scirs2_core::ndarray::Axis;
147
148    let mut result = joint_distribution.clone();
149
150    // Sum over all axes except the target variable
151    for &axis in axes_to_sum.iter().rev() {
152        if axis != variable_idx {
153            result = result.sum_axis(Axis(axis));
154        }
155    }
156
157    Ok(result)
158}
159
160/// Compute conditional probability P(X | Y = y).
161///
162/// This slices the joint distribution at the evidence values.
163pub fn condition(
164    joint_distribution: &ArrayD<f64>,
165    evidence: &HashMap<usize, usize>,
166) -> Result<ArrayD<f64>> {
167    let mut result = joint_distribution.clone();
168
169    // Slice at evidence values
170    for (&var_idx, &value) in evidence {
171        result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
172    }
173
174    // Normalize
175    let sum: f64 = result.iter().sum();
176    if sum > 0.0 {
177        result /= sum;
178    }
179
180    Ok(result)
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use approx::assert_abs_diff_eq;
187    use scirs2_core::ndarray::Array;
188    use tensorlogic_ir::Term;
189
190    #[test]
191    fn test_expr_to_factor_graph() {
192        let expr = TLExpr::pred("P", vec![Term::var("x")]);
193        let graph = expr_to_factor_graph(&expr).unwrap();
194        assert!(!graph.is_empty());
195    }
196
197    #[test]
198    fn test_marginalize_simple() {
199        // 2x2 joint distribution: P(X, Y)
200        let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
201            .unwrap()
202            .into_dyn();
203
204        // Marginalize over Y (axis 1) to get P(X)
205        let marginal = marginalize(&joint, 0, &[0, 1]).unwrap();
206
207        assert_eq!(marginal.ndim(), 1);
208        assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
209    }
210
211    #[test]
212    fn test_condition_simple() {
213        // 2x2 joint distribution
214        let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
215            .unwrap()
216            .into_dyn();
217
218        // Condition on Y=1: P(X | Y=1)
219        let mut evidence = HashMap::new();
220        evidence.insert(1, 1);
221
222        let conditional = condition(&joint, &evidence).unwrap();
223
224        // Should have one dimension less
225        assert_eq!(conditional.ndim(), 1);
226        // Should be normalized
227        assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
228    }
229
230    #[test]
231    fn test_factor_graph_construction() {
232        let mut graph = FactorGraph::new();
233        graph.add_variable("x".to_string(), "Domain1".to_string());
234        graph.add_variable("y".to_string(), "Domain2".to_string());
235
236        assert_eq!(graph.num_variables(), 2);
237    }
238}