1use crate::error::{QuantumError, Result};
12use crate::gate::Gate;
13use crate::types::{Complex, MeasurementOutcome, QubitIndex};
14
15use rand::rngs::StdRng;
16use rand::{Rng, SeedableRng};
17
18#[derive(Debug, Clone)]
20pub struct MpsConfig {
21 pub max_bond_dim: usize,
25 pub truncation_threshold: f64,
27}
28
29impl Default for MpsConfig {
30 fn default() -> Self {
31 Self {
32 max_bond_dim: 256,
33 truncation_threshold: 1e-10,
34 }
35 }
36}
37
38#[derive(Clone)]
47struct MpsTensor {
48 data: Vec<Complex>,
49 left_dim: usize,
50 right_dim: usize,
51}
52
53impl MpsTensor {
54 fn new_zero(left_dim: usize, right_dim: usize) -> Self {
56 Self {
57 data: vec![Complex::ZERO; left_dim * 2 * right_dim],
58 left_dim,
59 right_dim,
60 }
61 }
62
63 #[inline]
65 fn index(&self, left: usize, phys: usize, right: usize) -> usize {
66 left * (2 * self.right_dim) + phys * self.right_dim + right
67 }
68
69 #[inline]
71 fn get(&self, left: usize, phys: usize, right: usize) -> Complex {
72 self.data[self.index(left, phys, right)]
73 }
74
75 #[inline]
77 fn set(&mut self, left: usize, phys: usize, right: usize, val: Complex) {
78 let idx = self.index(left, phys, right);
79 self.data[idx] = val;
80 }
81}
82
83pub struct MpsState {
93 num_qubits: usize,
94 tensors: Vec<MpsTensor>,
95 config: MpsConfig,
96 rng: StdRng,
97 measurement_record: Vec<MeasurementOutcome>,
98 total_truncation_error: f64,
100}
101
102impl MpsState {
107 pub fn new(num_qubits: usize) -> Result<Self> {
112 Self::new_with_config(num_qubits, MpsConfig::default())
113 }
114
115 pub fn new_with_config(num_qubits: usize, config: MpsConfig) -> Result<Self> {
117 if num_qubits == 0 {
118 return Err(QuantumError::CircuitError(
119 "cannot create MPS with 0 qubits".into(),
120 ));
121 }
122 let mut tensors = Vec::with_capacity(num_qubits);
123 for _ in 0..num_qubits {
124 let mut t = MpsTensor::new_zero(1, 1);
125 t.set(0, 0, 0, Complex::ONE);
127 tensors.push(t);
128 }
129 Ok(Self {
130 num_qubits,
131 tensors,
132 config,
133 rng: StdRng::from_entropy(),
134 measurement_record: Vec::new(),
135 total_truncation_error: 0.0,
136 })
137 }
138
139 pub fn new_with_seed(num_qubits: usize, seed: u64, config: MpsConfig) -> Result<Self> {
141 let mut state = Self::new_with_config(num_qubits, config)?;
142 state.rng = StdRng::seed_from_u64(seed);
143 Ok(state)
144 }
145
146 pub fn num_qubits(&self) -> usize {
151 self.num_qubits
152 }
153
154 pub fn max_bond_dimension(&self) -> usize {
156 self.tensors
157 .iter()
158 .map(|t| t.left_dim.max(t.right_dim))
159 .max()
160 .unwrap_or(1)
161 }
162
163 pub fn truncation_error(&self) -> f64 {
165 self.total_truncation_error
166 }
167
168 pub fn measurement_record(&self) -> &[MeasurementOutcome] {
169 &self.measurement_record
170 }
171
172 pub fn apply_single_qubit_gate(&mut self, qubit: usize, matrix: &[[Complex; 2]; 2]) {
183 let t = &self.tensors[qubit];
184 let left_dim = t.left_dim;
185 let right_dim = t.right_dim;
186 let mut new_t = MpsTensor::new_zero(left_dim, right_dim);
187
188 for l in 0..left_dim {
189 for r in 0..right_dim {
190 let v0 = t.get(l, 0, r);
191 let v1 = t.get(l, 1, r);
192 new_t.set(l, 0, r, matrix[0][0] * v0 + matrix[0][1] * v1);
193 new_t.set(l, 1, r, matrix[1][0] * v0 + matrix[1][1] * v1);
194 }
195 }
196 self.tensors[qubit] = new_t;
197 }
198
199 pub fn apply_two_qubit_gate_adjacent(
211 &mut self,
212 q1: usize,
213 q2: usize,
214 matrix: &[[Complex; 4]; 4],
215 ) -> Result<()> {
216 if q1 >= self.num_qubits || q2 >= self.num_qubits {
217 return Err(QuantumError::CircuitError(
218 "qubit index out of range for MPS".into(),
219 ));
220 }
221 let (qa, qb) = if q1 < q2 { (q1, q2) } else { (q2, q1) };
223 if qb - qa != 1 {
224 return Err(QuantumError::CircuitError(
225 "apply_two_qubit_gate_adjacent requires adjacent qubits".into(),
226 ));
227 }
228
229 let t_a = &self.tensors[qa];
230 let t_b = &self.tensors[qb];
231 let left_dim = t_a.left_dim;
232 let inner_dim = t_a.right_dim; let right_dim = t_b.right_dim;
234
235 let mut theta = vec![Complex::ZERO; left_dim * 2 * 2 * right_dim];
238 let theta_idx =
239 |l: usize, ia: usize, ib: usize, r: usize| -> usize {
240 l * (4 * right_dim) + ia * (2 * right_dim) + ib * right_dim + r
241 };
242
243 for l in 0..left_dim {
244 for ia in 0..2 {
245 for ib in 0..2 {
246 for r in 0..right_dim {
247 let mut sum = Complex::ZERO;
248 for m in 0..inner_dim {
249 sum += t_a.get(l, ia, m) * t_b.get(m, ib, r);
250 }
251 theta[theta_idx(l, ia, ib, r)] = sum;
252 }
253 }
254 }
255 }
256
257 let swap_phys = q1 > q2;
262 let mut gated = vec![Complex::ZERO; left_dim * 2 * 2 * right_dim];
263 for l in 0..left_dim {
264 for r in 0..right_dim {
265 let mut inp = [Complex::ZERO; 4];
267 for ia in 0..2 {
268 for ib in 0..2 {
269 let idx = if swap_phys { ib * 2 + ia } else { ia * 2 + ib };
270 inp[idx] = theta[theta_idx(l, ia, ib, r)];
271 }
272 }
273 for ia_out in 0..2 {
275 for ib_out in 0..2 {
276 let row = if swap_phys {
277 ib_out * 2 + ia_out
278 } else {
279 ia_out * 2 + ib_out
280 };
281 let mut val = Complex::ZERO;
282 for c in 0..4 {
283 val += matrix[row][c] * inp[c];
284 }
285 gated[theta_idx(l, ia_out, ib_out, r)] = val;
286 }
287 }
288 }
289 }
290
291 let rows = left_dim * 2;
294 let cols = 2 * right_dim;
295 let mut mat = vec![Complex::ZERO; rows * cols];
296 for l in 0..left_dim {
297 for ia in 0..2 {
298 for ib in 0..2 {
299 for r in 0..right_dim {
300 let row = l * 2 + ia;
301 let col = ib * right_dim + r;
302 mat[row * cols + col] = gated[theta_idx(l, ia, ib, r)];
303 }
304 }
305 }
306 }
307
308 let (q_mat, r_mat, new_bond, trunc_err) = Self::truncated_qr(
309 &mat,
310 rows,
311 cols,
312 self.config.max_bond_dim,
313 self.config.truncation_threshold,
314 );
315 self.total_truncation_error += trunc_err;
316
317 let mut new_a = MpsTensor::new_zero(left_dim, new_bond);
320 for l in 0..left_dim {
321 for ia in 0..2 {
322 for nb in 0..new_bond {
323 let row = l * 2 + ia;
324 new_a.set(l, ia, nb, q_mat[row * new_bond + nb]);
325 }
326 }
327 }
328
329 let mut new_b = MpsTensor::new_zero(new_bond, right_dim);
330 for nb in 0..new_bond {
331 for ib in 0..2 {
332 for r in 0..right_dim {
333 let col = ib * right_dim + r;
334 new_b.set(nb, ib, r, r_mat[nb * cols + col]);
335 }
336 }
337 }
338
339 self.tensors[qa] = new_a;
340 self.tensors[qb] = new_b;
341 Ok(())
342 }
343
344 pub fn apply_two_qubit_gate(
354 &mut self,
355 q1: usize,
356 q2: usize,
357 matrix: &[[Complex; 4]; 4],
358 ) -> Result<()> {
359 if q1 == q2 {
360 return Err(QuantumError::CircuitError(
361 "two-qubit gate requires distinct qubits".into(),
362 ));
363 }
364 let diff = if q1 > q2 { q1 - q2 } else { q2 - q1 };
365 if diff == 1 {
366 return self.apply_two_qubit_gate_adjacent(q1, q2, matrix);
367 }
368
369 let swap_matrix = Self::swap_matrix();
370
371 let (mut pos1, target_pos) = if q1 < q2 {
374 (q1, q2 - 1)
375 } else {
376 (q1, q2 + 1)
377 };
378
379 let forward_steps: Vec<usize> = if pos1 < target_pos {
381 (pos1..target_pos).collect()
382 } else {
383 (target_pos..pos1).rev().collect()
384 };
385
386 for &s in &forward_steps {
387 self.apply_two_qubit_gate_adjacent(s, s + 1, &swap_matrix)?;
388 }
389 pos1 = target_pos;
390
391 self.apply_two_qubit_gate_adjacent(pos1, q2, matrix)?;
393
394 for &s in forward_steps.iter().rev() {
396 self.apply_two_qubit_gate_adjacent(s, s + 1, &swap_matrix)?;
397 }
398
399 Ok(())
400 }
401
402 pub fn measure(&mut self, qubit: usize) -> Result<MeasurementOutcome> {
413 if qubit >= self.num_qubits {
414 return Err(QuantumError::InvalidQubitIndex {
415 index: qubit as QubitIndex,
416 num_qubits: self.num_qubits as u32,
417 });
418 }
419
420 let (p0, p1) = self.qubit_probabilities(qubit);
423 let total = p0 + p1;
424 let p0_norm = if total > 0.0 { p0 / total } else { 0.5 };
425
426 let random: f64 = self.rng.gen();
427 let result = random >= p0_norm; let prob = if result { 1.0 - p0_norm } else { p0_norm };
429
430 let t = &self.tensors[qubit];
432 let left_dim = t.left_dim;
433 let right_dim = t.right_dim;
434 let measured_phys: usize = if result { 1 } else { 0 };
435
436 let mut new_t = MpsTensor::new_zero(left_dim, right_dim);
437 for l in 0..left_dim {
438 for r in 0..right_dim {
439 new_t.set(l, measured_phys, r, t.get(l, measured_phys, r));
440 }
441 }
442
443 let mut norm_sq = 0.0;
445 for val in &new_t.data {
446 norm_sq += val.norm_sq();
447 }
448 if norm_sq > 0.0 {
449 let inv_norm = 1.0 / norm_sq.sqrt();
450 for val in new_t.data.iter_mut() {
451 *val = *val * inv_norm;
452 }
453 }
454
455 self.tensors[qubit] = new_t;
456
457 let outcome = MeasurementOutcome {
458 qubit: qubit as QubitIndex,
459 result,
460 probability: prob,
461 };
462 self.measurement_record.push(outcome.clone());
463 Ok(outcome)
464 }
465
466 pub fn apply_gate(&mut self, gate: &Gate) -> Result<Vec<MeasurementOutcome>> {
472 for &q in gate.qubits().iter() {
473 if (q as usize) >= self.num_qubits {
474 return Err(QuantumError::InvalidQubitIndex {
475 index: q,
476 num_qubits: self.num_qubits as u32,
477 });
478 }
479 }
480
481 match gate {
482 Gate::Barrier => Ok(vec![]),
483
484 Gate::Measure(q) => {
485 let outcome = self.measure(*q as usize)?;
486 Ok(vec![outcome])
487 }
488
489 Gate::Reset(q) => {
490 let outcome = self.measure(*q as usize)?;
491 if outcome.result {
492 let x = Gate::X(*q).matrix_1q().unwrap();
493 self.apply_single_qubit_gate(*q as usize, &x);
494 }
495 Ok(vec![])
496 }
497
498 Gate::CNOT(q1, q2)
499 | Gate::CZ(q1, q2)
500 | Gate::SWAP(q1, q2)
501 | Gate::Rzz(q1, q2, _) => {
502 if q1 == q2 {
503 return Err(QuantumError::CircuitError(format!(
504 "two-qubit gate requires distinct qubits, got {} and {}",
505 q1, q2
506 )));
507 }
508 let matrix = gate.matrix_2q().unwrap();
509 self.apply_two_qubit_gate(*q1 as usize, *q2 as usize, &matrix)?;
510 Ok(vec![])
511 }
512
513 other => {
514 if let Some(matrix) = other.matrix_1q() {
515 let q = other.qubits()[0];
516 self.apply_single_qubit_gate(q as usize, &matrix);
517 Ok(vec![])
518 } else {
519 Err(QuantumError::CircuitError(format!(
520 "unsupported gate for MPS: {:?}",
521 other
522 )))
523 }
524 }
525 }
526 }
527
528 fn swap_matrix() -> [[Complex; 4]; 4] {
533 let c0 = Complex::ZERO;
534 let c1 = Complex::ONE;
535 [
536 [c1, c0, c0, c0],
537 [c0, c0, c1, c0],
538 [c0, c1, c0, c0],
539 [c0, c0, c0, c1],
540 ]
541 }
542
543 fn qubit_probabilities(&self, qubit: usize) -> (f64, f64) {
552 let bond_left = self.tensors[qubit].left_dim;
556 let mut env_left = vec![Complex::ZERO; bond_left * bond_left];
557 for i in 0..bond_left {
559 env_left[i * bond_left + i] = Complex::ONE;
560 }
561 for site in 0..qubit {
563 let t = &self.tensors[site];
564 let dim_in = t.left_dim;
565 let dim_out = t.right_dim;
566 let mut new_env = vec![Complex::ZERO; dim_out * dim_out];
567 for ro in 0..dim_out {
568 for co in 0..dim_out {
569 let mut sum = Complex::ZERO;
570 for ri in 0..dim_in {
571 for ci in 0..dim_in {
572 let e = env_left[ri * dim_in + ci];
573 if e.norm_sq() == 0.0 {
574 continue;
575 }
576 for p in 0..2 {
577 sum += e.conj() * t.get(ri, p, ro).conj()
579 * t.get(ci, p, co);
580 }
581 }
582 }
583 new_env[ro * dim_out + co] = sum;
584 }
585 }
586 env_left = new_env;
587 }
588
589 let bond_right = self.tensors[qubit].right_dim;
591 let mut env_right = vec![Complex::ZERO; bond_right * bond_right];
592 for i in 0..bond_right {
593 env_right[i * bond_right + i] = Complex::ONE;
594 }
595 for site in (qubit + 1..self.num_qubits).rev() {
596 let t = &self.tensors[site];
597 let dim_in = t.right_dim;
598 let dim_out = t.left_dim;
599 let mut new_env = vec![Complex::ZERO; dim_out * dim_out];
600 for ro in 0..dim_out {
601 for co in 0..dim_out {
602 let mut sum = Complex::ZERO;
603 for ri in 0..dim_in {
604 for ci in 0..dim_in {
605 let e = env_right[ri * dim_in + ci];
606 if e.norm_sq() == 0.0 {
607 continue;
608 }
609 for p in 0..2 {
610 sum += e.conj()
611 * t.get(ro, p, ri).conj()
612 * t.get(co, p, ci);
613 }
614 }
615 }
616 new_env[ro * dim_out + co] = sum;
617 }
618 }
619 env_right = new_env;
620 }
621
622 let t = &self.tensors[qubit];
624 let mut probs = [0.0f64; 2];
625 for phys in 0..2 {
626 let mut val = Complex::ZERO;
627 for l1 in 0..t.left_dim {
628 for l2 in 0..t.left_dim {
629 let e_l = env_left[l1 * t.left_dim + l2];
630 if e_l.norm_sq() == 0.0 {
631 continue;
632 }
633 for r1 in 0..t.right_dim {
634 for r2 in 0..t.right_dim {
635 let e_r = env_right[r1 * t.right_dim + r2];
636 if e_r.norm_sq() == 0.0 {
637 continue;
638 }
639 val += e_l.conj()
640 * t.get(l1, phys, r1).conj()
641 * t.get(l2, phys, r2)
642 * e_r;
643 }
644 }
645 }
646 }
647 probs[phys] = val.re; }
649
650 (probs[0].max(0.0), probs[1].max(0.0))
651 }
652
653 fn truncated_qr(
666 mat: &[Complex],
667 rows: usize,
668 cols: usize,
669 max_rank: usize,
670 threshold: f64,
671 ) -> (Vec<Complex>, Vec<Complex>, usize, f64) {
672 let rank_bound = rows.min(cols).min(max_rank);
673
674 let mut q_cols: Vec<Vec<Complex>> = Vec::with_capacity(rank_bound);
676 let mut r_data = vec![Complex::ZERO; rank_bound * cols];
677 let mut actual_rank = 0;
678 let mut trunc_error = 0.0;
679
680 for j in 0..cols.min(rank_bound + cols) {
681 if actual_rank >= rank_bound {
682 if j < cols {
684 for jj in j..cols {
685 let mut col_norm_sq = 0.0;
686 for i in 0..rows {
687 col_norm_sq += mat[i * cols + jj].norm_sq();
688 }
689 trunc_error += col_norm_sq;
690 }
691 trunc_error = trunc_error.sqrt();
692 }
693 break;
694 }
695 if j >= cols {
696 break;
697 }
698
699 let mut v: Vec<Complex> = (0..rows).map(|i| mat[i * cols + j]).collect();
701
702 for k in 0..actual_rank {
704 let mut dot = Complex::ZERO;
705 for i in 0..rows {
706 dot += q_cols[k][i].conj() * v[i];
707 }
708 r_data[k * cols + j] = dot;
709 for i in 0..rows {
710 v[i] = v[i] - dot * q_cols[k][i];
711 }
712 }
713
714 let mut norm_sq = 0.0;
716 for i in 0..rows {
717 norm_sq += v[i].norm_sq();
718 }
719 let norm = norm_sq.sqrt();
720
721 if norm < threshold {
722 trunc_error += norm;
724 continue;
725 }
726
727 r_data[actual_rank * cols + j] = Complex::new(norm, 0.0);
729 let inv_norm = 1.0 / norm;
730 for i in 0..rows {
731 v[i] = v[i] * inv_norm;
732 }
733 q_cols.push(v);
734 actual_rank += 1;
735 }
736
737 if actual_rank == 0 {
739 actual_rank = 1;
740 q_cols.push(vec![Complex::ZERO; rows]);
741 q_cols[0][0] = Complex::ONE;
742 }
744
745 let mut q_flat = vec![Complex::ZERO; rows * actual_rank];
747 for i in 0..rows {
748 for k in 0..actual_rank {
749 q_flat[i * actual_rank + k] = q_cols[k][i];
750 }
751 }
752
753 let mut r_flat = vec![Complex::ZERO; actual_rank * cols];
755 for k in 0..actual_rank {
756 for j in 0..cols {
757 r_flat[k * cols + j] = r_data[k * cols + j];
758 }
759 }
760
761 (q_flat, r_flat, actual_rank, trunc_error)
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768
769 #[test]
770 fn test_new_product_state() {
771 let mps = MpsState::new(4).unwrap();
772 assert_eq!(mps.num_qubits(), 4);
773 assert_eq!(mps.max_bond_dimension(), 1);
774 assert_eq!(mps.truncation_error(), 0.0);
775 }
776
777 #[test]
778 fn test_zero_qubits_errors() {
779 assert!(MpsState::new(0).is_err());
780 }
781
782 #[test]
783 fn test_single_qubit_x_gate() {
784 let mut mps = MpsState::new_with_seed(1, 42, MpsConfig::default()).unwrap();
785 let x = [[Complex::ZERO, Complex::ONE], [Complex::ONE, Complex::ZERO]];
787 mps.apply_single_qubit_gate(0, &x);
788 let t = &mps.tensors[0];
790 assert!(t.get(0, 0, 0).norm_sq() < 1e-20);
791 assert!((t.get(0, 1, 0).norm_sq() - 1.0).abs() < 1e-10);
792 }
793
794 #[test]
795 fn test_single_qubit_h_gate() {
796 let mut mps = MpsState::new_with_seed(1, 42, MpsConfig::default()).unwrap();
797 let h = std::f64::consts::FRAC_1_SQRT_2;
798 let hc = Complex::new(h, 0.0);
799 let h_gate = [[hc, hc], [hc, -hc]];
800 mps.apply_single_qubit_gate(0, &h_gate);
801 let t = &mps.tensors[0];
803 assert!((t.get(0, 0, 0).norm_sq() - 0.5).abs() < 1e-10);
804 assert!((t.get(0, 1, 0).norm_sq() - 0.5).abs() < 1e-10);
805 }
806
807 #[test]
808 fn test_cnot_creates_bell_state() {
809 let mut mps = MpsState::new_with_seed(2, 42, MpsConfig::default()).unwrap();
810 let h = std::f64::consts::FRAC_1_SQRT_2;
812 let hc = Complex::new(h, 0.0);
813 let h_gate = [[hc, hc], [hc, -hc]];
814 mps.apply_single_qubit_gate(0, &h_gate);
815
816 let c0 = Complex::ZERO;
818 let c1 = Complex::ONE;
819 let cnot = [
820 [c1, c0, c0, c0],
821 [c0, c1, c0, c0],
822 [c0, c0, c0, c1],
823 [c0, c0, c1, c0],
824 ];
825 mps.apply_two_qubit_gate(0, 1, &cnot).unwrap();
826 assert!(mps.max_bond_dimension() >= 2);
828 }
829
830 #[test]
831 fn test_measurement_deterministic() {
832 let mut mps = MpsState::new_with_seed(1, 42, MpsConfig::default()).unwrap();
834 let outcome = mps.measure(0).unwrap();
835 assert!(!outcome.result);
836 assert!((outcome.probability - 1.0).abs() < 1e-10);
837 }
838
839 #[test]
840 fn test_gate_dispatch() {
841 let mut mps = MpsState::new_with_seed(2, 42, MpsConfig::default()).unwrap();
842 let outcomes = mps.apply_gate(&Gate::H(0)).unwrap();
843 assert!(outcomes.is_empty());
844 let outcomes = mps.apply_gate(&Gate::CNOT(0, 1)).unwrap();
845 assert!(outcomes.is_empty());
846 }
847
848 #[test]
849 fn test_non_adjacent_two_qubit_gate() {
850 let mut mps = MpsState::new_with_seed(4, 42, MpsConfig::default()).unwrap();
851 let c0 = Complex::ZERO;
853 let c1 = Complex::ONE;
854 let cnot = [
855 [c1, c0, c0, c0],
856 [c0, c1, c0, c0],
857 [c0, c0, c0, c1],
858 [c0, c0, c1, c0],
859 ];
860 mps.apply_two_qubit_gate(0, 3, &cnot).unwrap();
862 }
863}