tensorlogic_quantrs_hooks/
lib.rs1mod cache;
23pub mod dbn;
24mod elimination_ordering;
25mod error;
26mod expectation_propagation;
27mod factor;
28mod graph;
29mod inference;
30pub mod influence;
31mod junction_tree;
32mod linear_chain_crf;
33pub mod memory;
34mod message_passing;
35mod models;
36mod parallel_message_passing;
37pub mod parameter_learning;
38pub mod quantrs_hooks;
39mod sampling;
40mod variable_elimination;
41mod variational;
42
43pub use cache::{CacheStats, CachedFactor, FactorCache};
44pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
45pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
46pub use error::{PgmError, Result};
47pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
48pub use factor::{Factor, FactorOp};
49pub use graph::{FactorGraph, FactorNode, VariableNode};
50pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
51pub use influence::{
52 InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
53};
54pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
55pub use linear_chain_crf::{
56 EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
57};
58pub use memory::{
59 BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
60 SparseFactor, StreamingFactorGraph,
61};
62pub use message_passing::{
63 ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
64};
65pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
66pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
67pub use parameter_learning::{
68 BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
69};
70pub use quantrs_hooks::{
71 DistributionExport, DistributionMetadata, ModelExport, ModelStatistics, QuantRSAssignment,
72 QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport, QuantRSParameterLearning,
73 QuantRSSamplingHook,
74};
75pub use sampling::{
76 Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
77 ProposalDistribution, WeightedSample,
78};
79pub use variable_elimination::VariableElimination;
80pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
81
82use scirs2_core::ndarray::ArrayD;
83use std::collections::HashMap;
84use tensorlogic_ir::TLExpr;
85
86pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
91 let mut graph = FactorGraph::new();
92
93 extract_factors(expr, &mut graph)?;
95
96 Ok(graph)
97}
98
99fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
101 match expr {
102 TLExpr::Pred { name, args } => {
103 let var_names: Vec<String> = args
105 .iter()
106 .filter_map(|term| match term {
107 tensorlogic_ir::Term::Var(v) => Some(v.clone()),
108 _ => None,
109 })
110 .collect();
111
112 for var_name in &var_names {
114 if graph.get_variable(var_name).is_none() {
115 graph.add_variable(var_name.clone(), "default".to_string());
116 }
117 }
118
119 if !var_names.is_empty() {
120 graph.add_factor_from_predicate(name, &var_names)?;
121 }
122 }
123 TLExpr::And(left, right) => {
124 extract_factors(left, graph)?;
126 extract_factors(right, graph)?;
127 }
128 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
129 graph.add_variable(var.clone(), domain.clone());
131 extract_factors(body, graph)?;
132 }
133 TLExpr::Imply(premise, conclusion) => {
134 extract_factors(premise, graph)?;
136 extract_factors(conclusion, graph)?;
137 }
138 TLExpr::Not(inner) => {
139 extract_factors(inner, graph)?;
141 }
142 _ => {
143 }
145 }
146
147 Ok(())
148}
149
150pub fn message_passing_reduce(
154 graph: &FactorGraph,
155 algorithm: &dyn MessagePassingAlgorithm,
156) -> Result<HashMap<String, ArrayD<f64>>> {
157 algorithm.run(graph)
158}
159
160pub fn marginalize(
164 joint_distribution: &ArrayD<f64>,
165 variable_idx: usize,
166 axes_to_sum: &[usize],
167) -> Result<ArrayD<f64>> {
168 use scirs2_core::ndarray::Axis;
169
170 let mut result = joint_distribution.clone();
171
172 for &axis in axes_to_sum.iter().rev() {
174 if axis != variable_idx {
175 result = result.sum_axis(Axis(axis));
176 }
177 }
178
179 Ok(result)
180}
181
182pub fn condition(
186 joint_distribution: &ArrayD<f64>,
187 evidence: &HashMap<usize, usize>,
188) -> Result<ArrayD<f64>> {
189 let mut result = joint_distribution.clone();
190
191 for (&var_idx, &value) in evidence {
193 result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
194 }
195
196 let sum: f64 = result.iter().sum();
198 if sum > 0.0 {
199 result /= sum;
200 }
201
202 Ok(result)
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use approx::assert_abs_diff_eq;
209 use scirs2_core::ndarray::Array;
210 use tensorlogic_ir::Term;
211
212 #[test]
213 fn test_expr_to_factor_graph() {
214 let expr = TLExpr::pred("P", vec![Term::var("x")]);
215 let graph = expr_to_factor_graph(&expr).unwrap();
216 assert!(!graph.is_empty());
217 }
218
219 #[test]
220 fn test_marginalize_simple() {
221 let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
223 .unwrap()
224 .into_dyn();
225
226 let marginal = marginalize(&joint, 0, &[0, 1]).unwrap();
228
229 assert_eq!(marginal.ndim(), 1);
230 assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
231 }
232
233 #[test]
234 fn test_condition_simple() {
235 let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
237 .unwrap()
238 .into_dyn();
239
240 let mut evidence = HashMap::new();
242 evidence.insert(1, 1);
243
244 let conditional = condition(&joint, &evidence).unwrap();
245
246 assert_eq!(conditional.ndim(), 1);
248 assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
250 }
251
252 #[test]
253 fn test_factor_graph_construction() {
254 let mut graph = FactorGraph::new();
255 graph.add_variable("x".to_string(), "Domain1".to_string());
256 graph.add_variable("y".to_string(), "Domain2".to_string());
257
258 assert_eq!(graph.num_variables(), 2);
259 }
260}