Skip to main content

ruqu_core/
tensor_network.rs

1//! Matrix Product State (MPS) tensor network simulator.
2//!
3//! Represents an n-qubit quantum state as a chain of tensors:
4//!   |psi> = Sum A[1]^{i1} . A[2]^{i2} . ... . A[n]^{in} |i1 i2 ... in>
5//!
6//! Each A[k] has shape (chi_{k-1}, 2, chi_k) where chi is the bond dimension.
7//! Product states have chi=1. Entanglement increases bond dimension up to a
8//! configurable maximum, beyond which truncation provides approximate simulation
9//! with controlled error.
10
11use crate::error::{QuantumError, Result};
12use crate::gate::Gate;
13use crate::types::{Complex, MeasurementOutcome, QubitIndex};
14
15use rand::rngs::StdRng;
16use rand::{Rng, SeedableRng};
17
18/// Configuration for the MPS simulator.
19#[derive(Debug, Clone)]
20pub struct MpsConfig {
21    /// Maximum bond dimension. Higher values yield more accurate simulation
22    /// at the cost of increased memory and computation time.
23    /// Typical values: 64, 128, 256, 512, 1024.
24    pub max_bond_dim: usize,
25    /// Truncation threshold: singular values below this are discarded.
26    pub truncation_threshold: f64,
27}
28
29impl Default for MpsConfig {
30    fn default() -> Self {
31        Self {
32            max_bond_dim: 256,
33            truncation_threshold: 1e-10,
34        }
35    }
36}
37
38// ---------------------------------------------------------------------------
39// MPS Tensor
40// ---------------------------------------------------------------------------
41
42/// A single MPS tensor for qubit k.
43///
44/// Shape: (left_dim, 2, right_dim) stored as a flat `Vec<Complex>` in
45/// row-major order with index = left * (2 * right_dim) + phys * right_dim + right.
46#[derive(Clone)]
47struct MpsTensor {
48    data: Vec<Complex>,
49    left_dim: usize,
50    right_dim: usize,
51}
52
53impl MpsTensor {
54    /// Create a tensor initialized to zero.
55    fn new_zero(left_dim: usize, right_dim: usize) -> Self {
56        Self {
57            data: vec![Complex::ZERO; left_dim * 2 * right_dim],
58            left_dim,
59            right_dim,
60        }
61    }
62
63    /// Compute the flat index for element (left, phys, right).
64    #[inline]
65    fn index(&self, left: usize, phys: usize, right: usize) -> usize {
66        left * (2 * self.right_dim) + phys * self.right_dim + right
67    }
68
69    /// Read the element at (left, phys, right).
70    #[inline]
71    fn get(&self, left: usize, phys: usize, right: usize) -> Complex {
72        self.data[self.index(left, phys, right)]
73    }
74
75    /// Write the element at (left, phys, right).
76    #[inline]
77    fn set(&mut self, left: usize, phys: usize, right: usize, val: Complex) {
78        let idx = self.index(left, phys, right);
79        self.data[idx] = val;
80    }
81}
82
83// ---------------------------------------------------------------------------
84// MPS State
85// ---------------------------------------------------------------------------
86
87/// Matrix Product State quantum simulator.
88///
89/// Represents quantum states as a chain of tensors, enabling efficient
90/// simulation of circuits with bounded entanglement. Can handle hundreds
91/// to thousands of qubits when bond dimension stays manageable.
92pub struct MpsState {
93    num_qubits: usize,
94    tensors: Vec<MpsTensor>,
95    config: MpsConfig,
96    rng: StdRng,
97    measurement_record: Vec<MeasurementOutcome>,
98    /// Accumulated truncation error for confidence bounds.
99    total_truncation_error: f64,
100}
101
102// ---------------------------------------------------------------------------
103// Construction
104// ---------------------------------------------------------------------------
105
106impl MpsState {
107    /// Initialize the |00...0> product state.
108    ///
109    /// Each tensor has bond dimension 1 and physical dimension 2, with the
110    /// amplitude concentrated on the |0> basis state.
111    pub fn new(num_qubits: usize) -> Result<Self> {
112        Self::new_with_config(num_qubits, MpsConfig::default())
113    }
114
115    /// Initialize |00...0> with explicit configuration.
116    pub fn new_with_config(num_qubits: usize, config: MpsConfig) -> Result<Self> {
117        if num_qubits == 0 {
118            return Err(QuantumError::CircuitError(
119                "cannot create MPS with 0 qubits".into(),
120            ));
121        }
122        let mut tensors = Vec::with_capacity(num_qubits);
123        for _ in 0..num_qubits {
124            let mut t = MpsTensor::new_zero(1, 1);
125            // |0> component = 1, |1> component = 0
126            t.set(0, 0, 0, Complex::ONE);
127            tensors.push(t);
128        }
129        Ok(Self {
130            num_qubits,
131            tensors,
132            config,
133            rng: StdRng::from_entropy(),
134            measurement_record: Vec::new(),
135            total_truncation_error: 0.0,
136        })
137    }
138
139    /// Initialize |00...0> with a deterministic seed for reproducibility.
140    pub fn new_with_seed(num_qubits: usize, seed: u64, config: MpsConfig) -> Result<Self> {
141        let mut state = Self::new_with_config(num_qubits, config)?;
142        state.rng = StdRng::seed_from_u64(seed);
143        Ok(state)
144    }
145
146    // -------------------------------------------------------------------
147    // Accessors
148    // -------------------------------------------------------------------
149
150    pub fn num_qubits(&self) -> usize {
151        self.num_qubits
152    }
153
154    /// Current maximum bond dimension across all bonds in the MPS chain.
155    pub fn max_bond_dimension(&self) -> usize {
156        self.tensors
157            .iter()
158            .map(|t| t.left_dim.max(t.right_dim))
159            .max()
160            .unwrap_or(1)
161    }
162
163    /// Accumulated truncation error from bond-dimension truncations.
164    pub fn truncation_error(&self) -> f64 {
165        self.total_truncation_error
166    }
167
168    pub fn measurement_record(&self) -> &[MeasurementOutcome] {
169        &self.measurement_record
170    }
171
172    // -------------------------------------------------------------------
173    // Single-qubit gate
174    // -------------------------------------------------------------------
175
176    /// Apply a 2x2 unitary to a single qubit.
177    ///
178    /// Contracts the gate matrix with the physical index of tensor[qubit]:
179    ///   new_tensor(l, i', r) = Sum_i matrix[i'][i] * tensor(l, i, r)
180    ///
181    /// This does not change bond dimensions.
182    pub fn apply_single_qubit_gate(&mut self, qubit: usize, matrix: &[[Complex; 2]; 2]) {
183        let t = &self.tensors[qubit];
184        let left_dim = t.left_dim;
185        let right_dim = t.right_dim;
186        let mut new_t = MpsTensor::new_zero(left_dim, right_dim);
187
188        for l in 0..left_dim {
189            for r in 0..right_dim {
190                let v0 = t.get(l, 0, r);
191                let v1 = t.get(l, 1, r);
192                new_t.set(l, 0, r, matrix[0][0] * v0 + matrix[0][1] * v1);
193                new_t.set(l, 1, r, matrix[1][0] * v0 + matrix[1][1] * v1);
194            }
195        }
196        self.tensors[qubit] = new_t;
197    }
198
199    // -------------------------------------------------------------------
200    // Two-qubit gate (adjacent)
201    // -------------------------------------------------------------------
202
203    /// Apply a 4x4 unitary gate to two adjacent qubits.
204    ///
205    /// The algorithm:
206    /// 1. Contract tensors at q1 and q2 into a combined 4-index tensor.
207    /// 2. Apply the 4x4 gate matrix on the two physical indices.
208    /// 3. Reshape into a matrix and perform truncated QR decomposition.
209    /// 4. Split back into two MPS tensors, respecting max_bond_dim.
210    pub fn apply_two_qubit_gate_adjacent(
211        &mut self,
212        q1: usize,
213        q2: usize,
214        matrix: &[[Complex; 4]; 4],
215    ) -> Result<()> {
216        if q1 >= self.num_qubits || q2 >= self.num_qubits {
217            return Err(QuantumError::CircuitError(
218                "qubit index out of range for MPS".into(),
219            ));
220        }
221        // Ensure q1 < q2 for adjacent gate application.
222        let (qa, qb) = if q1 < q2 { (q1, q2) } else { (q2, q1) };
223        if qb - qa != 1 {
224            return Err(QuantumError::CircuitError(
225                "apply_two_qubit_gate_adjacent requires adjacent qubits".into(),
226            ));
227        }
228
229        let t_a = &self.tensors[qa];
230        let t_b = &self.tensors[qb];
231        let left_dim = t_a.left_dim;
232        let inner_dim = t_a.right_dim; // == t_b.left_dim
233        let right_dim = t_b.right_dim;
234
235        // Step 1: Contract over the shared bond index to form a 4-index tensor
236        // theta(l, ia, ib, r) = Sum_m A_a(l, ia, m) * A_b(m, ib, r)
237        let mut theta = vec![Complex::ZERO; left_dim * 2 * 2 * right_dim];
238        let theta_idx =
239            |l: usize, ia: usize, ib: usize, r: usize| -> usize {
240                l * (4 * right_dim) + ia * (2 * right_dim) + ib * right_dim + r
241            };
242
243        for l in 0..left_dim {
244            for ia in 0..2 {
245                for ib in 0..2 {
246                    for r in 0..right_dim {
247                        let mut sum = Complex::ZERO;
248                        for m in 0..inner_dim {
249                            sum += t_a.get(l, ia, m) * t_b.get(m, ib, r);
250                        }
251                        theta[theta_idx(l, ia, ib, r)] = sum;
252                    }
253                }
254            }
255        }
256
257        // Step 2: Apply the gate matrix on the physical indices.
258        // Gate index convention: row = ia' * 2 + ib', col = ia * 2 + ib
259        // If q1 > q2, the gate was specified with reversed qubit order;
260        // we must transpose the physical indices accordingly.
261        let swap_phys = q1 > q2;
262        let mut gated = vec![Complex::ZERO; left_dim * 2 * 2 * right_dim];
263        for l in 0..left_dim {
264            for r in 0..right_dim {
265                // Collect the 4 input values
266                let mut inp = [Complex::ZERO; 4];
267                for ia in 0..2 {
268                    for ib in 0..2 {
269                        let idx = if swap_phys { ib * 2 + ia } else { ia * 2 + ib };
270                        inp[idx] = theta[theta_idx(l, ia, ib, r)];
271                    }
272                }
273                // Apply gate
274                for ia_out in 0..2 {
275                    for ib_out in 0..2 {
276                        let row = if swap_phys {
277                            ib_out * 2 + ia_out
278                        } else {
279                            ia_out * 2 + ib_out
280                        };
281                        let mut val = Complex::ZERO;
282                        for c in 0..4 {
283                            val += matrix[row][c] * inp[c];
284                        }
285                        gated[theta_idx(l, ia_out, ib_out, r)] = val;
286                    }
287                }
288            }
289        }
290
291        // Step 3: Reshape into matrix of shape (left_dim * 2) x (2 * right_dim)
292        // and perform truncated decomposition.
293        let rows = left_dim * 2;
294        let cols = 2 * right_dim;
295        let mut mat = vec![Complex::ZERO; rows * cols];
296        for l in 0..left_dim {
297            for ia in 0..2 {
298                for ib in 0..2 {
299                    for r in 0..right_dim {
300                        let row = l * 2 + ia;
301                        let col = ib * right_dim + r;
302                        mat[row * cols + col] = gated[theta_idx(l, ia, ib, r)];
303                    }
304                }
305            }
306        }
307
308        let (q_mat, r_mat, new_bond, trunc_err) = Self::truncated_qr(
309            &mat,
310            rows,
311            cols,
312            self.config.max_bond_dim,
313            self.config.truncation_threshold,
314        );
315        self.total_truncation_error += trunc_err;
316
317        // Step 4: Reshape Q into tensor_a (left_dim, 2, new_bond)
318        //         and R into tensor_b (new_bond, 2, right_dim).
319        let mut new_a = MpsTensor::new_zero(left_dim, new_bond);
320        for l in 0..left_dim {
321            for ia in 0..2 {
322                for nb in 0..new_bond {
323                    let row = l * 2 + ia;
324                    new_a.set(l, ia, nb, q_mat[row * new_bond + nb]);
325                }
326            }
327        }
328
329        let mut new_b = MpsTensor::new_zero(new_bond, right_dim);
330        for nb in 0..new_bond {
331            for ib in 0..2 {
332                for r in 0..right_dim {
333                    let col = ib * right_dim + r;
334                    new_b.set(nb, ib, r, r_mat[nb * cols + col]);
335                }
336            }
337        }
338
339        self.tensors[qa] = new_a;
340        self.tensors[qb] = new_b;
341        Ok(())
342    }
343
344    // -------------------------------------------------------------------
345    // Two-qubit gate (general, possibly non-adjacent)
346    // -------------------------------------------------------------------
347
348    /// Apply a 4x4 gate to any pair of qubits.
349    ///
350    /// If the qubits are adjacent, delegates directly. Otherwise, uses SWAP
351    /// gates to move the qubits next to each other, applies the gate, then
352    /// swaps back to restore qubit ordering.
353    pub fn apply_two_qubit_gate(
354        &mut self,
355        q1: usize,
356        q2: usize,
357        matrix: &[[Complex; 4]; 4],
358    ) -> Result<()> {
359        if q1 == q2 {
360            return Err(QuantumError::CircuitError(
361                "two-qubit gate requires distinct qubits".into(),
362            ));
363        }
364        let diff = if q1 > q2 { q1 - q2 } else { q2 - q1 };
365        if diff == 1 {
366            return self.apply_two_qubit_gate_adjacent(q1, q2, matrix);
367        }
368
369        let swap_matrix = Self::swap_matrix();
370
371        // Move q1 adjacent to q2 via SWAP chain.
372        // We swap q1 toward q2, keeping track of its current position.
373        let (mut pos1, target_pos) = if q1 < q2 {
374            (q1, q2 - 1)
375        } else {
376            (q1, q2 + 1)
377        };
378
379        // Forward swaps: move pos1 toward target_pos
380        let forward_steps: Vec<usize> = if pos1 < target_pos {
381            (pos1..target_pos).collect()
382        } else {
383            (target_pos..pos1).rev().collect()
384        };
385
386        for &s in &forward_steps {
387            self.apply_two_qubit_gate_adjacent(s, s + 1, &swap_matrix)?;
388        }
389        pos1 = target_pos;
390
391        // Now pos1 and q2 are adjacent: apply the gate.
392        self.apply_two_qubit_gate_adjacent(pos1, q2, matrix)?;
393
394        // Reverse swaps to restore original qubit ordering.
395        for &s in forward_steps.iter().rev() {
396            self.apply_two_qubit_gate_adjacent(s, s + 1, &swap_matrix)?;
397        }
398
399        Ok(())
400    }
401
402    // -------------------------------------------------------------------
403    // Measurement
404    // -------------------------------------------------------------------
405
406    /// Measure a single qubit projectively.
407    ///
408    /// 1. Compute the probability of |0> by locally contracting the MPS.
409    /// 2. Sample the outcome.
410    /// 3. Collapse the tensor at the measured qubit by projecting.
411    /// 4. Renormalize.
412    pub fn measure(&mut self, qubit: usize) -> Result<MeasurementOutcome> {
413        if qubit >= self.num_qubits {
414            return Err(QuantumError::InvalidQubitIndex {
415                index: qubit as QubitIndex,
416                num_qubits: self.num_qubits as u32,
417            });
418        }
419
420        // Compute reduced density matrix element rho_00 and rho_11
421        // for the target qubit by contracting the MPS from both ends.
422        let (p0, p1) = self.qubit_probabilities(qubit);
423        let total = p0 + p1;
424        let p0_norm = if total > 0.0 { p0 / total } else { 0.5 };
425
426        let random: f64 = self.rng.gen();
427        let result = random >= p0_norm; // true => measured |1>
428        let prob = if result { 1.0 - p0_norm } else { p0_norm };
429
430        // Collapse: project the tensor at this qubit onto the measured state.
431        let t = &self.tensors[qubit];
432        let left_dim = t.left_dim;
433        let right_dim = t.right_dim;
434        let measured_phys: usize = if result { 1 } else { 0 };
435
436        let mut new_t = MpsTensor::new_zero(left_dim, right_dim);
437        for l in 0..left_dim {
438            for r in 0..right_dim {
439                new_t.set(l, measured_phys, r, t.get(l, measured_phys, r));
440            }
441        }
442
443        // Renormalize the projected tensor.
444        let mut norm_sq = 0.0;
445        for val in &new_t.data {
446            norm_sq += val.norm_sq();
447        }
448        if norm_sq > 0.0 {
449            let inv_norm = 1.0 / norm_sq.sqrt();
450            for val in new_t.data.iter_mut() {
451                *val = *val * inv_norm;
452            }
453        }
454
455        self.tensors[qubit] = new_t;
456
457        let outcome = MeasurementOutcome {
458            qubit: qubit as QubitIndex,
459            result,
460            probability: prob,
461        };
462        self.measurement_record.push(outcome.clone());
463        Ok(outcome)
464    }
465
466    // -------------------------------------------------------------------
467    // Gate dispatch
468    // -------------------------------------------------------------------
469
470    /// Apply a gate from the Gate enum, returning any measurement outcomes.
471    pub fn apply_gate(&mut self, gate: &Gate) -> Result<Vec<MeasurementOutcome>> {
472        for &q in gate.qubits().iter() {
473            if (q as usize) >= self.num_qubits {
474                return Err(QuantumError::InvalidQubitIndex {
475                    index: q,
476                    num_qubits: self.num_qubits as u32,
477                });
478            }
479        }
480
481        match gate {
482            Gate::Barrier => Ok(vec![]),
483
484            Gate::Measure(q) => {
485                let outcome = self.measure(*q as usize)?;
486                Ok(vec![outcome])
487            }
488
489            Gate::Reset(q) => {
490                let outcome = self.measure(*q as usize)?;
491                if outcome.result {
492                    let x = Gate::X(*q).matrix_1q().unwrap();
493                    self.apply_single_qubit_gate(*q as usize, &x);
494                }
495                Ok(vec![])
496            }
497
498            Gate::CNOT(q1, q2)
499            | Gate::CZ(q1, q2)
500            | Gate::SWAP(q1, q2)
501            | Gate::Rzz(q1, q2, _) => {
502                if q1 == q2 {
503                    return Err(QuantumError::CircuitError(format!(
504                        "two-qubit gate requires distinct qubits, got {} and {}",
505                        q1, q2
506                    )));
507                }
508                let matrix = gate.matrix_2q().unwrap();
509                self.apply_two_qubit_gate(*q1 as usize, *q2 as usize, &matrix)?;
510                Ok(vec![])
511            }
512
513            other => {
514                if let Some(matrix) = other.matrix_1q() {
515                    let q = other.qubits()[0];
516                    self.apply_single_qubit_gate(q as usize, &matrix);
517                    Ok(vec![])
518                } else {
519                    Err(QuantumError::CircuitError(format!(
520                        "unsupported gate for MPS: {:?}",
521                        other
522                    )))
523                }
524            }
525        }
526    }
527
528    // -------------------------------------------------------------------
529    // Internal: SWAP matrix
530    // -------------------------------------------------------------------
531
532    fn swap_matrix() -> [[Complex; 4]; 4] {
533        let c0 = Complex::ZERO;
534        let c1 = Complex::ONE;
535        [
536            [c1, c0, c0, c0],
537            [c0, c0, c1, c0],
538            [c0, c1, c0, c0],
539            [c0, c0, c0, c1],
540        ]
541    }
542
543    // -------------------------------------------------------------------
544    // Internal: qubit probability computation
545    // -------------------------------------------------------------------
546
547    /// Compute (prob_0, prob_1) for a single qubit by contracting the MPS.
548    ///
549    /// This builds a partial "environment" from the left and right boundaries,
550    /// then contracts through the target qubit tensor for each physical index.
551    fn qubit_probabilities(&self, qubit: usize) -> (f64, f64) {
552        // Left environment: contract tensors 0..qubit into a matrix.
553        // env_left has shape (bond_dim, bond_dim) representing
554        // Sum_{physical indices} conj(A) * A contracted from the left.
555        let bond_left = self.tensors[qubit].left_dim;
556        let mut env_left = vec![Complex::ZERO; bond_left * bond_left];
557        // Initialize to identity (boundary condition: left boundary = 1).
558        for i in 0..bond_left {
559            env_left[i * bond_left + i] = Complex::ONE;
560        }
561        // Contract from site 0 to qubit-1.
562        for site in 0..qubit {
563            let t = &self.tensors[site];
564            let dim_in = t.left_dim;
565            let dim_out = t.right_dim;
566            let mut new_env = vec![Complex::ZERO; dim_out * dim_out];
567            for ro in 0..dim_out {
568                for co in 0..dim_out {
569                    let mut sum = Complex::ZERO;
570                    for ri in 0..dim_in {
571                        for ci in 0..dim_in {
572                            let e = env_left[ri * dim_in + ci];
573                            if e.norm_sq() == 0.0 {
574                                continue;
575                            }
576                            for p in 0..2 {
577                                sum += e.conj() // env^*
578                                    * t.get(ri, p, ro).conj()
579                                    * t.get(ci, p, co);
580                            }
581                        }
582                    }
583                    new_env[ro * dim_out + co] = sum;
584                }
585            }
586            env_left = new_env;
587        }
588
589        // Right environment: contract tensors (qubit+1)..num_qubits.
590        let bond_right = self.tensors[qubit].right_dim;
591        let mut env_right = vec![Complex::ZERO; bond_right * bond_right];
592        for i in 0..bond_right {
593            env_right[i * bond_right + i] = Complex::ONE;
594        }
595        for site in (qubit + 1..self.num_qubits).rev() {
596            let t = &self.tensors[site];
597            let dim_in = t.right_dim;
598            let dim_out = t.left_dim;
599            let mut new_env = vec![Complex::ZERO; dim_out * dim_out];
600            for ro in 0..dim_out {
601                for co in 0..dim_out {
602                    let mut sum = Complex::ZERO;
603                    for ri in 0..dim_in {
604                        for ci in 0..dim_in {
605                            let e = env_right[ri * dim_in + ci];
606                            if e.norm_sq() == 0.0 {
607                                continue;
608                            }
609                            for p in 0..2 {
610                                sum += e.conj()
611                                    * t.get(ro, p, ri).conj()
612                                    * t.get(co, p, ci);
613                            }
614                        }
615                    }
616                    new_env[ro * dim_out + co] = sum;
617                }
618            }
619            env_right = new_env;
620        }
621
622        // Contract with the target qubit tensor for each physical index.
623        let t = &self.tensors[qubit];
624        let mut probs = [0.0f64; 2];
625        for phys in 0..2 {
626            let mut val = Complex::ZERO;
627            for l1 in 0..t.left_dim {
628                for l2 in 0..t.left_dim {
629                    let e_l = env_left[l1 * t.left_dim + l2];
630                    if e_l.norm_sq() == 0.0 {
631                        continue;
632                    }
633                    for r1 in 0..t.right_dim {
634                        for r2 in 0..t.right_dim {
635                            let e_r = env_right[r1 * t.right_dim + r2];
636                            if e_r.norm_sq() == 0.0 {
637                                continue;
638                            }
639                            val += e_l.conj()
640                                * t.get(l1, phys, r1).conj()
641                                * t.get(l2, phys, r2)
642                                * e_r;
643                        }
644                    }
645                }
646            }
647            probs[phys] = val.re; // Should be real for a valid density matrix
648        }
649
650        (probs[0].max(0.0), probs[1].max(0.0))
651    }
652
653    // -------------------------------------------------------------------
654    // Internal: Truncated QR decomposition
655    // -------------------------------------------------------------------
656
657    /// Perform modified Gram-Schmidt QR on a complex matrix, then truncate.
658    ///
659    /// Given matrix M of shape (rows x cols), computes M = Q * R where Q has
660    /// orthonormal columns and R is upper triangular. Truncates to at most
661    /// `max_rank` columns of Q (and rows of R), discarding columns whose
662    /// R diagonal magnitude falls below `threshold`.
663    ///
664    /// Returns (Q_flat, R_flat, rank, truncation_error).
665    fn truncated_qr(
666        mat: &[Complex],
667        rows: usize,
668        cols: usize,
669        max_rank: usize,
670        threshold: f64,
671    ) -> (Vec<Complex>, Vec<Complex>, usize, f64) {
672        let rank_bound = rows.min(cols).min(max_rank);
673
674        // Modified Gram-Schmidt: build Q column by column, R simultaneously.
675        let mut q_cols: Vec<Vec<Complex>> = Vec::with_capacity(rank_bound);
676        let mut r_data = vec![Complex::ZERO; rank_bound * cols];
677        let mut actual_rank = 0;
678        let mut trunc_error = 0.0;
679
680        for j in 0..cols.min(rank_bound + cols) {
681            if actual_rank >= rank_bound {
682                // Estimate truncation error from remaining columns.
683                if j < cols {
684                    for jj in j..cols {
685                        let mut col_norm_sq = 0.0;
686                        for i in 0..rows {
687                            col_norm_sq += mat[i * cols + jj].norm_sq();
688                        }
689                        trunc_error += col_norm_sq;
690                    }
691                    trunc_error = trunc_error.sqrt();
692                }
693                break;
694            }
695            if j >= cols {
696                break;
697            }
698
699            // Extract column j of the input matrix.
700            let mut v: Vec<Complex> = (0..rows).map(|i| mat[i * cols + j]).collect();
701
702            // Orthogonalize against existing Q columns.
703            for k in 0..actual_rank {
704                let mut dot = Complex::ZERO;
705                for i in 0..rows {
706                    dot += q_cols[k][i].conj() * v[i];
707                }
708                r_data[k * cols + j] = dot;
709                for i in 0..rows {
710                    v[i] = v[i] - dot * q_cols[k][i];
711                }
712            }
713
714            // Compute norm of residual.
715            let mut norm_sq = 0.0;
716            for i in 0..rows {
717                norm_sq += v[i].norm_sq();
718            }
719            let norm = norm_sq.sqrt();
720
721            if norm < threshold {
722                // Column is (nearly) linearly dependent; skip it.
723                trunc_error += norm;
724                continue;
725            }
726
727            // Normalize and store.
728            r_data[actual_rank * cols + j] = Complex::new(norm, 0.0);
729            let inv_norm = 1.0 / norm;
730            for i in 0..rows {
731                v[i] = v[i] * inv_norm;
732            }
733            q_cols.push(v);
734            actual_rank += 1;
735        }
736
737        // Ensure at least rank 1 to avoid degenerate tensors.
738        if actual_rank == 0 {
739            actual_rank = 1;
740            q_cols.push(vec![Complex::ZERO; rows]);
741            q_cols[0][0] = Complex::ONE;
742            // R remains zero.
743        }
744
745        // Flatten Q: shape (rows, actual_rank)
746        let mut q_flat = vec![Complex::ZERO; rows * actual_rank];
747        for i in 0..rows {
748            for k in 0..actual_rank {
749                q_flat[i * actual_rank + k] = q_cols[k][i];
750            }
751        }
752
753        // Trim R to shape (actual_rank, cols)
754        let mut r_flat = vec![Complex::ZERO; actual_rank * cols];
755        for k in 0..actual_rank {
756            for j in 0..cols {
757                r_flat[k * cols + j] = r_data[k * cols + j];
758            }
759        }
760
761        (q_flat, r_flat, actual_rank, trunc_error)
762    }
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    #[test]
770    fn test_new_product_state() {
771        let mps = MpsState::new(4).unwrap();
772        assert_eq!(mps.num_qubits(), 4);
773        assert_eq!(mps.max_bond_dimension(), 1);
774        assert_eq!(mps.truncation_error(), 0.0);
775    }
776
777    #[test]
778    fn test_zero_qubits_errors() {
779        assert!(MpsState::new(0).is_err());
780    }
781
782    #[test]
783    fn test_single_qubit_x_gate() {
784        let mut mps = MpsState::new_with_seed(1, 42, MpsConfig::default()).unwrap();
785        // X gate: flips |0> to |1>
786        let x = [[Complex::ZERO, Complex::ONE], [Complex::ONE, Complex::ZERO]];
787        mps.apply_single_qubit_gate(0, &x);
788        // After X, tensor should have |1> = 1, |0> = 0
789        let t = &mps.tensors[0];
790        assert!(t.get(0, 0, 0).norm_sq() < 1e-20);
791        assert!((t.get(0, 1, 0).norm_sq() - 1.0).abs() < 1e-10);
792    }
793
794    #[test]
795    fn test_single_qubit_h_gate() {
796        let mut mps = MpsState::new_with_seed(1, 42, MpsConfig::default()).unwrap();
797        let h = std::f64::consts::FRAC_1_SQRT_2;
798        let hc = Complex::new(h, 0.0);
799        let h_gate = [[hc, hc], [hc, -hc]];
800        mps.apply_single_qubit_gate(0, &h_gate);
801        // After H|0>, both amplitudes should be 1/sqrt(2)
802        let t = &mps.tensors[0];
803        assert!((t.get(0, 0, 0).norm_sq() - 0.5).abs() < 1e-10);
804        assert!((t.get(0, 1, 0).norm_sq() - 0.5).abs() < 1e-10);
805    }
806
807    #[test]
808    fn test_cnot_creates_bell_state() {
809        let mut mps = MpsState::new_with_seed(2, 42, MpsConfig::default()).unwrap();
810        // Apply H to qubit 0
811        let h = std::f64::consts::FRAC_1_SQRT_2;
812        let hc = Complex::new(h, 0.0);
813        let h_gate = [[hc, hc], [hc, -hc]];
814        mps.apply_single_qubit_gate(0, &h_gate);
815
816        // Apply CNOT(0,1)
817        let c0 = Complex::ZERO;
818        let c1 = Complex::ONE;
819        let cnot = [
820            [c1, c0, c0, c0],
821            [c0, c1, c0, c0],
822            [c0, c0, c0, c1],
823            [c0, c0, c1, c0],
824        ];
825        mps.apply_two_qubit_gate(0, 1, &cnot).unwrap();
826        // Bond dimension should have increased from 1 to 2
827        assert!(mps.max_bond_dimension() >= 2);
828    }
829
830    #[test]
831    fn test_measurement_deterministic() {
832        // |0> state: measuring should always give 0
833        let mut mps = MpsState::new_with_seed(1, 42, MpsConfig::default()).unwrap();
834        let outcome = mps.measure(0).unwrap();
835        assert!(!outcome.result);
836        assert!((outcome.probability - 1.0).abs() < 1e-10);
837    }
838
839    #[test]
840    fn test_gate_dispatch() {
841        let mut mps = MpsState::new_with_seed(2, 42, MpsConfig::default()).unwrap();
842        let outcomes = mps.apply_gate(&Gate::H(0)).unwrap();
843        assert!(outcomes.is_empty());
844        let outcomes = mps.apply_gate(&Gate::CNOT(0, 1)).unwrap();
845        assert!(outcomes.is_empty());
846    }
847
848    #[test]
849    fn test_non_adjacent_two_qubit_gate() {
850        let mut mps = MpsState::new_with_seed(4, 42, MpsConfig::default()).unwrap();
851        // Apply CNOT between qubits 0 and 3 (non-adjacent)
852        let c0 = Complex::ZERO;
853        let c1 = Complex::ONE;
854        let cnot = [
855            [c1, c0, c0, c0],
856            [c0, c1, c0, c0],
857            [c0, c0, c0, c1],
858            [c0, c0, c1, c0],
859        ];
860        // Should not error even though qubits are non-adjacent
861        mps.apply_two_qubit_gate(0, 3, &cnot).unwrap();
862    }
863}