scirs2_fft/
sparse_fft_multi_gpu.rs

1//! Multi-GPU Sparse FFT Implementation
2//!
3//! This module provides advanced multi-GPU support for sparse FFT operations,
4//! allowing parallel processing across multiple GPU devices for maximum performance.
5
6use crate::error::{FFTError, FFTResult};
7use crate::sparse_fft::{SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, WindowFunction};
8use crate::sparse_fft_gpu::{GPUBackend, GPUSparseFFTConfig};
9use crate::sparse_fft_gpu_memory::{
10    init_cuda_device, init_hip_device, init_sycl_device, is_cuda_available, is_hip_available,
11    is_sycl_available,
12};
13use scirs2_core::numeric::Complex64;
14use scirs2_core::numeric::NumCast;
15use scirs2_core::parallel_ops::*;
16use scirs2_core::simd_ops::PlatformCapabilities;
17use std::collections::HashMap;
18use std::fmt::Debug;
19use std::sync::{Arc, Mutex};
20use std::time::Instant;
21
22/// Information about an available GPU device
23#[derive(Debug, Clone)]
24pub struct GPUDeviceInfo {
25    /// Device ID
26    pub device_id: i32,
27    /// Backend type
28    pub backend: GPUBackend,
29    /// Device name/description
30    pub device_name: String,
31    /// Available memory in bytes
32    pub memory_total: usize,
33    /// Free memory in bytes
34    pub memory_free: usize,
35    /// Compute capability or equivalent
36    pub compute_capability: f32,
37    /// Number of compute units/SMs
38    pub compute_units: usize,
39    /// Maximum threads per block
40    pub max_threads_per_block: usize,
41    /// Is this device currently available
42    pub is_available: bool,
43}
44
45impl Default for GPUDeviceInfo {
46    fn default() -> Self {
47        Self {
48            device_id: -1,
49            backend: GPUBackend::CPUFallback,
50            device_name: "Unknown Device".to_string(),
51            memory_total: 0,
52            memory_free: 0,
53            compute_capability: 0.0,
54            compute_units: 0,
55            max_threads_per_block: 0,
56            is_available: false,
57        }
58    }
59}
60
61/// Multi-GPU workload distribution strategy
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum WorkloadDistribution {
64    /// Equal distribution across all devices
65    Equal,
66    /// Distribution based on device memory capacity
67    MemoryBased,
68    /// Distribution based on device compute capability
69    ComputeBased,
70    /// Manual distribution with specified ratios
71    Manual,
72    /// Adaptive distribution based on runtime performance
73    Adaptive,
74}
75
76/// Multi-GPU configuration
77#[derive(Debug, Clone)]
78pub struct MultiGPUConfig {
79    /// Base sparse FFT configuration
80    pub base_config: SparseFFTConfig,
81    /// Workload distribution strategy
82    pub distribution: WorkloadDistribution,
83    /// Manual distribution ratios (if using Manual distribution)
84    pub manual_ratios: Vec<f32>,
85    /// Maximum number of devices to use (0 = use all available)
86    pub max_devices: usize,
87    /// Minimum signal size to enable multi-GPU processing
88    pub min_signal_size: usize,
89    /// Overlap between chunks for boundary handling
90    pub chunk_overlap: usize,
91    /// Enable load balancing between devices
92    pub enable_load_balancing: bool,
93    /// Timeout for device operations in milliseconds
94    pub device_timeout_ms: u64,
95}
96
97impl Default for MultiGPUConfig {
98    fn default() -> Self {
99        Self {
100            base_config: SparseFFTConfig::default(),
101            distribution: WorkloadDistribution::ComputeBased,
102            manual_ratios: Vec::new(),
103            max_devices: 0,        // Use all available
104            min_signal_size: 4096, // Only use multi-GPU for larger signals
105            chunk_overlap: 0,
106            enable_load_balancing: true,
107            device_timeout_ms: 5000,
108        }
109    }
110}
111
112/// Multi-GPU sparse FFT processor
113pub struct MultiGPUSparseFFT {
114    /// Configuration
115    _config: MultiGPUConfig,
116    /// Available devices
117    devices: Vec<GPUDeviceInfo>,
118    /// Device selection for current operation
119    selected_devices: Vec<usize>,
120    /// Performance history for adaptive load balancing
121    performance_history: Arc<Mutex<HashMap<i32, Vec<f64>>>>,
122    /// Is multi-GPU initialized
123    initialized: bool,
124}
125
126impl MultiGPUSparseFFT {
127    /// Create a new multi-GPU sparse FFT processor
128    pub fn new(config: MultiGPUConfig) -> Self {
129        Self {
130            _config: config,
131            devices: Vec::new(),
132            selected_devices: Vec::new(),
133            performance_history: Arc::new(Mutex::new(HashMap::new())),
134            initialized: false,
135        }
136    }
137
138    /// Initialize multi-GPU system and enumerate devices
139    pub fn initialize(&mut self) -> FFTResult<()> {
140        if self.initialized {
141            return Ok(());
142        }
143
144        // Enumerate all available GPU devices
145        self.enumerate_devices()?;
146
147        // Select devices based on configuration
148        self.select_devices()?;
149
150        self.initialized = true;
151        Ok(())
152    }
153
154    /// Enumerate all available GPU devices
155    fn enumerate_devices(&mut self) -> FFTResult<()> {
156        self.devices.clear();
157
158        // Enumerate CUDA devices
159        if is_cuda_available() {
160            self.enumerate_cuda_devices()?;
161        }
162
163        // Enumerate HIP devices
164        if is_hip_available() {
165            self.enumerate_hip_devices()?;
166        }
167
168        // Enumerate SYCL devices
169        if is_sycl_available() {
170            self.enumerate_sycl_devices()?;
171        }
172
173        // Add CPU fallback as last resort
174        self.devices.push(GPUDeviceInfo {
175            device_id: -1,
176            backend: GPUBackend::CPUFallback,
177            device_name: "CPU Fallback".to_string(),
178            memory_total: 16 * 1024 * 1024 * 1024, // Assume 16GB RAM
179            memory_free: 8 * 1024 * 1024 * 1024,   // Assume half available
180            compute_capability: 1.0,
181            compute_units: num_cpus::get(),
182            max_threads_per_block: 1,
183            is_available: true,
184        });
185
186        Ok(())
187    }
188
189    /// Enumerate CUDA devices
190    fn enumerate_cuda_devices(&mut self) -> FFTResult<()> {
191        // Initialize CUDA if available
192        if init_cuda_device()? {
193            // In a real implementation, this would query actual CUDA devices
194            // For now, simulate one CUDA device
195            self.devices.push(GPUDeviceInfo {
196                device_id: 0,
197                backend: GPUBackend::CUDA,
198                device_name: "NVIDIA GPU (simulated)".to_string(),
199                memory_total: 8 * 1024 * 1024 * 1024, // 8GB
200                memory_free: 6 * 1024 * 1024 * 1024,  // 6GB free
201                compute_capability: 8.6,
202                compute_units: 68,
203                max_threads_per_block: 1024,
204                is_available: true,
205            });
206        }
207
208        Ok(())
209    }
210
211    /// Enumerate HIP devices
212    fn enumerate_hip_devices(&mut self) -> FFTResult<()> {
213        // Initialize HIP if available
214        if init_hip_device()? {
215            // In a real implementation, this would query actual HIP devices
216            // For now, simulate one HIP device
217            self.devices.push(GPUDeviceInfo {
218                device_id: 0,
219                backend: GPUBackend::HIP,
220                device_name: "AMD GPU (simulated)".to_string(),
221                memory_total: 16 * 1024 * 1024 * 1024, // 16GB
222                memory_free: 12 * 1024 * 1024 * 1024,  // 12GB free
223                compute_capability: 10.3,              // GFX103x equivalent
224                compute_units: 40,
225                max_threads_per_block: 256,
226                is_available: true,
227            });
228        }
229
230        Ok(())
231    }
232
233    /// Enumerate SYCL devices
234    fn enumerate_sycl_devices(&mut self) -> FFTResult<()> {
235        // Initialize SYCL if available
236        if init_sycl_device()? {
237            // In a real implementation, this would query actual SYCL devices
238            // For now, simulate one SYCL device
239            self.devices.push(GPUDeviceInfo {
240                device_id: 0,
241                backend: GPUBackend::SYCL,
242                device_name: "Intel GPU (simulated)".to_string(),
243                memory_total: 4 * 1024 * 1024 * 1024, // 4GB
244                memory_free: 3 * 1024 * 1024 * 1024,  // 3GB free
245                compute_capability: 1.2,              // Intel GPU equivalent
246                compute_units: 96,
247                max_threads_per_block: 512,
248                is_available: true,
249            });
250        }
251
252        Ok(())
253    }
254
255    /// Select devices based on configuration
256    fn select_devices(&mut self) -> FFTResult<()> {
257        self.selected_devices.clear();
258
259        // Filter available devices
260        let available_devices: Vec<(usize, &GPUDeviceInfo)> = self
261            .devices
262            .iter()
263            .enumerate()
264            .filter(|(_, device)| device.is_available)
265            .collect();
266
267        if available_devices.is_empty() {
268            return Err(FFTError::ComputationError(
269                "No available GPU devices found".to_string(),
270            ));
271        }
272
273        // Determine how many devices to use
274        let max_devices = if self._config.max_devices == 0 {
275            available_devices.len()
276        } else {
277            self._config.max_devices.min(available_devices.len())
278        };
279
280        // Select devices based on strategy
281        match self._config.distribution {
282            WorkloadDistribution::Equal => {
283                // Use first N available devices
284                for i in 0..max_devices {
285                    self.selected_devices.push(available_devices[i].0);
286                }
287            }
288            WorkloadDistribution::ComputeBased => {
289                // Sort by compute capability and select top N
290                let mut sorted_devices = available_devices;
291                sorted_devices.sort_by(|a, b| {
292                    b.1.compute_capability
293                        .partial_cmp(&a.1.compute_capability)
294                        .unwrap_or(std::cmp::Ordering::Equal)
295                });
296
297                for i in 0..max_devices {
298                    self.selected_devices.push(sorted_devices[i].0);
299                }
300            }
301            WorkloadDistribution::MemoryBased => {
302                // Sort by available memory and select top N
303                let mut sorted_devices = available_devices;
304                sorted_devices.sort_by(|a, b| b.1.memory_free.cmp(&a.1.memory_free));
305
306                for i in 0..max_devices {
307                    self.selected_devices.push(sorted_devices[i].0);
308                }
309            }
310            WorkloadDistribution::Manual => {
311                // Use manual selection (for now, just use first N devices)
312                for i in 0..max_devices {
313                    self.selected_devices.push(available_devices[i].0);
314                }
315            }
316            WorkloadDistribution::Adaptive => {
317                // Clone available devices to avoid borrow issues
318                let available_devices_clone: Vec<(usize, GPUDeviceInfo)> = available_devices
319                    .iter()
320                    .map(|(idx, device)| (*idx, (*device).clone()))
321                    .collect();
322
323                // Use performance history to select best devices
324                self.select_adaptive_devices_with_clone(available_devices_clone, max_devices)?;
325            }
326        }
327
328        Ok(())
329    }
330
331    /// Select devices based on adaptive performance history
332    fn select_adaptive_devices_with_clone(
333        &mut self,
334        available_devices: Vec<(usize, GPUDeviceInfo)>,
335        max_devices: usize,
336    ) -> FFTResult<()> {
337        let performance_history = self.performance_history.lock().unwrap();
338
339        // Calculate average performance for each device
340        let mut device_scores: Vec<(usize, f64)> = available_devices
341            .iter()
342            .map(|(idx, device)| {
343                let avg_performance = performance_history
344                    .get(&device.device_id)
345                    .map(|times| {
346                        if times.is_empty() {
347                            // Default score based on device capabilities
348                            device.compute_capability as f64 * device.compute_units as f64
349                        } else {
350                            // Higher score for faster _devices (lower execution times)
351                            1.0 / (times.iter().sum::<f64>() / times.len() as f64)
352                        }
353                    })
354                    .unwrap_or_else(|| {
355                        // Default score for _devices without history
356                        device.compute_capability as f64 * device.compute_units as f64
357                    });
358
359                (*idx, avg_performance)
360            })
361            .collect();
362
363        // Sort by performance score (descending)
364        device_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
365
366        // Select top N _devices
367        for i in 0..max_devices {
368            self.selected_devices.push(device_scores[i].0);
369        }
370
371        Ok(())
372    }
373
374    /// Get information about available devices
375    pub fn get_devices(&self) -> &[GPUDeviceInfo] {
376        &self.devices
377    }
378
379    /// Get information about selected devices
380    pub fn get_selected_devices(&self) -> Vec<&GPUDeviceInfo> {
381        self.selected_devices
382            .iter()
383            .map(|&idx| &self.devices[idx])
384            .collect()
385    }
386
387    /// Perform multi-GPU sparse FFT
388    pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
389    where
390        T: NumCast + Copy + Debug + Send + Sync + 'static,
391    {
392        if !self.initialized {
393            self.initialize()?;
394        }
395
396        let signal_len = signal.len();
397
398        // Check if signal is large enough for multi-GPU processing
399        if signal_len < self._config.min_signal_size || self.selected_devices.len() <= 1 {
400            // Fall back to single-device processing
401            return self.single_device_sparse_fft(signal);
402        }
403
404        // Distribute workload across selected devices
405        self.multi_device_sparse_fft(signal)
406    }
407
408    /// Single-device sparse FFT fallback
409    fn single_device_sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
410    where
411        T: NumCast + Copy + Debug + 'static,
412    {
413        // Use the best available device
414        let device_idx = self.selected_devices.first().copied().unwrap_or(0);
415        let device = &self.devices[device_idx];
416
417        // Create GPU configuration for the selected device
418        let gpu_config = GPUSparseFFTConfig {
419            base_config: self._config.base_config.clone(),
420            backend: device.backend,
421            device_id: device.device_id,
422            ..GPUSparseFFTConfig::default()
423        };
424
425        // Create GPU processor and perform computation
426        let mut processor = crate::sparse_fft_gpu::GPUSparseFFT::new(gpu_config);
427        processor.sparse_fft(signal)
428    }
429
430    /// Multi-device sparse FFT implementation
431    fn multi_device_sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
432    where
433        T: NumCast + Copy + Debug + Send + Sync + 'static,
434    {
435        let signal_len = signal.len();
436        let num_devices = self.selected_devices.len();
437
438        // Calculate chunk sizes based on distribution strategy
439        let chunk_sizes = self.calculate_chunk_sizes(signal_len, num_devices)?;
440
441        // Split signal into chunks
442        let chunks = self.split_signal(signal, &chunk_sizes)?;
443
444        // Process chunks in parallel across devices
445        let chunk_results: Result<Vec<_>, _> = chunks
446            .par_iter()
447            .zip(self.selected_devices.par_iter())
448            .map(|(chunk, &device_idx)| {
449                let device = &self.devices[device_idx];
450                let start_time = Instant::now();
451
452                // Create GPU configuration for this device
453                let gpu_config = GPUSparseFFTConfig {
454                    base_config: self._config.base_config.clone(),
455                    backend: device.backend,
456                    device_id: device.device_id,
457                    ..GPUSparseFFTConfig::default()
458                };
459
460                // Process chunk
461                let mut processor = crate::sparse_fft_gpu::GPUSparseFFT::new(gpu_config);
462                let result = processor.sparse_fft(chunk);
463
464                // Record performance for adaptive selection
465                if result.is_ok() {
466                    let execution_time = start_time.elapsed().as_secs_f64();
467                    if let Ok(mut history) = self.performance_history.try_lock() {
468                        history
469                            .entry(device.device_id)
470                            .or_default()
471                            .push(execution_time);
472
473                        // Keep only recent history (last 10 measurements)
474                        if let Some(times) = history.get_mut(&device.device_id) {
475                            if times.len() > 10 {
476                                times.drain(0..times.len() - 10);
477                            }
478                        }
479                    }
480                }
481
482                result
483            })
484            .collect();
485
486        let chunk_results = chunk_results?;
487
488        // Combine results from all chunks
489        self.combine_chunk_results(chunk_results)
490    }
491
492    /// Calculate chunk sizes for workload distribution
493    fn calculate_chunk_sizes(
494        &self,
495        signal_len: usize,
496        num_devices: usize,
497    ) -> FFTResult<Vec<usize>> {
498        let mut chunk_sizes = Vec::with_capacity(num_devices);
499
500        match self._config.distribution {
501            WorkloadDistribution::Equal => {
502                let base_size = signal_len / num_devices;
503                let remainder = signal_len % num_devices;
504
505                for i in 0..num_devices {
506                    let size = if i < remainder {
507                        base_size + 1
508                    } else {
509                        base_size
510                    };
511                    chunk_sizes.push(size);
512                }
513            }
514            WorkloadDistribution::ComputeBased => {
515                // Distribute based on compute capability
516                let total_compute: f32 = self
517                    .selected_devices
518                    .iter()
519                    .map(|&idx| {
520                        self.devices[idx].compute_capability
521                            * self.devices[idx].compute_units as f32
522                    })
523                    .sum();
524
525                let mut remaining = signal_len;
526                for (i, &device_idx) in self.selected_devices.iter().enumerate() {
527                    let device = &self.devices[device_idx];
528                    let device_compute = device.compute_capability * device.compute_units as f32;
529                    let ratio = device_compute / total_compute;
530
531                    let size = if i == num_devices - 1 {
532                        remaining // Give remainder to last device
533                    } else {
534                        let size = (signal_len as f32 * ratio) as usize;
535                        remaining = remaining.saturating_sub(size);
536                        size
537                    };
538
539                    chunk_sizes.push(size);
540                }
541            }
542            WorkloadDistribution::MemoryBased => {
543                // Distribute based on available memory
544                let total_memory: usize = self
545                    .selected_devices
546                    .iter()
547                    .map(|&idx| self.devices[idx].memory_free)
548                    .sum();
549
550                let mut remaining = signal_len;
551                for (i, &device_idx) in self.selected_devices.iter().enumerate() {
552                    let device = &self.devices[device_idx];
553                    let ratio = device.memory_free as f32 / total_memory as f32;
554
555                    let size = if i == num_devices - 1 {
556                        remaining
557                    } else {
558                        let size = (signal_len as f32 * ratio) as usize;
559                        remaining = remaining.saturating_sub(size);
560                        size
561                    };
562
563                    chunk_sizes.push(size);
564                }
565            }
566            WorkloadDistribution::Manual => {
567                if self._config.manual_ratios.len() != num_devices {
568                    return Err(FFTError::ValueError(
569                        "Manual ratios length must match number of selected _devices".to_string(),
570                    ));
571                }
572
573                let total_ratio: f32 = self._config.manual_ratios.iter().sum();
574                let mut remaining = signal_len;
575
576                for (i, &ratio) in self._config.manual_ratios.iter().enumerate() {
577                    let size = if i == num_devices - 1 {
578                        remaining
579                    } else {
580                        let size = (signal_len as f32 * ratio / total_ratio) as usize;
581                        remaining = remaining.saturating_sub(size);
582                        size
583                    };
584
585                    chunk_sizes.push(size);
586                }
587            }
588            WorkloadDistribution::Adaptive => {
589                // Use performance history to determine optimal distribution
590                // For now, fall back to compute-based distribution
591                return self.calculate_chunk_sizes(signal_len, num_devices);
592            }
593        }
594
595        Ok(chunk_sizes)
596    }
597
598    /// Split signal into chunks based on calculated sizes
599    fn split_signal<T>(&self, signal: &[T], chunksizes: &[usize]) -> FFTResult<Vec<Vec<T>>>
600    where
601        T: Copy,
602    {
603        let mut chunks = Vec::new();
604        let mut offset = 0;
605
606        for &chunk_size in chunksizes {
607            if offset + chunk_size > signal.len() {
608                return Err(FFTError::ValueError(
609                    "Chunk sizes exceed signal length".to_string(),
610                ));
611            }
612
613            let chunk_end = offset + chunk_size;
614            let chunk = signal[offset..chunk_end].to_vec();
615            chunks.push(chunk);
616            offset = chunk_end;
617        }
618
619        Ok(chunks)
620    }
621
622    /// Combine results from multiple chunks
623    fn combine_chunk_results(
624        &self,
625        chunk_results: Vec<SparseFFTResult>,
626    ) -> FFTResult<SparseFFTResult> {
627        if chunk_results.is_empty() {
628            return Err(FFTError::ComputationError(
629                "No chunk _results to combine".to_string(),
630            ));
631        }
632
633        if chunk_results.len() == 1 {
634            return Ok(chunk_results.into_iter().next().unwrap());
635        }
636
637        // Use the computation time from the slowest device
638        let max_computation_time = chunk_results
639            .iter()
640            .map(|r| r.computation_time)
641            .max()
642            .unwrap_or_default();
643
644        // Combine frequency components from all chunks
645        let mut combined_values = Vec::new();
646        let mut combined_indices = Vec::new();
647        let mut index_offset = 0;
648
649        for result in chunk_results {
650            // Store the indices length before moving
651            let indices_len = result.indices.len();
652
653            // Add values from this chunk
654            combined_values.extend(result.values);
655
656            // Adjust indices to account for chunk offset
657            let adjusted_indices: Vec<usize> = result
658                .indices
659                .into_iter()
660                .map(|idx| idx + index_offset)
661                .collect();
662            combined_indices.extend(adjusted_indices);
663
664            // Update offset for next chunk
665            // This is a simplified approach - in practice, frequency domain combining is more complex
666            index_offset += indices_len;
667        }
668
669        // Remove duplicates and sort
670        let mut frequency_map: std::collections::HashMap<usize, Complex64> =
671            std::collections::HashMap::new();
672
673        for (idx, value) in combined_indices.iter().zip(combined_values.iter()) {
674            frequency_map.insert(*idx, *value);
675        }
676
677        let mut sorted_entries: Vec<_> = frequency_map.into_iter().collect();
678        sorted_entries.sort_by_key(|&(idx_, _)| idx_);
679
680        let final_indices: Vec<usize> = sorted_entries.iter().map(|(idx_, _)| *idx_).collect();
681        let final_values: Vec<Complex64> = sorted_entries.iter().map(|(_, val)| *val).collect();
682
683        // Calculate combined sparsity
684        let total_estimated_sparsity = final_values.len();
685
686        Ok(SparseFFTResult {
687            values: final_values,
688            indices: final_indices,
689            estimated_sparsity: total_estimated_sparsity,
690            computation_time: max_computation_time,
691            algorithm: self._config.base_config.algorithm,
692        })
693    }
694
695    /// Get performance statistics for each device
696    pub fn get_performance_stats(&self) -> HashMap<i32, Vec<f64>> {
697        self.performance_history.lock().unwrap().clone()
698    }
699
700    /// Reset performance history
701    pub fn reset_performance_history(&mut self) {
702        self.performance_history.lock().unwrap().clear();
703    }
704}
705
706/// Convenience function for multi-GPU sparse FFT with default configuration
707#[allow(dead_code)]
708pub fn multi_gpu_sparse_fft<T>(
709    signal: &[T],
710    k: usize,
711    algorithm: Option<SparseFFTAlgorithm>,
712    window_function: Option<WindowFunction>,
713) -> FFTResult<SparseFFTResult>
714where
715    T: NumCast + Copy + Debug + Send + Sync + 'static,
716{
717    let base_config = SparseFFTConfig {
718        sparsity: k,
719        algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
720        window_function: window_function.unwrap_or(WindowFunction::None),
721        ..SparseFFTConfig::default()
722    };
723
724    let config = MultiGPUConfig {
725        base_config,
726        ..MultiGPUConfig::default()
727    };
728
729    let mut processor = MultiGPUSparseFFT::new(config);
730    processor.sparse_fft(signal)
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736    use std::f64::consts::PI;
737
738    // Helper function to create a sparse signal
739    fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
740        let mut signal = vec![0.0; n];
741
742        for i in 0..n {
743            let t = 2.0 * PI * (i as f64) / (n as f64);
744            for &(freq, amp) in frequencies {
745                signal[i] += amp * (freq as f64 * t).sin();
746            }
747        }
748
749        signal
750    }
751
752    #[test]
753    fn test_multi_gpu_initialization() {
754        let mut processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
755        let result = processor.initialize();
756
757        // Should succeed even if no GPU devices available (CPU fallback)
758        assert!(result.is_ok());
759        assert!(!processor.get_devices().is_empty());
760
761        // Check if GPU is available
762        let caps = PlatformCapabilities::detect();
763        if !caps.cuda_available && !caps.gpu_available {
764            eprintln!("GPU not available, verifying CPU fallback is present");
765            assert!(processor
766                .get_devices()
767                .iter()
768                .any(|d| d.backend == GPUBackend::CPUFallback));
769        }
770    }
771
772    #[test]
773    fn test_device_enumeration() {
774        let mut processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
775        processor.initialize().unwrap();
776
777        let devices = processor.get_devices();
778        assert!(!devices.is_empty());
779
780        // Should have at least CPU fallback
781        assert!(devices.iter().any(|d| d.backend == GPUBackend::CPUFallback));
782
783        // Check if GPU is available
784        let caps = PlatformCapabilities::detect();
785        if caps.cuda_available || caps.gpu_available {
786            eprintln!("GPU available, checking for GPU devices in enumeration");
787            // May have GPU devices
788            assert!(!devices.is_empty());
789        } else {
790            eprintln!("GPU not available, verifying only CPU fallback present");
791            assert_eq!(devices.len(), 1);
792            assert_eq!(devices[0].backend, GPUBackend::CPUFallback);
793        }
794    }
795
796    #[test]
797    fn test_multi_gpu_sparse_fft() {
798        // Check if GPU is available
799        let caps = PlatformCapabilities::detect();
800        let n = if caps.cuda_available || caps.gpu_available {
801            8192 // Large size for multi-GPU if available
802        } else {
803            eprintln!("GPU not available, using smaller size for CPU fallback");
804            1024 // Smaller size for CPU fallback
805        };
806
807        let frequencies = vec![(10, 1.0), (50, 0.5), (100, 0.25)];
808        let signal = create_sparse_signal(n, &frequencies);
809
810        let result = multi_gpu_sparse_fft(
811            &signal,
812            6,
813            Some(SparseFFTAlgorithm::Sublinear),
814            Some(WindowFunction::Hann),
815        );
816
817        assert!(result.is_ok());
818        let result = result.unwrap();
819        assert!(!result.values.is_empty());
820        assert_eq!(result.values.len(), result.indices.len());
821    }
822
823    #[test]
824    fn test_chunk_size_calculation() {
825        let config = MultiGPUConfig {
826            distribution: WorkloadDistribution::Equal,
827            ..MultiGPUConfig::default()
828        };
829        let mut processor = MultiGPUSparseFFT::new(config);
830
831        // Simulate device setup
832        processor.selected_devices = vec![0, 1, 2];
833
834        let chunk_sizes = processor.calculate_chunk_sizes(1000, 3).unwrap();
835        assert_eq!(chunk_sizes.len(), 3);
836        assert_eq!(chunk_sizes.iter().sum::<usize>(), 1000);
837    }
838
839    #[test]
840    fn test_signal_splitting() {
841        let processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
842        let signal = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
843        let chunk_sizes = vec![3, 3, 4];
844
845        let chunks = processor.split_signal(&signal, &chunk_sizes).unwrap();
846        assert_eq!(chunks.len(), 3);
847        assert_eq!(chunks[0], vec![1, 2, 3]);
848        assert_eq!(chunks[1], vec![4, 5, 6]);
849        assert_eq!(chunks[2], vec![7, 8, 9, 10]);
850    }
851}