1use crate::error::{PgmError, Result};
26use crate::factor::Factor;
27use crate::graph::FactorGraph;
28use scirs2_core::ndarray::ArrayD;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32pub trait QuantRSDistribution {
36 fn to_quantrs_distribution(&self) -> Result<DistributionExport>;
43
44 fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self>
54 where
55 Self: Sized;
56
57 fn is_normalized(&self) -> bool;
59
60 fn support(&self) -> Vec<Vec<usize>>;
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct DistributionExport {
69 pub variables: Vec<String>,
71 pub cardinalities: Vec<usize>,
73 pub probabilities: Vec<f64>,
75 pub shape: Vec<usize>,
77 pub metadata: DistributionMetadata,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct DistributionMetadata {
84 pub distribution_type: String,
86 pub normalized: bool,
88 pub parameter_names: Vec<String>,
90 pub tags: Vec<String>,
92}
93
94impl QuantRSDistribution for Factor {
95 fn to_quantrs_distribution(&self) -> Result<DistributionExport> {
96 let cardinalities: Vec<usize> = self.values.shape().to_vec();
98
99 let probabilities: Vec<f64> = self.values.iter().copied().collect();
101
102 let sum: f64 = probabilities.iter().sum();
104 let normalized = (sum - 1.0).abs() < 1e-6;
105
106 Ok(DistributionExport {
107 variables: self.variables.clone(),
108 cardinalities,
109 probabilities,
110 shape: self.values.shape().to_vec(),
111 metadata: DistributionMetadata {
112 distribution_type: "categorical".to_string(),
113 normalized,
114 parameter_names: vec![],
115 tags: vec!["pgm".to_string(), "factor".to_string()],
116 },
117 })
118 }
119
120 fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self> {
121 let array = ArrayD::from_shape_vec(dist.shape.clone(), dist.probabilities.clone())
122 .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
123
124 Factor::new("quantrs_import".to_string(), dist.variables.clone(), array)
125 }
126
127 fn is_normalized(&self) -> bool {
128 let sum: f64 = self.values.iter().sum();
129 (sum - 1.0).abs() < 1e-6
130 }
131
132 fn support(&self) -> Vec<Vec<usize>> {
133 let shape = self.values.shape();
134 let mut support = Vec::new();
135
136 fn generate_indices(shape: &[usize], current: Vec<usize>, support: &mut Vec<Vec<usize>>) {
137 if current.len() == shape.len() {
138 support.push(current);
139 return;
140 }
141
142 let dim = current.len();
143 for i in 0..shape[dim] {
144 let mut next = current.clone();
145 next.push(i);
146 generate_indices(shape, next, support);
147 }
148 }
149
150 generate_indices(shape, vec![], &mut support);
151 support
152 }
153}
154
155pub trait QuantRSModelExport {
157 fn to_quantrs_model(&self) -> Result<ModelExport>;
159
160 fn model_stats(&self) -> ModelStatistics;
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelExport {
167 pub model_type: String,
169 pub variables: Vec<VariableDefinition>,
171 pub factors: Vec<FactorDefinition>,
173 pub structure: ModelStructure,
175 pub metadata: ModelMetadata,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VariableDefinition {
182 pub name: String,
184 pub domain: String,
186 pub cardinality: usize,
188 pub domain_values: Option<Vec<String>>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct FactorDefinition {
195 pub name: String,
197 pub scope: Vec<String>,
199 pub distribution: DistributionExport,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct ModelStructure {
206 pub structure_type: String,
208 pub edges: Vec<(String, String)>,
210 pub cliques: Vec<Vec<String>>,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelMetadata {
217 pub name: String,
219 pub description: String,
221 pub created_at: String,
223 pub tags: Vec<String>,
225}
226
227#[derive(Debug, Clone)]
229pub struct ModelStatistics {
230 pub num_variables: usize,
232 pub num_factors: usize,
234 pub avg_factor_size: f64,
236 pub max_factor_size: usize,
238 pub treewidth: Option<usize>,
240}
241
242impl QuantRSModelExport for FactorGraph {
243 fn to_quantrs_model(&self) -> Result<ModelExport> {
244 let variables: Vec<VariableDefinition> = self
246 .variables()
247 .map(|(name, var)| VariableDefinition {
248 name: name.clone(),
249 domain: var.domain.clone(),
250 cardinality: var.cardinality,
251 domain_values: None,
252 })
253 .collect();
254
255 let factors: Vec<FactorDefinition> = self
257 .factors()
258 .map(|factor| {
259 Ok(FactorDefinition {
260 name: factor.name.clone(),
261 scope: factor.variables.clone(),
262 distribution: factor.to_quantrs_distribution()?,
263 })
264 })
265 .collect::<Result<Vec<_>>>()?;
266
267 let edges = Vec::new();
269 let mut cliques = Vec::new();
270
271 for factor in self.factors() {
272 if factor.variables.len() > 1 {
273 cliques.push(factor.variables.clone());
274 }
275 }
276
277 Ok(ModelExport {
278 model_type: "factor_graph".to_string(),
279 variables,
280 factors,
281 structure: ModelStructure {
282 structure_type: "undirected".to_string(),
283 edges,
284 cliques,
285 },
286 metadata: ModelMetadata {
287 name: "Exported FactorGraph".to_string(),
288 description: "Factor graph exported from tensorlogic-quantrs-hooks".to_string(),
289 created_at: chrono::Utc::now().to_rfc3339(),
290 tags: vec!["pgm".to_string(), "factor_graph".to_string()],
291 },
292 })
293 }
294
295 fn model_stats(&self) -> ModelStatistics {
296 let num_variables = self.num_variables();
297 let num_factors = self.num_factors();
298
299 let avg_factor_size = if num_factors > 0 {
300 self.factors().map(|f| f.variables.len()).sum::<usize>() as f64 / num_factors as f64
301 } else {
302 0.0
303 };
304
305 let max_factor_size = self.factors().map(|f| f.variables.len()).max().unwrap_or(0);
306
307 ModelStatistics {
308 num_variables,
309 num_factors,
310 avg_factor_size,
311 max_factor_size,
312 treewidth: None,
313 }
314 }
315}
316
317pub trait QuantRSInferenceQuery {
319 fn query_marginal_quantrs(&self, variable: &str) -> Result<DistributionExport>;
321
322 fn query_conditional_quantrs(
324 &self,
325 variable: &str,
326 evidence: &HashMap<String, usize>,
327 ) -> Result<DistributionExport>;
328
329 fn query_map_quantrs(&self) -> Result<HashMap<String, usize>>;
331}
332
333pub trait QuantRSParameterLearning {
337 fn learn_parameters_ml(&mut self, data: &[QuantRSAssignment]) -> Result<()>;
339
340 fn learn_parameters_bayesian(
342 &mut self,
343 data: &[QuantRSAssignment],
344 priors: &HashMap<String, ArrayD<f64>>,
345 ) -> Result<()>;
346
347 fn get_parameters(&self) -> Result<Vec<DistributionExport>>;
349
350 fn set_parameters(&mut self, params: &[DistributionExport]) -> Result<()>;
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct QuantRSAssignment {
357 pub assignments: HashMap<String, usize>,
359}
360
361impl QuantRSAssignment {
362 pub fn new(assignments: HashMap<String, usize>) -> Self {
364 Self { assignments }
365 }
366
367 pub fn get(&self, variable: &str) -> Option<usize> {
369 self.assignments.get(variable).copied()
370 }
371
372 pub fn from_hashmap(assignments: HashMap<String, usize>) -> Self {
374 Self { assignments }
375 }
376
377 pub fn to_hashmap(&self) -> HashMap<String, usize> {
379 self.assignments.clone()
380 }
381}
382
383pub trait QuantRSSamplingHook {
385 fn sample_quantrs(&self, num_samples: usize) -> Result<Vec<QuantRSAssignment>>;
387
388 fn log_likelihood(&self, assignment: &QuantRSAssignment) -> Result<f64>;
390
391 fn unnormalized_probability(&self, assignment: &QuantRSAssignment) -> Result<f64>;
393}
394
395pub mod utils {
397 use super::*;
398
399 pub fn export_to_json(graph: &FactorGraph) -> Result<String> {
401 let model = graph.to_quantrs_model()?;
402 serde_json::to_string_pretty(&model)
403 .map_err(|e| PgmError::InvalidGraph(format!("JSON serialization failed: {}", e)))
404 }
405
406 pub fn import_from_json(json: &str) -> Result<ModelExport> {
408 serde_json::from_str(json)
409 .map_err(|e| PgmError::InvalidGraph(format!("JSON deserialization failed: {}", e)))
410 }
411
412 pub fn mutual_information(joint: &DistributionExport, _var1: &str, _var2: &str) -> Result<f64> {
414 if joint.variables.len() != 2 {
415 return Err(PgmError::InvalidGraph(
416 "Joint distribution must have exactly 2 variables".to_string(),
417 ));
418 }
419
420 let mut mi = 0.0;
421 let n1 = joint.cardinalities[0];
422 let n2 = joint.cardinalities[1];
423
424 let mut p_x = vec![0.0; n1];
426 let mut p_y = vec![0.0; n2];
427
428 for (i, px) in p_x.iter_mut().enumerate().take(n1) {
429 for (j, py) in p_y.iter_mut().enumerate().take(n2) {
430 let idx = i * n2 + j;
431 *px += joint.probabilities[idx];
432 *py += joint.probabilities[idx];
433 }
434 }
435
436 for (i, &px_val) in p_x.iter().enumerate().take(n1) {
438 for (j, &py_val) in p_y.iter().enumerate().take(n2) {
439 let idx = i * n2 + j;
440 let p_xy = joint.probabilities[idx];
441 if p_xy > 1e-10 && px_val > 1e-10 && py_val > 1e-10 {
442 mi += p_xy * (p_xy / (px_val * py_val)).ln();
443 }
444 }
445 }
446
447 Ok(mi)
448 }
449
450 pub fn kl_divergence(p: &DistributionExport, q: &DistributionExport) -> Result<f64> {
452 if p.shape != q.shape {
453 return Err(PgmError::InvalidGraph(
454 "Distributions must have same shape".to_string(),
455 ));
456 }
457
458 let mut kl = 0.0;
459 for i in 0..p.probabilities.len() {
460 let pi = p.probabilities[i];
461 let qi = q.probabilities[i];
462
463 if pi > 1e-10 {
464 if qi < 1e-10 {
465 return Ok(f64::INFINITY);
466 }
467 kl += pi * (pi / qi).ln();
468 }
469 }
470
471 Ok(kl)
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use crate::graph::FactorGraph;
479 use approx::assert_abs_diff_eq;
480 use scirs2_core::ndarray::Array;
481
482 #[test]
483 fn test_factor_to_quantrs_distribution() {
484 let values = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
485 .unwrap()
486 .into_dyn();
487 let factor = Factor::new(
488 "test".to_string(),
489 vec!["x".to_string(), "y".to_string()],
490 values,
491 )
492 .unwrap();
493
494 let dist = factor.to_quantrs_distribution().unwrap();
495
496 assert_eq!(dist.variables.len(), 2);
497 assert_eq!(dist.probabilities.len(), 4);
498 assert!(dist.metadata.normalized);
499 }
500
501 #[test]
502 fn test_quantrs_distribution_roundtrip() {
503 let values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
504 .unwrap()
505 .into_dyn();
506 let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
507
508 let dist = factor.to_quantrs_distribution().unwrap();
509 let factor2 = Factor::from_quantrs_distribution(&dist).unwrap();
510
511 assert_eq!(factor.variables, factor2.variables);
512 assert_eq!(factor.values.shape(), factor2.values.shape());
513 }
514
515 #[test]
516 fn test_is_normalized() {
517 let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
518 .unwrap()
519 .into_dyn();
520 let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
521
522 assert!(factor.is_normalized());
523 }
524
525 #[test]
526 fn test_support() {
527 let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
528 .unwrap()
529 .into_dyn();
530 let factor = Factor::new(
531 "test".to_string(),
532 vec!["x".to_string(), "y".to_string()],
533 values,
534 )
535 .unwrap();
536
537 let support = factor.support();
538 assert_eq!(support.len(), 4);
539 assert_eq!(support[0], vec![0, 0]);
540 assert_eq!(support[1], vec![0, 1]);
541 assert_eq!(support[2], vec![1, 0]);
542 assert_eq!(support[3], vec![1, 1]);
543 }
544
545 #[test]
546 fn test_factor_graph_to_quantrs_model() {
547 let mut graph = FactorGraph::new();
548 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
549 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
550
551 let factor = Factor::new(
552 "P(x,y)".to_string(),
553 vec!["x".to_string(), "y".to_string()],
554 Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
555 .unwrap()
556 .into_dyn(),
557 )
558 .unwrap();
559 graph.add_factor(factor).unwrap();
560
561 let model = graph.to_quantrs_model().unwrap();
562
563 assert_eq!(model.variables.len(), 2);
564 assert_eq!(model.factors.len(), 1);
565 assert_eq!(model.model_type, "factor_graph");
566 }
567
568 #[test]
569 fn test_model_stats() {
570 let mut graph = FactorGraph::new();
571 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
572 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
573
574 let factor = Factor::new(
575 "P(x,y)".to_string(),
576 vec!["x".to_string(), "y".to_string()],
577 Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
578 .unwrap()
579 .into_dyn(),
580 )
581 .unwrap();
582 graph.add_factor(factor).unwrap();
583
584 let stats = graph.model_stats();
585
586 assert_eq!(stats.num_variables, 2);
587 assert_eq!(stats.num_factors, 1);
588 assert_abs_diff_eq!(stats.avg_factor_size, 2.0);
589 assert_eq!(stats.max_factor_size, 2);
590 }
591
592 #[test]
593 fn test_mutual_information() {
594 let dist = DistributionExport {
595 variables: vec!["x".to_string(), "y".to_string()],
596 cardinalities: vec![2, 2],
597 probabilities: vec![0.25, 0.25, 0.25, 0.25],
598 shape: vec![2, 2],
599 metadata: DistributionMetadata {
600 distribution_type: "categorical".to_string(),
601 normalized: true,
602 parameter_names: vec![],
603 tags: vec![],
604 },
605 };
606
607 let mi = utils::mutual_information(&dist, "x", "y").unwrap();
608
609 assert_abs_diff_eq!(mi, 0.0, epsilon = 1e-6);
610 }
611
612 #[test]
613 fn test_kl_divergence() {
614 let p = DistributionExport {
615 variables: vec!["x".to_string()],
616 cardinalities: vec![2],
617 probabilities: vec![0.7, 0.3],
618 shape: vec![2],
619 metadata: DistributionMetadata {
620 distribution_type: "categorical".to_string(),
621 normalized: true,
622 parameter_names: vec![],
623 tags: vec![],
624 },
625 };
626
627 let q = DistributionExport {
628 variables: vec!["x".to_string()],
629 cardinalities: vec![2],
630 probabilities: vec![0.5, 0.5],
631 shape: vec![2],
632 metadata: DistributionMetadata {
633 distribution_type: "categorical".to_string(),
634 normalized: true,
635 parameter_names: vec![],
636 tags: vec![],
637 },
638 };
639
640 let kl = utils::kl_divergence(&p, &q).unwrap();
641
642 assert!(kl > 0.0);
643 }
644}