1use crate::error::{CoreError, CoreResult};
7use crate::gpu::GpuContext;
8use std::any::TypeId;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, RwLock};
11
12#[derive(Debug, thiserror::Error)]
14pub enum CrossDeviceError {
15 #[error("Device not found: {0}")]
17 DeviceNotFound(String),
18
19 #[error("Memory allocation failed on device {device}: {reason}")]
21 AllocationFailed { device: String, reason: String },
22
23 #[error("Data transfer failed from {from} to {to}: {reason}")]
25 TransferFailed {
26 from: String,
27 to: String,
28 reason: String,
29 },
30
31 #[error("Device synchronization failed: {0}")]
33 SynchronizationFailed(String),
34
35 #[error("Invalid device type: {0}")]
37 InvalidDeviceType(String),
38
39 #[error("Memory allocation not found: {0}")]
41 MemoryNotFound(String),
42}
43
44impl From<CrossDeviceError> for CoreError {
45 fn from(err: CrossDeviceError) -> Self {
46 CoreError::ComputationError(crate::error::ErrorContext::new(err.to_string()))
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum DeviceType {
53 Cpu,
55 CudaGpu(u32),
57 RocmGpu(u32),
59 IntelGpu(u32),
61 MetalGpu(u32),
63 Tpu(u32),
65 OpenClDevice(u32),
67}
68
69impl DeviceType {
70 pub const fn as_str(&self) -> &'static str {
72 match self {
73 DeviceType::Cpu => "CPU",
74 DeviceType::CudaGpu(_) => "CUDA_GPU",
75 DeviceType::RocmGpu(_) => "ROCM_GPU",
76 DeviceType::IntelGpu(_) => "INTEL_GPU",
77 DeviceType::MetalGpu(_) => "METAL_GPU",
78 DeviceType::Tpu(_) => "TPU",
79 DeviceType::OpenClDevice(_) => "OPENCL",
80 }
81 }
82
83 pub fn device_id(&self) -> u32 {
85 match self {
86 DeviceType::Cpu => 0,
87 DeviceType::CudaGpu(id)
88 | DeviceType::RocmGpu(id)
89 | DeviceType::IntelGpu(id)
90 | DeviceType::MetalGpu(id)
91 | DeviceType::Tpu(id)
92 | DeviceType::OpenClDevice(id) => *id,
93 }
94 }
95
96 pub fn supports_unified_memory(&self) -> bool {
98 matches!(self, DeviceType::CudaGpu(_) | DeviceType::RocmGpu(_))
99 }
100
101 pub fn supports_p2p_transfer(&self, other: &DeviceType) -> bool {
103 matches!(
104 (self, other),
105 (DeviceType::CudaGpu(_), DeviceType::CudaGpu(_))
106 | (DeviceType::RocmGpu(_), DeviceType::RocmGpu(_))
107 )
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct MemoryAllocation {
114 pub id: String,
116 pub device: DeviceType,
118 pub size: usize,
120 pub address: usize,
122 pub datatype: TypeId,
124 pub created_at: std::time::Instant,
126 pub last_accessed: std::time::Instant,
128 pub ref_count: usize,
130}
131
132impl MemoryAllocation {
133 pub fn new(
135 allocation_id: String,
136 device: DeviceType,
137 size: usize,
138 address: usize,
139 datatype: TypeId,
140 ) -> Self {
141 let now = std::time::Instant::now();
142 Self {
143 id: allocation_id,
144 device,
145 size,
146 address,
147 datatype,
148 created_at: now,
149 last_accessed: now,
150 ref_count: 1,
151 }
152 }
153
154 pub fn touch(&mut self) {
156 self.last_accessed = std::time::Instant::now();
157 }
158
159 pub fn add_ref(&mut self) {
161 self.ref_count += 1;
162 }
163
164 pub fn remove_ref(&mut self) -> usize {
166 self.ref_count = self.ref_count.saturating_sub(1);
167 self.ref_count
168 }
169}
170
171pub trait Device: Send + Sync {
173 fn device_type(&self) -> DeviceType;
175
176 fn allocate(&self, size: usize) -> CoreResult<usize>;
178
179 fn deallocate(&self, address: usize) -> CoreResult<()>;
181
182 unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()>;
191
192 unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()>;
201
202 fn copy_peer(
204 &self,
205 src: usize,
206 dst_device: &dyn Device,
207 dst: usize,
208 size: usize,
209 ) -> CoreResult<()>;
210
211 fn synchronize(&self) -> CoreResult<()>;
213
214 fn available_memory(&self) -> CoreResult<usize>;
216
217 fn total_memory(&self) -> CoreResult<usize>;
219}
220
221pub struct CpuDevice {
223 device_type: DeviceType,
224 allocations: Mutex<HashMap<usize, std::alloc::Layout>>,
230}
231
232impl CpuDevice {
233 pub fn new() -> Self {
235 Self {
236 device_type: DeviceType::Cpu,
237 allocations: Mutex::new(HashMap::new()),
238 }
239 }
240}
241
242impl Default for CpuDevice {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248impl Device for CpuDevice {
249 fn device_type(&self) -> DeviceType {
250 self.device_type.clone()
251 }
252
253 fn allocate(&self, size: usize) -> CoreResult<usize> {
254 let layout = std::alloc::Layout::from_size_align(size, 64).map_err(|e| {
255 CrossDeviceError::AllocationFailed {
256 device: "CPU".to_string(),
257 reason: e.to_string(),
258 }
259 })?;
260
261 if size == 0 {
263 return Ok(0);
264 }
265
266 unsafe {
267 let ptr = std::alloc::alloc(layout);
268 if ptr.is_null() {
269 Err(CrossDeviceError::AllocationFailed {
270 device: "CPU".to_string(),
271 reason: "Out of memory".to_string(),
272 }
273 .into())
274 } else {
275 self.allocations
277 .lock()
278 .map_err(|_| CrossDeviceError::AllocationFailed {
279 device: "CPU".to_string(),
280 reason: "Allocation registry lock poisoned".to_string(),
281 })?
282 .insert(ptr as usize, layout);
283 Ok(ptr as usize)
284 }
285 }
286 }
287
288 fn deallocate(&self, address: usize) -> CoreResult<()> {
289 if address == 0 {
292 return Ok(());
293 }
294
295 let layout = {
298 let mut allocations =
299 self.allocations
300 .lock()
301 .map_err(|_| CrossDeviceError::AllocationFailed {
302 device: "CPU".to_string(),
303 reason: "Allocation registry lock poisoned".to_string(),
304 })?;
305 allocations.remove(&address)
306 };
307
308 match layout {
309 Some(layout) => {
310 unsafe {
314 std::alloc::dealloc(address as *mut u8, layout);
315 }
316 Ok(())
317 }
318 None => Err(CrossDeviceError::MemoryNotFound(format!(
319 "CPU allocation at address {address:#x} is not tracked by this device"
320 ))
321 .into()),
322 }
323 }
324
325 unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
326 std::ptr::copy_nonoverlapping(src, dst as *mut u8, size);
327 Ok(())
328 }
329
330 unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
331 std::ptr::copy_nonoverlapping(src as *const u8, dst, size);
332 Ok(())
333 }
334
335 fn copy_peer(
336 &self,
337 src: usize,
338 _dst_device: &dyn Device,
339 _dst: usize,
340 _size: usize,
341 ) -> CoreResult<()> {
342 Err(CrossDeviceError::TransferFailed {
343 from: "CPU".to_string(),
344 to: "unknown".to_string(),
345 reason: "Peer-to-peer not supported for CPU".to_string(),
346 }
347 .into())
348 }
349
350 fn synchronize(&self) -> CoreResult<()> {
351 Ok(())
353 }
354
355 fn available_memory(&self) -> CoreResult<usize> {
356 Ok(8 * 1024 * 1024 * 1024) }
359
360 fn total_memory(&self) -> CoreResult<usize> {
361 Ok(16 * 1024 * 1024 * 1024) }
363}
364
365pub struct GpuContextWrapper {
374 inner: Arc<GpuContext>,
375 device_type: DeviceType,
376 buffers: Mutex<HashMap<usize, crate::gpu::GpuBuffer<u8>>>,
378 next_handle: Mutex<usize>,
380}
381
382impl GpuContextWrapper {
383 pub fn new(gpu_device: Arc<GpuContext>, devicetype: DeviceType) -> Self {
385 Self {
386 inner: gpu_device,
387 device_type: devicetype,
388 buffers: Mutex::new(HashMap::new()),
389 next_handle: Mutex::new(1),
390 }
391 }
392}
393
394impl Device for GpuContextWrapper {
395 fn device_type(&self) -> DeviceType {
396 self.device_type.clone()
397 }
398
399 fn allocate(&self, size: usize) -> CoreResult<usize> {
400 let buffer = self.inner.create_buffer::<u8>(size);
404
405 let handle = {
406 let mut next_handle =
407 self.next_handle
408 .lock()
409 .map_err(|_| CrossDeviceError::AllocationFailed {
410 device: self.device_type.as_str().to_string(),
411 reason: "Handle counter lock poisoned".to_string(),
412 })?;
413 let handle = *next_handle;
414 *next_handle = next_handle.wrapping_add(1).max(1);
415 handle
416 };
417
418 self.buffers
419 .lock()
420 .map_err(|_| CrossDeviceError::AllocationFailed {
421 device: self.device_type.as_str().to_string(),
422 reason: "Buffer registry lock poisoned".to_string(),
423 })?
424 .insert(handle, buffer);
425
426 Ok(handle)
427 }
428
429 fn deallocate(&self, address: usize) -> CoreResult<()> {
430 let removed = self
432 .buffers
433 .lock()
434 .map_err(|_| CrossDeviceError::AllocationFailed {
435 device: self.device_type.as_str().to_string(),
436 reason: "Buffer registry lock poisoned".to_string(),
437 })?
438 .remove(&address);
439
440 if removed.is_none() {
441 return Err(CrossDeviceError::MemoryNotFound(format!(
442 "GPU allocation handle {address} is not tracked by this device"
443 ))
444 .into());
445 }
446 Ok(())
447 }
448
449 unsafe fn copy_from_host(&self, src: *const u8, dst: usize, size: usize) -> CoreResult<()> {
450 if src.is_null() || size == 0 {
451 return Ok(());
452 }
453
454 let buffers = self
455 .buffers
456 .lock()
457 .map_err(|_| CrossDeviceError::TransferFailed {
458 from: "host".to_string(),
459 to: self.device_type.as_str().to_string(),
460 reason: "Buffer registry lock poisoned".to_string(),
461 })?;
462 let buffer = buffers
463 .get(&dst)
464 .ok_or_else(|| CrossDeviceError::TransferFailed {
465 from: "host".to_string(),
466 to: self.device_type.as_str().to_string(),
467 reason: format!("Unknown destination handle {dst}"),
468 })?;
469
470 let host_slice = std::slice::from_raw_parts(src, size);
472 buffer
473 .copy_from_host(host_slice)
474 .map_err(|e| CrossDeviceError::TransferFailed {
475 from: "host".to_string(),
476 to: self.device_type.as_str().to_string(),
477 reason: e.to_string(),
478 })?;
479 Ok(())
480 }
481
482 unsafe fn copy_to_host(&self, src: usize, dst: *mut u8, size: usize) -> CoreResult<()> {
483 if dst.is_null() || size == 0 {
484 return Ok(());
485 }
486
487 let buffers = self
488 .buffers
489 .lock()
490 .map_err(|_| CrossDeviceError::TransferFailed {
491 from: self.device_type.as_str().to_string(),
492 to: "host".to_string(),
493 reason: "Buffer registry lock poisoned".to_string(),
494 })?;
495 let buffer = buffers
496 .get(&src)
497 .ok_or_else(|| CrossDeviceError::TransferFailed {
498 from: self.device_type.as_str().to_string(),
499 to: "host".to_string(),
500 reason: format!("Unknown source handle {src}"),
501 })?;
502
503 let host_slice = std::slice::from_raw_parts_mut(dst, size);
505 buffer
506 .copy_to_host(host_slice)
507 .map_err(|e| CrossDeviceError::TransferFailed {
508 from: self.device_type.as_str().to_string(),
509 to: "host".to_string(),
510 reason: e.to_string(),
511 })?;
512 Ok(())
513 }
514
515 fn copy_peer(
516 &self,
517 _src: usize,
518 _dst_device: &dyn Device,
519 _dst: usize,
520 _size: usize,
521 ) -> CoreResult<()> {
522 Err(CrossDeviceError::TransferFailed {
526 from: self.device_type.as_str().to_string(),
527 to: "peer".to_string(),
528 reason: "Peer-to-peer transfer is not implemented for the generic GPU wrapper"
529 .to_string(),
530 }
531 .into())
532 }
533
534 fn synchronize(&self) -> CoreResult<()> {
535 Ok(())
539 }
540
541 fn available_memory(&self) -> CoreResult<usize> {
542 self.inner.get_available_memory().ok_or_else(|| {
543 CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
544 })
545 }
546
547 fn total_memory(&self) -> CoreResult<usize> {
548 self.inner.get_total_memory().ok_or_else(|| {
549 CrossDeviceError::DeviceNotFound("GPU memory info unavailable".to_string()).into()
550 })
551 }
552}
553
554pub struct CrossDeviceMemoryManager {
556 devices: RwLock<HashMap<DeviceType, Arc<dyn Device>>>,
557 allocations: RwLock<HashMap<String, MemoryAllocation>>,
558 allocation_counter: Mutex<u64>,
559 default_device: RwLock<Option<DeviceType>>,
560}
561
562impl CrossDeviceMemoryManager {
563 pub fn new() -> Self {
565 Self {
566 devices: RwLock::new(HashMap::new()),
567 allocations: RwLock::new(HashMap::new()),
568 allocation_counter: Mutex::new(0),
569 default_device: RwLock::new(None),
570 }
571 }
572
573 pub fn register_device(&self, device: Arc<dyn Device>) -> CoreResult<()> {
575 let device_type = device.device_type();
576 let mut devices = self.devices.write().expect("Operation failed");
577 devices.insert(device_type.clone(), device);
578
579 let mut default_device = self.default_device.write().expect("Operation failed");
581 if default_device.is_none() {
582 *default_device = Some(device_type);
583 }
584
585 Ok(())
586 }
587
588 pub fn set_default_device(&self, devicetype: DeviceType) -> CoreResult<()> {
590 let devices = self.devices.read().expect("Operation failed");
591 if !devices.contains_key(&devicetype) {
592 return Err(CrossDeviceError::DeviceNotFound(format!("{devicetype:?}")).into());
593 }
594
595 let mut default_device = self.default_device.write().expect("Operation failed");
596 *default_device = Some(devicetype);
597
598 Ok(())
599 }
600
601 pub fn get_default_device(&self) -> Option<DeviceType> {
603 self.default_device
604 .read()
605 .expect("Operation failed")
606 .clone()
607 }
608
609 pub fn allocate<T: 'static>(
611 self: &Arc<Self>,
612 device_type: &DeviceType,
613 count: usize,
614 ) -> CoreResult<CrossDeviceBuffer<T>> {
615 let devices = self.devices.read().expect("Operation failed");
616 let device = devices
617 .get(device_type)
618 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{device_type:?}")))?;
619
620 let size = count * std::mem::size_of::<T>();
621 let address = device.allocate(size)?;
622
623 let allocation_id = self.generate_allocation_id();
624 let allocation = MemoryAllocation::new(
625 allocation_id.clone(),
626 device_type.clone(),
627 size,
628 address,
629 TypeId::of::<T>(),
630 );
631
632 let mut allocations = self.allocations.write().expect("Operation failed");
633 allocations.insert(allocation_id.clone(), allocation);
634
635 Ok(CrossDeviceBuffer::new(
636 allocation_id,
637 device_type.clone(),
638 address,
639 count,
640 self.clone(),
641 ))
642 }
643
644 pub fn allocate_default<T: 'static>(
646 self: &Arc<Self>,
647 count: usize,
648 ) -> CoreResult<CrossDeviceBuffer<T>> {
649 let default_device = self
650 .get_default_device()
651 .ok_or_else(|| CrossDeviceError::DeviceNotFound("No default device set".to_string()))?;
652
653 self.allocate(&default_device, count)
654 }
655
656 pub fn transfer<T: 'static + Copy>(
658 self: &Arc<Self>,
659 src_buffer: &CrossDeviceBuffer<T>,
660 dst_device: &DeviceType,
661 ) -> CoreResult<CrossDeviceBuffer<T>> {
662 let devices = self.devices.read().expect("Operation failed");
663 let src_device = devices.get(&src_buffer.device_type).ok_or_else(|| {
664 CrossDeviceError::DeviceNotFound(format!("{0:?}", src_buffer.device_type))
665 })?;
666 let dst_device_obj = devices
667 .get(dst_device)
668 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{dst_device:?}")))?;
669
670 let dst_buffer = self.allocate::<T>(dst_device, src_buffer.count)?;
672
673 let size = src_buffer.count * std::mem::size_of::<T>();
674
675 if src_buffer.device_type.supports_p2p_transfer(dst_device) {
677 src_device.copy_peer(
678 src_buffer.address,
679 dst_device_obj.as_ref(),
680 dst_buffer.address,
681 size,
682 )?;
683 } else {
684 let staging_buffer = self.allocate::<T>(&DeviceType::Cpu, src_buffer.count)?;
686
687 unsafe {
689 src_device.copy_to_host(
690 src_buffer.address,
691 staging_buffer.address as *mut u8,
692 size,
693 )?;
694 }
695
696 unsafe {
698 dst_device_obj.copy_from_host(
699 staging_buffer.address as *const u8,
700 dst_buffer.address,
701 size,
702 )?;
703 }
704 }
705
706 Ok(dst_buffer)
707 }
708
709 pub fn synchronize_all(&self) -> CoreResult<()> {
711 let devices = self.devices.read().expect("Operation failed");
712 for device in devices.values() {
713 device.synchronize()?;
714 }
715 Ok(())
716 }
717
718 pub fn get_memory_statistics(&self) -> MemoryStatistics {
720 let allocations = self.allocations.read().expect("Operation failed");
721 let devices = self.devices.read().expect("Operation failed");
722
723 let mut stats_by_device = HashMap::new();
724 let mut total_allocated = 0;
725 let mut total_allocations = 0;
726
727 for allocation in allocations.values() {
728 let device_stats =
729 stats_by_device
730 .entry(allocation.device.clone())
731 .or_insert(DeviceMemoryStats {
732 device_type: allocation.device.clone(),
733 allocated_bytes: 0,
734 allocation_count: 0,
735 available_bytes: 0,
736 total_bytes: 0,
737 });
738
739 device_stats.allocated_bytes += allocation.size;
740 device_stats.allocation_count += 1;
741 total_allocated += allocation.size;
742 total_allocations += 1;
743 }
744
745 for (device_type, device) in devices.iter() {
747 let device_stats =
748 stats_by_device
749 .entry(device_type.clone())
750 .or_insert(DeviceMemoryStats {
751 device_type: device_type.clone(),
752 allocated_bytes: 0,
753 allocation_count: 0,
754 available_bytes: 0,
755 total_bytes: 0,
756 });
757
758 device_stats.available_bytes = device.available_memory().unwrap_or(0);
759 device_stats.total_bytes = device.total_memory().unwrap_or(0);
760 }
761
762 MemoryStatistics {
763 total_allocated_bytes: total_allocated,
764 total_allocations,
765 device_stats: stats_by_device.into_values().collect(),
766 }
767 }
768
769 pub fn cleanup_unused_allocations(&self, maxage: std::time::Duration) -> usize {
771 let mut allocations = self.allocations.write().expect("Operation failed");
772 let now = std::time::Instant::now();
773 let mut cleaned = 0;
774
775 allocations.retain(|_, allocation| {
776 if allocation.ref_count == 0 && now.duration_since(allocation.last_accessed) > maxage {
777 cleaned += 1;
779 false
780 } else {
781 true
782 }
783 });
784
785 cleaned
786 }
787
788 fn generate_allocation_id(&self) -> String {
790 let counter = {
791 let mut counter = self.allocation_counter.lock().expect("Operation failed");
792 *counter += 1;
793 *counter
794 };
795
796 format!("{counter:016x}")
797 }
798
799 pub(crate) fn remove_allocation(&self, allocationid: &str) {
801 let mut allocations = self.allocations.write().expect("Operation failed");
802 if let Some(allocation) = allocations.get_mut(allocationid) {
803 if allocation.remove_ref() == 0 {
804 allocations.remove(allocationid);
805 }
806 }
807 }
808
809 pub(crate) fn touch_allocation(&self, allocationid: &str) {
811 let mut allocations = self.allocations.write().expect("Operation failed");
812 if let Some(allocation) = allocations.get_mut(allocationid) {
813 allocation.touch();
814 }
815 }
816}
817
818impl Default for CrossDeviceMemoryManager {
819 fn default() -> Self {
820 Self::new()
821 }
822}
823
824pub struct CrossDeviceBuffer<T> {
826 allocation_id: String,
827 device_type: DeviceType,
828 address: usize,
829 count: usize,
830 manager: Arc<CrossDeviceMemoryManager>,
831 phantom: std::marker::PhantomData<T>,
832}
833
834impl<T> CrossDeviceBuffer<T> {
835 fn new(
837 allocation_id: String,
838 device_type: DeviceType,
839 address: usize,
840 count: usize,
841 manager: Arc<CrossDeviceMemoryManager>,
842 ) -> Self {
843 Self {
844 allocation_id,
845 device_type,
846 address,
847 count,
848 manager,
849 phantom: std::marker::PhantomData,
850 }
851 }
852
853 pub const fn device_type(&self) -> &DeviceType {
855 &self.device_type
856 }
857
858 pub fn len(&self) -> usize {
860 self.count
861 }
862
863 pub fn is_empty(&self) -> bool {
865 self.count == 0
866 }
867
868 pub fn size_bytes(&self) -> usize {
870 self.count * std::mem::size_of::<T>()
871 }
872
873 pub fn raw_address(&self) -> usize {
875 self.manager.touch_allocation(&self.allocation_id);
876 self.address
877 }
878
879 pub fn to_device(&self, devicetype: &DeviceType) -> CoreResult<CrossDeviceBuffer<T>>
881 where
882 T: Copy + 'static,
883 {
884 self.manager.transfer(self, devicetype)
885 }
886
887 pub fn copy_from_host(&self, data: &[T]) -> CoreResult<()>
889 where
890 T: Copy,
891 {
892 if data.len() != self.count {
893 return Err(CrossDeviceError::InvalidDeviceType(format!(
894 "Data length {} doesn't match buffer capacity {}",
895 data.len(),
896 self.count
897 ))
898 .into());
899 }
900
901 let devices = self.manager.devices.read().expect("Operation failed");
902 let device = devices
903 .get(&self.device_type)
904 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
905
906 unsafe {
907 device.copy_from_host(data.as_ptr() as *const u8, self.address, self.size_bytes())?;
908 }
909
910 self.manager.touch_allocation(&self.allocation_id);
911 Ok(())
912 }
913
914 pub fn copy_to_host(&self) -> CoreResult<Vec<T>>
916 where
917 T: Copy + Default,
918 {
919 let mut result = vec![T::default(); self.count];
920
921 let devices = self.manager.devices.read().expect("Operation failed");
922 let device = devices
923 .get(&self.device_type)
924 .ok_or_else(|| CrossDeviceError::DeviceNotFound(format!("{0:?}", self.device_type)))?;
925
926 unsafe {
927 device.copy_to_host(
928 self.address,
929 result.as_mut_ptr() as *mut u8,
930 self.size_bytes(),
931 )?;
932 }
933
934 self.manager.touch_allocation(&self.allocation_id);
935 Ok(result)
936 }
937}
938
939impl<T> Clone for CrossDeviceBuffer<T> {
940 fn clone(&self) -> Self {
941 {
943 let mut allocations = self.manager.allocations.write().expect("Operation failed");
944 if let Some(allocation) = allocations.get_mut(&self.allocation_id) {
945 allocation.add_ref();
946 }
947 }
948
949 Self {
950 allocation_id: self.allocation_id.clone(),
951 device_type: self.device_type.clone(),
952 address: self.address,
953 count: self.count,
954 manager: self.manager.clone(),
955 phantom: std::marker::PhantomData,
956 }
957 }
958}
959
960impl<T> Drop for CrossDeviceBuffer<T> {
961 fn drop(&mut self) {
962 self.manager.remove_allocation(&self.allocation_id);
963 }
964}
965
966#[derive(Debug, Clone)]
968pub struct MemoryStatistics {
969 pub total_allocated_bytes: usize,
971 pub total_allocations: usize,
973 pub device_stats: Vec<DeviceMemoryStats>,
975}
976
977#[derive(Debug, Clone)]
979pub struct DeviceMemoryStats {
980 pub device_type: DeviceType,
982 pub allocated_bytes: usize,
984 pub allocation_count: usize,
986 pub available_bytes: usize,
988 pub total_bytes: usize,
990}
991
992impl DeviceMemoryStats {
993 pub fn usage_percentage(&self) -> f64 {
995 if self.total_bytes == 0 {
996 0.0
997 } else {
998 (self.allocated_bytes as f64 / self.total_bytes as f64) * 100.0
999 }
1000 }
1001}
1002
1003static GLOBAL_MANAGER: std::sync::OnceLock<Arc<CrossDeviceMemoryManager>> =
1005 std::sync::OnceLock::new();
1006
1007#[allow(dead_code)]
1009pub fn global_manager() -> Arc<CrossDeviceMemoryManager> {
1010 GLOBAL_MANAGER
1011 .get_or_init(|| {
1012 let manager = Arc::new(CrossDeviceMemoryManager::new());
1013
1014 let cpu_device = Arc::new(CpuDevice::new());
1016 let _ = manager.register_device(cpu_device);
1017
1018 manager
1019 })
1020 .clone()
1021}
1022
1023#[allow(dead_code)]
1025pub fn initialize_with_gpu_devices(gpudevices: Vec<Arc<GpuContext>>) -> CoreResult<()> {
1026 let manager = global_manager();
1027
1028 for (i, gpu_device) in gpudevices.into_iter().enumerate() {
1029 let device_type = DeviceType::CudaGpu(i as u32); let wrapper = Arc::new(GpuContextWrapper::new(gpu_device, device_type));
1031 manager.register_device(wrapper)?;
1032 }
1033
1034 Ok(())
1035}
1036
1037pub mod utils {
1039 use super::*;
1040
1041 pub fn allocate_optimal<T: 'static>(count: usize) -> CoreResult<CrossDeviceBuffer<T>> {
1043 let manager = global_manager();
1044 let stats = manager.get_memory_statistics();
1045
1046 let best_device = stats
1048 .device_stats
1049 .iter()
1050 .max_by_key(|s| s.available_bytes)
1051 .map(|s| s.device_type.clone())
1052 .unwrap_or(DeviceType::Cpu);
1053
1054 manager.allocate(&best_device, count)
1055 }
1056
1057 pub fn create_buffer_with_data<T: Copy + 'static>(
1059 data: &[T],
1060 device_type: &DeviceType,
1061 ) -> CoreResult<CrossDeviceBuffer<T>> {
1062 let manager = global_manager();
1063 let buffer = manager.allocate(device_type, data.len())?;
1064 buffer.copy_from_host(data)?;
1065 Ok(buffer)
1066 }
1067
1068 pub fn transfer_data<T: Copy + 'static>(
1070 src_buffer: &CrossDeviceBuffer<T>,
1071 dst_device: &DeviceType,
1072 ) -> CoreResult<CrossDeviceBuffer<T>> {
1073 let manager = global_manager();
1074 manager.transfer(src_buffer, dst_device)
1075 }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use super::*;
1081
1082 #[test]
1083 fn test_device_type_creation() {
1084 let cpu = DeviceType::Cpu;
1085 let gpu = DeviceType::CudaGpu(0);
1086 let tpu = DeviceType::Tpu(1);
1087
1088 assert_eq!(cpu.as_str(), "CPU");
1089 assert_eq!(gpu.as_str(), "CUDA_GPU");
1090 assert_eq!(tpu.as_str(), "TPU");
1091
1092 assert_eq!(cpu.device_id(), 0);
1093 assert_eq!(gpu.device_id(), 0);
1094 assert_eq!(tpu.device_id(), 1);
1095 }
1096
1097 #[test]
1098 fn test_device_capabilities() {
1099 let cpu = DeviceType::Cpu;
1100 let cuda = DeviceType::CudaGpu(0);
1101 let rocm = DeviceType::RocmGpu(0);
1102
1103 assert!(!cpu.supports_unified_memory());
1104 assert!(cuda.supports_unified_memory());
1105 assert!(rocm.supports_unified_memory());
1106
1107 assert!(cuda.supports_p2p_transfer(&DeviceType::CudaGpu(1)));
1108 assert!(!cuda.supports_p2p_transfer(&DeviceType::RocmGpu(0)));
1109 assert!(!cpu.supports_p2p_transfer(&DeviceType::CudaGpu(0)));
1110 }
1111
1112 #[test]
1113 fn test_memory_allocation_creation() {
1114 let allocation = MemoryAllocation::new(
1115 "test_alloc".to_string(),
1116 DeviceType::Cpu,
1117 1024,
1118 0x1000,
1119 TypeId::of::<f32>(),
1120 );
1121
1122 assert_eq!(allocation.id, "test_alloc");
1123 assert_eq!(allocation.size, 1024);
1124 assert_eq!(allocation.address, 0x1000);
1125 assert_eq!(allocation.ref_count, 1);
1126 }
1127
1128 #[test]
1129 fn test_cpu_device() {
1130 let cpu = CpuDevice::new();
1131 assert_eq!(cpu.device_type(), DeviceType::Cpu);
1132
1133 assert!(cpu.available_memory().is_ok());
1135 assert!(cpu.total_memory().is_ok());
1136
1137 assert!(cpu.synchronize().is_ok());
1139 }
1140
1141 #[test]
1142 fn test_cross_device_manager() {
1143 let manager = CrossDeviceMemoryManager::new();
1144
1145 let cpu_device = Arc::new(CpuDevice::new());
1147 assert!(manager.register_device(cpu_device).is_ok());
1148
1149 assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
1151
1152 let stats = manager.get_memory_statistics();
1154 assert_eq!(stats.total_allocations, 0);
1155 assert_eq!(stats.total_allocated_bytes, 0);
1156 }
1157
1158 #[test]
1159 fn test_global_manager() {
1160 let manager = global_manager();
1161 assert_eq!(manager.get_default_device(), Some(DeviceType::Cpu));
1162
1163 let stats = manager.get_memory_statistics();
1164 assert!(!stats.device_stats.is_empty());
1165 }
1166
1167 #[test]
1168 fn test_memory_statistics() {
1169 let stats = DeviceMemoryStats {
1170 device_type: DeviceType::Cpu,
1171 allocated_bytes: 1024,
1172 allocation_count: 1,
1173 available_bytes: 7 * 1024 * 1024 * 1024,
1174 total_bytes: 8 * 1024 * 1024 * 1024,
1175 };
1176
1177 let usage = stats.usage_percentage();
1178 assert!(usage > 0.0 && usage < 1.0);
1179 }
1180}