tensorlogic_quantrs_hooks/
lib.rs1mod error;
21mod expectation_propagation;
22mod factor;
23mod graph;
24mod inference;
25mod junction_tree;
26mod linear_chain_crf;
27mod message_passing;
28mod models;
29pub mod parameter_learning;
30pub mod quantrs_hooks;
31mod sampling;
32mod variable_elimination;
33mod variational;
34
35pub use error::{PgmError, Result};
36pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
37pub use factor::{Factor, FactorOp};
38pub use graph::{FactorGraph, FactorNode, VariableNode};
39pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
40pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
41pub use linear_chain_crf::{
42 EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
43};
44pub use message_passing::{
45 ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
46};
47pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
48pub use parameter_learning::{
49 BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
50};
51pub use quantrs_hooks::{
52 DistributionExport, DistributionMetadata, ModelExport, ModelStatistics, QuantRSAssignment,
53 QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport, QuantRSParameterLearning,
54 QuantRSSamplingHook,
55};
56pub use sampling::{Assignment, GibbsSampler};
57pub use variable_elimination::VariableElimination;
58pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
59
60use scirs2_core::ndarray::ArrayD;
61use std::collections::HashMap;
62use tensorlogic_ir::TLExpr;
63
64pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
69 let mut graph = FactorGraph::new();
70
71 extract_factors(expr, &mut graph)?;
73
74 Ok(graph)
75}
76
77fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
79 match expr {
80 TLExpr::Pred { name, args } => {
81 let var_names: Vec<String> = args
83 .iter()
84 .filter_map(|term| match term {
85 tensorlogic_ir::Term::Var(v) => Some(v.clone()),
86 _ => None,
87 })
88 .collect();
89
90 for var_name in &var_names {
92 if graph.get_variable(var_name).is_none() {
93 graph.add_variable(var_name.clone(), "default".to_string());
94 }
95 }
96
97 if !var_names.is_empty() {
98 graph.add_factor_from_predicate(name, &var_names)?;
99 }
100 }
101 TLExpr::And(left, right) => {
102 extract_factors(left, graph)?;
104 extract_factors(right, graph)?;
105 }
106 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
107 graph.add_variable(var.clone(), domain.clone());
109 extract_factors(body, graph)?;
110 }
111 TLExpr::Imply(premise, conclusion) => {
112 extract_factors(premise, graph)?;
114 extract_factors(conclusion, graph)?;
115 }
116 TLExpr::Not(inner) => {
117 extract_factors(inner, graph)?;
119 }
120 _ => {
121 }
123 }
124
125 Ok(())
126}
127
128pub fn message_passing_reduce(
132 graph: &FactorGraph,
133 algorithm: &dyn MessagePassingAlgorithm,
134) -> Result<HashMap<String, ArrayD<f64>>> {
135 algorithm.run(graph)
136}
137
138pub fn marginalize(
142 joint_distribution: &ArrayD<f64>,
143 variable_idx: usize,
144 axes_to_sum: &[usize],
145) -> Result<ArrayD<f64>> {
146 use scirs2_core::ndarray::Axis;
147
148 let mut result = joint_distribution.clone();
149
150 for &axis in axes_to_sum.iter().rev() {
152 if axis != variable_idx {
153 result = result.sum_axis(Axis(axis));
154 }
155 }
156
157 Ok(result)
158}
159
160pub fn condition(
164 joint_distribution: &ArrayD<f64>,
165 evidence: &HashMap<usize, usize>,
166) -> Result<ArrayD<f64>> {
167 let mut result = joint_distribution.clone();
168
169 for (&var_idx, &value) in evidence {
171 result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
172 }
173
174 let sum: f64 = result.iter().sum();
176 if sum > 0.0 {
177 result /= sum;
178 }
179
180 Ok(result)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use approx::assert_abs_diff_eq;
187 use scirs2_core::ndarray::Array;
188 use tensorlogic_ir::Term;
189
190 #[test]
191 fn test_expr_to_factor_graph() {
192 let expr = TLExpr::pred("P", vec![Term::var("x")]);
193 let graph = expr_to_factor_graph(&expr).unwrap();
194 assert!(!graph.is_empty());
195 }
196
197 #[test]
198 fn test_marginalize_simple() {
199 let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
201 .unwrap()
202 .into_dyn();
203
204 let marginal = marginalize(&joint, 0, &[0, 1]).unwrap();
206
207 assert_eq!(marginal.ndim(), 1);
208 assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
209 }
210
211 #[test]
212 fn test_condition_simple() {
213 let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
215 .unwrap()
216 .into_dyn();
217
218 let mut evidence = HashMap::new();
220 evidence.insert(1, 1);
221
222 let conditional = condition(&joint, &evidence).unwrap();
223
224 assert_eq!(conditional.ndim(), 1);
226 assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
228 }
229
230 #[test]
231 fn test_factor_graph_construction() {
232 let mut graph = FactorGraph::new();
233 graph.add_variable("x".to_string(), "Domain1".to_string());
234 graph.add_variable("y".to_string(), "Domain2".to_string());
235
236 assert_eq!(graph.num_variables(), 2);
237 }
238}