1use crate::error::{QuantumError, Result};
7use crate::gate::Gate;
8use crate::types::*;
9
10use rand::rngs::StdRng;
11use rand::Rng;
12use rand::SeedableRng;
13
14pub const MAX_QUBITS: u32 = 25;
16
17pub struct QuantumState {
19 amplitudes: Vec<Complex>,
20 num_qubits: u32,
21 rng: StdRng,
22 measurement_record: Vec<MeasurementOutcome>,
23}
24
25impl QuantumState {
30 pub fn new(num_qubits: u32) -> Result<Self> {
32 if num_qubits == 0 {
33 return Err(QuantumError::CircuitError(
34 "cannot create quantum state with 0 qubits".into(),
35 ));
36 }
37 if num_qubits > MAX_QUBITS {
38 return Err(QuantumError::QubitLimitExceeded {
39 requested: num_qubits,
40 maximum: MAX_QUBITS,
41 });
42 }
43 let n = 1usize << num_qubits;
44 let mut amplitudes = vec![Complex::ZERO; n];
45 amplitudes[0] = Complex::ONE;
46 Ok(Self {
47 amplitudes,
48 num_qubits,
49 rng: StdRng::from_entropy(),
50 measurement_record: Vec::new(),
51 })
52 }
53
54 pub fn new_with_seed(num_qubits: u32, seed: u64) -> Result<Self> {
56 if num_qubits > MAX_QUBITS {
57 return Err(QuantumError::QubitLimitExceeded {
58 requested: num_qubits,
59 maximum: MAX_QUBITS,
60 });
61 }
62 let n = 1usize << num_qubits;
63 let mut amplitudes = vec![Complex::ZERO; n];
64 amplitudes[0] = Complex::ONE;
65 Ok(Self {
66 amplitudes,
67 num_qubits,
68 rng: StdRng::seed_from_u64(seed),
69 measurement_record: Vec::new(),
70 })
71 }
72
73 pub fn from_amplitudes(amps: Vec<Complex>, num_qubits: u32) -> Result<Self> {
77 if num_qubits > MAX_QUBITS {
78 return Err(QuantumError::QubitLimitExceeded {
79 requested: num_qubits,
80 maximum: MAX_QUBITS,
81 });
82 }
83 let expected = 1usize << num_qubits;
84 if amps.len() != expected {
85 return Err(QuantumError::InvalidStateVector {
86 length: amps.len(),
87 num_qubits,
88 });
89 }
90 Ok(Self {
91 amplitudes: amps,
92 num_qubits,
93 rng: StdRng::from_entropy(),
94 measurement_record: Vec::new(),
95 })
96 }
97
98 pub fn num_qubits(&self) -> u32 {
103 self.num_qubits
104 }
105
106 pub fn num_amplitudes(&self) -> usize {
107 self.amplitudes.len()
108 }
109
110 pub fn state_vector(&self) -> &[Complex] {
111 &self.amplitudes
112 }
113
114 pub fn amplitudes_mut(&mut self) -> &mut [Complex] {
119 &mut self.amplitudes
120 }
121
122 pub fn probabilities(&self) -> Vec<f64> {
124 self.amplitudes.iter().map(|a| a.norm_sq()).collect()
125 }
126
127 pub fn probability_of_qubit(&self, qubit: QubitIndex) -> f64 {
129 let qubit_bit = 1usize << qubit;
130 let mut p1 = 0.0;
131 for (i, amp) in self.amplitudes.iter().enumerate() {
132 if i & qubit_bit != 0 {
133 p1 += amp.norm_sq();
134 }
135 }
136 p1
137 }
138
139 pub fn measurement_record(&self) -> &[MeasurementOutcome] {
140 &self.measurement_record
141 }
142
143 pub fn estimate_memory(num_qubits: u32) -> usize {
145 (1usize << num_qubits) * std::mem::size_of::<Complex>()
146 }
147
148 pub(crate) fn rng_mut(&mut self) -> &mut StdRng {
150 &mut self.rng
151 }
152
153 pub fn apply_gate(&mut self, gate: &Gate) -> Result<Vec<MeasurementOutcome>> {
159 for &q in gate.qubits().iter() {
161 self.validate_qubit(q)?;
162 }
163
164 match gate {
165 Gate::Barrier => Ok(vec![]),
166
167 Gate::Measure(q) => {
168 let outcome = self.measure(*q)?;
169 Ok(vec![outcome])
170 }
171
172 Gate::Reset(q) => {
173 self.reset_qubit(*q)?;
174 Ok(vec![])
175 }
176
177 Gate::CNOT(q1, q2)
179 | Gate::CZ(q1, q2)
180 | Gate::SWAP(q1, q2)
181 | Gate::Rzz(q1, q2, _) => {
182 if q1 == q2 {
183 return Err(QuantumError::CircuitError(format!(
184 "two-qubit gate requires distinct qubits, got {} and {}",
185 q1, q2
186 )));
187 }
188 let matrix = gate.matrix_2q().unwrap();
189 self.apply_two_qubit_gate(*q1, *q2, &matrix);
190 Ok(vec![])
191 }
192
193 other => {
195 if let Some(matrix) = other.matrix_1q() {
196 let q = other.qubits()[0];
197 self.apply_single_qubit_gate(q, &matrix);
198 Ok(vec![])
199 } else {
200 Err(QuantumError::CircuitError(format!(
201 "unsupported gate: {:?}",
202 other
203 )))
204 }
205 }
206 }
207 }
208
209 pub fn apply_single_qubit_gate(&mut self, qubit: QubitIndex, matrix: &[[Complex; 2]; 2]) {
218 let step = 1usize << qubit;
219 let n = self.amplitudes.len();
220
221 let mut block_start = 0;
222 while block_start < n {
223 for i in block_start..block_start + step {
224 let j = i + step;
225 let a = self.amplitudes[i]; let b = self.amplitudes[j]; self.amplitudes[i] = matrix[0][0] * a + matrix[0][1] * b;
228 self.amplitudes[j] = matrix[1][0] * a + matrix[1][1] * b;
229 }
230 block_start += step << 1;
231 }
232 }
233
234 pub fn apply_two_qubit_gate(
242 &mut self,
243 q1: QubitIndex,
244 q2: QubitIndex,
245 matrix: &[[Complex; 4]; 4],
246 ) {
247 let q1_bit = 1usize << q1;
248 let q2_bit = 1usize << q2;
249 let n = self.amplitudes.len();
250
251 for base in 0..n {
252 if base & q1_bit != 0 || base & q2_bit != 0 {
255 continue;
256 }
257
258 let idxs = [
259 base, base | q2_bit, base | q1_bit, base | q1_bit | q2_bit, ];
264
265 let vals = [
266 self.amplitudes[idxs[0]],
267 self.amplitudes[idxs[1]],
268 self.amplitudes[idxs[2]],
269 self.amplitudes[idxs[3]],
270 ];
271
272 for r in 0..4 {
273 self.amplitudes[idxs[r]] = matrix[r][0] * vals[0]
274 + matrix[r][1] * vals[1]
275 + matrix[r][2] * vals[2]
276 + matrix[r][3] * vals[3];
277 }
278 }
279 }
280
281 pub fn measure(&mut self, qubit: QubitIndex) -> Result<MeasurementOutcome> {
292 self.validate_qubit(qubit)?;
293
294 let qubit_bit = 1usize << qubit;
295 let n = self.amplitudes.len();
296
297 let mut p0: f64 = 0.0;
299 for i in 0..n {
300 if i & qubit_bit == 0 {
301 p0 += self.amplitudes[i].norm_sq();
302 }
303 }
304
305 let random: f64 = self.rng.gen();
306 let result = random >= p0; let prob = if result { 1.0 - p0 } else { p0 };
308
309 let norm_factor = if prob > 0.0 { 1.0 / prob.sqrt() } else { 0.0 };
311
312 for i in 0..n {
314 let bit_is_one = i & qubit_bit != 0;
315 if bit_is_one == result {
316 self.amplitudes[i] = self.amplitudes[i] * norm_factor;
317 } else {
318 self.amplitudes[i] = Complex::ZERO;
319 }
320 }
321
322 let outcome = MeasurementOutcome {
323 qubit,
324 result,
325 probability: prob,
326 };
327 self.measurement_record.push(outcome.clone());
328 Ok(outcome)
329 }
330
331 pub fn measure_all(&mut self) -> Result<Vec<MeasurementOutcome>> {
333 let mut outcomes = Vec::with_capacity(self.num_qubits as usize);
334 for q in 0..self.num_qubits {
335 outcomes.push(self.measure(q)?);
336 }
337 Ok(outcomes)
338 }
339
340 pub fn reset_qubit(&mut self, qubit: QubitIndex) -> Result<()> {
348 let outcome = self.measure(qubit)?;
349 if outcome.result {
350 let x_matrix = Gate::X(qubit).matrix_1q().unwrap();
352 self.apply_single_qubit_gate(qubit, &x_matrix);
353 }
354 Ok(())
355 }
356
357 pub fn expectation_value(&self, pauli: &PauliString) -> f64 {
366 let n = self.amplitudes.len();
367 let mut result = Complex::ZERO;
368
369 for i in 0..n {
370 let mut j = i;
371 let mut phase = Complex::ONE;
372
373 for &(qubit, op) in &pauli.ops {
374 let bit = (i >> qubit) & 1;
375 match op {
376 PauliOp::I => {}
377 PauliOp::X => {
378 j ^= 1usize << qubit;
379 }
380 PauliOp::Y => {
381 j ^= 1usize << qubit;
382 if bit == 0 {
384 phase = phase * Complex::I;
385 } else {
386 phase = phase * Complex::new(0.0, -1.0);
387 }
388 }
389 PauliOp::Z => {
390 if bit == 1 {
391 phase = -phase;
392 }
393 }
394 }
395 }
396
397 result += self.amplitudes[j].conj() * phase * self.amplitudes[i];
399 }
400
401 result.re
403 }
404
405 pub fn expectation_hamiltonian(&self, h: &Hamiltonian) -> f64 {
407 h.terms
408 .iter()
409 .map(|(coeff, ps)| coeff * self.expectation_value(ps))
410 .sum()
411 }
412
413 pub fn normalize(&mut self) {
419 let norm_sq: f64 = self.amplitudes.iter().map(|a| a.norm_sq()).sum();
420 if norm_sq > 0.0 {
421 let inv_norm = 1.0 / norm_sq.sqrt();
422 for a in self.amplitudes.iter_mut() {
423 *a = *a * inv_norm;
424 }
425 }
426 }
427
428 pub fn fidelity(&self, other: &QuantumState) -> f64 {
430 if self.num_qubits != other.num_qubits {
431 return 0.0;
432 }
433 let mut inner = Complex::ZERO;
434 for (a, b) in self.amplitudes.iter().zip(other.amplitudes.iter()) {
435 inner += a.conj() * *b;
436 }
437 inner.norm_sq()
438 }
439
440 fn validate_qubit(&self, qubit: QubitIndex) -> Result<()> {
445 if qubit >= self.num_qubits {
446 return Err(QuantumError::InvalidQubitIndex {
447 index: qubit,
448 num_qubits: self.num_qubits,
449 });
450 }
451 Ok(())
452 }
453}