1use crate::error::{FFTError, FFTResult};
7use crate::sparse_fft::{SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, WindowFunction};
8use crate::sparse_fft_gpu::{GPUBackend, GPUSparseFFTConfig};
9use crate::sparse_fft_gpu_memory::{
10 init_cuda_device, init_hip_device, init_sycl_device, is_cuda_available, is_hip_available,
11 is_sycl_available,
12};
13use scirs2_core::numeric::Complex64;
14use scirs2_core::numeric::NumCast;
15use scirs2_core::parallel_ops::*;
16use scirs2_core::simd_ops::PlatformCapabilities;
17use std::collections::HashMap;
18use std::fmt::Debug;
19use std::sync::{Arc, Mutex};
20use std::time::Instant;
21
22#[derive(Debug, Clone)]
24pub struct GPUDeviceInfo {
25 pub device_id: i32,
27 pub backend: GPUBackend,
29 pub device_name: String,
31 pub memory_total: usize,
33 pub memory_free: usize,
35 pub compute_capability: f32,
37 pub compute_units: usize,
39 pub max_threads_per_block: usize,
41 pub is_available: bool,
43}
44
45impl Default for GPUDeviceInfo {
46 fn default() -> Self {
47 Self {
48 device_id: -1,
49 backend: GPUBackend::CPUFallback,
50 device_name: "Unknown Device".to_string(),
51 memory_total: 0,
52 memory_free: 0,
53 compute_capability: 0.0,
54 compute_units: 0,
55 max_threads_per_block: 0,
56 is_available: false,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum WorkloadDistribution {
64 Equal,
66 MemoryBased,
68 ComputeBased,
70 Manual,
72 Adaptive,
74}
75
76#[derive(Debug, Clone)]
78pub struct MultiGPUConfig {
79 pub base_config: SparseFFTConfig,
81 pub distribution: WorkloadDistribution,
83 pub manual_ratios: Vec<f32>,
85 pub max_devices: usize,
87 pub min_signal_size: usize,
89 pub chunk_overlap: usize,
91 pub enable_load_balancing: bool,
93 pub device_timeout_ms: u64,
95}
96
97impl Default for MultiGPUConfig {
98 fn default() -> Self {
99 Self {
100 base_config: SparseFFTConfig::default(),
101 distribution: WorkloadDistribution::ComputeBased,
102 manual_ratios: Vec::new(),
103 max_devices: 0, min_signal_size: 4096, chunk_overlap: 0,
106 enable_load_balancing: true,
107 device_timeout_ms: 5000,
108 }
109 }
110}
111
112pub struct MultiGPUSparseFFT {
114 _config: MultiGPUConfig,
116 devices: Vec<GPUDeviceInfo>,
118 selected_devices: Vec<usize>,
120 performance_history: Arc<Mutex<HashMap<i32, Vec<f64>>>>,
122 initialized: bool,
124}
125
126impl MultiGPUSparseFFT {
127 pub fn new(config: MultiGPUConfig) -> Self {
129 Self {
130 _config: config,
131 devices: Vec::new(),
132 selected_devices: Vec::new(),
133 performance_history: Arc::new(Mutex::new(HashMap::new())),
134 initialized: false,
135 }
136 }
137
138 pub fn initialize(&mut self) -> FFTResult<()> {
140 if self.initialized {
141 return Ok(());
142 }
143
144 self.enumerate_devices()?;
146
147 self.select_devices()?;
149
150 self.initialized = true;
151 Ok(())
152 }
153
154 fn enumerate_devices(&mut self) -> FFTResult<()> {
156 self.devices.clear();
157
158 if is_cuda_available() {
160 self.enumerate_cuda_devices()?;
161 }
162
163 if is_hip_available() {
165 self.enumerate_hip_devices()?;
166 }
167
168 if is_sycl_available() {
170 self.enumerate_sycl_devices()?;
171 }
172
173 self.devices.push(GPUDeviceInfo {
175 device_id: -1,
176 backend: GPUBackend::CPUFallback,
177 device_name: "CPU Fallback".to_string(),
178 memory_total: 16 * 1024 * 1024 * 1024, memory_free: 8 * 1024 * 1024 * 1024, compute_capability: 1.0,
181 compute_units: num_cpus::get(),
182 max_threads_per_block: 1,
183 is_available: true,
184 });
185
186 Ok(())
187 }
188
189 fn enumerate_cuda_devices(&mut self) -> FFTResult<()> {
191 if init_cuda_device()? {
193 self.devices.push(GPUDeviceInfo {
196 device_id: 0,
197 backend: GPUBackend::CUDA,
198 device_name: "NVIDIA GPU (simulated)".to_string(),
199 memory_total: 8 * 1024 * 1024 * 1024, memory_free: 6 * 1024 * 1024 * 1024, compute_capability: 8.6,
202 compute_units: 68,
203 max_threads_per_block: 1024,
204 is_available: true,
205 });
206 }
207
208 Ok(())
209 }
210
211 fn enumerate_hip_devices(&mut self) -> FFTResult<()> {
213 if init_hip_device()? {
215 self.devices.push(GPUDeviceInfo {
218 device_id: 0,
219 backend: GPUBackend::HIP,
220 device_name: "AMD GPU (simulated)".to_string(),
221 memory_total: 16 * 1024 * 1024 * 1024, memory_free: 12 * 1024 * 1024 * 1024, compute_capability: 10.3, compute_units: 40,
225 max_threads_per_block: 256,
226 is_available: true,
227 });
228 }
229
230 Ok(())
231 }
232
233 fn enumerate_sycl_devices(&mut self) -> FFTResult<()> {
235 if init_sycl_device()? {
237 self.devices.push(GPUDeviceInfo {
240 device_id: 0,
241 backend: GPUBackend::SYCL,
242 device_name: "Intel GPU (simulated)".to_string(),
243 memory_total: 4 * 1024 * 1024 * 1024, memory_free: 3 * 1024 * 1024 * 1024, compute_capability: 1.2, compute_units: 96,
247 max_threads_per_block: 512,
248 is_available: true,
249 });
250 }
251
252 Ok(())
253 }
254
255 fn select_devices(&mut self) -> FFTResult<()> {
257 self.selected_devices.clear();
258
259 let available_devices: Vec<(usize, &GPUDeviceInfo)> = self
261 .devices
262 .iter()
263 .enumerate()
264 .filter(|(_, device)| device.is_available)
265 .collect();
266
267 if available_devices.is_empty() {
268 return Err(FFTError::ComputationError(
269 "No available GPU devices found".to_string(),
270 ));
271 }
272
273 let max_devices = if self._config.max_devices == 0 {
275 available_devices.len()
276 } else {
277 self._config.max_devices.min(available_devices.len())
278 };
279
280 match self._config.distribution {
282 WorkloadDistribution::Equal => {
283 for i in 0..max_devices {
285 self.selected_devices.push(available_devices[i].0);
286 }
287 }
288 WorkloadDistribution::ComputeBased => {
289 let mut sorted_devices = available_devices;
291 sorted_devices.sort_by(|a, b| {
292 b.1.compute_capability
293 .partial_cmp(&a.1.compute_capability)
294 .unwrap_or(std::cmp::Ordering::Equal)
295 });
296
297 for i in 0..max_devices {
298 self.selected_devices.push(sorted_devices[i].0);
299 }
300 }
301 WorkloadDistribution::MemoryBased => {
302 let mut sorted_devices = available_devices;
304 sorted_devices.sort_by(|a, b| b.1.memory_free.cmp(&a.1.memory_free));
305
306 for i in 0..max_devices {
307 self.selected_devices.push(sorted_devices[i].0);
308 }
309 }
310 WorkloadDistribution::Manual => {
311 for i in 0..max_devices {
313 self.selected_devices.push(available_devices[i].0);
314 }
315 }
316 WorkloadDistribution::Adaptive => {
317 let available_devices_clone: Vec<(usize, GPUDeviceInfo)> = available_devices
319 .iter()
320 .map(|(idx, device)| (*idx, (*device).clone()))
321 .collect();
322
323 self.select_adaptive_devices_with_clone(available_devices_clone, max_devices)?;
325 }
326 }
327
328 Ok(())
329 }
330
331 fn select_adaptive_devices_with_clone(
333 &mut self,
334 available_devices: Vec<(usize, GPUDeviceInfo)>,
335 max_devices: usize,
336 ) -> FFTResult<()> {
337 let performance_history = self.performance_history.lock().unwrap();
338
339 let mut device_scores: Vec<(usize, f64)> = available_devices
341 .iter()
342 .map(|(idx, device)| {
343 let avg_performance = performance_history
344 .get(&device.device_id)
345 .map(|times| {
346 if times.is_empty() {
347 device.compute_capability as f64 * device.compute_units as f64
349 } else {
350 1.0 / (times.iter().sum::<f64>() / times.len() as f64)
352 }
353 })
354 .unwrap_or_else(|| {
355 device.compute_capability as f64 * device.compute_units as f64
357 });
358
359 (*idx, avg_performance)
360 })
361 .collect();
362
363 device_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
365
366 for i in 0..max_devices {
368 self.selected_devices.push(device_scores[i].0);
369 }
370
371 Ok(())
372 }
373
374 pub fn get_devices(&self) -> &[GPUDeviceInfo] {
376 &self.devices
377 }
378
379 pub fn get_selected_devices(&self) -> Vec<&GPUDeviceInfo> {
381 self.selected_devices
382 .iter()
383 .map(|&idx| &self.devices[idx])
384 .collect()
385 }
386
387 pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
389 where
390 T: NumCast + Copy + Debug + Send + Sync + 'static,
391 {
392 if !self.initialized {
393 self.initialize()?;
394 }
395
396 let signal_len = signal.len();
397
398 if signal_len < self._config.min_signal_size || self.selected_devices.len() <= 1 {
400 return self.single_device_sparse_fft(signal);
402 }
403
404 self.multi_device_sparse_fft(signal)
406 }
407
408 fn single_device_sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
410 where
411 T: NumCast + Copy + Debug + 'static,
412 {
413 let device_idx = self.selected_devices.first().copied().unwrap_or(0);
415 let device = &self.devices[device_idx];
416
417 let gpu_config = GPUSparseFFTConfig {
419 base_config: self._config.base_config.clone(),
420 backend: device.backend,
421 device_id: device.device_id,
422 ..GPUSparseFFTConfig::default()
423 };
424
425 let mut processor = crate::sparse_fft_gpu::GPUSparseFFT::new(gpu_config);
427 processor.sparse_fft(signal)
428 }
429
430 fn multi_device_sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
432 where
433 T: NumCast + Copy + Debug + Send + Sync + 'static,
434 {
435 let signal_len = signal.len();
436 let num_devices = self.selected_devices.len();
437
438 let chunk_sizes = self.calculate_chunk_sizes(signal_len, num_devices)?;
440
441 let chunks = self.split_signal(signal, &chunk_sizes)?;
443
444 let chunk_results: Result<Vec<_>, _> = chunks
446 .par_iter()
447 .zip(self.selected_devices.par_iter())
448 .map(|(chunk, &device_idx)| {
449 let device = &self.devices[device_idx];
450 let start_time = Instant::now();
451
452 let gpu_config = GPUSparseFFTConfig {
454 base_config: self._config.base_config.clone(),
455 backend: device.backend,
456 device_id: device.device_id,
457 ..GPUSparseFFTConfig::default()
458 };
459
460 let mut processor = crate::sparse_fft_gpu::GPUSparseFFT::new(gpu_config);
462 let result = processor.sparse_fft(chunk);
463
464 if result.is_ok() {
466 let execution_time = start_time.elapsed().as_secs_f64();
467 if let Ok(mut history) = self.performance_history.try_lock() {
468 history
469 .entry(device.device_id)
470 .or_default()
471 .push(execution_time);
472
473 if let Some(times) = history.get_mut(&device.device_id) {
475 if times.len() > 10 {
476 times.drain(0..times.len() - 10);
477 }
478 }
479 }
480 }
481
482 result
483 })
484 .collect();
485
486 let chunk_results = chunk_results?;
487
488 self.combine_chunk_results(chunk_results)
490 }
491
492 fn calculate_chunk_sizes(
494 &self,
495 signal_len: usize,
496 num_devices: usize,
497 ) -> FFTResult<Vec<usize>> {
498 let mut chunk_sizes = Vec::with_capacity(num_devices);
499
500 match self._config.distribution {
501 WorkloadDistribution::Equal => {
502 let base_size = signal_len / num_devices;
503 let remainder = signal_len % num_devices;
504
505 for i in 0..num_devices {
506 let size = if i < remainder {
507 base_size + 1
508 } else {
509 base_size
510 };
511 chunk_sizes.push(size);
512 }
513 }
514 WorkloadDistribution::ComputeBased => {
515 let total_compute: f32 = self
517 .selected_devices
518 .iter()
519 .map(|&idx| {
520 self.devices[idx].compute_capability
521 * self.devices[idx].compute_units as f32
522 })
523 .sum();
524
525 let mut remaining = signal_len;
526 for (i, &device_idx) in self.selected_devices.iter().enumerate() {
527 let device = &self.devices[device_idx];
528 let device_compute = device.compute_capability * device.compute_units as f32;
529 let ratio = device_compute / total_compute;
530
531 let size = if i == num_devices - 1 {
532 remaining } else {
534 let size = (signal_len as f32 * ratio) as usize;
535 remaining = remaining.saturating_sub(size);
536 size
537 };
538
539 chunk_sizes.push(size);
540 }
541 }
542 WorkloadDistribution::MemoryBased => {
543 let total_memory: usize = self
545 .selected_devices
546 .iter()
547 .map(|&idx| self.devices[idx].memory_free)
548 .sum();
549
550 let mut remaining = signal_len;
551 for (i, &device_idx) in self.selected_devices.iter().enumerate() {
552 let device = &self.devices[device_idx];
553 let ratio = device.memory_free as f32 / total_memory as f32;
554
555 let size = if i == num_devices - 1 {
556 remaining
557 } else {
558 let size = (signal_len as f32 * ratio) as usize;
559 remaining = remaining.saturating_sub(size);
560 size
561 };
562
563 chunk_sizes.push(size);
564 }
565 }
566 WorkloadDistribution::Manual => {
567 if self._config.manual_ratios.len() != num_devices {
568 return Err(FFTError::ValueError(
569 "Manual ratios length must match number of selected _devices".to_string(),
570 ));
571 }
572
573 let total_ratio: f32 = self._config.manual_ratios.iter().sum();
574 let mut remaining = signal_len;
575
576 for (i, &ratio) in self._config.manual_ratios.iter().enumerate() {
577 let size = if i == num_devices - 1 {
578 remaining
579 } else {
580 let size = (signal_len as f32 * ratio / total_ratio) as usize;
581 remaining = remaining.saturating_sub(size);
582 size
583 };
584
585 chunk_sizes.push(size);
586 }
587 }
588 WorkloadDistribution::Adaptive => {
589 return self.calculate_chunk_sizes(signal_len, num_devices);
592 }
593 }
594
595 Ok(chunk_sizes)
596 }
597
598 fn split_signal<T>(&self, signal: &[T], chunksizes: &[usize]) -> FFTResult<Vec<Vec<T>>>
600 where
601 T: Copy,
602 {
603 let mut chunks = Vec::new();
604 let mut offset = 0;
605
606 for &chunk_size in chunksizes {
607 if offset + chunk_size > signal.len() {
608 return Err(FFTError::ValueError(
609 "Chunk sizes exceed signal length".to_string(),
610 ));
611 }
612
613 let chunk_end = offset + chunk_size;
614 let chunk = signal[offset..chunk_end].to_vec();
615 chunks.push(chunk);
616 offset = chunk_end;
617 }
618
619 Ok(chunks)
620 }
621
622 fn combine_chunk_results(
624 &self,
625 chunk_results: Vec<SparseFFTResult>,
626 ) -> FFTResult<SparseFFTResult> {
627 if chunk_results.is_empty() {
628 return Err(FFTError::ComputationError(
629 "No chunk _results to combine".to_string(),
630 ));
631 }
632
633 if chunk_results.len() == 1 {
634 return Ok(chunk_results.into_iter().next().unwrap());
635 }
636
637 let max_computation_time = chunk_results
639 .iter()
640 .map(|r| r.computation_time)
641 .max()
642 .unwrap_or_default();
643
644 let mut combined_values = Vec::new();
646 let mut combined_indices = Vec::new();
647 let mut index_offset = 0;
648
649 for result in chunk_results {
650 let indices_len = result.indices.len();
652
653 combined_values.extend(result.values);
655
656 let adjusted_indices: Vec<usize> = result
658 .indices
659 .into_iter()
660 .map(|idx| idx + index_offset)
661 .collect();
662 combined_indices.extend(adjusted_indices);
663
664 index_offset += indices_len;
667 }
668
669 let mut frequency_map: std::collections::HashMap<usize, Complex64> =
671 std::collections::HashMap::new();
672
673 for (idx, value) in combined_indices.iter().zip(combined_values.iter()) {
674 frequency_map.insert(*idx, *value);
675 }
676
677 let mut sorted_entries: Vec<_> = frequency_map.into_iter().collect();
678 sorted_entries.sort_by_key(|&(idx_, _)| idx_);
679
680 let final_indices: Vec<usize> = sorted_entries.iter().map(|(idx_, _)| *idx_).collect();
681 let final_values: Vec<Complex64> = sorted_entries.iter().map(|(_, val)| *val).collect();
682
683 let total_estimated_sparsity = final_values.len();
685
686 Ok(SparseFFTResult {
687 values: final_values,
688 indices: final_indices,
689 estimated_sparsity: total_estimated_sparsity,
690 computation_time: max_computation_time,
691 algorithm: self._config.base_config.algorithm,
692 })
693 }
694
695 pub fn get_performance_stats(&self) -> HashMap<i32, Vec<f64>> {
697 self.performance_history.lock().unwrap().clone()
698 }
699
700 pub fn reset_performance_history(&mut self) {
702 self.performance_history.lock().unwrap().clear();
703 }
704}
705
706#[allow(dead_code)]
708pub fn multi_gpu_sparse_fft<T>(
709 signal: &[T],
710 k: usize,
711 algorithm: Option<SparseFFTAlgorithm>,
712 window_function: Option<WindowFunction>,
713) -> FFTResult<SparseFFTResult>
714where
715 T: NumCast + Copy + Debug + Send + Sync + 'static,
716{
717 let base_config = SparseFFTConfig {
718 sparsity: k,
719 algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
720 window_function: window_function.unwrap_or(WindowFunction::None),
721 ..SparseFFTConfig::default()
722 };
723
724 let config = MultiGPUConfig {
725 base_config,
726 ..MultiGPUConfig::default()
727 };
728
729 let mut processor = MultiGPUSparseFFT::new(config);
730 processor.sparse_fft(signal)
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736 use std::f64::consts::PI;
737
738 fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
740 let mut signal = vec![0.0; n];
741
742 for i in 0..n {
743 let t = 2.0 * PI * (i as f64) / (n as f64);
744 for &(freq, amp) in frequencies {
745 signal[i] += amp * (freq as f64 * t).sin();
746 }
747 }
748
749 signal
750 }
751
752 #[test]
753 fn test_multi_gpu_initialization() {
754 let mut processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
755 let result = processor.initialize();
756
757 assert!(result.is_ok());
759 assert!(!processor.get_devices().is_empty());
760
761 let caps = PlatformCapabilities::detect();
763 if !caps.cuda_available && !caps.gpu_available {
764 eprintln!("GPU not available, verifying CPU fallback is present");
765 assert!(processor
766 .get_devices()
767 .iter()
768 .any(|d| d.backend == GPUBackend::CPUFallback));
769 }
770 }
771
772 #[test]
773 fn test_device_enumeration() {
774 let mut processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
775 processor.initialize().unwrap();
776
777 let devices = processor.get_devices();
778 assert!(!devices.is_empty());
779
780 assert!(devices.iter().any(|d| d.backend == GPUBackend::CPUFallback));
782
783 let caps = PlatformCapabilities::detect();
785 if caps.cuda_available || caps.gpu_available {
786 eprintln!("GPU available, checking for GPU devices in enumeration");
787 assert!(!devices.is_empty());
789 } else {
790 eprintln!("GPU not available, verifying only CPU fallback present");
791 assert_eq!(devices.len(), 1);
792 assert_eq!(devices[0].backend, GPUBackend::CPUFallback);
793 }
794 }
795
796 #[test]
797 fn test_multi_gpu_sparse_fft() {
798 let caps = PlatformCapabilities::detect();
800 let n = if caps.cuda_available || caps.gpu_available {
801 8192 } else {
803 eprintln!("GPU not available, using smaller size for CPU fallback");
804 1024 };
806
807 let frequencies = vec![(10, 1.0), (50, 0.5), (100, 0.25)];
808 let signal = create_sparse_signal(n, &frequencies);
809
810 let result = multi_gpu_sparse_fft(
811 &signal,
812 6,
813 Some(SparseFFTAlgorithm::Sublinear),
814 Some(WindowFunction::Hann),
815 );
816
817 assert!(result.is_ok());
818 let result = result.unwrap();
819 assert!(!result.values.is_empty());
820 assert_eq!(result.values.len(), result.indices.len());
821 }
822
823 #[test]
824 fn test_chunk_size_calculation() {
825 let config = MultiGPUConfig {
826 distribution: WorkloadDistribution::Equal,
827 ..MultiGPUConfig::default()
828 };
829 let mut processor = MultiGPUSparseFFT::new(config);
830
831 processor.selected_devices = vec![0, 1, 2];
833
834 let chunk_sizes = processor.calculate_chunk_sizes(1000, 3).unwrap();
835 assert_eq!(chunk_sizes.len(), 3);
836 assert_eq!(chunk_sizes.iter().sum::<usize>(), 1000);
837 }
838
839 #[test]
840 fn test_signal_splitting() {
841 let processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
842 let signal = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
843 let chunk_sizes = vec![3, 3, 4];
844
845 let chunks = processor.split_signal(&signal, &chunk_sizes).unwrap();
846 assert_eq!(chunks.len(), 3);
847 assert_eq!(chunks[0], vec![1, 2, 3]);
848 assert_eq!(chunks[1], vec![4, 5, 6]);
849 assert_eq!(chunks[2], vec![7, 8, 9, 10]);
850 }
851}