scirs2_spatial/quantum_inspired/concepts/
mod.rs1use crate::error::{SpatialError, SpatialResult};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::Complex64;
10use scirs2_core::random::Rng;
11use std::f64::consts::SQRT_2;
12
13pub type QuantumAmplitude = Complex64;
15
16#[derive(Debug, Clone)]
43pub struct QuantumState {
44 pub amplitudes: Array1<QuantumAmplitude>,
46 pub numqubits: usize,
48}
49
50impl QuantumState {
51 pub fn new(amplitudes: Array1<QuantumAmplitude>) -> SpatialResult<Self> {
62 let num_states = amplitudes.len();
63 if !num_states.is_power_of_two() {
64 return Err(SpatialError::InvalidInput(
65 "Number of amplitudes must be a power of 2".to_string(),
66 ));
67 }
68
69 let numqubits = (num_states as f64).log2() as usize;
70
71 Ok(Self {
72 amplitudes,
73 numqubits,
74 })
75 }
76
77 pub fn zero_state(numqubits: usize) -> Self {
88 let num_states = 1 << numqubits;
89 let mut amplitudes = Array1::zeros(num_states);
90 amplitudes[0] = Complex64::new(1.0, 0.0);
91
92 Self {
93 amplitudes,
94 numqubits,
95 }
96 }
97
98 pub fn uniform_superposition(numqubits: usize) -> Self {
109 let num_states = 1 << numqubits;
110 let amplitude = Complex64::new(1.0 / (num_states as f64).sqrt(), 0.0);
111 let amplitudes = Array1::from_elem(num_states, amplitude);
112
113 Self {
114 amplitudes,
115 numqubits,
116 }
117 }
118
119 pub fn measure(&self) -> usize {
128 let mut rng = scirs2_core::random::rng();
129
130 let probabilities: Vec<f64> = self.amplitudes.iter().map(|amp| amp.norm_sqr()).collect();
132
133 let mut cumulative = 0.0;
135 let random_value = rng.gen_range(0.0..1.0);
136
137 for (i, &prob) in probabilities.iter().enumerate() {
138 cumulative += prob;
139 if random_value <= cumulative {
140 return i;
141 }
142 }
143
144 probabilities.len() - 1
146 }
147
148 pub fn probability(&self, state: usize) -> f64 {
159 if state >= self.amplitudes.len() {
160 0.0
161 } else {
162 self.amplitudes[state].norm_sqr()
163 }
164 }
165
166 pub fn hadamard(&mut self, qubit: usize) -> SpatialResult<()> {
178 if qubit >= self.numqubits {
179 return Err(SpatialError::InvalidInput(format!(
180 "Qubit index {qubit} out of range"
181 )));
182 }
183
184 let mut new_amplitudes = self.amplitudes.clone();
185 let qubit_mask = 1 << qubit;
186
187 for i in 0..self.amplitudes.len() {
188 let j = i ^ qubit_mask; if i < j {
190 let amp_i = self.amplitudes[i];
191 let amp_j = self.amplitudes[j];
192
193 new_amplitudes[i] = (amp_i + amp_j) / SQRT_2;
194 new_amplitudes[j] = (amp_i - amp_j) / SQRT_2;
195 }
196 }
197
198 self.amplitudes = new_amplitudes;
199 Ok(())
200 }
201
202 pub fn phase_rotation(&mut self, qubit: usize, angle: f64) -> SpatialResult<()> {
214 if qubit >= self.numqubits {
215 return Err(SpatialError::InvalidInput(format!(
216 "Qubit index {qubit} out of range"
217 )));
218 }
219
220 let phase = Complex64::new(0.0, angle).exp();
221 let qubit_mask = 1 << qubit;
222
223 for i in 0..self.amplitudes.len() {
224 if (i & qubit_mask) != 0 {
225 self.amplitudes[i] *= phase;
226 }
227 }
228
229 Ok(())
230 }
231
232 pub fn controlled_rotation(
245 &mut self,
246 control: usize,
247 target: usize,
248 angle: f64,
249 ) -> SpatialResult<()> {
250 if control >= self.numqubits || target >= self.numqubits {
251 return Err(SpatialError::InvalidInput(
252 "Qubit indices out of range".to_string(),
253 ));
254 }
255
256 let control_mask = 1 << control;
257 let target_mask = 1 << target;
258 let cos_half = (angle / 2.0).cos();
259 let sin_half = (angle / 2.0).sin();
260
261 let mut new_amplitudes = self.amplitudes.clone();
262
263 for i in 0..self.amplitudes.len() {
264 if (i & control_mask) != 0 {
265 let j = i ^ target_mask; if i < j {
268 let amp_i = self.amplitudes[i];
269 let amp_j = self.amplitudes[j];
270
271 new_amplitudes[i] = Complex64::new(cos_half, 0.0) * amp_i
272 - Complex64::new(0.0, sin_half) * amp_j;
273 new_amplitudes[j] = Complex64::new(0.0, sin_half) * amp_i
274 + Complex64::new(cos_half, 0.0) * amp_j;
275 }
276 }
277 }
278
279 self.amplitudes = new_amplitudes;
280 Ok(())
281 }
282
283 pub fn pauli_x(&mut self, qubit: usize) -> SpatialResult<()> {
293 if qubit >= self.numqubits {
294 return Err(SpatialError::InvalidInput(format!(
295 "Qubit index {qubit} out of range"
296 )));
297 }
298
299 let qubit_mask = 1 << qubit;
300 let mut new_amplitudes = self.amplitudes.clone();
301
302 for i in 0..self.amplitudes.len() {
303 let j = i ^ qubit_mask; new_amplitudes[i] = self.amplitudes[j];
305 }
306
307 self.amplitudes = new_amplitudes;
308 Ok(())
309 }
310
311 pub fn pauli_y(&mut self, qubit: usize) -> SpatialResult<()> {
321 if qubit >= self.numqubits {
322 return Err(SpatialError::InvalidInput(format!(
323 "Qubit index {qubit} out of range"
324 )));
325 }
326
327 let qubit_mask = 1 << qubit;
328 let mut new_amplitudes = self.amplitudes.clone();
329 let i_complex = Complex64::new(0.0, 1.0);
330
331 for i in 0..self.amplitudes.len() {
332 let j = i ^ qubit_mask; if (i & qubit_mask) == 0 {
334 new_amplitudes[j] = i_complex * self.amplitudes[i];
336 new_amplitudes[i] = Complex64::new(0.0, 0.0);
337 } else {
338 new_amplitudes[j] = -i_complex * self.amplitudes[i];
340 new_amplitudes[i] = Complex64::new(0.0, 0.0);
341 }
342 }
343
344 self.amplitudes = new_amplitudes;
345 Ok(())
346 }
347
348 pub fn pauli_z(&mut self, qubit: usize) -> SpatialResult<()> {
358 if qubit >= self.numqubits {
359 return Err(SpatialError::InvalidInput(format!(
360 "Qubit index {qubit} out of range"
361 )));
362 }
363
364 let qubit_mask = 1 << qubit;
365
366 for i in 0..self.amplitudes.len() {
367 if (i & qubit_mask) != 0 {
368 self.amplitudes[i] *= -1.0;
369 }
370 }
371
372 Ok(())
373 }
374
375 pub fn num_qubits(&self) -> usize {
377 self.numqubits
378 }
379
380 pub fn num_states(&self) -> usize {
382 self.amplitudes.len()
383 }
384
385 pub fn is_normalized(&self) -> bool {
387 let norm_squared: f64 = self.amplitudes.iter().map(|amp| amp.norm_sqr()).sum();
388 (norm_squared - 1.0).abs() < 1e-10
389 }
390
391 pub fn normalize(&mut self) {
393 let norm: f64 = self
394 .amplitudes
395 .iter()
396 .map(|amp| amp.norm_sqr())
397 .sum::<f64>()
398 .sqrt();
399 if norm > 1e-10 {
400 for amp in self.amplitudes.iter_mut() {
401 *amp /= norm;
402 }
403 }
404 }
405
406 pub fn amplitude(&self, state: usize) -> Option<QuantumAmplitude> {
408 self.amplitudes.get(state).copied()
409 }
410
411 pub fn set_amplitude(
413 &mut self,
414 state: usize,
415 amplitude: QuantumAmplitude,
416 ) -> SpatialResult<()> {
417 if state >= self.amplitudes.len() {
418 return Err(SpatialError::InvalidInput(
419 "State index out of range".to_string(),
420 ));
421 }
422 self.amplitudes[state] = amplitude;
423 Ok(())
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use std::f64::consts::PI;
431
432 #[test]
433 fn test_zero_state_creation() {
434 let state = QuantumState::zero_state(2);
435 assert_eq!(state.num_qubits(), 2);
436 assert_eq!(state.num_states(), 4);
437 assert_eq!(state.probability(0), 1.0);
438 assert_eq!(state.probability(1), 0.0);
439 assert!(state.is_normalized());
440 }
441
442 #[test]
443 fn test_uniform_superposition() {
444 let state = QuantumState::uniform_superposition(2);
445 assert_eq!(state.num_qubits(), 2);
446 assert_eq!(state.num_states(), 4);
447
448 for i in 0..4 {
450 assert!((state.probability(i) - 0.25).abs() < 1e-10);
451 }
452 assert!(state.is_normalized());
453 }
454
455 #[test]
456 fn test_hadamard_gate() {
457 let mut state = QuantumState::zero_state(1);
458 state.hadamard(0).unwrap();
459
460 assert!((state.probability(0) - 0.5).abs() < 1e-10);
462 assert!((state.probability(1) - 0.5).abs() < 1e-10);
463 assert!(state.is_normalized());
464 }
465
466 #[test]
467 fn test_pauli_x_gate() {
468 let mut state = QuantumState::zero_state(1);
469 state.pauli_x(0).unwrap();
470
471 assert_eq!(state.probability(0), 0.0);
473 assert_eq!(state.probability(1), 1.0);
474 assert!(state.is_normalized());
475 }
476
477 #[test]
478 fn test_phase_rotation() {
479 let mut state = QuantumState::uniform_superposition(1);
480 state.phase_rotation(0, PI).unwrap();
481
482 assert!(state.is_normalized());
484 assert!((state.probability(0) - 0.5).abs() < 1e-10);
485 assert!((state.probability(1) - 0.5).abs() < 1e-10);
486 }
487
488 #[test]
489 fn test_controlled_rotation() {
490 let mut state = QuantumState::zero_state(2);
491 state.hadamard(0).unwrap();
493 state.controlled_rotation(0, 1, PI).unwrap();
494
495 assert!(state.is_normalized());
496 assert!((state.probability(0) - 0.5).abs() < 1e-10);
498 assert!((state.probability(3) - 0.5).abs() < 1e-10);
499 }
500
501 #[test]
502 fn test_measurement() {
503 let state = QuantumState::zero_state(2);
504 let result = state.measure();
505 assert_eq!(result, 0); }
507
508 #[test]
509 fn test_invalid_qubit_index() {
510 let mut state = QuantumState::zero_state(2);
511 assert!(state.hadamard(2).is_err()); assert!(state.pauli_x(2).is_err());
513 assert!(state.phase_rotation(2, PI).is_err());
514 }
515
516 #[test]
517 fn test_amplitude_access() {
518 let state = QuantumState::zero_state(2);
519 assert_eq!(state.amplitude(0), Some(Complex64::new(1.0, 0.0)));
520 assert_eq!(state.amplitude(1), Some(Complex64::new(0.0, 0.0)));
521 assert_eq!(state.amplitude(10), None); }
523
524 #[test]
525 fn test_normalization() {
526 let amplitudes = Array1::from_vec(vec![
527 Complex64::new(2.0, 0.0),
528 Complex64::new(2.0, 0.0),
529 Complex64::new(0.0, 0.0),
530 Complex64::new(0.0, 0.0),
531 ]);
532 let mut state = QuantumState::new(amplitudes).unwrap();
533
534 assert!(!state.is_normalized());
535 state.normalize();
536 assert!(state.is_normalized());
537 assert!((state.probability(0) - 0.5).abs() < 1e-10);
538 assert!((state.probability(1) - 0.5).abs() < 1e-10);
539 }
540}