tensorlogic_quantrs_hooks/
quantrs_hooks.rs

1//! QuantRS2 integration hooks for probabilistic graphical models.
2//!
3//! This module provides integration between tensorlogic-quantrs-hooks and the QuantRS2
4//! probabilistic programming ecosystem. It defines traits and utilities for seamless
5//! interoperability between PGM inference and QuantRS2 distributions and models.
6//!
7//! # Architecture
8//!
9//! ```text
10//! TensorLogic PGM ←→ QuantRS2 Distributions
11//!       ↓                      ↓
12//!   FactorGraph ←→ Probabilistic Models
13//!       ↓                      ↓
14//!   Inference   ←→  Sampling/Optimization
15//! ```
16//!
17//! # Integration Points
18//!
19//! 1. **Distribution Conversion**: Factor ↔ QuantRS Distribution
20//! 2. **Model Export**: FactorGraph → QuantRS ProbabilisticModel
21//! 3. **Inference Queries**: Unified query interface
22//! 4. **Parameter Learning**: Hook into QuantRS optimizers
23//! 5. **Sampling**: Bridge to QuantRS MCMC samplers
24
25use crate::error::{PgmError, Result};
26use crate::factor::Factor;
27use crate::graph::FactorGraph;
28use scirs2_core::ndarray::ArrayD;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32/// Trait for converting between PGM factors and QuantRS distributions.
33///
34/// This enables seamless integration with QuantRS2's probabilistic modeling framework.
35pub trait QuantRSDistribution {
36    /// Convert a factor to a QuantRS-compatible distribution.
37    ///
38    /// # Returns
39    ///
40    /// A normalized probability distribution that can be used with QuantRS2 samplers
41    /// and inference algorithms.
42    fn to_quantrs_distribution(&self) -> Result<DistributionExport>;
43
44    /// Create a factor from a QuantRS distribution.
45    ///
46    /// # Arguments
47    ///
48    /// * `dist` - The QuantRS distribution to convert
49    ///
50    /// # Returns
51    ///
52    /// A Factor representation suitable for PGM inference.
53    fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self>
54    where
55        Self: Sized;
56
57    /// Check if the distribution is normalized.
58    fn is_normalized(&self) -> bool;
59
60    /// Get the support (valid values) of the distribution.
61    fn support(&self) -> Vec<Vec<usize>>;
62}
63
64/// Exported distribution format compatible with QuantRS2.
65///
66/// This structure can be serialized and used across the COOLJAPAN ecosystem.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct DistributionExport {
69    /// Variable names
70    pub variables: Vec<String>,
71    /// Domain sizes (cardinalities) for each variable
72    pub cardinalities: Vec<usize>,
73    /// Probability values (flattened tensor)
74    pub probabilities: Vec<f64>,
75    /// Shape of the probability tensor
76    pub shape: Vec<usize>,
77    /// Metadata for integration
78    pub metadata: DistributionMetadata,
79}
80
81/// Metadata for distribution export.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct DistributionMetadata {
84    /// Distribution type (e.g., "categorical", "gaussian", "conditional")
85    pub distribution_type: String,
86    /// Whether the distribution is normalized
87    pub normalized: bool,
88    /// Optional parameter names
89    pub parameter_names: Vec<String>,
90    /// Optional tags for categorization
91    pub tags: Vec<String>,
92}
93
94impl QuantRSDistribution for Factor {
95    fn to_quantrs_distribution(&self) -> Result<DistributionExport> {
96        // Get cardinalities from shape
97        let cardinalities: Vec<usize> = self.values.shape().to_vec();
98
99        // Flatten values
100        let probabilities: Vec<f64> = self.values.iter().copied().collect();
101
102        // Check normalization
103        let sum: f64 = probabilities.iter().sum();
104        let normalized = (sum - 1.0).abs() < 1e-6;
105
106        Ok(DistributionExport {
107            variables: self.variables.clone(),
108            cardinalities,
109            probabilities,
110            shape: self.values.shape().to_vec(),
111            metadata: DistributionMetadata {
112                distribution_type: "categorical".to_string(),
113                normalized,
114                parameter_names: vec![],
115                tags: vec!["pgm".to_string(), "factor".to_string()],
116            },
117        })
118    }
119
120    fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self> {
121        let array = ArrayD::from_shape_vec(dist.shape.clone(), dist.probabilities.clone())
122            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
123
124        Factor::new("quantrs_import".to_string(), dist.variables.clone(), array)
125    }
126
127    fn is_normalized(&self) -> bool {
128        let sum: f64 = self.values.iter().sum();
129        (sum - 1.0).abs() < 1e-6
130    }
131
132    fn support(&self) -> Vec<Vec<usize>> {
133        let shape = self.values.shape();
134        let mut support = Vec::new();
135
136        fn generate_indices(shape: &[usize], current: Vec<usize>, support: &mut Vec<Vec<usize>>) {
137            if current.len() == shape.len() {
138                support.push(current);
139                return;
140            }
141
142            let dim = current.len();
143            for i in 0..shape[dim] {
144                let mut next = current.clone();
145                next.push(i);
146                generate_indices(shape, next, support);
147            }
148        }
149
150        generate_indices(shape, vec![], &mut support);
151        support
152    }
153}
154
155/// Trait for models that can export to QuantRS2 format.
156pub trait QuantRSModelExport {
157    /// Export the model to a QuantRS-compatible format.
158    fn to_quantrs_model(&self) -> Result<ModelExport>;
159
160    /// Get model statistics for QuantRS integration.
161    fn model_stats(&self) -> ModelStatistics;
162}
163
164/// Exported model format compatible with QuantRS2.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelExport {
167    /// Model type (e.g., "bayesian_network", "markov_random_field")
168    pub model_type: String,
169    /// Variable definitions
170    pub variables: Vec<VariableDefinition>,
171    /// Factor definitions
172    pub factors: Vec<FactorDefinition>,
173    /// Model structure (edges, dependencies)
174    pub structure: ModelStructure,
175    /// Metadata
176    pub metadata: ModelMetadata,
177}
178
179/// Variable definition for export.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VariableDefinition {
182    /// Variable name
183    pub name: String,
184    /// Domain type
185    pub domain: String,
186    /// Cardinality (number of possible values)
187    pub cardinality: usize,
188    /// Optional domain values
189    pub domain_values: Option<Vec<String>>,
190}
191
192/// Factor definition for export.
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct FactorDefinition {
195    /// Factor name
196    pub name: String,
197    /// Scope (variables involved)
198    pub scope: Vec<String>,
199    /// Distribution export
200    pub distribution: DistributionExport,
201}
202
203/// Model structure definition.
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct ModelStructure {
206    /// Type of structure ("directed", "undirected", "factor_graph")
207    pub structure_type: String,
208    /// Edges (for directed/undirected graphs)
209    pub edges: Vec<(String, String)>,
210    /// Cliques (for MRFs)
211    pub cliques: Vec<Vec<String>>,
212}
213
214/// Model metadata.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelMetadata {
217    /// Model name
218    pub name: String,
219    /// Description
220    pub description: String,
221    /// Creation timestamp
222    pub created_at: String,
223    /// Tags
224    pub tags: Vec<String>,
225}
226
227/// Model statistics for QuantRS integration.
228#[derive(Debug, Clone)]
229pub struct ModelStatistics {
230    /// Number of variables
231    pub num_variables: usize,
232    /// Number of factors
233    pub num_factors: usize,
234    /// Average factor size
235    pub avg_factor_size: f64,
236    /// Maximum factor size
237    pub max_factor_size: usize,
238    /// Treewidth (if computed)
239    pub treewidth: Option<usize>,
240}
241
242impl QuantRSModelExport for FactorGraph {
243    fn to_quantrs_model(&self) -> Result<ModelExport> {
244        // Export variables
245        let variables: Vec<VariableDefinition> = self
246            .variables()
247            .map(|(name, var)| VariableDefinition {
248                name: name.clone(),
249                domain: var.domain.clone(),
250                cardinality: var.cardinality,
251                domain_values: None,
252            })
253            .collect();
254
255        // Export factors
256        let factors: Vec<FactorDefinition> = self
257            .factors()
258            .map(|factor| {
259                Ok(FactorDefinition {
260                    name: factor.name.clone(),
261                    scope: factor.variables.clone(),
262                    distribution: factor.to_quantrs_distribution()?,
263                })
264            })
265            .collect::<Result<Vec<_>>>()?;
266
267        // Build structure
268        let edges = Vec::new();
269        let mut cliques = Vec::new();
270
271        for factor in self.factors() {
272            if factor.variables.len() > 1 {
273                cliques.push(factor.variables.clone());
274            }
275        }
276
277        Ok(ModelExport {
278            model_type: "factor_graph".to_string(),
279            variables,
280            factors,
281            structure: ModelStructure {
282                structure_type: "undirected".to_string(),
283                edges,
284                cliques,
285            },
286            metadata: ModelMetadata {
287                name: "Exported FactorGraph".to_string(),
288                description: "Factor graph exported from tensorlogic-quantrs-hooks".to_string(),
289                created_at: chrono::Utc::now().to_rfc3339(),
290                tags: vec!["pgm".to_string(), "factor_graph".to_string()],
291            },
292        })
293    }
294
295    fn model_stats(&self) -> ModelStatistics {
296        let num_variables = self.num_variables();
297        let num_factors = self.num_factors();
298
299        let avg_factor_size = if num_factors > 0 {
300            self.factors().map(|f| f.variables.len()).sum::<usize>() as f64 / num_factors as f64
301        } else {
302            0.0
303        };
304
305        let max_factor_size = self.factors().map(|f| f.variables.len()).max().unwrap_or(0);
306
307        ModelStatistics {
308            num_variables,
309            num_factors,
310            avg_factor_size,
311            max_factor_size,
312            treewidth: None,
313        }
314    }
315}
316
317/// Trait for probabilistic inference queries compatible with QuantRS2.
318pub trait QuantRSInferenceQuery {
319    /// Execute a marginal query and return QuantRS-compatible distribution.
320    fn query_marginal_quantrs(&self, variable: &str) -> Result<DistributionExport>;
321
322    /// Execute a conditional query.
323    fn query_conditional_quantrs(
324        &self,
325        variable: &str,
326        evidence: &HashMap<String, usize>,
327    ) -> Result<DistributionExport>;
328
329    /// Execute a MAP (maximum a posteriori) query.
330    fn query_map_quantrs(&self) -> Result<HashMap<String, usize>>;
331}
332
333/// Parameter learning interface for QuantRS integration.
334///
335/// This trait enables parameter estimation using QuantRS2 optimization algorithms.
336pub trait QuantRSParameterLearning {
337    /// Learn parameters from data using maximum likelihood estimation.
338    fn learn_parameters_ml(&mut self, data: &[QuantRSAssignment]) -> Result<()>;
339
340    /// Learn parameters using Bayesian estimation with priors.
341    fn learn_parameters_bayesian(
342        &mut self,
343        data: &[QuantRSAssignment],
344        priors: &HashMap<String, ArrayD<f64>>,
345    ) -> Result<()>;
346
347    /// Get current parameters as QuantRS distributions.
348    fn get_parameters(&self) -> Result<Vec<DistributionExport>>;
349
350    /// Set parameters from QuantRS distributions.
351    fn set_parameters(&mut self, params: &[DistributionExport]) -> Result<()>;
352}
353
354/// Assignment of values to variables (for learning and QuantRS integration).
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct QuantRSAssignment {
357    /// Variable assignments
358    pub assignments: HashMap<String, usize>,
359}
360
361impl QuantRSAssignment {
362    /// Create a new assignment.
363    pub fn new(assignments: HashMap<String, usize>) -> Self {
364        Self { assignments }
365    }
366
367    /// Get the value assigned to a variable.
368    pub fn get(&self, variable: &str) -> Option<usize> {
369        self.assignments.get(variable).copied()
370    }
371
372    /// Create from a simple HashMap (compatibility with sampling module).
373    pub fn from_hashmap(assignments: HashMap<String, usize>) -> Self {
374        Self { assignments }
375    }
376
377    /// Convert to a simple HashMap (compatibility with sampling module).
378    pub fn to_hashmap(&self) -> HashMap<String, usize> {
379        self.assignments.clone()
380    }
381}
382
383/// Hook for MCMC sampling integration with QuantRS2.
384pub trait QuantRSSamplingHook {
385    /// Generate samples using QuantRS2-compatible sampler.
386    fn sample_quantrs(&self, num_samples: usize) -> Result<Vec<QuantRSAssignment>>;
387
388    /// Compute log-likelihood for QuantRS integration.
389    fn log_likelihood(&self, assignment: &QuantRSAssignment) -> Result<f64>;
390
391    /// Compute unnormalized probability (potential).
392    fn unnormalized_probability(&self, assignment: &QuantRSAssignment) -> Result<f64>;
393}
394
395/// Utility functions for QuantRS integration.
396pub mod utils {
397    use super::*;
398
399    /// Convert a factor graph to JSON for QuantRS export.
400    pub fn export_to_json(graph: &FactorGraph) -> Result<String> {
401        let model = graph.to_quantrs_model()?;
402        serde_json::to_string_pretty(&model)
403            .map_err(|e| PgmError::InvalidGraph(format!("JSON serialization failed: {}", e)))
404    }
405
406    /// Import a factor graph from JSON.
407    pub fn import_from_json(json: &str) -> Result<ModelExport> {
408        serde_json::from_str(json)
409            .map_err(|e| PgmError::InvalidGraph(format!("JSON deserialization failed: {}", e)))
410    }
411
412    /// Compute mutual information between two variables using QuantRS format.
413    pub fn mutual_information(joint: &DistributionExport, _var1: &str, _var2: &str) -> Result<f64> {
414        if joint.variables.len() != 2 {
415            return Err(PgmError::InvalidGraph(
416                "Joint distribution must have exactly 2 variables".to_string(),
417            ));
418        }
419
420        let mut mi = 0.0;
421        let n1 = joint.cardinalities[0];
422        let n2 = joint.cardinalities[1];
423
424        // Compute marginals
425        let mut p_x = vec![0.0; n1];
426        let mut p_y = vec![0.0; n2];
427
428        for (i, px) in p_x.iter_mut().enumerate().take(n1) {
429            for (j, py) in p_y.iter_mut().enumerate().take(n2) {
430                let idx = i * n2 + j;
431                *px += joint.probabilities[idx];
432                *py += joint.probabilities[idx];
433            }
434        }
435
436        // Compute MI
437        for (i, &px_val) in p_x.iter().enumerate().take(n1) {
438            for (j, &py_val) in p_y.iter().enumerate().take(n2) {
439                let idx = i * n2 + j;
440                let p_xy = joint.probabilities[idx];
441                if p_xy > 1e-10 && px_val > 1e-10 && py_val > 1e-10 {
442                    mi += p_xy * (p_xy / (px_val * py_val)).ln();
443                }
444            }
445        }
446
447        Ok(mi)
448    }
449
450    /// Compute KL divergence between two distributions.
451    pub fn kl_divergence(p: &DistributionExport, q: &DistributionExport) -> Result<f64> {
452        if p.shape != q.shape {
453            return Err(PgmError::InvalidGraph(
454                "Distributions must have same shape".to_string(),
455            ));
456        }
457
458        let mut kl = 0.0;
459        for i in 0..p.probabilities.len() {
460            let pi = p.probabilities[i];
461            let qi = q.probabilities[i];
462
463            if pi > 1e-10 {
464                if qi < 1e-10 {
465                    return Ok(f64::INFINITY);
466                }
467                kl += pi * (pi / qi).ln();
468            }
469        }
470
471        Ok(kl)
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use crate::graph::FactorGraph;
479    use approx::assert_abs_diff_eq;
480    use scirs2_core::ndarray::Array;
481
482    #[test]
483    fn test_factor_to_quantrs_distribution() {
484        let values = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
485            .unwrap()
486            .into_dyn();
487        let factor = Factor::new(
488            "test".to_string(),
489            vec!["x".to_string(), "y".to_string()],
490            values,
491        )
492        .unwrap();
493
494        let dist = factor.to_quantrs_distribution().unwrap();
495
496        assert_eq!(dist.variables.len(), 2);
497        assert_eq!(dist.probabilities.len(), 4);
498        assert!(dist.metadata.normalized);
499    }
500
501    #[test]
502    fn test_quantrs_distribution_roundtrip() {
503        let values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
504            .unwrap()
505            .into_dyn();
506        let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
507
508        let dist = factor.to_quantrs_distribution().unwrap();
509        let factor2 = Factor::from_quantrs_distribution(&dist).unwrap();
510
511        assert_eq!(factor.variables, factor2.variables);
512        assert_eq!(factor.values.shape(), factor2.values.shape());
513    }
514
515    #[test]
516    fn test_is_normalized() {
517        let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
518            .unwrap()
519            .into_dyn();
520        let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
521
522        assert!(factor.is_normalized());
523    }
524
525    #[test]
526    fn test_support() {
527        let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
528            .unwrap()
529            .into_dyn();
530        let factor = Factor::new(
531            "test".to_string(),
532            vec!["x".to_string(), "y".to_string()],
533            values,
534        )
535        .unwrap();
536
537        let support = factor.support();
538        assert_eq!(support.len(), 4);
539        assert_eq!(support[0], vec![0, 0]);
540        assert_eq!(support[1], vec![0, 1]);
541        assert_eq!(support[2], vec![1, 0]);
542        assert_eq!(support[3], vec![1, 1]);
543    }
544
545    #[test]
546    fn test_factor_graph_to_quantrs_model() {
547        let mut graph = FactorGraph::new();
548        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
549        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
550
551        let factor = Factor::new(
552            "P(x,y)".to_string(),
553            vec!["x".to_string(), "y".to_string()],
554            Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
555                .unwrap()
556                .into_dyn(),
557        )
558        .unwrap();
559        graph.add_factor(factor).unwrap();
560
561        let model = graph.to_quantrs_model().unwrap();
562
563        assert_eq!(model.variables.len(), 2);
564        assert_eq!(model.factors.len(), 1);
565        assert_eq!(model.model_type, "factor_graph");
566    }
567
568    #[test]
569    fn test_model_stats() {
570        let mut graph = FactorGraph::new();
571        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
572        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
573
574        let factor = Factor::new(
575            "P(x,y)".to_string(),
576            vec!["x".to_string(), "y".to_string()],
577            Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
578                .unwrap()
579                .into_dyn(),
580        )
581        .unwrap();
582        graph.add_factor(factor).unwrap();
583
584        let stats = graph.model_stats();
585
586        assert_eq!(stats.num_variables, 2);
587        assert_eq!(stats.num_factors, 1);
588        assert_abs_diff_eq!(stats.avg_factor_size, 2.0);
589        assert_eq!(stats.max_factor_size, 2);
590    }
591
592    #[test]
593    fn test_mutual_information() {
594        let dist = DistributionExport {
595            variables: vec!["x".to_string(), "y".to_string()],
596            cardinalities: vec![2, 2],
597            probabilities: vec![0.25, 0.25, 0.25, 0.25],
598            shape: vec![2, 2],
599            metadata: DistributionMetadata {
600                distribution_type: "categorical".to_string(),
601                normalized: true,
602                parameter_names: vec![],
603                tags: vec![],
604            },
605        };
606
607        let mi = utils::mutual_information(&dist, "x", "y").unwrap();
608
609        assert_abs_diff_eq!(mi, 0.0, epsilon = 1e-6);
610    }
611
612    #[test]
613    fn test_kl_divergence() {
614        let p = DistributionExport {
615            variables: vec!["x".to_string()],
616            cardinalities: vec![2],
617            probabilities: vec![0.7, 0.3],
618            shape: vec![2],
619            metadata: DistributionMetadata {
620                distribution_type: "categorical".to_string(),
621                normalized: true,
622                parameter_names: vec![],
623                tags: vec![],
624            },
625        };
626
627        let q = DistributionExport {
628            variables: vec!["x".to_string()],
629            cardinalities: vec![2],
630            probabilities: vec![0.5, 0.5],
631            shape: vec![2],
632            metadata: DistributionMetadata {
633                distribution_type: "categorical".to_string(),
634                normalized: true,
635                parameter_names: vec![],
636                tags: vec![],
637            },
638        };
639
640        let kl = utils::kl_divergence(&p, &q).unwrap();
641
642        assert!(kl > 0.0);
643    }
644}