1use crate::error::{CoreError, CoreResult};
7use crate::gpu::GpuContext;
8use std::any::TypeId;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, RwLock};
11
12#[derive(Debug, thiserror::Error)]
14pub enum CrossDeviceError {
15 #[error("Device not found: {0}")]
17 DeviceNotFound(String),
18
19 #[error("Memory allocation failed on device {device}: {reason}")]
21 AllocationFailed { device: String, reason: String },
22
23 #[error("Data transfer failed from {from} to {to}: {reason}")]
25 TransferFailed {
26 from: String,
27 to: String,
28 reason: String,
29 },
30
31 #[error("Device synchronization failed: {0}")]
33 SynchronizationFailed(String),
34
35 #[error("Invalid device type: {0}")]
37 InvalidDeviceType(String),
38
39 #[error("Memory allocation not found: {0}")]
41 MemoryNotFound(String),
42}
43
44impl From<CrossDeviceError> for CoreError {
45 fn from(err: CrossDeviceError) -> Self {
46 CoreError::ComputationError(crate::error::ErrorContext::new(err.to_string()))
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum DeviceType {
53 Cpu,
55 CudaGpu(u32),
57 RocmGpu(u32),
59 IntelGpu(u32),
61 MetalGpu(u32),
63 Tpu(u32),
65 OpenClDevice(u32),
67}
68
69impl DeviceType {
70 pub const fn as_str(&self) -> &'static str {
72 match self {
73 DeviceType::Cpu => "CPU",
74 DeviceType::CudaGpu(_) => "CUDA_GPU",
75 DeviceType::RocmGpu(_) => "ROCM_GPU",
76 DeviceType::IntelGpu(_) => "INTEL_GPU",
77 DeviceType::MetalGpu(_) => "METAL_GPU",
78 DeviceType::Tpu(_) => "TPU",
79 DeviceType::OpenClDevice(_) => "OPENCL",
80 }
81 }
82
83 pub fn device_id(&self) -> u32 {
85 match self {
86 DeviceType::Cpu => 0,
87 DeviceType::CudaGpu(id)
88 | DeviceType::RocmGpu(id)
89 | DeviceType::IntelGpu(id)
90 | DeviceType::MetalGpu(id)
91 | DeviceType::Tpu(id)
92 | DeviceType::OpenClDevice(id) => *id,
93 }
94 }
95
96 pub fn supports_unified_memory(&self) -> bool {
98 matches!(self, DeviceType::CudaGpu(_) | DeviceType::RocmGpu(_))
99 }
100
101 pub fn supports_p2p_transfer(&self, other: &DeviceType) -> bool {
103 matches!(
104 (self, other),
105 (DeviceType::CudaGpu(_), DeviceType::CudaGpu(_))
106 | (DeviceType::RocmGpu(_), DeviceType::RocmGpu(_))
107 )
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct MemoryAllocation {
114 pub id: String,
116 pub device: DeviceType,
118 pub size: usize,
120 pub address: usize,
122 pub datatype: TypeId,
124 pub created_at: std::time::Instant,
126 pub last_accessed: std::time::Instant,
128 pub ref_count: usize,
130}
131
132impl MemoryAllocation {
133 pub fn new(
135 allocation_id: String,
136 device: DeviceType,
137 size: usize,
138 address: usize,
139 datatype: TypeId,
140 ) -> Self {
141 let now = std::time::Instant::now();
142 Self {
143 id: allocation_id,
144 device,
145 size,
146 address,
147 datatype,
148 created_at: now,
149 last_accessed: now,
150 ref_count: 1,
151 }
152 }
153
154 pub fn touch(&mut self) {
156 self.last_accessed = std::time::Instant::now();
157 }
158
159 pub fn add_ref(&mut self) {
161 self.ref_count += 1;
162 }
163
164 pub fn remove_ref(&mut self) -> usize {
166 self.ref_count = self.ref_count.saturating_sub(1);
167 self.ref_count
168 }
169}
170
171pub trait Device: Send + Sync {
173 fn device_type(&self) -> DeviceType;
175
176 fn allocate(&self, size: usize) -> CoreResult<usize>;
178
179 fn deallocate(&self, address: usize) -> CoreResult<()>;
181
182 unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()>;
191
192 unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()>;
201
202 fn copy_peer(
204 &self,
205 src: usize,
206 dst_device: &dyn Device,
207 dst: usize,
208 size: usize,
209 ) -> CoreResult<()>;
210
211 fn synchronize(&self) -> CoreResult<()>;
213
214 fn available_memory(&self) -> CoreResult<usize>;
216
217 fn total_memory(&self) -> CoreResult<usize>;
219}
220
221pub struct CpuDevice {
223 device_type: DeviceType,
224}
225
226impl CpuDevice {
227 pub fn new() -> Self {
229 Self {
230 device_type: DeviceType::Cpu,
231 }
232 }
233}
234
235impl Default for CpuDevice {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241impl Device for CpuDevice {
242 fn device_type(&self) -> DeviceType {
243 self.device_type.clone()
244 }
245
246 fn allocate(&self, size: usize) -> CoreResult<usize> {
247 let layout = std::alloc::Layout::from_size_align(size, 64).map_err(|e| {
248 CrossDeviceError::AllocationFailed {
249 device: "CPU".to_string(),
250 reason: e.to_string(),
251 }
252 })?;
253
254 unsafe {
255 let ptr = std::alloc::alloc(layout);
256 if ptr.is_null() {
257 Err(CrossDeviceError::AllocationFailed {
258 device: "CPU".to_string(),
259 reason: "Out of memory".to_string(),
260 }
261 .into())
262 } else {
263 Ok(ptr as usize)
264 }
265 }
266 }
267
268 fn deallocate(&self, address: usize) -> CoreResult<()> {
269 let _ = address;
272 Ok(())
273 }
274
275 unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
276 std::ptr::copy_nonoverlapping(src, dst as *mut u8, size);
277 Ok(())
278 }
279
280 unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
281 std::ptr::copy_nonoverlapping(src as *const u8, dst, size);
282 Ok(())
283 }
284
285 fn copy_peer(
286 &self,
287 src: usize,
288 _dst_device: &dyn Device,
289 _dst: usize,
290 _size: usize,
291 ) -> CoreResult<()> {
292 Err(CrossDeviceError::TransferFailed {
293 from: "CPU".to_string(),
294 to: "unknown".to_string(),
295 reason: "Peer-to-peer not supported for CPU".to_string(),
296 }
297 .into())
298 }
299
300 fn synchronize(&self) -> CoreResult<()> {
301 Ok(())
303 }
304
305 fn available_memory(&self) -> CoreResult<usize> {
306 Ok(8 * 1024 * 1024 * 1024) }
309
310 fn total_memory(&self) -> CoreResult<usize> {
311 Ok(16 * 1024 * 1024 * 1024) }
313}
314
315pub struct GpuContextWrapper {
317 inner: Arc<GpuContext>,
318 device_type: DeviceType,
319}
320
321impl GpuContextWrapper {
322 pub fn new(gpu_device: Arc<GpuContext>, devicetype: DeviceType) -> Self {
324 Self {
325 inner: gpu_device,
326 device_type: devicetype,
327 }
328 }
329}
330
331impl Device for GpuContextWrapper {
332 fn device_type(&self) -> DeviceType {
333 self.device_type.clone()
334 }
335
336 fn allocate(&self, size: usize) -> CoreResult<usize> {
337 let _buffer = self.inner.create_buffer::<u8>(size);
339 Ok(size) }
343
344 fn deallocate(&self, address: usize) -> CoreResult<()> {
345 Ok(())
347 }
348
349 unsafe fn copy_from_host(&self, src: *const u8, _dst: usize, size: usize) -> CoreResult<()> {
350 Ok(())
352 }
353
354 unsafe fn copy_to_host(&self, src: usize, _dst: *mut u8, size: usize) -> CoreResult<()> {
355 Ok(())
357 }
358
359 fn copy_peer(
360 &self,
361 src: usize,
362 _dst_device: &dyn Device,
363 _dst: usize,
364 _size: usize,
365 ) -> CoreResult<()> {
366 Ok(())
368 }
369
370 fn synchronize(&self) -> CoreResult<()> {
371 Ok(())
373 }
374
375 fn available_memory(&self) -> CoreResult<usize> {
376 self.inner.get_available_memory().ok_or_else(|| {
377 CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
378 })
379 }
380
381 fn total_memory(&self) -> CoreResult<usize> {
382 self.inner.get_total_memory().ok_or_else(|| {
383 CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
384 })
385 }
386}
387
388pub struct CrossDeviceMemoryManager {
390 devices: RwLock<HashMap<DeviceType, Arc<dyn Device>>>,
391 allocations: RwLock<HashMap<String, MemoryAllocation>>,
392 allocation_counter: Mutex<u64>,
393 default_device: RwLock<Option<DeviceType>>,
394}
395
396impl CrossDeviceMemoryManager {
397 pub fn new() -> Self {
399 Self {
400 devices: RwLock::new(HashMap::new()),
401 allocations: RwLock::new(HashMap::new()),
402 allocation_counter: Mutex::new(0),
403 default_device: RwLock::new(None),
404 }
405 }
406
407 pub fn register_device(&self, device: Arc<dyn Device>) -> CoreResult<()> {
409 let device_type = device.device_type();
410 let mut devices = self.devices.write().expect("Operation failed");
411 devices.insert(device_type.clone(), device);
412
413 let mut default_device = self.default_device.write().expect("Operation failed");
415 if default_device.is_none() {
416 *default_device = Some(device_type);
417 }
418
419 Ok(())
420 }
421
422 pub fn set_default_device(&self, devicetype: DeviceType) -> CoreResult<()> {
424 let devices = self.devices.read().expect("Operation failed");
425 if !devices.contains_key(&devicetype) {
426 return Err(CrossDeviceError::DeviceNotFound(format!("{devicetype:?}")).into());
427 }
428
429 let mut default_device = self.default_device.write().expect("Operation failed");
430 *default_device = Some(devicetype);
431
432 Ok(())
433 }
434
435 pub fn get_default_device(&self) -> Option<DeviceType> {
437 self.default_device
438 .read()
439 .expect("Operation failed")
440 .clone()
441 }
442
443 pub fn allocate<T: 'static>(
445 self: &Arc<Self>,
446 device_type: &DeviceType,
447 count: usize,
448 ) -> CoreResult<CrossDeviceBuffer<T>> {
449 let devices = self.devices.read().expect("Operation failed");
450 let device = devices
451 .get(device_type)
452 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{device_type:?}")))?;
453
454 let size = count * std::mem::size_of::<T>();
455 let address = device.allocate(size)?;
456
457 let allocation_id = self.generate_allocation_id();
458 let allocation = MemoryAllocation::new(
459 allocation_id.clone(),
460 device_type.clone(),
461 size,
462 address,
463 TypeId::of::<T>(),
464 );
465
466 let mut allocations = self.allocations.write().expect("Operation failed");
467 allocations.insert(allocation_id.clone(), allocation);
468
469 Ok(CrossDeviceBuffer::new(
470 allocation_id,
471 device_type.clone(),
472 address,
473 count,
474 self.clone(),
475 ))
476 }
477
478 pub fn allocate_default<T: 'static>(
480 self: &Arc<Self>,
481 count: usize,
482 ) -> CoreResult<CrossDeviceBuffer<T>> {
483 let default_device = self
484 .get_default_device()
485 .ok_or_else(|| CrossDeviceError::DeviceNotFound("No default device set".to_string()))?;
486
487 self.allocate(&default_device, count)
488 }
489
490 pub fn transfer<T: 'static + Copy>(
492 self: &Arc<Self>,
493 src_buffer: &CrossDeviceBuffer<T>,
494 dst_device: &DeviceType,
495 ) -> CoreResult<CrossDeviceBuffer<T>> {
496 let devices = self.devices.read().expect("Operation failed");
497 let src_device = devices.get(&src_buffer.device_type).ok_or_else(|| {
498 CrossDeviceError::DeviceNotFound(format!("{0:?}", src_buffer.device_type))
499 })?;
500 let dst_device_obj = devices
501 .get(dst_device)
502 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{dst_device:?}")))?;
503
504 let dst_buffer = self.allocate::<T>(dst_device, src_buffer.count)?;
506
507 let size = src_buffer.count * std::mem::size_of::<T>();
508
509 if src_buffer.device_type.supports_p2p_transfer(dst_device) {
511 src_device.copy_peer(
512 src_buffer.address,
513 dst_device_obj.as_ref(),
514 dst_buffer.address,
515 size,
516 )?;
517 } else {
518 let staging_buffer = self.allocate::<T>(&DeviceType::Cpu, src_buffer.count)?;
520
521 unsafe {
523 src_device.copy_to_host(
524 src_buffer.address,
525 staging_buffer.address as *mut u8,
526 size,
527 )?;
528 }
529
530 unsafe {
532 dst_device_obj.copy_from_host(
533 staging_buffer.address as *const u8,
534 dst_buffer.address,
535 size,
536 )?;
537 }
538 }
539
540 Ok(dst_buffer)
541 }
542
543 pub fn synchronize_all(&self) -> CoreResult<()> {
545 let devices = self.devices.read().expect("Operation failed");
546 for device in devices.values() {
547 device.synchronize()?;
548 }
549 Ok(())
550 }
551
552 pub fn get_memory_statistics(&self) -> MemoryStatistics {
554 let allocations = self.allocations.read().expect("Operation failed");
555 let devices = self.devices.read().expect("Operation failed");
556
557 let mut stats_by_device = HashMap::new();
558 let mut total_allocated = 0;
559 let mut total_allocations = 0;
560
561 for allocation in allocations.values() {
562 let device_stats =
563 stats_by_device
564 .entry(allocation.device.clone())
565 .or_insert(DeviceMemoryStats {
566 device_type: allocation.device.clone(),
567 allocated_bytes: 0,
568 allocation_count: 0,
569 available_bytes: 0,
570 total_bytes: 0,
571 });
572
573 device_stats.allocated_bytes += allocation.size;
574 device_stats.allocation_count += 1;
575 total_allocated += allocation.size;
576 total_allocations += 1;
577 }
578
579 for (device_type, device) in devices.iter() {
581 let device_stats =
582 stats_by_device
583 .entry(device_type.clone())
584 .or_insert(DeviceMemoryStats {
585 device_type: device_type.clone(),
586 allocated_bytes: 0,
587 allocation_count: 0,
588 available_bytes: 0,
589 total_bytes: 0,
590 });
591
592 device_stats.available_bytes = device.available_memory().unwrap_or(0);
593 device_stats.total_bytes = device.total_memory().unwrap_or(0);
594 }
595
596 MemoryStatistics {
597 total_allocated_bytes: total_allocated,
598 total_allocations,
599 device_stats: stats_by_device.into_values().collect(),
600 }
601 }
602
603 pub fn cleanup_unused_allocations(&self, maxage: std::time::Duration) -> usize {
605 let mut allocations = self.allocations.write().expect("Operation failed");
606 let now = std::time::Instant::now();
607 let mut cleaned = 0;
608
609 allocations.retain(|_, allocation| {
610 if allocation.ref_count == 0 && now.duration_since(allocation.last_accessed) > maxage {
611 cleaned += 1;
613 false
614 } else {
615 true
616 }
617 });
618
619 cleaned
620 }
621
622 fn generate_allocation_id(&self) -> String {
624 let counter = {
625 let mut counter = self.allocation_counter.lock().expect("Operation failed");
626 *counter += 1;
627 *counter
628 };
629
630 format!("{counter:016x}")
631 }
632
633 pub(crate) fn remove_allocation(&self, allocationid: &str) {
635 let mut allocations = self.allocations.write().expect("Operation failed");
636 if let Some(allocation) = allocations.get_mut(allocationid) {
637 if allocation.remove_ref() == 0 {
638 allocations.remove(allocationid);
639 }
640 }
641 }
642
643 pub(crate) fn touch_allocation(&self, allocationid: &str) {
645 let mut allocations = self.allocations.write().expect("Operation failed");
646 if let Some(allocation) = allocations.get_mut(allocationid) {
647 allocation.touch();
648 }
649 }
650}
651
652impl Default for CrossDeviceMemoryManager {
653 fn default() -> Self {
654 Self::new()
655 }
656}
657
658pub struct CrossDeviceBuffer<T> {
660 allocation_id: String,
661 device_type: DeviceType,
662 address: usize,
663 count: usize,
664 manager: Arc<CrossDeviceMemoryManager>,
665 phantom: std::marker::PhantomData<T>,
666}
667
668impl<T> CrossDeviceBuffer<T> {
669 fn new(
671 allocation_id: String,
672 device_type: DeviceType,
673 address: usize,
674 count: usize,
675 manager: Arc<CrossDeviceMemoryManager>,
676 ) -> Self {
677 Self {
678 allocation_id,
679 device_type,
680 address,
681 count,
682 manager,
683 phantom: std::marker::PhantomData,
684 }
685 }
686
687 pub const fn device_type(&self) -> &DeviceType {
689 &self.device_type
690 }
691
692 pub fn len(&self) -> usize {
694 self.count
695 }
696
697 pub fn is_empty(&self) -> bool {
699 self.count == 0
700 }
701
702 pub fn size_bytes(&self) -> usize {
704 self.count * std::mem::size_of::<T>()
705 }
706
707 pub fn raw_address(&self) -> usize {
709 self.manager.touch_allocation(&self.allocation_id);
710 self.address
711 }
712
713 pub fn to_device(&self, devicetype: &DeviceType) -> CoreResult<CrossDeviceBuffer<T>>
715 where
716 T: Copy + 'static,
717 {
718 self.manager.transfer(self, devicetype)
719 }
720
721 pub fn copy_from_host(&self, data: &[T]) -> CoreResult<()>
723 where
724 T: Copy,
725 {
726 if data.len() != self.count {
727 return Err(CrossDeviceError::InvalidDeviceType(format!(
728 "Data length {} doesn't match buffer capacity {}",
729 data.len(),
730 self.count
731 ))
732 .into());
733 }
734
735 let devices = self.manager.devices.read().expect("Operation failed");
736 let device = devices
737 .get(&self.device_type)
738 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
739
740 unsafe {
741 device.copy_from_host(data.as_ptr() as *const u8, self.address, self.size_bytes())?;
742 }
743
744 self.manager.touch_allocation(&self.allocation_id);
745 Ok(())
746 }
747
748 pub fn copy_to_host(&self) -> CoreResult<Vec<T>>
750 where
751 T: Copy + Default,
752 {
753 let mut result = vec![T::default(); self.count];
754
755 let devices = self.manager.devices.read().expect("Operation failed");
756 let device = devices
757 .get(&self.device_type)
758 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
759
760 unsafe {
761 device.copy_to_host(
762 self.address,
763 result.as_mut_ptr() as *mut u8,
764 self.size_bytes(),
765 )?;
766 }
767
768 self.manager.touch_allocation(&self.allocation_id);
769 Ok(result)
770 }
771}
772
773impl<T> Clone for CrossDeviceBuffer<T> {
774 fn clone(&self) -> Self {
775 {
777 let mut allocations = self.manager.allocations.write().expect("Operation failed");
778 if let Some(allocation) = allocations.get_mut(&self.allocation_id) {
779 allocation.add_ref();
780 }
781 }
782
783 Self {
784 allocation_id: self.allocation_id.clone(),
785 device_type: self.device_type.clone(),
786 address: self.address,
787 count: self.count,
788 manager: self.manager.clone(),
789 phantom: std::marker::PhantomData,
790 }
791 }
792}
793
794impl<T> Drop for CrossDeviceBuffer<T> {
795 fn drop(&mut self) {
796 self.manager.remove_allocation(&self.allocation_id);
797 }
798}
799
800#[derive(Debug, Clone)]
802pub struct MemoryStatistics {
803 pub total_allocated_bytes: usize,
805 pub total_allocations: usize,
807 pub device_stats: Vec<DeviceMemoryStats>,
809}
810
811#[derive(Debug, Clone)]
813pub struct DeviceMemoryStats {
814 pub device_type: DeviceType,
816 pub allocated_bytes: usize,
818 pub allocation_count: usize,
820 pub available_bytes: usize,
822 pub total_bytes: usize,
824}
825
826impl DeviceMemoryStats {
827 pub fn usage_percentage(&self) -> f64 {
829 if self.total_bytes == 0 {
830 0.0
831 } else {
832 (self.allocated_bytes as f64 / self.total_bytes as f64) * 100.0
833 }
834 }
835}
836
837static GLOBAL_MANAGER: std::sync::OnceLock<Arc<CrossDeviceMemoryManager>> =
839 std::sync::OnceLock::new();
840
841#[allow(dead_code)]
843pub fn global_manager() -> Arc<CrossDeviceMemoryManager> {
844 GLOBAL_MANAGER
845 .get_or_init(|| {
846 let manager = Arc::new(CrossDeviceMemoryManager::new());
847
848 let cpu_device = Arc::new(CpuDevice::new());
850 let _ = manager.register_device(cpu_device);
851
852 manager
853 })
854 .clone()
855}
856
857#[allow(dead_code)]
859pub fn initialize_with_gpu_devices(gpudevices: Vec<Arc<GpuContext>>) -> CoreResult<()> {
860 let manager = global_manager();
861
862 for (i, gpu_device) in gpudevices.into_iter().enumerate() {
863 let device_type = DeviceType::CudaGpu(i as u32); let wrapper = Arc::new(GpuContextWrapper::new(gpu_device, device_type));
865 manager.register_device(wrapper)?;
866 }
867
868 Ok(())
869}
870
871pub mod utils {
873 use super::*;
874
875 pub fn allocate_optimal<T: 'static>(count: usize) -> CoreResult<CrossDeviceBuffer<T>> {
877 let manager = global_manager();
878 let stats = manager.get_memory_statistics();
879
880 let best_device = stats
882 .device_stats
883 .iter()
884 .max_by_key(|s| s.available_bytes)
885 .map(|s| s.device_type.clone())
886 .unwrap_or(DeviceType::Cpu);
887
888 manager.allocate(&best_device, count)
889 }
890
891 pub fn create_buffer_with_data<T: Copy + 'static>(
893 data: &[T],
894 device_type: &DeviceType,
895 ) -> CoreResult<CrossDeviceBuffer<T>> {
896 let manager = global_manager();
897 let buffer = manager.allocate(device_type, data.len())?;
898 buffer.copy_from_host(data)?;
899 Ok(buffer)
900 }
901
902 pub fn transfer_data<T: Copy + 'static>(
904 src_buffer: &CrossDeviceBuffer<T>,
905 dst_device: &DeviceType,
906 ) -> CoreResult<CrossDeviceBuffer<T>> {
907 let manager = global_manager();
908 manager.transfer(src_buffer, dst_device)
909 }
910}
911
912#[cfg(test)]
913mod tests {
914 use super::*;
915
916 #[test]
917 fn test_device_type_creation() {
918 let cpu = DeviceType::Cpu;
919 let gpu = DeviceType::CudaGpu(0);
920 let tpu = DeviceType::Tpu(1);
921
922 assert_eq!(cpu.as_str(), "CPU");
923 assert_eq!(gpu.as_str(), "CUDA_GPU");
924 assert_eq!(tpu.as_str(), "TPU");
925
926 assert_eq!(cpu.device_id(), 0);
927 assert_eq!(gpu.device_id(), 0);
928 assert_eq!(tpu.device_id(), 1);
929 }
930
931 #[test]
932 fn test_device_capabilities() {
933 let cpu = DeviceType::Cpu;
934 let cuda = DeviceType::CudaGpu(0);
935 let rocm = DeviceType::RocmGpu(0);
936
937 assert!(!cpu.supports_unified_memory());
938 assert!(cuda.supports_unified_memory());
939 assert!(rocm.supports_unified_memory());
940
941 assert!(cuda.supports_p2p_transfer(&DeviceType::CudaGpu(1)));
942 assert!(!cuda.supports_p2p_transfer(&DeviceType::RocmGpu(0)));
943 assert!(!cpu.supports_p2p_transfer(&DeviceType::CudaGpu(0)));
944 }
945
946 #[test]
947 fn test_memory_allocation_creation() {
948 let allocation = MemoryAllocation::new(
949 "test_alloc".to_string(),
950 DeviceType::Cpu,
951 1024,
952 0x1000,
953 TypeId::of::<f32>(),
954 );
955
956 assert_eq!(allocation.id, "test_alloc");
957 assert_eq!(allocation.size, 1024);
958 assert_eq!(allocation.address, 0x1000);
959 assert_eq!(allocation.ref_count, 1);
960 }
961
962 #[test]
963 fn test_cpu_device() {
964 let cpu = CpuDevice::new();
965 assert_eq!(cpu.device_type(), DeviceType::Cpu);
966
967 assert!(cpu.available_memory().is_ok());
969 assert!(cpu.total_memory().is_ok());
970
971 assert!(cpu.synchronize().is_ok());
973 }
974
975 #[test]
976 fn test_cross_device_manager() {
977 let manager = CrossDeviceMemoryManager::new();
978
979 let cpu_device = Arc::new(CpuDevice::new());
981 assert!(manager.register_device(cpu_device).is_ok());
982
983 assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
985
986 let stats = manager.get_memory_statistics();
988 assert_eq!(stats.total_allocations, 0);
989 assert_eq!(stats.total_allocated_bytes, 0);
990 }
991
992 #[test]
993 fn test_global_manager() {
994 let manager = global_manager();
995 assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
996
997 let stats = manager.get_memory_statistics();
998 assert!(!stats.device_stats.is_empty());
999 }
1000
1001 #[test]
1002 fn test_memory_statistics() {
1003 let stats = DeviceMemoryStats {
1004 device_type: DeviceType::Cpu,
1005 allocated_bytes: 1024,
1006 allocation_count: 1,
1007 available_bytes: 7 * 1024 * 1024 * 1024,
1008 total_bytes: 8 * 1024 * 1024 * 1024,
1009 };
1010
1011 let usage = stats.usage_percentage();
1012 assert!(usage > 0.0 && usage < 1.0);
1013 }
1014}