1mod cache;
23pub mod convergence;
24pub mod dbn;
25mod elimination_ordering;
26mod error;
27mod expectation_propagation;
28mod factor;
29pub mod factor_graph_viz;
30mod graph;
31mod inference;
32pub mod influence;
33mod junction_tree;
34mod linear_chain_crf;
35pub mod loopy_bp;
36pub mod memory;
37mod message_passing;
38mod models;
39mod parallel_message_passing;
40pub mod parameter_learning;
41pub mod quantrs_hooks;
42pub mod quantum_circuit;
43pub mod quantum_simulation;
44mod sampling;
45pub mod tensor_network_bridge;
46mod variable_elimination;
47mod variational;
48pub mod vmp;
49
50pub use cache::{CacheStats, CachedFactor, FactorCache};
51pub use convergence::{
52 ConvergenceConfig, ConvergenceError, ConvergenceMonitor, ConvergenceState, DampingSchedule,
53 InferenceStats,
54};
55pub use dbn::{CoupledDBN, CouplingFactor, DBNBuilder, DynamicBayesianNetwork, TemporalVar};
56pub use elimination_ordering::{EliminationOrdering, EliminationStrategy};
57pub use error::{PgmError, Result};
58pub use expectation_propagation::{ExpectationPropagation, GaussianEP, GaussianSite, Site};
59pub use factor::{Factor, FactorOp};
60pub use factor_graph_viz::{
61 render_ascii, render_dot, FactorGraphModel, FactorGraphStats, VizFactorNode, VizVariableNode,
62};
63pub use graph::{FactorGraph, FactorNode, VariableNode};
64pub use inference::{ConditionalQuery, InferenceEngine, MarginalizationQuery};
65pub use influence::{
66 InfluenceDiagram, InfluenceDiagramBuilder, InfluenceNode, MultiAttributeUtility, NodeType,
67};
68pub use junction_tree::{Clique, JunctionTree, JunctionTreeEdge, Separator};
69pub use linear_chain_crf::{
70 EmissionFeature, FeatureFunction, IdentityFeature, LinearChainCRF, TransitionFeature,
71};
72pub use loopy_bp::{
73 bethe_free_energy, BetheFreeEnergy, CycleAnalysis, CycleDetector, LbpConvergenceMonitor,
74 LbpDampingPolicy, LbpIterStats, LogMessage, LoopyBeliefPropagation, LoopyBpConfig,
75 LoopyBpResult, UpdateSchedule,
76};
77pub use memory::{
78 BlockSparseFactor, CompressedFactor, FactorPool, LazyFactor, MemoryEstimate, PoolStats,
79 SparseFactor, StreamingFactorGraph,
80};
81pub use message_passing::{
82 ConvergenceStats, MaxProductAlgorithm, MessagePassingAlgorithm, SumProductAlgorithm,
83};
84pub use models::{BayesianNetwork, ConditionalRandomField, HiddenMarkovModel, MarkovRandomField};
85pub use parallel_message_passing::{ParallelMaxProduct, ParallelSumProduct};
86pub use parameter_learning::{
87 BaumWelchLearner, BayesianEstimator, MaximumLikelihoodEstimator, SimpleHMM,
88};
89pub use quantrs_hooks::{
90 AnnealingConfig, DistributionExport, DistributionMetadata, ModelExport, ModelStatistics,
91 QuantRSAssignment, QuantRSDistribution, QuantRSInferenceQuery, QuantRSModelExport,
92 QuantRSParameterLearning, QuantRSSamplingHook, QuantumAnnealing, QuantumInference,
93 QuantumSolution, QuantumSolutionMetadata,
94};
95pub use quantum_circuit::{
96 tlexpr_to_qaoa_circuit, IsingModel, QAOAConfig, QAOAResult, QUBOProblem, QuantumCircuitBuilder,
97};
98pub use quantum_simulation::{
99 run_qaoa, QuantumSimulationBackend, SimulatedState, SimulationConfig,
100};
101pub use sampling::{
102 Assignment, GibbsSampler, ImportanceSampler, LikelihoodWeighting, Particle, ParticleFilter,
103 ProposalDistribution, WeightedSample,
104};
105pub use tensor_network_bridge::{
106 factor_graph_to_tensor_network, linear_chain_to_mps, MatrixProductState, Tensor, TensorNetwork,
107 TensorNetworkStats,
108};
109pub use variable_elimination::VariableElimination;
110pub use variational::{BetheApproximation, MeanFieldInference, TreeReweightedBP};
111pub use vmp::{
112 categorical_kl, dirichlet_kl, gaussian_kl, gaussian_kl_fixed_precision,
113 BetaBernoulliObservation, BetaNP, CategoricalNP, DirichletNP, ExponentialFamily, Family,
114 GammaNP, GammaPoissonObservation, GaussianNP, MessageDirection, VariationalMessagePassing,
115 VariationalState, VmpConfig, VmpFactor, VmpMessage, VmpResult,
116};
117
118use scirs2_core::ndarray::ArrayD;
119use std::collections::HashMap;
120use tensorlogic_ir::TLExpr;
121
122pub fn expr_to_factor_graph(expr: &TLExpr) -> Result<FactorGraph> {
127 let mut graph = FactorGraph::new();
128
129 extract_factors(expr, &mut graph)?;
131
132 Ok(graph)
133}
134
135fn extract_factors(expr: &TLExpr, graph: &mut FactorGraph) -> Result<()> {
137 match expr {
138 TLExpr::Pred { name, args } => {
139 let var_names: Vec<String> = args
141 .iter()
142 .filter_map(|term| match term {
143 tensorlogic_ir::Term::Var(v) => Some(v.clone()),
144 _ => None,
145 })
146 .collect();
147
148 for var_name in &var_names {
150 if graph.get_variable(var_name).is_none() {
151 graph.add_variable(var_name.clone(), "default".to_string());
152 }
153 }
154
155 if !var_names.is_empty() {
156 graph.add_factor_from_predicate(name, &var_names)?;
157 }
158 }
159 TLExpr::And(left, right) => {
160 extract_factors(left, graph)?;
162 extract_factors(right, graph)?;
163 }
164 TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
165 graph.add_variable(var.clone(), domain.clone());
167 extract_factors(body, graph)?;
168 }
169 TLExpr::Imply(premise, conclusion) => {
170 extract_factors(premise, graph)?;
172 extract_factors(conclusion, graph)?;
173 }
174 TLExpr::Not(inner) => {
175 extract_factors(inner, graph)?;
177 }
178 _ => {
179 }
181 }
182
183 Ok(())
184}
185
186pub fn message_passing_reduce(
190 graph: &FactorGraph,
191 algorithm: &dyn MessagePassingAlgorithm,
192) -> Result<HashMap<String, ArrayD<f64>>> {
193 algorithm.run(graph)
194}
195
196pub fn marginalize(
200 joint_distribution: &ArrayD<f64>,
201 variable_idx: usize,
202 axes_to_sum: &[usize],
203) -> Result<ArrayD<f64>> {
204 use scirs2_core::ndarray::Axis;
205
206 let mut result = joint_distribution.clone();
207
208 for &axis in axes_to_sum.iter().rev() {
210 if axis != variable_idx {
211 result = result.sum_axis(Axis(axis));
212 }
213 }
214
215 Ok(result)
216}
217
218pub fn condition(
222 joint_distribution: &ArrayD<f64>,
223 evidence: &HashMap<usize, usize>,
224) -> Result<ArrayD<f64>> {
225 let mut result = joint_distribution.clone();
226
227 for (&var_idx, &value) in evidence {
229 result = result.index_axis_move(scirs2_core::ndarray::Axis(var_idx), value);
230 }
231
232 let sum: f64 = result.iter().sum();
234 if sum > 0.0 {
235 result /= sum;
236 }
237
238 Ok(result)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use approx::assert_abs_diff_eq;
245 use scirs2_core::ndarray::Array;
246 use tensorlogic_ir::Term;
247
248 #[test]
249 fn test_expr_to_factor_graph() {
250 let expr = TLExpr::pred("P", vec![Term::var("x")]);
251 let graph = expr_to_factor_graph(&expr).expect("unwrap");
252 assert!(!graph.is_empty());
253 }
254
255 #[test]
256 fn test_marginalize_simple() {
257 let joint = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
259 .expect("unwrap")
260 .into_dyn();
261
262 let marginal = marginalize(&joint, 0, &[0, 1]).expect("unwrap");
264
265 assert_eq!(marginal.ndim(), 1);
266 assert_abs_diff_eq!(marginal.sum(), 1.0, epsilon = 1e-10);
267 }
268
269 #[test]
270 fn test_condition_simple() {
271 let joint = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
273 .expect("unwrap")
274 .into_dyn();
275
276 let mut evidence = HashMap::new();
278 evidence.insert(1, 1);
279
280 let conditional = condition(&joint, &evidence).expect("unwrap");
281
282 assert_eq!(conditional.ndim(), 1);
284 assert_abs_diff_eq!(conditional.sum(), 1.0, epsilon = 1e-10);
286 }
287
288 #[test]
289 fn test_factor_graph_construction() {
290 let mut graph = FactorGraph::new();
291 graph.add_variable("x".to_string(), "Domain1".to_string());
292 graph.add_variable("y".to_string(), "Domain2".to_string());
293
294 assert_eq!(graph.num_variables(), 2);
295 }
296}