1use crate::error::SpatialResult;
48use crate::memory_pool::DistancePool;
49use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
50use std::path::Path;
51use std::process::Command;
52use std::sync::Arc;
53
54type GpuDeviceInfoResult = Result<(Vec<String>, Vec<(usize, usize)>), Box<dyn std::error::Error>>;
56
57#[derive(Debug, Clone)]
59pub struct GpuCapabilities {
60 pub gpu_available: bool,
62 pub device_count: usize,
64 pub total_memory: usize,
66 pub available_memory: usize,
68 pub compute_capability: Option<(u32, u32)>,
70 pub max_threads_per_block: usize,
72 pub max_blocks_per_grid: usize,
74 pub device_names: Vec<String>,
76 pub supported_backends: Vec<GpuBackend>,
78}
79
80impl Default for GpuCapabilities {
81 fn default() -> Self {
82 Self {
83 gpu_available: false,
84 device_count: 0,
85 total_memory: 0,
86 available_memory: 0,
87 compute_capability: None,
88 max_threads_per_block: 1024,
89 max_blocks_per_grid: 65535,
90 device_names: Vec::new(),
91 supported_backends: Vec::new(),
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq)]
98pub enum GpuBackend {
99 Cuda,
101 Rocm,
103 LevelZero,
105 Vulkan,
107 CpuFallback,
109}
110
111pub struct GpuDevice {
113 capabilities: GpuCapabilities,
114 preferred_backend: GpuBackend,
115 #[allow(dead_code)]
116 memory_pool: Arc<DistancePool>,
117}
118
119impl GpuDevice {
120 pub fn new() -> SpatialResult<Self> {
122 let capabilities = Self::detect_capabilities()?;
123 let preferred_backend = Self::select_optimal_backend(&capabilities);
124 let memory_pool = Arc::new(DistancePool::new(1000));
125
126 Ok(Self {
127 capabilities,
128 preferred_backend,
129 memory_pool,
130 })
131 }
132
133 fn detect_capabilities() -> SpatialResult<GpuCapabilities> {
135 let mut caps = GpuCapabilities::default();
136
137 #[cfg(feature = "cuda")]
139 {
140 if Self::check_cuda_available() {
141 caps.gpu_available = true;
142 caps.device_count = Self::get_cuda_device_count();
143 caps.supported_backends.push(GpuBackend::Cuda);
144
145 if let Ok((names, memory_info)) = Self::get_cuda_device_info() {
147 caps.device_names = names;
148 if let Some((total, available)) = memory_info.first() {
149 caps.total_memory = *total;
150 caps.available_memory = *available;
151 }
152 }
153
154 caps.max_threads_per_block = 1024;
156 caps.max_blocks_per_grid = 2147483647; caps.compute_capability = Self::get_cuda_compute_capability();
158 }
159 }
160
161 #[cfg(feature = "rocm")]
163 {
164 if Self::check_rocm_available() {
165 caps.gpu_available = true;
166 let rocm_count = Self::get_rocm_device_count();
167 if rocm_count > caps.device_count {
168 caps.device_count = rocm_count;
169 }
170 caps.supported_backends.push(GpuBackend::Rocm);
171
172 if let Ok((names, memory_info)) = Self::get_rocm_device_info() {
174 if caps.device_names.is_empty() {
175 caps.device_names = names;
176 } else {
177 caps.device_names.extend(names);
178 }
179 if let Some((total, available)) = memory_info.first() {
180 if caps.total_memory == 0 {
181 caps.total_memory = *total;
182 caps.available_memory = *available;
183 }
184 }
185 }
186
187 caps.max_threads_per_block = 1024;
189 caps.max_blocks_per_grid = 2147483647;
190 }
191 }
192
193 #[cfg(feature = "vulkan")]
195 {
196 if Self::check_vulkan_available() {
197 caps.gpu_available = true;
198 caps.supported_backends.push(GpuBackend::Vulkan);
199
200 if let Ok((names, memory_info)) = Self::get_vulkan_device_info() {
202 if caps.device_names.is_empty() {
203 caps.device_names = names;
204 } else {
205 caps.device_names.extend(names);
206 }
207 if let Some((total, available)) = memory_info.first() {
208 if caps.total_memory == 0 {
209 caps.total_memory = *total;
210 caps.available_memory = *available;
211 }
212 }
213 }
214 }
215 }
216
217 caps.supported_backends.push(GpuBackend::CpuFallback);
219
220 Ok(caps)
221 }
222
223 fn select_optimal_backend(caps: &GpuCapabilities) -> GpuBackend {
225 if caps.supported_backends.contains(&GpuBackend::Cuda) {
227 GpuBackend::Cuda
228 } else if caps.supported_backends.contains(&GpuBackend::Rocm) {
229 GpuBackend::Rocm
230 } else if caps.supported_backends.contains(&GpuBackend::Vulkan) {
231 GpuBackend::Vulkan
232 } else {
233 GpuBackend::CpuFallback
234 }
235 }
236
237 pub fn is_gpu_available(&self) -> bool {
239 self.capabilities.gpu_available
240 }
241
242 pub fn capabilities(&self) -> &GpuCapabilities {
244 &self.capabilities
245 }
246
247 pub fn optimal_block_size(&self, _problemsize: usize) -> usize {
249 match self.preferred_backend {
250 GpuBackend::Cuda => {
251 let warp_size = 32;
253 let optimal = (_problemsize / warp_size).max(1) * warp_size;
254 optimal.min(self.capabilities.max_threads_per_block)
255 }
256 GpuBackend::Rocm => {
257 let wavefront_size = 64;
259 let optimal = (_problemsize / wavefront_size).max(1) * wavefront_size;
260 optimal.min(self.capabilities.max_threads_per_block)
261 }
262 _ => {
263 256.min(self.capabilities.max_threads_per_block)
265 }
266 }
267 }
268
269 #[cfg(feature = "cuda")]
271 fn check_cuda_available() -> bool {
272 if let Ok(output) = Command::new("nvidia-smi")
274 .arg("--query-gpu=count")
275 .arg("--format=csv,noheader,nounits")
276 .output()
277 {
278 if output.status.success() {
279 if let Ok(count_str) = String::from_utf8(output.stdout) {
280 if let Ok(count) = count_str.trim().parse::<u32>() {
281 return count > 0;
282 }
283 }
284 }
285 }
286
287 #[cfg(target_os = "linux")]
289 {
290 Path::exists(Path::new("/usr/local/cuda/lib64/libcuda.so"))
291 || Path::exists(Path::new("/usr/lib/x86_64-linux-gnu/libcuda.so"))
292 || Path::exists(Path::new("/usr/lib64/libcuda.so"))
293 }
294
295 #[cfg(target_os = "windows")]
296 {
297 Path::exists(Path::new("C:\\Windows\\System32\\nvcuda.dll"))
298 }
299
300 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
301 false
302 }
303
304 #[cfg(feature = "cuda")]
305 fn get_cuda_device_count() -> usize {
306 if let Ok(output) = Command::new("nvidia-smi")
307 .arg("--query-gpu=count")
308 .arg("--format=csv,noheader,nounits")
309 .output()
310 {
311 if output.status.success() {
312 if let Ok(count_str) = String::from_utf8(output.stdout) {
313 if let Ok(count) = count_str.trim().parse::<usize>() {
314 return count;
315 }
316 }
317 }
318 }
319
320 if let Ok(output) = Command::new("nvidia-smi").arg("-L").output() {
322 if output.status.success() {
323 if let Ok(list_str) = String::from_utf8(output.stdout) {
324 return list_str
325 .lines()
326 .filter(|line| line.starts_with("GPU "))
327 .count();
328 }
329 }
330 }
331
332 0
333 }
334
335 #[cfg(feature = "rocm")]
336 fn check_rocm_available() -> bool {
337 if let Ok(output) = Command::new("rocm-smi").arg("--showid").output() {
339 if output.status.success() {
340 return true;
341 }
342 }
343
344 #[cfg(target_os = "linux")]
346 {
347 Path::exists(Path::new("/opt/rocm/lib/libhip.so"))
348 || Path::exists(Path::new("/usr/lib/libhip.so"))
349 || Path::exists(Path::new("/usr/lib/x86_64-linux-gnu/libhip.so"))
350 }
351
352 #[cfg(not(target_os = "linux"))]
353 false
354 }
355
356 #[cfg(feature = "rocm")]
357 fn get_rocm_device_count() -> usize {
358 if let Ok(output) = Command::new("rocm-smi").arg("--showid").output() {
359 if output.status.success() {
360 if let Ok(list_str) = String::from_utf8(output.stdout) {
361 return list_str
363 .lines()
364 .filter(|line| line.contains("GPU") || line.contains("card"))
365 .count();
366 }
367 }
368 }
369
370 #[cfg(target_os = "linux")]
372 {
373 use std::fs;
374 if let Ok(entries) = fs::read_dir("/sys/class/drm") {
375 let count = entries
376 .filter_map(Result::ok)
377 .filter(|entry| {
378 if let Ok(name) = entry.file_name().into_string() {
379 name.starts_with("card") && !name.contains("-")
380 } else {
381 false
382 }
383 })
384 .count();
385 if count > 0 {
386 return count;
387 }
388 }
389 }
390
391 0
392 }
393
394 #[cfg(feature = "vulkan")]
395 fn check_vulkan_available() -> bool {
396 if let Ok(output) = Command::new("vulkaninfo").arg("--summary").output() {
398 if output.status.success() {
399 if let Ok(info_str) = String::from_utf8(output.stdout) {
400 return info_str.contains("VK_QUEUE_COMPUTE_BIT")
402 || info_str.contains("deviceType");
403 }
404 }
405 }
406
407 #[cfg(target_os = "linux")]
409 {
410 Path::exists(Path::new("/usr/lib/libvulkan.so"))
411 || Path::exists(Path::new("/usr/lib/x86_64-linux-gnu/libvulkan.so"))
412 || Path::exists(Path::new("/usr/local/lib/libvulkan.so"))
413 }
414
415 #[cfg(target_os = "windows")]
416 {
417 Path::exists(Path::new("C:\\Windows\\System32\\vulkan-1.dll"))
418 }
419
420 #[cfg(target_os = "macos")]
421 {
422 Path::exists(Path::new("/usr/local/lib/libvulkan.dylib"))
423 || Path::exists(Path::new(
424 "/System/Library/Frameworks/Metal.framework/Metal",
425 ))
426 }
427
428 #[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))]
429 false
430 }
431
432 #[cfg(feature = "cuda")]
434 fn get_cuda_device_info() -> GpuDeviceInfoResult {
435 let mut device_names = Vec::new();
436 let mut memory_info = Vec::new();
437
438 if let Ok(output) = Command::new("nvidia-smi")
440 .arg("--query-gpu=name")
441 .arg("--format=csv,noheader,nounits")
442 .output()
443 {
444 if output.status.success() {
445 if let Ok(names_str) = String::from_utf8(output.stdout) {
446 device_names = names_str.lines().map(|s| s.trim().to_string()).collect();
447 }
448 }
449 }
450
451 if let Ok(output) = Command::new("nvidia-smi")
453 .arg("--query-gpu=memory.total,memory.free")
454 .arg("--format=csv,noheader,nounits")
455 .output()
456 {
457 if output.status.success() {
458 if let Ok(memory_str) = String::from_utf8(output.stdout) {
459 for line in memory_str.lines() {
460 let parts: Vec<&str> = line.split(',').collect();
461 if parts.len() >= 2 {
462 if let (Ok(total), Ok(free)) = (
463 parts[0].trim().parse::<usize>(),
464 parts[1].trim().parse::<usize>(),
465 ) {
466 memory_info.push((total * 1024 * 1024, free * 1024 * 1024));
467 }
469 }
470 }
471 }
472 }
473 }
474
475 Ok((device_names, memory_info))
476 }
477
478 #[cfg(feature = "cuda")]
480 fn get_cuda_compute_capability() -> Option<(u32, u32)> {
481 if let Ok(output) = Command::new("nvidia-smi")
482 .arg("--query-gpu=compute_cap")
483 .arg("--format=csv,noheader,nounits")
484 .output()
485 {
486 if output.status.success() {
487 if let Ok(cap_str) = String::from_utf8(output.stdout) {
488 if let Some(line) = cap_str.lines().next() {
489 let parts: Vec<&str> = line.trim().split('.').collect();
490 if parts.len() >= 2 {
491 if let (Ok(major), Ok(minor)) =
492 (parts[0].parse::<u32>(), parts[1].parse::<u32>())
493 {
494 return Some((major, minor));
495 }
496 }
497 }
498 }
499 }
500 }
501
502 None
503 }
504
505 #[cfg(feature = "rocm")]
507 fn get_rocm_device_info() -> GpuDeviceInfoResult {
508 let mut device_names = Vec::new();
509 let mut memory_info = Vec::new();
510
511 if let Ok(output) = Command::new("rocm-smi").arg("--showproductname").output() {
513 if output.status.success() {
514 if let Ok(info_str) = String::from_utf8(output.stdout) {
515 for line in info_str.lines() {
516 if line.contains("Card series:") {
517 if let Some(name) = line.split(':').nth(1) {
518 device_names.push(name.trim().to_string());
519 }
520 }
521 }
522 }
523 }
524 }
525
526 if let Ok(output) = Command::new("rocm-smi")
528 .arg("--showmeminfo")
529 .arg("vram")
530 .output()
531 {
532 if output.status.success() {
533 if let Ok(memory_str) = String::from_utf8(output.stdout) {
534 for line in memory_str.lines() {
535 if line.contains("Total memory") || line.contains("Used memory") {
536 if let Some(mem_part) = line
537 .split_whitespace()
538 .find(|s| s.ends_with("MB") || s.ends_with("GB"))
539 {
540 if let Ok(mem_val) = mem_part
541 .trim_end_matches("MB")
542 .trim_end_matches("GB")
543 .parse::<usize>()
544 {
545 let bytes = if mem_part.ends_with("GB") {
546 mem_val * 1024 * 1024 * 1024
547 } else {
548 mem_val * 1024 * 1024
549 };
550 memory_info.push((bytes, bytes / 2));
551 }
552 }
553 }
554 }
555 }
556 }
557 }
558
559 if device_names.is_empty() && memory_info.is_empty() {
560 device_names.push("AMD GPU (ROCm)".to_string());
561 memory_info.push((8 * 1024 * 1024 * 1024, 6 * 1024 * 1024 * 1024));
562 }
563
564 Ok((device_names, memory_info))
565 }
566
567 #[cfg(feature = "vulkan")]
569 fn get_vulkan_device_info() -> GpuDeviceInfoResult {
570 let mut device_names = Vec::new();
571 let mut memory_info = Vec::new();
572
573 if let Ok(output) = Command::new("vulkaninfo").arg("--summary").output() {
574 if output.status.success() {
575 if let Ok(info_str) = String::from_utf8(output.stdout) {
576 for line in info_str.lines() {
577 if line.contains("deviceName") {
578 if let Some(name_part) = line.split('=').nth(1) {
579 device_names.push(name_part.trim().to_string());
580 }
581 } else if line.contains("heapSize") {
582 if let Some(mem_part) = line.split('=').nth(1) {
583 if let Ok(mem_val) = mem_part.trim().parse::<usize>() {
584 memory_info.push((mem_val, mem_val * 3 / 4));
585 }
586 }
587 }
588 }
589 }
590 }
591 }
592
593 if device_names.is_empty() {
594 device_names.push("Vulkan Device".to_string());
595 memory_info.push((4 * 1024 * 1024 * 1024, 3 * 1024 * 1024 * 1024));
596 }
597
598 Ok((device_names, memory_info))
599 }
600
601 #[cfg(not(feature = "cuda"))]
602 #[allow(dead_code)]
603 fn get_cuda_device_info() -> GpuDeviceInfoResult {
604 Ok((Vec::new(), Vec::new()))
605 }
606
607 #[cfg(not(feature = "cuda"))]
608 #[allow(dead_code)]
609 fn get_cuda_compute_capability() -> Option<(u32, u32)> {
610 None
611 }
612
613 #[cfg(not(feature = "rocm"))]
614 #[allow(dead_code)]
615 fn get_rocm_device_info() -> GpuDeviceInfoResult {
616 Ok((Vec::new(), Vec::new()))
617 }
618
619 #[cfg(not(feature = "vulkan"))]
620 #[allow(dead_code)]
621 fn get_vulkan_device_info() -> GpuDeviceInfoResult {
622 Ok((Vec::new(), Vec::new()))
623 }
624}
625
626impl Default for GpuDevice {
627 fn default() -> Self {
628 Self::new().unwrap_or_else(|_| Self {
629 capabilities: GpuCapabilities::default(),
630 preferred_backend: GpuBackend::CpuFallback,
631 memory_pool: Arc::new(DistancePool::new(1000)),
632 })
633 }
634}
635
636pub struct GpuDistanceMatrix {
638 device: Arc<GpuDevice>,
639 batch_size: usize,
640 use_mixed_precision: bool,
641}
642
643impl GpuDistanceMatrix {
644 pub fn new() -> SpatialResult<Self> {
646 let device = Arc::new(GpuDevice::new()?);
647 Ok(Self {
648 device,
649 batch_size: 1024,
650 use_mixed_precision: true,
651 })
652 }
653
654 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
656 self.batch_size = batchsize;
657 self
658 }
659
660 pub fn with_mixed_precision(mut self, use_mixedprecision: bool) -> Self {
662 self.use_mixed_precision = use_mixedprecision;
663 self
664 }
665
666 pub async fn compute_parallel(
668 &self,
669 points: &ArrayView2<'_, f64>,
670 ) -> SpatialResult<Array2<f64>> {
671 let _n_points = points.nrows();
672
673 if !self.device.is_gpu_available() {
674 return self.compute_cpu_fallback(points).await;
675 }
676
677 match self.device.preferred_backend {
678 GpuBackend::Cuda => self.compute_cuda(points).await,
679 GpuBackend::Rocm => self.compute_rocm(points).await,
680 GpuBackend::Vulkan => self.compute_vulkan(points).await,
681 GpuBackend::CpuFallback => self.compute_cpu_fallback(points).await,
682 GpuBackend::LevelZero => self.compute_cpu_fallback(points).await, }
684 }
685
686 async fn compute_cuda(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
688 self.compute_cpu_fallback(points).await
696 }
697
698 async fn compute_rocm(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
700 self.compute_cpu_fallback(points).await
702 }
703
704 async fn compute_vulkan(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
706 self.compute_cpu_fallback(points).await
708 }
709
710 async fn compute_cpu_fallback(
712 &self,
713 points: &ArrayView2<'_, f64>,
714 ) -> SpatialResult<Array2<f64>> {
715 use crate::simd_distance::parallel_pdist;
717
718 let condensed = parallel_pdist(points, "euclidean")?;
719
720 let n = points.nrows();
722 let mut matrix = Array2::zeros((n, n));
723
724 let mut idx = 0;
725 for i in 0..n {
726 for j in (i + 1)..n {
727 matrix[[i, j]] = condensed[idx];
728 matrix[[j, i]] = condensed[idx];
729 idx += 1;
730 }
731 }
732
733 Ok(matrix)
734 }
735}
736
737pub struct GpuKMeans {
739 device: Arc<GpuDevice>,
740 k: usize,
741 max_iterations: usize,
742 tolerance: f64,
743 batch_size: usize,
744}
745
746impl GpuKMeans {
747 pub fn new(k: usize) -> SpatialResult<Self> {
749 let device = Arc::new(GpuDevice::new()?);
750 Ok(Self {
751 device,
752 k,
753 max_iterations: 100,
754 tolerance: 1e-6,
755 batch_size: 1024,
756 })
757 }
758
759 pub fn with_max_iterations(mut self, maxiterations: usize) -> Self {
761 self.max_iterations = maxiterations;
762 self
763 }
764
765 pub fn with_tolerance(mut self, tolerance: f64) -> Self {
767 self.tolerance = tolerance;
768 self
769 }
770
771 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
773 self.batch_size = batchsize;
774 self
775 }
776
777 pub async fn fit(
779 &self,
780 points: &ArrayView2<'_, f64>,
781 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
782 if !self.device.is_gpu_available() {
783 return self.fit_cpu_fallback(points).await;
784 }
785
786 match self.device.preferred_backend {
787 GpuBackend::Cuda => self.fit_cuda(points).await,
788 GpuBackend::Rocm => self.fit_rocm(points).await,
789 GpuBackend::Vulkan => self.fit_vulkan(points).await,
790 GpuBackend::CpuFallback => self.fit_cpu_fallback(points).await,
791 GpuBackend::LevelZero => self.fit_cpu_fallback(points).await, }
793 }
794
795 async fn fit_cuda(
797 &self,
798 points: &ArrayView2<'_, f64>,
799 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
800 self.fit_cpu_fallback(points).await
807 }
808
809 async fn fit_rocm(
811 &self,
812 points: &ArrayView2<'_, f64>,
813 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
814 self.fit_cpu_fallback(points).await
816 }
817
818 async fn fit_vulkan(
820 &self,
821 points: &ArrayView2<'_, f64>,
822 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
823 self.fit_cpu_fallback(points).await
825 }
826
827 async fn fit_cpu_fallback(
829 &self,
830 points: &ArrayView2<'_, f64>,
831 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
832 use crate::simd_distance::advanced_simd_clustering::AdvancedSimdKMeans;
834
835 let advanced_kmeans = AdvancedSimdKMeans::new(self.k)
836 .with_mixed_precision(true)
837 .with_block_size(256);
838
839 advanced_kmeans.fit(points)
840 }
841}
842
843pub struct GpuNearestNeighbors {
845 device: Arc<GpuDevice>,
846 #[allow(dead_code)]
847 build_batch_size: usize,
848 #[allow(dead_code)]
849 query_batch_size: usize,
850}
851
852impl GpuNearestNeighbors {
853 pub fn new() -> SpatialResult<Self> {
855 let device = Arc::new(GpuDevice::new()?);
856 Ok(Self {
857 device,
858 build_batch_size: 1024,
859 query_batch_size: 256,
860 })
861 }
862
863 pub async fn knn_search(
865 &self,
866 query_points: &ArrayView2<'_, f64>,
867 data_points: &ArrayView2<'_, f64>,
868 k: usize,
869 ) -> SpatialResult<(Array2<usize>, Array2<f64>)> {
870 if !self.device.is_gpu_available() {
871 return self
872 .knn_search_cpu_fallback(query_points, data_points, k)
873 .await;
874 }
875
876 match self.device.preferred_backend {
877 GpuBackend::Cuda => self.knn_search_cuda(query_points, data_points, k).await,
878 GpuBackend::Rocm => self.knn_search_rocm(query_points, data_points, k).await,
879 GpuBackend::Vulkan => self.knn_search_vulkan(query_points, data_points, k).await,
880 GpuBackend::CpuFallback => {
881 self.knn_search_cpu_fallback(query_points, data_points, k)
882 .await
883 }
884 _ => {
885 self.knn_search_cpu_fallback(query_points, data_points, k)
886 .await
887 }
888 }
889 }
890
891 async fn knn_search_cuda(
893 &self,
894 query_points: &ArrayView2<'_, f64>,
895 data_points: &ArrayView2<'_, f64>,
896 k: usize,
897 ) -> SpatialResult<(Array2<usize>, Array2<f64>)> {
898 self.knn_search_cpu_fallback(query_points, data_points, k)
904 .await
905 }
906
907 async fn knn_search_rocm(
909 &self,
910 query_points: &ArrayView2<'_, f64>,
911 data_points: &ArrayView2<'_, f64>,
912 k: usize,
913 ) -> SpatialResult<(Array2<usize>, Array2<f64>)> {
914 self.knn_search_cpu_fallback(query_points, data_points, k)
915 .await
916 }
917
918 async fn knn_search_vulkan(
920 &self,
921 query_points: &ArrayView2<'_, f64>,
922 data_points: &ArrayView2<'_, f64>,
923 k: usize,
924 ) -> SpatialResult<(Array2<usize>, Array2<f64>)> {
925 self.knn_search_cpu_fallback(query_points, data_points, k)
926 .await
927 }
928
929 async fn knn_search_cpu_fallback(
931 &self,
932 query_points: &ArrayView2<'_, f64>,
933 data_points: &ArrayView2<'_, f64>,
934 k: usize,
935 ) -> SpatialResult<(Array2<usize>, Array2<f64>)> {
936 use crate::simd_distance::advanced_simd_clustering::AdvancedSimdNearestNeighbors;
938
939 let advanced_nn = AdvancedSimdNearestNeighbors::new();
940 advanced_nn.simd_knn_advanced_fast(query_points, data_points, k)
941 }
942}
943
944impl Default for GpuNearestNeighbors {
945 fn default() -> Self {
946 Self::new().unwrap_or_else(|_| Self {
947 device: Arc::new(GpuDevice::default()),
948 build_batch_size: 1024,
949 query_batch_size: 256,
950 })
951 }
952}
953
954pub struct HybridProcessor {
956 gpu_device: Arc<GpuDevice>,
957 cpu_threshold: usize,
958 gpu_threshold: usize,
959}
960
961impl HybridProcessor {
962 pub fn new() -> SpatialResult<Self> {
964 let gpu_device = Arc::new(GpuDevice::new()?);
965 Ok(Self {
966 gpu_device,
967 cpu_threshold: 1000, gpu_threshold: 100000, })
970 }
971
972 pub fn choose_strategy(&self, _datasetsize: usize) -> ProcessingStrategy {
974 if !self.gpu_device.is_gpu_available() {
975 return ProcessingStrategy::CpuOnly;
976 }
977
978 if _datasetsize < self.cpu_threshold {
979 ProcessingStrategy::CpuOnly
980 } else if _datasetsize < self.gpu_threshold {
981 ProcessingStrategy::Hybrid
982 } else {
983 ProcessingStrategy::GpuOnly
984 }
985 }
986
987 pub fn optimal_batch_sizes(&self, _totalsize: usize) -> (usize, usize) {
989 let gpu_capability = self.gpu_device.capabilities().total_memory / (8 * 1024); let cpu_batch = (_totalsize / 4).max(1000); let gpu_batch = (_totalsize * 3 / 4).min(gpu_capability); (cpu_batch, gpu_batch)
994 }
995}
996
997impl Default for HybridProcessor {
998 fn default() -> Self {
999 Self::new().unwrap_or_else(|_| Self {
1000 gpu_device: Arc::new(GpuDevice::default()),
1001 cpu_threshold: 1000,
1002 gpu_threshold: 100000,
1003 })
1004 }
1005}
1006
1007#[derive(Debug, Clone, PartialEq)]
1009pub enum ProcessingStrategy {
1010 CpuOnly,
1012 GpuOnly,
1014 Hybrid,
1016}
1017
1018static GLOBAL_GPU_DEVICE: std::sync::OnceLock<GpuDevice> = std::sync::OnceLock::new();
1020
1021#[allow(dead_code)]
1023pub fn global_gpu_device() -> &'static GpuDevice {
1024 GLOBAL_GPU_DEVICE.get_or_init(GpuDevice::default)
1025}
1026
1027#[allow(dead_code)]
1029pub fn is_gpu_acceleration_available() -> bool {
1030 global_gpu_device().is_gpu_available()
1031}
1032
1033#[allow(dead_code)]
1035pub fn get_gpu_capabilities() -> &'static GpuCapabilities {
1036 global_gpu_device().capabilities()
1037}
1038
1039#[allow(dead_code)]
1041pub fn report_gpu_status() {
1042 let device = global_gpu_device();
1043 let caps = device.capabilities();
1044
1045 println!("GPU Acceleration Status:");
1046 println!(" Available: {}", caps.gpu_available);
1047 println!(" Device Count: {}", caps.device_count);
1048
1049 if caps.gpu_available {
1050 println!(
1051 " Total Memory: {:.1} GB",
1052 caps.total_memory as f64 / (1024.0 * 1024.0 * 1024.0)
1053 );
1054 println!(
1055 " Available Memory: {:.1} GB",
1056 caps.available_memory as f64 / (1024.0 * 1024.0 * 1024.0)
1057 );
1058 println!(" Max Threads/Block: {}", caps.max_threads_per_block);
1059 println!(" Supported Backends: {:?}", caps.supported_backends);
1060
1061 for (i, name) in caps.device_names.iter().enumerate() {
1062 println!(" Device {i}: {name}");
1063 }
1064 } else {
1065 println!(" Reason: No compatible GPU devices found");
1066 println!(" Fallback: Using optimized CPU SIMD operations");
1067 }
1068}
1069
1070#[cfg(test)]
1071mod tests {
1072 use super::*;
1073 use scirs2_core::ndarray::array;
1074
1075 #[test]
1076 fn test_gpu_device_creation() {
1077 let device = GpuDevice::new();
1078 assert!(device.is_ok());
1079
1080 let device = device.unwrap();
1081 assert!(!device.capabilities().supported_backends.is_empty());
1083 }
1084
1085 #[test]
1086 fn test_processing_strategy_selection() {
1087 let processor = HybridProcessor::new().unwrap();
1088
1089 let strategy = processor.choose_strategy(500);
1091 assert_eq!(strategy, ProcessingStrategy::CpuOnly);
1092
1093 let strategy = processor.choose_strategy(200000);
1095 assert!(matches!(
1097 strategy,
1098 ProcessingStrategy::GpuOnly | ProcessingStrategy::CpuOnly
1099 ));
1100 }
1101
1102 #[test]
1103 #[ignore] fn test_gpu_distance_matrix() {
1105 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1106
1107 let gpu_matrix = GpuDistanceMatrix::new().unwrap();
1108 let points_view = points.view();
1110 let _result = gpu_matrix.compute_parallel(&points_view);
1111
1112 }
1118
1119 #[test]
1120 #[ignore] fn test_gpu_kmeans() {
1122 let points = array![
1123 [0.0, 0.0],
1124 [0.1, 0.1],
1125 [0.0, 0.1], [5.0, 5.0],
1127 [5.1, 5.1],
1128 [5.0, 5.1], ];
1130
1131 let gpu_kmeans = GpuKMeans::new(2).unwrap();
1132 let points_view = points.view();
1134 let _result = gpu_kmeans.fit(&points_view);
1135
1136 }
1141
1142 #[test]
1143 #[ignore] fn test_gpu_nearest_neighbors() {
1145 let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1146 let query_points = array![[0.1, 0.1], [0.9, 0.9]];
1147
1148 let gpu_nn = GpuNearestNeighbors::new().unwrap();
1149 let query_view = query_points.view();
1150 let data_view = data_points.view();
1151 let _result = gpu_nn.knn_search(&query_view, &data_view, 2);
1152
1153 }
1162
1163 #[test]
1164 fn test_global_gpu_functions() {
1165 let device = global_gpu_device();
1167 assert!(!device.capabilities.device_names.is_empty() || !device.capabilities.gpu_available);
1168
1169 report_gpu_status();
1171 let _caps = get_gpu_capabilities();
1172 let _available = is_gpu_acceleration_available();
1173 }
1174}