1use crate::error::{ClusteringError, Result};
7use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use serde::{Deserialize, Serialize};
10
11use super::core::{GpuConfig, GpuContext};
12use super::memory::{GpuMemoryManager, MemoryTransfer};
13
14#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
16pub enum DistanceMetric {
17 Euclidean,
19 Manhattan,
21 Cosine,
23 Minkowski(f64),
25 SquaredEuclidean,
27 Chebyshev,
29 Hamming,
31}
32
33impl Default for DistanceMetric {
34 fn default() -> Self {
35 DistanceMetric::Euclidean
36 }
37}
38
39impl std::fmt::Display for DistanceMetric {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 DistanceMetric::Euclidean => write!(f, "euclidean"),
43 DistanceMetric::Manhattan => write!(f, "manhattan"),
44 DistanceMetric::Cosine => write!(f, "cosine"),
45 DistanceMetric::Minkowski(p) => write!(f, "minkowski(p={})", p),
46 DistanceMetric::SquaredEuclidean => write!(f, "squared_euclidean"),
47 DistanceMetric::Chebyshev => write!(f, "chebyshev"),
48 DistanceMetric::Hamming => write!(f, "hamming"),
49 }
50 }
51}
52
53#[derive(Debug)]
55pub struct GpuDistanceMatrix<F: Float> {
56 context: GpuContext,
58 metric: DistanceMetric,
60 gpu_data: Option<GpuArray<F>>,
62 tile_size: usize,
64 use_shared_memory: bool,
66 memory_manager: GpuMemoryManager,
68}
69
70#[derive(Debug)]
72pub struct GpuArray<F: Float> {
73 device_ptr: usize,
75 shape: [usize; 2],
77 element_size: usize,
79 on_device: bool,
81 _phantom: std::marker::PhantomData<F>,
82}
83
84impl<F: Float + FromPrimitive + Send + Sync> GpuDistanceMatrix<F> {
85 pub fn new(
87 gpu_config: GpuConfig,
88 metric: DistanceMetric,
89 tile_size: Option<usize>,
90 ) -> Result<Self> {
91 let device = Self::detect_gpu_device(&gpu_config)?;
92 let context = GpuContext::new(device, gpu_config)?;
93
94 let optimal_tile_size =
95 tile_size.unwrap_or_else(|| Self::calculate_optimal_tile_size(&context));
96
97 let memory_manager = GpuMemoryManager::new(256, 100);
98
99 Ok(Self {
100 context,
101 metric,
102 gpu_data: None,
103 tile_size: optimal_tile_size,
104 use_shared_memory: true,
105 memory_manager,
106 })
107 }
108
109 pub fn preload_data(&mut self, data: ArrayView2<F>) -> Result<()> {
111 let shape = [data.nrows(), data.ncols()];
112 let mut gpu_data = GpuArray::allocate(shape)?;
113 gpu_data.copy_from_host(data)?;
114 self.gpu_data = Some(gpu_data);
115 Ok(())
116 }
117
118 pub fn compute_distance_matrix(&mut self, data: ArrayView2<F>) -> Result<Array2<F>> {
120 let n_samples = data.nrows();
121 let mut result = Array2::zeros((n_samples, n_samples));
122
123 if !self.context.is_gpu_accelerated() {
124 return self.compute_distance_matrix_cpu(data);
126 }
127
128 if self.gpu_data.is_none() {
130 self.preload_data(data)?;
131 }
132
133 for i in (0..n_samples).step_by(self.tile_size) {
135 for j in (0..n_samples).step_by(self.tile_size) {
136 let i_end = (i + self.tile_size).min(n_samples);
137 let j_end = (j + self.tile_size).min(n_samples);
138
139 let tile_result = self.compute_distance_tile(i, i_end, j, j_end)?;
140
141 for (ii, row) in tile_result.rows().into_iter().enumerate() {
143 for (jj, &val) in row.iter().enumerate() {
144 if i + ii < n_samples && j + jj < n_samples {
145 result[[i + ii, j + jj]] = val;
146 }
147 }
148 }
149 }
150 }
151
152 Ok(result)
153 }
154
155 pub fn compute_distances_to_centroids(
157 &mut self,
158 data: ArrayView2<F>,
159 centroids: ArrayView2<F>,
160 ) -> Result<Array2<F>> {
161 let n_samples = data.nrows();
162 let n_centroids = centroids.nrows();
163 let mut result = Array2::zeros((n_samples, n_centroids));
164
165 if !self.context.is_gpu_accelerated() {
166 return self.compute_distances_to_centroids_cpu(data, centroids);
167 }
168
169 for i in (0..n_samples).step_by(self.tile_size) {
171 let i_end = (i + self.tile_size).min(n_samples);
172
173 for j in (0..n_centroids).step_by(self.tile_size) {
174 let j_end = (j + self.tile_size).min(n_centroids);
175
176 let tile_result =
177 self.compute_centroid_distance_tile(data, centroids, i, i_end, j, j_end)?;
178
179 for (ii, row) in tile_result.rows().into_iter().enumerate() {
181 for (jj, &val) in row.iter().enumerate() {
182 if i + ii < n_samples && j + jj < n_centroids {
183 result[[i + ii, j + jj]] = val;
184 }
185 }
186 }
187 }
188 }
189
190 Ok(result)
191 }
192
193 pub fn find_k_nearest(
195 &mut self,
196 query: ArrayView1<F>,
197 data: ArrayView2<F>,
198 k: usize,
199 ) -> Result<(Vec<usize>, Vec<F>)> {
200 if k == 0 || k > data.nrows() {
201 return Err(ClusteringError::InvalidInput(
202 "Invalid k value for k-nearest neighbors".to_string(),
203 ));
204 }
205
206 let distances = self.compute_point_distances(query, data)?;
207
208 let mut indexed_distances: Vec<(usize, F)> =
210 distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
211
212 indexed_distances
213 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
214
215 let indices = indexed_distances.iter().take(k).map(|(i, _)| *i).collect();
216 let distances = indexed_distances.iter().take(k).map(|(_, d)| *d).collect();
217
218 Ok((indices, distances))
219 }
220
221 fn compute_point_distances(
223 &mut self,
224 query: ArrayView1<F>,
225 data: ArrayView2<F>,
226 ) -> Result<Vec<F>> {
227 let n_samples = data.nrows();
228 let mut distances = vec![F::zero(); n_samples];
229
230 for (i, data_point) in data.rows().into_iter().enumerate() {
231 distances[i] = self.compute_single_distance(query, data_point)?;
232 }
233
234 Ok(distances)
235 }
236
237 fn compute_single_distance(&self, point1: ArrayView1<F>, point2: ArrayView1<F>) -> Result<F> {
239 if point1.len() != point2.len() {
240 return Err(ClusteringError::InvalidInput(
241 "Points must have same dimensionality".to_string(),
242 ));
243 }
244
245 let distance = match self.metric {
246 DistanceMetric::Euclidean => {
247 let sum_sq: F = point1
248 .iter()
249 .zip(point2.iter())
250 .map(|(&a, &b)| (a - b) * (a - b))
251 .fold(F::zero(), |acc, x| acc + x);
252 sum_sq.sqrt()
253 }
254 DistanceMetric::SquaredEuclidean => point1
255 .iter()
256 .zip(point2.iter())
257 .map(|(&a, &b)| (a - b) * (a - b))
258 .fold(F::zero(), |acc, x| acc + x),
259 DistanceMetric::Manhattan => point1
260 .iter()
261 .zip(point2.iter())
262 .map(|(&a, &b)| (a - b).abs())
263 .fold(F::zero(), |acc, x| acc + x),
264 DistanceMetric::Cosine => {
265 let dot_product = point1
266 .iter()
267 .zip(point2.iter())
268 .map(|(&a, &b)| a * b)
269 .fold(F::zero(), |acc, x| acc + x);
270
271 let norm1 = point1
272 .iter()
273 .map(|&x| x * x)
274 .fold(F::zero(), |acc, x| acc + x)
275 .sqrt();
276
277 let norm2 = point2
278 .iter()
279 .map(|&x| x * x)
280 .fold(F::zero(), |acc, x| acc + x)
281 .sqrt();
282
283 if norm1 == F::zero() || norm2 == F::zero() {
284 F::one()
285 } else {
286 F::one() - (dot_product / (norm1 * norm2))
287 }
288 }
289 DistanceMetric::Chebyshev => point1
290 .iter()
291 .zip(point2.iter())
292 .map(|(&a, &b)| (a - b).abs())
293 .fold(F::zero(), |acc, x| if x > acc { x } else { acc }),
294 DistanceMetric::Minkowski(p) => {
295 let p_f = F::from(p).unwrap_or(F::one());
296 let sum: F = point1
297 .iter()
298 .zip(point2.iter())
299 .map(|(&a, &b)| (a - b).abs().powf(p_f))
300 .fold(F::zero(), |acc, x| acc + x);
301 sum.powf(F::one() / p_f)
302 }
303 DistanceMetric::Hamming => {
304 let threshold = F::from(0.5).unwrap_or(F::zero());
306 let count = point1
307 .iter()
308 .zip(point2.iter())
309 .filter(|(&a, &b)| (a - b).abs() > threshold)
310 .count();
311 F::from(count).unwrap_or(F::zero())
312 }
313 };
314
315 Ok(distance)
316 }
317
318 pub fn compute_distance_matrix_cpu(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
320 let n_samples = data.nrows();
321 let mut result = Array2::zeros((n_samples, n_samples));
322
323 for i in 0..n_samples {
324 for j in i..n_samples {
325 let distance = self.compute_single_distance(data.row(i), data.row(j))?;
326 result[[i, j]] = distance;
327 result[[j, i]] = distance;
328 }
329 }
330
331 Ok(result)
332 }
333
334 fn compute_distances_to_centroids_cpu(
336 &self,
337 data: ArrayView2<F>,
338 centroids: ArrayView2<F>,
339 ) -> Result<Array2<F>> {
340 let n_samples = data.nrows();
341 let n_centroids = centroids.nrows();
342 let mut result = Array2::zeros((n_samples, n_centroids));
343
344 for i in 0..n_samples {
345 for j in 0..n_centroids {
346 let distance = self.compute_single_distance(data.row(i), centroids.row(j))?;
347 result[[i, j]] = distance;
348 }
349 }
350
351 Ok(result)
352 }
353
354 fn compute_distance_tile(
356 &self,
357 _i_start: usize,
358 _i_end: usize,
359 _j_start: usize,
360 _j_end: usize,
361 ) -> Result<Array2<F>> {
362 Ok(Array2::zeros((1, 1)))
365 }
366
367 fn compute_centroid_distance_tile(
368 &self,
369 _data: ArrayView2<F>,
370 _centroids: ArrayView2<F>,
371 _i_start: usize,
372 _i_end: usize,
373 _j_start: usize,
374 _j_end: usize,
375 ) -> Result<Array2<F>> {
376 Ok(Array2::zeros((1, 1)))
379 }
380
381 fn detect_gpu_device(config: &GpuConfig) -> Result<super::core::GpuDevice> {
383 Ok(super::core::GpuDevice::new(
385 0,
386 "Stub GPU".to_string(),
387 8_000_000_000,
388 6_000_000_000,
389 "1.0".to_string(),
390 1024,
391 config.preferred_backend,
392 true,
393 ))
394 }
395
396 fn calculate_optimal_tile_size(context: &GpuContext) -> usize {
398 let (total_memory, available_memory) = context.memory_info();
400 let compute_units = context.device.compute_units as usize;
401
402 let memory_based = (available_memory / (8 * std::mem::size_of::<F>())).min(1024);
404 let compute_based = (compute_units * 32).min(512);
405
406 memory_based.min(compute_based).max(32)
407 }
408}
409
410impl<F: Float> GpuArray<F> {
411 pub fn allocate(shape: [usize; 2]) -> Result<Self> {
413 let element_size = std::mem::size_of::<F>();
414 let total_size = shape[0] * shape[1] * element_size;
415
416 let device_ptr = 0x2000_0000; Ok(Self {
420 device_ptr,
421 shape,
422 element_size,
423 on_device: true,
424 _phantom: std::marker::PhantomData,
425 })
426 }
427
428 pub fn copy_from_host(&mut self, _data: ArrayView2<F>) -> Result<()> {
430 self.on_device = true;
432 Ok(())
433 }
434
435 pub fn copy_to_host(&self) -> Result<Array2<F>> {
437 Ok(Array2::zeros((self.shape[0], self.shape[1])))
439 }
440
441 pub fn shape(&self) -> [usize; 2] {
443 self.shape
444 }
445
446 pub fn is_on_device(&self) -> bool {
448 self.on_device
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use scirs2_core::ndarray::Array2;
456
457 #[test]
458 fn test_distance_metrics() {
459 let point1 = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]);
460 let point2 = scirs2_core::ndarray::arr1(&[4.0, 5.0, 6.0]);
461
462 let config = GpuConfig::default();
463 let matrix = GpuDistanceMatrix::<f64>::new(config, DistanceMetric::Euclidean, None)
464 .expect("Operation failed");
465
466 let distance = matrix
467 .compute_single_distance(point1.view(), point2.view())
468 .expect("Operation failed");
469 assert!((distance - 5.196152422706632).abs() < 1e-10);
470 }
471
472 #[test]
473 fn test_gpu_array_allocation() {
474 let array = GpuArray::<f32>::allocate([100, 50]).expect("Operation failed");
475 assert_eq!(array.shape(), [100, 50]);
476 assert!(array.is_on_device());
477 }
478
479 #[test]
480 fn test_distance_matrix_cpu_fallback() {
481 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
482 .expect("Operation failed");
483
484 let config = GpuConfig::default();
485 let matrix = GpuDistanceMatrix::new(config, DistanceMetric::Euclidean, None)
486 .expect("Operation failed");
487
488 let result = matrix
489 .compute_distance_matrix_cpu(data.view())
490 .expect("Operation failed");
491 assert_eq!(result.shape(), &[3, 3]);
492 assert!((result[[0, 0]] - 0.0).abs() < 1e-10);
493 }
494
495 #[test]
496 fn test_k_nearest_neighbors() {
497 let query = scirs2_core::ndarray::arr1(&[1.0, 1.0]);
498 let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 2.0, 2.0, 3.0, 3.0, 1.0, 1.0])
499 .expect("Operation failed");
500
501 let config = GpuConfig::default();
502 let mut matrix = GpuDistanceMatrix::new(config, DistanceMetric::Euclidean, None)
503 .expect("Operation failed");
504
505 let (indices, distances) = matrix
506 .find_k_nearest(query.view(), data.view(), 2)
507 .expect("Operation failed");
508 assert_eq!(indices.len(), 2);
509 assert_eq!(distances.len(), 2);
510 assert_eq!(indices[0], 3); }
512}