Skip to main content

scirs2_fft/
sparse_fft_multi_gpu.rs

1//! Multi-GPU Sparse FFT Implementation
2//!
3//! This module provides advanced multi-GPU support for sparse FFT operations,
4//! allowing parallel processing across multiple GPU devices for maximum performance.
5
6use crate::error::{FFTError, FFTResult};
7use crate::sparse_fft::{SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, WindowFunction};
8use crate::sparse_fft_gpu::{GPUBackend, GPUSparseFFTConfig};
9use crate::sparse_fft_gpu_memory::{
10    init_cuda_device, init_hip_device, init_sycl_device, is_cuda_available, is_hip_available,
11    is_sycl_available,
12};
13use scirs2_core::numeric::Complex64;
14use scirs2_core::numeric::NumCast;
15use scirs2_core::parallel_ops::*;
16use scirs2_core::simd_ops::PlatformCapabilities;
17use std::collections::HashMap;
18use std::fmt::Debug;
19use std::sync::{Arc, Mutex};
20use std::time::Instant;
21
22/// Information about an available GPU device
23#[derive(Debug, Clone)]
24pub struct GPUDeviceInfo {
25    /// Device ID
26    pub device_id: i32,
27    /// Backend type
28    pub backend: GPUBackend,
29    /// Device name/description
30    pub device_name: String,
31    /// Available memory in bytes
32    pub memory_total: usize,
33    /// Free memory in bytes
34    pub memory_free: usize,
35    /// Compute capability or equivalent
36    pub compute_capability: f32,
37    /// Number of compute units/SMs
38    pub compute_units: usize,
39    /// Maximum threads per block
40    pub max_threads_per_block: usize,
41    /// Is this device currently available
42    pub is_available: bool,
43}
44
45impl Default for GPUDeviceInfo {
46    fn default() -> Self {
47        Self {
48            device_id: -1,
49            backend: GPUBackend::CPUFallback,
50            device_name: "Unknown Device".to_string(),
51            memory_total: 0,
52            memory_free: 0,
53            compute_capability: 0.0,
54            compute_units: 0,
55            max_threads_per_block: 0,
56            is_available: false,
57        }
58    }
59}
60
61/// Multi-GPU workload distribution strategy
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum WorkloadDistribution {
64    /// Equal distribution across all devices
65    Equal,
66    /// Distribution based on device memory capacity
67    MemoryBased,
68    /// Distribution based on device compute capability
69    ComputeBased,
70    /// Manual distribution with specified ratios
71    Manual,
72    /// Adaptive distribution based on runtime performance
73    Adaptive,
74}
75
76/// Multi-GPU configuration
77#[derive(Debug, Clone)]
78pub struct MultiGPUConfig {
79    /// Base sparse FFT configuration
80    pub base_config: SparseFFTConfig,
81    /// Workload distribution strategy
82    pub distribution: WorkloadDistribution,
83    /// Manual distribution ratios (if using Manual distribution)
84    pub manual_ratios: Vec<f32>,
85    /// Maximum number of devices to use (0 = use all available)
86    pub max_devices: usize,
87    /// Minimum signal size to enable multi-GPU processing
88    pub min_signal_size: usize,
89    /// Overlap between chunks for boundary handling
90    pub chunk_overlap: usize,
91    /// Enable load balancing between devices
92    pub enable_load_balancing: bool,
93    /// Timeout for device operations in milliseconds
94    pub device_timeout_ms: u64,
95}
96
97impl Default for MultiGPUConfig {
98    fn default() -> Self {
99        Self {
100            base_config: SparseFFTConfig::default(),
101            distribution: WorkloadDistribution::ComputeBased,
102            manual_ratios: Vec::new(),
103            max_devices: 0,        // Use all available
104            min_signal_size: 4096, // Only use multi-GPU for larger signals
105            chunk_overlap: 0,
106            enable_load_balancing: true,
107            device_timeout_ms: 5000,
108        }
109    }
110}
111
112/// Multi-GPU sparse FFT processor
113pub struct MultiGPUSparseFFT {
114    /// Configuration
115    _config: MultiGPUConfig,
116    /// Available devices
117    devices: Vec<GPUDeviceInfo>,
118    /// Device selection for current operation
119    selected_devices: Vec<usize>,
120    /// Performance history for adaptive load balancing
121    performance_history: Arc<Mutex<HashMap<i32, Vec<f64>>>>,
122    /// Is multi-GPU initialized
123    initialized: bool,
124}
125
126impl MultiGPUSparseFFT {
127    /// Create a new multi-GPU sparse FFT processor
128    pub fn new(config: MultiGPUConfig) -> Self {
129        Self {
130            _config: config,
131            devices: Vec::new(),
132            selected_devices: Vec::new(),
133            performance_history: Arc::new(Mutex::new(HashMap::new())),
134            initialized: false,
135        }
136    }
137
138    /// Initialize multi-GPU system and enumerate devices
139    pub fn initialize(&mut self) -> FFTResult<()> {
140        if self.initialized {
141            return Ok(());
142        }
143
144        // Enumerate all available GPU devices
145        self.enumerate_devices()?;
146
147        // Select devices based on configuration
148        self.select_devices()?;
149
150        self.initialized = true;
151        Ok(())
152    }
153
154    /// Enumerate all available GPU devices
155    fn enumerate_devices(&mut self) -> FFTResult<()> {
156        self.devices.clear();
157
158        // Enumerate CUDA devices
159        if is_cuda_available() {
160            self.enumerate_cuda_devices()?;
161        }
162
163        // Enumerate HIP devices
164        if is_hip_available() {
165            self.enumerate_hip_devices()?;
166        }
167
168        // Enumerate SYCL devices
169        if is_sycl_available() {
170            self.enumerate_sycl_devices()?;
171        }
172
173        // Add CPU fallback as last resort
174        #[cfg(target_pointer_width = "32")]
175        let (memory_total, memory_free) = (1024 * 1024 * 1024, 512 * 1024 * 1024); // 1GB total, 512MB free for 32-bit
176        #[cfg(target_pointer_width = "64")]
177        let (memory_total, memory_free) =
178            (16usize * 1024 * 1024 * 1024, 8usize * 1024 * 1024 * 1024); // 16GB total, 8GB free for 64-bit
179
180        self.devices.push(GPUDeviceInfo {
181            device_id: -1,
182            backend: GPUBackend::CPUFallback,
183            device_name: "CPU Fallback".to_string(),
184            memory_total,
185            memory_free,
186            compute_capability: 1.0,
187            compute_units: num_cpus::get(),
188            max_threads_per_block: 1,
189            is_available: true,
190        });
191
192        Ok(())
193    }
194
195    /// Enumerate CUDA devices
196    fn enumerate_cuda_devices(&mut self) -> FFTResult<()> {
197        // Initialize CUDA if available
198        if init_cuda_device()? {
199            // In a real implementation, this would query actual CUDA devices
200            // For now, simulate one CUDA device
201            #[cfg(target_pointer_width = "32")]
202            let (memory_total, memory_free) = (512 * 1024 * 1024, 384 * 1024 * 1024); // 512MB total, 384MB free for 32-bit
203            #[cfg(target_pointer_width = "64")]
204            let (memory_total, memory_free) =
205                (8usize * 1024 * 1024 * 1024, 6usize * 1024 * 1024 * 1024); // 8GB total, 6GB free for 64-bit
206
207            self.devices.push(GPUDeviceInfo {
208                device_id: 0,
209                backend: GPUBackend::CUDA,
210                device_name: "NVIDIA GPU (simulated)".to_string(),
211                memory_total,
212                memory_free,
213                compute_capability: 8.6,
214                compute_units: 68,
215                max_threads_per_block: 1024,
216                is_available: true,
217            });
218        }
219
220        Ok(())
221    }
222
223    /// Enumerate HIP devices
224    fn enumerate_hip_devices(&mut self) -> FFTResult<()> {
225        // Initialize HIP if available
226        if init_hip_device()? {
227            // In a real implementation, this would query actual HIP devices
228            // For now, simulate one HIP device
229            #[cfg(target_pointer_width = "32")]
230            let (memory_total, memory_free) = (1024 * 1024 * 1024, 768 * 1024 * 1024); // 1GB total, 768MB free for 32-bit
231            #[cfg(target_pointer_width = "64")]
232            let (memory_total, memory_free) =
233                (16usize * 1024 * 1024 * 1024, 12usize * 1024 * 1024 * 1024); // 16GB total, 12GB free for 64-bit
234
235            self.devices.push(GPUDeviceInfo {
236                device_id: 0,
237                backend: GPUBackend::HIP,
238                device_name: "AMD GPU (simulated)".to_string(),
239                memory_total,
240                memory_free,
241                compute_capability: 10.3, // GFX103x equivalent
242                compute_units: 40,
243                max_threads_per_block: 256,
244                is_available: true,
245            });
246        }
247
248        Ok(())
249    }
250
251    /// Enumerate SYCL devices
252    fn enumerate_sycl_devices(&mut self) -> FFTResult<()> {
253        // Initialize SYCL if available
254        if init_sycl_device()? {
255            // In a real implementation, this would query actual SYCL devices
256            // For now, simulate one SYCL device
257            #[cfg(target_pointer_width = "32")]
258            let (memory_total, memory_free) = (256 * 1024 * 1024, 192 * 1024 * 1024); // 256MB total, 192MB free for 32-bit
259            #[cfg(target_pointer_width = "64")]
260            let (memory_total, memory_free) =
261                (4usize * 1024 * 1024 * 1024, 3usize * 1024 * 1024 * 1024); // 4GB total, 3GB free for 64-bit
262
263            self.devices.push(GPUDeviceInfo {
264                device_id: 0,
265                backend: GPUBackend::SYCL,
266                device_name: "Intel GPU (simulated)".to_string(),
267                memory_total,
268                memory_free,
269                compute_capability: 1.2, // Intel GPU equivalent
270                compute_units: 96,
271                max_threads_per_block: 512,
272                is_available: true,
273            });
274        }
275
276        Ok(())
277    }
278
279    /// Select devices based on configuration
280    fn select_devices(&mut self) -> FFTResult<()> {
281        self.selected_devices.clear();
282
283        // Filter available devices
284        let available_devices: Vec<(usize, &GPUDeviceInfo)> = self
285            .devices
286            .iter()
287            .enumerate()
288            .filter(|(_, device)| device.is_available)
289            .collect();
290
291        if available_devices.is_empty() {
292            return Err(FFTError::ComputationError(
293                "No available GPU devices found".to_string(),
294            ));
295        }
296
297        // Determine how many devices to use
298        let max_devices = if self._config.max_devices == 0 {
299            available_devices.len()
300        } else {
301            self._config.max_devices.min(available_devices.len())
302        };
303
304        // Select devices based on strategy
305        match self._config.distribution {
306            WorkloadDistribution::Equal => {
307                // Use first N available devices
308                for i in 0..max_devices {
309                    self.selected_devices.push(available_devices[i].0);
310                }
311            }
312            WorkloadDistribution::ComputeBased => {
313                // Sort by compute capability and select top N
314                let mut sorted_devices = available_devices;
315                sorted_devices.sort_by(|a, b| {
316                    b.1.compute_capability
317                        .partial_cmp(&a.1.compute_capability)
318                        .unwrap_or(std::cmp::Ordering::Equal)
319                });
320
321                for i in 0..max_devices {
322                    self.selected_devices.push(sorted_devices[i].0);
323                }
324            }
325            WorkloadDistribution::MemoryBased => {
326                // Sort by available memory and select top N
327                let mut sorted_devices = available_devices;
328                sorted_devices.sort_by_key(|item| std::cmp::Reverse(item.1.memory_free));
329
330                for i in 0..max_devices {
331                    self.selected_devices.push(sorted_devices[i].0);
332                }
333            }
334            WorkloadDistribution::Manual => {
335                // Use manual selection (for now, just use first N devices)
336                for i in 0..max_devices {
337                    self.selected_devices.push(available_devices[i].0);
338                }
339            }
340            WorkloadDistribution::Adaptive => {
341                // Clone available devices to avoid borrow issues
342                let available_devices_clone: Vec<(usize, GPUDeviceInfo)> = available_devices
343                    .iter()
344                    .map(|(idx, device)| (*idx, (*device).clone()))
345                    .collect();
346
347                // Use performance history to select best devices
348                self.select_adaptive_devices_with_clone(available_devices_clone, max_devices)?;
349            }
350        }
351
352        Ok(())
353    }
354
355    /// Select devices based on adaptive performance history
356    fn select_adaptive_devices_with_clone(
357        &mut self,
358        available_devices: Vec<(usize, GPUDeviceInfo)>,
359        max_devices: usize,
360    ) -> FFTResult<()> {
361        let performance_history = self.performance_history.lock().expect("Operation failed");
362
363        // Calculate average performance for each device
364        let mut device_scores: Vec<(usize, f64)> = available_devices
365            .iter()
366            .map(|(idx, device)| {
367                let avg_performance = performance_history
368                    .get(&device.device_id)
369                    .map(|times| {
370                        if times.is_empty() {
371                            // Default score based on device capabilities
372                            device.compute_capability as f64 * device.compute_units as f64
373                        } else {
374                            // Higher score for faster _devices (lower execution times)
375                            1.0 / (times.iter().sum::<f64>() / times.len() as f64)
376                        }
377                    })
378                    .unwrap_or_else(|| {
379                        // Default score for _devices without history
380                        device.compute_capability as f64 * device.compute_units as f64
381                    });
382
383                (*idx, avg_performance)
384            })
385            .collect();
386
387        // Sort by performance score (descending)
388        device_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
389
390        // Select top N _devices
391        for i in 0..max_devices {
392            self.selected_devices.push(device_scores[i].0);
393        }
394
395        Ok(())
396    }
397
398    /// Get information about available devices
399    pub fn get_devices(&self) -> &[GPUDeviceInfo] {
400        &self.devices
401    }
402
403    /// Get information about selected devices
404    pub fn get_selected_devices(&self) -> Vec<&GPUDeviceInfo> {
405        self.selected_devices
406            .iter()
407            .map(|&idx| &self.devices[idx])
408            .collect()
409    }
410
411    /// Perform multi-GPU sparse FFT
412    pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
413    where
414        T: NumCast + Copy + Debug + Send + Sync + 'static,
415    {
416        if !self.initialized {
417            self.initialize()?;
418        }
419
420        let signal_len = signal.len();
421
422        // Check if signal is large enough for multi-GPU processing
423        if signal_len < self._config.min_signal_size || self.selected_devices.len() <= 1 {
424            // Fall back to single-device processing
425            return self.single_device_sparse_fft(signal);
426        }
427
428        // Distribute workload across selected devices
429        self.multi_device_sparse_fft(signal)
430    }
431
432    /// Single-device sparse FFT fallback
433    fn single_device_sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
434    where
435        T: NumCast + Copy + Debug + 'static,
436    {
437        // Use the best available device
438        let device_idx = self.selected_devices.first().copied().unwrap_or(0);
439        let device = &self.devices[device_idx];
440
441        // Create GPU configuration for the selected device
442        let gpu_config = GPUSparseFFTConfig {
443            base_config: self._config.base_config.clone(),
444            backend: device.backend,
445            device_id: device.device_id,
446            ..GPUSparseFFTConfig::default()
447        };
448
449        // Create GPU processor and perform computation
450        let mut processor = crate::sparse_fft_gpu::GPUSparseFFT::new(gpu_config);
451        processor.sparse_fft(signal)
452    }
453
454    /// Multi-device sparse FFT implementation
455    fn multi_device_sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
456    where
457        T: NumCast + Copy + Debug + Send + Sync + 'static,
458    {
459        let signal_len = signal.len();
460        let num_devices = self.selected_devices.len();
461
462        // Calculate chunk sizes based on distribution strategy
463        let chunk_sizes = self.calculate_chunk_sizes(signal_len, num_devices)?;
464
465        // Split signal into chunks
466        let chunks = self.split_signal(signal, &chunk_sizes)?;
467
468        // Process chunks in parallel across devices
469        let chunk_results: Result<Vec<_>, _> = chunks
470            .par_iter()
471            .zip(self.selected_devices.par_iter())
472            .map(|(chunk, &device_idx)| {
473                let device = &self.devices[device_idx];
474                let start_time = Instant::now();
475
476                // Create GPU configuration for this device
477                let gpu_config = GPUSparseFFTConfig {
478                    base_config: self._config.base_config.clone(),
479                    backend: device.backend,
480                    device_id: device.device_id,
481                    ..GPUSparseFFTConfig::default()
482                };
483
484                // Process chunk
485                let mut processor = crate::sparse_fft_gpu::GPUSparseFFT::new(gpu_config);
486                let result = processor.sparse_fft(chunk);
487
488                // Record performance for adaptive selection
489                if result.is_ok() {
490                    let execution_time = start_time.elapsed().as_secs_f64();
491                    if let Ok(mut history) = self.performance_history.try_lock() {
492                        history
493                            .entry(device.device_id)
494                            .or_default()
495                            .push(execution_time);
496
497                        // Keep only recent history (last 10 measurements)
498                        if let Some(times) = history.get_mut(&device.device_id) {
499                            if times.len() > 10 {
500                                times.drain(0..times.len() - 10);
501                            }
502                        }
503                    }
504                }
505
506                result
507            })
508            .collect();
509
510        let chunk_results = chunk_results?;
511
512        // Combine results from all chunks
513        self.combine_chunk_results(chunk_results)
514    }
515
516    /// Calculate chunk sizes for workload distribution
517    fn calculate_chunk_sizes(
518        &self,
519        signal_len: usize,
520        num_devices: usize,
521    ) -> FFTResult<Vec<usize>> {
522        let mut chunk_sizes = Vec::with_capacity(num_devices);
523
524        match self._config.distribution {
525            WorkloadDistribution::Equal => {
526                let base_size = signal_len / num_devices;
527                let remainder = signal_len % num_devices;
528
529                for i in 0..num_devices {
530                    let size = if i < remainder {
531                        base_size + 1
532                    } else {
533                        base_size
534                    };
535                    chunk_sizes.push(size);
536                }
537            }
538            WorkloadDistribution::ComputeBased => {
539                // Distribute based on compute capability
540                let total_compute: f32 = self
541                    .selected_devices
542                    .iter()
543                    .map(|&idx| {
544                        self.devices[idx].compute_capability
545                            * self.devices[idx].compute_units as f32
546                    })
547                    .sum();
548
549                let mut remaining = signal_len;
550                for (i, &device_idx) in self.selected_devices.iter().enumerate() {
551                    let device = &self.devices[device_idx];
552                    let device_compute = device.compute_capability * device.compute_units as f32;
553                    let ratio = device_compute / total_compute;
554
555                    let size = if i == num_devices - 1 {
556                        remaining // Give remainder to last device
557                    } else {
558                        let size = (signal_len as f32 * ratio) as usize;
559                        remaining = remaining.saturating_sub(size);
560                        size
561                    };
562
563                    chunk_sizes.push(size);
564                }
565            }
566            WorkloadDistribution::MemoryBased => {
567                // Distribute based on available memory
568                let total_memory: usize = self
569                    .selected_devices
570                    .iter()
571                    .map(|&idx| self.devices[idx].memory_free)
572                    .sum();
573
574                let mut remaining = signal_len;
575                for (i, &device_idx) in self.selected_devices.iter().enumerate() {
576                    let device = &self.devices[device_idx];
577                    let ratio = device.memory_free as f32 / total_memory as f32;
578
579                    let size = if i == num_devices - 1 {
580                        remaining
581                    } else {
582                        let size = (signal_len as f32 * ratio) as usize;
583                        remaining = remaining.saturating_sub(size);
584                        size
585                    };
586
587                    chunk_sizes.push(size);
588                }
589            }
590            WorkloadDistribution::Manual => {
591                if self._config.manual_ratios.len() != num_devices {
592                    return Err(FFTError::ValueError(
593                        "Manual ratios length must match number of selected _devices".to_string(),
594                    ));
595                }
596
597                let total_ratio: f32 = self._config.manual_ratios.iter().sum();
598                let mut remaining = signal_len;
599
600                for (i, &ratio) in self._config.manual_ratios.iter().enumerate() {
601                    let size = if i == num_devices - 1 {
602                        remaining
603                    } else {
604                        let size = (signal_len as f32 * ratio / total_ratio) as usize;
605                        remaining = remaining.saturating_sub(size);
606                        size
607                    };
608
609                    chunk_sizes.push(size);
610                }
611            }
612            WorkloadDistribution::Adaptive => {
613                // Use performance history to determine optimal distribution
614                // For now, fall back to compute-based distribution
615                return self.calculate_chunk_sizes(signal_len, num_devices);
616            }
617        }
618
619        Ok(chunk_sizes)
620    }
621
622    /// Split signal into chunks based on calculated sizes
623    fn split_signal<T>(&self, signal: &[T], chunksizes: &[usize]) -> FFTResult<Vec<Vec<T>>>
624    where
625        T: Copy,
626    {
627        let mut chunks = Vec::new();
628        let mut offset = 0;
629
630        for &chunk_size in chunksizes {
631            if offset + chunk_size > signal.len() {
632                return Err(FFTError::ValueError(
633                    "Chunk sizes exceed signal length".to_string(),
634                ));
635            }
636
637            let chunk_end = offset + chunk_size;
638            let chunk = signal[offset..chunk_end].to_vec();
639            chunks.push(chunk);
640            offset = chunk_end;
641        }
642
643        Ok(chunks)
644    }
645
646    /// Combine results from multiple chunks
647    fn combine_chunk_results(
648        &self,
649        chunk_results: Vec<SparseFFTResult>,
650    ) -> FFTResult<SparseFFTResult> {
651        if chunk_results.is_empty() {
652            return Err(FFTError::ComputationError(
653                "No chunk _results to combine".to_string(),
654            ));
655        }
656
657        if chunk_results.len() == 1 {
658            return Ok(chunk_results.into_iter().next().expect("Operation failed"));
659        }
660
661        // Use the computation time from the slowest device
662        let max_computation_time = chunk_results
663            .iter()
664            .map(|r| r.computation_time)
665            .max()
666            .unwrap_or_default();
667
668        // Combine frequency components from all chunks
669        let mut combined_values = Vec::new();
670        let mut combined_indices = Vec::new();
671        let mut index_offset = 0;
672
673        for result in chunk_results {
674            // Store the indices length before moving
675            let indices_len = result.indices.len();
676
677            // Add values from this chunk
678            combined_values.extend(result.values);
679
680            // Adjust indices to account for chunk offset
681            let adjusted_indices: Vec<usize> = result
682                .indices
683                .into_iter()
684                .map(|idx| idx + index_offset)
685                .collect();
686            combined_indices.extend(adjusted_indices);
687
688            // Update offset for next chunk
689            // This is a simplified approach - in practice, frequency domain combining is more complex
690            index_offset += indices_len;
691        }
692
693        // Remove duplicates and sort
694        let mut frequency_map: std::collections::HashMap<usize, Complex64> =
695            std::collections::HashMap::new();
696
697        for (idx, value) in combined_indices.iter().zip(combined_values.iter()) {
698            frequency_map.insert(*idx, *value);
699        }
700
701        let mut sorted_entries: Vec<_> = frequency_map.into_iter().collect();
702        sorted_entries.sort_by_key(|&(idx_, _)| idx_);
703
704        let final_indices: Vec<usize> = sorted_entries.iter().map(|(idx_, _)| *idx_).collect();
705        let final_values: Vec<Complex64> = sorted_entries.iter().map(|(_, val)| *val).collect();
706
707        // Calculate combined sparsity
708        let total_estimated_sparsity = final_values.len();
709
710        Ok(SparseFFTResult {
711            values: final_values,
712            indices: final_indices,
713            estimated_sparsity: total_estimated_sparsity,
714            computation_time: max_computation_time,
715            algorithm: self._config.base_config.algorithm,
716        })
717    }
718
719    /// Get performance statistics for each device
720    pub fn get_performance_stats(&self) -> HashMap<i32, Vec<f64>> {
721        self.performance_history
722            .lock()
723            .expect("Operation failed")
724            .clone()
725    }
726
727    /// Reset performance history
728    pub fn reset_performance_history(&mut self) {
729        self.performance_history
730            .lock()
731            .expect("Operation failed")
732            .clear();
733    }
734}
735
736/// Convenience function for multi-GPU sparse FFT with default configuration
737#[allow(dead_code)]
738pub fn multi_gpu_sparse_fft<T>(
739    signal: &[T],
740    k: usize,
741    algorithm: Option<SparseFFTAlgorithm>,
742    window_function: Option<WindowFunction>,
743) -> FFTResult<SparseFFTResult>
744where
745    T: NumCast + Copy + Debug + Send + Sync + 'static,
746{
747    let base_config = SparseFFTConfig {
748        sparsity: k,
749        algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
750        window_function: window_function.unwrap_or(WindowFunction::None),
751        ..SparseFFTConfig::default()
752    };
753
754    let config = MultiGPUConfig {
755        base_config,
756        ..MultiGPUConfig::default()
757    };
758
759    let mut processor = MultiGPUSparseFFT::new(config);
760    processor.sparse_fft(signal)
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use std::f64::consts::PI;
767
768    // Helper function to create a sparse signal
769    fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
770        let mut signal = vec![0.0; n];
771
772        for i in 0..n {
773            let t = 2.0 * PI * (i as f64) / (n as f64);
774            for &(freq, amp) in frequencies {
775                signal[i] += amp * (freq as f64 * t).sin();
776            }
777        }
778
779        signal
780    }
781
782    #[test]
783    fn test_multi_gpu_initialization() {
784        let mut processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
785        let result = processor.initialize();
786
787        // Should succeed even if no GPU devices available (CPU fallback)
788        assert!(result.is_ok());
789        assert!(!processor.get_devices().is_empty());
790
791        // Check if GPU is available
792        let caps = PlatformCapabilities::detect();
793        if !caps.cuda_available && !caps.gpu_available {
794            eprintln!("GPU not available, verifying CPU fallback is present");
795            assert!(processor
796                .get_devices()
797                .iter()
798                .any(|d| d.backend == GPUBackend::CPUFallback));
799        }
800    }
801
802    #[test]
803    fn test_device_enumeration() {
804        let mut processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
805        processor.initialize().expect("Operation failed");
806
807        let devices = processor.get_devices();
808        assert!(!devices.is_empty());
809
810        // Should have at least CPU fallback
811        assert!(devices.iter().any(|d| d.backend == GPUBackend::CPUFallback));
812
813        // Check if GPU is available
814        let caps = PlatformCapabilities::detect();
815        if caps.cuda_available || caps.gpu_available {
816            eprintln!("GPU available, checking for GPU devices in enumeration");
817            // May have GPU devices
818            assert!(!devices.is_empty());
819        } else {
820            eprintln!("GPU not available, verifying only CPU fallback present");
821            assert_eq!(devices.len(), 1);
822            assert_eq!(devices[0].backend, GPUBackend::CPUFallback);
823        }
824    }
825
826    #[test]
827    fn test_multi_gpu_sparse_fft() {
828        // Check if GPU is available
829        let caps = PlatformCapabilities::detect();
830        let n = if caps.cuda_available || caps.gpu_available {
831            8192 // Large size for multi-GPU if available
832        } else {
833            eprintln!("GPU not available, using smaller size for CPU fallback");
834            1024 // Smaller size for CPU fallback
835        };
836
837        let frequencies = vec![(10, 1.0), (50, 0.5), (100, 0.25)];
838        let signal = create_sparse_signal(n, &frequencies);
839
840        let result = multi_gpu_sparse_fft(
841            &signal,
842            6,
843            Some(SparseFFTAlgorithm::Sublinear),
844            Some(WindowFunction::Hann),
845        );
846
847        assert!(result.is_ok());
848        let result = result.expect("Operation failed");
849        assert!(!result.values.is_empty());
850        assert_eq!(result.values.len(), result.indices.len());
851    }
852
853    #[test]
854    fn test_chunk_size_calculation() {
855        let config = MultiGPUConfig {
856            distribution: WorkloadDistribution::Equal,
857            ..MultiGPUConfig::default()
858        };
859        let mut processor = MultiGPUSparseFFT::new(config);
860
861        // Simulate device setup
862        processor.selected_devices = vec![0, 1, 2];
863
864        let chunk_sizes = processor
865            .calculate_chunk_sizes(1000, 3)
866            .expect("Operation failed");
867        assert_eq!(chunk_sizes.len(), 3);
868        assert_eq!(chunk_sizes.iter().sum::<usize>(), 1000);
869    }
870
871    #[test]
872    fn test_signal_splitting() {
873        let processor = MultiGPUSparseFFT::new(MultiGPUConfig::default());
874        let signal = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
875        let chunk_sizes = vec![3, 3, 4];
876
877        let chunks = processor
878            .split_signal(&signal, &chunk_sizes)
879            .expect("Operation failed");
880        assert_eq!(chunks.len(), 3);
881        assert_eq!(chunks[0], vec![1, 2, 3]);
882        assert_eq!(chunks[1], vec![4, 5, 6]);
883        assert_eq!(chunks[2], vec![7, 8, 9, 10]);
884    }
885}