tensorlogic_quantrs_hooks/
lib.rs1mod cache;
23pub mod dbn;
24mod elimination_ordering;
25mod error;
26mod expectation_propagation;
27mod factor;
28mod graph;
29mod inference;
30pub mod influence;
31mod junction_tree;
32mod linear_chain_crf;
33pub mod memory;
34mod message_passing;
35mod models;
36mod parallel_message_passing;
37pub mod parameter_learning;
38pub mod quantrs_hooks;
39pub mod quantum_circuit;
40pub mod quantum_simulation;
41mod sampling;
42pub mod tensor_network_bridge;
43mod variable_elimination;
44mod variational;
45
46pub use cache::{CacheStats, CachedFactor, FactorCache};
47pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
48pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
49pub use error::{PgmError, Result};
50pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
51pub use factor::{Factor, FactorOp};
52pub use graph::{FactorGraph, FactorNode, VariableNode};
53pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
54pub use influence::{
55 InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
56};
57pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
58pub use linear_chain_crf::{
59 EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
60};
61pub use memory::{
62 BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
63 SparseFactor, StreamingFactorGraph,
64};
65pub use message_passing::{
66 ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
67};
68pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
69pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
70pub use parameter_learning::{
71 BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
72};
73pub use quantrs_hooks::{
74 AnnealingConfig, DistributionExport, DistributionMetadata, ModelExport, ModelStatistics,
75 QuantRSAssignment, QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport,
76 QuantRSParameterLearning, QuantRSSamplingHook, QuantumAnnealing, QuantumInference,
77 QuantumSolution, QuantumSolutionMetadata,
78};
79pub use quantum_circuit::{
80 tlexpr_to_qaoa_circuit, IsingModel, QAOAConfig, QAOAResult, QUBOProblem, QuantumCircuitBuilder,
81};
82pub use quantum_simulation::{
83 run_qaoa, QuantumSimulationBackend, SimulatedState, SimulationConfig,
84};
85pub use sampling::{
86 Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
87 ProposalDistribution, WeightedSample,
88};
89pub use tensor_network_bridge::{
90 factor_graph_to_tensor_network, linear_chain_to_mps, MatrixProductState, Tensor, TensorNetwork,
91 TensorNetworkStats,
92};
93pub use variable_elimination::VariableElimination;
94pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
95
96use scirs2_core::ndarray::ArrayD;
97use std::collections::HashMap;
98use tensorlogic_ir::TLExpr;
99
100pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
105 let mut graph = FactorGraph::new();
106
107 extract_factors(expr, &mut graph)?;
109
110 Ok(graph)
111}
112
113fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
115 match expr {
116 TLExpr::Pred { name, args } => {
117 let var_names: Vec<String> = args
119 .iter()
120 .filter_map(|term| match term {
121 tensorlogic_ir::Term::Var(v) => Some(v.clone()),
122 _ => None,
123 })
124 .collect();
125
126 for var_name in &var_names {
128 if graph.get_variable(var_name).is_none() {
129 graph.add_variable(var_name.clone(), "default".to_string());
130 }
131 }
132
133 if !var_names.is_empty() {
134 graph.add_factor_from_predicate(name, &var_names)?;
135 }
136 }
137 TLExpr::And(left, right) => {
138 extract_factors(left, graph)?;
140 extract_factors(right, graph)?;
141 }
142 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
143 graph.add_variable(var.clone(), domain.clone());
145 extract_factors(body, graph)?;
146 }
147 TLExpr::Imply(premise, conclusion) => {
148 extract_factors(premise, graph)?;
150 extract_factors(conclusion, graph)?;
151 }
152 TLExpr::Not(inner) => {
153 extract_factors(inner, graph)?;
155 }
156 _ => {
157 }
159 }
160
161 Ok(())
162}
163
164pub fn message_passing_reduce(
168 graph: &FactorGraph,
169 algorithm: &dyn MessagePassingAlgorithm,
170) -> Result<HashMap<String, ArrayD<f64>>> {
171 algorithm.run(graph)
172}
173
174pub fn marginalize(
178 joint_distribution: &ArrayD<f64>,
179 variable_idx: usize,
180 axes_to_sum: &[usize],
181) -> Result<ArrayD<f64>> {
182 use scirs2_core::ndarray::Axis;
183
184 let mut result = joint_distribution.clone();
185
186 for &axis in axes_to_sum.iter().rev() {
188 if axis != variable_idx {
189 result = result.sum_axis(Axis(axis));
190 }
191 }
192
193 Ok(result)
194}
195
196pub fn condition(
200 joint_distribution: &ArrayD<f64>,
201 evidence: &HashMap<usize, usize>,
202) -> Result<ArrayD<f64>> {
203 let mut result = joint_distribution.clone();
204
205 for (&var_idx, &value) in evidence {
207 result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
208 }
209
210 let sum: f64 = result.iter().sum();
212 if sum > 0.0 {
213 result /= sum;
214 }
215
216 Ok(result)
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use approx::assert_abs_diff_eq;
223 use scirs2_core::ndarray::Array;
224 use tensorlogic_ir::Term;
225
226 #[test]
227 fn test_expr_to_factor_graph() {
228 let expr = TLExpr::pred("P", vec![Term::var("x")]);
229 let graph = expr_to_factor_graph(&expr).unwrap();
230 assert!(!graph.is_empty());
231 }
232
233 #[test]
234 fn test_marginalize_simple() {
235 let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
237 .unwrap()
238 .into_dyn();
239
240 let marginal = marginalize(&joint, 0, &[0, 1]).unwrap();
242
243 assert_eq!(marginal.ndim(), 1);
244 assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
245 }
246
247 #[test]
248 fn test_condition_simple() {
249 let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
251 .unwrap()
252 .into_dyn();
253
254 let mut evidence = HashMap::new();
256 evidence.insert(1, 1);
257
258 let conditional = condition(&joint, &evidence).unwrap();
259
260 assert_eq!(conditional.ndim(), 1);
262 assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
264 }
265
266 #[test]
267 fn test_factor_graph_construction() {
268 let mut graph = FactorGraph::new();
269 graph.add_variable("x".to_string(), "Domain1".to_string());
270 graph.add_variable("y".to_string(), "Domain2".to_string());
271
272 assert_eq!(graph.num_variables(), 2);
273 }
274}