Skip to main content

scirs2_core/memory/
cross_device.rs

1//! # Cross-Device Memory Management
2//!
3//! This module provides unified memory management across different compute devices
4//! including CPU, GPU, and TPU with automatic data movement and synchronization.
5
6use crate::error::{CoreError, CoreResult};
7use crate::gpu::GpuContext;
8use std::any::TypeId;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, RwLock};
11
12/// Error types for cross-device memory management
13#[derive(Debug, thiserror::Error)]
14pub enum CrossDeviceError {
15    /// Device not found
16    #[error("Device not found: {0}")]
17    DeviceNotFound(String),
18
19    /// Memory allocation failed
20    #[error("Memory allocation failed on device {device}: {reason}")]
21    AllocationFailed { device: String, reason: String },
22
23    /// Data transfer failed
24    #[error("Data transfer failed from {from} to {to}: {reason}")]
25    TransferFailed {
26        from: String,
27        to: String,
28        reason: String,
29    },
30
31    /// Synchronization failed
32    #[error("Device synchronization failed: {0}")]
33    SynchronizationFailed(String),
34
35    /// Invalid device type
36    #[error("Invalid device type: {0}")]
37    InvalidDeviceType(String),
38
39    /// Memory not found
40    #[error("Memory allocation not found: {0}")]
41    MemoryNotFound(String),
42}
43
44impl From<CrossDeviceError> for CoreError {
45    fn from(err: CrossDeviceError) -> Self {
46        CoreError::ComputationError(crate::error::ErrorContext::new(err.to_string()))
47    }
48}
49
50/// Device types supported by the memory manager
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum DeviceType {
53    /// CPU memory (system RAM)
54    Cpu,
55    /// NVIDIA GPU (CUDA)
56    CudaGpu(u32),
57    /// AMD GPU (ROCm/OpenCL)
58    RocmGpu(u32),
59    /// Intel GPU (OpenCL)
60    IntelGpu(u32),
61    /// Apple Metal GPU
62    MetalGpu(u32),
63    /// Google TPU
64    Tpu(u32),
65    /// Generic OpenCL device
66    OpenClDevice(u32),
67}
68
69impl DeviceType {
70    /// Get a string representation of the device type
71    pub const fn as_str(&self) -> &'static str {
72        match self {
73            DeviceType::Cpu => "CPU",
74            DeviceType::CudaGpu(_) => "CUDA_GPU",
75            DeviceType::RocmGpu(_) => "ROCM_GPU",
76            DeviceType::IntelGpu(_) => "INTEL_GPU",
77            DeviceType::MetalGpu(_) => "METAL_GPU",
78            DeviceType::Tpu(_) => "TPU",
79            DeviceType::OpenClDevice(_) => "OPENCL",
80        }
81    }
82
83    /// Get device ID
84    pub fn device_id(&self) -> u32 {
85        match self {
86            DeviceType::Cpu => 0,
87            DeviceType::CudaGpu(id)
88            | DeviceType::RocmGpu(id)
89            | DeviceType::IntelGpu(id)
90            | DeviceType::MetalGpu(id)
91            | DeviceType::Tpu(id)
92            | DeviceType::OpenClDevice(id) => *id,
93        }
94    }
95
96    /// Check if device supports unified memory
97    pub fn supports_unified_memory(&self) -> bool {
98        matches!(self, DeviceType::CudaGpu(_) | DeviceType::RocmGpu(_))
99    }
100
101    /// Check if device supports peer-to-peer transfer
102    pub fn supports_p2p_transfer(&self, other: &DeviceType) -> bool {
103        matches!(
104            (self, other),
105            (DeviceType::CudaGpu(_), DeviceType::CudaGpu(_))
106                | (DeviceType::RocmGpu(_), DeviceType::RocmGpu(_))
107        )
108    }
109}
110
111/// Memory allocation information
112#[derive(Debug, Clone)]
113pub struct MemoryAllocation {
114    /// Unique allocation ID
115    pub id: String,
116    /// Device where memory is allocated
117    pub device: DeviceType,
118    /// Size in bytes
119    pub size: usize,
120    /// Memory address (platform-specific)
121    pub address: usize,
122    /// Data type information
123    pub datatype: TypeId,
124    /// Creation timestamp
125    pub created_at: std::time::Instant,
126    /// Last access timestamp
127    pub last_accessed: std::time::Instant,
128    /// Reference count
129    pub ref_count: usize,
130}
131
132impl MemoryAllocation {
133    /// Create a new memory allocation record
134    pub fn new(
135        allocation_id: String,
136        device: DeviceType,
137        size: usize,
138        address: usize,
139        datatype: TypeId,
140    ) -> Self {
141        let now = std::time::Instant::now();
142        Self {
143            id: allocation_id,
144            device,
145            size,
146            address,
147            datatype,
148            created_at: now,
149            last_accessed: now,
150            ref_count: 1,
151        }
152    }
153
154    /// Update last access time
155    pub fn touch(&mut self) {
156        self.last_accessed = std::time::Instant::now();
157    }
158
159    /// Increment reference count
160    pub fn add_ref(&mut self) {
161        self.ref_count += 1;
162    }
163
164    /// Decrement reference count
165    pub fn remove_ref(&mut self) -> usize {
166        self.ref_count = self.ref_count.saturating_sub(1);
167        self.ref_count
168    }
169}
170
171/// Device interface trait
172pub trait Device: Send + Sync {
173    /// Get device type
174    fn device_type(&self) -> DeviceType;
175
176    /// Allocate memory on this device
177    fn allocate(&self, size: usize) -> CoreResult<usize>;
178
179    /// Deallocate memory on this device
180    fn deallocate(&self, address: usize) -> CoreResult<()>;
181
182    /// Copy data to this device from CPU
183    ///
184    /// # Safety
185    ///
186    /// The caller must ensure:
187    /// - `src` points to at least `size` bytes of valid memory
188    /// - `dst` is a valid device memory address with at least `size` bytes allocated
189    /// - The memory regions do not overlap
190    unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()>;
191
192    /// Copy data from this device to CPU
193    ///
194    /// # Safety
195    ///
196    /// The caller must ensure:
197    /// - `src` is a valid device memory address with at least `size` bytes allocated
198    /// - `dst` points to at least `size` bytes of valid writable memory
199    /// - The memory regions do not overlap
200    unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()>;
201
202    /// Copy data between devices (if supported)
203    fn copy_peer(
204        &self,
205        src: usize,
206        dst_device: &dyn Device,
207        dst: usize,
208        size: usize,
209    ) -> CoreResult<()>;
210
211    /// Synchronize device operations
212    fn synchronize(&self) -> CoreResult<()>;
213
214    /// Get available memory in bytes
215    fn available_memory(&self) -> CoreResult<usize>;
216
217    /// Get total memory in bytes
218    fn total_memory(&self) -> CoreResult<usize>;
219}
220
221/// CPU device implementation
222pub struct CpuDevice {
223    device_type: DeviceType,
224    /// Tracks the layout used for each live allocation so that `deallocate`
225    /// can free memory correctly. The global allocator requires the exact
226    /// `Layout` that was used for `alloc` in order to `dealloc` without
227    /// undefined behavior, and that information is not recoverable from the
228    /// raw address alone.
229    allocations: Mutex<HashMap<usize, std::alloc::Layout>>,
230}
231
232impl CpuDevice {
233    /// Create a new CPU device
234    pub fn new() -> Self {
235        Self {
236            device_type: DeviceType::Cpu,
237            allocations: Mutex::new(HashMap::new()),
238        }
239    }
240}
241
242impl Default for CpuDevice {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248impl Device for CpuDevice {
249    fn device_type(&self) -> DeviceType {
250        self.device_type.clone()
251    }
252
253    fn allocate(&self, size: usize) -> CoreResult<usize> {
254        let layout = std::alloc::Layout::from_size_align(size, 64).map_err(|e| {
255            CrossDeviceError::AllocationFailed {
256                device: "CPU".to_string(),
257                reason: e.to_string(),
258            }
259        })?;
260
261        // A zero-sized allocation has no backing storage to track or free.
262        if size == 0 {
263            return Ok(0);
264        }
265
266        unsafe {
267            let ptr = std::alloc::alloc(layout);
268            if ptr.is_null() {
269                Err(CrossDeviceError::AllocationFailed {
270                    device: "CPU".to_string(),
271                    reason: "Out of memory".to_string(),
272                }
273                .into())
274            } else {
275                // Record the layout so the matching `deallocate` can free it.
276                self.allocations
277                    .lock()
278                    .map_err(|_| CrossDeviceError::AllocationFailed {
279                        device: "CPU".to_string(),
280                        reason: "Allocation registry lock poisoned".to_string(),
281                    })?
282                    .insert(ptr as usize, layout);
283                Ok(ptr as usize)
284            }
285        }
286    }
287
288    fn deallocate(&self, address: usize) -> CoreResult<()> {
289        // A zero address corresponds to a zero-sized allocation, which has no
290        // backing storage to free.
291        if address == 0 {
292            return Ok(());
293        }
294
295        // Look up the layout recorded at allocation time. The global allocator
296        // requires the exact layout to free the block safely.
297        let layout = {
298            let mut allocations =
299                self.allocations
300                    .lock()
301                    .map_err(|_| CrossDeviceError::AllocationFailed {
302                        device: "CPU".to_string(),
303                        reason: "Allocation registry lock poisoned".to_string(),
304                    })?;
305            allocations.remove(&address)
306        };
307
308        match layout {
309            Some(layout) => {
310                // SAFETY: `address` was returned by `alloc` with this exact
311                // `layout` and has not been freed yet (we just removed it from
312                // the registry, so a double free is impossible).
313                unsafe {
314                    std::alloc::dealloc(address as *mut u8, layout);
315                }
316                Ok(())
317            }
318            None => Err(CrossDeviceError::MemoryNotFound(format!(
319                "CPU allocation at address {address:#x} is not tracked by this device"
320            ))
321            .into()),
322        }
323    }
324
325    unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
326        std::ptr::copy_nonoverlapping(src, dst as *mut u8, size);
327        Ok(())
328    }
329
330    unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
331        std::ptr::copy_nonoverlapping(src as *const u8, dst, size);
332        Ok(())
333    }
334
335    fn copy_peer(
336        &self,
337        src: usize,
338        _dst_device: &dyn Device,
339        _dst: usize,
340        _size: usize,
341    ) -> CoreResult<()> {
342        Err(CrossDeviceError::TransferFailed {
343            from: "CPU".to_string(),
344            to: "unknown".to_string(),
345            reason: "Peer-to-peer not supported for CPU".to_string(),
346        }
347        .into())
348    }
349
350    fn synchronize(&self) -> CoreResult<()> {
351        // CPU operations are synchronous
352        Ok(())
353    }
354
355    fn available_memory(&self) -> CoreResult<usize> {
356        // Simple approximation - in reality would use platform-specific APIs
357        Ok(8 * 1024 * 1024 * 1024) // 8 GB
358    }
359
360    fn total_memory(&self) -> CoreResult<usize> {
361        Ok(16 * 1024 * 1024 * 1024) // 16 GB
362    }
363}
364
365/// GPU device wrapper.
366///
367/// Bridges the generic [`Device`] interface (which addresses memory through
368/// opaque `usize` handles) to the strongly-typed [`GpuContext`] buffer API.
369/// Because [`crate::gpu::GpuBuffer`] manages its device memory through RAII,
370/// this wrapper keeps each allocated buffer alive in a registry keyed by a
371/// stable handle. Copies are delegated to the real buffer, and `deallocate`
372/// drops the buffer (which releases the underlying device memory).
373pub struct GpuContextWrapper {
374    inner: Arc<GpuContext>,
375    device_type: DeviceType,
376    /// Live byte buffers keyed by the handle returned from `allocate`.
377    buffers: Mutex<HashMap<usize, crate::gpu::GpuBuffer<u8>>>,
378    /// Monotonic source of unique, non-zero allocation handles.
379    next_handle: Mutex<usize>,
380}
381
382impl GpuContextWrapper {
383    /// Create a new GPU device wrapper
384    pub fn new(gpu_device: Arc<GpuContext>, devicetype: DeviceType) -> Self {
385        Self {
386            inner: gpu_device,
387            device_type: devicetype,
388            buffers: Mutex::new(HashMap::new()),
389            next_handle: Mutex::new(1),
390        }
391    }
392}
393
394impl Device for GpuContextWrapper {
395    fn device_type(&self) -> DeviceType {
396        self.device_type.clone()
397    }
398
399    fn allocate(&self, size: usize) -> CoreResult<usize> {
400        // Allocate a real device buffer and retain it so the memory stays
401        // valid until `deallocate` is called. The returned handle is an opaque
402        // registry key, not a device pointer; callers must treat it as opaque.
403        let buffer = self.inner.create_buffer::<u8>(size);
404
405        let handle = {
406            let mut next_handle =
407                self.next_handle
408                    .lock()
409                    .map_err(|_| CrossDeviceError::AllocationFailed {
410                        device: self.device_type.as_str().to_string(),
411                        reason: "Handle counter lock poisoned".to_string(),
412                    })?;
413            let handle = *next_handle;
414            *next_handle = next_handle.wrapping_add(1).max(1);
415            handle
416        };
417
418        self.buffers
419            .lock()
420            .map_err(|_| CrossDeviceError::AllocationFailed {
421                device: self.device_type.as_str().to_string(),
422                reason: "Buffer registry lock poisoned".to_string(),
423            })?
424            .insert(handle, buffer);
425
426        Ok(handle)
427    }
428
429    fn deallocate(&self, address: usize) -> CoreResult<()> {
430        // Dropping the stored buffer releases the underlying device memory.
431        let removed = self
432            .buffers
433            .lock()
434            .map_err(|_| CrossDeviceError::AllocationFailed {
435                device: self.device_type.as_str().to_string(),
436                reason: "Buffer registry lock poisoned".to_string(),
437            })?
438            .remove(&address);
439
440        if removed.is_none() {
441            return Err(CrossDeviceError::MemoryNotFound(format!(
442                "GPU allocation handle {address} is not tracked by this device"
443            ))
444            .into());
445        }
446        Ok(())
447    }
448
449    unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
450        if src.is_null() || size == 0 {
451            return Ok(());
452        }
453
454        let buffers = self
455            .buffers
456            .lock()
457            .map_err(|_| CrossDeviceError::TransferFailed {
458                from: "host".to_string(),
459                to: self.device_type.as_str().to_string(),
460                reason: "Buffer registry lock poisoned".to_string(),
461            })?;
462        let buffer = buffers
463            .get(&dst)
464            .ok_or_else(|| CrossDeviceError::TransferFailed {
465                from: "host".to_string(),
466                to: self.device_type.as_str().to_string(),
467                reason: format!("Unknown destination handle {dst}"),
468            })?;
469
470        // SAFETY: caller guarantees `src` points to at least `size` valid bytes.
471        let host_slice = std::slice::from_raw_parts(src, size);
472        buffer
473            .copy_from_host(host_slice)
474            .map_err(|e| CrossDeviceError::TransferFailed {
475                from: "host".to_string(),
476                to: self.device_type.as_str().to_string(),
477                reason: e.to_string(),
478            })?;
479        Ok(())
480    }
481
482    unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
483        if dst.is_null() || size == 0 {
484            return Ok(());
485        }
486
487        let buffers = self
488            .buffers
489            .lock()
490            .map_err(|_| CrossDeviceError::TransferFailed {
491                from: self.device_type.as_str().to_string(),
492                to: "host".to_string(),
493                reason: "Buffer registry lock poisoned".to_string(),
494            })?;
495        let buffer = buffers
496            .get(&src)
497            .ok_or_else(|| CrossDeviceError::TransferFailed {
498                from: self.device_type.as_str().to_string(),
499                to: "host".to_string(),
500                reason: format!("Unknown source handle {src}"),
501            })?;
502
503        // SAFETY: caller guarantees `dst` points to at least `size` writable bytes.
504        let host_slice = std::slice::from_raw_parts_mut(dst, size);
505        buffer
506            .copy_to_host(host_slice)
507            .map_err(|e| CrossDeviceError::TransferFailed {
508                from: self.device_type.as_str().to_string(),
509                to: "host".to_string(),
510                reason: e.to_string(),
511            })?;
512        Ok(())
513    }
514
515    fn copy_peer(
516        &self,
517        _src: usize,
518        _dst_device: &dyn Device,
519        _dst: usize,
520        _size: usize,
521    ) -> CoreResult<()> {
522        // Direct device-to-device (peer) transfer is not wired through this
523        // generic wrapper. Callers should route peer copies via the host using
524        // `copy_to_host` + `copy_from_host`, or use a backend-specific path.
525        Err(CrossDeviceError::TransferFailed {
526            from: self.device_type.as_str().to_string(),
527            to: "peer".to_string(),
528            reason: "Peer-to-peer transfer is not implemented for the generic GPU wrapper"
529                .to_string(),
530        }
531        .into())
532    }
533
534    fn synchronize(&self) -> CoreResult<()> {
535        // The high-level GpuContext buffer copy operations used above are
536        // synchronous (host<->device copies block until complete), so there is
537        // no outstanding asynchronous work to wait on here.
538        Ok(())
539    }
540
541    fn available_memory(&self) -> CoreResult<usize> {
542        self.inner.get_available_memory().ok_or_else(|| {
543            CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
544        })
545    }
546
547    fn total_memory(&self) -> CoreResult<usize> {
548        self.inner.get_total_memory().ok_or_else(|| {
549            CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
550        })
551    }
552}
553
554/// Cross-device memory manager
555pub struct CrossDeviceMemoryManager {
556    devices: RwLock<HashMap<DeviceType, Arc<dyn Device>>>,
557    allocations: RwLock<HashMap<String, MemoryAllocation>>,
558    allocation_counter: Mutex<u64>,
559    default_device: RwLock<Option<DeviceType>>,
560}
561
562impl CrossDeviceMemoryManager {
563    /// Create a new cross-device memory manager
564    pub fn new() -> Self {
565        Self {
566            devices: RwLock::new(HashMap::new()),
567            allocations: RwLock::new(HashMap::new()),
568            allocation_counter: Mutex::new(0),
569            default_device: RwLock::new(None),
570        }
571    }
572
573    /// Register a device with the manager
574    pub fn register_device(&self, device: Arc<dyn Device>) -> CoreResult<()> {
575        let device_type = device.device_type();
576        let mut devices = self.devices.write().expect("Operation failed");
577        devices.insert(device_type.clone(), device);
578
579        // Set as default if it's the first device
580        let mut default_device = self.default_device.write().expect("Operation failed");
581        if default_device.is_none() {
582            *default_device = Some(device_type);
583        }
584
585        Ok(())
586    }
587
588    /// Set the default device
589    pub fn set_default_device(&self, devicetype: DeviceType) -> CoreResult<()> {
590        let devices = self.devices.read().expect("Operation failed");
591        if !devices.contains_key(&devicetype) {
592            return Err(CrossDeviceError::DeviceNotFound(format!("{devicetype:?}")).into());
593        }
594
595        let mut default_device = self.default_device.write().expect("Operation failed");
596        *default_device = Some(devicetype);
597
598        Ok(())
599    }
600
601    /// Get the default device
602    pub fn get_default_device(&self) -> Option<DeviceType> {
603        self.default_device
604            .read()
605            .expect("Operation failed")
606            .clone()
607    }
608
609    /// Allocate memory on a specific device
610    pub fn allocate<T: 'static>(
611        self: &Arc<Self>,
612        device_type: &DeviceType,
613        count: usize,
614    ) -> CoreResult<CrossDeviceBuffer<T>> {
615        let devices = self.devices.read().expect("Operation failed");
616        let device = devices
617            .get(device_type)
618            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{device_type:?}")))?;
619
620        let size = count * std::mem::size_of::<T>();
621        let address = device.allocate(size)?;
622
623        let allocation_id = self.generate_allocation_id();
624        let allocation = MemoryAllocation::new(
625            allocation_id.clone(),
626            device_type.clone(),
627            size,
628            address,
629            TypeId::of::<T>(),
630        );
631
632        let mut allocations = self.allocations.write().expect("Operation failed");
633        allocations.insert(allocation_id.clone(), allocation);
634
635        Ok(CrossDeviceBuffer::new(
636            allocation_id,
637            device_type.clone(),
638            address,
639            count,
640            self.clone(),
641        ))
642    }
643
644    /// Allocate memory on the default device
645    pub fn allocate_default<T: 'static>(
646        self: &Arc<Self>,
647        count: usize,
648    ) -> CoreResult<CrossDeviceBuffer<T>> {
649        let default_device = self
650            .get_default_device()
651            .ok_or_else(|| CrossDeviceError::DeviceNotFound("No default device set".to_string()))?;
652
653        self.allocate(&default_device, count)
654    }
655
656    /// Transfer data between devices
657    pub fn transfer<T: 'static + Copy>(
658        self: &Arc<Self>,
659        src_buffer: &CrossDeviceBuffer<T>,
660        dst_device: &DeviceType,
661    ) -> CoreResult<CrossDeviceBuffer<T>> {
662        let devices = self.devices.read().expect("Operation failed");
663        let src_device = devices.get(&src_buffer.device_type).ok_or_else(|| {
664            CrossDeviceError::DeviceNotFound(format!("{0:?}", src_buffer.device_type))
665        })?;
666        let dst_device_obj = devices
667            .get(dst_device)
668            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{dst_device:?}")))?;
669
670        // Allocate memory on destination device
671        let dst_buffer = self.allocate::<T>(dst_device, src_buffer.count)?;
672
673        let size = src_buffer.count * std::mem::size_of::<T>();
674
675        // Try peer-to-peer transfer first
676        if src_buffer.device_type.supports_p2p_transfer(dst_device) {
677            src_device.copy_peer(
678                src_buffer.address,
679                dst_device_obj.as_ref(),
680                dst_buffer.address,
681                size,
682            )?;
683        } else {
684            // Fall back to CPU staging
685            let staging_buffer = self.allocate::<T>(&DeviceType::Cpu, src_buffer.count)?;
686
687            // Copy from source to CPU
688            unsafe {
689                src_device.copy_to_host(
690                    src_buffer.address,
691                    staging_buffer.address as *mut u8,
692                    size,
693                )?;
694            }
695
696            // Copy from CPU to destination
697            unsafe {
698                dst_device_obj.copy_from_host(
699                    staging_buffer.address as *const u8,
700                    dst_buffer.address,
701                    size,
702                )?;
703            }
704        }
705
706        Ok(dst_buffer)
707    }
708
709    /// Synchronize all devices
710    pub fn synchronize_all(&self) -> CoreResult<()> {
711        let devices = self.devices.read().expect("Operation failed");
712        for device in devices.values() {
713            device.synchronize()?;
714        }
715        Ok(())
716    }
717
718    /// Get memory statistics
719    pub fn get_memory_statistics(&self) -> MemoryStatistics {
720        let allocations = self.allocations.read().expect("Operation failed");
721        let devices = self.devices.read().expect("Operation failed");
722
723        let mut stats_by_device = HashMap::new();
724        let mut total_allocated = 0;
725        let mut total_allocations = 0;
726
727        for allocation in allocations.values() {
728            let device_stats =
729                stats_by_device
730                    .entry(allocation.device.clone())
731                    .or_insert(DeviceMemoryStats {
732                        device_type: allocation.device.clone(),
733                        allocated_bytes: 0,
734                        allocation_count: 0,
735                        available_bytes: 0,
736                        total_bytes: 0,
737                    });
738
739            device_stats.allocated_bytes += allocation.size;
740            device_stats.allocation_count += 1;
741            total_allocated += allocation.size;
742            total_allocations += 1;
743        }
744
745        // Update available/total memory from devices and ensure all devices are included
746        for (device_type, device) in devices.iter() {
747            let device_stats =
748                stats_by_device
749                    .entry(device_type.clone())
750                    .or_insert(DeviceMemoryStats {
751                        device_type: device_type.clone(),
752                        allocated_bytes: 0,
753                        allocation_count: 0,
754                        available_bytes: 0,
755                        total_bytes: 0,
756                    });
757
758            device_stats.available_bytes = device.available_memory().unwrap_or(0);
759            device_stats.total_bytes = device.total_memory().unwrap_or(0);
760        }
761
762        MemoryStatistics {
763            total_allocated_bytes: total_allocated,
764            total_allocations,
765            device_stats: stats_by_device.into_values().collect(),
766        }
767    }
768
769    /// Clean up unused allocations
770    pub fn cleanup_unused_allocations(&self, maxage: std::time::Duration) -> usize {
771        let mut allocations = self.allocations.write().expect("Operation failed");
772        let now = std::time::Instant::now();
773        let mut cleaned = 0;
774
775        allocations.retain(|_, allocation| {
776            if allocation.ref_count == 0 && now.duration_since(allocation.last_accessed) > maxage {
777                // In a real implementation, we'd call deallocate on the device
778                cleaned += 1;
779                false
780            } else {
781                true
782            }
783        });
784
785        cleaned
786    }
787
788    /// Generate unique allocation ID
789    fn generate_allocation_id(&self) -> String {
790        let counter = {
791            let mut counter = self.allocation_counter.lock().expect("Operation failed");
792            *counter += 1;
793            *counter
794        };
795
796        format!("{counter:016x}")
797    }
798
799    /// Internal method to remove allocation (called by CrossDeviceBuffer on drop)
800    pub(crate) fn remove_allocation(&self, allocationid: &str) {
801        let mut allocations = self.allocations.write().expect("Operation failed");
802        if let Some(allocation) = allocations.get_mut(allocationid) {
803            if allocation.remove_ref() == 0 {
804                allocations.remove(allocationid);
805            }
806        }
807    }
808
809    /// Internal method to touch allocation (update last access time)
810    pub(crate) fn touch_allocation(&self, allocationid: &str) {
811        let mut allocations = self.allocations.write().expect("Operation failed");
812        if let Some(allocation) = allocations.get_mut(allocationid) {
813            allocation.touch();
814        }
815    }
816}
817
818impl Default for CrossDeviceMemoryManager {
819    fn default() -> Self {
820        Self::new()
821    }
822}
823
824/// Cross-device buffer that manages memory across different devices
825pub struct CrossDeviceBuffer<T> {
826    allocation_id: String,
827    device_type: DeviceType,
828    address: usize,
829    count: usize,
830    manager: Arc<CrossDeviceMemoryManager>,
831    phantom: std::marker::PhantomData<T>,
832}
833
834impl<T> CrossDeviceBuffer<T> {
835    /// Create a new cross-device buffer
836    fn new(
837        allocation_id: String,
838        device_type: DeviceType,
839        address: usize,
840        count: usize,
841        manager: Arc<CrossDeviceMemoryManager>,
842    ) -> Self {
843        Self {
844            allocation_id,
845            device_type,
846            address,
847            count,
848            manager,
849            phantom: std::marker::PhantomData,
850        }
851    }
852
853    /// Get the device type this buffer is allocated on
854    pub const fn device_type(&self) -> &DeviceType {
855        &self.device_type
856    }
857
858    /// Get the number of elements in the buffer
859    pub fn len(&self) -> usize {
860        self.count
861    }
862
863    /// Check if the buffer is empty
864    pub fn is_empty(&self) -> bool {
865        self.count == 0
866    }
867
868    /// Get the size in bytes
869    pub fn size_bytes(&self) -> usize {
870        self.count * std::mem::size_of::<T>()
871    }
872
873    /// Get the raw address (device-specific)
874    pub fn raw_address(&self) -> usize {
875        self.manager.touch_allocation(&self.allocation_id);
876        self.address
877    }
878
879    /// Transfer this buffer to another device
880    pub fn to_device(&self, devicetype: &DeviceType) -> CoreResult<CrossDeviceBuffer<T>>
881    where
882        T: Copy + 'static,
883    {
884        self.manager.transfer(self, devicetype)
885    }
886
887    /// Copy data from host to this buffer
888    pub fn copy_from_host(&self, data: &[T]) -> CoreResult<()>
889    where
890        T: Copy,
891    {
892        if data.len() != self.count {
893            return Err(CrossDeviceError::InvalidDeviceType(format!(
894                "Data length {} doesn't match buffer capacity {}",
895                data.len(),
896                self.count
897            ))
898            .into());
899        }
900
901        let devices = self.manager.devices.read().expect("Operation failed");
902        let device = devices
903            .get(&self.device_type)
904            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
905
906        unsafe {
907            device.copy_from_host(data.as_ptr() as *const u8, self.address, self.size_bytes())?;
908        }
909
910        self.manager.touch_allocation(&self.allocation_id);
911        Ok(())
912    }
913
914    /// Copy data from this buffer to host
915    pub fn copy_to_host(&self) -> CoreResult<Vec<T>>
916    where
917        T: Copy + Default,
918    {
919        let mut result = vec![T::default(); self.count];
920
921        let devices = self.manager.devices.read().expect("Operation failed");
922        let device = devices
923            .get(&self.device_type)
924            .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
925
926        unsafe {
927            device.copy_to_host(
928                self.address,
929                result.as_mut_ptr() as *mut u8,
930                self.size_bytes(),
931            )?;
932        }
933
934        self.manager.touch_allocation(&self.allocation_id);
935        Ok(result)
936    }
937}
938
939impl<T> Clone for CrossDeviceBuffer<T> {
940    fn clone(&self) -> Self {
941        // Increment reference count
942        {
943            let mut allocations = self.manager.allocations.write().expect("Operation failed");
944            if let Some(allocation) = allocations.get_mut(&self.allocation_id) {
945                allocation.add_ref();
946            }
947        }
948
949        Self {
950            allocation_id: self.allocation_id.clone(),
951            device_type: self.device_type.clone(),
952            address: self.address,
953            count: self.count,
954            manager: self.manager.clone(),
955            phantom: std::marker::PhantomData,
956        }
957    }
958}
959
960impl<T> Drop for CrossDeviceBuffer<T> {
961    fn drop(&mut self) {
962        self.manager.remove_allocation(&self.allocation_id);
963    }
964}
965
966/// Memory statistics
967#[derive(Debug, Clone)]
968pub struct MemoryStatistics {
969    /// Total bytes allocated across all devices
970    pub total_allocated_bytes: usize,
971    /// Total number of allocations
972    pub total_allocations: usize,
973    /// Statistics per device
974    pub device_stats: Vec<DeviceMemoryStats>,
975}
976
977/// Memory statistics for a specific device
978#[derive(Debug, Clone)]
979pub struct DeviceMemoryStats {
980    /// Device type
981    pub device_type: DeviceType,
982    /// Currently allocated bytes on this device
983    pub allocated_bytes: usize,
984    /// Number of active allocations
985    pub allocation_count: usize,
986    /// Available memory on device
987    pub available_bytes: usize,
988    /// Total memory on device
989    pub total_bytes: usize,
990}
991
992impl DeviceMemoryStats {
993    /// Get memory usage percentage
994    pub fn usage_percentage(&self) -> f64 {
995        if self.total_bytes == 0 {
996            0.0
997        } else {
998            (self.allocated_bytes as f64 / self.total_bytes as f64) * 100.0
999        }
1000    }
1001}
1002
1003/// Global cross-device memory manager instance
1004static GLOBAL_MANAGER: std::sync::OnceLock<Arc<CrossDeviceMemoryManager>> =
1005    std::sync::OnceLock::new();
1006
1007/// Get the global cross-device memory manager
1008#[allow(dead_code)]
1009pub fn global_manager() -> Arc<CrossDeviceMemoryManager> {
1010    GLOBAL_MANAGER
1011        .get_or_init(|| {
1012            let manager = Arc::new(CrossDeviceMemoryManager::new());
1013
1014            // Register CPU device by default
1015            let cpu_device = Arc::new(CpuDevice::new());
1016            let _ = manager.register_device(cpu_device);
1017
1018            manager
1019        })
1020        .clone()
1021}
1022
1023/// Initialize cross-device memory management with GPU devices
1024#[allow(dead_code)]
1025pub fn initialize_with_gpu_devices(gpudevices: Vec<Arc<GpuContext>>) -> CoreResult<()> {
1026    let manager = global_manager();
1027
1028    for (i, gpu_device) in gpudevices.into_iter().enumerate() {
1029        let device_type = DeviceType::CudaGpu(i as u32); // Assume CUDA for now
1030        let wrapper = Arc::new(GpuContextWrapper::new(gpu_device, device_type));
1031        manager.register_device(wrapper)?;
1032    }
1033
1034    Ok(())
1035}
1036
1037/// Convenience functions for cross-device memory management
1038pub mod utils {
1039    use super::*;
1040
1041    /// Allocate a buffer on the best available device
1042    pub fn allocate_optimal<T: 'static>(count: usize) -> CoreResult<CrossDeviceBuffer<T>> {
1043        let manager = global_manager();
1044        let stats = manager.get_memory_statistics();
1045
1046        // Find device with most available memory
1047        let best_device = stats
1048            .device_stats
1049            .iter()
1050            .max_by_key(|s| s.available_bytes)
1051            .map(|s| s.device_type.clone())
1052            .unwrap_or(DeviceType::Cpu);
1053
1054        manager.allocate(&best_device, count)
1055    }
1056
1057    /// Create a buffer with data from host
1058    pub fn create_buffer_with_data<T: Copy + 'static>(
1059        data: &[T],
1060        device_type: &DeviceType,
1061    ) -> CoreResult<CrossDeviceBuffer<T>> {
1062        let manager = global_manager();
1063        let buffer = manager.allocate(device_type, data.len())?;
1064        buffer.copy_from_host(data)?;
1065        Ok(buffer)
1066    }
1067
1068    /// Transfer data between any two devices
1069    pub fn transfer_data<T: Copy + 'static>(
1070        src_buffer: &CrossDeviceBuffer<T>,
1071        dst_device: &DeviceType,
1072    ) -> CoreResult<CrossDeviceBuffer<T>> {
1073        let manager = global_manager();
1074        manager.transfer(src_buffer, dst_device)
1075    }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080    use super::*;
1081
1082    #[test]
1083    fn test_device_type_creation() {
1084        let cpu = DeviceType::Cpu;
1085        let gpu = DeviceType::CudaGpu(0);
1086        let tpu = DeviceType::Tpu(1);
1087
1088        assert_eq!(cpu.as_str(), "CPU");
1089        assert_eq!(gpu.as_str(), "CUDA_GPU");
1090        assert_eq!(tpu.as_str(), "TPU");
1091
1092        assert_eq!(cpu.device_id(), 0);
1093        assert_eq!(gpu.device_id(), 0);
1094        assert_eq!(tpu.device_id(), 1);
1095    }
1096
1097    #[test]
1098    fn test_device_capabilities() {
1099        let cpu = DeviceType::Cpu;
1100        let cuda = DeviceType::CudaGpu(0);
1101        let rocm = DeviceType::RocmGpu(0);
1102
1103        assert!(!cpu.supports_unified_memory());
1104        assert!(cuda.supports_unified_memory());
1105        assert!(rocm.supports_unified_memory());
1106
1107        assert!(cuda.supports_p2p_transfer(&DeviceType::CudaGpu(1)));
1108        assert!(!cuda.supports_p2p_transfer(&DeviceType::RocmGpu(0)));
1109        assert!(!cpu.supports_p2p_transfer(&DeviceType::CudaGpu(0)));
1110    }
1111
1112    #[test]
1113    fn test_memory_allocation_creation() {
1114        let allocation = MemoryAllocation::new(
1115            "test_alloc".to_string(),
1116            DeviceType::Cpu,
1117            1024,
1118            0x1000,
1119            TypeId::of::<f32>(),
1120        );
1121
1122        assert_eq!(allocation.id, "test_alloc");
1123        assert_eq!(allocation.size, 1024);
1124        assert_eq!(allocation.address, 0x1000);
1125        assert_eq!(allocation.ref_count, 1);
1126    }
1127
1128    #[test]
1129    fn test_cpu_device() {
1130        let cpu = CpuDevice::new();
1131        assert_eq!(cpu.device_type(), DeviceType::Cpu);
1132
1133        // Test memory info
1134        assert!(cpu.available_memory().is_ok());
1135        assert!(cpu.total_memory().is_ok());
1136
1137        // Test synchronization
1138        assert!(cpu.synchronize().is_ok());
1139    }
1140
1141    #[test]
1142    fn test_cross_device_manager() {
1143        let manager = CrossDeviceMemoryManager::new();
1144
1145        // Register CPU device
1146        let cpu_device = Arc::new(CpuDevice::new());
1147        assert!(manager.register_device(cpu_device).is_ok());
1148
1149        // Check default device
1150        assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
1151
1152        // Get initial statistics
1153        let stats = manager.get_memory_statistics();
1154        assert_eq!(stats.total_allocations, 0);
1155        assert_eq!(stats.total_allocated_bytes, 0);
1156    }
1157
1158    #[test]
1159    fn test_global_manager() {
1160        let manager = global_manager();
1161        assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
1162
1163        let stats = manager.get_memory_statistics();
1164        assert!(!stats.device_stats.is_empty());
1165    }
1166
1167    #[test]
1168    fn test_memory_statistics() {
1169        let stats = DeviceMemoryStats {
1170            device_type: DeviceType::Cpu,
1171            allocated_bytes: 1024,
1172            allocation_count: 1,
1173            available_bytes: 7 * 1024 * 1024 * 1024,
1174            total_bytes: 8 * 1024 * 1024 * 1024,
1175        };
1176
1177        let usage = stats.usage_percentage();
1178        assert!(usage > 0.0 && usage < 1.0);
1179    }
1180}