Skip to main content

tensorlogic_quantrs_hooks/
quantrs_hooks.rs

1//! QuantRS2 integration hooks for probabilistic graphical models.
2//!
3//! This module provides integration between tensorlogic-quantrs-hooks and the QuantRS2
4//! probabilistic programming ecosystem. It defines traits and utilities for seamless
5//! interoperability between PGM inference and QuantRS2 distributions and models.
6//!
7//! # Architecture
8//!
9//! ```text
10//! TensorLogic PGM ←→ QuantRS2 Distributions
11//!       ↓                      ↓
12//!   FactorGraph ←→ Probabilistic Models
13//!       ↓                      ↓
14//!   Inference   ←→  Sampling/Optimization
15//! ```
16//!
17//! # Integration Points
18//!
19//! 1. **Distribution Conversion**: Factor ↔ QuantRS Distribution
20//! 2. **Model Export**: FactorGraph → QuantRS ProbabilisticModel
21//! 3. **Inference Queries**: Unified query interface
22//! 4. **Parameter Learning**: Hook into QuantRS optimizers
23//! 5. **Sampling**: Bridge to QuantRS MCMC samplers
24
25use crate::error::{PgmError, Result};
26use crate::factor::Factor;
27use crate::graph::FactorGraph;
28use scirs2_core::ndarray::ArrayD;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32/// Trait for converting between PGM factors and QuantRS distributions.
33///
34/// This enables seamless integration with QuantRS2's probabilistic modeling framework.
35pub trait QuantRSDistribution {
36    /// Convert a factor to a QuantRS-compatible distribution.
37    ///
38    /// # Returns
39    ///
40    /// A normalized probability distribution that can be used with QuantRS2 samplers
41    /// and inference algorithms.
42    fn to_quantrs_distribution(&self) -> Result<DistributionExport>;
43
44    /// Create a factor from a QuantRS distribution.
45    ///
46    /// # Arguments
47    ///
48    /// * `dist` - The QuantRS distribution to convert
49    ///
50    /// # Returns
51    ///
52    /// A Factor representation suitable for PGM inference.
53    fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self>
54    where
55        Self: Sized;
56
57    /// Check if the distribution is normalized.
58    fn is_normalized(&self) -> bool;
59
60    /// Get the support (valid values) of the distribution.
61    fn support(&self) -> Vec<Vec<usize>>;
62}
63
64/// Exported distribution format compatible with QuantRS2.
65///
66/// This structure can be serialized and used across the COOLJAPAN ecosystem.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct DistributionExport {
69    /// Variable names
70    pub variables: Vec<String>,
71    /// Domain sizes (cardinalities) for each variable
72    pub cardinalities: Vec<usize>,
73    /// Probability values (flattened tensor)
74    pub probabilities: Vec<f64>,
75    /// Shape of the probability tensor
76    pub shape: Vec<usize>,
77    /// Metadata for integration
78    pub metadata: DistributionMetadata,
79}
80
81/// Metadata for distribution export.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct DistributionMetadata {
84    /// Distribution type (e.g., "categorical", "gaussian", "conditional")
85    pub distribution_type: String,
86    /// Whether the distribution is normalized
87    pub normalized: bool,
88    /// Optional parameter names
89    pub parameter_names: Vec<String>,
90    /// Optional tags for categorization
91    pub tags: Vec<String>,
92}
93
94impl QuantRSDistribution for Factor {
95    fn to_quantrs_distribution(&self) -> Result<DistributionExport> {
96        // Get cardinalities from shape
97        let cardinalities: Vec<usize> = self.values.shape().to_vec();
98
99        // Flatten values
100        let probabilities: Vec<f64> = self.values.iter().copied().collect();
101
102        // Check normalization
103        let sum: f64 = probabilities.iter().sum();
104        let normalized = (sum - 1.0).abs() < 1e-6;
105
106        Ok(DistributionExport {
107            variables: self.variables.clone(),
108            cardinalities,
109            probabilities,
110            shape: self.values.shape().to_vec(),
111            metadata: DistributionMetadata {
112                distribution_type: "categorical".to_string(),
113                normalized,
114                parameter_names: vec![],
115                tags: vec!["pgm".to_string(), "factor".to_string()],
116            },
117        })
118    }
119
120    fn from_quantrs_distribution(dist: &DistributionExport) -> Result<Self> {
121        let array = ArrayD::from_shape_vec(dist.shape.clone(), dist.probabilities.clone())
122            .map_err(|e| PgmError::InvalidGraph(format!("Array creation failed: {}", e)))?;
123
124        Factor::new("quantrs_import".to_string(), dist.variables.clone(), array)
125    }
126
127    fn is_normalized(&self) -> bool {
128        let sum: f64 = self.values.iter().sum();
129        (sum - 1.0).abs() < 1e-6
130    }
131
132    fn support(&self) -> Vec<Vec<usize>> {
133        let shape = self.values.shape();
134        let mut support = Vec::new();
135
136        fn generate_indices(shape: &[usize], current: Vec<usize>, support: &mut Vec<Vec<usize>>) {
137            if current.len() == shape.len() {
138                support.push(current);
139                return;
140            }
141
142            let dim = current.len();
143            for i in 0..shape[dim] {
144                let mut next = current.clone();
145                next.push(i);
146                generate_indices(shape, next, support);
147            }
148        }
149
150        generate_indices(shape, vec![], &mut support);
151        support
152    }
153}
154
155/// Trait for models that can export to QuantRS2 format.
156pub trait QuantRSModelExport {
157    /// Export the model to a QuantRS-compatible format.
158    fn to_quantrs_model(&self) -> Result<ModelExport>;
159
160    /// Get model statistics for QuantRS integration.
161    fn model_stats(&self) -> ModelStatistics;
162}
163
164/// Exported model format compatible with QuantRS2.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelExport {
167    /// Model type (e.g., "bayesian_network", "markov_random_field")
168    pub model_type: String,
169    /// Variable definitions
170    pub variables: Vec<VariableDefinition>,
171    /// Factor definitions
172    pub factors: Vec<FactorDefinition>,
173    /// Model structure (edges, dependencies)
174    pub structure: ModelStructure,
175    /// Metadata
176    pub metadata: ModelMetadata,
177}
178
179/// Variable definition for export.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VariableDefinition {
182    /// Variable name
183    pub name: String,
184    /// Domain type
185    pub domain: String,
186    /// Cardinality (number of possible values)
187    pub cardinality: usize,
188    /// Optional domain values
189    pub domain_values: Option<Vec<String>>,
190}
191
192/// Factor definition for export.
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct FactorDefinition {
195    /// Factor name
196    pub name: String,
197    /// Scope (variables involved)
198    pub scope: Vec<String>,
199    /// Distribution export
200    pub distribution: DistributionExport,
201}
202
203/// Model structure definition.
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct ModelStructure {
206    /// Type of structure ("directed", "undirected", "factor_graph")
207    pub structure_type: String,
208    /// Edges (for directed/undirected graphs)
209    pub edges: Vec<(String, String)>,
210    /// Cliques (for MRFs)
211    pub cliques: Vec<Vec<String>>,
212}
213
214/// Model metadata.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelMetadata {
217    /// Model name
218    pub name: String,
219    /// Description
220    pub description: String,
221    /// Creation timestamp
222    pub created_at: String,
223    /// Tags
224    pub tags: Vec<String>,
225}
226
227/// Model statistics for QuantRS integration.
228#[derive(Debug, Clone)]
229pub struct ModelStatistics {
230    /// Number of variables
231    pub num_variables: usize,
232    /// Number of factors
233    pub num_factors: usize,
234    /// Average factor size
235    pub avg_factor_size: f64,
236    /// Maximum factor size
237    pub max_factor_size: usize,
238    /// Treewidth (if computed)
239    pub treewidth: Option<usize>,
240}
241
242impl QuantRSModelExport for FactorGraph {
243    fn to_quantrs_model(&self) -> Result<ModelExport> {
244        // Export variables
245        let variables: Vec<VariableDefinition> = self
246            .variables()
247            .map(|(name, var)| VariableDefinition {
248                name: name.clone(),
249                domain: var.domain.clone(),
250                cardinality: var.cardinality,
251                domain_values: None,
252            })
253            .collect();
254
255        // Export factors
256        let factors: Vec<FactorDefinition> = self
257            .factors()
258            .map(|factor| {
259                Ok(FactorDefinition {
260                    name: factor.name.clone(),
261                    scope: factor.variables.clone(),
262                    distribution: factor.to_quantrs_distribution()?,
263                })
264            })
265            .collect::<Result<Vec<_>>>()?;
266
267        // Build structure
268        let edges = Vec::new();
269        let mut cliques = Vec::new();
270
271        for factor in self.factors() {
272            if factor.variables.len() > 1 {
273                cliques.push(factor.variables.clone());
274            }
275        }
276
277        Ok(ModelExport {
278            model_type: "factor_graph".to_string(),
279            variables,
280            factors,
281            structure: ModelStructure {
282                structure_type: "undirected".to_string(),
283                edges,
284                cliques,
285            },
286            metadata: ModelMetadata {
287                name: "Exported FactorGraph".to_string(),
288                description: "Factor graph exported from tensorlogic-quantrs-hooks".to_string(),
289                created_at: chrono::Utc::now().to_rfc3339(),
290                tags: vec!["pgm".to_string(), "factor_graph".to_string()],
291            },
292        })
293    }
294
295    fn model_stats(&self) -> ModelStatistics {
296        let num_variables = self.num_variables();
297        let num_factors = self.num_factors();
298
299        let avg_factor_size = if num_factors > 0 {
300            self.factors().map(|f| f.variables.len()).sum::<usize>() as f64 / num_factors as f64
301        } else {
302            0.0
303        };
304
305        let max_factor_size = self.factors().map(|f| f.variables.len()).max().unwrap_or(0);
306
307        ModelStatistics {
308            num_variables,
309            num_factors,
310            avg_factor_size,
311            max_factor_size,
312            treewidth: None,
313        }
314    }
315}
316
317/// Trait for probabilistic inference queries compatible with QuantRS2.
318pub trait QuantRSInferenceQuery {
319    /// Execute a marginal query and return QuantRS-compatible distribution.
320    fn query_marginal_quantrs(&self, variable: &str) -> Result<DistributionExport>;
321
322    /// Execute a conditional query.
323    fn query_conditional_quantrs(
324        &self,
325        variable: &str,
326        evidence: &HashMap<String, usize>,
327    ) -> Result<DistributionExport>;
328
329    /// Execute a MAP (maximum a posteriori) query.
330    fn query_map_quantrs(&self) -> Result<HashMap<String, usize>>;
331}
332
333/// Parameter learning interface for QuantRS integration.
334///
335/// This trait enables parameter estimation using QuantRS2 optimization algorithms.
336pub trait QuantRSParameterLearning {
337    /// Learn parameters from data using maximum likelihood estimation.
338    fn learn_parameters_ml(&mut self, data: &[QuantRSAssignment]) -> Result<()>;
339
340    /// Learn parameters using Bayesian estimation with priors.
341    fn learn_parameters_bayesian(
342        &mut self,
343        data: &[QuantRSAssignment],
344        priors: &HashMap<String, ArrayD<f64>>,
345    ) -> Result<()>;
346
347    /// Get current parameters as QuantRS distributions.
348    fn get_parameters(&self) -> Result<Vec<DistributionExport>>;
349
350    /// Set parameters from QuantRS distributions.
351    fn set_parameters(&mut self, params: &[DistributionExport]) -> Result<()>;
352}
353
354/// Assignment of values to variables (for learning and QuantRS integration).
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct QuantRSAssignment {
357    /// Variable assignments
358    pub assignments: HashMap<String, usize>,
359}
360
361impl QuantRSAssignment {
362    /// Create a new assignment.
363    pub fn new(assignments: HashMap<String, usize>) -> Self {
364        Self { assignments }
365    }
366
367    /// Get the value assigned to a variable.
368    pub fn get(&self, variable: &str) -> Option<usize> {
369        self.assignments.get(variable).copied()
370    }
371
372    /// Create from a simple HashMap (compatibility with sampling module).
373    pub fn from_hashmap(assignments: HashMap<String, usize>) -> Self {
374        Self { assignments }
375    }
376
377    /// Convert to a simple HashMap (compatibility with sampling module).
378    pub fn to_hashmap(&self) -> HashMap<String, usize> {
379        self.assignments.clone()
380    }
381}
382
383/// Hook for MCMC sampling integration with QuantRS2.
384pub trait QuantRSSamplingHook {
385    /// Generate samples using QuantRS2-compatible sampler.
386    fn sample_quantrs(&self, num_samples: usize) -> Result<Vec<QuantRSAssignment>>;
387
388    /// Compute log-likelihood for QuantRS integration.
389    fn log_likelihood(&self, assignment: &QuantRSAssignment) -> Result<f64>;
390
391    /// Compute unnormalized probability (potential).
392    fn unnormalized_probability(&self, assignment: &QuantRSAssignment) -> Result<f64>;
393}
394
395// ============================================================================
396// Quantum Computing Integration Traits
397// ============================================================================
398
399/// Configuration for quantum annealing optimization.
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct AnnealingConfig {
402    /// Number of annealing steps
403    pub num_steps: usize,
404    /// Total annealing time
405    pub annealing_time: f64,
406    /// Number of samples per run
407    pub num_samples: usize,
408    /// Initial temperature (for simulated annealing)
409    pub initial_temperature: f64,
410    /// Final temperature (for simulated annealing)
411    pub final_temperature: f64,
412}
413
414impl Default for AnnealingConfig {
415    fn default() -> Self {
416        Self {
417            num_steps: 100,
418            annealing_time: 10.0,
419            num_samples: 100,
420            initial_temperature: 10.0,
421            final_temperature: 0.01,
422        }
423    }
424}
425
426impl AnnealingConfig {
427    /// Create a new annealing configuration.
428    pub fn new(num_steps: usize, annealing_time: f64) -> Self {
429        Self {
430            num_steps,
431            annealing_time,
432            ..Default::default()
433        }
434    }
435
436    /// Set the number of samples.
437    pub fn with_samples(mut self, num_samples: usize) -> Self {
438        self.num_samples = num_samples;
439        self
440    }
441
442    /// Set the temperature schedule.
443    pub fn with_temperature(mut self, initial: f64, final_temp: f64) -> Self {
444        self.initial_temperature = initial;
445        self.final_temperature = final_temp;
446        self
447    }
448}
449
450/// Solution from quantum annealing or QAOA.
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct QuantumSolution {
453    /// Variable assignments
454    pub assignments: HashMap<String, usize>,
455    /// Objective value (energy)
456    pub objective_value: f64,
457    /// Solution quality indicator (lower is better)
458    pub quality: f64,
459    /// Number of iterations/shots used
460    pub iterations: usize,
461    /// Additional metadata
462    pub metadata: QuantumSolutionMetadata,
463}
464
465/// Metadata for quantum solutions.
466#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct QuantumSolutionMetadata {
468    /// Algorithm used
469    pub algorithm: String,
470    /// Number of QAOA layers (if applicable)
471    pub num_layers: Option<usize>,
472    /// Optimal parameters found
473    pub optimal_params: Option<Vec<f64>>,
474    /// Time taken in seconds
475    pub time_seconds: Option<f64>,
476}
477
478impl QuantumSolution {
479    /// Create a new quantum solution.
480    pub fn new(assignments: HashMap<String, usize>, objective_value: f64, algorithm: &str) -> Self {
481        Self {
482            assignments,
483            objective_value,
484            quality: objective_value.abs(),
485            iterations: 1,
486            metadata: QuantumSolutionMetadata {
487                algorithm: algorithm.to_string(),
488                num_layers: None,
489                optimal_params: None,
490                time_seconds: None,
491            },
492        }
493    }
494
495    /// Get variable assignment.
496    pub fn get(&self, variable: &str) -> Option<usize> {
497        self.assignments.get(variable).copied()
498    }
499}
500
501/// Trait for quantum-enhanced inference on factor graphs.
502///
503/// This trait provides methods for using quantum algorithms (QAOA, quantum annealing)
504/// to perform inference tasks on probabilistic graphical models.
505///
506/// # Example
507///
508/// ```no_run
509/// use tensorlogic_quantrs_hooks::{FactorGraph, QuantumInference};
510/// use std::collections::HashMap;
511///
512/// let mut graph = FactorGraph::new();
513/// graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
514/// graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
515///
516/// // Solve using QAOA
517/// let solution = graph.solve_qaoa(2).unwrap();
518/// println!("Best assignment: {:?}", solution);
519/// ```
520pub trait QuantumInference {
521    /// Solve the optimization problem using QAOA (Quantum Approximate Optimization Algorithm).
522    ///
523    /// QAOA maps the factor graph to a quantum circuit and finds the optimal
524    /// variable assignment that maximizes the joint probability (or minimizes energy).
525    ///
526    /// # Arguments
527    ///
528    /// * `num_layers` - Number of QAOA layers (p parameter). More layers give
529    ///   better approximation but require more quantum resources.
530    ///
531    /// # Returns
532    ///
533    /// A map from variable names to their optimal values.
534    fn solve_qaoa(&self, num_layers: usize) -> Result<HashMap<String, usize>>;
535
536    /// Compute marginal distributions using quantum sampling.
537    ///
538    /// This method uses quantum circuits to sample from the joint distribution
539    /// and estimates marginal probabilities from the samples.
540    ///
541    /// # Arguments
542    ///
543    /// * `num_shots` - Number of measurement shots for sampling.
544    ///
545    /// # Returns
546    ///
547    /// A map from variable names to their marginal probability distributions.
548    fn quantum_marginals(&self, num_shots: usize) -> Result<HashMap<String, ArrayD<f64>>>;
549
550    /// Compute the partition function using quantum amplitude estimation.
551    ///
552    /// This is useful for computing normalized probabilities and
553    /// free energy.
554    fn quantum_partition_function(&self) -> Result<f64>;
555}
556
557/// Trait for quantum annealing optimization.
558///
559/// Quantum annealing is a metaheuristic that uses quantum fluctuations
560/// to find the global minimum of an objective function.
561///
562/// # Example
563///
564/// ```no_run
565/// use tensorlogic_quantrs_hooks::{FactorGraph, QuantumAnnealing, AnnealingConfig};
566/// use tensorlogic_quantrs_hooks::quantum_circuit::QUBOProblem;
567///
568/// let mut graph = FactorGraph::new();
569/// graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
570///
571/// // Convert to QUBO
572/// let qubo = graph.to_qubo().unwrap();
573///
574/// // Run annealing
575/// let config = AnnealingConfig::default();
576/// let solution = graph.anneal(&config).unwrap();
577/// ```
578pub trait QuantumAnnealing {
579    /// Convert the factor graph to a QUBO (Quadratic Unconstrained Binary Optimization) problem.
580    ///
581    /// QUBO is the natural formulation for quantum annealing.
582    fn to_qubo(&self) -> Result<crate::quantum_circuit::QUBOProblem>;
583
584    /// Run quantum annealing to find the optimal assignment.
585    ///
586    /// # Arguments
587    ///
588    /// * `config` - Annealing configuration parameters.
589    ///
590    /// # Returns
591    ///
592    /// The optimal solution found by annealing.
593    fn anneal(&self, config: &AnnealingConfig) -> Result<QuantumSolution>;
594
595    /// Run multiple annealing runs and return the best solution.
596    ///
597    /// # Arguments
598    ///
599    /// * `config` - Annealing configuration parameters.
600    /// * `num_runs` - Number of independent annealing runs.
601    ///
602    /// # Returns
603    ///
604    /// The best solution across all runs.
605    fn anneal_multiple(&self, config: &AnnealingConfig, num_runs: usize)
606        -> Result<QuantumSolution>;
607}
608
609// Implement QuantumInference for FactorGraph
610impl QuantumInference for FactorGraph {
611    fn solve_qaoa(&self, num_layers: usize) -> Result<HashMap<String, usize>> {
612        use crate::quantum_circuit::{factor_graph_to_qubo, QAOAConfig};
613        use crate::quantum_simulation::{run_qaoa, QuantumSimulationBackend};
614
615        let qubo = factor_graph_to_qubo(self)?;
616        let config = QAOAConfig::new(num_layers);
617        let backend = QuantumSimulationBackend::new();
618        let result = run_qaoa(&qubo, &config, &backend)?;
619
620        // Convert result to HashMap
621        let var_names: Vec<String> = self.variable_names().cloned().collect();
622        let mut assignments: HashMap<String, usize> = HashMap::new();
623
624        let solution: &Vec<usize> = &result.best_solution;
625        for (idx, &value) in solution.iter().enumerate() {
626            if idx < var_names.len() {
627                let var_name: &String = &var_names[idx];
628                assignments.insert(var_name.clone(), value);
629            }
630        }
631
632        Ok(assignments)
633    }
634
635    fn quantum_marginals(&self, num_shots: usize) -> Result<HashMap<String, ArrayD<f64>>> {
636        use crate::quantum_simulation::{QuantumSimulationBackend, SimulationConfig};
637
638        // Create backend and run quantum sampling
639        let config = SimulationConfig::with_shots(num_shots);
640        let backend = QuantumSimulationBackend::with_config(config);
641        let samples = backend.quantum_sample(self, num_shots)?;
642
643        // Compute marginals from samples
644        let mut counts: HashMap<String, Vec<usize>> = HashMap::new();
645        let var_names: Vec<String> = self.variable_names().cloned().collect();
646
647        for var in &var_names {
648            counts.insert(var.clone(), vec![0, 0]); // Binary variables
649        }
650
651        for sample in &samples {
652            for (var, &value) in sample {
653                if let Some(count) = counts.get_mut(var) {
654                    if value < count.len() {
655                        count[value] += 1;
656                    }
657                }
658            }
659        }
660
661        // Convert counts to probabilities
662        let mut marginals: HashMap<String, ArrayD<f64>> = HashMap::new();
663        let total = samples.len() as f64;
664
665        for (var, count_vec) in counts {
666            let probs: Vec<f64> = count_vec.iter().map(|&c| c as f64 / total).collect();
667            let shape = vec![probs.len()];
668            let arrd = ArrayD::from_shape_vec(shape, probs)
669                .map_err(|e| PgmError::InvalidDistribution(format!("Reshape failed: {}", e)))?;
670            marginals.insert(var, arrd);
671        }
672
673        Ok(marginals)
674    }
675
676    fn quantum_partition_function(&self) -> Result<f64> {
677        // Simplified: sum over all configurations
678        // In practice, would use quantum amplitude estimation
679        let mut z = 0.0;
680        let var_names: Vec<String> = self.variable_names().cloned().collect();
681        let cardinalities: Vec<usize> = var_names
682            .iter()
683            .filter_map(|name| self.get_variable(name).map(|v| v.cardinality))
684            .collect();
685
686        let total_configs: usize = cardinalities.iter().product();
687
688        for config_idx in 0..total_configs {
689            let mut assignment = HashMap::new();
690            let mut temp = config_idx;
691
692            for (i, &card) in cardinalities.iter().enumerate().rev() {
693                assignment.insert(var_names[i].clone(), temp % card);
694                temp /= card;
695            }
696
697            // Compute unnormalized probability for this configuration
698            let mut prob = 1.0;
699            for factor in self.factors() {
700                let mut indices = Vec::new();
701                for var in &factor.variables {
702                    if let Some(&val) = assignment.get(var) {
703                        indices.push(val);
704                    }
705                }
706                if !indices.is_empty() {
707                    prob *= factor.values[indices.as_slice()];
708                }
709            }
710
711            z += prob;
712        }
713
714        Ok(z)
715    }
716}
717
718// Implement QuantumAnnealing for FactorGraph
719impl QuantumAnnealing for FactorGraph {
720    fn to_qubo(&self) -> Result<crate::quantum_circuit::QUBOProblem> {
721        crate::quantum_circuit::factor_graph_to_qubo(self)
722    }
723
724    fn anneal(&self, config: &AnnealingConfig) -> Result<QuantumSolution> {
725        // Use classical simulated annealing as placeholder
726        // Full quantum annealing would require hardware integration
727        use scirs2_core::random::thread_rng;
728
729        let qubo = self.to_qubo()?;
730        let num_vars = qubo.num_variables;
731        let var_names: Vec<String> = self.variable_names().cloned().collect();
732
733        // Initialize random solution using f64 and converting
734        let mut rng = thread_rng();
735        let mut best_solution: Vec<usize> = (0..num_vars)
736            .map(|_| if rng.random::<f64>() < 0.5 { 0 } else { 1 })
737            .collect();
738
739        // Compute initial value
740        let compute_value = |sol: &[usize]| -> f64 {
741            let mut val = qubo.offset;
742            for i in 0..num_vars {
743                val += qubo.linear[i] * sol[i] as f64;
744                for j in (i + 1)..num_vars {
745                    val += qubo.quadratic[[i, j]] * (sol[i] * sol[j]) as f64;
746                }
747            }
748            val
749        };
750
751        let mut best_value = compute_value(&best_solution);
752
753        // Simulated annealing loop
754        let mut current = best_solution.clone();
755        let mut current_value = best_value;
756
757        for step in 0..config.num_steps {
758            let temp = config.annealing_time * (1.0 - step as f64 / config.num_steps as f64);
759
760            // Flip a random bit using f64 random
761            let flip_idx = (rng.random::<f64>() * num_vars as f64) as usize % num_vars;
762            current[flip_idx] = 1 - current[flip_idx];
763
764            let new_value = compute_value(&current);
765            let delta = new_value - current_value;
766
767            if delta < 0.0 || rng.random::<f64>() < (-delta / temp.max(1e-10)).exp() {
768                current_value = new_value;
769                if current_value < best_value {
770                    best_value = current_value;
771                    best_solution = current.clone();
772                }
773            } else {
774                // Revert flip
775                current[flip_idx] = 1 - current[flip_idx];
776            }
777        }
778
779        // Convert solution to HashMap
780        let mut assignments: HashMap<String, usize> = HashMap::new();
781        for (idx, &val) in best_solution.iter().enumerate() {
782            if idx < var_names.len() {
783                let var_name: &String = &var_names[idx];
784                assignments.insert(var_name.clone(), val);
785            }
786        }
787
788        Ok(QuantumSolution {
789            assignments,
790            objective_value: best_value,
791            quality: best_value.abs(),
792            iterations: config.num_steps,
793            metadata: QuantumSolutionMetadata {
794                algorithm: "simulated_annealing".to_string(),
795                num_layers: None,
796                optimal_params: None,
797                time_seconds: None,
798            },
799        })
800    }
801
802    fn anneal_multiple(
803        &self,
804        config: &AnnealingConfig,
805        num_runs: usize,
806    ) -> Result<QuantumSolution> {
807        let mut best_solution: Option<QuantumSolution> = None;
808
809        for _ in 0..num_runs {
810            let solution = self.anneal(config)?;
811
812            match &best_solution {
813                None => best_solution = Some(solution),
814                Some(best) => {
815                    if solution.objective_value < best.objective_value {
816                        best_solution = Some(solution);
817                    }
818                }
819            }
820        }
821
822        best_solution.ok_or_else(|| PgmError::InvalidGraph("No solution found".to_string()))
823    }
824}
825
826/// Utility functions for QuantRS integration.
827pub mod utils {
828    use super::*;
829
830    /// Convert a factor graph to JSON for QuantRS export.
831    pub fn export_to_json(graph: &FactorGraph) -> Result<String> {
832        let model = graph.to_quantrs_model()?;
833        serde_json::to_string_pretty(&model)
834            .map_err(|e| PgmError::InvalidGraph(format!("JSON serialization failed: {}", e)))
835    }
836
837    /// Import a factor graph from JSON.
838    pub fn import_from_json(json: &str) -> Result<ModelExport> {
839        serde_json::from_str(json)
840            .map_err(|e| PgmError::InvalidGraph(format!("JSON deserialization failed: {}", e)))
841    }
842
843    /// Compute mutual information between two variables using QuantRS format.
844    pub fn mutual_information(joint: &DistributionExport, _var1: &str, _var2: &str) -> Result<f64> {
845        if joint.variables.len() != 2 {
846            return Err(PgmError::InvalidGraph(
847                "Joint distribution must have exactly 2 variables".to_string(),
848            ));
849        }
850
851        let mut mi = 0.0;
852        let n1 = joint.cardinalities[0];
853        let n2 = joint.cardinalities[1];
854
855        // Compute marginals
856        let mut p_x = vec![0.0; n1];
857        let mut p_y = vec![0.0; n2];
858
859        for (i, px) in p_x.iter_mut().enumerate().take(n1) {
860            for (j, py) in p_y.iter_mut().enumerate().take(n2) {
861                let idx = i * n2 + j;
862                *px += joint.probabilities[idx];
863                *py += joint.probabilities[idx];
864            }
865        }
866
867        // Compute MI
868        for (i, &px_val) in p_x.iter().enumerate().take(n1) {
869            for (j, &py_val) in p_y.iter().enumerate().take(n2) {
870                let idx = i * n2 + j;
871                let p_xy = joint.probabilities[idx];
872                if p_xy > 1e-10 && px_val > 1e-10 && py_val > 1e-10 {
873                    mi += p_xy * (p_xy / (px_val * py_val)).ln();
874                }
875            }
876        }
877
878        Ok(mi)
879    }
880
881    /// Compute KL divergence between two distributions.
882    pub fn kl_divergence(p: &DistributionExport, q: &DistributionExport) -> Result<f64> {
883        if p.shape != q.shape {
884            return Err(PgmError::InvalidGraph(
885                "Distributions must have same shape".to_string(),
886            ));
887        }
888
889        let mut kl = 0.0;
890        for i in 0..p.probabilities.len() {
891            let pi = p.probabilities[i];
892            let qi = q.probabilities[i];
893
894            if pi > 1e-10 {
895                if qi < 1e-10 {
896                    return Ok(f64::INFINITY);
897                }
898                kl += pi * (pi / qi).ln();
899            }
900        }
901
902        Ok(kl)
903    }
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909    use crate::graph::FactorGraph;
910    use approx::assert_abs_diff_eq;
911    use scirs2_core::ndarray::Array;
912
913    #[test]
914    fn test_factor_to_quantrs_distribution() {
915        let values = Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
916            .unwrap()
917            .into_dyn();
918        let factor = Factor::new(
919            "test".to_string(),
920            vec!["x".to_string(), "y".to_string()],
921            values,
922        )
923        .unwrap();
924
925        let dist = factor.to_quantrs_distribution().unwrap();
926
927        assert_eq!(dist.variables.len(), 2);
928        assert_eq!(dist.probabilities.len(), 4);
929        assert!(dist.metadata.normalized);
930    }
931
932    #[test]
933    fn test_quantrs_distribution_roundtrip() {
934        let values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
935            .unwrap()
936            .into_dyn();
937        let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
938
939        let dist = factor.to_quantrs_distribution().unwrap();
940        let factor2 = Factor::from_quantrs_distribution(&dist).unwrap();
941
942        assert_eq!(factor.variables, factor2.variables);
943        assert_eq!(factor.values.shape(), factor2.values.shape());
944    }
945
946    #[test]
947    fn test_is_normalized() {
948        let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
949            .unwrap()
950            .into_dyn();
951        let factor = Factor::new("test".to_string(), vec!["x".to_string()], values).unwrap();
952
953        assert!(factor.is_normalized());
954    }
955
956    #[test]
957    fn test_support() {
958        let values = Array::from_shape_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4])
959            .unwrap()
960            .into_dyn();
961        let factor = Factor::new(
962            "test".to_string(),
963            vec!["x".to_string(), "y".to_string()],
964            values,
965        )
966        .unwrap();
967
968        let support = factor.support();
969        assert_eq!(support.len(), 4);
970        assert_eq!(support[0], vec![0, 0]);
971        assert_eq!(support[1], vec![0, 1]);
972        assert_eq!(support[2], vec![1, 0]);
973        assert_eq!(support[3], vec![1, 1]);
974    }
975
976    #[test]
977    fn test_factor_graph_to_quantrs_model() {
978        let mut graph = FactorGraph::new();
979        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
980        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
981
982        let factor = Factor::new(
983            "P(x,y)".to_string(),
984            vec!["x".to_string(), "y".to_string()],
985            Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
986                .unwrap()
987                .into_dyn(),
988        )
989        .unwrap();
990        graph.add_factor(factor).unwrap();
991
992        let model = graph.to_quantrs_model().unwrap();
993
994        assert_eq!(model.variables.len(), 2);
995        assert_eq!(model.factors.len(), 1);
996        assert_eq!(model.model_type, "factor_graph");
997    }
998
999    #[test]
1000    fn test_model_stats() {
1001        let mut graph = FactorGraph::new();
1002        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
1003        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
1004
1005        let factor = Factor::new(
1006            "P(x,y)".to_string(),
1007            vec!["x".to_string(), "y".to_string()],
1008            Array::from_shape_vec(vec![2, 2], vec![0.25, 0.25, 0.25, 0.25])
1009                .unwrap()
1010                .into_dyn(),
1011        )
1012        .unwrap();
1013        graph.add_factor(factor).unwrap();
1014
1015        let stats = graph.model_stats();
1016
1017        assert_eq!(stats.num_variables, 2);
1018        assert_eq!(stats.num_factors, 1);
1019        assert_abs_diff_eq!(stats.avg_factor_size, 2.0);
1020        assert_eq!(stats.max_factor_size, 2);
1021    }
1022
1023    #[test]
1024    fn test_mutual_information() {
1025        let dist = DistributionExport {
1026            variables: vec!["x".to_string(), "y".to_string()],
1027            cardinalities: vec![2, 2],
1028            probabilities: vec![0.25, 0.25, 0.25, 0.25],
1029            shape: vec![2, 2],
1030            metadata: DistributionMetadata {
1031                distribution_type: "categorical".to_string(),
1032                normalized: true,
1033                parameter_names: vec![],
1034                tags: vec![],
1035            },
1036        };
1037
1038        let mi = utils::mutual_information(&dist, "x", "y").unwrap();
1039
1040        assert_abs_diff_eq!(mi, 0.0, epsilon = 1e-6);
1041    }
1042
1043    #[test]
1044    fn test_kl_divergence() {
1045        let p = DistributionExport {
1046            variables: vec!["x".to_string()],
1047            cardinalities: vec![2],
1048            probabilities: vec![0.7, 0.3],
1049            shape: vec![2],
1050            metadata: DistributionMetadata {
1051                distribution_type: "categorical".to_string(),
1052                normalized: true,
1053                parameter_names: vec![],
1054                tags: vec![],
1055            },
1056        };
1057
1058        let q = DistributionExport {
1059            variables: vec!["x".to_string()],
1060            cardinalities: vec![2],
1061            probabilities: vec![0.5, 0.5],
1062            shape: vec![2],
1063            metadata: DistributionMetadata {
1064                distribution_type: "categorical".to_string(),
1065                normalized: true,
1066                parameter_names: vec![],
1067                tags: vec![],
1068            },
1069        };
1070
1071        let kl = utils::kl_divergence(&p, &q).unwrap();
1072
1073        assert!(kl > 0.0);
1074    }
1075}