1use crate::error::{PgmError, Result};
26use crate::factor::Factor;
27use crate::graph::FactorGraph;
28use scirs2_core::ndarray::ArrayD;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32pub trait QuantRSDistribution {
36 fn to_quantrs_distribution(&self) -> Result<DistributionExport>;
43
44 fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self>
54 where
55 Self: Sized;
56
57 fn is_normalized(&self) -> bool;
59
60 fn support(&self) -> Vec<Vec<usize>>;
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct DistributionExport {
69 pub variables: Vec<String>,
71 pub cardinalities: Vec<usize>,
73 pub probabilities: Vec<f64>,
75 pub shape: Vec<usize>,
77 pub metadata: DistributionMetadata,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct DistributionMetadata {
84 pub distribution_type: String,
86 pub normalized: bool,
88 pub parameter_names: Vec<String>,
90 pub tags: Vec<String>,
92}
93
94impl QuantRSDistribution for Factor {
95 fn to_quantrs_distribution(&self) -> Result<DistributionExport> {
96 let cardinalities: Vec<usize> = self.values.shape().to_vec();
98
99 let probabilities: Vec<f64> = self.values.iter().copied().collect();
101
102 let sum: f64 = probabilities.iter().sum();
104 let normalized = (sum - 1.0).abs() < 1e-6;
105
106 Ok(DistributionExport {
107 variables: self.variables.clone(),
108 cardinalities,
109 probabilities,
110 shape: self.values.shape().to_vec(),
111 metadata: DistributionMetadata {
112 distribution_type: "categorical".to_string(),
113 normalized,
114 parameter_names: vec![],
115 tags: vec!["pgm".to_string(), "factor".to_string()],
116 },
117 })
118 }
119
120 fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self> {
121 let array = ArrayD::from_shape_vec(dist.shape.clone(), dist.probabilities.clone())
122 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
123
124 Factor::new("quantrs_import".to_string(), dist.variables.clone(), array)
125 }
126
127 fn is_normalized(&self) -> bool {
128 let sum: f64 = self.values.iter().sum();
129 (sum - 1.0).abs() < 1e-6
130 }
131
132 fn support(&self) -> Vec<Vec<usize>> {
133 let shape = self.values.shape();
134 let mut support = Vec::new();
135
136 fn generate_indices(shape: &[usize], current: Vec<usize>, support: &mut Vec<Vec<usize>>) {
137 if current.len() == shape.len() {
138 support.push(current);
139 return;
140 }
141
142 let dim = current.len();
143 for i in 0..shape[dim] {
144 let mut next = current.clone();
145 next.push(i);
146 generate_indices(shape, next, support);
147 }
148 }
149
150 generate_indices(shape, vec![], &mut support);
151 support
152 }
153}
154
155pub trait QuantRSModelExport {
157 fn to_quantrs_model(&self) -> Result<ModelExport>;
159
160 fn model_stats(&self) -> ModelStatistics;
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelExport {
167 pub model_type: String,
169 pub variables: Vec<VariableDefinition>,
171 pub factors: Vec<FactorDefinition>,
173 pub structure: ModelStructure,
175 pub metadata: ModelMetadata,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VariableDefinition {
182 pub name: String,
184 pub domain: String,
186 pub cardinality: usize,
188 pub domain_values: Option<Vec<String>>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct FactorDefinition {
195 pub name: String,
197 pub scope: Vec<String>,
199 pub distribution: DistributionExport,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct ModelStructure {
206 pub structure_type: String,
208 pub edges: Vec<(String, String)>,
210 pub cliques: Vec<Vec<String>>,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelMetadata {
217 pub name: String,
219 pub description: String,
221 pub created_at: String,
223 pub tags: Vec<String>,
225}
226
227#[derive(Debug, Clone)]
229pub struct ModelStatistics {
230 pub num_variables: usize,
232 pub num_factors: usize,
234 pub avg_factor_size: f64,
236 pub max_factor_size: usize,
238 pub treewidth: Option<usize>,
240}
241
242impl QuantRSModelExport for FactorGraph {
243 fn to_quantrs_model(&self) -> Result<ModelExport> {
244 let variables: Vec<VariableDefinition> = self
246 .variables()
247 .map(|(name, var)| VariableDefinition {
248 name: name.clone(),
249 domain: var.domain.clone(),
250 cardinality: var.cardinality,
251 domain_values: None,
252 })
253 .collect();
254
255 let factors: Vec<FactorDefinition> = self
257 .factors()
258 .map(|factor| {
259 Ok(FactorDefinition {
260 name: factor.name.clone(),
261 scope: factor.variables.clone(),
262 distribution: factor.to_quantrs_distribution()?,
263 })
264 })
265 .collect::<Result<Vec<_>>>()?;
266
267 let edges = Vec::new();
269 let mut cliques = Vec::new();
270
271 for factor in self.factors() {
272 if factor.variables.len() > 1 {
273 cliques.push(factor.variables.clone());
274 }
275 }
276
277 Ok(ModelExport {
278 model_type: "factor_graph".to_string(),
279 variables,
280 factors,
281 structure: ModelStructure {
282 structure_type: "undirected".to_string(),
283 edges,
284 cliques,
285 },
286 metadata: ModelMetadata {
287 name: "Exported FactorGraph".to_string(),
288 description: "Factor graph exported from tensorlogic-quantrs-hooks".to_string(),
289 created_at: chrono::Utc::now().to_rfc3339(),
290 tags: vec!["pgm".to_string(), "factor_graph".to_string()],
291 },
292 })
293 }
294
295 fn model_stats(&self) -> ModelStatistics {
296 let num_variables = self.num_variables();
297 let num_factors = self.num_factors();
298
299 let avg_factor_size = if num_factors > 0 {
300 self.factors().map(|f| f.variables.len()).sum::<usize>() as f64 / num_factors as f64
301 } else {
302 0.0
303 };
304
305 let max_factor_size = self.factors().map(|f| f.variables.len()).max().unwrap_or(0);
306
307 ModelStatistics {
308 num_variables,
309 num_factors,
310 avg_factor_size,
311 max_factor_size,
312 treewidth: None,
313 }
314 }
315}
316
317pub trait QuantRSInferenceQuery {
319 fn query_marginal_quantrs(&self, variable: &str) -> Result<DistributionExport>;
321
322 fn query_conditional_quantrs(
324 &self,
325 variable: &str,
326 evidence: &HashMap<String, usize>,
327 ) -> Result<DistributionExport>;
328
329 fn query_map_quantrs(&self) -> Result<HashMap<String, usize>>;
331}
332
333pub trait QuantRSParameterLearning {
337 fn learn_parameters_ml(&mut self, data: &[QuantRSAssignment]) -> Result<()>;
339
340 fn learn_parameters_bayesian(
342 &mut self,
343 data: &[QuantRSAssignment],
344 priors: &HashMap<String, ArrayD<f64>>,
345 ) -> Result<()>;
346
347 fn get_parameters(&self) -> Result<Vec<DistributionExport>>;
349
350 fn set_parameters(&mut self, params: &[DistributionExport]) -> Result<()>;
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct QuantRSAssignment {
357 pub assignments: HashMap<String, usize>,
359}
360
361impl QuantRSAssignment {
362 pub fn new(assignments: HashMap<String, usize>) -> Self {
364 Self { assignments }
365 }
366
367 pub fn get(&self, variable: &str) -> Option<usize> {
369 self.assignments.get(variable).copied()
370 }
371
372 pub fn from_hashmap(assignments: HashMap<String, usize>) -> Self {
374 Self { assignments }
375 }
376
377 pub fn to_hashmap(&self) -> HashMap<String, usize> {
379 self.assignments.clone()
380 }
381}
382
383pub trait QuantRSSamplingHook {
385 fn sample_quantrs(&self, num_samples: usize) -> Result<Vec<QuantRSAssignment>>;
387
388 fn log_likelihood(&self, assignment: &QuantRSAssignment) -> Result<f64>;
390
391 fn unnormalized_probability(&self, assignment: &QuantRSAssignment) -> Result<f64>;
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct AnnealingConfig {
402 pub num_steps: usize,
404 pub annealing_time: f64,
406 pub num_samples: usize,
408 pub initial_temperature: f64,
410 pub final_temperature: f64,
412}
413
414impl Default for AnnealingConfig {
415 fn default() -> Self {
416 Self {
417 num_steps: 100,
418 annealing_time: 10.0,
419 num_samples: 100,
420 initial_temperature: 10.0,
421 final_temperature: 0.01,
422 }
423 }
424}
425
426impl AnnealingConfig {
427 pub fn new(num_steps: usize, annealing_time: f64) -> Self {
429 Self {
430 num_steps,
431 annealing_time,
432 ..Default::default()
433 }
434 }
435
436 pub fn with_samples(mut self, num_samples: usize) -> Self {
438 self.num_samples = num_samples;
439 self
440 }
441
442 pub fn with_temperature(mut self, initial: f64, final_temp: f64) -> Self {
444 self.initial_temperature = initial;
445 self.final_temperature = final_temp;
446 self
447 }
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct QuantumSolution {
453 pub assignments: HashMap<String, usize>,
455 pub objective_value: f64,
457 pub quality: f64,
459 pub iterations: usize,
461 pub metadata: QuantumSolutionMetadata,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct QuantumSolutionMetadata {
468 pub algorithm: String,
470 pub num_layers: Option<usize>,
472 pub optimal_params: Option<Vec<f64>>,
474 pub time_seconds: Option<f64>,
476}
477
478impl QuantumSolution {
479 pub fn new(assignments: HashMap<String, usize>, objective_value: f64, algorithm: &str) -> Self {
481 Self {
482 assignments,
483 objective_value,
484 quality: objective_value.abs(),
485 iterations: 1,
486 metadata: QuantumSolutionMetadata {
487 algorithm: algorithm.to_string(),
488 num_layers: None,
489 optimal_params: None,
490 time_seconds: None,
491 },
492 }
493 }
494
495 pub fn get(&self, variable: &str) -> Option<usize> {
497 self.assignments.get(variable).copied()
498 }
499}
500
501pub trait QuantumInference {
521 fn solve_qaoa(&self, num_layers: usize) -> Result<HashMap<String, usize>>;
535
536 fn quantum_marginals(&self, num_shots: usize) -> Result<HashMap<String, ArrayD<f64>>>;
549
550 fn quantum_partition_function(&self) -> Result<f64>;
555}
556
557pub trait QuantumAnnealing {
579 fn to_qubo(&self) -> Result<crate::quantum_circuit::QUBOProblem>;
583
584 fn anneal(&self, config: &AnnealingConfig) -> Result<QuantumSolution>;
594
595 fn anneal_multiple(&self, config: &AnnealingConfig, num_runs: usize)
606 -> Result<QuantumSolution>;
607}
608
609impl QuantumInference for FactorGraph {
611 fn solve_qaoa(&self, num_layers: usize) -> Result<HashMap<String, usize>> {
612 use crate::quantum_circuit::{factor_graph_to_qubo, QAOAConfig};
613 use crate::quantum_simulation::{run_qaoa, QuantumSimulationBackend};
614
615 let qubo = factor_graph_to_qubo(self)?;
616 let config = QAOAConfig::new(num_layers);
617 let backend = QuantumSimulationBackend::new();
618 let result = run_qaoa(&qubo, &config, &backend)?;
619
620 let var_names: Vec<String> = self.variable_names().cloned().collect();
622 let mut assignments: HashMap<String, usize> = HashMap::new();
623
624 let solution: &Vec<usize> = &result.best_solution;
625 for (idx, &value) in solution.iter().enumerate() {
626 if idx < var_names.len() {
627 let var_name: &String = &var_names[idx];
628 assignments.insert(var_name.clone(), value);
629 }
630 }
631
632 Ok(assignments)
633 }
634
635 fn quantum_marginals(&self, num_shots: usize) -> Result<HashMap<String, ArrayD<f64>>> {
636 use crate::quantum_simulation::{QuantumSimulationBackend, SimulationConfig};
637
638 let config = SimulationConfig::with_shots(num_shots);
640 let backend = QuantumSimulationBackend::with_config(config);
641 let samples = backend.quantum_sample(self, num_shots)?;
642
643 let mut counts: HashMap<String, Vec<usize>> = HashMap::new();
645 let var_names: Vec<String> = self.variable_names().cloned().collect();
646
647 for var in &var_names {
648 counts.insert(var.clone(), vec![0, 0]); }
650
651 for sample in &samples {
652 for (var, &value) in sample {
653 if let Some(count) = counts.get_mut(var) {
654 if value < count.len() {
655 count[value] += 1;
656 }
657 }
658 }
659 }
660
661 let mut marginals: HashMap<String, ArrayD<f64>> = HashMap::new();
663 let total = samples.len() as f64;
664
665 for (var, count_vec) in counts {
666 let probs: Vec<f64> = count_vec.iter().map(|&c| c as f64 / total).collect();
667 let shape = vec![probs.len()];
668 let arrd = ArrayD::from_shape_vec(shape, probs)
669 .map_err(|e| PgmError::InvalidDistribution(format!("Reshape failed: {}", e)))?;
670 marginals.insert(var, arrd);
671 }
672
673 Ok(marginals)
674 }
675
676 fn quantum_partition_function(&self) -> Result<f64> {
677 let mut z = 0.0;
680 let var_names: Vec<String> = self.variable_names().cloned().collect();
681 let cardinalities: Vec<usize> = var_names
682 .iter()
683 .filter_map(|name| self.get_variable(name).map(|v| v.cardinality))
684 .collect();
685
686 let total_configs: usize = cardinalities.iter().product();
687
688 for config_idx in 0..total_configs {
689 let mut assignment = HashMap::new();
690 let mut temp = config_idx;
691
692 for (i, &card) in cardinalities.iter().enumerate().rev() {
693 assignment.insert(var_names[i].clone(), temp % card);
694 temp /= card;
695 }
696
697 let mut prob = 1.0;
699 for factor in self.factors() {
700 let mut indices = Vec::new();
701 for var in &factor.variables {
702 if let Some(&val) = assignment.get(var) {
703 indices.push(val);
704 }
705 }
706 if !indices.is_empty() {
707 prob *= factor.values[indices.as_slice()];
708 }
709 }
710
711 z += prob;
712 }
713
714 Ok(z)
715 }
716}
717
718impl QuantumAnnealing for FactorGraph {
720 fn to_qubo(&self) -> Result<crate::quantum_circuit::QUBOProblem> {
721 crate::quantum_circuit::factor_graph_to_qubo(self)
722 }
723
724 fn anneal(&self, config: &AnnealingConfig) -> Result<QuantumSolution> {
725 use scirs2_core::random::thread_rng;
728
729 let qubo = self.to_qubo()?;
730 let num_vars = qubo.num_variables;
731 let var_names: Vec<String> = self.variable_names().cloned().collect();
732
733 let mut rng = thread_rng();
735 let mut best_solution: Vec<usize> = (0..num_vars)
736 .map(|_| if rng.random::<f64>() < 0.5 { 0 } else { 1 })
737 .collect();
738
739 let compute_value = |sol: &[usize]| -> f64 {
741 let mut val = qubo.offset;
742 for i in 0..num_vars {
743 val += qubo.linear[i] * sol[i] as f64;
744 for j in (i + 1)..num_vars {
745 val += qubo.quadratic[[i, j]] * (sol[i] * sol[j]) as f64;
746 }
747 }
748 val
749 };
750
751 let mut best_value = compute_value(&best_solution);
752
753 let mut current = best_solution.clone();
755 let mut current_value = best_value;
756
757 for step in 0..config.num_steps {
758 let temp = config.annealing_time * (1.0 - step as f64 / config.num_steps as f64);
759
760 let flip_idx = (rng.random::<f64>() * num_vars as f64) as usize % num_vars;
762 current[flip_idx] = 1 - current[flip_idx];
763
764 let new_value = compute_value(¤t);
765 let delta = new_value - current_value;
766
767 if delta < 0.0 || rng.random::<f64>() < (-delta / temp.max(1e-10)).exp() {
768 current_value = new_value;
769 if current_value < best_value {
770 best_value = current_value;
771 best_solution = current.clone();
772 }
773 } else {
774 current[flip_idx] = 1 - current[flip_idx];
776 }
777 }
778
779 let mut assignments: HashMap<String, usize> = HashMap::new();
781 for (idx, &val) in best_solution.iter().enumerate() {
782 if idx < var_names.len() {
783 let var_name: &String = &var_names[idx];
784 assignments.insert(var_name.clone(), val);
785 }
786 }
787
788 Ok(QuantumSolution {
789 assignments,
790 objective_value: best_value,
791 quality: best_value.abs(),
792 iterations: config.num_steps,
793 metadata: QuantumSolutionMetadata {
794 algorithm: "simulated_annealing".to_string(),
795 num_layers: None,
796 optimal_params: None,
797 time_seconds: None,
798 },
799 })
800 }
801
802 fn anneal_multiple(
803 &self,
804 config: &AnnealingConfig,
805 num_runs: usize,
806 ) -> Result<QuantumSolution> {
807 let mut best_solution: Option<QuantumSolution> = None;
808
809 for _ in 0..num_runs {
810 let solution = self.anneal(config)?;
811
812 match &best_solution {
813 None => best_solution = Some(solution),
814 Some(best) => {
815 if solution.objective_value < best.objective_value {
816 best_solution = Some(solution);
817 }
818 }
819 }
820 }
821
822 best_solution.ok_or_else(|| PgmError::InvalidGraph("No solution found".to_string()))
823 }
824}
825
826pub mod utils {
828 use super::*;
829
830 pub fn export_to_json(graph: &FactorGraph) -> Result<String> {
832 let model = graph.to_quantrs_model()?;
833 serde_json::to_string_pretty(&model)
834 .map_err(|e| PgmError::InvalidGraph(format!("JSON serialization failed: {}", e)))
835 }
836
837 pub fn import_from_json(json: &str) -> Result<ModelExport> {
839 serde_json::from_str(json)
840 .map_err(|e| PgmError::InvalidGraph(format!("JSON deserialization failed: {}", e)))
841 }
842
843 pub fn mutual_information(joint: &DistributionExport, _var1: &str, _var2: &str) -> Result<f64> {
845 if joint.variables.len() != 2 {
846 return Err(PgmError::InvalidGraph(
847 "Joint distribution must have exactly 2 variables".to_string(),
848 ));
849 }
850
851 let mut mi = 0.0;
852 let n1 = joint.cardinalities[0];
853 let n2 = joint.cardinalities[1];
854
855 let mut p_x = vec![0.0; n1];
857 let mut p_y = vec![0.0; n2];
858
859 for (i, px) in p_x.iter_mut().enumerate().take(n1) {
860 for (j, py) in p_y.iter_mut().enumerate().take(n2) {
861 let idx = i * n2 + j;
862 *px += joint.probabilities[idx];
863 *py += joint.probabilities[idx];
864 }
865 }
866
867 for (i, &px_val) in p_x.iter().enumerate().take(n1) {
869 for (j, &py_val) in p_y.iter().enumerate().take(n2) {
870 let idx = i * n2 + j;
871 let p_xy = joint.probabilities[idx];
872 if p_xy > 1e-10 && px_val > 1e-10 && py_val > 1e-10 {
873 mi += p_xy * (p_xy / (px_val * py_val)).ln();
874 }
875 }
876 }
877
878 Ok(mi)
879 }
880
881 pub fn kl_divergence(p: &DistributionExport, q: &DistributionExport) -> Result<f64> {
883 if p.shape != q.shape {
884 return Err(PgmError::InvalidGraph(
885 "Distributions must have same shape".to_string(),
886 ));
887 }
888
889 let mut kl = 0.0;
890 for i in 0..p.probabilities.len() {
891 let pi = p.probabilities[i];
892 let qi = q.probabilities[i];
893
894 if pi > 1e-10 {
895 if qi < 1e-10 {
896 return Ok(f64::INFINITY);
897 }
898 kl += pi * (pi / qi).ln();
899 }
900 }
901
902 Ok(kl)
903 }
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909 use crate::graph::FactorGraph;
910 use approx::assert_abs_diff_eq;
911 use scirs2_core::ndarray::Array;
912
913 #[test]
914 fn test_factor_to_quantrs_distribution() {
915 let values = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
916 .unwrap()
917 .into_dyn();
918 let factor = Factor::new(
919 "test".to_string(),
920 vec!["x".to_string(), "y".to_string()],
921 values,
922 )
923 .unwrap();
924
925 let dist = factor.to_quantrs_distribution().unwrap();
926
927 assert_eq!(dist.variables.len(), 2);
928 assert_eq!(dist.probabilities.len(), 4);
929 assert!(dist.metadata.normalized);
930 }
931
932 #[test]
933 fn test_quantrs_distribution_roundtrip() {
934 let values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
935 .unwrap()
936 .into_dyn();
937 let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
938
939 let dist = factor.to_quantrs_distribution().unwrap();
940 let factor2 = Factor::from_quantrs_distribution(&dist).unwrap();
941
942 assert_eq!(factor.variables, factor2.variables);
943 assert_eq!(factor.values.shape(), factor2.values.shape());
944 }
945
946 #[test]
947 fn test_is_normalized() {
948 let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
949 .unwrap()
950 .into_dyn();
951 let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
952
953 assert!(factor.is_normalized());
954 }
955
956 #[test]
957 fn test_support() {
958 let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
959 .unwrap()
960 .into_dyn();
961 let factor = Factor::new(
962 "test".to_string(),
963 vec!["x".to_string(), "y".to_string()],
964 values,
965 )
966 .unwrap();
967
968 let support = factor.support();
969 assert_eq!(support.len(), 4);
970 assert_eq!(support[0], vec![0, 0]);
971 assert_eq!(support[1], vec![0, 1]);
972 assert_eq!(support[2], vec![1, 0]);
973 assert_eq!(support[3], vec![1, 1]);
974 }
975
976 #[test]
977 fn test_factor_graph_to_quantrs_model() {
978 let mut graph = FactorGraph::new();
979 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
980 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
981
982 let factor = Factor::new(
983 "P(x,y)".to_string(),
984 vec!["x".to_string(), "y".to_string()],
985 Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
986 .unwrap()
987 .into_dyn(),
988 )
989 .unwrap();
990 graph.add_factor(factor).unwrap();
991
992 let model = graph.to_quantrs_model().unwrap();
993
994 assert_eq!(model.variables.len(), 2);
995 assert_eq!(model.factors.len(), 1);
996 assert_eq!(model.model_type, "factor_graph");
997 }
998
999 #[test]
1000 fn test_model_stats() {
1001 let mut graph = FactorGraph::new();
1002 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1003 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
1004
1005 let factor = Factor::new(
1006 "P(x,y)".to_string(),
1007 vec!["x".to_string(), "y".to_string()],
1008 Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
1009 .unwrap()
1010 .into_dyn(),
1011 )
1012 .unwrap();
1013 graph.add_factor(factor).unwrap();
1014
1015 let stats = graph.model_stats();
1016
1017 assert_eq!(stats.num_variables, 2);
1018 assert_eq!(stats.num_factors, 1);
1019 assert_abs_diff_eq!(stats.avg_factor_size, 2.0);
1020 assert_eq!(stats.max_factor_size, 2);
1021 }
1022
1023 #[test]
1024 fn test_mutual_information() {
1025 let dist = DistributionExport {
1026 variables: vec!["x".to_string(), "y".to_string()],
1027 cardinalities: vec![2, 2],
1028 probabilities: vec![0.25, 0.25, 0.25, 0.25],
1029 shape: vec![2, 2],
1030 metadata: DistributionMetadata {
1031 distribution_type: "categorical".to_string(),
1032 normalized: true,
1033 parameter_names: vec![],
1034 tags: vec![],
1035 },
1036 };
1037
1038 let mi = utils::mutual_information(&dist, "x", "y").unwrap();
1039
1040 assert_abs_diff_eq!(mi, 0.0, epsilon = 1e-6);
1041 }
1042
1043 #[test]
1044 fn test_kl_divergence() {
1045 let p = DistributionExport {
1046 variables: vec!["x".to_string()],
1047 cardinalities: vec![2],
1048 probabilities: vec![0.7, 0.3],
1049 shape: vec![2],
1050 metadata: DistributionMetadata {
1051 distribution_type: "categorical".to_string(),
1052 normalized: true,
1053 parameter_names: vec![],
1054 tags: vec![],
1055 },
1056 };
1057
1058 let q = DistributionExport {
1059 variables: vec!["x".to_string()],
1060 cardinalities: vec![2],
1061 probabilities: vec![0.5, 0.5],
1062 shape: vec![2],
1063 metadata: DistributionMetadata {
1064 distribution_type: "categorical".to_string(),
1065 normalized: true,
1066 parameter_names: vec![],
1067 tags: vec![],
1068 },
1069 };
1070
1071 let kl = utils::kl_divergence(&p, &q).unwrap();
1072
1073 assert!(kl > 0.0);
1074 }
1075}