scirs2_optimize/quantum_classical/
statevector.rs1use crate::error::OptimizeError;
7use crate::quantum_classical::QcResult;
8
9#[inline]
11pub fn cmul(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
12 (a.0 * b.0 - a.1 * b.1, a.0 * b.1 + a.1 * b.0)
13}
14
15#[inline]
17pub fn cadd(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
18 (a.0 + b.0, a.1 + b.1)
19}
20
21#[inline]
23pub fn cabs2(z: (f64, f64)) -> f64 {
24 z.0 * z.0 + z.1 * z.1
25}
26
27#[derive(Debug, Clone)]
33pub struct Statevector {
34 pub amplitudes: Vec<(f64, f64)>,
36 pub n_qubits: usize,
38}
39
40impl Statevector {
41 pub fn zero_state(n: usize) -> QcResult<Self> {
43 if n == 0 {
44 return Err(OptimizeError::ValueError(
45 "Number of qubits must be at least 1".to_string(),
46 ));
47 }
48 if n > 30 {
49 return Err(OptimizeError::ValueError(format!(
50 "Too many qubits: {n}; maximum supported is 30"
51 )));
52 }
53 let dim = 1usize << n;
54 let mut amplitudes = vec![(0.0_f64, 0.0_f64); dim];
55 amplitudes[0] = (1.0, 0.0);
56 Ok(Self {
57 amplitudes,
58 n_qubits: n,
59 })
60 }
61
62 pub fn norm_squared(&self) -> f64 {
64 self.amplitudes.iter().map(|&z| cabs2(z)).sum()
65 }
66
67 pub fn norm(&self) -> f64 {
69 self.norm_squared().sqrt()
70 }
71
72 pub fn apply_hadamard(&mut self, qubit: usize) -> QcResult<()> {
78 self.check_qubit(qubit)?;
79 let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
80 let dim = self.amplitudes.len();
81 let bit = 1usize << qubit;
82
83 for i in 0..dim {
84 if i & bit == 0 {
85 let j = i | bit; let (a, b) = (self.amplitudes[i], self.amplitudes[j]);
87 self.amplitudes[i] = ((a.0 + b.0) * inv_sqrt2, (a.1 + b.1) * inv_sqrt2);
88 self.amplitudes[j] = ((a.0 - b.0) * inv_sqrt2, (a.1 - b.1) * inv_sqrt2);
89 }
90 }
91 Ok(())
92 }
93
94 pub fn apply_rz(&mut self, qubit: usize, theta: f64) -> QcResult<()> {
101 self.check_qubit(qubit)?;
102 let half = theta / 2.0;
103 let phase0 = (half.cos(), -half.sin()); let phase1 = (half.cos(), half.sin()); let bit = 1usize << qubit;
106
107 for (i, amp) in self.amplitudes.iter_mut().enumerate() {
108 if i & bit == 0 {
109 *amp = cmul(*amp, phase0);
110 } else {
111 *amp = cmul(*amp, phase1);
112 }
113 }
114 Ok(())
115 }
116
117 pub fn apply_rx(&mut self, qubit: usize, theta: f64) -> QcResult<()> {
121 self.check_qubit(qubit)?;
122 let half = theta / 2.0;
123 let c = half.cos();
124 let s = half.sin();
125 let bit = 1usize << qubit;
126 let dim = self.amplitudes.len();
127
128 for i in 0..dim {
129 if i & bit == 0 {
130 let j = i | bit;
131 let (a, b) = (self.amplitudes[i], self.amplitudes[j]);
132 self.amplitudes[i] = cadd(
135 (a.0 * c, a.1 * c),
136 (b.1 * s, -b.0 * s), );
138 self.amplitudes[j] = cadd(
139 (a.1 * s, -a.0 * s), (b.0 * c, b.1 * c),
141 );
142 }
143 }
144 Ok(())
145 }
146
147 pub fn apply_ry(&mut self, qubit: usize, theta: f64) -> QcResult<()> {
151 self.check_qubit(qubit)?;
152 let half = theta / 2.0;
153 let c = half.cos();
154 let s = half.sin();
155 let bit = 1usize << qubit;
156 let dim = self.amplitudes.len();
157
158 for i in 0..dim {
159 if i & bit == 0 {
160 let j = i | bit;
161 let (a, b) = (self.amplitudes[i], self.amplitudes[j]);
162 self.amplitudes[i] = (a.0 * c - b.0 * s, a.1 * c - b.1 * s);
164 self.amplitudes[j] = (a.0 * s + b.0 * c, a.1 * s + b.1 * c);
165 }
166 }
167 Ok(())
168 }
169
170 pub fn apply_cnot(&mut self, control: usize, target: usize) -> QcResult<()> {
174 self.check_qubit(control)?;
175 self.check_qubit(target)?;
176 if control == target {
177 return Err(OptimizeError::ValueError(
178 "CNOT control and target must be different qubits".to_string(),
179 ));
180 }
181 let ctrl_bit = 1usize << control;
182 let tgt_bit = 1usize << target;
183 let dim = self.amplitudes.len();
184
185 for i in 0..dim {
186 if (i & ctrl_bit != 0) && (i & tgt_bit == 0) {
188 let j = i | tgt_bit; self.amplitudes.swap(i, j);
190 }
191 }
192 Ok(())
193 }
194
195 pub fn apply_rzz(&mut self, q1: usize, q2: usize, theta: f64) -> QcResult<()> {
203 self.check_qubit(q1)?;
204 self.check_qubit(q2)?;
205 if q1 == q2 {
206 return Err(OptimizeError::ValueError(
207 "Rzz: q1 and q2 must be different qubits".to_string(),
208 ));
209 }
210 let half = theta / 2.0;
211 let phase_same = (half.cos(), -half.sin()); let phase_diff = (half.cos(), half.sin()); let bit1 = 1usize << q1;
214 let bit2 = 1usize << q2;
215
216 for (i, amp) in self.amplitudes.iter_mut().enumerate() {
217 let b1 = (i & bit1) != 0;
218 let b2 = (i & bit2) != 0;
219 if b1 == b2 {
220 *amp = cmul(*amp, phase_same);
221 } else {
222 *amp = cmul(*amp, phase_diff);
223 }
224 }
225 Ok(())
226 }
227
228 pub fn expectation_zz(&self, q1: usize, q2: usize) -> QcResult<f64> {
232 self.check_qubit(q1)?;
233 self.check_qubit(q2)?;
234 let bit1 = 1usize << q1;
235 let bit2 = 1usize << q2;
236
237 let value = self
238 .amplitudes
239 .iter()
240 .enumerate()
241 .map(|(i, &)| {
242 let b1 = (i & bit1) != 0;
243 let b2 = (i & bit2) != 0;
244 let sign = if b1 == b2 { 1.0 } else { -1.0 };
245 sign * cabs2(amp)
246 })
247 .sum();
248 Ok(value)
249 }
250
251 pub fn expectation_z(&self, qubit: usize) -> QcResult<f64> {
255 self.check_qubit(qubit)?;
256 let bit = 1usize << qubit;
257
258 let value = self
259 .amplitudes
260 .iter()
261 .enumerate()
262 .map(|(i, &)| {
263 let sign = if i & bit == 0 { 1.0 } else { -1.0 };
264 sign * cabs2(amp)
265 })
266 .sum();
267 Ok(value)
268 }
269
270 pub fn most_probable_state(&self) -> usize {
272 self.amplitudes
273 .iter()
274 .enumerate()
275 .max_by(|(_, a), (_, b)| {
276 cabs2(**a)
277 .partial_cmp(&cabs2(**b))
278 .unwrap_or(std::cmp::Ordering::Equal)
279 })
280 .map(|(i, _)| i)
281 .unwrap_or(0)
282 }
283
284 pub fn index_to_bits(&self, idx: usize) -> Vec<bool> {
286 (0..self.n_qubits).map(|k| (idx >> k) & 1 == 1).collect()
287 }
288
289 fn check_qubit(&self, qubit: usize) -> QcResult<()> {
290 if qubit >= self.n_qubits {
291 return Err(OptimizeError::ValueError(format!(
292 "Qubit index {qubit} out of range for {}-qubit register",
293 self.n_qubits
294 )));
295 }
296 Ok(())
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 const EPS: f64 = 1e-10;
305
306 #[test]
307 fn test_zero_state_amplitude() {
308 let sv = Statevector::zero_state(3).unwrap();
309 assert_eq!(sv.amplitudes.len(), 8);
310 assert!((sv.amplitudes[0].0 - 1.0).abs() < EPS);
311 assert!(sv.amplitudes[0].1.abs() < EPS);
312 for & in &sv.amplitudes[1..] {
313 assert!(cabs2(amp) < EPS);
314 }
315 }
316
317 #[test]
318 fn test_hadamard_creates_plus_state() {
319 let mut sv = Statevector::zero_state(1).unwrap();
320 sv.apply_hadamard(0).unwrap();
321 let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
322 assert!((sv.amplitudes[0].0 - inv_sqrt2).abs() < EPS);
323 assert!((sv.amplitudes[1].0 - inv_sqrt2).abs() < EPS);
324 assert!(sv.amplitudes[0].1.abs() < EPS);
325 assert!(sv.amplitudes[1].1.abs() < EPS);
326 }
327
328 #[test]
329 fn test_cnot_10_to_11() {
330 let mut sv = Statevector::zero_state(2).unwrap();
332 sv.amplitudes[0] = (0.0, 0.0);
335 sv.amplitudes[2] = (1.0, 0.0); sv.apply_cnot(1, 0).unwrap();
337 assert!(cabs2(sv.amplitudes[3]) > 1.0 - EPS);
339 assert!(cabs2(sv.amplitudes[2]) < EPS);
340 }
341
342 #[test]
343 fn test_rz_phase_rotation() {
344 let mut sv = Statevector::zero_state(1).unwrap();
346 sv.apply_rz(0, std::f64::consts::PI).unwrap();
347 assert!(sv.amplitudes[0].0.abs() < EPS);
348 assert!((sv.amplitudes[0].1 + 1.0).abs() < EPS); }
350
351 #[test]
352 fn test_norm_preserved_after_gates() {
353 let mut sv = Statevector::zero_state(3).unwrap();
354 sv.apply_hadamard(0).unwrap();
355 sv.apply_hadamard(1).unwrap();
356 sv.apply_cnot(0, 1).unwrap();
357 sv.apply_rz(2, 0.7).unwrap();
358 sv.apply_rzz(0, 2, 1.2).unwrap();
359 let norm = sv.norm_squared();
360 assert!((norm - 1.0).abs() < 1e-12);
361 }
362
363 #[test]
364 fn test_expectation_z_basis_states() {
365 let sv0 = Statevector::zero_state(1).unwrap();
367 let ez0 = sv0.expectation_z(0).unwrap();
368 assert!((ez0 - 1.0).abs() < EPS);
369
370 let mut sv1 = Statevector::zero_state(1).unwrap();
372 sv1.amplitudes[0] = (0.0, 0.0);
373 sv1.amplitudes[1] = (1.0, 0.0);
374 let ez1 = sv1.expectation_z(0).unwrap();
375 assert!((ez1 + 1.0).abs() < EPS);
376 }
377
378 #[test]
379 fn test_expectation_zz() {
380 let sv = Statevector::zero_state(2).unwrap();
382 let ezz = sv.expectation_zz(0, 1).unwrap();
383 assert!((ezz - 1.0).abs() < EPS);
384
385 let mut sv2 = Statevector::zero_state(2).unwrap();
387 sv2.amplitudes[0] = (0.0, 0.0);
388 sv2.amplitudes[2] = (1.0, 0.0); let ezz2 = sv2.expectation_zz(0, 1).unwrap();
390 assert!((ezz2 + 1.0).abs() < EPS);
391 }
392}