Skip to main content

scirs2_optimize/quantum_classical/
statevector.rs

1//! N-qubit statevector simulator
2//!
3//! Represents quantum states as dense complex amplitude vectors of size 2^n,
4//! and applies single- and two-qubit gates exactly.
5
6use crate::error::OptimizeError;
7use crate::quantum_classical::QcResult;
8
9/// Complex multiply: (a+ib)(c+id) = (ac-bd) + i(ad+bc)
10#[inline]
11pub fn cmul(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
12    (a.0 * b.0 - a.1 * b.1, a.0 * b.1 + a.1 * b.0)
13}
14
15/// Complex add: (a+ib) + (c+id) = (a+c) + i(b+d)
16#[inline]
17pub fn cadd(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
18    (a.0 + b.0, a.1 + b.1)
19}
20
21/// Complex magnitude squared: |z|² = re² + im²
22#[inline]
23pub fn cabs2(z: (f64, f64)) -> f64 {
24    z.0 * z.0 + z.1 * z.1
25}
26
27/// N-qubit statevector: stores 2^n complex amplitudes as (re, im) pairs.
28///
29/// Qubit indexing convention: qubit 0 is the *least* significant bit.
30/// So the basis state |b_{n-1} ... b_1 b_0⟩ corresponds to index
31/// `b_0 + 2*b_1 + ... + 2^(n-1)*b_{n-1}`.
32#[derive(Debug, Clone)]
33pub struct Statevector {
34    /// Complex amplitudes: `amplitudes[k] = (re, im)` for basis state `|k⟩`
35    pub amplitudes: Vec<(f64, f64)>,
36    /// Number of qubits
37    pub n_qubits: usize,
38}
39
40impl Statevector {
41    /// Create the zero state |0...0⟩ for `n` qubits.
42    pub fn zero_state(n: usize) -> QcResult<Self> {
43        if n == 0 {
44            return Err(OptimizeError::ValueError(
45                "Number of qubits must be at least 1".to_string(),
46            ));
47        }
48        if n > 30 {
49            return Err(OptimizeError::ValueError(format!(
50                "Too many qubits: {n}; maximum supported is 30"
51            )));
52        }
53        let dim = 1usize << n;
54        let mut amplitudes = vec![(0.0_f64, 0.0_f64); dim];
55        amplitudes[0] = (1.0, 0.0);
56        Ok(Self {
57            amplitudes,
58            n_qubits: n,
59        })
60    }
61
62    /// Total norm squared: should remain 1.0 after any unitary operation.
63    pub fn norm_squared(&self) -> f64 {
64        self.amplitudes.iter().map(|&z| cabs2(z)).sum()
65    }
66
67    /// Total norm (Euclidean).
68    pub fn norm(&self) -> f64 {
69        self.norm_squared().sqrt()
70    }
71
72    /// Apply Hadamard gate to `qubit`.
73    ///
74    /// H = (1/√2) [[1, 1], [1, -1]]
75    ///
76    /// Pairs basis states that differ only in bit `qubit`.
77    pub fn apply_hadamard(&mut self, qubit: usize) -> QcResult<()> {
78        self.check_qubit(qubit)?;
79        let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
80        let dim = self.amplitudes.len();
81        let bit = 1usize << qubit;
82
83        for i in 0..dim {
84            if i & bit == 0 {
85                let j = i | bit; // partner with bit `qubit` set
86                let (a, b) = (self.amplitudes[i], self.amplitudes[j]);
87                self.amplitudes[i] = ((a.0 + b.0) * inv_sqrt2, (a.1 + b.1) * inv_sqrt2);
88                self.amplitudes[j] = ((a.0 - b.0) * inv_sqrt2, (a.1 - b.1) * inv_sqrt2);
89            }
90        }
91        Ok(())
92    }
93
94    /// Apply Rz(θ) gate to `qubit`.
95    ///
96    /// Rz(θ) = [[e^{-iθ/2}, 0], [0, e^{iθ/2}]]
97    ///
98    /// States with bit `qubit` = 0 get phase e^{-iθ/2}; states with bit = 1
99    /// get phase e^{+iθ/2}.
100    pub fn apply_rz(&mut self, qubit: usize, theta: f64) -> QcResult<()> {
101        self.check_qubit(qubit)?;
102        let half = theta / 2.0;
103        let phase0 = (half.cos(), -half.sin()); // e^{-iθ/2}
104        let phase1 = (half.cos(), half.sin()); // e^{+iθ/2}
105        let bit = 1usize << qubit;
106
107        for (i, amp) in self.amplitudes.iter_mut().enumerate() {
108            if i & bit == 0 {
109                *amp = cmul(*amp, phase0);
110            } else {
111                *amp = cmul(*amp, phase1);
112            }
113        }
114        Ok(())
115    }
116
117    /// Apply Rx(θ) gate to `qubit`.
118    ///
119    /// Rx(θ) = [[cos(θ/2), -i·sin(θ/2)], [-i·sin(θ/2), cos(θ/2)]]
120    pub fn apply_rx(&mut self, qubit: usize, theta: f64) -> QcResult<()> {
121        self.check_qubit(qubit)?;
122        let half = theta / 2.0;
123        let c = half.cos();
124        let s = half.sin();
125        let bit = 1usize << qubit;
126        let dim = self.amplitudes.len();
127
128        for i in 0..dim {
129            if i & bit == 0 {
130                let j = i | bit;
131                let (a, b) = (self.amplitudes[i], self.amplitudes[j]);
132                // |0⟩ → c|0⟩ - i·s|1⟩
133                // |1⟩ → -i·s|0⟩ + c|1⟩
134                self.amplitudes[i] = cadd(
135                    (a.0 * c, a.1 * c),
136                    (b.1 * s, -b.0 * s), // -i*s * b = (b.im*s, -b.re*s)
137                );
138                self.amplitudes[j] = cadd(
139                    (a.1 * s, -a.0 * s), // -i*s * a
140                    (b.0 * c, b.1 * c),
141                );
142            }
143        }
144        Ok(())
145    }
146
147    /// Apply Ry(θ) gate to `qubit`.
148    ///
149    /// Ry(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]]
150    pub fn apply_ry(&mut self, qubit: usize, theta: f64) -> QcResult<()> {
151        self.check_qubit(qubit)?;
152        let half = theta / 2.0;
153        let c = half.cos();
154        let s = half.sin();
155        let bit = 1usize << qubit;
156        let dim = self.amplitudes.len();
157
158        for i in 0..dim {
159            if i & bit == 0 {
160                let j = i | bit;
161                let (a, b) = (self.amplitudes[i], self.amplitudes[j]);
162                // |0⟩ → c|0⟩ - s|1⟩,  |1⟩ → s|0⟩ + c|1⟩
163                self.amplitudes[i] = (a.0 * c - b.0 * s, a.1 * c - b.1 * s);
164                self.amplitudes[j] = (a.0 * s + b.0 * c, a.1 * s + b.1 * c);
165            }
166        }
167        Ok(())
168    }
169
170    /// Apply CNOT gate with `control` qubit and `target` qubit.
171    ///
172    /// Flips the target qubit when the control qubit is |1⟩.
173    pub fn apply_cnot(&mut self, control: usize, target: usize) -> QcResult<()> {
174        self.check_qubit(control)?;
175        self.check_qubit(target)?;
176        if control == target {
177            return Err(OptimizeError::ValueError(
178                "CNOT control and target must be different qubits".to_string(),
179            ));
180        }
181        let ctrl_bit = 1usize << control;
182        let tgt_bit = 1usize << target;
183        let dim = self.amplitudes.len();
184
185        for i in 0..dim {
186            // Only process states where control is |1⟩ and target is |0⟩
187            if (i & ctrl_bit != 0) && (i & tgt_bit == 0) {
188                let j = i | tgt_bit; // same state but with target flipped to |1⟩
189                self.amplitudes.swap(i, j);
190            }
191        }
192        Ok(())
193    }
194
195    /// Apply Rzz(θ) gate to qubits `q1` and `q2`.
196    ///
197    /// Rzz(θ) = e^{-iθ/2 · Z⊗Z}
198    ///
199    /// For basis state |b1, b2⟩:
200    /// - If b1 XOR b2 = 0 (same bits): phase e^{-iθ/2}
201    /// - If b1 XOR b2 = 1 (different bits): phase e^{+iθ/2}
202    pub fn apply_rzz(&mut self, q1: usize, q2: usize, theta: f64) -> QcResult<()> {
203        self.check_qubit(q1)?;
204        self.check_qubit(q2)?;
205        if q1 == q2 {
206            return Err(OptimizeError::ValueError(
207                "Rzz: q1 and q2 must be different qubits".to_string(),
208            ));
209        }
210        let half = theta / 2.0;
211        let phase_same = (half.cos(), -half.sin()); // e^{-iθ/2} when ZZ eigenvalue = +1
212        let phase_diff = (half.cos(), half.sin()); // e^{+iθ/2} when ZZ eigenvalue = -1
213        let bit1 = 1usize << q1;
214        let bit2 = 1usize << q2;
215
216        for (i, amp) in self.amplitudes.iter_mut().enumerate() {
217            let b1 = (i & bit1) != 0;
218            let b2 = (i & bit2) != 0;
219            if b1 == b2 {
220                *amp = cmul(*amp, phase_same);
221            } else {
222                *amp = cmul(*amp, phase_diff);
223            }
224        }
225        Ok(())
226    }
227
228    /// Compute ⟨Z_i Z_j⟩ expectation value.
229    ///
230    /// Z_i Z_j has eigenvalue +1 when bits i and j are equal, -1 otherwise.
231    pub fn expectation_zz(&self, q1: usize, q2: usize) -> QcResult<f64> {
232        self.check_qubit(q1)?;
233        self.check_qubit(q2)?;
234        let bit1 = 1usize << q1;
235        let bit2 = 1usize << q2;
236
237        let value = self
238            .amplitudes
239            .iter()
240            .enumerate()
241            .map(|(i, &amp)| {
242                let b1 = (i & bit1) != 0;
243                let b2 = (i & bit2) != 0;
244                let sign = if b1 == b2 { 1.0 } else { -1.0 };
245                sign * cabs2(amp)
246            })
247            .sum();
248        Ok(value)
249    }
250
251    /// Compute ⟨Z_k⟩ expectation value.
252    ///
253    /// Z_k has eigenvalue +1 when bit k is 0, and -1 when bit k is 1.
254    pub fn expectation_z(&self, qubit: usize) -> QcResult<f64> {
255        self.check_qubit(qubit)?;
256        let bit = 1usize << qubit;
257
258        let value = self
259            .amplitudes
260            .iter()
261            .enumerate()
262            .map(|(i, &amp)| {
263                let sign = if i & bit == 0 { 1.0 } else { -1.0 };
264                sign * cabs2(amp)
265            })
266            .sum();
267        Ok(value)
268    }
269
270    /// Return the index of the basis state with the highest probability.
271    pub fn most_probable_state(&self) -> usize {
272        self.amplitudes
273            .iter()
274            .enumerate()
275            .max_by(|(_, a), (_, b)| {
276                cabs2(**a)
277                    .partial_cmp(&cabs2(**b))
278                    .unwrap_or(std::cmp::Ordering::Equal)
279            })
280            .map(|(i, _)| i)
281            .unwrap_or(0)
282    }
283
284    /// Decode a basis state index into a bitstring (qubit 0 = LSB).
285    pub fn index_to_bits(&self, idx: usize) -> Vec<bool> {
286        (0..self.n_qubits).map(|k| (idx >> k) & 1 == 1).collect()
287    }
288
289    fn check_qubit(&self, qubit: usize) -> QcResult<()> {
290        if qubit >= self.n_qubits {
291            return Err(OptimizeError::ValueError(format!(
292                "Qubit index {qubit} out of range for {}-qubit register",
293                self.n_qubits
294            )));
295        }
296        Ok(())
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    const EPS: f64 = 1e-10;
305
306    #[test]
307    fn test_zero_state_amplitude() {
308        let sv = Statevector::zero_state(3).unwrap();
309        assert_eq!(sv.amplitudes.len(), 8);
310        assert!((sv.amplitudes[0].0 - 1.0).abs() < EPS);
311        assert!(sv.amplitudes[0].1.abs() < EPS);
312        for &amp in &sv.amplitudes[1..] {
313            assert!(cabs2(amp) < EPS);
314        }
315    }
316
317    #[test]
318    fn test_hadamard_creates_plus_state() {
319        let mut sv = Statevector::zero_state(1).unwrap();
320        sv.apply_hadamard(0).unwrap();
321        let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
322        assert!((sv.amplitudes[0].0 - inv_sqrt2).abs() < EPS);
323        assert!((sv.amplitudes[1].0 - inv_sqrt2).abs() < EPS);
324        assert!(sv.amplitudes[0].1.abs() < EPS);
325        assert!(sv.amplitudes[1].1.abs() < EPS);
326    }
327
328    #[test]
329    fn test_cnot_10_to_11() {
330        // Prepare |10⟩: qubit 1 = 1, qubit 0 = 0 → index = 2 (binary 10)
331        let mut sv = Statevector::zero_state(2).unwrap();
332        // Apply X (= H Rz(π) H) to qubit 1 to set it to |1⟩
333        // Simpler: directly set amplitude
334        sv.amplitudes[0] = (0.0, 0.0);
335        sv.amplitudes[2] = (1.0, 0.0); // index 2 = |10⟩
336        sv.apply_cnot(1, 0).unwrap();
337        // After CNOT: control=1(=bit 1), target=0 → |10⟩ → |11⟩ (index 3)
338        assert!(cabs2(sv.amplitudes[3]) > 1.0 - EPS);
339        assert!(cabs2(sv.amplitudes[2]) < EPS);
340    }
341
342    #[test]
343    fn test_rz_phase_rotation() {
344        // Rz(π)|0⟩ should give e^{-iπ/2}|0⟩ = -i|0⟩
345        let mut sv = Statevector::zero_state(1).unwrap();
346        sv.apply_rz(0, std::f64::consts::PI).unwrap();
347        assert!(sv.amplitudes[0].0.abs() < EPS);
348        assert!((sv.amplitudes[0].1 + 1.0).abs() < EPS); // -i
349    }
350
351    #[test]
352    fn test_norm_preserved_after_gates() {
353        let mut sv = Statevector::zero_state(3).unwrap();
354        sv.apply_hadamard(0).unwrap();
355        sv.apply_hadamard(1).unwrap();
356        sv.apply_cnot(0, 1).unwrap();
357        sv.apply_rz(2, 0.7).unwrap();
358        sv.apply_rzz(0, 2, 1.2).unwrap();
359        let norm = sv.norm_squared();
360        assert!((norm - 1.0).abs() < 1e-12);
361    }
362
363    #[test]
364    fn test_expectation_z_basis_states() {
365        // |0⟩ → ⟨Z⟩ = +1
366        let sv0 = Statevector::zero_state(1).unwrap();
367        let ez0 = sv0.expectation_z(0).unwrap();
368        assert!((ez0 - 1.0).abs() < EPS);
369
370        // |1⟩ → ⟨Z⟩ = -1
371        let mut sv1 = Statevector::zero_state(1).unwrap();
372        sv1.amplitudes[0] = (0.0, 0.0);
373        sv1.amplitudes[1] = (1.0, 0.0);
374        let ez1 = sv1.expectation_z(0).unwrap();
375        assert!((ez1 + 1.0).abs() < EPS);
376    }
377
378    #[test]
379    fn test_expectation_zz() {
380        // |00⟩ → ⟨Z0 Z1⟩ = +1
381        let sv = Statevector::zero_state(2).unwrap();
382        let ezz = sv.expectation_zz(0, 1).unwrap();
383        assert!((ezz - 1.0).abs() < EPS);
384
385        // |10⟩ → bit0=0, bit1=1 → different → ⟨ZZ⟩ = -1
386        let mut sv2 = Statevector::zero_state(2).unwrap();
387        sv2.amplitudes[0] = (0.0, 0.0);
388        sv2.amplitudes[2] = (1.0, 0.0); // index 2 = bit1=1,bit0=0
389        let ezz2 = sv2.expectation_zz(0, 1).unwrap();
390        assert!((ezz2 + 1.0).abs() < EPS);
391    }
392}