1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip};
8use scirs2_core::numeric::{Float, FromPrimitive, Zero};
9use scirs2_core::random::{Rng, SeedableRng};
10use serde::{Deserialize, Serialize};
11use std::fmt::Debug;
12
13use crate::error::{ClusteringError, Result};
14use crate::vq::euclidean_distance;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QuantumConfig {
19 pub n_quantum_states: usize,
21 pub decoherence_factor: f64,
23 pub quantum_iterations: usize,
25 pub entanglement_strength: f64,
27 pub measurement_threshold: f64,
29 pub temperature: f64,
31 pub cooling_rate: f64,
33}
34
35impl Default for QuantumConfig {
36 fn default() -> Self {
37 Self {
38 n_quantum_states: 8,
39 decoherence_factor: 0.95,
40 quantum_iterations: 50,
41 entanglement_strength: 0.3,
42 measurement_threshold: 0.1,
43 temperature: 1.0,
44 cooling_rate: 0.95,
45 }
46 }
47}
48
49pub struct QuantumKMeans<F: Float> {
55 config: QuantumConfig,
56 n_clusters: usize,
57 quantum_centroids: Option<Array2<F>>,
58 quantum_amplitudes: Option<Array2<F>>,
59 classical_centroids: Option<Array2<F>>,
60 quantum_states: Vec<QuantumState<F>>,
61 initialized: bool,
62}
63
64#[derive(Debug, Clone)]
66pub struct QuantumState<F: Float> {
67 amplitude: F,
69 phase: F,
71 cluster_probabilities: Array1<F>,
73}
74
75impl<F: Float + FromPrimitive + Debug> QuantumKMeans<F> {
76 pub fn new(nclusters: usize, config: QuantumConfig) -> Self {
78 Self {
79 config,
80 n_clusters: nclusters,
81 quantum_centroids: None,
82 quantum_amplitudes: None,
83 classical_centroids: None,
84 quantum_states: Vec::new(),
85 initialized: false,
86 }
87 }
88
89 pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
91 let (n_samples, n_features) = data.dim();
92
93 if n_samples == 0 || n_features == 0 {
94 return Err(ClusteringError::InvalidInput(
95 "Data cannot be empty".to_string(),
96 ));
97 }
98
99 let mut quantum_centroids =
101 Array2::zeros((self.config.n_quantum_states * self.n_clusters, n_features));
102 let mut quantum_amplitudes = Array2::zeros((self.config.n_quantum_states, self.n_clusters));
103
104 let mut classical_centroids = Array2::zeros((self.n_clusters, n_features));
106 self.initialize_classical_centroids(&mut classical_centroids, data)?;
107
108 for quantum_state in 0..self.config.n_quantum_states {
110 for cluster in 0..self.n_clusters {
111 let idx = quantum_state * self.n_clusters + cluster;
112
113 let noise_scale = F::from(0.1).unwrap();
115 for feature in 0..n_features {
116 let noise = self.quantum_noise() * noise_scale;
117 quantum_centroids[[idx, feature]] =
118 classical_centroids[[cluster, feature]] + noise;
119 }
120
121 quantum_amplitudes[[quantum_state, cluster]] =
123 F::from(1.0 / (self.config.n_quantum_states as f64).sqrt()).unwrap();
124 }
125 }
126
127 self.quantum_states = Vec::with_capacity(n_samples);
129 for _ in 0..n_samples {
130 let amplitude = F::from(1.0 / (n_samples as f64).sqrt()).unwrap();
131 let phase = F::zero();
132 let cluster_probabilities = Array1::from_elem(
133 self.n_clusters,
134 F::from(1.0 / self.n_clusters as f64).unwrap(),
135 );
136
137 self.quantum_states.push(QuantumState {
138 amplitude,
139 phase,
140 cluster_probabilities,
141 });
142 }
143
144 self.quantum_centroids = Some(quantum_centroids);
145 self.quantum_amplitudes = Some(quantum_amplitudes);
146 self.classical_centroids = Some(classical_centroids);
147 self.initialized = true;
148
149 self.quantum_optimization(data)?;
151
152 Ok(())
153 }
154
155 fn initialize_classical_centroids(
157 &self,
158 centroids: &mut Array2<F>,
159 data: ArrayView2<F>,
160 ) -> Result<()> {
161 let n_samples = data.nrows();
162
163 centroids.row_mut(0).assign(&data.row(0));
165
166 for i in 1..self.n_clusters {
168 let mut distances = Array1::zeros(n_samples);
169 let mut total_distance = F::zero();
170
171 for j in 0..n_samples {
172 let mut min_dist = F::infinity();
173 for k in 0..i {
174 let dist = euclidean_distance(data.row(j), centroids.row(k));
175 if dist < min_dist {
176 min_dist = dist;
177 }
178 }
179 distances[j] = min_dist * min_dist;
180 total_distance = total_distance + distances[j];
181 }
182
183 let target = total_distance * F::from(0.5).unwrap();
185 let mut cumsum = F::zero();
186 for j in 0..n_samples {
187 cumsum = cumsum + distances[j];
188 if cumsum >= target {
189 centroids.row_mut(i).assign(&data.row(j));
190 break;
191 }
192 }
193 }
194
195 Ok(())
196 }
197
198 fn quantum_noise(&self) -> F {
200 let mut rng = scirs2_core::random::thread_rng();
202 F::from(rng.gen_range(-1.0..1.0)).unwrap()
203 }
204
205 fn quantum_optimization(&mut self, data: ArrayView2<F>) -> Result<()> {
207 let mut temperature = F::from(self.config.temperature).unwrap();
208 let cooling_rate = F::from(self.config.cooling_rate).unwrap();
209
210 for iteration in 0..self.config.quantum_iterations {
211 self.quantum_evolution_step(data)?;
213
214 self.apply_entanglement()?;
216
217 self.measure_and_decohere(temperature)?;
219
220 temperature = temperature * cooling_rate;
222
223 if iteration % 10 == 0 {
225 self.update_classical_centroids(data)?;
226 }
227 }
228
229 Ok(())
230 }
231
232 fn quantum_evolution_step(&mut self, data: ArrayView2<F>) -> Result<()> {
234 let quantum_centroids = self.quantum_centroids.as_ref().unwrap();
235 let quantum_amplitudes = self.quantum_amplitudes.as_ref().unwrap();
236
237 for (point_idx, point) in data.rows().into_iter().enumerate() {
238 let quantum_state = &mut self.quantum_states[point_idx];
239
240 for cluster in 0..self.n_clusters {
242 let mut total_amplitude = F::zero();
243
244 for quantum_idx in 0..self.config.n_quantum_states {
245 let centroid_idx = quantum_idx * self.n_clusters + cluster;
246 let centroid = quantum_centroids.row(centroid_idx);
247 let distance = euclidean_distance(point, centroid);
248
249 let amplitude = quantum_amplitudes[[quantum_idx, cluster]];
251 let quantum_weight =
252 amplitude * F::from((-distance.to_f64().unwrap()).exp()).unwrap();
253 total_amplitude = total_amplitude + quantum_weight;
254 }
255
256 quantum_state.cluster_probabilities[cluster] = total_amplitude;
257 }
258
259 let sum: F = quantum_state.cluster_probabilities.sum();
261 if sum > F::zero() {
262 quantum_state
263 .cluster_probabilities
264 .mapv_inplace(|x| x / sum);
265 }
266 }
267
268 Ok(())
269 }
270
271 fn apply_entanglement(&mut self) -> Result<()> {
273 let entanglement = F::from(self.config.entanglement_strength).unwrap();
274
275 for i in 0..(self.quantum_states.len() - 1) {
277 let (left, right) = self.quantum_states.split_at_mut(i + 1);
278 let state_i = &mut left[i];
279 let state_j = &mut right[0];
280
281 for cluster in 0..self.n_clusters {
283 let prob_i = state_i.cluster_probabilities[cluster];
284 let prob_j = state_j.cluster_probabilities[cluster];
285
286 let entangled_i = prob_i + entanglement * (prob_j - prob_i);
287 let entangled_j = prob_j + entanglement * (prob_i - prob_j);
288
289 state_i.cluster_probabilities[cluster] = entangled_i;
290 state_j.cluster_probabilities[cluster] = entangled_j;
291 }
292
293 let sum_i: F = state_i.cluster_probabilities.sum();
295 let sum_j: F = state_j.cluster_probabilities.sum();
296
297 if sum_i > F::zero() {
298 state_i.cluster_probabilities.mapv_inplace(|x| x / sum_i);
299 }
300 if sum_j > F::zero() {
301 state_j.cluster_probabilities.mapv_inplace(|x| x / sum_j);
302 }
303 }
304
305 Ok(())
306 }
307
308 fn measure_and_decohere(&mut self, temperature: F) -> Result<()> {
310 let decoherence = F::from(self.config.decoherence_factor).unwrap();
311 let threshold = F::from(self.config.measurement_threshold).unwrap();
312 let quantum_noise = self.quantum_noise();
313
314 for quantum_state in &mut self.quantum_states {
315 quantum_state.amplitude = quantum_state.amplitude * decoherence;
317
318 let thermal_noise = temperature * quantum_noise * F::from(0.01).unwrap();
320 quantum_state.phase = quantum_state.phase + thermal_noise;
321
322 for cluster in 0..self.n_clusters {
324 if quantum_state.cluster_probabilities[cluster] > threshold {
325 quantum_state.cluster_probabilities[cluster] =
327 quantum_state.cluster_probabilities[cluster] * F::from(1.1).unwrap();
328 }
329 }
330
331 let sum: F = quantum_state.cluster_probabilities.sum();
333 if sum > F::zero() {
334 quantum_state
335 .cluster_probabilities
336 .mapv_inplace(|x| x / sum);
337 }
338 }
339
340 Ok(())
341 }
342
343 fn update_classical_centroids(&mut self, data: ArrayView2<F>) -> Result<()> {
345 let classical_centroids = self.classical_centroids.as_mut().unwrap();
346 classical_centroids.fill(F::zero());
347
348 let mut cluster_weights = Array1::zeros(self.n_clusters);
349
350 for (point_idx, point) in data.rows().into_iter().enumerate() {
352 let quantum_state = &self.quantum_states[point_idx];
353
354 for cluster in 0..self.n_clusters {
355 let weight = quantum_state.cluster_probabilities[cluster];
356 cluster_weights[cluster] = cluster_weights[cluster] + weight;
357
358 Zip::from(classical_centroids.row_mut(cluster))
360 .and(point)
361 .for_each(|centroid_val, &point_val| {
362 *centroid_val = *centroid_val + weight * point_val;
363 });
364 }
365 }
366
367 for cluster in 0..self.n_clusters {
369 if cluster_weights[cluster] > F::zero() {
370 let mut row = classical_centroids.row_mut(cluster);
371 row.mapv_inplace(|x| x / cluster_weights[cluster]);
372 }
373 }
374
375 Ok(())
376 }
377
378 pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
380 if !self.initialized {
381 return Err(ClusteringError::InvalidInput(
382 "Model must be fitted before prediction".to_string(),
383 ));
384 }
385
386 let classical_centroids = self.classical_centroids.as_ref().unwrap();
387 let n_samples = data.nrows();
388 let mut labels = Array1::zeros(n_samples);
389
390 for (i, point) in data.rows().into_iter().enumerate() {
391 let mut min_distance = F::infinity();
392 let mut best_cluster = 0;
393
394 for cluster in 0..self.n_clusters {
395 let distance = euclidean_distance(point, classical_centroids.row(cluster));
396 if distance < min_distance {
397 min_distance = distance;
398 best_cluster = cluster;
399 }
400 }
401
402 labels[i] = best_cluster;
403 }
404
405 Ok(labels)
406 }
407
408 pub fn cluster_centers(&self) -> Option<&Array2<F>> {
410 self.classical_centroids.as_ref()
411 }
412
413 pub fn quantum_states(&self) -> &[QuantumState<F>] {
415 &self.quantum_states
416 }
417}
418
419pub fn quantum_kmeans<F: Float + FromPrimitive + Debug>(
421 data: ArrayView2<F>,
422 n_clusters: usize,
423 config: Option<QuantumConfig>,
424) -> Result<(Array2<F>, Array1<usize>)> {
425 let config = config.unwrap_or_default();
426 let mut clusterer = QuantumKMeans::new(n_clusters, config);
427 clusterer.fit(data)?;
428
429 let centroids = clusterer.cluster_centers().unwrap().clone();
430 let labels = clusterer.predict(data)?;
431
432 Ok((centroids, labels))
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use scirs2_core::ndarray::Array2;
439
440 #[test]
441 fn test_quantum_config_default() {
442 let config = QuantumConfig::default();
443 assert_eq!(config.n_quantum_states, 8);
444 assert_eq!(config.quantum_iterations, 50);
445 assert!((config.decoherence_factor - 0.95).abs() < 1e-10);
446 }
447
448 #[test]
449 fn test_quantum_kmeans_creation() {
450 let config = QuantumConfig::default();
451 let clusterer = QuantumKMeans::<f64>::new(3, config);
452 assert_eq!(clusterer.n_clusters, 3);
453 assert!(!clusterer.initialized);
454 }
455
456 #[test]
457 fn test_quantum_kmeans_simple() {
458 let data = Array2::from_shape_vec(
459 (6, 2),
460 vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 5.0, 5.0, 6.0, 6.0],
461 )
462 .unwrap();
463 let config = QuantumConfig {
464 quantum_iterations: 10,
465 ..Default::default()
466 };
467
468 let result = quantum_kmeans(data.view(), 2, Some(config));
469 assert!(result.is_ok());
470
471 let (centroids, labels) = result.unwrap();
472 assert_eq!(centroids.nrows(), 2);
473 assert_eq!(centroids.ncols(), 2);
474 assert_eq!(labels.len(), 6);
475 }
476
477 #[test]
478 fn test_quantum_state() {
479 let amplitude = 0.5f64;
480 let phase = 0.0f64;
481 let cluster_probs = Array1::from_vec(vec![0.3, 0.7]);
482
483 let state = QuantumState {
484 amplitude,
485 phase,
486 cluster_probabilities: cluster_probs,
487 };
488
489 assert!((state.amplitude - 0.5).abs() < 1e-10);
490 assert_eq!(state.cluster_probabilities.len(), 2);
491 }
492
493 #[test]
494 fn test_quantum_noise_generation() {
495 let config = QuantumConfig::default();
496 let clusterer = QuantumKMeans::<f64>::new(2, config);
497
498 let noise = clusterer.quantum_noise();
499 assert!(noise >= -1.0 && noise <= 1.0);
500 }
501}