scirs2_fft/
sparse_fft_gpu_memory.rs

1//! Memory management for GPU-accelerated sparse FFT
2//!
3//! This module provides memory management utilities for GPU-accelerated sparse FFT
4//! implementations, including buffer allocation, reuse, and transfer optimization.
5
6use crate::error::{FFTError, FFTResult};
7use crate::sparse_fft_gpu::GPUBackend;
8use scirs2_core::numeric::Complex64;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12// CUDA support temporarily disabled until cudarc dependency is enabled
13// #[cfg(feature = "cuda")]
14// use cudarc::driver::{CudaDevice, DevicePtr, DriverError};
15
16// HIP support temporarily disabled until hiprt dependency is enabled
17// #[cfg(feature = "hip")]
18// use hiprt::{hipDevice_t, hipDeviceptr_t, hipError_t};
19
20#[cfg(any(feature = "cuda", feature = "hip", feature = "sycl"))]
21use std::sync::OnceLock;
22
23// CUDA support temporarily disabled until cudarc dependency is enabled
24#[cfg(feature = "cuda")]
25static CUDA_DEVICE: OnceLock<Option<Arc<u8>>> = OnceLock::new(); // Placeholder type
26
27// HIP support temporarily disabled until hiprt dependency is enabled
28#[cfg(feature = "hip")]
29static HIP_DEVICE: OnceLock<Option<u8>> = OnceLock::new(); // Placeholder type
30
31#[cfg(feature = "sycl")]
32static SYCL_DEVICE: OnceLock<Option<SyclDevice>> = OnceLock::new();
33
34/// Placeholder SYCL device type
35#[cfg(feature = "sycl")]
36#[derive(Debug, Clone)]
37#[allow(dead_code)]
38pub struct SyclDevice {
39    device_id: i32,
40    device_name: String,
41}
42
43/// Placeholder SYCL device pointer type
44#[cfg(feature = "sycl")]
45pub type SyclDevicePtr = *mut std::os::raw::c_void;
46
47/// Memory buffer location
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum BufferLocation {
50    /// Host (CPU) memory
51    Host,
52    /// Device (GPU) memory
53    Device,
54    /// Pinned host memory (page-locked for faster transfers)
55    PinnedHost,
56    /// Unified memory (accessible from both CPU and GPU)
57    Unified,
58}
59
60/// Memory buffer type
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum BufferType {
63    /// Input signal buffer
64    Input,
65    /// Output signal buffer
66    Output,
67    /// Work buffer for intermediate results
68    Work,
69    /// FFT plan buffer
70    Plan,
71}
72
73/// Buffer allocation strategy
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum AllocationStrategy {
76    /// Allocate once, reuse for same sizes
77    CacheBySize,
78    /// Allocate for each operation
79    AlwaysAllocate,
80    /// Preallocate a fixed size buffer
81    Preallocate,
82    /// Use a pool of buffers
83    BufferPool,
84}
85
86/// Memory buffer descriptor
87#[derive(Debug, Clone)]
88pub struct BufferDescriptor {
89    /// Size of the buffer in elements
90    pub size: usize,
91    /// Element size in bytes
92    pub element_size: usize,
93    /// Buffer location
94    pub location: BufferLocation,
95    /// Buffer type
96    pub buffer_type: BufferType,
97    /// Buffer ID
98    pub id: usize,
99    /// GPU backend used for this buffer
100    pub backend: GPUBackend,
101    /// Device memory pointer (for CUDA)
102    #[cfg(feature = "cuda")]
103    cuda_device_ptr: Option<*mut u8>, // Placeholder type for disabled CUDA
104    /// Device memory pointer (for HIP)
105    #[cfg(feature = "hip")]
106    hip_device_ptr: Option<*mut u8>, // Placeholder type for disabled HIP
107    /// Device memory pointer (for SYCL)
108    #[cfg(feature = "sycl")]
109    sycl_device_ptr: Option<SyclDevicePtr>,
110    /// Host memory pointer (for CPU fallback or pinned memory)
111    host_ptr: Option<*mut std::os::raw::c_void>,
112}
113
114// SAFETY: BufferDescriptor manages memory through proper allocation/deallocation
115// Raw pointers are only used within controlled contexts
116unsafe impl Send for BufferDescriptor {}
117unsafe impl Sync for BufferDescriptor {}
118
119/// Initialize CUDA device (call once at startup)
120#[cfg(feature = "cuda")]
121#[allow(dead_code)]
122pub fn init_cuda_device() -> FFTResult<bool> {
123    let device_result = CUDA_DEVICE.get_or_init(|| {
124        // CUDA device initialization temporarily disabled until cudarc dependency is enabled
125        /*
126        match CudaDevice::new(0) {
127            Ok(device) => Some(Arc::new(device)),
128            Err(_) => None,
129        }
130        */
131        None // Placeholder - no CUDA device available
132    });
133
134    Ok(device_result.is_some())
135}
136
137/// Initialize CUDA device (no-op without CUDA feature)
138#[cfg(not(feature = "cuda"))]
139#[allow(dead_code)]
140pub fn init_cuda_device() -> FFTResult<bool> {
141    Ok(false)
142}
143
144/// Initialize HIP device (call once at startup)
145#[cfg(feature = "hip")]
146#[allow(dead_code)]
147pub fn init_hip_device() -> FFTResult<bool> {
148    // HIP support temporarily disabled until hiprt dependency is enabled
149    Err(FFTError::NotImplementedError(
150        "HIP support is temporarily disabled".to_string(),
151    ))
152}
153
154/// Initialize HIP device (no-op without HIP feature)
155#[cfg(not(feature = "hip"))]
156#[allow(dead_code)]
157pub fn init_hip_device() -> FFTResult<bool> {
158    Ok(false)
159}
160
161/// Initialize SYCL device (call once at startup)
162#[cfg(feature = "sycl")]
163#[allow(dead_code)]
164pub fn init_sycl_device() -> FFTResult<bool> {
165    let device_result = SYCL_DEVICE.get_or_init(|| {
166        // In a real SYCL implementation, this would:
167        // 1. Query available SYCL devices
168        // 2. Select the best device (GPU preferred, then CPU)
169        // 3. Create a SYCL context and queue
170
171        // For now, we'll create a placeholder device
172        Some(SyclDevice {
173            device_id: 0,
174            device_name: "Generic SYCL Device".to_string(),
175        })
176    });
177
178    Ok(device_result.is_some())
179}
180
181/// Initialize SYCL device (no-op without SYCL feature)
182#[cfg(not(feature = "sycl"))]
183#[allow(dead_code)]
184pub fn init_sycl_device() -> FFTResult<bool> {
185    Ok(false)
186}
187
188/// Check if CUDA is available
189#[cfg(feature = "cuda")]
190#[allow(dead_code)]
191pub fn is_cuda_available() -> bool {
192    CUDA_DEVICE.get().map(|d| d.is_some()).unwrap_or(false)
193}
194
195/// Check if CUDA is available (always false without CUDA feature)
196#[cfg(not(feature = "cuda"))]
197#[allow(dead_code)]
198pub fn is_cuda_available() -> bool {
199    false
200}
201
202/// Check if HIP is available
203#[cfg(feature = "hip")]
204#[allow(dead_code)]
205pub fn is_hip_available() -> bool {
206    HIP_DEVICE.get().map(|d| d.is_some()).unwrap_or(false)
207}
208
209/// Check if HIP is available (always false without HIP feature)
210#[cfg(not(feature = "hip"))]
211#[allow(dead_code)]
212pub fn is_hip_available() -> bool {
213    false
214}
215
216/// Check if SYCL is available
217#[cfg(feature = "sycl")]
218#[allow(dead_code)]
219pub fn is_sycl_available() -> bool {
220    SYCL_DEVICE.get().map(|d| d.is_some()).unwrap_or(false)
221}
222
223/// Check if SYCL is available (always false without SYCL feature)
224#[cfg(not(feature = "sycl"))]
225#[allow(dead_code)]
226pub fn is_sycl_available() -> bool {
227    false
228}
229
230/// Check if any GPU backend is available
231#[allow(dead_code)]
232pub fn is_gpu_available() -> bool {
233    is_cuda_available() || is_hip_available() || is_sycl_available()
234}
235
236/// Initialize the best available GPU backend
237#[allow(dead_code)]
238pub fn init_gpu_backend() -> FFTResult<GPUBackend> {
239    // Try CUDA first (usually fastest)
240    if init_cuda_device()? {
241        return Ok(GPUBackend::CUDA);
242    }
243
244    // Then try HIP (AMD GPUs)
245    if init_hip_device()? {
246        return Ok(GPUBackend::HIP);
247    }
248
249    // Then try SYCL (cross-platform, Intel GPUs, etc.)
250    if init_sycl_device()? {
251        return Ok(GPUBackend::SYCL);
252    }
253
254    // Fall back to CPU
255    Ok(GPUBackend::CPUFallback)
256}
257
258impl BufferDescriptor {
259    /// Create a new buffer descriptor with specified backend
260    pub fn new(
261        size: usize,
262        element_size: usize,
263        location: BufferLocation,
264        buffer_type: BufferType,
265        id: usize,
266        backend: GPUBackend,
267    ) -> FFTResult<Self> {
268        let mut descriptor = Self {
269            size,
270            element_size,
271            location,
272            buffer_type,
273            id,
274            backend,
275            #[cfg(feature = "cuda")]
276            cuda_device_ptr: None,
277            #[cfg(feature = "hip")]
278            hip_device_ptr: None,
279            #[cfg(feature = "sycl")]
280            sycl_device_ptr: None,
281            host_ptr: None,
282        };
283
284        descriptor.allocate()?;
285        Ok(descriptor)
286    }
287
288    /// Create a new buffer descriptor with auto-detected backend
289    pub fn new_auto(
290        size: usize,
291        element_size: usize,
292        location: BufferLocation,
293        buffer_type: BufferType,
294        id: usize,
295    ) -> FFTResult<Self> {
296        let backend = init_gpu_backend()?;
297        Self::new(size, element_size, location, buffer_type, id, backend)
298    }
299
300    /// Allocate the actual memory based on location and backend
301    fn allocate(&mut self) -> FFTResult<()> {
302        let total_size = self.size * self.element_size;
303
304        match self.location {
305            BufferLocation::Device => {
306                match self.backend {
307                    GPUBackend::CUDA => {
308                        #[cfg(feature = "cuda")]
309                        {
310                            if let Some(_device) = CUDA_DEVICE.get().and_then(|d| d.as_ref()) {
311                                // CUDA API calls temporarily disabled until cudarc dependency is enabled
312                                /*
313                                let device_mem = device.alloc::<u8>(total_size).map_err(|e| {
314                                    FFTError::ComputationError(format!(
315                                        "Failed to allocate CUDA memory: {:?}",
316                                        e
317                                    ))
318                                })?;
319                                self.cuda_device_ptr = Some(device_mem);
320                                return Ok(());
321                                */
322                            }
323                        }
324
325                        // Fallback to host memory if CUDA is not available
326                        self.backend = GPUBackend::CPUFallback;
327                        self.location = BufferLocation::Host;
328                        self.allocate_host_memory(total_size)?;
329                    }
330                    GPUBackend::HIP => {
331                        #[cfg(feature = "hip")]
332                        {
333                            if HIP_DEVICE.get().map(|d| d.is_some()).unwrap_or(false) {
334                                // use hiprt::*; // Temporarily disabled
335                                // HIP API calls temporarily disabled until hiprt dependency is available
336                                /*
337                                unsafe {
338                                    let mut device_ptr: hipDeviceptr_t = std::ptr::null_mut();
339                                    let result = hipMalloc(&mut device_ptr, total_size);
340                                    if result == hipError_t::hipSuccess {
341                                        self.hip_device_ptr = Some(device_ptr);
342                                        return Ok(());
343                                    } else {
344                                        return Err(FFTError::ComputationError(format!(
345                                            "Failed to allocate HIP memory: {:?}",
346                                            result
347                                        )));
348                                    }
349                                }
350                                */
351                            }
352                        }
353
354                        // Fallback to host memory if HIP is not available
355                        self.backend = GPUBackend::CPUFallback;
356                        self.location = BufferLocation::Host;
357                        self.allocate_host_memory(total_size)?;
358                    }
359                    GPUBackend::SYCL => {
360                        #[cfg(feature = "sycl")]
361                        {
362                            if SYCL_DEVICE.get().map(|d| d.is_some()).unwrap_or(false) {
363                                // In a real SYCL implementation, this would:
364                                // 1. Use sycl::malloc_device() to allocate device memory
365                                // 2. Store the device pointer for later use
366                                // 3. Handle allocation errors appropriately
367
368                                // For placeholder implementation, simulate successful allocation
369                                let device_ptr = Box::into_raw(Box::new(vec![0u8; total_size]))
370                                    as *mut std::os::raw::c_void;
371                                self.sycl_device_ptr = Some(device_ptr);
372                                return Ok(());
373                            }
374                        }
375
376                        // Fallback to host memory if SYCL is not available
377                        self.backend = GPUBackend::CPUFallback;
378                        self.location = BufferLocation::Host;
379                        self.allocate_host_memory(total_size)?;
380                    }
381                    GPUBackend::CPUFallback => {
382                        self.location = BufferLocation::Host;
383                        self.allocate_host_memory(total_size)?;
384                    }
385                }
386            }
387            BufferLocation::Host | BufferLocation::PinnedHost | BufferLocation::Unified => {
388                self.allocate_host_memory(total_size)?;
389            }
390        }
391
392        Ok(())
393    }
394
395    /// Allocate host memory
396    fn allocate_host_memory(&mut self, size: usize) -> FFTResult<()> {
397        let vec = vec![0u8; size];
398        let boxed_slice = vec.into_boxed_slice();
399        let ptr = Box::into_raw(boxed_slice) as *mut std::os::raw::c_void;
400        self.host_ptr = Some(ptr);
401        Ok(())
402    }
403
404    /// Get host pointer and size
405    pub fn get_host_ptr(&self) -> (*mut std::os::raw::c_void, usize) {
406        match self.host_ptr {
407            Some(ptr) => (ptr, self.size * self.element_size),
408            None => {
409                // This shouldn't happen with proper allocation
410                panic!("Attempted to get host pointer from unallocated buffer");
411            }
412        }
413    }
414
415    /// Get device pointer (CUDA)
416    #[cfg(feature = "cuda")]
417    pub fn get_cuda_device_ptr(&self) -> Option<*mut u8> {
418        self.cuda_device_ptr
419    }
420
421    /// Get device pointer (HIP)
422    #[cfg(feature = "hip")]
423    pub fn get_hip_device_ptr(&self) -> Option<*mut u8> {
424        self.hip_device_ptr
425    }
426
427    /// Get device pointer (SYCL)
428    #[cfg(feature = "sycl")]
429    pub fn get_sycl_device_ptr(&self) -> Option<SyclDevicePtr> {
430        self.sycl_device_ptr
431    }
432
433    /// Check if this buffer has GPU memory allocated
434    pub fn has_device_memory(&self) -> bool {
435        match self.backend {
436            GPUBackend::CUDA => {
437                #[cfg(feature = "cuda")]
438                return self.cuda_device_ptr.is_some();
439                #[cfg(not(feature = "cuda"))]
440                return false;
441            }
442            GPUBackend::HIP => {
443                #[cfg(feature = "hip")]
444                return self.hip_device_ptr.is_some();
445                #[cfg(not(feature = "hip"))]
446                return false;
447            }
448            GPUBackend::SYCL => {
449                #[cfg(feature = "sycl")]
450                return self.sycl_device_ptr.is_some();
451                #[cfg(not(feature = "sycl"))]
452                return false;
453            }
454            _ => false,
455        }
456    }
457
458    /// Copy data from host to device
459    pub fn copy_host_to_device(&self, hostdata: &[u8]) -> FFTResult<()> {
460        match self.location {
461            BufferLocation::Device => {
462                match self.backend {
463                    GPUBackend::CUDA => {
464                        #[cfg(feature = "cuda")]
465                        {
466                            if let (Some(_device_ptr), Some(_device)) = (
467                                self.cuda_device_ptr.as_ref(),
468                                CUDA_DEVICE.get().and_then(|d| d.as_ref()),
469                            ) {
470                                // CUDA API calls temporarily disabled until cudarc dependency is enabled
471                                /*
472                                device.htod_copy(hostdata, device_ptr).map_err(|e| {
473                                    FFTError::ComputationError(format!(
474                                        "Failed to copy _data to CUDA GPU: {:?}",
475                                        e
476                                    ))
477                                })?;
478                                return Ok(());
479                                */
480                            }
481                        }
482
483                        // Fallback to host memory
484                        self.copy_to_host_memory(hostdata)?;
485                    }
486                    GPUBackend::HIP => {
487                        #[cfg(feature = "hip")]
488                        {
489                            if let Some(_device_ptr) = self.hip_device_ptr {
490                                // use hiprt::*; // Temporarily disabled
491                                // HIP API calls temporarily disabled until hiprt dependency is available
492                                /*
493                                unsafe {
494                                    let result = hipMemcpyHtoD(
495                                        device_ptr,
496                                        hostdata.as_ptr() as *const std::os::raw::c_void,
497                                        hostdata.len(),
498                                    );
499                                    if result == hipError_t::hipSuccess {
500                                        return Ok(());
501                                    } else {
502                                        return Err(FFTError::ComputationError(format!(
503                                            "Failed to copy _data to HIP GPU: {:?}",
504                                            result
505                                        )));
506                                    }
507                                }
508                                */
509                            }
510                        }
511
512                        // Fallback to host memory
513                        self.copy_to_host_memory(hostdata)?;
514                    }
515                    GPUBackend::SYCL => {
516                        #[cfg(feature = "sycl")]
517                        {
518                            if let Some(device_ptr) = self.sycl_device_ptr {
519                                // In a real SYCL implementation, this would:
520                                // 1. Use sycl::queue::memcpy() or similar to copy _data
521                                // 2. Handle synchronization appropriately
522                                // 3. Return appropriate error codes
523
524                                // For placeholder implementation, simulate the copy
525                                unsafe {
526                                    std::ptr::copy_nonoverlapping(
527                                        hostdata.as_ptr(),
528                                        device_ptr as *mut u8,
529                                        hostdata.len(),
530                                    );
531                                }
532                                return Ok(());
533                            }
534                        }
535
536                        // Fallback to host memory
537                        self.copy_to_host_memory(hostdata)?;
538                    }
539                    _ => {
540                        // CPU fallback
541                        self.copy_to_host_memory(hostdata)?;
542                    }
543                }
544            }
545            BufferLocation::Host | BufferLocation::PinnedHost | BufferLocation::Unified => {
546                self.copy_to_host_memory(hostdata)?;
547            }
548        }
549
550        Ok(())
551    }
552
553    /// Helper to copy data to host memory
554    fn copy_to_host_memory(&self, hostdata: &[u8]) -> FFTResult<()> {
555        if let Some(host_ptr) = self.host_ptr {
556            unsafe {
557                std::ptr::copy_nonoverlapping(
558                    hostdata.as_ptr(),
559                    host_ptr as *mut u8,
560                    hostdata.len(),
561                );
562            }
563        }
564        Ok(())
565    }
566
567    /// Copy data from device to host
568    pub fn copy_device_to_host(&self, hostdata: &mut [u8]) -> FFTResult<()> {
569        match self.location {
570            BufferLocation::Device => {
571                match self.backend {
572                    GPUBackend::CUDA => {
573                        #[cfg(feature = "cuda")]
574                        {
575                            if let (Some(_device_ptr), Some(_device)) = (
576                                self.cuda_device_ptr.as_ref(),
577                                CUDA_DEVICE.get().and_then(|d| d.as_ref()),
578                            ) {
579                                // CUDA API calls temporarily disabled until cudarc dependency is enabled
580                                /*
581                                device.dtoh_copy(device_ptr, hostdata).map_err(|e| {
582                                    FFTError::ComputationError(format!(
583                                        "Failed to copy _data from CUDA GPU: {:?}",
584                                        e
585                                    ))
586                                })?;
587                                return Ok(());
588                                */
589                            }
590                        }
591
592                        // Fallback to host memory
593                        self.copy_from_host_memory(hostdata)?;
594                    }
595                    GPUBackend::HIP => {
596                        #[cfg(feature = "hip")]
597                        {
598                            if let Some(_device_ptr) = self.hip_device_ptr {
599                                // use hiprt::*; // Temporarily disabled
600                                // HIP API calls temporarily disabled until hiprt dependency is available
601                                /*
602                                unsafe {
603                                    let result = hipMemcpyDtoH(
604                                        hostdata.as_mut_ptr() as *mut std::os::raw::c_void,
605                                        device_ptr,
606                                        hostdata.len(),
607                                    );
608                                    if result == hipError_t::hipSuccess {
609                                        return Ok(());
610                                    } else {
611                                        return Err(FFTError::ComputationError(format!(
612                                            "Failed to copy _data from HIP GPU: {:?}",
613                                            result
614                                        )));
615                                    }
616                                }
617                                */
618                            }
619                        }
620
621                        // Fallback to host memory
622                        self.copy_from_host_memory(hostdata)?;
623                    }
624                    GPUBackend::SYCL => {
625                        #[cfg(feature = "sycl")]
626                        {
627                            if let Some(device_ptr) = self.sycl_device_ptr {
628                                // In a real SYCL implementation, this would:
629                                // 1. Use sycl::queue::memcpy() to copy from device to host
630                                // 2. Handle synchronization and error checking
631                                // 3. Wait for completion if needed
632
633                                // For placeholder implementation, simulate the copy
634                                unsafe {
635                                    std::ptr::copy_nonoverlapping(
636                                        device_ptr as *const u8,
637                                        hostdata.as_mut_ptr(),
638                                        hostdata.len(),
639                                    );
640                                }
641                                return Ok(());
642                            }
643                        }
644
645                        // Fallback to host memory
646                        self.copy_from_host_memory(hostdata)?;
647                    }
648                    _ => {
649                        // CPU fallback
650                        self.copy_from_host_memory(hostdata)?;
651                    }
652                }
653            }
654            BufferLocation::Host | BufferLocation::PinnedHost | BufferLocation::Unified => {
655                self.copy_from_host_memory(hostdata)?;
656            }
657        }
658
659        Ok(())
660    }
661
662    /// Helper to copy data from host memory
663    fn copy_from_host_memory(&self, hostdata: &mut [u8]) -> FFTResult<()> {
664        if let Some(host_ptr) = self.host_ptr {
665            unsafe {
666                std::ptr::copy_nonoverlapping(
667                    host_ptr as *const u8,
668                    hostdata.as_mut_ptr(),
669                    hostdata.len(),
670                );
671            }
672        }
673        Ok(())
674    }
675}
676
677impl Drop for BufferDescriptor {
678    fn drop(&mut self) {
679        // Clean up host memory
680        if let Some(ptr) = self.host_ptr.take() {
681            unsafe {
682                // Convert back to Box<[u8]> to drop properly
683                let vec_size = self.size * self.element_size;
684                let _ = Box::from_raw(std::slice::from_raw_parts_mut(ptr as *mut u8, vec_size));
685            }
686        }
687
688        // Clean up device memory based on backend
689        match self.backend {
690            GPUBackend::CUDA => {
691                // CUDA device memory is automatically dropped when DevicePtr goes out of scope
692                #[cfg(feature = "cuda")]
693                {
694                    self.cuda_device_ptr.take();
695                }
696            }
697            GPUBackend::HIP => {
698                // Clean up HIP device memory
699                #[cfg(feature = "hip")]
700                {
701                    if let Some(_device_ptr) = self.hip_device_ptr.take() {
702                        // use hiprt::*; // Temporarily disabled
703                        // HIP API calls temporarily disabled until hiprt dependency is available
704                        /*
705                        unsafe {
706                            let _ = hipFree(device_ptr);
707                        }
708                        */
709                    }
710                }
711            }
712            GPUBackend::SYCL => {
713                // Clean up SYCL device memory
714                #[cfg(feature = "sycl")]
715                {
716                    if let Some(device_ptr) = self.sycl_device_ptr.take() {
717                        // In a real SYCL implementation, this would:
718                        // 1. Use sycl::free() to deallocate device memory
719                        // 2. Handle any synchronization requirements
720                        // 3. Clean up associated SYCL resources
721
722                        // For placeholder implementation, free the allocated memory
723                        unsafe {
724                            let _ = Box::from_raw(device_ptr as *mut u8);
725                        }
726                    }
727                }
728            }
729            _ => {
730                // No GPU memory to clean up for CPU fallback
731            }
732        }
733    }
734}
735
736/// GPU Memory manager for sparse FFT operations
737pub struct GPUMemoryManager {
738    /// GPU backend
739    backend: GPUBackend,
740    /// Current device ID
741    _device_id: i32,
742    /// Allocation strategy
743    allocation_strategy: AllocationStrategy,
744    /// Maximum memory usage in bytes
745    max_memory: usize,
746    /// Current memory usage in bytes
747    current_memory: usize,
748    /// Buffer cache by size
749    buffer_cache: HashMap<usize, Vec<BufferDescriptor>>,
750    /// Next buffer ID
751    next_buffer_id: usize,
752}
753
754impl GPUMemoryManager {
755    /// Create a new GPU memory manager
756    pub fn new(
757        backend: GPUBackend,
758        device_id: i32,
759        allocation_strategy: AllocationStrategy,
760        max_memory: usize,
761    ) -> Self {
762        Self {
763            backend,
764            _device_id: device_id,
765            allocation_strategy,
766            max_memory,
767            current_memory: 0,
768            buffer_cache: HashMap::new(),
769            next_buffer_id: 0,
770        }
771    }
772
773    /// Get backend name
774    pub fn backend_name(&self) -> &'static str {
775        match self.backend {
776            GPUBackend::CUDA => "CUDA",
777            GPUBackend::HIP => "HIP",
778            GPUBackend::SYCL => "SYCL",
779            GPUBackend::CPUFallback => "CPU",
780        }
781    }
782
783    /// Allocate a buffer of specified size and type
784    pub fn allocate_buffer(
785        &mut self,
786        size: usize,
787        element_size: usize,
788        location: BufferLocation,
789        buffer_type: BufferType,
790    ) -> FFTResult<BufferDescriptor> {
791        let total_size = size * element_size;
792
793        // Check if we're going to exceed the memory limit
794        if self.max_memory > 0 && self.current_memory + total_size > self.max_memory {
795            return Err(FFTError::MemoryError(format!(
796                "Memory limit exceeded: cannot allocate {} bytes (current usage: {} bytes, limit: {} bytes)",
797                total_size, self.current_memory, self.max_memory
798            )));
799        }
800
801        // If using a cache strategy, check if we have an available buffer
802        if self.allocation_strategy == AllocationStrategy::CacheBySize {
803            if let Some(buffers) = self.buffer_cache.get_mut(&size) {
804                if let Some(descriptor) = buffers
805                    .iter()
806                    .position(|b| b.buffer_type == buffer_type && b.location == location)
807                    .map(|idx| buffers.remove(idx))
808                {
809                    return Ok(descriptor);
810                }
811            }
812        }
813
814        // Allocate a new buffer with proper memory allocation
815        let buffer_id = self.next_buffer_id;
816        self.next_buffer_id += 1;
817        self.current_memory += total_size;
818
819        // Create descriptor with actual memory allocation
820        let descriptor = BufferDescriptor::new(
821            size,
822            element_size,
823            location,
824            buffer_type,
825            buffer_id,
826            self.backend,
827        )?;
828
829        Ok(descriptor)
830    }
831
832    /// Release a buffer
833    pub fn release_buffer(&mut self, descriptor: BufferDescriptor) -> FFTResult<()> {
834        let buffer_size = descriptor.size * descriptor.element_size;
835
836        // If using cache strategy, add to cache but don't decrement memory (it's still allocated)
837        if self.allocation_strategy == AllocationStrategy::CacheBySize {
838            self.buffer_cache
839                .entry(descriptor.size)
840                .or_default()
841                .push(descriptor);
842        } else {
843            // Actually free the buffer and decrement memory usage
844            self.current_memory = self.current_memory.saturating_sub(buffer_size);
845        }
846
847        Ok(())
848    }
849
850    /// Clear the buffer cache
851    pub fn clear_cache(&mut self) -> FFTResult<()> {
852        // Free all cached buffers and update memory usage
853        for (_, buffers) in self.buffer_cache.drain() {
854            for descriptor in buffers {
855                let buffer_size = descriptor.size * descriptor.element_size;
856                self.current_memory = self.current_memory.saturating_sub(buffer_size);
857                // The BufferDescriptor's Drop implementation will handle actual memory cleanup
858            }
859        }
860
861        Ok(())
862    }
863
864    /// Get current memory usage
865    pub fn current_memory_usage(&self) -> usize {
866        self.current_memory
867    }
868
869    /// Get memory limit
870    pub fn memory_limit(&self) -> usize {
871        self.max_memory
872    }
873}
874
875/// Global memory manager singleton
876static GLOBAL_MEMORY_MANAGER: Mutex<Option<Arc<Mutex<GPUMemoryManager>>>> = Mutex::new(None);
877
878/// Initialize global memory manager
879#[allow(dead_code)]
880pub fn init_global_memory_manager(
881    backend: GPUBackend,
882    device_id: i32,
883    allocation_strategy: AllocationStrategy,
884    max_memory: usize,
885) -> FFTResult<()> {
886    let mut global = GLOBAL_MEMORY_MANAGER.lock().unwrap();
887    *global = Some(Arc::new(Mutex::new(GPUMemoryManager::new(
888        backend,
889        device_id,
890        allocation_strategy,
891        max_memory,
892    ))));
893    Ok(())
894}
895
896/// Get global memory manager
897#[allow(dead_code)]
898pub fn get_global_memory_manager() -> FFTResult<Arc<Mutex<GPUMemoryManager>>> {
899    let global = GLOBAL_MEMORY_MANAGER.lock().unwrap();
900    if let Some(ref manager) = *global {
901        Ok(manager.clone())
902    } else {
903        // Create a default memory manager if none exists
904        init_global_memory_manager(
905            GPUBackend::CPUFallback,
906            -1,
907            AllocationStrategy::CacheBySize,
908            0,
909        )?;
910        get_global_memory_manager()
911    }
912}
913
914/// Memory-efficient GPU sparse FFT computation
915#[allow(dead_code)]
916pub fn memory_efficient_gpu_sparse_fft<T>(
917    signal: &[T],
918    _max_memory: usize,
919) -> FFTResult<Vec<Complex64>>
920where
921    T: Clone + 'static,
922{
923    // Get the global _memory manager
924    let manager = get_global_memory_manager()?;
925    let _manager = manager.lock().unwrap();
926
927    // Determine optimal chunk size based on available _memory
928    let signal_len = signal.len();
929    // let _element_size = std::mem::size_of::<Complex64>();
930
931    // In a real implementation, this would perform chunked processing
932    // For now, just return a simple result
933    let mut result = Vec::with_capacity(signal_len);
934    for _ in 0..signal_len {
935        result.push(Complex64::new(0.0, 0.0));
936    }
937
938    Ok(result)
939}
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944
945    #[test]
946    fn test_memory_manager_allocation() {
947        let mut manager = GPUMemoryManager::new(
948            GPUBackend::CPUFallback,
949            -1,
950            AllocationStrategy::AlwaysAllocate,
951            1024 * 1024, // 1MB limit
952        );
953
954        // Allocate a buffer
955        let buffer = manager
956            .allocate_buffer(
957                1024,
958                8, // Size of Complex64
959                BufferLocation::Host,
960                BufferType::Input,
961            )
962            .unwrap();
963
964        assert_eq!(buffer.size, 1024);
965        assert_eq!(buffer.element_size, 8);
966        assert_eq!(buffer.location, BufferLocation::Host);
967        assert_eq!(buffer.buffer_type, BufferType::Input);
968        assert_eq!(manager.current_memory_usage(), 1024 * 8);
969
970        // Release buffer
971        manager.release_buffer(buffer).unwrap();
972        assert_eq!(manager.current_memory_usage(), 0);
973    }
974
975    #[test]
976    fn test_memory_manager_cache() {
977        let mut manager = GPUMemoryManager::new(
978            GPUBackend::CPUFallback,
979            -1,
980            AllocationStrategy::CacheBySize,
981            1024 * 1024, // 1MB limit
982        );
983
984        // Allocate a buffer
985        let buffer1 = manager
986            .allocate_buffer(
987                1024,
988                8, // Size of Complex64
989                BufferLocation::Host,
990                BufferType::Input,
991            )
992            .unwrap();
993
994        // Release to cache
995        manager.release_buffer(buffer1).unwrap();
996
997        // Memory usage should not decrease when using CacheBySize
998        assert_eq!(manager.current_memory_usage(), 1024 * 8);
999
1000        // Allocate same size buffer, should get from cache
1001        let buffer2 = manager
1002            .allocate_buffer(1024, 8, BufferLocation::Host, BufferType::Input)
1003            .unwrap();
1004
1005        // Memory should not increase since we're reusing
1006        assert_eq!(manager.current_memory_usage(), 1024 * 8);
1007
1008        // Release the second buffer back to cache
1009        manager.release_buffer(buffer2).unwrap();
1010
1011        // Memory should still be allocated (cached)
1012        assert_eq!(manager.current_memory_usage(), 1024 * 8);
1013
1014        // Clear cache - now this should free the cached memory
1015        manager.clear_cache().unwrap();
1016        assert_eq!(manager.current_memory_usage(), 0);
1017    }
1018
1019    #[test]
1020    fn test_global_memory_manager() {
1021        // Initialize global memory manager
1022        init_global_memory_manager(
1023            GPUBackend::CPUFallback,
1024            -1,
1025            AllocationStrategy::CacheBySize,
1026            1024 * 1024,
1027        )
1028        .unwrap();
1029
1030        // Get global memory manager
1031        let manager = get_global_memory_manager().unwrap();
1032        let mut manager = manager.lock().unwrap();
1033
1034        // Allocate a buffer
1035        let buffer = manager
1036            .allocate_buffer(1024, 8, BufferLocation::Host, BufferType::Input)
1037            .unwrap();
1038
1039        assert_eq!(buffer.size, 1024);
1040        manager.release_buffer(buffer).unwrap();
1041    }
1042}