Skip to main content

tensorlogic_quantrs_hooks/
lib.rs

1//! TL <-> QuantrS2 hooks (PGM/message passing as reductions).
2//!
3//! **Version**: 0.1.0-beta.1 | **Status**: Production Ready
4//!
5//! This crate provides integration between TensorLogic and probabilistic graphical models (PGMs).
6//! It maps belief propagation and other message passing algorithms onto einsum reduction patterns.
7//!
8//! # Core Concepts
9//!
10//! - **Factor Graphs**: Convert TLExpr predicates into factors
11//! - **Message Passing**: Sum-product and max-product algorithms as tensor operations
12//! - **Inference**: Marginalization and conditional queries via reductions
13//!
14//! # Architecture
15//!
16//! ```text
17//! TLExpr → FactorGraph → MessagePassing → Marginals
18//!    ↓         ↓              ↓              ↓
19//! Predicates Factors    Einsum Ops    Probabilities
20//! ```
21
22mod cache;
23pub mod dbn;
24mod elimination_ordering;
25mod error;
26mod expectation_propagation;
27mod factor;
28mod graph;
29mod inference;
30pub mod influence;
31mod junction_tree;
32mod linear_chain_crf;
33pub mod memory;
34mod message_passing;
35mod models;
36mod parallel_message_passing;
37pub mod parameter_learning;
38pub mod quantrs_hooks;
39pub mod quantum_circuit;
40pub mod quantum_simulation;
41mod sampling;
42pub mod tensor_network_bridge;
43mod variable_elimination;
44mod variational;
45
46pub use cache::{CacheStats, CachedFactor, FactorCache};
47pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
48pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
49pub use error::{PgmError, Result};
50pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
51pub use factor::{Factor, FactorOp};
52pub use graph::{FactorGraph, FactorNode, VariableNode};
53pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
54pub use influence::{
55    InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
56};
57pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
58pub use linear_chain_crf::{
59    EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
60};
61pub use memory::{
62    BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
63    SparseFactor, StreamingFactorGraph,
64};
65pub use message_passing::{
66    ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
67};
68pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
69pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
70pub use parameter_learning::{
71    BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
72};
73pub use quantrs_hooks::{
74    AnnealingConfig, DistributionExport, DistributionMetadata, ModelExport, ModelStatistics,
75    QuantRSAssignment, QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport,
76    QuantRSParameterLearning, QuantRSSamplingHook, QuantumAnnealing, QuantumInference,
77    QuantumSolution, QuantumSolutionMetadata,
78};
79pub use quantum_circuit::{
80    tlexpr_to_qaoa_circuit, IsingModel, QAOAConfig, QAOAResult, QUBOProblem, QuantumCircuitBuilder,
81};
82pub use quantum_simulation::{
83    run_qaoa, QuantumSimulationBackend, SimulatedState, SimulationConfig,
84};
85pub use sampling::{
86    Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
87    ProposalDistribution, WeightedSample,
88};
89pub use tensor_network_bridge::{
90    factor_graph_to_tensor_network, linear_chain_to_mps, MatrixProductState, Tensor, TensorNetwork,
91    TensorNetworkStats,
92};
93pub use variable_elimination::VariableElimination;
94pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
95
96use scirs2_core::ndarray::ArrayD;
97use std::collections::HashMap;
98use tensorlogic_ir::TLExpr;
99
100/// Convert a TensorLogic expression to a factor graph.
101///
102/// This function analyzes the logical structure and creates a factor graph
103/// where predicates become factors and quantified variables become nodes.
104pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
105    let mut graph = FactorGraph::new();
106
107    // Recursively extract factors from expression
108    extract_factors(expr, &mut graph)?;
109
110    Ok(graph)
111}
112
113/// Extract factors from a TLExpr and add them to the factor graph.
114fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
115    match expr {
116        TLExpr::Pred { name, args } => {
117            // Create a factor from predicate
118            let var_names: Vec<String> = args
119                .iter()
120                .filter_map(|term| match term {
121                    tensorlogic_ir::Term::Var(v) => Some(v.clone()),
122                    _ => None,
123                })
124                .collect();
125
126            // Add variables if they don't exist
127            for var_name in &var_names {
128                if graph.get_variable(var_name).is_none() {
129                    graph.add_variable(var_name.clone(), "default".to_string());
130                }
131            }
132
133            if !var_names.is_empty() {
134                graph.add_factor_from_predicate(name, &var_names)?;
135            }
136        }
137        TLExpr::And(left, right) => {
138            // Conjunction creates multiple factors
139            extract_factors(left, graph)?;
140            extract_factors(right, graph)?;
141        }
142        TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
143            // Quantified variables become nodes in the factor graph
144            graph.add_variable(var.clone(), domain.clone());
145            extract_factors(body, graph)?;
146        }
147        TLExpr::Imply(premise, conclusion) => {
148            // Implication can be represented as factors
149            extract_factors(premise, graph)?;
150            extract_factors(conclusion, graph)?;
151        }
152        TLExpr::Not(inner) => {
153            // Negation affects factor values
154            extract_factors(inner, graph)?;
155        }
156        _ => {
157            // Other expressions may not directly map to factors
158        }
159    }
160
161    Ok(())
162}
163
164/// Perform message passing inference on a factor graph.
165///
166/// This function runs belief propagation to compute marginal probabilities.
167pub fn message_passing_reduce(
168    graph: &FactorGraph,
169    algorithm: &dyn MessagePassingAlgorithm,
170) -> Result<HashMap<String, ArrayD<f64>>> {
171    algorithm.run(graph)
172}
173
174/// Compute marginal probability for a variable.
175///
176/// This maps to a reduction operation over all other variables.
177pub fn marginalize(
178    joint_distribution: &ArrayD<f64>,
179    variable_idx: usize,
180    axes_to_sum: &[usize],
181) -> Result<ArrayD<f64>> {
182    use scirs2_core::ndarray::Axis;
183
184    let mut result = joint_distribution.clone();
185
186    // Sum over all axes except the target variable
187    for &axis in axes_to_sum.iter().rev() {
188        if axis != variable_idx {
189            result = result.sum_axis(Axis(axis));
190        }
191    }
192
193    Ok(result)
194}
195
196/// Compute conditional probability P(X | Y = y).
197///
198/// This slices the joint distribution at the evidence values.
199pub fn condition(
200    joint_distribution: &ArrayD<f64>,
201    evidence: &HashMap<usize, usize>,
202) -> Result<ArrayD<f64>> {
203    let mut result = joint_distribution.clone();
204
205    // Slice at evidence values
206    for (&var_idx, &value) in evidence {
207        result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
208    }
209
210    // Normalize
211    let sum: f64 = result.iter().sum();
212    if sum > 0.0 {
213        result /= sum;
214    }
215
216    Ok(result)
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use approx::assert_abs_diff_eq;
223    use scirs2_core::ndarray::Array;
224    use tensorlogic_ir::Term;
225
226    #[test]
227    fn test_expr_to_factor_graph() {
228        let expr = TLExpr::pred("P", vec![Term::var("x")]);
229        let graph = expr_to_factor_graph(&expr).unwrap();
230        assert!(!graph.is_empty());
231    }
232
233    #[test]
234    fn test_marginalize_simple() {
235        // 2x2 joint distribution: P(X, Y)
236        let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
237            .unwrap()
238            .into_dyn();
239
240        // Marginalize over Y (axis 1) to get P(X)
241        let marginal = marginalize(&joint, 0, &[0, 1]).unwrap();
242
243        assert_eq!(marginal.ndim(), 1);
244        assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
245    }
246
247    #[test]
248    fn test_condition_simple() {
249        // 2x2 joint distribution
250        let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
251            .unwrap()
252            .into_dyn();
253
254        // Condition on Y=1: P(X | Y=1)
255        let mut evidence = HashMap::new();
256        evidence.insert(1, 1);
257
258        let conditional = condition(&joint, &evidence).unwrap();
259
260        // Should have one dimension less
261        assert_eq!(conditional.ndim(), 1);
262        // Should be normalized
263        assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
264    }
265
266    #[test]
267    fn test_factor_graph_construction() {
268        let mut graph = FactorGraph::new();
269        graph.add_variable("x".to_string(), "Domain1".to_string());
270        graph.add_variable("y".to_string(), "Domain2".to_string());
271
272        assert_eq!(graph.num_variables(), 2);
273    }
274}