1use crate::error::{ClusteringError, Result};
7use serde::{Deserialize, Serialize};
8
9use super::core::GpuBackend;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum DistanceKernelType {
18 Euclidean,
20 SquaredEuclidean,
22 Manhattan,
24 Cosine,
26 Chebyshev,
28 Minkowski,
30 Hamming,
32}
33
34impl Default for DistanceKernelType {
35 fn default() -> Self {
36 DistanceKernelType::SquaredEuclidean
37 }
38}
39
40pub fn generate_cuda_distance_matrix_kernel() -> String {
46 r#"
47extern "C" __global__ void squared_euclidean_distance_matrix(
48 const float* __restrict__ data,
49 float* __restrict__ distances,
50 const int n_samples,
51 const int n_features,
52 const int tile_size
53) {
54 // Shared memory for tiling
55 extern __shared__ float shared_mem[];
56 float* tile_a = shared_mem;
57 float* tile_b = shared_mem + tile_size * n_features;
58
59 const int row = blockIdx.y * tile_size + threadIdx.y;
60 const int col = blockIdx.x * tile_size + threadIdx.x;
61
62 if (row >= n_samples || col >= n_samples) return;
63
64 // Load tiles into shared memory
65 for (int k = 0; k < n_features; k += blockDim.x) {
66 int feat_idx = k + threadIdx.x;
67 if (feat_idx < n_features) {
68 if (threadIdx.y < tile_size && row < n_samples) {
69 tile_a[threadIdx.y * n_features + feat_idx] = data[row * n_features + feat_idx];
70 }
71 if (threadIdx.x < tile_size && col < n_samples) {
72 tile_b[threadIdx.x * n_features + feat_idx] = data[col * n_features + feat_idx];
73 }
74 }
75 }
76
77 __syncthreads();
78
79 // Compute squared Euclidean distance
80 float sum = 0.0f;
81 for (int k = 0; k < n_features; k++) {
82 float diff = tile_a[threadIdx.y * n_features + k] - tile_b[threadIdx.x * n_features + k];
83 sum += diff * diff;
84 }
85
86 distances[row * n_samples + col] = sum;
87}
88
89extern "C" __global__ void euclidean_distance_matrix(
90 const float* __restrict__ data,
91 float* __restrict__ distances,
92 const int n_samples,
93 const int n_features
94) {
95 const int row = blockIdx.y * blockDim.y + threadIdx.y;
96 const int col = blockIdx.x * blockDim.x + threadIdx.x;
97
98 if (row >= n_samples || col >= n_samples) return;
99 if (row > col) {
100 // Only compute upper triangle, copy to lower
101 return;
102 }
103
104 float sum = 0.0f;
105 for (int k = 0; k < n_features; k++) {
106 float diff = data[row * n_features + k] - data[col * n_features + k];
107 sum += diff * diff;
108 }
109
110 float dist = sqrtf(sum);
111 distances[row * n_samples + col] = dist;
112 distances[col * n_samples + row] = dist; // Symmetric
113}
114
115extern "C" __global__ void cosine_distance_matrix(
116 const float* __restrict__ data,
117 float* __restrict__ distances,
118 const int n_samples,
119 const int n_features
120) {
121 const int row = blockIdx.y * blockDim.y + threadIdx.y;
122 const int col = blockIdx.x * blockDim.x + threadIdx.x;
123
124 if (row >= n_samples || col >= n_samples) return;
125 if (row > col) return;
126
127 float dot_product = 0.0f;
128 float norm_a = 0.0f;
129 float norm_b = 0.0f;
130
131 for (int k = 0; k < n_features; k++) {
132 float a = data[row * n_features + k];
133 float b = data[col * n_features + k];
134 dot_product += a * b;
135 norm_a += a * a;
136 norm_b += b * b;
137 }
138
139 norm_a = sqrtf(norm_a);
140 norm_b = sqrtf(norm_b);
141
142 float cosine_sim = (norm_a > 0 && norm_b > 0) ? (dot_product / (norm_a * norm_b)) : 0.0f;
143 float dist = 1.0f - cosine_sim;
144
145 distances[row * n_samples + col] = dist;
146 distances[col * n_samples + row] = dist;
147}
148
149extern "C" __global__ void manhattan_distance_matrix(
150 const float* __restrict__ data,
151 float* __restrict__ distances,
152 const int n_samples,
153 const int n_features
154) {
155 const int row = blockIdx.y * blockDim.y + threadIdx.y;
156 const int col = blockIdx.x * blockDim.x + threadIdx.x;
157
158 if (row >= n_samples || col >= n_samples) return;
159 if (row > col) return;
160
161 float sum = 0.0f;
162 for (int k = 0; k < n_features; k++) {
163 float diff = data[row * n_features + k] - data[col * n_features + k];
164 sum += fabsf(diff);
165 }
166
167 distances[row * n_samples + col] = sum;
168 distances[col * n_samples + row] = sum;
169}
170"#
171 .to_string()
172}
173
174pub fn generate_cuda_kmeans_assign_kernel() -> String {
176 r#"
177extern "C" __global__ void kmeans_assign_labels(
178 const float* __restrict__ data,
179 const float* __restrict__ centroids,
180 int* __restrict__ labels,
181 float* __restrict__ distances,
182 const int n_samples,
183 const int n_centroids,
184 const int n_features
185) {
186 const int sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
187
188 if (sample_idx >= n_samples) return;
189
190 float min_dist = 1e38f; // Large number
191 int min_label = 0;
192
193 for (int c = 0; c < n_centroids; c++) {
194 float dist = 0.0f;
195 for (int f = 0; f < n_features; f++) {
196 float diff = data[sample_idx * n_features + f] - centroids[c * n_features + f];
197 dist += diff * diff;
198 }
199
200 if (dist < min_dist) {
201 min_dist = dist;
202 min_label = c;
203 }
204 }
205
206 labels[sample_idx] = min_label;
207 distances[sample_idx] = min_dist;
208}
209
210extern "C" __global__ void kmeans_compute_centroids(
211 const float* __restrict__ data,
212 const int* __restrict__ labels,
213 float* __restrict__ new_centroids,
214 int* __restrict__ counts,
215 const int n_samples,
216 const int n_centroids,
217 const int n_features
218) {
219 const int centroid_idx = blockIdx.x;
220 const int feature_idx = threadIdx.x;
221
222 if (centroid_idx >= n_centroids || feature_idx >= n_features) return;
223
224 // Initialize
225 if (feature_idx == 0) {
226 counts[centroid_idx] = 0;
227 }
228 new_centroids[centroid_idx * n_features + feature_idx] = 0.0f;
229
230 __syncthreads();
231
232 // Sum points in this cluster
233 for (int i = 0; i < n_samples; i++) {
234 if (labels[i] == centroid_idx) {
235 atomicAdd(&new_centroids[centroid_idx * n_features + feature_idx],
236 data[i * n_features + feature_idx]);
237 if (feature_idx == 0) {
238 atomicAdd(&counts[centroid_idx], 1);
239 }
240 }
241 }
242
243 __syncthreads();
244
245 // Normalize
246 int count = counts[centroid_idx];
247 if (count > 0) {
248 new_centroids[centroid_idx * n_features + feature_idx] /= (float)count;
249 }
250}
251"#
252 .to_string()
253}
254
255pub fn generate_cuda_batch_distance_kernel() -> String {
257 r#"
258extern "C" __global__ void batch_squared_euclidean(
259 const float* __restrict__ points,
260 const float* __restrict__ centroids,
261 float* __restrict__ distances,
262 const int n_points,
263 const int n_centroids,
264 const int n_features
265) {
266 const int point_idx = blockIdx.y * blockDim.y + threadIdx.y;
267 const int centroid_idx = blockIdx.x * blockDim.x + threadIdx.x;
268
269 if (point_idx >= n_points || centroid_idx >= n_centroids) return;
270
271 float sum = 0.0f;
272 for (int k = 0; k < n_features; k++) {
273 float diff = points[point_idx * n_features + k] - centroids[centroid_idx * n_features + k];
274 sum += diff * diff;
275 }
276
277 distances[point_idx * n_centroids + centroid_idx] = sum;
278}
279
280// Tensor core accelerated version (for supported GPUs)
281extern "C" __global__ void batch_squared_euclidean_tc(
282 const half* __restrict__ points,
283 const half* __restrict__ centroids,
284 float* __restrict__ distances,
285 const int n_points,
286 const int n_centroids,
287 const int n_features
288) {
289 // Simplified version - actual implementation would use WMMA intrinsics
290 const int point_idx = blockIdx.y * blockDim.y + threadIdx.y;
291 const int centroid_idx = blockIdx.x * blockDim.x + threadIdx.x;
292
293 if (point_idx >= n_points || centroid_idx >= n_centroids) return;
294
295 float sum = 0.0f;
296 for (int k = 0; k < n_features; k++) {
297 float p = __half2float(points[point_idx * n_features + k]);
298 float c = __half2float(centroids[centroid_idx * n_features + k]);
299 float diff = p - c;
300 sum += diff * diff;
301 }
302
303 distances[point_idx * n_centroids + centroid_idx] = sum;
304}
305"#
306 .to_string()
307}
308
309pub fn generate_opencl_distance_matrix_kernel() -> String {
315 r#"
316__kernel void squared_euclidean_distance_matrix(
317 __global const float* data,
318 __global float* distances,
319 const int n_samples,
320 const int n_features
321) {
322 const int row = get_global_id(0);
323 const int col = get_global_id(1);
324
325 if (row >= n_samples || col >= n_samples) return;
326 if (row > col) return; // Compute upper triangle only
327
328 float sum = 0.0f;
329 for (int k = 0; k < n_features; k++) {
330 float diff = data[row * n_features + k] - data[col * n_features + k];
331 sum += diff * diff;
332 }
333
334 distances[row * n_samples + col] = sum;
335 distances[col * n_samples + row] = sum; // Symmetric
336}
337
338__kernel void euclidean_distance_matrix(
339 __global const float* data,
340 __global float* distances,
341 const int n_samples,
342 const int n_features
343) {
344 const int row = get_global_id(0);
345 const int col = get_global_id(1);
346
347 if (row >= n_samples || col >= n_samples) return;
348 if (row > col) return;
349
350 float sum = 0.0f;
351 for (int k = 0; k < n_features; k++) {
352 float diff = data[row * n_features + k] - data[col * n_features + k];
353 sum += diff * diff;
354 }
355
356 float dist = sqrt(sum);
357 distances[row * n_samples + col] = dist;
358 distances[col * n_samples + row] = dist;
359}
360
361__kernel void kmeans_assign_labels(
362 __global const float* data,
363 __global const float* centroids,
364 __global int* labels,
365 __global float* distances,
366 const int n_samples,
367 const int n_centroids,
368 const int n_features
369) {
370 const int sample_idx = get_global_id(0);
371
372 if (sample_idx >= n_samples) return;
373
374 float min_dist = 1e38f;
375 int min_label = 0;
376
377 for (int c = 0; c < n_centroids; c++) {
378 float dist = 0.0f;
379 for (int f = 0; f < n_features; f++) {
380 float diff = data[sample_idx * n_features + f] - centroids[c * n_features + f];
381 dist += diff * diff;
382 }
383
384 if (dist < min_dist) {
385 min_dist = dist;
386 min_label = c;
387 }
388 }
389
390 labels[sample_idx] = min_label;
391 distances[sample_idx] = min_dist;
392}
393"#
394 .to_string()
395}
396
397pub fn generate_metal_distance_kernel() -> String {
403 r#"
404#include <metal_stdlib>
405using namespace metal;
406
407kernel void squared_euclidean_distance_matrix(
408 device const float* data [[buffer(0)]],
409 device float* distances [[buffer(1)]],
410 constant uint& n_samples [[buffer(2)]],
411 constant uint& n_features [[buffer(3)]],
412 uint2 gid [[thread_position_in_grid]]
413) {
414 uint row = gid.y;
415 uint col = gid.x;
416
417 if (row >= n_samples || col >= n_samples) return;
418 if (row > col) return;
419
420 float sum = 0.0f;
421 for (uint k = 0; k < n_features; k++) {
422 float diff = data[row * n_features + k] - data[col * n_features + k];
423 sum += diff * diff;
424 }
425
426 distances[row * n_samples + col] = sum;
427 distances[col * n_samples + row] = sum;
428}
429
430kernel void euclidean_distance_matrix(
431 device const float* data [[buffer(0)]],
432 device float* distances [[buffer(1)]],
433 constant uint& n_samples [[buffer(2)]],
434 constant uint& n_features [[buffer(3)]],
435 uint2 gid [[thread_position_in_grid]]
436) {
437 uint row = gid.y;
438 uint col = gid.x;
439
440 if (row >= n_samples || col >= n_samples) return;
441 if (row > col) return;
442
443 float sum = 0.0f;
444 for (uint k = 0; k < n_features; k++) {
445 float diff = data[row * n_features + k] - data[col * n_features + k];
446 sum += diff * diff;
447 }
448
449 float dist = sqrt(sum);
450 distances[row * n_samples + col] = dist;
451 distances[col * n_samples + row] = dist;
452}
453
454kernel void kmeans_assign_labels(
455 device const float* data [[buffer(0)]],
456 device const float* centroids [[buffer(1)]],
457 device int* labels [[buffer(2)]],
458 device float* distances [[buffer(3)]],
459 constant uint& n_samples [[buffer(4)]],
460 constant uint& n_centroids [[buffer(5)]],
461 constant uint& n_features [[buffer(6)]],
462 uint gid [[thread_position_in_grid]]
463) {
464 uint sample_idx = gid;
465
466 if (sample_idx >= n_samples) return;
467
468 float min_dist = 1e38f;
469 int min_label = 0;
470
471 for (uint c = 0; c < n_centroids; c++) {
472 float dist = 0.0f;
473 for (uint f = 0; f < n_features; f++) {
474 float diff = data[sample_idx * n_features + f] - centroids[c * n_features + f];
475 dist += diff * diff;
476 }
477
478 if (dist < min_dist) {
479 min_dist = dist;
480 min_label = (int)c;
481 }
482 }
483
484 labels[sample_idx] = min_label;
485 distances[sample_idx] = min_dist;
486}
487
488kernel void batch_squared_euclidean(
489 device const float* points [[buffer(0)]],
490 device const float* centroids [[buffer(1)]],
491 device float* distances [[buffer(2)]],
492 constant uint& n_points [[buffer(3)]],
493 constant uint& n_centroids [[buffer(4)]],
494 constant uint& n_features [[buffer(5)]],
495 uint2 gid [[thread_position_in_grid]]
496) {
497 uint point_idx = gid.y;
498 uint centroid_idx = gid.x;
499
500 if (point_idx >= n_points || centroid_idx >= n_centroids) return;
501
502 float sum = 0.0f;
503 for (uint k = 0; k < n_features; k++) {
504 float diff = points[point_idx * n_features + k] - centroids[centroid_idx * n_features + k];
505 sum += diff * diff;
506 }
507
508 distances[point_idx * n_centroids + centroid_idx] = sum;
509}
510"#
511 .to_string()
512}
513
514pub fn generate_rocm_distance_kernel() -> String {
520 r#"
521#include <hip/hip_runtime.h>
522
523extern "C" __global__ void squared_euclidean_distance_matrix(
524 const float* __restrict__ data,
525 float* __restrict__ distances,
526 const int n_samples,
527 const int n_features
528) {
529 const int row = blockIdx.y * blockDim.y + threadIdx.y;
530 const int col = blockIdx.x * blockDim.x + threadIdx.x;
531
532 if (row >= n_samples || col >= n_samples) return;
533 if (row > col) return;
534
535 float sum = 0.0f;
536 for (int k = 0; k < n_features; k++) {
537 float diff = data[row * n_features + k] - data[col * n_features + k];
538 sum += diff * diff;
539 }
540
541 distances[row * n_samples + col] = sum;
542 distances[col * n_samples + row] = sum;
543}
544
545extern "C" __global__ void kmeans_assign_labels(
546 const float* __restrict__ data,
547 const float* __restrict__ centroids,
548 int* __restrict__ labels,
549 float* __restrict__ distances,
550 const int n_samples,
551 const int n_centroids,
552 const int n_features
553) {
554 const int sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
555
556 if (sample_idx >= n_samples) return;
557
558 float min_dist = 1e38f;
559 int min_label = 0;
560
561 for (int c = 0; c < n_centroids; c++) {
562 float dist = 0.0f;
563 for (int f = 0; f < n_features; f++) {
564 float diff = data[sample_idx * n_features + f] - centroids[c * n_features + f];
565 dist += diff * diff;
566 }
567
568 if (dist < min_dist) {
569 min_dist = dist;
570 min_label = c;
571 }
572 }
573
574 labels[sample_idx] = min_label;
575 distances[sample_idx] = min_dist;
576}
577"#
578 .to_string()
579}
580
581#[derive(Debug, Clone)]
587pub struct KernelConfig {
588 pub block_size: (usize, usize, usize),
590 pub grid_size: (usize, usize, usize),
592 pub shared_mem_size: usize,
594 pub use_tensor_cores: bool,
596 pub data_type: KernelDataType,
598}
599
600impl Default for KernelConfig {
601 fn default() -> Self {
602 Self {
603 block_size: (16, 16, 1),
604 grid_size: (1, 1, 1),
605 shared_mem_size: 0,
606 use_tensor_cores: false,
607 data_type: KernelDataType::Float32,
608 }
609 }
610}
611
612#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
614pub enum KernelDataType {
615 Float16,
617 Float32,
619 Float64,
621 Int8,
623}
624
625pub fn get_kernel_source(backend: GpuBackend, kernel_type: DistanceKernelType) -> Result<String> {
627 match backend {
628 GpuBackend::Cuda => Ok(generate_cuda_distance_matrix_kernel()),
629 GpuBackend::OpenCl => Ok(generate_opencl_distance_matrix_kernel()),
630 GpuBackend::Metal => Ok(generate_metal_distance_kernel()),
631 GpuBackend::Rocm => Ok(generate_rocm_distance_kernel()),
632 GpuBackend::CpuFallback => Err(ClusteringError::InvalidInput(
633 "CPU fallback does not use GPU kernels".to_string(),
634 )),
635 _ => Err(ClusteringError::InvalidInput(format!(
636 "Backend {:?} not supported for kernel generation",
637 backend
638 ))),
639 }
640}
641
642pub fn get_kmeans_kernel_source(backend: GpuBackend) -> Result<String> {
644 match backend {
645 GpuBackend::Cuda => Ok(generate_cuda_kmeans_assign_kernel()),
646 GpuBackend::OpenCl => Ok(generate_opencl_distance_matrix_kernel()), GpuBackend::Metal => Ok(generate_metal_distance_kernel()),
648 GpuBackend::Rocm => Ok(generate_rocm_distance_kernel()),
649 GpuBackend::CpuFallback => Err(ClusteringError::InvalidInput(
650 "CPU fallback does not use GPU kernels".to_string(),
651 )),
652 _ => Err(ClusteringError::InvalidInput(format!(
653 "Backend {:?} not supported for kernel generation",
654 backend
655 ))),
656 }
657}
658
659pub fn calculate_kernel_config(
661 n_samples: usize,
662 n_features: usize,
663 backend: GpuBackend,
664) -> KernelConfig {
665 let block_size = match backend {
666 GpuBackend::Cuda | GpuBackend::Rocm => {
667 if n_samples <= 256 {
669 (16, 16, 1)
670 } else {
671 (32, 32, 1)
672 }
673 }
674 GpuBackend::Metal => {
675 (16, 16, 1)
677 }
678 GpuBackend::OpenCl => {
679 (16, 16, 1)
681 }
682 _ => (16, 16, 1),
683 };
684
685 let grid_size = (
686 (n_samples + block_size.0 - 1) / block_size.0,
687 (n_samples + block_size.1 - 1) / block_size.1,
688 1,
689 );
690
691 let tile_size = block_size.0;
693 let shared_mem_size = 2 * tile_size * n_features * std::mem::size_of::<f32>();
694
695 KernelConfig {
696 block_size,
697 grid_size,
698 shared_mem_size,
699 use_tensor_cores: matches!(backend, GpuBackend::Cuda | GpuBackend::Rocm),
700 data_type: KernelDataType::Float32,
701 }
702}
703
704#[cfg(test)]
709mod tests {
710 use super::*;
711
712 #[test]
713 fn test_cuda_kernel_generation() {
714 let kernel = generate_cuda_distance_matrix_kernel();
715 assert!(kernel.contains("squared_euclidean_distance_matrix"));
716 assert!(kernel.contains("euclidean_distance_matrix"));
717 assert!(kernel.contains("cosine_distance_matrix"));
718 }
719
720 #[test]
721 fn test_cuda_kmeans_kernel_generation() {
722 let kernel = generate_cuda_kmeans_assign_kernel();
723 assert!(kernel.contains("kmeans_assign_labels"));
724 assert!(kernel.contains("kmeans_compute_centroids"));
725 }
726
727 #[test]
728 fn test_opencl_kernel_generation() {
729 let kernel = generate_opencl_distance_matrix_kernel();
730 assert!(kernel.contains("__kernel"));
731 assert!(kernel.contains("squared_euclidean_distance_matrix"));
732 }
733
734 #[test]
735 fn test_metal_kernel_generation() {
736 let kernel = generate_metal_distance_kernel();
737 assert!(kernel.contains("using namespace metal"));
738 assert!(kernel.contains("squared_euclidean_distance_matrix"));
739 }
740
741 #[test]
742 fn test_rocm_kernel_generation() {
743 let kernel = generate_rocm_distance_kernel();
744 assert!(kernel.contains("hip/hip_runtime.h"));
745 assert!(kernel.contains("squared_euclidean_distance_matrix"));
746 }
747
748 #[test]
749 fn test_get_kernel_source() {
750 let cuda_source = get_kernel_source(GpuBackend::Cuda, DistanceKernelType::Euclidean);
751 assert!(cuda_source.is_ok());
752
753 let cpu_source = get_kernel_source(GpuBackend::CpuFallback, DistanceKernelType::Euclidean);
754 assert!(cpu_source.is_err());
755 }
756
757 #[test]
758 fn test_kernel_config_calculation() {
759 let config = calculate_kernel_config(1000, 50, GpuBackend::Cuda);
760 assert!(config.block_size.0 > 0);
761 assert!(config.grid_size.0 > 0);
762 assert!(config.use_tensor_cores);
763 }
764
765 #[test]
766 fn test_kernel_config_default() {
767 let config = KernelConfig::default();
768 assert_eq!(config.block_size, (16, 16, 1));
769 assert_eq!(config.data_type, KernelDataType::Float32);
770 }
771}