Skip to main content

tensorlogic_quantrs_hooks/
tensor_network_bridge.rs

1//! Tensor network bridge for quantum-classical hybrid inference.
2//!
3//! This module provides conversion between probabilistic graphical models
4//! and quantum tensor network representations, enabling efficient computation
5//! of marginals and partition functions.
6//!
7//! # Overview
8//!
9//! Tensor networks provide a natural bridge between:
10//! - Classical PGMs (factor graphs, MRFs)
11//! - Quantum states (MPS, PEPS, MERA)
12//!
13//! # Key Concepts
14//!
15//! - **MPS (Matrix Product State)**: 1D tensor network for linear chains
16//! - **PEPS (Projected Entangled Pair State)**: 2D tensor network
17//! - **Tensor Network Contraction**: Computing expectations and marginals
18//!
19//! # Example
20//!
21//! ```no_run
22//! use tensorlogic_quantrs_hooks::tensor_network_bridge::{
23//!     factor_graph_to_tensor_network, TensorNetwork,
24//! };
25//! use tensorlogic_quantrs_hooks::FactorGraph;
26//!
27//! let mut graph = FactorGraph::new();
28//! graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
29//! graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
30//!
31//! let tn = factor_graph_to_tensor_network(&graph).unwrap();
32//! println!("Tensor network with {} tensors", tn.num_tensors());
33//! ```
34
35use crate::error::{PgmError, Result};
36use crate::graph::FactorGraph;
37use crate::linear_chain_crf::LinearChainCRF;
38use quantrs2_sim::Complex64;
39use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayD};
40use serde::{Deserialize, Serialize};
41
42/// A tensor in the tensor network.
43#[derive(Debug, Clone)]
44pub struct Tensor {
45    /// Tensor name/identifier
46    pub name: String,
47    /// Tensor data (n-dimensional array)
48    pub data: ArrayD<Complex64>,
49    /// Index labels for contraction
50    pub indices: Vec<String>,
51    /// Bond dimensions for each index
52    pub bond_dims: Vec<usize>,
53}
54
55impl Tensor {
56    /// Create a new tensor.
57    pub fn new(name: String, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
58        let bond_dims = data.shape().to_vec();
59        Self {
60            name,
61            data,
62            indices,
63            bond_dims,
64        }
65    }
66
67    /// Create from a real-valued array.
68    pub fn from_real(name: String, data: ArrayD<f64>, indices: Vec<String>) -> Self {
69        let complex_data = data.mapv(|x| Complex64::new(x, 0.0));
70        Self::new(name, complex_data, indices)
71    }
72
73    /// Get the number of indices (rank).
74    pub fn rank(&self) -> usize {
75        self.indices.len()
76    }
77
78    /// Get bond dimension for a specific index.
79    pub fn bond_dim(&self, index: &str) -> Option<usize> {
80        self.indices
81            .iter()
82            .position(|i| i == index)
83            .map(|pos| self.bond_dims[pos])
84    }
85
86    /// Contract two tensors over shared indices.
87    pub fn contract(&self, other: &Tensor) -> Result<Tensor> {
88        // Find shared indices
89        let shared: Vec<(usize, usize)> = self
90            .indices
91            .iter()
92            .enumerate()
93            .filter_map(|(i, idx)| {
94                other
95                    .indices
96                    .iter()
97                    .position(|oidx| oidx == idx)
98                    .map(|j| (i, j))
99            })
100            .collect();
101
102        if shared.is_empty() {
103            // Outer product
104            return self.outer_product(other);
105        }
106
107        // Contract over shared indices using tensordot-like operation
108        let result_indices: Vec<String> = self
109            .indices
110            .iter()
111            .enumerate()
112            .filter(|(i, _)| !shared.iter().any(|(si, _)| si == i))
113            .map(|(_, idx)| idx.clone())
114            .chain(
115                other
116                    .indices
117                    .iter()
118                    .enumerate()
119                    .filter(|(j, _)| !shared.iter().any(|(_, sj)| sj == j))
120                    .map(|(_, idx)| idx.clone()),
121            )
122            .collect();
123
124        // Compute result shape
125        let result_shape: Vec<usize> = self
126            .bond_dims
127            .iter()
128            .enumerate()
129            .filter(|(i, _)| !shared.iter().any(|(si, _)| si == i))
130            .map(|(_, &d)| d)
131            .chain(
132                other
133                    .bond_dims
134                    .iter()
135                    .enumerate()
136                    .filter(|(j, _)| !shared.iter().any(|(_, sj)| sj == j))
137                    .map(|(_, &d)| d),
138            )
139            .collect();
140
141        // For simplicity, implement for common cases
142        // Full generic contraction would require more complex index manipulation
143        let result_data = self.contract_data(other, &shared, &result_shape)?;
144
145        Ok(Tensor {
146            name: format!("{}*{}", self.name, other.name),
147            data: result_data,
148            indices: result_indices,
149            bond_dims: result_shape,
150        })
151    }
152
153    /// Contract tensor data (simplified implementation).
154    fn contract_data(
155        &self,
156        _other: &Tensor,
157        _shared: &[(usize, usize)],
158        result_shape: &[usize],
159    ) -> Result<ArrayD<Complex64>> {
160        // Simplified: flatten and sum over contracted indices
161        // In practice, this would use optimized tensor contraction
162        let total_size: usize = result_shape.iter().product();
163        let data = vec![Complex64::new(0.0, 0.0); total_size.max(1)];
164        ArrayD::from_shape_vec(result_shape.to_vec(), data)
165            .map_err(|e| PgmError::InvalidDistribution(format!("Contraction failed: {}", e)))
166    }
167
168    /// Compute outer product with another tensor.
169    fn outer_product(&self, other: &Tensor) -> Result<Tensor> {
170        let result_indices: Vec<String> = self
171            .indices
172            .iter()
173            .chain(other.indices.iter())
174            .cloned()
175            .collect();
176
177        let result_shape: Vec<usize> = self
178            .bond_dims
179            .iter()
180            .chain(other.bond_dims.iter())
181            .copied()
182            .collect();
183
184        let total_size: usize = result_shape.iter().product();
185        let mut data = vec![Complex64::new(0.0, 0.0); total_size.max(1)];
186
187        // Compute outer product
188        for (i, &a) in self.data.iter().enumerate() {
189            for (j, &b) in other.data.iter().enumerate() {
190                data[i * other.data.len() + j] = a * b;
191            }
192        }
193
194        Ok(Tensor {
195            name: format!("{}⊗{}", self.name, other.name),
196            data: ArrayD::from_shape_vec(result_shape.clone(), data).map_err(|e| {
197                PgmError::InvalidDistribution(format!("Outer product failed: {}", e))
198            })?,
199            indices: result_indices,
200            bond_dims: result_shape,
201        })
202    }
203}
204
205/// A tensor network representation.
206#[derive(Debug, Clone)]
207pub struct TensorNetwork {
208    /// Tensors in the network
209    tensors: Vec<Tensor>,
210    /// Physical indices (observable variables)
211    physical_indices: Vec<String>,
212    /// Virtual/bond indices
213    bond_indices: Vec<String>,
214}
215
216impl TensorNetwork {
217    /// Create a new empty tensor network.
218    pub fn new() -> Self {
219        Self {
220            tensors: Vec::new(),
221            physical_indices: Vec::new(),
222            bond_indices: Vec::new(),
223        }
224    }
225
226    /// Add a tensor to the network.
227    pub fn add_tensor(&mut self, tensor: Tensor) {
228        self.tensors.push(tensor);
229    }
230
231    /// Add a physical index.
232    pub fn add_physical_index(&mut self, index: String) {
233        if !self.physical_indices.contains(&index) {
234            self.physical_indices.push(index);
235        }
236    }
237
238    /// Add a bond index.
239    pub fn add_bond_index(&mut self, index: String) {
240        if !self.bond_indices.contains(&index) {
241            self.bond_indices.push(index);
242        }
243    }
244
245    /// Get the number of tensors.
246    pub fn num_tensors(&self) -> usize {
247        self.tensors.len()
248    }
249
250    /// Get the number of physical indices.
251    pub fn num_physical_indices(&self) -> usize {
252        self.physical_indices.len()
253    }
254
255    /// Get total bond dimension.
256    pub fn total_bond_dim(&self) -> usize {
257        self.tensors
258            .iter()
259            .map(|t| t.bond_dims.iter().product::<usize>())
260            .sum()
261    }
262
263    /// Contract the entire network to a single tensor.
264    ///
265    /// This uses a simple sequential contraction strategy.
266    pub fn contract(&self) -> Result<Tensor> {
267        if self.tensors.is_empty() {
268            return Err(PgmError::InvalidGraph(
269                "Cannot contract empty tensor network".to_string(),
270            ));
271        }
272
273        let mut result = self.tensors[0].clone();
274        for tensor in self.tensors.iter().skip(1) {
275            result = result.contract(tensor)?;
276        }
277
278        Ok(result)
279    }
280
281    /// Compute the partition function (trace over all indices).
282    pub fn partition_function(&self) -> Result<Complex64> {
283        let contracted = self.contract()?;
284        Ok(contracted.data.iter().sum())
285    }
286
287    /// Compute marginal for a subset of physical indices.
288    pub fn marginal(&self, indices: &[String]) -> Result<Tensor> {
289        // Contract network, then trace out non-specified indices
290        let contracted = self.contract()?;
291
292        // Keep only specified indices
293        let keep_positions: Vec<usize> = contracted
294            .indices
295            .iter()
296            .enumerate()
297            .filter_map(
298                |(i, idx)| {
299                    if indices.contains(idx) {
300                        Some(i)
301                    } else {
302                        None
303                    }
304                },
305            )
306            .collect();
307
308        if keep_positions.is_empty() {
309            // Return scalar
310            let sum: Complex64 = contracted.data.iter().sum();
311            return Ok(Tensor::new(
312                "marginal".to_string(),
313                ArrayD::from_elem(vec![], sum),
314                vec![],
315            ));
316        }
317
318        // Sum over non-kept indices (simplified implementation)
319        let result_shape: Vec<usize> = keep_positions
320            .iter()
321            .map(|&i| contracted.bond_dims[i])
322            .collect();
323        let result_indices: Vec<String> = keep_positions
324            .iter()
325            .map(|&i| contracted.indices[i].clone())
326            .collect();
327
328        // For now, return contracted tensor with subset of indices
329        Ok(Tensor {
330            name: "marginal".to_string(),
331            data: contracted.data, // Simplified: would need proper marginalization
332            indices: result_indices,
333            bond_dims: result_shape,
334        })
335    }
336}
337
338impl Default for TensorNetwork {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344/// Convert a factor graph to a tensor network.
345///
346/// Each factor becomes a tensor, and shared variables become bond indices.
347pub fn factor_graph_to_tensor_network(graph: &FactorGraph) -> Result<TensorNetwork> {
348    let mut tn = TensorNetwork::new();
349
350    // Add physical indices for each variable
351    for var_name in graph.variable_names() {
352        tn.add_physical_index(var_name.clone());
353    }
354
355    // Convert each factor to a tensor
356    for factor in graph.factors() {
357        let indices = factor.variables.clone();
358        let tensor = Tensor::from_real(factor.name.clone(), factor.values.clone(), indices);
359        tn.add_tensor(tensor);
360    }
361
362    Ok(tn)
363}
364
365/// Matrix Product State (MPS) representation.
366///
367/// MPS is a 1D tensor network particularly suited for linear-chain structures.
368///
369/// |ψ⟩ = Σ_{s₁...sₙ} A\[1\]^{s₁} A\[2\]^{s₂} ... A\[n\]^{sₙ} |s₁...sₙ⟩
370#[derive(Debug, Clone)]
371pub struct MatrixProductState {
372    /// Site tensors (each is [bond_left, physical, bond_right])
373    pub tensors: Vec<Array3<Complex64>>,
374    /// Physical dimensions at each site
375    pub physical_dims: Vec<usize>,
376    /// Bond dimensions
377    pub bond_dims: Vec<usize>,
378}
379
380impl MatrixProductState {
381    /// Create a new MPS with uniform physical dimension.
382    pub fn new(length: usize, physical_dim: usize, bond_dim: usize) -> Self {
383        let mut tensors = Vec::with_capacity(length);
384        let mut bond_dims = Vec::with_capacity(length + 1);
385
386        bond_dims.push(1); // Left boundary
387
388        for i in 0..length {
389            let left_dim = bond_dims[i];
390            let right_dim = if i == length - 1 { 1 } else { bond_dim };
391            bond_dims.push(right_dim);
392
393            // Initialize with random values
394            let tensor = Array3::from_shape_fn((left_dim, physical_dim, right_dim), |_| {
395                Complex64::new(1.0 / (left_dim * physical_dim * right_dim) as f64, 0.0)
396            });
397            tensors.push(tensor);
398        }
399
400        Self {
401            tensors,
402            physical_dims: vec![physical_dim; length],
403            bond_dims,
404        }
405    }
406
407    /// Create an MPS in the product state |00...0⟩.
408    pub fn product_state(length: usize, physical_dim: usize) -> Self {
409        let mut tensors = Vec::with_capacity(length);
410
411        for _ in 0..length {
412            // |0⟩ state at each site
413            let mut tensor = Array3::zeros((1, physical_dim, 1));
414            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
415            tensors.push(tensor);
416        }
417
418        Self {
419            tensors,
420            physical_dims: vec![physical_dim; length],
421            bond_dims: vec![1; length + 1],
422        }
423    }
424
425    /// Get the length (number of sites).
426    pub fn length(&self) -> usize {
427        self.tensors.len()
428    }
429
430    /// Get the maximum bond dimension.
431    pub fn max_bond_dim(&self) -> usize {
432        *self.bond_dims.iter().max().unwrap_or(&1)
433    }
434
435    /// Contract the MPS to a single tensor (full state vector).
436    ///
437    /// Warning: This is exponentially expensive for large systems.
438    pub fn to_state_vector(&self) -> Result<Array1<Complex64>> {
439        if self.tensors.is_empty() {
440            return Ok(Array1::from(vec![Complex64::new(1.0, 0.0)]));
441        }
442
443        let total_dim: usize = self.physical_dims.iter().product();
444        let mut state = Array1::zeros(total_dim);
445
446        // Enumerate all basis states
447        for basis_idx in 0..total_dim {
448            let mut indices = vec![0; self.tensors.len()];
449            let mut temp = basis_idx;
450            for (i, &dim) in self.physical_dims.iter().enumerate().rev() {
451                indices[i] = temp % dim;
452                temp /= dim;
453            }
454
455            // Contract along the chain
456            let mut amplitude = Complex64::new(1.0, 0.0);
457            let mut left_idx = 0;
458
459            for (site, &phys_idx) in indices.iter().enumerate() {
460                let tensor = &self.tensors[site];
461                // Sum over virtual index
462                let right_dim = tensor.shape()[2];
463                let mut sum = Complex64::new(0.0, 0.0);
464                for right_idx in 0..right_dim {
465                    sum += tensor[[left_idx, phys_idx, right_idx]];
466                }
467                amplitude *= sum;
468                left_idx = 0; // MPS contraction uses the virtual index
469            }
470
471            state[basis_idx] = amplitude;
472        }
473
474        // Normalize
475        let norm: f64 = state
476            .iter()
477            .map(|x: &Complex64| x.norm_sqr())
478            .sum::<f64>()
479            .sqrt();
480        if norm > 1e-10 {
481            for x in state.iter_mut() {
482                *x /= norm;
483            }
484        }
485
486        Ok(state)
487    }
488
489    /// Compute the norm of the MPS.
490    pub fn norm(&self) -> f64 {
491        let state_result: Result<Array1<Complex64>> = self.to_state_vector();
492        match state_result {
493            Ok(state) => {
494                let state_arr: Array1<Complex64> = state;
495                state_arr
496                    .iter()
497                    .map(|x: &Complex64| x.norm_sqr())
498                    .sum::<f64>()
499                    .sqrt()
500            }
501            Err(_) => 0.0,
502        }
503    }
504
505    /// Compute the expectation value of a local operator at a site.
506    pub fn expectation_local(
507        &self,
508        site: usize,
509        operator: &Array2<Complex64>,
510    ) -> Result<Complex64> {
511        if site >= self.tensors.len() {
512            return Err(PgmError::VariableNotFound(format!(
513                "Site {} out of range",
514                site
515            )));
516        }
517
518        // Simplified: compute full expectation (inefficient for large MPS)
519        let state = self.to_state_vector()?;
520
521        let mut result = Complex64::new(0.0, 0.0);
522        let num_sites = self.tensors.len();
523        let total_dim: usize = self.physical_dims.iter().product();
524
525        for basis_idx in 0..total_dim {
526            // Decode basis state
527            let mut indices = vec![0; num_sites];
528            let mut temp = basis_idx;
529            for (i, &dim) in self.physical_dims.iter().enumerate().rev() {
530                indices[i] = temp % dim;
531                temp /= dim;
532            }
533
534            // Apply operator at site
535            for new_idx in 0..self.physical_dims[site] {
536                let op_elem = operator[[new_idx, indices[site]]];
537                if op_elem.norm_sqr() > 1e-20 {
538                    // Compute new basis index
539                    let mut new_basis_idx = 0;
540                    let mut multiplier = 1;
541                    for (i, &idx) in indices.iter().enumerate().rev() {
542                        let idx_to_use = if i == site { new_idx } else { idx };
543                        new_basis_idx += idx_to_use * multiplier;
544                        multiplier *= self.physical_dims[i];
545                    }
546
547                    result += state[new_basis_idx].conj() * op_elem * state[basis_idx];
548                }
549            }
550        }
551
552        Ok(result)
553    }
554}
555
556/// Convert a Linear Chain CRF to a Matrix Product State.
557///
558/// The CRF's potential functions become the MPS tensors.
559pub fn linear_chain_to_mps(
560    crf: &LinearChainCRF,
561    input_sequence: &[usize],
562) -> Result<MatrixProductState> {
563    let factor_graph = crf.to_factor_graph(input_sequence)?;
564    let num_sites = input_sequence.len();
565
566    if num_sites == 0 {
567        return Err(PgmError::InvalidGraph("Empty sequence".to_string()));
568    }
569
570    // Get state dimension from CRF
571    let num_states = factor_graph
572        .get_variable("y_0")
573        .map(|v| v.cardinality)
574        .unwrap_or(2);
575
576    // Build MPS from factors
577    let mut mps = MatrixProductState::new(num_sites, num_states, num_states);
578
579    // Populate tensors from emission and transition factors
580    for t in 0..num_sites {
581        let emission_name = format!("emission_{}", t);
582        let transition_name = format!("transition_{}", t);
583
584        // Get emission factor
585        if let Some(emission) = factor_graph.get_factor_by_name(&emission_name) {
586            // Emission factor is diagonal in physical index
587            for (s, &val) in emission.values.iter().enumerate() {
588                if s < num_states {
589                    mps.tensors[t][[0, s, 0]] = Complex64::new(val.sqrt(), 0.0);
590                }
591            }
592        }
593
594        // Get transition factor (if not first site)
595        if t > 0 {
596            if let Some(transition) = factor_graph.get_factor_by_name(&transition_name) {
597                // Transition factor connects adjacent sites
598                for s_prev in 0..num_states {
599                    for s_curr in 0..num_states {
600                        if s_prev < transition.values.shape()[0]
601                            && s_curr < transition.values.shape()[1]
602                        {
603                            let val = transition.values[[s_prev, s_curr]];
604                            // Incorporate into MPS tensor
605                            let tensor = &mut mps.tensors[t];
606                            let left_dim = tensor.shape()[0];
607                            let right_dim = tensor.shape()[2];
608                            if s_prev < left_dim && s_curr < num_states {
609                                tensor[[s_prev.min(left_dim - 1), s_curr, 0.min(right_dim - 1)]] =
610                                    Complex64::new(val.sqrt(), 0.0);
611                            }
612                        }
613                    }
614                }
615            }
616        }
617    }
618
619    Ok(mps)
620}
621
622/// Statistics about a tensor network.
623#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct TensorNetworkStats {
625    /// Number of tensors
626    pub num_tensors: usize,
627    /// Total number of elements
628    pub total_elements: usize,
629    /// Maximum tensor rank
630    pub max_rank: usize,
631    /// Average tensor rank
632    pub avg_rank: f64,
633    /// Number of physical indices
634    pub num_physical_indices: usize,
635    /// Number of bond indices
636    pub num_bond_indices: usize,
637}
638
639impl TensorNetwork {
640    /// Compute statistics about the tensor network.
641    pub fn stats(&self) -> TensorNetworkStats {
642        let num_tensors = self.tensors.len();
643        let total_elements: usize = self.tensors.iter().map(|t| t.data.len()).sum();
644        let max_rank = self.tensors.iter().map(|t| t.rank()).max().unwrap_or(0);
645        let avg_rank = if num_tensors > 0 {
646            self.tensors.iter().map(|t| t.rank()).sum::<usize>() as f64 / num_tensors as f64
647        } else {
648            0.0
649        };
650
651        TensorNetworkStats {
652            num_tensors,
653            total_elements,
654            max_rank,
655            avg_rank,
656            num_physical_indices: self.physical_indices.len(),
657            num_bond_indices: self.bond_indices.len(),
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use crate::graph::FactorGraph;
666    use approx::assert_abs_diff_eq;
667
668    #[test]
669    fn test_tensor_creation() {
670        let data = ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
671            .expect("Array creation failed");
672        let tensor = Tensor::from_real(
673            "test".to_string(),
674            data,
675            vec!["i".to_string(), "j".to_string()],
676        );
677
678        assert_eq!(tensor.rank(), 2);
679        assert_eq!(tensor.bond_dim("i"), Some(2));
680        assert_eq!(tensor.bond_dim("j"), Some(3));
681    }
682
683    #[test]
684    fn test_tensor_network_creation() {
685        let mut tn = TensorNetwork::new();
686        let data = ArrayD::from_shape_vec(vec![2], vec![1.0, 0.0]).expect("Array creation failed");
687        let tensor = Tensor::from_real("A".to_string(), data, vec!["x".to_string()]);
688
689        tn.add_tensor(tensor);
690        tn.add_physical_index("x".to_string());
691
692        assert_eq!(tn.num_tensors(), 1);
693        assert_eq!(tn.num_physical_indices(), 1);
694    }
695
696    #[test]
697    fn test_factor_graph_to_tn() {
698        let mut graph = FactorGraph::new();
699        graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
700        graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
701
702        let tn = factor_graph_to_tensor_network(&graph);
703        assert!(tn.is_ok());
704
705        let tn = tn.expect("TN creation failed");
706        assert_eq!(tn.num_physical_indices(), 2);
707    }
708
709    #[test]
710    fn test_mps_creation() {
711        let mps = MatrixProductState::new(4, 2, 4);
712
713        assert_eq!(mps.length(), 4);
714        assert_eq!(mps.physical_dims.len(), 4);
715        assert!(mps.max_bond_dim() <= 4);
716    }
717
718    #[test]
719    fn test_mps_product_state() {
720        let mps = MatrixProductState::product_state(3, 2);
721
722        assert_eq!(mps.length(), 3);
723        assert_eq!(mps.max_bond_dim(), 1);
724
725        // Product state |000⟩ should have norm 1
726        let norm = mps.norm();
727        assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6);
728    }
729
730    #[test]
731    fn test_mps_to_state_vector() {
732        let mps = MatrixProductState::product_state(2, 2);
733        let state = mps.to_state_vector();
734
735        assert!(state.is_ok());
736        let state = state.expect("State vector failed");
737        assert_eq!(state.len(), 4); // 2^2 = 4 basis states
738    }
739
740    #[test]
741    fn test_tensor_network_stats() {
742        let mut tn = TensorNetwork::new();
743        let data1 =
744            ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).expect("Array creation failed");
745        let data2 =
746            ArrayD::from_shape_vec(vec![3, 4], vec![1.0; 12]).expect("Array creation failed");
747
748        tn.add_tensor(Tensor::from_real(
749            "A".to_string(),
750            data1,
751            vec!["i".to_string(), "j".to_string()],
752        ));
753        tn.add_tensor(Tensor::from_real(
754            "B".to_string(),
755            data2,
756            vec!["j".to_string(), "k".to_string()],
757        ));
758
759        let stats = tn.stats();
760        assert_eq!(stats.num_tensors, 2);
761        assert_eq!(stats.total_elements, 18);
762        assert_eq!(stats.max_rank, 2);
763    }
764
765    #[test]
766    fn test_tensor_outer_product() {
767        let data1 = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).expect("Array creation failed");
768        let data2 =
769            ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("Array creation failed");
770
771        let t1 = Tensor::from_real("A".to_string(), data1, vec!["i".to_string()]);
772        let t2 = Tensor::from_real("B".to_string(), data2, vec!["j".to_string()]);
773
774        let result = t1.contract(&t2);
775        assert!(result.is_ok());
776
777        let result = result.expect("Contraction failed");
778        assert_eq!(result.indices.len(), 2);
779        assert_eq!(result.bond_dims, vec![2, 3]);
780    }
781
782    #[test]
783    fn test_mps_expectation() {
784        let mps = MatrixProductState::product_state(2, 2);
785
786        // Z operator: |0⟩ → +1, |1⟩ → -1
787        let z_op = Array2::from_shape_vec(
788            (2, 2),
789            vec![
790                Complex64::new(1.0, 0.0),
791                Complex64::new(0.0, 0.0),
792                Complex64::new(0.0, 0.0),
793                Complex64::new(-1.0, 0.0),
794            ],
795        )
796        .expect("Operator creation failed");
797
798        let exp_val = mps.expectation_local(0, &z_op);
799        assert!(exp_val.is_ok());
800
801        // For |00⟩, ⟨Z⟩ at site 0 should be +1
802        let exp_val = exp_val.expect("Expectation failed");
803        assert_abs_diff_eq!(exp_val.re, 1.0, epsilon = 1e-6);
804    }
805}