scirs2_core/memory/
cross_device.rs

1//! # Cross-Device Memory Management
2//!
3//! This module provides unified memory management across different compute devices
4//! including CPU, GPU, and TPU with automatic data movement and synchronization.
5
6use crate::error::{CoreError, CoreResult};
7use crate::gpu::GpuContext;
8use std::any::TypeId;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, RwLock};
11
12/// Error types for cross-device memory management
13#[derive(Debug, thiserror::Error)]
14pub enum CrossDeviceError {
15    /// Device not found
16    #[error("Device not found: {0}")]
17    DeviceNotFound(String),
18
19    /// Memory allocation failed
20    #[error("Memory allocation failed on device {device}: {reason}")]
21    AllocationFailed { device: String, reason: String },
22
23    /// Data transfer failed
24    #[error("Data transfer failed from {from} to {to}: {reason}")]
25    TransferFailed {
26        from: String,
27        to: String,
28        reason: String,
29    },
30
31    /// Synchronization failed
32    #[error("Device synchronization failed: {0}")]
33    SynchronizationFailed(String),
34
35    /// Invalid device type
36    #[error("Invalid device type: {0}")]
37    InvalidDeviceType(String),
38
39    /// Memory not found
40    #[error("Memory allocation not found: {0}")]
41    MemoryNotFound(String),
42}
43
44impl From<CrossDeviceError> for CoreError {
45    fn from(err: CrossDeviceError) -> Self {
46        CoreError::ComputationError(crate::error::ErrorContext::new(err.to_string()))
47    }
48}
49
50/// Device types supported by the memory manager
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum DeviceType {
53    /// CPU memory (system RAM)
54    Cpu,
55    /// NVIDIA GPU (CUDA)
56    CudaGpu(u32),
57    /// AMD GPU (ROCm/OpenCL)
58    RocmGpu(u32),
59    /// Intel GPU (OpenCL)
60    IntelGpu(u32),
61    /// Apple Metal GPU
62    MetalGpu(u32),
63    /// Google TPU
64    Tpu(u32),
65    /// Generic OpenCL device
66    OpenClDevice(u32),
67}
68
69impl DeviceType {
70    /// Get a string representation of the device type
71    pub const fn as_str(&self) -> &'static str {
72        match self {
73            DeviceType::Cpu => "CPU",
74            DeviceType::CudaGpu(_) => "CUDA_GPU",
75            DeviceType::RocmGpu(_) => "ROCM_GPU",
76            DeviceType::IntelGpu(_) => "INTEL_GPU",
77            DeviceType::MetalGpu(_) => "METAL_GPU",
78            DeviceType::Tpu(_) => "TPU",
79            DeviceType::OpenClDevice(_) => "OPENCL",
80        }
81    }
82
83    /// Get device ID
84    pub fn device_id(&self) -> u32 {
85        match self {
86            DeviceType::Cpu => 0,
87            DeviceType::CudaGpu(id)
88            | DeviceType::RocmGpu(id)
89            | DeviceType::IntelGpu(id)
90            | DeviceType::MetalGpu(id)
91            | DeviceType::Tpu(id)
92            | DeviceType::OpenClDevice(id) => *id,
93        }
94    }
95
96    /// Check if device supports unified memory
97    pub fn supports_unified_memory(&self) -> bool {
98        matches!(self, DeviceType::CudaGpu(_) | DeviceType::RocmGpu(_))
99    }
100
101    /// Check if device supports peer-to-peer transfer
102    pub fn supports_p2p_transfer(&self, other: &DeviceType) -> bool {
103        matches!(
104            (self, other),
105            (DeviceType::CudaGpu(_), DeviceType::CudaGpu(_))
106                | (DeviceType::RocmGpu(_), DeviceType::RocmGpu(_))
107        )
108    }
109}
110
111/// Memory allocation information
112#[derive(Debug, Clone)]
113pub struct MemoryAllocation {
114    /// Unique allocation ID
115    pub id: String,
116    /// Device where memory is allocated
117    pub device: DeviceType,
118    /// Size in bytes
119    pub size: usize,
120    /// Memory address (platform-specific)
121    pub address: usize,
122    /// Data type information
123    pub datatype: TypeId,
124    /// Creation timestamp
125    pub created_at: std::time::Instant,
126    /// Last access timestamp
127    pub last_accessed: std::time::Instant,
128    /// Reference count
129    pub ref_count: usize,
130}
131
132impl MemoryAllocation {
133    /// Create a new memory allocation record
134    pub fn new(
135        allocation_id: String,
136        device: DeviceType,
137        size: usize,
138        address: usize,
139        datatype: TypeId,
140    ) -> Self {
141        let now = std::time::Instant::now();
142        Self {
143            id: allocation_id,
144            device,
145            size,
146            address,
147            datatype,
148            created_at: now,
149            last_accessed: now,
150            ref_count: 1,
151        }
152    }
153
154    /// Update last access time
155    pub fn touch(&mut self) {
156        self.last_accessed = std::time::Instant::now();
157    }
158
159    /// Increment reference count
160    pub fn add_ref(&mut self) {
161        self.ref_count += 1;
162    }
163
164    /// Decrement reference count
165    pub fn remove_ref(&mut self) -> usize {
166        self.ref_count = self.ref_count.saturating_sub(1);
167        self.ref_count
168    }
169}
170
171/// Device interface trait
172pub trait Device: Send + Sync {
173    /// Get device type
174    fn device_type(&self) -> DeviceType;
175
176    /// Allocate memory on this device
177    fn allocate(&self, size: usize) -> CoreResult<usize>;
178
179    /// Deallocate memory on this device
180    fn deallocate(&self, address: usize) -> CoreResult<()>;
181
182    /// Copy data to this device from CPU
183    ///
184    /// # Safety
185    ///
186    /// The caller must ensure:
187    /// - `src` points to at least `size` bytes of valid memory
188    /// - `dst` is a valid device memory address with at least `size` bytes allocated
189    /// - The memory regions do not overlap
190    unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()>;
191
192    /// Copy data from this device to CPU
193    ///
194    /// # Safety
195    ///
196    /// The caller must ensure:
197    /// - `src` is a valid device memory address with at least `size` bytes allocated
198    /// - `dst` points to at least `size` bytes of valid writable memory
199    /// - The memory regions do not overlap
200    unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()>;
201
202    /// Copy data between devices (if supported)
203    fn copy_peer(
204        &self,
205        src: usize,
206        dst_device: &dyn Device,
207        dst: usize,
208        size: usize,
209    ) -> CoreResult<()>;
210
211    /// Synchronize device operations
212    fn synchronize(&self) -> CoreResult<()>;
213
214    /// Get available memory in bytes
215    fn available_memory(&self) -> CoreResult<usize>;
216
217    /// Get total memory in bytes
218    fn total_memory(&self) -> CoreResult<usize>;
219}
220
221/// CPU device implementation
222pub struct CpuDevice {
223    device_type: DeviceType,
224}
225
226impl CpuDevice {
227    /// Create a new CPU device
228    pub fn new() -> Self {
229        Self {
230            device_type: DeviceType::Cpu,
231        }
232    }
233}
234
235impl Default for CpuDevice {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241impl Device for CpuDevice {
242    fn device_type(&self) -> DeviceType {
243        self.device_type.clone()
244    }
245
246    fn allocate(&self, size: usize) -> CoreResult<usize> {
247        let layout = std::alloc::Layout::from_size_align(size, 64).map_err(|e| {
248            CrossDeviceError::AllocationFailed {
249                device: "CPU".to_string(),
250                reason: e.to_string(),
251            }
252        })?;
253
254        unsafe {
255            let ptr = std::alloc::alloc(layout);
256            if ptr.is_null() {
257                Err(CrossDeviceError::AllocationFailed {
258                    device: "CPU".to_string(),
259                    reason: "Out of memory".to_string(),
260                }
261                .into())
262            } else {
263                Ok(ptr as usize)
264            }
265        }
266    }
267
268    fn deallocate(&self, address: usize) -> CoreResult<()> {
269        // Note: In a real implementation, we'd need to track the layout
270        // For now, we'll skip the actual deallocation
271        let _ = address;
272        Ok(())
273    }
274
275    unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
276        std::ptr::copy_nonoverlapping(src, dst as *mut u8, size);
277        Ok(())
278    }
279
280    unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
281        std::ptr::copy_nonoverlapping(src as *const u8, dst, size);
282        Ok(())
283    }
284
285    fn copy_peer(
286        &self,
287        src: usize,
288        _dst_device: &dyn Device,
289        _dst: usize,
290        _size: usize,
291    ) -> CoreResult<()> {
292        Err(CrossDeviceError::TransferFailed {
293            from: "CPU".to_string(),
294            to: "unknown".to_string(),
295            reason: "Peer-to-peer not supported for CPU".to_string(),
296        }
297        .into())
298    }
299
300    fn synchronize(&self) -> CoreResult<()> {
301        // CPU operations are synchronous
302        Ok(())
303    }
304
305    fn available_memory(&self) -> CoreResult<usize> {
306        // Simple approximation - in reality would use platform-specific APIs
307        Ok(8 * 1024 * 1024 * 1024) // 8 GB
308    }
309
310    fn total_memory(&self) -> CoreResult<usize> {
311        Ok(16 * 1024 * 1024 * 1024) // 16 GB
312    }
313}
314
315/// GPU device wrapper
316pub struct GpuContextWrapper {
317    inner: Arc<GpuContext>,
318    device_type: DeviceType,
319}
320
321impl GpuContextWrapper {
322    /// Create a new GPU device wrapper
323    pub fn new(gpu_device: Arc<GpuContext>, devicetype: DeviceType) -> Self {
324        Self {
325            inner: gpu_device,
326            device_type: devicetype,
327        }
328    }
329}
330
331impl Device for GpuContextWrapper {
332    fn device_type(&self) -> DeviceType {
333        self.device_type.clone()
334    }
335
336    fn allocate(&self, size: usize) -> CoreResult<usize> {
337        // Use the GPU device's buffer allocation
338        let _buffer = self.inner.create_buffer::<u8>(size);
339        // In a real implementation, we'd extract the actual device pointer
340        // For now, we'll use a placeholder based on buffer properties
341        Ok(size) // Return the size as a placeholder ID
342    }
343
344    fn deallocate(&self, address: usize) -> CoreResult<()> {
345        // GPU buffers are automatically freed when dropped
346        Ok(())
347    }
348
349    unsafe fn copy_from_host(&self, src: *const u8, _dst: usize, size: usize) -> CoreResult<()> {
350        // Would use GPU-specific memory copy operations
351        Ok(())
352    }
353
354    unsafe fn copy_to_host(&self, src: usize, _dst: *mut u8, size: usize) -> CoreResult<()> {
355        // Would use GPU-specific memory copy operations
356        Ok(())
357    }
358
359    fn copy_peer(
360        &self,
361        src: usize,
362        _dst_device: &dyn Device,
363        _dst: usize,
364        _size: usize,
365    ) -> CoreResult<()> {
366        // Would implement GPU-to-GPU transfers
367        Ok(())
368    }
369
370    fn synchronize(&self) -> CoreResult<()> {
371        // Would synchronize GPU streams/queues
372        Ok(())
373    }
374
375    fn available_memory(&self) -> CoreResult<usize> {
376        self.inner.get_available_memory().ok_or_else(|| {
377            CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
378        })
379    }
380
381    fn total_memory(&self) -> CoreResult<usize> {
382        self.inner.get_total_memory().ok_or_else(|| {
383            CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
384        })
385    }
386}
387
388/// Cross-device memory manager
389pub struct CrossDeviceMemoryManager {
390    devices: RwLock<HashMap<DeviceType, Arc<dyn Device>>>,
391    allocations: RwLock<HashMap<String, MemoryAllocation>>,
392    allocation_counter: Mutex<u64>,
393    default_device: RwLock<Option<DeviceType>>,
394}
395
396impl CrossDeviceMemoryManager {
397    /// Create a new cross-device memory manager
398    pub fn new() -> Self {
399        Self {
400            devices: RwLock::new(HashMap::new()),
401            allocations: RwLock::new(HashMap::new()),
402            allocation_counter: Mutex::new(0),
403            default_device: RwLock::new(None),
404        }
405    }
406
407    /// Register a device with the manager
408    pub fn register_device(&self, device: Arc<dyn Device>) -> CoreResult<()> {
409        let device_type = device.device_type();
410        let mut devices = self.devices.write().expect("Operation failed");
411        devices.insert(device_type.clone(), device);
412
413        // Set as default if it's the first device
414        let mut default_device = self.default_device.write().expect("Operation failed");
415        if default_device.is_none() {
416            *default_device = Some(device_type);
417        }
418
419        Ok(())
420    }
421
422    /// Set the default device
423    pub fn set_default_device(&self, devicetype: DeviceType) -> CoreResult<()> {
424        let devices = self.devices.read().expect("Operation failed");
425        if !devices.contains_key(&devicetype) {
426            return Err(CrossDeviceError::DeviceNotFound(format!("{devicetype:?}")).into());
427        }
428
429        let mut default_device = self.default_device.write().expect("Operation failed");
430        *default_device = Some(devicetype);
431
432        Ok(())
433    }
434
435    /// Get the default device
436    pub fn get_default_device(&self) -> Option<DeviceType> {
437        self.default_device
438            .read()
439            .expect("Operation failed")
440            .clone()
441    }
442
443    /// Allocate memory on a specific device
444    pub fn allocate<T: 'static>(
445        self: &Arc<Self>,
446        device_type: &DeviceType,
447        count: usize,
448    ) -> CoreResult<CrossDeviceBuffer<T>> {
449        let devices = self.devices.read().expect("Operation failed");
450        let device = devices
451            .get(device_type)
452            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{device_type:?}")))?;
453
454        let size = count * std::mem::size_of::<T>();
455        let address = device.allocate(size)?;
456
457        let allocation_id = self.generate_allocation_id();
458        let allocation = MemoryAllocation::new(
459            allocation_id.clone(),
460            device_type.clone(),
461            size,
462            address,
463            TypeId::of::<T>(),
464        );
465
466        let mut allocations = self.allocations.write().expect("Operation failed");
467        allocations.insert(allocation_id.clone(), allocation);
468
469        Ok(CrossDeviceBuffer::new(
470            allocation_id,
471            device_type.clone(),
472            address,
473            count,
474            self.clone(),
475        ))
476    }
477
478    /// Allocate memory on the default device
479    pub fn allocate_default<T: 'static>(
480        self: &Arc<Self>,
481        count: usize,
482    ) -> CoreResult<CrossDeviceBuffer<T>> {
483        let default_device = self
484            .get_default_device()
485            .ok_or_else(|| CrossDeviceError::DeviceNotFound("No default device set".to_string()))?;
486
487        self.allocate(&default_device, count)
488    }
489
490    /// Transfer data between devices
491    pub fn transfer<T: 'static + Copy>(
492        self: &Arc<Self>,
493        src_buffer: &CrossDeviceBuffer<T>,
494        dst_device: &DeviceType,
495    ) -> CoreResult<CrossDeviceBuffer<T>> {
496        let devices = self.devices.read().expect("Operation failed");
497        let src_device = devices.get(&src_buffer.device_type).ok_or_else(|| {
498            CrossDeviceError::DeviceNotFound(format!("{0:?}", src_buffer.device_type))
499        })?;
500        let dst_device_obj = devices
501            .get(dst_device)
502            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{dst_device:?}")))?;
503
504        // Allocate memory on destination device
505        let dst_buffer = self.allocate::<T>(dst_device, src_buffer.count)?;
506
507        let size = src_buffer.count * std::mem::size_of::<T>();
508
509        // Try peer-to-peer transfer first
510        if src_buffer.device_type.supports_p2p_transfer(dst_device) {
511            src_device.copy_peer(
512                src_buffer.address,
513                dst_device_obj.as_ref(),
514                dst_buffer.address,
515                size,
516            )?;
517        } else {
518            // Fall back to CPU staging
519            let staging_buffer = self.allocate::<T>(&DeviceType::Cpu, src_buffer.count)?;
520
521            // Copy from source to CPU
522            unsafe {
523                src_device.copy_to_host(
524                    src_buffer.address,
525                    staging_buffer.address as *mut u8,
526                    size,
527                )?;
528            }
529
530            // Copy from CPU to destination
531            unsafe {
532                dst_device_obj.copy_from_host(
533                    staging_buffer.address as *const u8,
534                    dst_buffer.address,
535                    size,
536                )?;
537            }
538        }
539
540        Ok(dst_buffer)
541    }
542
543    /// Synchronize all devices
544    pub fn synchronize_all(&self) -> CoreResult<()> {
545        let devices = self.devices.read().expect("Operation failed");
546        for device in devices.values() {
547            device.synchronize()?;
548        }
549        Ok(())
550    }
551
552    /// Get memory statistics
553    pub fn get_memory_statistics(&self) -> MemoryStatistics {
554        let allocations = self.allocations.read().expect("Operation failed");
555        let devices = self.devices.read().expect("Operation failed");
556
557        let mut stats_by_device = HashMap::new();
558        let mut total_allocated = 0;
559        let mut total_allocations = 0;
560
561        for allocation in allocations.values() {
562            let device_stats =
563                stats_by_device
564                    .entry(allocation.device.clone())
565                    .or_insert(DeviceMemoryStats {
566                        device_type: allocation.device.clone(),
567                        allocated_bytes: 0,
568                        allocation_count: 0,
569                        available_bytes: 0,
570                        total_bytes: 0,
571                    });
572
573            device_stats.allocated_bytes += allocation.size;
574            device_stats.allocation_count += 1;
575            total_allocated += allocation.size;
576            total_allocations += 1;
577        }
578
579        // Update available/total memory from devices and ensure all devices are included
580        for (device_type, device) in devices.iter() {
581            let device_stats =
582                stats_by_device
583                    .entry(device_type.clone())
584                    .or_insert(DeviceMemoryStats {
585                        device_type: device_type.clone(),
586                        allocated_bytes: 0,
587                        allocation_count: 0,
588                        available_bytes: 0,
589                        total_bytes: 0,
590                    });
591
592            device_stats.available_bytes = device.available_memory().unwrap_or(0);
593            device_stats.total_bytes = device.total_memory().unwrap_or(0);
594        }
595
596        MemoryStatistics {
597            total_allocated_bytes: total_allocated,
598            total_allocations,
599            device_stats: stats_by_device.into_values().collect(),
600        }
601    }
602
603    /// Clean up unused allocations
604    pub fn cleanup_unused_allocations(&self, maxage: std::time::Duration) -> usize {
605        let mut allocations = self.allocations.write().expect("Operation failed");
606        let now = std::time::Instant::now();
607        let mut cleaned = 0;
608
609        allocations.retain(|_, allocation| {
610            if allocation.ref_count == 0 && now.duration_since(allocation.last_accessed) > maxage {
611                // In a real implementation, we'd call deallocate on the device
612                cleaned += 1;
613                false
614            } else {
615                true
616            }
617        });
618
619        cleaned
620    }
621
622    /// Generate unique allocation ID
623    fn generate_allocation_id(&self) -> String {
624        let counter = {
625            let mut counter = self.allocation_counter.lock().expect("Operation failed");
626            *counter += 1;
627            *counter
628        };
629
630        format!("{counter:016x}")
631    }
632
633    /// Internal method to remove allocation (called by CrossDeviceBuffer on drop)
634    pub(crate) fn remove_allocation(&self, allocationid: &str) {
635        let mut allocations = self.allocations.write().expect("Operation failed");
636        if let Some(allocation) = allocations.get_mut(allocationid) {
637            if allocation.remove_ref() == 0 {
638                allocations.remove(allocationid);
639            }
640        }
641    }
642
643    /// Internal method to touch allocation (update last access time)
644    pub(crate) fn touch_allocation(&self, allocationid: &str) {
645        let mut allocations = self.allocations.write().expect("Operation failed");
646        if let Some(allocation) = allocations.get_mut(allocationid) {
647            allocation.touch();
648        }
649    }
650}
651
652impl Default for CrossDeviceMemoryManager {
653    fn default() -> Self {
654        Self::new()
655    }
656}
657
658/// Cross-device buffer that manages memory across different devices
659pub struct CrossDeviceBuffer<T> {
660    allocation_id: String,
661    device_type: DeviceType,
662    address: usize,
663    count: usize,
664    manager: Arc<CrossDeviceMemoryManager>,
665    phantom: std::marker::PhantomData<T>,
666}
667
668impl<T> CrossDeviceBuffer<T> {
669    /// Create a new cross-device buffer
670    fn new(
671        allocation_id: String,
672        device_type: DeviceType,
673        address: usize,
674        count: usize,
675        manager: Arc<CrossDeviceMemoryManager>,
676    ) -> Self {
677        Self {
678            allocation_id,
679            device_type,
680            address,
681            count,
682            manager,
683            phantom: std::marker::PhantomData,
684        }
685    }
686
687    /// Get the device type this buffer is allocated on
688    pub const fn device_type(&self) -> &DeviceType {
689        &self.device_type
690    }
691
692    /// Get the number of elements in the buffer
693    pub fn len(&self) -> usize {
694        self.count
695    }
696
697    /// Check if the buffer is empty
698    pub fn is_empty(&self) -> bool {
699        self.count == 0
700    }
701
702    /// Get the size in bytes
703    pub fn size_bytes(&self) -> usize {
704        self.count * std::mem::size_of::<T>()
705    }
706
707    /// Get the raw address (device-specific)
708    pub fn raw_address(&self) -> usize {
709        self.manager.touch_allocation(&self.allocation_id);
710        self.address
711    }
712
713    /// Transfer this buffer to another device
714    pub fn to_device(&self, devicetype: &DeviceType) -> CoreResult<CrossDeviceBuffer<T>>
715    where
716        T: Copy + 'static,
717    {
718        self.manager.transfer(self, devicetype)
719    }
720
721    /// Copy data from host to this buffer
722    pub fn copy_from_host(&self, data: &[T]) -> CoreResult<()>
723    where
724        T: Copy,
725    {
726        if data.len() != self.count {
727            return Err(CrossDeviceError::InvalidDeviceType(format!(
728                "Data length {} doesn't match buffer capacity {}",
729                data.len(),
730                self.count
731            ))
732            .into());
733        }
734
735        let devices = self.manager.devices.read().expect("Operation failed");
736        let device = devices
737            .get(&self.device_type)
738            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
739
740        unsafe {
741            device.copy_from_host(data.as_ptr() as *const u8, self.address, self.size_bytes())?;
742        }
743
744        self.manager.touch_allocation(&self.allocation_id);
745        Ok(())
746    }
747
748    /// Copy data from this buffer to host
749    pub fn copy_to_host(&self) -> CoreResult<Vec<T>>
750    where
751        T: Copy + Default,
752    {
753        let mut result = vec![T::default(); self.count];
754
755        let devices = self.manager.devices.read().expect("Operation failed");
756        let device = devices
757            .get(&self.device_type)
758            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
759
760        unsafe {
761            device.copy_to_host(
762                self.address,
763                result.as_mut_ptr() as *mut u8,
764                self.size_bytes(),
765            )?;
766        }
767
768        self.manager.touch_allocation(&self.allocation_id);
769        Ok(result)
770    }
771}
772
773impl<T> Clone for CrossDeviceBuffer<T> {
774    fn clone(&self) -> Self {
775        // Increment reference count
776        {
777            let mut allocations = self.manager.allocations.write().expect("Operation failed");
778            if let Some(allocation) = allocations.get_mut(&self.allocation_id) {
779                allocation.add_ref();
780            }
781        }
782
783        Self {
784            allocation_id: self.allocation_id.clone(),
785            device_type: self.device_type.clone(),
786            address: self.address,
787            count: self.count,
788            manager: self.manager.clone(),
789            phantom: std::marker::PhantomData,
790        }
791    }
792}
793
794impl<T> Drop for CrossDeviceBuffer<T> {
795    fn drop(&mut self) {
796        self.manager.remove_allocation(&self.allocation_id);
797    }
798}
799
800/// Memory statistics
801#[derive(Debug, Clone)]
802pub struct MemoryStatistics {
803    /// Total bytes allocated across all devices
804    pub total_allocated_bytes: usize,
805    /// Total number of allocations
806    pub total_allocations: usize,
807    /// Statistics per device
808    pub device_stats: Vec<DeviceMemoryStats>,
809}
810
811/// Memory statistics for a specific device
812#[derive(Debug, Clone)]
813pub struct DeviceMemoryStats {
814    /// Device type
815    pub device_type: DeviceType,
816    /// Currently allocated bytes on this device
817    pub allocated_bytes: usize,
818    /// Number of active allocations
819    pub allocation_count: usize,
820    /// Available memory on device
821    pub available_bytes: usize,
822    /// Total memory on device
823    pub total_bytes: usize,
824}
825
826impl DeviceMemoryStats {
827    /// Get memory usage percentage
828    pub fn usage_percentage(&self) -> f64 {
829        if self.total_bytes == 0 {
830            0.0
831        } else {
832            (self.allocated_bytes as f64 / self.total_bytes as f64) * 100.0
833        }
834    }
835}
836
837/// Global cross-device memory manager instance
838static GLOBAL_MANAGER: std::sync::OnceLock<Arc<CrossDeviceMemoryManager>> =
839    std::sync::OnceLock::new();
840
841/// Get the global cross-device memory manager
842#[allow(dead_code)]
843pub fn global_manager() -> Arc<CrossDeviceMemoryManager> {
844    GLOBAL_MANAGER
845        .get_or_init(|| {
846            let manager = Arc::new(CrossDeviceMemoryManager::new());
847
848            // Register CPU device by default
849            let cpu_device = Arc::new(CpuDevice::new());
850            let _ = manager.register_device(cpu_device);
851
852            manager
853        })
854        .clone()
855}
856
857/// Initialize cross-device memory management with GPU devices
858#[allow(dead_code)]
859pub fn initialize_with_gpu_devices(gpudevices: Vec<Arc<GpuContext>>) -> CoreResult<()> {
860    let manager = global_manager();
861
862    for (i, gpu_device) in gpudevices.into_iter().enumerate() {
863        let device_type = DeviceType::CudaGpu(i as u32); // Assume CUDA for now
864        let wrapper = Arc::new(GpuContextWrapper::new(gpu_device, device_type));
865        manager.register_device(wrapper)?;
866    }
867
868    Ok(())
869}
870
871/// Convenience functions for cross-device memory management
872pub mod utils {
873    use super::*;
874
875    /// Allocate a buffer on the best available device
876    pub fn allocate_optimal<T: 'static>(count: usize) -> CoreResult<CrossDeviceBuffer<T>> {
877        let manager = global_manager();
878        let stats = manager.get_memory_statistics();
879
880        // Find device with most available memory
881        let best_device = stats
882            .device_stats
883            .iter()
884            .max_by_key(|s| s.available_bytes)
885            .map(|s| s.device_type.clone())
886            .unwrap_or(DeviceType::Cpu);
887
888        manager.allocate(&best_device, count)
889    }
890
891    /// Create a buffer with data from host
892    pub fn create_buffer_with_data<T: Copy + 'static>(
893        data: &[T],
894        device_type: &DeviceType,
895    ) -> CoreResult<CrossDeviceBuffer<T>> {
896        let manager = global_manager();
897        let buffer = manager.allocate(device_type, data.len())?;
898        buffer.copy_from_host(data)?;
899        Ok(buffer)
900    }
901
902    /// Transfer data between any two devices
903    pub fn transfer_data<T: Copy + 'static>(
904        src_buffer: &CrossDeviceBuffer<T>,
905        dst_device: &DeviceType,
906    ) -> CoreResult<CrossDeviceBuffer<T>> {
907        let manager = global_manager();
908        manager.transfer(src_buffer, dst_device)
909    }
910}
911
912#[cfg(test)]
913mod tests {
914    use super::*;
915
916    #[test]
917    fn test_device_type_creation() {
918        let cpu = DeviceType::Cpu;
919        let gpu = DeviceType::CudaGpu(0);
920        let tpu = DeviceType::Tpu(1);
921
922        assert_eq!(cpu.as_str(), "CPU");
923        assert_eq!(gpu.as_str(), "CUDA_GPU");
924        assert_eq!(tpu.as_str(), "TPU");
925
926        assert_eq!(cpu.device_id(), 0);
927        assert_eq!(gpu.device_id(), 0);
928        assert_eq!(tpu.device_id(), 1);
929    }
930
931    #[test]
932    fn test_device_capabilities() {
933        let cpu = DeviceType::Cpu;
934        let cuda = DeviceType::CudaGpu(0);
935        let rocm = DeviceType::RocmGpu(0);
936
937        assert!(!cpu.supports_unified_memory());
938        assert!(cuda.supports_unified_memory());
939        assert!(rocm.supports_unified_memory());
940
941        assert!(cuda.supports_p2p_transfer(&DeviceType::CudaGpu(1)));
942        assert!(!cuda.supports_p2p_transfer(&DeviceType::RocmGpu(0)));
943        assert!(!cpu.supports_p2p_transfer(&DeviceType::CudaGpu(0)));
944    }
945
946    #[test]
947    fn test_memory_allocation_creation() {
948        let allocation = MemoryAllocation::new(
949            "test_alloc".to_string(),
950            DeviceType::Cpu,
951            1024,
952            0x1000,
953            TypeId::of::<f32>(),
954        );
955
956        assert_eq!(allocation.id, "test_alloc");
957        assert_eq!(allocation.size, 1024);
958        assert_eq!(allocation.address, 0x1000);
959        assert_eq!(allocation.ref_count, 1);
960    }
961
962    #[test]
963    fn test_cpu_device() {
964        let cpu = CpuDevice::new();
965        assert_eq!(cpu.device_type(), DeviceType::Cpu);
966
967        // Test memory info
968        assert!(cpu.available_memory().is_ok());
969        assert!(cpu.total_memory().is_ok());
970
971        // Test synchronization
972        assert!(cpu.synchronize().is_ok());
973    }
974
975    #[test]
976    fn test_cross_device_manager() {
977        let manager = CrossDeviceMemoryManager::new();
978
979        // Register CPU device
980        let cpu_device = Arc::new(CpuDevice::new());
981        assert!(manager.register_device(cpu_device).is_ok());
982
983        // Check default device
984        assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
985
986        // Get initial statistics
987        let stats = manager.get_memory_statistics();
988        assert_eq!(stats.total_allocations, 0);
989        assert_eq!(stats.total_allocated_bytes, 0);
990    }
991
992    #[test]
993    fn test_global_manager() {
994        let manager = global_manager();
995        assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
996
997        let stats = manager.get_memory_statistics();
998        assert!(!stats.device_stats.is_empty());
999    }
1000
1001    #[test]
1002    fn test_memory_statistics() {
1003        let stats = DeviceMemoryStats {
1004            device_type: DeviceType::Cpu,
1005            allocated_bytes: 1024,
1006            allocation_count: 1,
1007            available_bytes: 7 * 1024 * 1024 * 1024,
1008            total_bytes: 8 * 1024 * 1024 * 1024,
1009        };
1010
1011        let usage = stats.usage_percentage();
1012        assert!(usage > 0.0 && usage < 1.0);
1013    }
1014}