Skip to main content

tensorlogic_quantrs_hooks/
lib.rs

1//! TL <-> QuantrS2 hooks (PGM/message passing as reductions).
2//!
3//! **Version**: 0.1.0 | **Status**: Production Ready
4//!
5//! This crate provides integration between TensorLogic and probabilistic graphical models (PGMs).
6//! It maps belief propagation and other message passing algorithms onto einsum reduction patterns.
7//!
8//! # Core Concepts
9//!
10//! - **Factor Graphs**: Convert TLExpr predicates into factors
11//! - **Message Passing**: Sum-product and max-product algorithms as tensor operations
12//! - **Inference**: Marginalization and conditional queries via reductions
13//!
14//! # Architecture
15//!
16//! ```text
17//! TLExpr → FactorGraph → MessagePassing → Marginals
18//!    ↓         ↓              ↓              ↓
19//! Predicates Factors    Einsum Ops    Probabilities
20//! ```
21
22mod cache;
23pub mod convergence;
24pub mod dbn;
25mod elimination_ordering;
26mod error;
27mod expectation_propagation;
28mod factor;
29pub mod factor_graph_viz;
30mod graph;
31mod inference;
32pub mod influence;
33mod junction_tree;
34mod linear_chain_crf;
35pub mod loopy_bp;
36pub mod memory;
37mod message_passing;
38mod models;
39mod parallel_message_passing;
40pub mod parameter_learning;
41pub mod quantrs_hooks;
42pub mod quantum_circuit;
43pub mod quantum_simulation;
44mod sampling;
45pub mod tensor_network_bridge;
46mod variable_elimination;
47mod variational;
48pub mod vmp;
49
50pub use cache::{CacheStats, CachedFactor, FactorCache};
51pub use convergence::{
52    ConvergenceConfig, ConvergenceError, ConvergenceMonitor, ConvergenceState, DampingSchedule,
53    InferenceStats,
54};
55pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
56pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
57pub use error::{PgmError, Result};
58pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
59pub use factor::{Factor, FactorOp};
60pub use factor_graph_viz::{
61    render_ascii, render_dot, FactorGraphModel, FactorGraphStats, VizFactorNode, VizVariableNode,
62};
63pub use graph::{FactorGraph, FactorNode, VariableNode};
64pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
65pub use influence::{
66    InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
67};
68pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
69pub use linear_chain_crf::{
70    EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
71};
72pub use loopy_bp::{
73    bethe_free_energy, BetheFreeEnergy, CycleAnalysis, CycleDetector, LbpConvergenceMonitor,
74    LbpDampingPolicy, LbpIterStats, LogMessage, LoopyBeliefPropagation, LoopyBpConfig,
75    LoopyBpResult, UpdateSchedule,
76};
77pub use memory::{
78    BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
79    SparseFactor, StreamingFactorGraph,
80};
81pub use message_passing::{
82    ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
83};
84pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
85pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
86pub use parameter_learning::{
87    BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
88};
89pub use quantrs_hooks::{
90    AnnealingConfig, DistributionExport, DistributionMetadata, ModelExport, ModelStatistics,
91    QuantRSAssignment, QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport,
92    QuantRSParameterLearning, QuantRSSamplingHook, QuantumAnnealing, QuantumInference,
93    QuantumSolution, QuantumSolutionMetadata,
94};
95pub use quantum_circuit::{
96    tlexpr_to_qaoa_circuit, IsingModel, QAOAConfig, QAOAResult, QUBOProblem, QuantumCircuitBuilder,
97};
98pub use quantum_simulation::{
99    run_qaoa, QuantumSimulationBackend, SimulatedState, SimulationConfig,
100};
101pub use sampling::{
102    Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
103    ProposalDistribution, WeightedSample,
104};
105pub use tensor_network_bridge::{
106    factor_graph_to_tensor_network, linear_chain_to_mps, MatrixProductState, Tensor, TensorNetwork,
107    TensorNetworkStats,
108};
109pub use variable_elimination::VariableElimination;
110pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
111pub use vmp::{
112    categorical_kl, dirichlet_kl, gaussian_kl, gaussian_kl_fixed_precision,
113    BetaBernoulliObservation, BetaNP, CategoricalNP, DirichletNP, ExponentialFamily, Family,
114    GammaNP, GammaPoissonObservation, GaussianNP, MessageDirection, VariationalMessagePassing,
115    VariationalState, VmpConfig, VmpFactor, VmpMessage, VmpResult,
116};
117
118use scirs2_core::ndarray::ArrayD;
119use std::collections::HashMap;
120use tensorlogic_ir::TLExpr;
121
122/// Convert a TensorLogic expression to a factor graph.
123///
124/// This function analyzes the logical structure and creates a factor graph
125/// where predicates become factors and quantified variables become nodes.
126pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
127    let mut graph = FactorGraph::new();
128
129    // Recursively extract factors from expression
130    extract_factors(expr, &mut graph)?;
131
132    Ok(graph)
133}
134
135/// Extract factors from a TLExpr and add them to the factor graph.
136fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
137    match expr {
138        TLExpr::Pred { name, args } => {
139            // Create a factor from predicate
140            let var_names: Vec<String> = args
141                .iter()
142                .filter_map(|term| match term {
143                    tensorlogic_ir::Term::Var(v) => Some(v.clone()),
144                    _ => None,
145                })
146                .collect();
147
148            // Add variables if they don't exist
149            for var_name in &var_names {
150                if graph.get_variable(var_name).is_none() {
151                    graph.add_variable(var_name.clone(), "default".to_string());
152                }
153            }
154
155            if !var_names.is_empty() {
156                graph.add_factor_from_predicate(name, &var_names)?;
157            }
158        }
159        TLExpr::And(left, right) => {
160            // Conjunction creates multiple factors
161            extract_factors(left, graph)?;
162            extract_factors(right, graph)?;
163        }
164        TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
165            // Quantified variables become nodes in the factor graph
166            graph.add_variable(var.clone(), domain.clone());
167            extract_factors(body, graph)?;
168        }
169        TLExpr::Imply(premise, conclusion) => {
170            // Implication can be represented as factors
171            extract_factors(premise, graph)?;
172            extract_factors(conclusion, graph)?;
173        }
174        TLExpr::Not(inner) => {
175            // Negation affects factor values
176            extract_factors(inner, graph)?;
177        }
178        _ => {
179            // Other expressions may not directly map to factors
180        }
181    }
182
183    Ok(())
184}
185
186/// Perform message passing inference on a factor graph.
187///
188/// This function runs belief propagation to compute marginal probabilities.
189pub fn message_passing_reduce(
190    graph: &FactorGraph,
191    algorithm: &dyn MessagePassingAlgorithm,
192) -> Result<HashMap<String, ArrayD<f64>>> {
193    algorithm.run(graph)
194}
195
196/// Compute marginal probability for a variable.
197///
198/// This maps to a reduction operation over all other variables.
199pub fn marginalize(
200    joint_distribution: &ArrayD<f64>,
201    variable_idx: usize,
202    axes_to_sum: &[usize],
203) -> Result<ArrayD<f64>> {
204    use scirs2_core::ndarray::Axis;
205
206    let mut result = joint_distribution.clone();
207
208    // Sum over all axes except the target variable
209    for &axis in axes_to_sum.iter().rev() {
210        if axis != variable_idx {
211            result = result.sum_axis(Axis(axis));
212        }
213    }
214
215    Ok(result)
216}
217
218/// Compute conditional probability P(X | Y = y).
219///
220/// This slices the joint distribution at the evidence values.
221pub fn condition(
222    joint_distribution: &ArrayD<f64>,
223    evidence: &HashMap<usize, usize>,
224) -> Result<ArrayD<f64>> {
225    let mut result = joint_distribution.clone();
226
227    // Slice at evidence values
228    for (&var_idx, &value) in evidence {
229        result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
230    }
231
232    // Normalize
233    let sum: f64 = result.iter().sum();
234    if sum > 0.0 {
235        result /= sum;
236    }
237
238    Ok(result)
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use approx::assert_abs_diff_eq;
245    use scirs2_core::ndarray::Array;
246    use tensorlogic_ir::Term;
247
248    #[test]
249    fn test_expr_to_factor_graph() {
250        let expr = TLExpr::pred("P", vec![Term::var("x")]);
251        let graph = expr_to_factor_graph(&expr).expect("unwrap");
252        assert!(!graph.is_empty());
253    }
254
255    #[test]
256    fn test_marginalize_simple() {
257        // 2x2 joint distribution: P(X, Y)
258        let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
259            .expect("unwrap")
260            .into_dyn();
261
262        // Marginalize over Y (axis 1) to get P(X)
263        let marginal = marginalize(&joint, 0, &[0, 1]).expect("unwrap");
264
265        assert_eq!(marginal.ndim(), 1);
266        assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
267    }
268
269    #[test]
270    fn test_condition_simple() {
271        // 2x2 joint distribution
272        let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
273            .expect("unwrap")
274            .into_dyn();
275
276        // Condition on Y=1: P(X | Y=1)
277        let mut evidence = HashMap::new();
278        evidence.insert(1, 1);
279
280        let conditional = condition(&joint, &evidence).expect("unwrap");
281
282        // Should have one dimension less
283        assert_eq!(conditional.ndim(), 1);
284        // Should be normalized
285        assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
286    }
287
288    #[test]
289    fn test_factor_graph_construction() {
290        let mut graph = FactorGraph::new();
291        graph.add_variable("x".to_string(), "Domain1".to_string());
292        graph.add_variable("y".to_string(), "Domain2".to_string());
293
294        assert_eq!(graph.num_variables(), 2);
295    }
296}