1use crate::device::{Device, DeviceType};
7use std::marker::PhantomData;
8
9pub trait PhantomDevice: 'static + std::fmt::Debug + Send + Sync {
14 const DEVICE_TYPE: DeviceType;
16
17 fn device_type() -> DeviceType {
19 Self::DEVICE_TYPE
20 }
21
22 fn is_compatible<Other: PhantomDevice>() -> bool {
24 Self::DEVICE_TYPE == Other::DEVICE_TYPE
25 }
26
27 fn device_name() -> &'static str;
29
30 fn requires_gpu() -> bool {
32 !matches!(Self::DEVICE_TYPE, DeviceType::Cpu)
33 }
34
35 fn supports_p2p() -> bool {
37 matches!(Self::DEVICE_TYPE, DeviceType::Cuda(_))
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub struct PhantomCpu;
44
45impl PhantomDevice for PhantomCpu {
46 const DEVICE_TYPE: DeviceType = DeviceType::Cpu;
47
48 fn device_name() -> &'static str {
49 "CPU"
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct PhantomCuda<const INDEX: usize>;
56
57impl<const INDEX: usize> PhantomDevice for PhantomCuda<INDEX> {
58 const DEVICE_TYPE: DeviceType = DeviceType::Cuda(INDEX);
59
60 fn device_name() -> &'static str {
61 "CUDA"
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub struct PhantomMetal<const INDEX: usize>;
68
69impl<const INDEX: usize> PhantomDevice for PhantomMetal<INDEX> {
70 const DEVICE_TYPE: DeviceType = DeviceType::Metal(INDEX);
71
72 fn device_name() -> &'static str {
73 "Metal"
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub struct PhantomWgpu<const INDEX: usize>;
80
81impl<const INDEX: usize> PhantomDevice for PhantomWgpu<INDEX> {
82 const DEVICE_TYPE: DeviceType = DeviceType::Wgpu(INDEX);
83
84 fn device_name() -> &'static str {
85 "WebGPU"
86 }
87}
88
89#[derive(Debug)]
110pub struct DeviceHandle<P: PhantomDevice> {
111 device: Box<dyn Device>,
112 _phantom: PhantomData<P>,
113}
114
115impl<P: PhantomDevice> DeviceHandle<P> {
116 pub fn new(device: Box<dyn Device>) -> Result<Self, crate::error::TorshError> {
118 if device.device_type() != P::DEVICE_TYPE {
119 return Err(crate::error::TorshError::InvalidArgument(format!(
120 "Device type mismatch: expected {:?}, got {:?}",
121 P::DEVICE_TYPE,
122 device.device_type()
123 )));
124 }
125
126 Ok(Self {
127 device,
128 _phantom: PhantomData,
129 })
130 }
131
132 pub unsafe fn new_unchecked(device: Box<dyn Device>) -> Self {
137 Self {
138 device,
139 _phantom: PhantomData,
140 }
141 }
142
143 pub fn device(&self) -> &dyn Device {
145 self.device.as_ref()
146 }
147
148 pub fn device_mut(&mut self) -> &mut dyn Device {
150 self.device.as_mut()
151 }
152
153 pub fn phantom_device_type() -> DeviceType {
155 P::DEVICE_TYPE
156 }
157
158 pub const fn is_cpu() -> bool {
160 matches!(P::DEVICE_TYPE, DeviceType::Cpu)
161 }
162
163 pub const fn is_gpu() -> bool {
165 !matches!(P::DEVICE_TYPE, DeviceType::Cpu)
166 }
167
168 pub const fn is_cuda() -> bool {
170 matches!(P::DEVICE_TYPE, DeviceType::Cuda(_))
171 }
172
173 pub const fn is_metal() -> bool {
175 matches!(P::DEVICE_TYPE, DeviceType::Metal(_))
176 }
177
178 pub const fn is_wgpu() -> bool {
180 matches!(P::DEVICE_TYPE, DeviceType::Wgpu(_))
181 }
182
183 pub fn cast<Q: PhantomDevice>(
185 self,
186 ) -> Result<DeviceHandle<Q>, (Self, crate::error::TorshError)> {
187 if self.device.device_type() != Q::DEVICE_TYPE {
188 let error = crate::error::TorshError::InvalidArgument(format!(
189 "Cannot cast device from {:?} to {:?}",
190 P::DEVICE_TYPE,
191 Q::DEVICE_TYPE
192 ));
193 return Err((self, error));
194 }
195
196 Ok(DeviceHandle {
197 device: self.device,
198 _phantom: PhantomData,
199 })
200 }
201
202 pub unsafe fn cast_unchecked<Q: PhantomDevice>(self) -> DeviceHandle<Q> {
207 DeviceHandle {
208 device: self.device,
209 _phantom: PhantomData,
210 }
211 }
212
213 pub fn into_device(self) -> Box<dyn Device> {
215 self.device
216 }
217}
218
219impl<P: PhantomDevice> Clone for DeviceHandle<P> {
220 fn clone(&self) -> Self {
221 let cloned_device = self.device.clone_device().expect("Failed to clone device");
222
223 Self {
224 device: cloned_device,
225 _phantom: PhantomData,
226 }
227 }
228}
229
230pub trait DeviceCompatible<Other> {
235 const COMPATIBLE: bool;
237
238 fn compatibility_info() -> &'static str;
240}
241
242impl<P: PhantomDevice> DeviceCompatible<P> for P {
243 const COMPATIBLE: bool = true;
244
245 fn compatibility_info() -> &'static str {
246 "Same device type - always compatible"
247 }
248}
249
250pub trait DeviceOperation<P: PhantomDevice> {
257 type Output;
259
260 type Requirements: DeviceRequirements;
262
263 fn execute(device: &DeviceHandle<P>) -> Result<Self::Output, crate::error::TorshError>;
265
266 const SUPPORTED: bool = Self::Requirements::SATISFIED_BY_DEVICE;
268}
269
270pub trait DeviceRequirements {
272 const SATISFIED_BY_DEVICE: bool;
274
275 fn description() -> &'static str;
277}
278
279#[derive(Debug, Clone, Copy)]
281pub struct RequiresGpu;
282
283impl DeviceRequirements for RequiresGpu {
284 const SATISFIED_BY_DEVICE: bool = false; fn description() -> &'static str {
287 "Requires GPU device"
288 }
289}
290
291#[derive(Debug, Clone, Copy)]
293pub struct RequiresCpu;
294
295impl DeviceRequirements for RequiresCpu {
296 const SATISFIED_BY_DEVICE: bool = false; fn description() -> &'static str {
299 "Requires CPU device"
300 }
301}
302
303#[derive(Debug, Clone, Copy)]
305pub struct RequiresCuda;
306
307impl DeviceRequirements for RequiresCuda {
308 const SATISFIED_BY_DEVICE: bool = false; fn description() -> &'static str {
311 "Requires CUDA device"
312 }
313}
314
315#[derive(Debug, Clone, Copy)]
317pub struct NoRequirements;
318
319impl DeviceRequirements for NoRequirements {
320 const SATISFIED_BY_DEVICE: bool = true;
321
322 fn description() -> &'static str {
323 "No specific device requirements"
324 }
325}
326
327#[derive(Debug)]
329pub struct SameDevice<P1: PhantomDevice, P2: PhantomDevice> {
330 _phantom: PhantomData<(P1, P2)>,
331}
332
333impl<P1: PhantomDevice, P2: PhantomDevice> SameDevice<P1, P2> {
334 pub fn is_satisfied() -> bool {
336 match (P1::DEVICE_TYPE, P2::DEVICE_TYPE) {
337 (DeviceType::Cpu, DeviceType::Cpu) => true,
338 (DeviceType::Cuda(a), DeviceType::Cuda(b)) => a == b,
339 (DeviceType::Metal(a), DeviceType::Metal(b)) => a == b,
340 (DeviceType::Wgpu(a), DeviceType::Wgpu(b)) => a == b,
341 _ => false,
342 }
343 }
344}
345
346#[derive(Debug)]
348pub struct TransferCompatible<P1: PhantomDevice, P2: PhantomDevice> {
349 _phantom: PhantomData<(P1, P2)>,
350}
351
352impl<P1: PhantomDevice, P2: PhantomDevice> TransferCompatible<P1, P2> {
353 pub const SUPPORTED: bool = true; pub fn transfer_cost() -> u32 {
358 match (P1::DEVICE_TYPE, P2::DEVICE_TYPE) {
359 (DeviceType::Cpu, DeviceType::Cpu) => 0,
360 (DeviceType::Cuda(a), DeviceType::Cuda(b)) if a == b => 0,
361 (DeviceType::Metal(a), DeviceType::Metal(b)) if a == b => 0,
362 (DeviceType::Wgpu(a), DeviceType::Wgpu(b)) if a == b => 0,
363 (DeviceType::Cpu, DeviceType::Cuda(_)) => 100,
364 (DeviceType::Cuda(_), DeviceType::Cpu) => 100,
365 (DeviceType::Cpu, DeviceType::Metal(_)) => 80,
366 (DeviceType::Metal(_), DeviceType::Cpu) => 80,
367 _ => 200, }
369 }
370}
371
372#[derive(Debug)]
374pub struct PhantomDeviceManager<P: PhantomDevice> {
375 handles: Vec<DeviceHandle<P>>,
376 _phantom: PhantomData<P>,
377}
378
379impl<P: PhantomDevice> PhantomDeviceManager<P> {
380 pub fn new() -> Self {
382 Self {
383 handles: Vec::new(),
384 _phantom: PhantomData,
385 }
386 }
387
388 pub fn add_device(&mut self, handle: DeviceHandle<P>) {
390 self.handles.push(handle);
391 }
392
393 pub fn device_count(&self) -> usize {
395 self.handles.len()
396 }
397
398 pub fn get_device(&self, index: usize) -> Option<&DeviceHandle<P>> {
400 self.handles.get(index)
401 }
402
403 pub fn get_device_mut(&mut self, index: usize) -> Option<&mut DeviceHandle<P>> {
405 self.handles.get_mut(index)
406 }
407
408 pub fn remove_device(&mut self, index: usize) -> Option<DeviceHandle<P>> {
410 if index < self.handles.len() {
411 Some(self.handles.remove(index))
412 } else {
413 None
414 }
415 }
416
417 pub fn devices(&self) -> &[DeviceHandle<P>] {
419 &self.handles
420 }
421
422 pub fn clear(&mut self) {
424 self.handles.clear();
425 }
426
427 pub fn execute_on_all<Op>(
429 &self,
430 _operation: Op,
431 ) -> Vec<Result<Op::Output, crate::error::TorshError>>
432 where
433 Op: DeviceOperation<P> + Clone,
434 {
435 self.handles
436 .iter()
437 .map(|handle| Op::execute(handle))
438 .collect()
439 }
440}
441
442impl<P: PhantomDevice> Default for PhantomDeviceManager<P> {
443 fn default() -> Self {
444 Self::new()
445 }
446}
447
448pub mod utils {
450 use super::*;
451
452 pub fn create_phantom_handle<P: PhantomDevice>(
454 device: Box<dyn Device>,
455 ) -> Result<DeviceHandle<P>, crate::error::TorshError> {
456 DeviceHandle::<P>::new(device)
457 }
458
459 pub fn check_phantom_compatibility<P1: PhantomDevice, P2: PhantomDevice>() -> bool {
461 P1::DEVICE_TYPE == P2::DEVICE_TYPE
462 }
463
464 pub fn phantom_transfer_cost<P1: PhantomDevice, P2: PhantomDevice>() -> u32 {
466 TransferCompatible::<P1, P2>::transfer_cost()
467 }
468
469 pub fn create_phantom_manager<P: PhantomDevice>() -> PhantomDeviceManager<P> {
471 PhantomDeviceManager::new()
472 }
473
474 pub fn verify_operation_support<P: PhantomDevice, Op: DeviceOperation<P>>() -> bool {
476 Op::SUPPORTED
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::device::core::Device;
484 use std::any::Any;
485
486 #[derive(Debug)]
488 struct MockDevice {
489 device_type: DeviceType,
490 }
491
492 impl MockDevice {
493 fn new(device_type: DeviceType) -> Self {
494 Self { device_type }
495 }
496 }
497
498 impl Device for MockDevice {
499 fn device_type(&self) -> DeviceType {
500 self.device_type
501 }
502
503 fn name(&self) -> &str {
504 "Mock Device"
505 }
506
507 fn is_available(&self) -> Result<bool, crate::error::TorshError> {
508 Ok(true)
509 }
510
511 fn capabilities(
512 &self,
513 ) -> Result<crate::device::DeviceCapabilities, crate::error::TorshError> {
514 crate::device::DeviceCapabilities::detect(self.device_type)
515 }
516
517 fn synchronize(&self) -> Result<(), crate::error::TorshError> {
518 Ok(())
519 }
520
521 fn reset(&self) -> Result<(), crate::error::TorshError> {
522 Ok(())
523 }
524
525 fn as_any(&self) -> &dyn Any {
526 self
527 }
528
529 fn as_any_mut(&mut self) -> &mut dyn Any {
530 self
531 }
532
533 fn clone_device(&self) -> Result<Box<dyn Device>, crate::error::TorshError> {
534 Ok(Box::new(MockDevice::new(self.device_type)))
535 }
536 }
537
538 #[test]
539 fn test_phantom_device_markers() {
540 assert_eq!(PhantomCpu::device_type(), DeviceType::Cpu);
541 assert_eq!(PhantomCuda::<0>::device_type(), DeviceType::Cuda(0));
542 assert_eq!(PhantomMetal::<1>::device_type(), DeviceType::Metal(1));
543 assert_eq!(PhantomWgpu::<2>::device_type(), DeviceType::Wgpu(2));
544
545 assert_eq!(PhantomCpu::device_name(), "CPU");
546 assert_eq!(PhantomCuda::<0>::device_name(), "CUDA");
547 assert_eq!(PhantomMetal::<0>::device_name(), "Metal");
548 assert_eq!(PhantomWgpu::<0>::device_name(), "WebGPU");
549 }
550
551 #[test]
552 fn test_phantom_device_properties() {
553 assert!(!PhantomCpu::requires_gpu());
554 assert!(PhantomCuda::<0>::requires_gpu());
555 assert!(PhantomMetal::<0>::requires_gpu());
556 assert!(PhantomWgpu::<0>::requires_gpu());
557
558 assert!(!PhantomCpu::supports_p2p());
559 assert!(PhantomCuda::<0>::supports_p2p());
560 assert!(!PhantomMetal::<0>::supports_p2p());
561 assert!(!PhantomWgpu::<0>::supports_p2p());
562 }
563
564 #[test]
565 fn test_device_handle() {
566 let mock_device = Box::new(MockDevice::new(DeviceType::Cpu));
567 let _handle =
568 DeviceHandle::<PhantomCpu>::new(mock_device).expect("DeviceHandle::new should succeed");
569
570 assert_eq!(
571 DeviceHandle::<PhantomCpu>::phantom_device_type(),
572 DeviceType::Cpu
573 );
574 assert!(DeviceHandle::<PhantomCpu>::is_cpu());
575 assert!(!DeviceHandle::<PhantomCpu>::is_gpu());
576 assert!(!DeviceHandle::<PhantomCpu>::is_cuda());
577 }
578
579 #[test]
580 fn test_device_handle_type_mismatch() {
581 let mock_device = Box::new(MockDevice::new(DeviceType::Cuda(0)));
582 let result = DeviceHandle::<PhantomCpu>::new(mock_device);
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_device_compatibility() {
588 assert!(PhantomCpu::is_compatible::<PhantomCpu>());
589 assert!(!PhantomCpu::is_compatible::<PhantomCuda<0>>());
590 assert!(PhantomCuda::<0>::is_compatible::<PhantomCuda<0>>());
591 assert!(!PhantomCuda::<0>::is_compatible::<PhantomCuda<1>>());
592 }
593
594 #[test]
595 fn test_phantom_device_manager() {
596 let mut manager = PhantomDeviceManager::<PhantomCpu>::new();
597 assert_eq!(manager.device_count(), 0);
598
599 let mock_device = Box::new(MockDevice::new(DeviceType::Cpu));
600 let handle =
601 DeviceHandle::<PhantomCpu>::new(mock_device).expect("DeviceHandle::new should succeed");
602 manager.add_device(handle);
603
604 assert_eq!(manager.device_count(), 1);
605 assert!(manager.get_device(0).is_some());
606 assert!(manager.get_device(1).is_none());
607
608 let removed = manager.remove_device(0);
609 assert!(removed.is_some());
610 assert_eq!(manager.device_count(), 0);
611 }
612
613 #[test]
614 fn test_transfer_cost_constants() {
615 assert_eq!(
616 TransferCompatible::<PhantomCpu, PhantomCpu>::transfer_cost(),
617 0
618 );
619 assert_eq!(
620 TransferCompatible::<PhantomCpu, PhantomCuda<0>>::transfer_cost(),
621 100
622 );
623 assert_eq!(
624 TransferCompatible::<PhantomCpu, PhantomMetal<0>>::transfer_cost(),
625 80
626 );
627 }
628
629 #[test]
630 fn test_utils_functions() {
631 assert!(utils::check_phantom_compatibility::<PhantomCpu, PhantomCpu>());
632 assert!(!utils::check_phantom_compatibility::<
633 PhantomCpu,
634 PhantomCuda<0>,
635 >());
636
637 let cost = utils::phantom_transfer_cost::<PhantomCpu, PhantomCuda<0>>();
638 assert_eq!(cost, 100);
639
640 let manager = utils::create_phantom_manager::<PhantomCpu>();
641 assert_eq!(manager.device_count(), 0);
642 }
643}
644
645#[derive(Debug)]
654pub struct DeviceGroup<P: PhantomDevice, const N: usize> {
655 devices: [DeviceHandle<P>; N],
656}
657
658impl<P: PhantomDevice, const N: usize> DeviceGroup<P, N> {
659 pub fn new(devices: [DeviceHandle<P>; N]) -> Self {
661 Self { devices }
662 }
663
664 pub const fn device_count() -> usize {
666 N
667 }
668
669 pub fn get(&self, index: usize) -> Option<&DeviceHandle<P>> {
671 self.devices.get(index)
672 }
673
674 pub fn get_mut(&mut self, index: usize) -> Option<&mut DeviceHandle<P>> {
676 self.devices.get_mut(index)
677 }
678
679 pub fn iter(&self) -> impl Iterator<Item = &DeviceHandle<P>> {
681 self.devices.iter()
682 }
683
684 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut DeviceHandle<P>> {
686 self.devices.iter_mut()
687 }
688
689 pub fn parallel_execute<F, R>(&self, f: F) -> Vec<R>
691 where
692 F: Fn(&DeviceHandle<P>) -> R + Sync,
693 R: Send,
694 {
695 self.devices.iter().map(f).collect()
696 }
697
698 pub fn supports_p2p() -> bool {
700 P::supports_p2p()
701 }
702
703 pub fn group_type_name() -> String {
705 format!("DeviceGroup<{}, {}>", P::device_name(), N)
706 }
707}
708
709pub trait PeerToPeerOps<Other: PhantomDevice>: PhantomDevice {
714 const P2P_SUPPORTED: bool;
716
717 fn p2p_bandwidth() -> u32;
719
720 fn p2p_latency() -> u32;
722}
723
724impl<const I1: usize, const I2: usize> PeerToPeerOps<PhantomCuda<I2>> for PhantomCuda<I1> {
726 const P2P_SUPPORTED: bool = true;
727
728 fn p2p_bandwidth() -> u32 {
729 if I1 == I2 {
730 0 } else {
732 50_000 }
735 }
736
737 fn p2p_latency() -> u32 {
738 if I1 == I2 {
739 0
740 } else {
741 5 }
743 }
744}
745
746impl<const I1: usize, const I2: usize> PeerToPeerOps<PhantomMetal<I2>> for PhantomMetal<I1> {
748 const P2P_SUPPORTED: bool = I1 == I2; fn p2p_bandwidth() -> u32 {
751 if I1 == I2 {
752 400_000 } else {
754 0
755 }
756 }
757
758 fn p2p_latency() -> u32 {
759 if I1 == I2 {
760 1 } else {
762 0
763 }
764 }
765}
766
767pub trait DeviceTopology {
772 const DEVICE_COUNT: usize;
774
775 const SUPPORTS_ALLREDUCE: bool;
777
778 const SUPPORTS_BROADCAST: bool;
780
781 fn topology_name() -> &'static str;
783
784 fn allreduce_bandwidth() -> u32;
786}
787
788#[derive(Debug)]
790pub struct RingTopology<P: PhantomDevice, const N: usize> {
791 _phantom: PhantomData<P>,
792}
793
794impl<P: PhantomDevice, const N: usize> DeviceTopology for RingTopology<P, N> {
795 const DEVICE_COUNT: usize = N;
796 const SUPPORTS_ALLREDUCE: bool = true;
797 const SUPPORTS_BROADCAST: bool = true;
798
799 fn topology_name() -> &'static str {
800 "Ring"
801 }
802
803 fn allreduce_bandwidth() -> u32 {
804 if P::supports_p2p() {
806 25_000 } else {
808 5_000 }
810 }
811}
812
813#[derive(Debug)]
815pub struct TreeTopology<P: PhantomDevice, const N: usize> {
816 _phantom: PhantomData<P>,
817}
818
819impl<P: PhantomDevice, const N: usize> DeviceTopology for TreeTopology<P, N> {
820 const DEVICE_COUNT: usize = N;
821 const SUPPORTS_ALLREDUCE: bool = true;
822 const SUPPORTS_BROADCAST: bool = true;
823
824 fn topology_name() -> &'static str {
825 "Tree"
826 }
827
828 fn allreduce_bandwidth() -> u32 {
829 if P::supports_p2p() {
831 40_000 } else {
833 8_000 }
835 }
836}
837
838#[derive(Debug)]
840pub struct AllToAllTopology<P: PhantomDevice, const N: usize> {
841 _phantom: PhantomData<P>,
842}
843
844impl<P: PhantomDevice, const N: usize> DeviceTopology for AllToAllTopology<P, N> {
845 const DEVICE_COUNT: usize = N;
846 const SUPPORTS_ALLREDUCE: bool = true;
847 const SUPPORTS_BROADCAST: bool = true;
848
849 fn topology_name() -> &'static str {
850 "AllToAll"
851 }
852
853 fn allreduce_bandwidth() -> u32 {
854 if P::supports_p2p() {
856 100_000 } else {
858 15_000 }
860 }
861}
862
863#[derive(Debug)]
868pub struct TypedDeviceAffinity<P: PhantomDevice> {
869 device_handle: DeviceHandle<P>,
870 preferred_numa_node: Option<usize>,
871 cpu_affinity: Option<Vec<usize>>,
872}
873
874impl<P: PhantomDevice> TypedDeviceAffinity<P> {
875 pub fn new(device_handle: DeviceHandle<P>) -> Self {
877 Self {
878 device_handle,
879 preferred_numa_node: None,
880 cpu_affinity: None,
881 }
882 }
883
884 pub fn with_numa_node(mut self, node: usize) -> Self {
886 self.preferred_numa_node = Some(node);
887 self
888 }
889
890 pub fn with_cpu_affinity(mut self, cpus: Vec<usize>) -> Self {
892 self.cpu_affinity = Some(cpus);
893 self
894 }
895
896 pub fn device(&self) -> &DeviceHandle<P> {
898 &self.device_handle
899 }
900
901 pub fn numa_node(&self) -> Option<usize> {
903 self.preferred_numa_node
904 }
905
906 pub fn cpu_affinity(&self) -> Option<&[usize]> {
908 self.cpu_affinity.as_deref()
909 }
910
911 pub const fn is_cpu_device() -> bool {
913 matches!(P::DEVICE_TYPE, DeviceType::Cpu)
914 }
915
916 pub fn locality_score(&self, target_numa: usize) -> u32 {
918 match self.preferred_numa_node {
919 Some(node) if node == target_numa => 100,
920 Some(_) => 30, None => 50, }
923 }
924}
925
926#[derive(Debug)]
931pub struct CrossDeviceOp<PSrc: PhantomDevice, PDst: PhantomDevice> {
932 _phantom: PhantomData<(PSrc, PDst)>,
933}
934
935impl<PSrc: PhantomDevice, PDst: PhantomDevice> CrossDeviceOp<PSrc, PDst> {
936 pub const SUPPORTED: bool = TransferCompatible::<PSrc, PDst>::SUPPORTED;
938
939 pub fn transfer_cost() -> u32 {
941 TransferCompatible::<PSrc, PDst>::transfer_cost()
942 }
943
944 pub fn transfer_strategy() -> &'static str {
946 match (PSrc::DEVICE_TYPE, PDst::DEVICE_TYPE) {
947 (DeviceType::Cpu, DeviceType::Cpu) => "memcpy",
948 (DeviceType::Cuda(_), DeviceType::Cuda(_)) => "peer-to-peer",
949 (DeviceType::Cpu, DeviceType::Cuda(_)) => "pinned-transfer",
950 (DeviceType::Cuda(_), DeviceType::Cpu) => "staged-readback",
951 (DeviceType::Metal(_), DeviceType::Metal(_)) => "unified-memory",
952 (DeviceType::Cpu, DeviceType::Metal(_)) => "shared-memory",
953 _ => "staged-transfer",
954 }
955 }
956
957 pub const fn supports_zero_copy() -> bool {
959 matches!(
960 (PSrc::DEVICE_TYPE, PDst::DEVICE_TYPE),
961 (DeviceType::Metal(_), DeviceType::Metal(_))
962 )
963 }
964}
965
966pub mod compile_time {
968 use super::*;
969
970 pub fn assert_same_device<P1: PhantomDevice, P2: PhantomDevice>() {
972 if !SameDevice::<P1, P2>::is_satisfied() {
973 panic!("Device types must match");
974 }
975 }
976
977 pub fn assert_gpu<P: PhantomDevice>() {
979 if !P::requires_gpu() {
980 panic!("Operation requires GPU device");
981 }
982 }
983
984 pub fn assert_cpu<P: PhantomDevice>() {
986 if P::requires_gpu() {
987 panic!("Operation requires CPU device");
988 }
989 }
990
991 pub fn assert_p2p<P1, P2>()
993 where
994 P1: PhantomDevice + PeerToPeerOps<P2>,
995 P2: PhantomDevice,
996 {
997 if !P1::P2P_SUPPORTED {
998 panic!("P2P not supported between these device types");
999 }
1000 }
1001}
1002
1003#[cfg(test)]
1004mod advanced_tests {
1005 use super::*;
1006 use crate::device::implementations::CpuDevice;
1007
1008 #[test]
1009 fn test_device_group() {
1010 let cpu_device = Box::new(CpuDevice::new());
1011 let handle1 =
1012 DeviceHandle::<PhantomCpu>::new(cpu_device).expect("DeviceHandle::new should succeed");
1013
1014 let cpu_device2 = Box::new(CpuDevice::new());
1015 let handle2 =
1016 DeviceHandle::<PhantomCpu>::new(cpu_device2).expect("DeviceHandle::new should succeed");
1017
1018 let group = DeviceGroup::new([handle1, handle2]);
1019 assert_eq!(DeviceGroup::<PhantomCpu, 2>::device_count(), 2);
1020 assert!(group.get(0).is_some());
1021 assert!(group.get(1).is_some());
1022 assert!(group.get(2).is_none());
1023 }
1024
1025 #[test]
1026 fn test_p2p_cuda() {
1027 assert!(PhantomCuda::<0>::supports_p2p());
1029 assert!(<PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::P2P_SUPPORTED);
1030
1031 let bandwidth = <PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::p2p_bandwidth();
1032 assert!(bandwidth > 0);
1033
1034 let latency = <PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::p2p_latency();
1035 assert!(latency > 0);
1036 }
1037
1038 #[test]
1039 fn test_device_topology() {
1040 assert_eq!(RingTopology::<PhantomCuda<0>, 4>::DEVICE_COUNT, 4);
1042 assert!(RingTopology::<PhantomCuda<0>, 4>::SUPPORTS_ALLREDUCE);
1043 assert!(RingTopology::<PhantomCuda<0>, 4>::SUPPORTS_BROADCAST);
1044
1045 let bandwidth = RingTopology::<PhantomCuda<0>, 4>::allreduce_bandwidth();
1046 assert!(bandwidth > 0);
1047
1048 assert_eq!(TreeTopology::<PhantomCuda<0>, 8>::DEVICE_COUNT, 8);
1050 let tree_bandwidth = TreeTopology::<PhantomCuda<0>, 8>::allreduce_bandwidth();
1051 assert!(tree_bandwidth > 0);
1052
1053 assert_eq!(AllToAllTopology::<PhantomCuda<0>, 4>::DEVICE_COUNT, 4);
1055 let all2all_bandwidth = AllToAllTopology::<PhantomCuda<0>, 4>::allreduce_bandwidth();
1056 assert!(all2all_bandwidth >= tree_bandwidth); }
1058
1059 #[test]
1060 fn test_typed_device_affinity() {
1061 let cpu_device = Box::new(CpuDevice::new());
1062 let handle =
1063 DeviceHandle::<PhantomCpu>::new(cpu_device).expect("DeviceHandle::new should succeed");
1064
1065 let affinity = TypedDeviceAffinity::new(handle)
1066 .with_numa_node(0)
1067 .with_cpu_affinity(vec![0, 1, 2, 3]);
1068
1069 assert_eq!(affinity.numa_node(), Some(0));
1070 assert_eq!(affinity.cpu_affinity(), Some(&[0, 1, 2, 3][..]));
1071 assert_eq!(affinity.locality_score(0), 100); assert_eq!(affinity.locality_score(1), 30); }
1074
1075 #[test]
1076 fn test_cross_device_op() {
1077 assert!(CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::SUPPORTED);
1079 assert_eq!(
1080 CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::transfer_cost(),
1081 100
1082 );
1083 assert_eq!(
1084 CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::transfer_strategy(),
1085 "pinned-transfer"
1086 );
1087
1088 assert!(CrossDeviceOp::<PhantomMetal<0>, PhantomMetal<0>>::supports_zero_copy());
1090
1091 assert_eq!(CrossDeviceOp::<PhantomCpu, PhantomCpu>::transfer_cost(), 0);
1093 }
1094
1095 #[test]
1096 fn test_compile_time_validation() {
1097 assert_eq!(
1099 CrossDeviceOp::<PhantomCpu, PhantomCpu>::transfer_strategy(),
1100 "memcpy"
1101 );
1102
1103 assert!(DeviceGroup::<PhantomCuda<0>, 4>::supports_p2p());
1105 assert!(!DeviceGroup::<PhantomCpu, 4>::supports_p2p());
1106 }
1107}