1use crate::error::{FFTError, FFTResult};
7use crate::sparse_fft::{
8 SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, SparsityEstimationMethod, WindowFunction,
9};
10use scirs2_core::gpu::{GpuBackend, GpuDevice};
11use scirs2_core::numeric::Complex64;
12use scirs2_core::numeric::NumCast;
13use scirs2_core::simd_ops::PlatformCapabilities;
14use std::fmt::Debug;
15use std::time::Instant;
16
17#[allow(dead_code)]
19pub struct BufferDescriptor {
20 size: usize,
21 id: u64,
22}
23
24pub enum BufferLocation {
26 Device,
27 Host,
28}
29
30pub enum BufferType {
32 Input,
33 Output,
34 Work,
35}
36
37#[allow(dead_code)]
39pub struct GpuStream {
40 id: u64,
41}
42
43impl GpuStream {
44 pub fn new(_deviceid: i32) -> FFTResult<Self> {
45 Err(FFTError::NotImplementedError(
46 "GPU streams need to be implemented with scirs2-core::gpu abstractions".to_string(),
47 ))
48 }
49}
50
51pub struct GpuMemoryManager;
53
54impl GpuMemoryManager {
55 pub fn allocate(
56 &self,
57 _size: usize,
58 _location: BufferLocation,
59 _buffer_type: BufferType,
60 ) -> FFTResult<BufferDescriptor> {
61 Err(FFTError::NotImplementedError(
62 "GPU memory management needs to be implemented with scirs2-core::gpu abstractions"
63 .to_string(),
64 ))
65 }
66
67 pub fn free(_descriptor: BufferDescriptor) -> FFTResult<()> {
68 Err(FFTError::NotImplementedError(
69 "GPU memory management needs to be implemented with scirs2-core::gpu abstractions"
70 .to_string(),
71 ))
72 }
73}
74
75#[allow(dead_code)]
77pub fn get_global_memory_manager() -> FFTResult<GpuMemoryManager> {
78 Err(FFTError::NotImplementedError(
79 "GPU memory management needs to be implemented with scirs2-core::gpu abstractions"
80 .to_string(),
81 ))
82}
83
84#[allow(dead_code)]
86pub fn ensure_gpu_available() -> FFTResult<bool> {
87 let caps = PlatformCapabilities::detect();
88 Ok(caps.cuda_available || caps.gpu_available)
89}
90
91pub struct GpuDeviceInfo {
93 pub device: GpuDevice,
95 pub initialized: bool,
97}
98
99impl GpuDeviceInfo {
100 pub fn new(_deviceid: usize) -> FFTResult<Self> {
102 let device = GpuDevice::new(GpuBackend::default(), _deviceid);
103 Ok(Self {
104 device,
105 initialized: true,
106 })
107 }
108
109 pub fn is_available(&self) -> bool {
111 self.initialized
112 }
113}
114
115#[allow(dead_code)]
117pub struct FftGpuContext {
118 core_context: scirs2_core::gpu::GpuContext,
120 device_id: i32,
122 device_info: GpuDeviceInfo,
124 stream: GpuStream,
126 initialized: bool,
128}
129
130impl FftGpuContext {
131 pub fn new(deviceid: i32) -> FFTResult<Self> {
133 let gpu_backend = scirs2_core::gpu::GpuBackend::Cuda;
135 let core_context = scirs2_core::gpu::GpuContext::new(gpu_backend)
136 .map_err(|e| FFTError::ComputationError(e.to_string()))?;
137
138 let device_info = GpuDeviceInfo::new(deviceid as usize)?;
140
141 let stream = GpuStream::new(deviceid)?;
143
144 Ok(Self {
145 core_context,
146 device_id: deviceid,
147 device_info,
148 stream,
149 initialized: true,
150 })
151 }
152
153 pub fn device_info(&self) -> &GpuDeviceInfo {
155 &self.device_info
156 }
157
158 pub fn stream(&self) -> &GpuStream {
160 &self.stream
161 }
162
163 pub fn allocate(&self, sizebytes: usize) -> FFTResult<BufferDescriptor> {
165 let manager = get_global_memory_manager()?;
169
170 manager.allocate(sizebytes, BufferLocation::Device, BufferType::Work)
171 }
172
173 pub fn free(&self, descriptor: BufferDescriptor) -> FFTResult<()> {
175 let _manager = get_global_memory_manager()?;
179
180 GpuMemoryManager::free(descriptor)
181 }
182
183 pub fn copy_host_to_device<T>(
185 &self,
186 host_data: &[T],
187 device_buffer: &BufferDescriptor,
188 ) -> FFTResult<()> {
189 let host_size_bytes = std::mem::size_of_val(host_data);
193 let device_size_bytes = device_buffer.size;
194
195 if host_size_bytes > device_size_bytes {
196 return Err(FFTError::DimensionError(format!(
197 "Host buffer size ({host_size_bytes} bytes) exceeds device buffer size ({device_size_bytes} bytes)"
198 )));
199 }
200
201 Ok(())
202 }
203
204 pub fn copy_device_to_host<T>(
206 &self,
207 device_buffer: &BufferDescriptor,
208 host_data: &mut [T],
209 ) -> FFTResult<()> {
210 let host_size_bytes = std::mem::size_of_val(host_data);
214 let device_size_bytes = device_buffer.size;
215
216 if device_size_bytes > host_size_bytes {
217 return Err(FFTError::DimensionError(format!(
218 "Device buffer size ({device_size_bytes} bytes) exceeds host buffer size ({host_size_bytes} bytes)"
219 )));
220 }
221
222 Ok(())
223 }
224}
225
226pub struct GpuSparseFFT {
228 context: FftGpuContext,
230 config: SparseFFTConfig,
232 input_buffer: Option<BufferDescriptor>,
234 output_values_buffer: Option<BufferDescriptor>,
236 output_indices_buffer: Option<BufferDescriptor>,
238}
239
240impl GpuSparseFFT {
241 pub fn new(_deviceid: i32, config: SparseFFTConfig) -> FFTResult<Self> {
243 let context = FftGpuContext::new(_deviceid)?;
245
246 Ok(Self {
247 context,
248 config,
249 input_buffer: None,
250 output_values_buffer: None,
251 output_indices_buffer: None,
252 })
253 }
254
255 fn initialize_buffers(&mut self, signalsize: usize) -> FFTResult<()> {
257 self.free_buffers()?;
259
260 let memory_manager = get_global_memory_manager()?;
262
263 let input_buffer = memory_manager.allocate(
265 signalsize * std::mem::size_of::<Complex64>(),
266 BufferLocation::Device,
267 BufferType::Input,
268 )?;
269 self.input_buffer = Some(input_buffer);
270
271 let max_components = self.config.sparsity.min(signalsize);
273
274 let output_values_buffer = memory_manager.allocate(
275 max_components * std::mem::size_of::<Complex64>(),
276 BufferLocation::Device,
277 BufferType::Output,
278 )?;
279 self.output_values_buffer = Some(output_values_buffer);
280
281 let output_indices_buffer = memory_manager.allocate(
282 max_components * std::mem::size_of::<usize>(),
283 BufferLocation::Device,
284 BufferType::Output,
285 )?;
286 self.output_indices_buffer = Some(output_indices_buffer);
287
288 Ok(())
289 }
290
291 fn free_buffers(&mut self) -> FFTResult<()> {
293 if let Ok(_memory_manager) = get_global_memory_manager() {
294 if let Some(buffer) = self.input_buffer.take() {
295 GpuMemoryManager::free(buffer)?;
296 }
297
298 if let Some(buffer) = self.output_values_buffer.take() {
299 GpuMemoryManager::free(buffer)?;
300 }
301
302 if let Some(buffer) = self.output_indices_buffer.take() {
303 GpuMemoryManager::free(buffer)?;
304 }
305 }
306
307 Ok(())
308 }
309
310 pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
312 where
313 T: NumCast + Copy + Debug + 'static,
314 {
315 let start = Instant::now();
316
317 self.initialize_buffers(signal.len())?;
319
320 let signal_complex: Vec<Complex64> = signal
322 .iter()
323 .map(|&val| {
324 let val_f64 = NumCast::from(val).ok_or_else(|| {
325 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
326 })?;
327 Ok(Complex64::new(val_f64, 0.0))
328 })
329 .collect::<FFTResult<Vec<_>>>()?;
330
331 if let Some(input_buffer) = &self.input_buffer {
333 self.context
334 .copy_host_to_device(&signal_complex, input_buffer)?;
335 } else {
336 return Err(FFTError::MemoryError(
337 "Input buffer not initialized".to_string(),
338 ));
339 }
340
341 let result = match self.config.algorithm {
343 SparseFFTAlgorithm::Sublinear => crate::execute_cuda_sublinear_sparse_fft(
344 &signal_complex,
345 self.config.sparsity,
346 self.config.algorithm,
347 )?,
348 SparseFFTAlgorithm::CompressedSensing => {
349 crate::execute_cuda_compressed_sensing_sparse_fft(
350 &signal_complex,
351 self.config.sparsity,
352 )?
353 }
354 SparseFFTAlgorithm::Iterative => {
355 crate::execute_cuda_iterative_sparse_fft(
356 &signal_complex,
357 self.config.sparsity,
358 100, )?
360 }
361 SparseFFTAlgorithm::FrequencyPruning => {
362 crate::execute_cuda_frequency_pruning_sparse_fft(
363 &signal_complex,
364 self.config.sparsity,
365 0.01, )?
367 }
368 SparseFFTAlgorithm::SpectralFlatness => {
369 crate::execute_cuda_spectral_flatness_sparse_fft(
370 &signal_complex,
371 self.config.sparsity,
372 self.config.flatness_threshold,
373 )?
374 }
375 _ => {
377 let mut cpu_processor = crate::sparse_fft::SparseFFT::new(self.config.clone());
378 let mut cpu_result = cpu_processor.sparse_fft(&signal_complex)?;
379
380 cpu_result.computation_time = start.elapsed();
382 cpu_result.algorithm = self.config.algorithm;
383
384 cpu_result
385 }
386 };
387
388 Ok(result)
389 }
390}
391
392impl Drop for GpuSparseFFT {
393 fn drop(&mut self) {
394 let _ = self.free_buffers();
396 }
397}
398
399#[allow(clippy::too_many_arguments)]
416#[allow(dead_code)]
417pub fn cuda_sparse_fft<T>(
418 signal: &[T],
419 k: usize,
420 device_id: i32,
421 algorithm: Option<SparseFFTAlgorithm>,
422 window_function: Option<WindowFunction>,
423) -> FFTResult<SparseFFTResult>
424where
425 T: NumCast + Copy + Debug + 'static,
426{
427 if !ensure_gpu_available()? {
429 return Err(FFTError::ComputationError(
430 "GPU is not available. Either GPU features are not enabled or GPU hardware/drivers are not available.".to_string()
431 ));
432 }
433
434 let config = SparseFFTConfig {
436 estimation_method: SparsityEstimationMethod::Manual,
437 sparsity: k,
438 algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
439 window_function: window_function.unwrap_or(WindowFunction::None),
440 ..SparseFFTConfig::default()
441 };
442
443 let mut processor = GpuSparseFFT::new(device_id, config)?;
448 processor.sparse_fft(signal)
449}
450
451#[allow(clippy::too_many_arguments)]
467#[allow(dead_code)]
468pub fn cuda_batch_sparse_fft<T>(
469 signals: &[Vec<T>],
470 k: usize,
471 device_id: i32,
472 algorithm: Option<SparseFFTAlgorithm>,
473 window_function: Option<WindowFunction>,
474) -> FFTResult<Vec<SparseFFTResult>>
475where
476 T: NumCast + Copy + Debug + 'static,
477{
478 let config = SparseFFTConfig {
480 estimation_method: SparsityEstimationMethod::Manual,
481 sparsity: k,
482 algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
483 window_function: window_function.unwrap_or(WindowFunction::None),
484 ..SparseFFTConfig::default()
485 };
486
487 let mut processor = GpuSparseFFT::new(device_id, config)?;
489
490 let mut results = Vec::with_capacity(signals.len());
492 for signal in signals {
493 results.push(processor.sparse_fft(signal)?);
494 }
495
496 Ok(results)
497}
498
499#[allow(dead_code)]
501pub fn get_cuda_devices() -> FFTResult<Vec<GpuDeviceInfo>> {
502 if !ensure_gpu_available().unwrap_or(false) {
506 return Ok(Vec::new());
507 }
508
509 let devices = vec![GpuDeviceInfo::new(0)?];
511
512 Ok(devices)
513}
514
515#[cfg(test)]
518mod tests {
519 use super::*;
520 use crate::sparse_fft_gpu_memory::AllocationStrategy;
521 use std::f64::consts::PI;
522
523 fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
525 let mut signal = vec![0.0; n];
526
527 for i in 0..n {
528 let t = 2.0 * PI * (i as f64) / (n as f64);
529 for &(freq, amp) in frequencies {
530 signal[i] += amp * (freq as f64 * t).sin();
531 }
532 }
533
534 signal
535 }
536
537 #[test]
538 fn test_cuda_initialization() {
539 if !ensure_gpu_available().unwrap_or(false) {
541 eprintln!("GPU not available, using mock initialization test");
543 let devices = get_cuda_devices().unwrap();
545 assert!(devices.is_empty() || !devices.is_empty()); return;
547 }
548
549 let _ = crate::sparse_fft_gpu_memory::init_global_memory_manager(
551 crate::sparse_fft_gpu::GPUBackend::CUDA,
552 0,
553 AllocationStrategy::CacheBySize,
554 1024 * 1024 * 1024, );
556
557 let devices = get_cuda_devices().expect("CUDA devices query should succeed");
559 if devices.is_empty() {
560 return;
562 }
563 assert!(!devices.is_empty());
564
565 match FftGpuContext::new(0) {
567 Ok(context) => {
568 assert_eq!(context.device_id, 0);
569 assert!(context.initialized);
570 }
571 Err(_) => {
572 eprintln!("GPU context creation failed - no GPU hardware available");
574 }
575 }
576 }
577
578 #[test]
579 fn test_cuda_sparse_fft() {
580 let n = 256;
582 let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.25)];
583 let signal = create_sparse_signal(n, &frequencies);
584
585 if !ensure_gpu_available().unwrap_or(false) {
587 eprintln!("GPU not available, using CPU fallback for sparse FFT");
589 let config = SparseFFTConfig {
590 estimation_method: SparsityEstimationMethod::Manual,
591 sparsity: 6,
592 algorithm: SparseFFTAlgorithm::Sublinear,
593 window_function: WindowFunction::Hann,
594 ..SparseFFTConfig::default()
595 };
596 let mut processor = crate::sparse_fft::algorithms::SparseFFT::new(config);
597 let result = processor.sparse_fft(&signal).unwrap();
598 assert!(!result.values.is_empty());
599 assert_eq!(result.algorithm, SparseFFTAlgorithm::Sublinear);
600 return;
601 }
602
603 match cuda_sparse_fft(
605 &signal,
606 6,
607 0,
608 Some(SparseFFTAlgorithm::Sublinear),
609 Some(WindowFunction::Hann),
610 ) {
611 Ok(result) => {
612 assert!(!result.values.is_empty());
614 assert_eq!(result.algorithm, SparseFFTAlgorithm::Sublinear);
615 }
616 Err(e) => {
617 assert!(e.to_string().contains("GPU") || e.to_string().contains("not available"));
619 eprintln!("GPU test skipped: {}", e);
620 }
621 }
622 }
623
624 #[test]
625 fn test_cuda_batch_processing() {
626 let n = 128;
628 let signals = vec![
629 create_sparse_signal(n, &[(3, 1.0), (7, 0.5)]),
630 create_sparse_signal(n, &[(5, 1.0), (10, 0.7)]),
631 create_sparse_signal(n, &[(2, 0.8), (12, 0.6)]),
632 ];
633
634 if !ensure_gpu_available().unwrap_or(false) {
636 eprintln!("GPU not available, using CPU fallback for batch processing");
638 let config = SparseFFTConfig {
639 estimation_method: SparsityEstimationMethod::Manual,
640 sparsity: 4,
641 algorithm: SparseFFTAlgorithm::Sublinear,
642 window_function: WindowFunction::None,
643 ..SparseFFTConfig::default()
644 };
645 let mut processor = crate::sparse_fft::algorithms::SparseFFT::new(config);
646 let mut results = Vec::new();
647 for signal in &signals {
648 results.push(processor.sparse_fft(signal).unwrap());
649 }
650 assert_eq!(results.len(), signals.len());
651 return;
652 }
653
654 match cuda_batch_sparse_fft(&signals, 4, 0, Some(SparseFFTAlgorithm::Sublinear), None) {
656 Ok(results) => {
657 assert_eq!(results.len(), signals.len());
659 for result in results {
661 assert!(!result.values.is_empty());
662 }
663 }
664 Err(e) => {
665 assert!(e.to_string().contains("GPU") || e.to_string().contains("not available"));
667 eprintln!("GPU batch test skipped: {}", e);
668 }
669 }
670 }
671}
672
673