Skip to main content

scirs2_cluster/gpu/
distance.rs

1//! GPU-accelerated distance computations
2//!
3//! This module provides GPU-accelerated distance matrix computations and
4//! various distance metrics optimized for GPU hardware.
5
6use crate::error::{ClusteringError, Result};
7use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use serde::{Deserialize, Serialize};
10
11use super::core::{GpuConfig, GpuContext};
12use super::memory::{GpuMemoryManager, MemoryTransfer};
13
14/// Distance metrics supported by GPU acceleration
15#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
16pub enum DistanceMetric {
17    /// Euclidean distance (L2 norm)
18    Euclidean,
19    /// Manhattan distance (L1 norm)
20    Manhattan,
21    /// Cosine distance
22    Cosine,
23    /// Minkowski distance with custom p
24    Minkowski(f64),
25    /// Squared Euclidean distance (faster, no sqrt)
26    SquaredEuclidean,
27    /// Chebyshev distance (L norm)
28    Chebyshev,
29    /// Hamming distance (for binary data)
30    Hamming,
31}
32
33impl Default for DistanceMetric {
34    fn default() -> Self {
35        DistanceMetric::Euclidean
36    }
37}
38
39impl std::fmt::Display for DistanceMetric {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            DistanceMetric::Euclidean => write!(f, "euclidean"),
43            DistanceMetric::Manhattan => write!(f, "manhattan"),
44            DistanceMetric::Cosine => write!(f, "cosine"),
45            DistanceMetric::Minkowski(p) => write!(f, "minkowski(p={})", p),
46            DistanceMetric::SquaredEuclidean => write!(f, "squared_euclidean"),
47            DistanceMetric::Chebyshev => write!(f, "chebyshev"),
48            DistanceMetric::Hamming => write!(f, "hamming"),
49        }
50    }
51}
52
53/// Enhanced GPU distance matrix for fast nearest neighbor computations
54#[derive(Debug)]
55pub struct GpuDistanceMatrix<F: Float> {
56    /// GPU context
57    context: GpuContext,
58    /// Distance metric
59    metric: DistanceMetric,
60    /// Pre-loaded GPU data
61    gpu_data: Option<GpuArray<F>>,
62    /// Tile size for blocked computations
63    tile_size: usize,
64    /// Whether to use shared memory optimization
65    use_shared_memory: bool,
66    /// Memory manager
67    memory_manager: GpuMemoryManager,
68}
69
70/// GPU array abstraction
71#[derive(Debug)]
72pub struct GpuArray<F: Float> {
73    /// Device pointer
74    device_ptr: usize,
75    /// Array shape (rows, cols)
76    shape: [usize; 2],
77    /// Data type size in bytes
78    element_size: usize,
79    /// Whether data is currently on device
80    on_device: bool,
81    _phantom: std::marker::PhantomData<F>,
82}
83
84impl<F: Float + FromPrimitive + Send + Sync> GpuDistanceMatrix<F> {
85    /// Create new GPU distance matrix
86    pub fn new(
87        gpu_config: GpuConfig,
88        metric: DistanceMetric,
89        tile_size: Option<usize>,
90    ) -> Result<Self> {
91        let device = Self::detect_gpu_device(&gpu_config)?;
92        let context = GpuContext::new(device, gpu_config)?;
93
94        let optimal_tile_size =
95            tile_size.unwrap_or_else(|| Self::calculate_optimal_tile_size(&context));
96
97        let memory_manager = GpuMemoryManager::new(256, 100);
98
99        Ok(Self {
100            context,
101            metric,
102            gpu_data: None,
103            tile_size: optimal_tile_size,
104            use_shared_memory: true,
105            memory_manager,
106        })
107    }
108
109    /// Preload data to GPU for repeated distance computations
110    pub fn preload_data(&mut self, data: ArrayView2<F>) -> Result<()> {
111        let shape = [data.nrows(), data.ncols()];
112        let mut gpu_data = GpuArray::allocate(shape)?;
113        gpu_data.copy_from_host(data)?;
114        self.gpu_data = Some(gpu_data);
115        Ok(())
116    }
117
118    /// Compute full distance matrix
119    pub fn compute_distance_matrix(&mut self, data: ArrayView2<F>) -> Result<Array2<F>> {
120        let n_samples = data.nrows();
121        let mut result = Array2::zeros((n_samples, n_samples));
122
123        if !self.context.is_gpu_accelerated() {
124            // CPU fallback
125            return self.compute_distance_matrix_cpu(data);
126        }
127
128        // Use preloaded data if available
129        if self.gpu_data.is_none() {
130            self.preload_data(data)?;
131        }
132
133        // GPU computation with tiling
134        for i in (0..n_samples).step_by(self.tile_size) {
135            for j in (0..n_samples).step_by(self.tile_size) {
136                let i_end = (i + self.tile_size).min(n_samples);
137                let j_end = (j + self.tile_size).min(n_samples);
138
139                let tile_result = self.compute_distance_tile(i, i_end, j, j_end)?;
140
141                // Copy results back to host
142                for (ii, row) in tile_result.rows().into_iter().enumerate() {
143                    for (jj, &val) in row.iter().enumerate() {
144                        if i + ii < n_samples && j + jj < n_samples {
145                            result[[i + ii, j + jj]] = val;
146                        }
147                    }
148                }
149            }
150        }
151
152        Ok(result)
153    }
154
155    /// Compute distances from points to centroids
156    pub fn compute_distances_to_centroids(
157        &mut self,
158        data: ArrayView2<F>,
159        centroids: ArrayView2<F>,
160    ) -> Result<Array2<F>> {
161        let n_samples = data.nrows();
162        let n_centroids = centroids.nrows();
163        let mut result = Array2::zeros((n_samples, n_centroids));
164
165        if !self.context.is_gpu_accelerated() {
166            return self.compute_distances_to_centroids_cpu(data, centroids);
167        }
168
169        // GPU implementation
170        for i in (0..n_samples).step_by(self.tile_size) {
171            let i_end = (i + self.tile_size).min(n_samples);
172
173            for j in (0..n_centroids).step_by(self.tile_size) {
174                let j_end = (j + self.tile_size).min(n_centroids);
175
176                let tile_result =
177                    self.compute_centroid_distance_tile(data, centroids, i, i_end, j, j_end)?;
178
179                // Copy results
180                for (ii, row) in tile_result.rows().into_iter().enumerate() {
181                    for (jj, &val) in row.iter().enumerate() {
182                        if i + ii < n_samples && j + jj < n_centroids {
183                            result[[i + ii, j + jj]] = val;
184                        }
185                    }
186                }
187            }
188        }
189
190        Ok(result)
191    }
192
193    /// Find k nearest neighbors
194    pub fn find_k_nearest(
195        &mut self,
196        query: ArrayView1<F>,
197        data: ArrayView2<F>,
198        k: usize,
199    ) -> Result<(Vec<usize>, Vec<F>)> {
200        if k == 0 || k > data.nrows() {
201            return Err(ClusteringError::InvalidInput(
202                "Invalid k value for k-nearest neighbors".to_string(),
203            ));
204        }
205
206        let distances = self.compute_point_distances(query, data)?;
207
208        // Sort and get top k
209        let mut indexed_distances: Vec<(usize, F)> =
210            distances.iter().enumerate().map(|(i, &d)| (i, d)).collect();
211
212        indexed_distances
213            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
214
215        let indices = indexed_distances.iter().take(k).map(|(i, _)| *i).collect();
216        let distances = indexed_distances.iter().take(k).map(|(_, d)| *d).collect();
217
218        Ok((indices, distances))
219    }
220
221    /// Compute distances from a single point to all data points
222    fn compute_point_distances(
223        &mut self,
224        query: ArrayView1<F>,
225        data: ArrayView2<F>,
226    ) -> Result<Vec<F>> {
227        let n_samples = data.nrows();
228        let mut distances = vec![F::zero(); n_samples];
229
230        for (i, data_point) in data.rows().into_iter().enumerate() {
231            distances[i] = self.compute_single_distance(query, data_point)?;
232        }
233
234        Ok(distances)
235    }
236
237    /// Compute distance between two points
238    fn compute_single_distance(&self, point1: ArrayView1<F>, point2: ArrayView1<F>) -> Result<F> {
239        if point1.len() != point2.len() {
240            return Err(ClusteringError::InvalidInput(
241                "Points must have same dimensionality".to_string(),
242            ));
243        }
244
245        let distance = match self.metric {
246            DistanceMetric::Euclidean => {
247                let sum_sq: F = point1
248                    .iter()
249                    .zip(point2.iter())
250                    .map(|(&a, &b)| (a - b) * (a - b))
251                    .fold(F::zero(), |acc, x| acc + x);
252                sum_sq.sqrt()
253            }
254            DistanceMetric::SquaredEuclidean => point1
255                .iter()
256                .zip(point2.iter())
257                .map(|(&a, &b)| (a - b) * (a - b))
258                .fold(F::zero(), |acc, x| acc + x),
259            DistanceMetric::Manhattan => point1
260                .iter()
261                .zip(point2.iter())
262                .map(|(&a, &b)| (a - b).abs())
263                .fold(F::zero(), |acc, x| acc + x),
264            DistanceMetric::Cosine => {
265                let dot_product = point1
266                    .iter()
267                    .zip(point2.iter())
268                    .map(|(&a, &b)| a * b)
269                    .fold(F::zero(), |acc, x| acc + x);
270
271                let norm1 = point1
272                    .iter()
273                    .map(|&x| x * x)
274                    .fold(F::zero(), |acc, x| acc + x)
275                    .sqrt();
276
277                let norm2 = point2
278                    .iter()
279                    .map(|&x| x * x)
280                    .fold(F::zero(), |acc, x| acc + x)
281                    .sqrt();
282
283                if norm1 == F::zero() || norm2 == F::zero() {
284                    F::one()
285                } else {
286                    F::one() - (dot_product / (norm1 * norm2))
287                }
288            }
289            DistanceMetric::Chebyshev => point1
290                .iter()
291                .zip(point2.iter())
292                .map(|(&a, &b)| (a - b).abs())
293                .fold(F::zero(), |acc, x| if x > acc { x } else { acc }),
294            DistanceMetric::Minkowski(p) => {
295                let p_f = F::from(p).unwrap_or(F::one());
296                let sum: F = point1
297                    .iter()
298                    .zip(point2.iter())
299                    .map(|(&a, &b)| (a - b).abs().powf(p_f))
300                    .fold(F::zero(), |acc, x| acc + x);
301                sum.powf(F::one() / p_f)
302            }
303            DistanceMetric::Hamming => {
304                // For continuous data, use threshold-based Hamming
305                let threshold = F::from(0.5).unwrap_or(F::zero());
306                let count = point1
307                    .iter()
308                    .zip(point2.iter())
309                    .filter(|(&a, &b)| (a - b).abs() > threshold)
310                    .count();
311                F::from(count).unwrap_or(F::zero())
312            }
313        };
314
315        Ok(distance)
316    }
317
318    /// CPU fallback for distance matrix computation
319    pub fn compute_distance_matrix_cpu(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
320        let n_samples = data.nrows();
321        let mut result = Array2::zeros((n_samples, n_samples));
322
323        for i in 0..n_samples {
324            for j in i..n_samples {
325                let distance = self.compute_single_distance(data.row(i), data.row(j))?;
326                result[[i, j]] = distance;
327                result[[j, i]] = distance;
328            }
329        }
330
331        Ok(result)
332    }
333
334    /// CPU fallback for centroid distances
335    fn compute_distances_to_centroids_cpu(
336        &self,
337        data: ArrayView2<F>,
338        centroids: ArrayView2<F>,
339    ) -> Result<Array2<F>> {
340        let n_samples = data.nrows();
341        let n_centroids = centroids.nrows();
342        let mut result = Array2::zeros((n_samples, n_centroids));
343
344        for i in 0..n_samples {
345            for j in 0..n_centroids {
346                let distance = self.compute_single_distance(data.row(i), centroids.row(j))?;
347                result[[i, j]] = distance;
348            }
349        }
350
351        Ok(result)
352    }
353
354    /// Stub implementations for GPU computations
355    fn compute_distance_tile(
356        &self,
357        _i_start: usize,
358        _i_end: usize,
359        _j_start: usize,
360        _j_end: usize,
361    ) -> Result<Array2<F>> {
362        // This would contain the actual GPU kernel launch
363        // For now, return empty array as stub
364        Ok(Array2::zeros((1, 1)))
365    }
366
367    fn compute_centroid_distance_tile(
368        &self,
369        _data: ArrayView2<F>,
370        _centroids: ArrayView2<F>,
371        _i_start: usize,
372        _i_end: usize,
373        _j_start: usize,
374        _j_end: usize,
375    ) -> Result<Array2<F>> {
376        // This would contain the actual GPU kernel launch
377        // For now, return empty array as stub
378        Ok(Array2::zeros((1, 1)))
379    }
380
381    /// Detect available GPU device
382    fn detect_gpu_device(config: &GpuConfig) -> Result<super::core::GpuDevice> {
383        // Stub implementation - would detect actual GPU devices
384        Ok(super::core::GpuDevice::new(
385            0,
386            "Stub GPU".to_string(),
387            8_000_000_000,
388            6_000_000_000,
389            "1.0".to_string(),
390            1024,
391            config.preferred_backend,
392            true,
393        ))
394    }
395
396    /// Calculate optimal tile size based on GPU capabilities
397    fn calculate_optimal_tile_size(context: &GpuContext) -> usize {
398        // Calculate based on available memory and compute units
399        let (total_memory, available_memory) = context.memory_info();
400        let compute_units = context.device.compute_units as usize;
401
402        // Simple heuristic: balance memory usage and parallelism
403        let memory_based = (available_memory / (8 * std::mem::size_of::<F>())).min(1024);
404        let compute_based = (compute_units * 32).min(512);
405
406        memory_based.min(compute_based).max(32)
407    }
408}
409
410impl<F: Float> GpuArray<F> {
411    /// Allocate GPU array
412    pub fn allocate(shape: [usize; 2]) -> Result<Self> {
413        let element_size = std::mem::size_of::<F>();
414        let total_size = shape[0] * shape[1] * element_size;
415
416        // Stub allocation - would allocate actual GPU memory
417        let device_ptr = 0x2000_0000; // Fake pointer
418
419        Ok(Self {
420            device_ptr,
421            shape,
422            element_size,
423            on_device: true,
424            _phantom: std::marker::PhantomData,
425        })
426    }
427
428    /// Copy data from host to device
429    pub fn copy_from_host(&mut self, _data: ArrayView2<F>) -> Result<()> {
430        // Stub implementation - would perform actual memory transfer
431        self.on_device = true;
432        Ok(())
433    }
434
435    /// Copy data from device to host
436    pub fn copy_to_host(&self) -> Result<Array2<F>> {
437        // Stub implementation - would perform actual memory transfer
438        Ok(Array2::zeros((self.shape[0], self.shape[1])))
439    }
440
441    /// Get array shape
442    pub fn shape(&self) -> [usize; 2] {
443        self.shape
444    }
445
446    /// Check if data is on device
447    pub fn is_on_device(&self) -> bool {
448        self.on_device
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use scirs2_core::ndarray::Array2;
456
457    #[test]
458    fn test_distance_metrics() {
459        let point1 = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]);
460        let point2 = scirs2_core::ndarray::arr1(&[4.0, 5.0, 6.0]);
461
462        let config = GpuConfig::default();
463        let matrix = GpuDistanceMatrix::<f64>::new(config, DistanceMetric::Euclidean, None)
464            .expect("Operation failed");
465
466        let distance = matrix
467            .compute_single_distance(point1.view(), point2.view())
468            .expect("Operation failed");
469        assert!((distance - 5.196152422706632).abs() < 1e-10);
470    }
471
472    #[test]
473    fn test_gpu_array_allocation() {
474        let array = GpuArray::<f32>::allocate([100, 50]).expect("Operation failed");
475        assert_eq!(array.shape(), [100, 50]);
476        assert!(array.is_on_device());
477    }
478
479    #[test]
480    fn test_distance_matrix_cpu_fallback() {
481        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
482            .expect("Operation failed");
483
484        let config = GpuConfig::default();
485        let matrix = GpuDistanceMatrix::new(config, DistanceMetric::Euclidean, None)
486            .expect("Operation failed");
487
488        let result = matrix
489            .compute_distance_matrix_cpu(data.view())
490            .expect("Operation failed");
491        assert_eq!(result.shape(), &[3, 3]);
492        assert!((result[[0, 0]] - 0.0).abs() < 1e-10);
493    }
494
495    #[test]
496    fn test_k_nearest_neighbors() {
497        let query = scirs2_core::ndarray::arr1(&[1.0, 1.0]);
498        let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 2.0, 2.0, 3.0, 3.0, 1.0, 1.0])
499            .expect("Operation failed");
500
501        let config = GpuConfig::default();
502        let mut matrix = GpuDistanceMatrix::new(config, DistanceMetric::Euclidean, None)
503            .expect("Operation failed");
504
505        let (indices, distances) = matrix
506            .find_k_nearest(query.view(), data.view(), 2)
507            .expect("Operation failed");
508        assert_eq!(indices.len(), 2);
509        assert_eq!(distances.len(), 2);
510        assert_eq!(indices[0], 3); // Exact match should be first
511    }
512}