scirs2_spatial/quantum_inspired/algorithms/
quantum_search.rs1use crate::error::{SpatialError, SpatialResult};
7use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::Complex64;
9use std::f64::consts::PI;
10
11use super::super::concepts::QuantumState;
13
14#[derive(Debug, Clone)]
42pub struct QuantumNearestNeighbor {
43 quantum_points: Vec<QuantumState>,
45 classical_points: Array2<f64>,
47 quantum_encoding: bool,
49 amplitude_amplification: bool,
51 grover_iterations: usize,
53}
54
55impl QuantumNearestNeighbor {
56 pub fn new(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
64 let classical_points = points.to_owned();
65 let quantum_points = Vec::new(); Ok(Self {
68 quantum_points,
69 classical_points,
70 quantum_encoding: false,
71 amplitude_amplification: false,
72 grover_iterations: 3,
73 })
74 }
75
76 pub fn with_quantum_encoding(mut self, enabled: bool) -> Self {
84 self.quantum_encoding = enabled;
85
86 if enabled {
87 if let Ok(encoded) = self.encode_reference_points() {
89 self.quantum_points = encoded;
90 }
91 }
92
93 self
94 }
95
96 pub fn with_amplitude_amplification(mut self, enabled: bool) -> Self {
104 self.amplitude_amplification = enabled;
105 self
106 }
107
108 pub fn with_grover_iterations(mut self, iterations: usize) -> Self {
116 self.grover_iterations = iterations;
117 self
118 }
119
120 pub fn query_quantum(
133 &self,
134 query_point: &ArrayView1<f64>,
135 k: usize,
136 ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
137 let n_points = self.classical_points.nrows();
138
139 if k > n_points {
140 return Err(SpatialError::InvalidInput(
141 "k cannot be larger than number of points".to_string(),
142 ));
143 }
144
145 let mut distances = if self.quantum_encoding && !self.quantum_points.is_empty() {
146 self.quantum_distance_computation(query_point)?
148 } else {
149 self.classical_distance_computation(query_point)
151 };
152
153 if self.amplitude_amplification {
155 distances = self.apply_amplitude_amplification(distances)?;
156 }
157
158 let mut indexed_distances: Vec<(usize, f64)> = distances.into_iter().enumerate().collect();
160 indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
161
162 let indices: Vec<usize> = indexed_distances
163 .iter()
164 .take(k)
165 .map(|(i_, _)| *i_)
166 .collect();
167 let dists: Vec<f64> = indexed_distances.iter().take(k).map(|(_, d)| *d).collect();
168
169 Ok((indices, dists))
170 }
171
172 fn encode_reference_points(&self) -> SpatialResult<Vec<QuantumState>> {
177 let (n_points, n_dims) = self.classical_points.dim();
178 let mut encoded_points = Vec::new();
179
180 for i in 0..n_points {
181 let point = self.classical_points.row(i);
182
183 let numqubits = (n_dims).next_power_of_two().trailing_zeros() as usize + 2;
185 let mut quantum_point = QuantumState::zero_state(numqubits);
186
187 for (dim, &coord) in point.iter().enumerate() {
189 if dim < numqubits - 1 {
190 let normalized_coord = (coord + 10.0) / 20.0; let angle = normalized_coord.clamp(0.0, 1.0) * PI;
193 quantum_point.phase_rotation(dim, angle)?;
194 }
195 }
196
197 for i in 0..numqubits - 1 {
199 quantum_point.controlled_rotation(i, i + 1, PI / 4.0)?;
200 }
201
202 encoded_points.push(quantum_point);
203 }
204
205 Ok(encoded_points)
206 }
207
208 fn quantum_distance_computation(
213 &self,
214 query_point: &ArrayView1<f64>,
215 ) -> SpatialResult<Vec<f64>> {
216 let n_dims = query_point.len();
217 let mut distances = Vec::new();
218
219 let numqubits = n_dims.next_power_of_two().trailing_zeros() as usize + 2;
221 let mut query_state = QuantumState::zero_state(numqubits);
222
223 for (dim, &coord) in query_point.iter().enumerate() {
224 if dim < numqubits - 1 {
225 let normalized_coord = (coord + 10.0) / 20.0;
226 let angle = normalized_coord.clamp(0.0, 1.0) * PI;
227 query_state.phase_rotation(dim, angle)?;
228 }
229 }
230
231 for i in 0..numqubits - 1 {
233 query_state.controlled_rotation(i, i + 1, PI / 4.0)?;
234 }
235
236 for quantum_ref in &self.quantum_points {
238 let fidelity =
239 QuantumNearestNeighbor::calculate_quantum_fidelity(&query_state, quantum_ref);
240
241 let quantum_distance = 1.0 - fidelity;
243 distances.push(quantum_distance);
244 }
245
246 Ok(distances)
247 }
248
249 fn classical_distance_computation(&self, query_point: &ArrayView1<f64>) -> Vec<f64> {
254 let mut distances = Vec::new();
255
256 for i in 0..self.classical_points.nrows() {
257 let ref_point = self.classical_points.row(i);
258 let distance: f64 = query_point
259 .iter()
260 .zip(ref_point.iter())
261 .map(|(&a, &b)| (a - b).powi(2))
262 .sum::<f64>()
263 .sqrt();
264
265 distances.push(distance);
266 }
267
268 distances
269 }
270
271 fn calculate_quantum_fidelity(state1: &QuantumState, state2: &QuantumState) -> f64 {
275 if state1.amplitudes.len() != state2.amplitudes.len() {
276 return 0.0;
277 }
278
279 let inner_product: Complex64 = state1
281 .amplitudes
282 .iter()
283 .zip(state2.amplitudes.iter())
284 .map(|(a, b)| a.conj() * b)
285 .sum();
286
287 inner_product.norm_sqr()
289 }
290
291 fn apply_amplitude_amplification(&self, mut distances: Vec<f64>) -> SpatialResult<Vec<f64>> {
296 if distances.is_empty() {
297 return Ok(distances);
298 }
299
300 let avg_distance: f64 = distances.iter().sum::<f64>() / distances.len() as f64;
302
303 for _ in 0..self.grover_iterations {
305 #[allow(clippy::manual_slice_fill)]
307 for distance in &mut distances {
308 *distance = 2.0 * avg_distance - *distance;
309 }
310
311 for distance in &mut distances {
313 if *distance < avg_distance {
314 *distance *= 0.9; }
316 }
317 }
318
319 let min_distance = distances.iter().fold(f64::INFINITY, |a, &b| a.min(b));
321 if min_distance < 0.0 {
322 for distance in &mut distances {
323 *distance -= min_distance;
324 }
325 }
326
327 Ok(distances)
328 }
329
330 pub fn len(&self) -> usize {
332 self.classical_points.nrows()
333 }
334
335 pub fn is_empty(&self) -> bool {
337 self.classical_points.nrows() == 0
338 }
339
340 pub fn classical_points(&self) -> &Array2<f64> {
342 &self.classical_points
343 }
344
345 pub fn is_quantum_enabled(&self) -> bool {
347 self.quantum_encoding
348 }
349
350 pub fn is_amplification_enabled(&self) -> bool {
352 self.amplitude_amplification
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use scirs2_core::ndarray::Array2;
360
361 #[test]
362 fn test_quantum_nearest_neighbor_creation() {
363 let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
364 let searcher = QuantumNearestNeighbor::new(&points.view()).unwrap();
365
366 assert_eq!(searcher.len(), 3);
367 assert!(!searcher.is_quantum_enabled());
368 assert!(!searcher.is_amplification_enabled());
369 }
370
371 #[test]
372 fn test_classical_search() {
373 let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
374 let searcher = QuantumNearestNeighbor::new(&points.view()).unwrap();
375
376 let query = scirs2_core::ndarray::arr1(&[0.5, 0.5]);
377 let (indices, distances) = searcher.query_quantum(&query.view(), 2).unwrap();
378
379 assert_eq!(indices.len(), 2);
380 assert_eq!(distances.len(), 2);
381 assert!(indices.contains(&0));
383 assert!(indices.contains(&1));
384 }
385
386 #[test]
387 fn test_quantum_configuration() {
388 let points = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
389 let searcher = QuantumNearestNeighbor::new(&points.view())
390 .unwrap()
391 .with_quantum_encoding(true)
392 .with_amplitude_amplification(true)
393 .with_grover_iterations(5);
394
395 assert!(searcher.is_quantum_enabled());
396 assert!(searcher.is_amplification_enabled());
397 assert_eq!(searcher.grover_iterations, 5);
398 }
399
400 #[test]
401 fn test_empty_points() {
402 let points = Array2::from_shape_vec((0, 2), vec![]).unwrap();
403 let searcher = QuantumNearestNeighbor::new(&points.view()).unwrap();
404
405 assert!(searcher.is_empty());
406 assert_eq!(searcher.len(), 0);
407 }
408
409 #[test]
410 fn test_invalid_k() {
411 let points = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
412 let searcher = QuantumNearestNeighbor::new(&points.view()).unwrap();
413
414 let query = scirs2_core::ndarray::arr1(&[0.5, 0.5]);
415 let result = searcher.query_quantum(&query.view(), 5);
416
417 assert!(result.is_err());
418 }
419}