tensorlogic_quantrs_hooks/
lib.rs

1//! TL <-> QuantrS2 hooks (PGM/message passing as reductions).
2//!
3//! **Version**: 0.1.0-alpha.2 | **Status**: Production Ready
4//!
5//! This crate provides integration between TensorLogic and probabilistic graphical models (PGMs).
6//! It maps belief propagation and other message passing algorithms onto einsum reduction patterns.
7//!
8//! # Core Concepts
9//!
10//! - **Factor Graphs**: Convert TLExpr predicates into factors
11//! - **Message Passing**: Sum-product and max-product algorithms as tensor operations
12//! - **Inference**: Marginalization and conditional queries via reductions
13//!
14//! # Architecture
15//!
16//! ```text
17//! TLExpr → FactorGraph → MessagePassing → Marginals
18//!    ↓         ↓              ↓              ↓
19//! Predicates Factors    Einsum Ops    Probabilities
20//! ```
21
22mod cache;
23pub mod dbn;
24mod elimination_ordering;
25mod error;
26mod expectation_propagation;
27mod factor;
28mod graph;
29mod inference;
30pub mod influence;
31mod junction_tree;
32mod linear_chain_crf;
33pub mod memory;
34mod message_passing;
35mod models;
36mod parallel_message_passing;
37pub mod parameter_learning;
38pub mod quantrs_hooks;
39mod sampling;
40mod variable_elimination;
41mod variational;
42
43pub use cache::{CacheStats, CachedFactor, FactorCache};
44pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
45pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
46pub use error::{PgmError, Result};
47pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
48pub use factor::{Factor, FactorOp};
49pub use graph::{FactorGraph, FactorNode, VariableNode};
50pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
51pub use influence::{
52    InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
53};
54pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
55pub use linear_chain_crf::{
56    EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
57};
58pub use memory::{
59    BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
60    SparseFactor, StreamingFactorGraph,
61};
62pub use message_passing::{
63    ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
64};
65pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
66pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
67pub use parameter_learning::{
68    BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
69};
70pub use quantrs_hooks::{
71    DistributionExport, DistributionMetadata, ModelExport, ModelStatistics, QuantRSAssignment,
72    QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport, QuantRSParameterLearning,
73    QuantRSSamplingHook,
74};
75pub use sampling::{
76    Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
77    ProposalDistribution, WeightedSample,
78};
79pub use variable_elimination::VariableElimination;
80pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
81
82use scirs2_core::ndarray::ArrayD;
83use std::collections::HashMap;
84use tensorlogic_ir::TLExpr;
85
86/// Convert a TensorLogic expression to a factor graph.
87///
88/// This function analyzes the logical structure and creates a factor graph
89/// where predicates become factors and quantified variables become nodes.
90pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
91    let mut graph = FactorGraph::new();
92
93    // Recursively extract factors from expression
94    extract_factors(expr, &mut graph)?;
95
96    Ok(graph)
97}
98
99/// Extract factors from a TLExpr and add them to the factor graph.
100fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
101    match expr {
102        TLExpr::Pred { name, args } => {
103            // Create a factor from predicate
104            let var_names: Vec<String> = args
105                .iter()
106                .filter_map(|term| match term {
107                    tensorlogic_ir::Term::Var(v) => Some(v.clone()),
108                    _ => None,
109                })
110                .collect();
111
112            // Add variables if they don't exist
113            for var_name in &var_names {
114                if graph.get_variable(var_name).is_none() {
115                    graph.add_variable(var_name.clone(), "default".to_string());
116                }
117            }
118
119            if !var_names.is_empty() {
120                graph.add_factor_from_predicate(name, &var_names)?;
121            }
122        }
123        TLExpr::And(left, right) => {
124            // Conjunction creates multiple factors
125            extract_factors(left, graph)?;
126            extract_factors(right, graph)?;
127        }
128        TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
129            // Quantified variables become nodes in the factor graph
130            graph.add_variable(var.clone(), domain.clone());
131            extract_factors(body, graph)?;
132        }
133        TLExpr::Imply(premise, conclusion) => {
134            // Implication can be represented as factors
135            extract_factors(premise, graph)?;
136            extract_factors(conclusion, graph)?;
137        }
138        TLExpr::Not(inner) => {
139            // Negation affects factor values
140            extract_factors(inner, graph)?;
141        }
142        _ => {
143            // Other expressions may not directly map to factors
144        }
145    }
146
147    Ok(())
148}
149
150/// Perform message passing inference on a factor graph.
151///
152/// This function runs belief propagation to compute marginal probabilities.
153pub fn message_passing_reduce(
154    graph: &FactorGraph,
155    algorithm: &dyn MessagePassingAlgorithm,
156) -> Result<HashMap<String, ArrayD<f64>>> {
157    algorithm.run(graph)
158}
159
160/// Compute marginal probability for a variable.
161///
162/// This maps to a reduction operation over all other variables.
163pub fn marginalize(
164    joint_distribution: &ArrayD<f64>,
165    variable_idx: usize,
166    axes_to_sum: &[usize],
167) -> Result<ArrayD<f64>> {
168    use scirs2_core::ndarray::Axis;
169
170    let mut result = joint_distribution.clone();
171
172    // Sum over all axes except the target variable
173    for &axis in axes_to_sum.iter().rev() {
174        if axis != variable_idx {
175            result = result.sum_axis(Axis(axis));
176        }
177    }
178
179    Ok(result)
180}
181
182/// Compute conditional probability P(X | Y = y).
183///
184/// This slices the joint distribution at the evidence values.
185pub fn condition(
186    joint_distribution: &ArrayD<f64>,
187    evidence: &HashMap<usize, usize>,
188) -> Result<ArrayD<f64>> {
189    let mut result = joint_distribution.clone();
190
191    // Slice at evidence values
192    for (&var_idx, &value) in evidence {
193        result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
194    }
195
196    // Normalize
197    let sum: f64 = result.iter().sum();
198    if sum > 0.0 {
199        result /= sum;
200    }
201
202    Ok(result)
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use approx::assert_abs_diff_eq;
209    use scirs2_core::ndarray::Array;
210    use tensorlogic_ir::Term;
211
212    #[test]
213    fn test_expr_to_factor_graph() {
214        let expr = TLExpr::pred("P", vec![Term::var("x")]);
215        let graph = expr_to_factor_graph(&expr).unwrap();
216        assert!(!graph.is_empty());
217    }
218
219    #[test]
220    fn test_marginalize_simple() {
221        // 2x2 joint distribution: P(X, Y)
222        let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
223            .unwrap()
224            .into_dyn();
225
226        // Marginalize over Y (axis 1) to get P(X)
227        let marginal = marginalize(&joint, 0, &[0, 1]).unwrap();
228
229        assert_eq!(marginal.ndim(), 1);
230        assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
231    }
232
233    #[test]
234    fn test_condition_simple() {
235        // 2x2 joint distribution
236        let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
237            .unwrap()
238            .into_dyn();
239
240        // Condition on Y=1: P(X | Y=1)
241        let mut evidence = HashMap::new();
242        evidence.insert(1, 1);
243
244        let conditional = condition(&joint, &evidence).unwrap();
245
246        // Should have one dimension less
247        assert_eq!(conditional.ndim(), 1);
248        // Should be normalized
249        assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
250    }
251
252    #[test]
253    fn test_factor_graph_construction() {
254        let mut graph = FactorGraph::new();
255        graph.add_variable("x".to_string(), "Domain1".to_string());
256        graph.add_variable("y".to_string(), "Domain2".to_string());
257
258        assert_eq!(graph.num_variables(), 2);
259    }
260}