Skip to main content

torsh_core/device/
phantom.rs

1//! Phantom device types for compile-time safety
2//!
3//! This module provides phantom types and zero-cost abstractions that enable
4//! compile-time device type safety and validation without runtime overhead.
5
6use crate::device::{Device, DeviceType};
7use std::marker::PhantomData;
8
9/// Phantom device marker trait for compile-time device type information
10///
11/// This trait is used to mark types with specific device information at compile time,
12/// enabling type-safe device operations without runtime overhead.
13pub trait PhantomDevice: 'static + std::fmt::Debug + Send + Sync {
14    /// The device type this phantom represents
15    const DEVICE_TYPE: DeviceType;
16
17    /// Get the device type (compile-time constant)
18    fn device_type() -> DeviceType {
19        Self::DEVICE_TYPE
20    }
21
22    /// Check if this phantom device is compatible with another
23    fn is_compatible<Other: PhantomDevice>() -> bool {
24        Self::DEVICE_TYPE == Other::DEVICE_TYPE
25    }
26
27    /// Get the device name as a compile-time string
28    fn device_name() -> &'static str;
29
30    /// Check if this device requires GPU features
31    fn requires_gpu() -> bool {
32        !matches!(Self::DEVICE_TYPE, DeviceType::Cpu)
33    }
34
35    /// Check if this device supports peer-to-peer operations
36    fn supports_p2p() -> bool {
37        matches!(Self::DEVICE_TYPE, DeviceType::Cuda(_))
38    }
39}
40
41/// CPU phantom device marker
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub struct PhantomCpu;
44
45impl PhantomDevice for PhantomCpu {
46    const DEVICE_TYPE: DeviceType = DeviceType::Cpu;
47
48    fn device_name() -> &'static str {
49        "CPU"
50    }
51}
52
53/// CUDA phantom device marker
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct PhantomCuda<const INDEX: usize>;
56
57impl<const INDEX: usize> PhantomDevice for PhantomCuda<INDEX> {
58    const DEVICE_TYPE: DeviceType = DeviceType::Cuda(INDEX);
59
60    fn device_name() -> &'static str {
61        "CUDA"
62    }
63}
64
65/// Metal phantom device marker
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub struct PhantomMetal<const INDEX: usize>;
68
69impl<const INDEX: usize> PhantomDevice for PhantomMetal<INDEX> {
70    const DEVICE_TYPE: DeviceType = DeviceType::Metal(INDEX);
71
72    fn device_name() -> &'static str {
73        "Metal"
74    }
75}
76
77/// WebGPU phantom device marker
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub struct PhantomWgpu<const INDEX: usize>;
80
81impl<const INDEX: usize> PhantomDevice for PhantomWgpu<INDEX> {
82    const DEVICE_TYPE: DeviceType = DeviceType::Wgpu(INDEX);
83
84    fn device_name() -> &'static str {
85        "WebGPU"
86    }
87}
88
89/// Type-safe device handle with phantom device information
90///
91/// This wrapper provides compile-time device type safety for device operations.
92/// The phantom type parameter carries device information at the type level.
93///
94/// # Examples
95///
96/// ```ignore
97/// use torsh_core::device::{DeviceHandle, PhantomCpu, PhantomCuda};
98///
99/// // CPU device handle
100/// let cpu_handle = DeviceHandle::<PhantomCpu>::new(cpu_device);
101///
102/// // CUDA device handle with index 0
103/// let cuda_handle = DeviceHandle::<PhantomCuda<0>>::new(cuda_device);
104///
105/// // Compile-time device type checking
106/// assert!(cpu_handle.is_cpu());
107/// assert!(cuda_handle.is_gpu());
108/// ```
109#[derive(Debug)]
110pub struct DeviceHandle<P: PhantomDevice> {
111    device: Box<dyn Device>,
112    _phantom: PhantomData<P>,
113}
114
115impl<P: PhantomDevice> DeviceHandle<P> {
116    /// Create a new device handle with phantom type information
117    pub fn new(device: Box<dyn Device>) -> Result<Self, crate::error::TorshError> {
118        if device.device_type() != P::DEVICE_TYPE {
119            return Err(crate::error::TorshError::InvalidArgument(format!(
120                "Device type mismatch: expected {:?}, got {:?}",
121                P::DEVICE_TYPE,
122                device.device_type()
123            )));
124        }
125
126        Ok(Self {
127            device,
128            _phantom: PhantomData,
129        })
130    }
131
132    /// Create an unchecked device handle (unsafe)
133    ///
134    /// # Safety
135    /// The caller must ensure that the device type matches the phantom type.
136    pub unsafe fn new_unchecked(device: Box<dyn Device>) -> Self {
137        Self {
138            device,
139            _phantom: PhantomData,
140        }
141    }
142
143    /// Get the underlying device
144    pub fn device(&self) -> &dyn Device {
145        self.device.as_ref()
146    }
147
148    /// Get the underlying device mutably
149    pub fn device_mut(&mut self) -> &mut dyn Device {
150        self.device.as_mut()
151    }
152
153    /// Get the phantom device type
154    pub fn phantom_device_type() -> DeviceType {
155        P::DEVICE_TYPE
156    }
157
158    /// Check if this is a CPU device (compile-time)
159    pub const fn is_cpu() -> bool {
160        matches!(P::DEVICE_TYPE, DeviceType::Cpu)
161    }
162
163    /// Check if this is a GPU device (compile-time)
164    pub const fn is_gpu() -> bool {
165        !matches!(P::DEVICE_TYPE, DeviceType::Cpu)
166    }
167
168    /// Check if this is a CUDA device (compile-time)
169    pub const fn is_cuda() -> bool {
170        matches!(P::DEVICE_TYPE, DeviceType::Cuda(_))
171    }
172
173    /// Check if this is a Metal device (compile-time)
174    pub const fn is_metal() -> bool {
175        matches!(P::DEVICE_TYPE, DeviceType::Metal(_))
176    }
177
178    /// Check if this is a WebGPU device (compile-time)
179    pub const fn is_wgpu() -> bool {
180        matches!(P::DEVICE_TYPE, DeviceType::Wgpu(_))
181    }
182
183    /// Convert to a different phantom device type (with runtime check)
184    pub fn cast<Q: PhantomDevice>(
185        self,
186    ) -> Result<DeviceHandle<Q>, (Self, crate::error::TorshError)> {
187        if self.device.device_type() != Q::DEVICE_TYPE {
188            let error = crate::error::TorshError::InvalidArgument(format!(
189                "Cannot cast device from {:?} to {:?}",
190                P::DEVICE_TYPE,
191                Q::DEVICE_TYPE
192            ));
193            return Err((self, error));
194        }
195
196        Ok(DeviceHandle {
197            device: self.device,
198            _phantom: PhantomData,
199        })
200    }
201
202    /// Convert to a different phantom device type (unsafe, no runtime check)
203    ///
204    /// # Safety
205    /// The caller must ensure that the device type matches the target phantom type.
206    pub unsafe fn cast_unchecked<Q: PhantomDevice>(self) -> DeviceHandle<Q> {
207        DeviceHandle {
208            device: self.device,
209            _phantom: PhantomData,
210        }
211    }
212
213    /// Extract the underlying device, consuming the handle
214    pub fn into_device(self) -> Box<dyn Device> {
215        self.device
216    }
217}
218
219impl<P: PhantomDevice> Clone for DeviceHandle<P> {
220    fn clone(&self) -> Self {
221        let cloned_device = self.device.clone_device().expect("Failed to clone device");
222
223        Self {
224            device: cloned_device,
225            _phantom: PhantomData,
226        }
227    }
228}
229
230/// Compile-time device compatibility checker
231///
232/// This trait provides compile-time guarantees about device compatibility
233/// for operations that require specific device types or combinations.
234pub trait DeviceCompatible<Other> {
235    /// Check if the devices are compatible at compile time
236    const COMPATIBLE: bool;
237
238    /// Get compatibility information
239    fn compatibility_info() -> &'static str;
240}
241
242impl<P: PhantomDevice> DeviceCompatible<P> for P {
243    const COMPATIBLE: bool = true;
244
245    fn compatibility_info() -> &'static str {
246        "Same device type - always compatible"
247    }
248}
249
250// DeviceCompatible is already implemented generically above
251
252/// Type-level operation constraints
253///
254/// This trait allows operations to specify their device requirements at the type level,
255/// enabling compile-time validation of device compatibility.
256pub trait DeviceOperation<P: PhantomDevice> {
257    /// The result type of this operation
258    type Output;
259
260    /// Device requirements for this operation
261    type Requirements: DeviceRequirements;
262
263    /// Execute the operation on the given device
264    fn execute(device: &DeviceHandle<P>) -> Result<Self::Output, crate::error::TorshError>;
265
266    /// Check if the operation is supported on this device type (compile-time)
267    const SUPPORTED: bool = Self::Requirements::SATISFIED_BY_DEVICE;
268}
269
270/// Device requirements trait for compile-time requirement checking
271pub trait DeviceRequirements {
272    /// Whether this requirement is satisfied by the device
273    const SATISFIED_BY_DEVICE: bool;
274
275    /// Description of the requirements
276    fn description() -> &'static str;
277}
278
279/// Requirement for GPU device
280#[derive(Debug, Clone, Copy)]
281pub struct RequiresGpu;
282
283impl DeviceRequirements for RequiresGpu {
284    const SATISFIED_BY_DEVICE: bool = false; // Will be specialized for GPU types
285
286    fn description() -> &'static str {
287        "Requires GPU device"
288    }
289}
290
291/// Requirement for CPU device
292#[derive(Debug, Clone, Copy)]
293pub struct RequiresCpu;
294
295impl DeviceRequirements for RequiresCpu {
296    const SATISFIED_BY_DEVICE: bool = false; // Will be specialized for CPU type
297
298    fn description() -> &'static str {
299        "Requires CPU device"
300    }
301}
302
303/// Requirement for CUDA device
304#[derive(Debug, Clone, Copy)]
305pub struct RequiresCuda;
306
307impl DeviceRequirements for RequiresCuda {
308    const SATISFIED_BY_DEVICE: bool = false; // Will be specialized for CUDA types
309
310    fn description() -> &'static str {
311        "Requires CUDA device"
312    }
313}
314
315/// No specific device requirements
316#[derive(Debug, Clone, Copy)]
317pub struct NoRequirements;
318
319impl DeviceRequirements for NoRequirements {
320    const SATISFIED_BY_DEVICE: bool = true;
321
322    fn description() -> &'static str {
323        "No specific device requirements"
324    }
325}
326
327/// Device constraint that requires two devices to be the same type
328#[derive(Debug)]
329pub struct SameDevice<P1: PhantomDevice, P2: PhantomDevice> {
330    _phantom: PhantomData<(P1, P2)>,
331}
332
333impl<P1: PhantomDevice, P2: PhantomDevice> SameDevice<P1, P2> {
334    /// Check if the constraint is satisfied
335    pub fn is_satisfied() -> bool {
336        match (P1::DEVICE_TYPE, P2::DEVICE_TYPE) {
337            (DeviceType::Cpu, DeviceType::Cpu) => true,
338            (DeviceType::Cuda(a), DeviceType::Cuda(b)) => a == b,
339            (DeviceType::Metal(a), DeviceType::Metal(b)) => a == b,
340            (DeviceType::Wgpu(a), DeviceType::Wgpu(b)) => a == b,
341            _ => false,
342        }
343    }
344}
345
346/// Device constraint that allows transfer between compatible devices
347#[derive(Debug)]
348pub struct TransferCompatible<P1: PhantomDevice, P2: PhantomDevice> {
349    _phantom: PhantomData<(P1, P2)>,
350}
351
352impl<P1: PhantomDevice, P2: PhantomDevice> TransferCompatible<P1, P2> {
353    /// Check if transfer is supported (compile-time)
354    pub const SUPPORTED: bool = true; // All devices support some form of transfer
355
356    /// Get the estimated transfer cost
357    pub fn transfer_cost() -> u32 {
358        match (P1::DEVICE_TYPE, P2::DEVICE_TYPE) {
359            (DeviceType::Cpu, DeviceType::Cpu) => 0,
360            (DeviceType::Cuda(a), DeviceType::Cuda(b)) if a == b => 0,
361            (DeviceType::Metal(a), DeviceType::Metal(b)) if a == b => 0,
362            (DeviceType::Wgpu(a), DeviceType::Wgpu(b)) if a == b => 0,
363            (DeviceType::Cpu, DeviceType::Cuda(_)) => 100,
364            (DeviceType::Cuda(_), DeviceType::Cpu) => 100,
365            (DeviceType::Cpu, DeviceType::Metal(_)) => 80,
366            (DeviceType::Metal(_), DeviceType::Cpu) => 80,
367            _ => 200, // Cross-GPU transfers
368        }
369    }
370}
371
372/// Type-safe device manager that maintains phantom type information
373#[derive(Debug)]
374pub struct PhantomDeviceManager<P: PhantomDevice> {
375    handles: Vec<DeviceHandle<P>>,
376    _phantom: PhantomData<P>,
377}
378
379impl<P: PhantomDevice> PhantomDeviceManager<P> {
380    /// Create a new phantom device manager
381    pub fn new() -> Self {
382        Self {
383            handles: Vec::new(),
384            _phantom: PhantomData,
385        }
386    }
387
388    /// Add a device handle
389    pub fn add_device(&mut self, handle: DeviceHandle<P>) {
390        self.handles.push(handle);
391    }
392
393    /// Get the number of managed devices
394    pub fn device_count(&self) -> usize {
395        self.handles.len()
396    }
397
398    /// Get a device handle by index
399    pub fn get_device(&self, index: usize) -> Option<&DeviceHandle<P>> {
400        self.handles.get(index)
401    }
402
403    /// Get a mutable device handle by index
404    pub fn get_device_mut(&mut self, index: usize) -> Option<&mut DeviceHandle<P>> {
405        self.handles.get_mut(index)
406    }
407
408    /// Remove a device handle by index
409    pub fn remove_device(&mut self, index: usize) -> Option<DeviceHandle<P>> {
410        if index < self.handles.len() {
411            Some(self.handles.remove(index))
412        } else {
413            None
414        }
415    }
416
417    /// Get all device handles
418    pub fn devices(&self) -> &[DeviceHandle<P>] {
419        &self.handles
420    }
421
422    /// Clear all devices
423    pub fn clear(&mut self) {
424        self.handles.clear();
425    }
426
427    /// Execute an operation on all devices
428    pub fn execute_on_all<Op>(
429        &self,
430        _operation: Op,
431    ) -> Vec<Result<Op::Output, crate::error::TorshError>>
432    where
433        Op: DeviceOperation<P> + Clone,
434    {
435        self.handles
436            .iter()
437            .map(|handle| Op::execute(handle))
438            .collect()
439    }
440}
441
442impl<P: PhantomDevice> Default for PhantomDeviceManager<P> {
443    fn default() -> Self {
444        Self::new()
445    }
446}
447
448/// Utility functions for phantom device operations
449pub mod utils {
450    use super::*;
451
452    /// Create a type-safe device handle from a runtime device
453    pub fn create_phantom_handle<P: PhantomDevice>(
454        device: Box<dyn Device>,
455    ) -> Result<DeviceHandle<P>, crate::error::TorshError> {
456        DeviceHandle::<P>::new(device)
457    }
458
459    /// Check device compatibility at runtime with phantom type information
460    pub fn check_phantom_compatibility<P1: PhantomDevice, P2: PhantomDevice>() -> bool {
461        P1::DEVICE_TYPE == P2::DEVICE_TYPE
462    }
463
464    /// Get the transfer cost between two phantom device types
465    pub fn phantom_transfer_cost<P1: PhantomDevice, P2: PhantomDevice>() -> u32 {
466        TransferCompatible::<P1, P2>::transfer_cost()
467    }
468
469    /// Create a device manager for a specific phantom device type
470    pub fn create_phantom_manager<P: PhantomDevice>() -> PhantomDeviceManager<P> {
471        PhantomDeviceManager::new()
472    }
473
474    /// Verify that an operation is supported on a phantom device type
475    pub fn verify_operation_support<P: PhantomDevice, Op: DeviceOperation<P>>() -> bool {
476        Op::SUPPORTED
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::device::core::Device;
484    use std::any::Any;
485
486    // Mock device for testing
487    #[derive(Debug)]
488    struct MockDevice {
489        device_type: DeviceType,
490    }
491
492    impl MockDevice {
493        fn new(device_type: DeviceType) -> Self {
494            Self { device_type }
495        }
496    }
497
498    impl Device for MockDevice {
499        fn device_type(&self) -> DeviceType {
500            self.device_type
501        }
502
503        fn name(&self) -> &str {
504            "Mock Device"
505        }
506
507        fn is_available(&self) -> Result<bool, crate::error::TorshError> {
508            Ok(true)
509        }
510
511        fn capabilities(
512            &self,
513        ) -> Result<crate::device::DeviceCapabilities, crate::error::TorshError> {
514            crate::device::DeviceCapabilities::detect(self.device_type)
515        }
516
517        fn synchronize(&self) -> Result<(), crate::error::TorshError> {
518            Ok(())
519        }
520
521        fn reset(&self) -> Result<(), crate::error::TorshError> {
522            Ok(())
523        }
524
525        fn as_any(&self) -> &dyn Any {
526            self
527        }
528
529        fn as_any_mut(&mut self) -> &mut dyn Any {
530            self
531        }
532
533        fn clone_device(&self) -> Result<Box<dyn Device>, crate::error::TorshError> {
534            Ok(Box::new(MockDevice::new(self.device_type)))
535        }
536    }
537
538    #[test]
539    fn test_phantom_device_markers() {
540        assert_eq!(PhantomCpu::device_type(), DeviceType::Cpu);
541        assert_eq!(PhantomCuda::<0>::device_type(), DeviceType::Cuda(0));
542        assert_eq!(PhantomMetal::<1>::device_type(), DeviceType::Metal(1));
543        assert_eq!(PhantomWgpu::<2>::device_type(), DeviceType::Wgpu(2));
544
545        assert_eq!(PhantomCpu::device_name(), "CPU");
546        assert_eq!(PhantomCuda::<0>::device_name(), "CUDA");
547        assert_eq!(PhantomMetal::<0>::device_name(), "Metal");
548        assert_eq!(PhantomWgpu::<0>::device_name(), "WebGPU");
549    }
550
551    #[test]
552    fn test_phantom_device_properties() {
553        assert!(!PhantomCpu::requires_gpu());
554        assert!(PhantomCuda::<0>::requires_gpu());
555        assert!(PhantomMetal::<0>::requires_gpu());
556        assert!(PhantomWgpu::<0>::requires_gpu());
557
558        assert!(!PhantomCpu::supports_p2p());
559        assert!(PhantomCuda::<0>::supports_p2p());
560        assert!(!PhantomMetal::<0>::supports_p2p());
561        assert!(!PhantomWgpu::<0>::supports_p2p());
562    }
563
564    #[test]
565    fn test_device_handle() {
566        let mock_device = Box::new(MockDevice::new(DeviceType::Cpu));
567        let _handle =
568            DeviceHandle::<PhantomCpu>::new(mock_device).expect("DeviceHandle::new should succeed");
569
570        assert_eq!(
571            DeviceHandle::<PhantomCpu>::phantom_device_type(),
572            DeviceType::Cpu
573        );
574        assert!(DeviceHandle::<PhantomCpu>::is_cpu());
575        assert!(!DeviceHandle::<PhantomCpu>::is_gpu());
576        assert!(!DeviceHandle::<PhantomCpu>::is_cuda());
577    }
578
579    #[test]
580    fn test_device_handle_type_mismatch() {
581        let mock_device = Box::new(MockDevice::new(DeviceType::Cuda(0)));
582        let result = DeviceHandle::<PhantomCpu>::new(mock_device);
583        assert!(result.is_err());
584    }
585
586    #[test]
587    fn test_device_compatibility() {
588        assert!(PhantomCpu::is_compatible::<PhantomCpu>());
589        assert!(!PhantomCpu::is_compatible::<PhantomCuda<0>>());
590        assert!(PhantomCuda::<0>::is_compatible::<PhantomCuda<0>>());
591        assert!(!PhantomCuda::<0>::is_compatible::<PhantomCuda<1>>());
592    }
593
594    #[test]
595    fn test_phantom_device_manager() {
596        let mut manager = PhantomDeviceManager::<PhantomCpu>::new();
597        assert_eq!(manager.device_count(), 0);
598
599        let mock_device = Box::new(MockDevice::new(DeviceType::Cpu));
600        let handle =
601            DeviceHandle::<PhantomCpu>::new(mock_device).expect("DeviceHandle::new should succeed");
602        manager.add_device(handle);
603
604        assert_eq!(manager.device_count(), 1);
605        assert!(manager.get_device(0).is_some());
606        assert!(manager.get_device(1).is_none());
607
608        let removed = manager.remove_device(0);
609        assert!(removed.is_some());
610        assert_eq!(manager.device_count(), 0);
611    }
612
613    #[test]
614    fn test_transfer_cost_constants() {
615        assert_eq!(
616            TransferCompatible::<PhantomCpu, PhantomCpu>::transfer_cost(),
617            0
618        );
619        assert_eq!(
620            TransferCompatible::<PhantomCpu, PhantomCuda<0>>::transfer_cost(),
621            100
622        );
623        assert_eq!(
624            TransferCompatible::<PhantomCpu, PhantomMetal<0>>::transfer_cost(),
625            80
626        );
627    }
628
629    #[test]
630    fn test_utils_functions() {
631        assert!(utils::check_phantom_compatibility::<PhantomCpu, PhantomCpu>());
632        assert!(!utils::check_phantom_compatibility::<
633            PhantomCpu,
634            PhantomCuda<0>,
635        >());
636
637        let cost = utils::phantom_transfer_cost::<PhantomCpu, PhantomCuda<0>>();
638        assert_eq!(cost, 100);
639
640        let manager = utils::create_phantom_manager::<PhantomCpu>();
641        assert_eq!(manager.device_count(), 0);
642    }
643}
644
645// ============================================================================
646// Advanced Phantom Type Features - Multi-GPU and Topology Support
647// ============================================================================
648
649/// Device group for multi-GPU operations
650///
651/// This type represents a group of devices of the same type,
652/// enabling type-safe multi-GPU operations with compile-time guarantees.
653#[derive(Debug)]
654pub struct DeviceGroup<P: PhantomDevice, const N: usize> {
655    devices: [DeviceHandle<P>; N],
656}
657
658impl<P: PhantomDevice, const N: usize> DeviceGroup<P, N> {
659    /// Create a new device group
660    pub fn new(devices: [DeviceHandle<P>; N]) -> Self {
661        Self { devices }
662    }
663
664    /// Get the number of devices in this group
665    pub const fn device_count() -> usize {
666        N
667    }
668
669    /// Get a device by index
670    pub fn get(&self, index: usize) -> Option<&DeviceHandle<P>> {
671        self.devices.get(index)
672    }
673
674    /// Get a device mutably by index
675    pub fn get_mut(&mut self, index: usize) -> Option<&mut DeviceHandle<P>> {
676        self.devices.get_mut(index)
677    }
678
679    /// Iterate over all devices
680    pub fn iter(&self) -> impl Iterator<Item = &DeviceHandle<P>> {
681        self.devices.iter()
682    }
683
684    /// Iterate over all devices mutably
685    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut DeviceHandle<P>> {
686        self.devices.iter_mut()
687    }
688
689    /// Execute an operation on all devices in parallel
690    pub fn parallel_execute<F, R>(&self, f: F) -> Vec<R>
691    where
692        F: Fn(&DeviceHandle<P>) -> R + Sync,
693        R: Send,
694    {
695        self.devices.iter().map(f).collect()
696    }
697
698    /// Check if all devices support peer-to-peer operations
699    pub fn supports_p2p() -> bool {
700        P::supports_p2p()
701    }
702
703    /// Get the device group type name
704    pub fn group_type_name() -> String {
705        format!("DeviceGroup<{}, {}>", P::device_name(), N)
706    }
707}
708
709/// Peer-to-peer operation trait for type-safe P2P operations
710///
711/// This trait ensures that P2P operations can only be performed between
712/// compatible devices at compile time.
713pub trait PeerToPeerOps<Other: PhantomDevice>: PhantomDevice {
714    /// Whether P2P is supported between these device types
715    const P2P_SUPPORTED: bool;
716
717    /// Get the P2P bandwidth estimate (MB/s)
718    fn p2p_bandwidth() -> u32;
719
720    /// Get P2P latency estimate (microseconds)
721    fn p2p_latency() -> u32;
722}
723
724// Implement P2P for CUDA devices
725impl<const I1: usize, const I2: usize> PeerToPeerOps<PhantomCuda<I2>> for PhantomCuda<I1> {
726    const P2P_SUPPORTED: bool = true;
727
728    fn p2p_bandwidth() -> u32 {
729        if I1 == I2 {
730            0 // Same device - no transfer needed
731        } else {
732            // NVLink or PCIe bandwidth
733            50_000 // 50 GB/s estimate for NVLink
734        }
735    }
736
737    fn p2p_latency() -> u32 {
738        if I1 == I2 {
739            0
740        } else {
741            5 // ~5 microseconds for NVLink
742        }
743    }
744}
745
746// Implement P2P for Metal devices (Apple Silicon with unified memory)
747impl<const I1: usize, const I2: usize> PeerToPeerOps<PhantomMetal<I2>> for PhantomMetal<I1> {
748    const P2P_SUPPORTED: bool = I1 == I2; // Only same device on Apple Silicon
749
750    fn p2p_bandwidth() -> u32 {
751        if I1 == I2 {
752            400_000 // 400 GB/s unified memory bandwidth (M1 Max/Ultra)
753        } else {
754            0
755        }
756    }
757
758    fn p2p_latency() -> u32 {
759        if I1 == I2 {
760            1 // Sub-microsecond on unified memory
761        } else {
762            0
763        }
764    }
765}
766
767/// Device topology constraint for compile-time topology validation
768///
769/// This trait enables compile-time validation of device topology requirements
770/// for operations that depend on specific device arrangements.
771pub trait DeviceTopology {
772    /// The number of devices in this topology
773    const DEVICE_COUNT: usize;
774
775    /// Whether this topology supports efficient all-reduce operations
776    const SUPPORTS_ALLREDUCE: bool;
777
778    /// Whether this topology supports efficient broadcast operations
779    const SUPPORTS_BROADCAST: bool;
780
781    /// Get the topology name
782    fn topology_name() -> &'static str;
783
784    /// Get the estimated all-reduce bandwidth (MB/s)
785    fn allreduce_bandwidth() -> u32;
786}
787
788/// Ring topology for distributed operations
789#[derive(Debug)]
790pub struct RingTopology<P: PhantomDevice, const N: usize> {
791    _phantom: PhantomData<P>,
792}
793
794impl<P: PhantomDevice, const N: usize> DeviceTopology for RingTopology<P, N> {
795    const DEVICE_COUNT: usize = N;
796    const SUPPORTS_ALLREDUCE: bool = true;
797    const SUPPORTS_BROADCAST: bool = true;
798
799    fn topology_name() -> &'static str {
800        "Ring"
801    }
802
803    fn allreduce_bandwidth() -> u32 {
804        // Ring all-reduce bandwidth depends on inter-device bandwidth
805        if P::supports_p2p() {
806            25_000 // 25 GB/s for NVLink ring
807        } else {
808            5_000 // 5 GB/s for PCIe ring
809        }
810    }
811}
812
813/// Tree topology for distributed operations
814#[derive(Debug)]
815pub struct TreeTopology<P: PhantomDevice, const N: usize> {
816    _phantom: PhantomData<P>,
817}
818
819impl<P: PhantomDevice, const N: usize> DeviceTopology for TreeTopology<P, N> {
820    const DEVICE_COUNT: usize = N;
821    const SUPPORTS_ALLREDUCE: bool = true;
822    const SUPPORTS_BROADCAST: bool = true;
823
824    fn topology_name() -> &'static str {
825        "Tree"
826    }
827
828    fn allreduce_bandwidth() -> u32 {
829        // Tree has logarithmic depth, higher bandwidth for small N
830        if P::supports_p2p() {
831            40_000 // 40 GB/s for NVLink tree
832        } else {
833            8_000 // 8 GB/s for PCIe tree
834        }
835    }
836}
837
838/// All-to-all topology (fully connected) for maximum bandwidth
839#[derive(Debug)]
840pub struct AllToAllTopology<P: PhantomDevice, const N: usize> {
841    _phantom: PhantomData<P>,
842}
843
844impl<P: PhantomDevice, const N: usize> DeviceTopology for AllToAllTopology<P, N> {
845    const DEVICE_COUNT: usize = N;
846    const SUPPORTS_ALLREDUCE: bool = true;
847    const SUPPORTS_BROADCAST: bool = true;
848
849    fn topology_name() -> &'static str {
850        "AllToAll"
851    }
852
853    fn allreduce_bandwidth() -> u32 {
854        // All-to-all provides maximum bandwidth
855        if P::supports_p2p() {
856            100_000 // 100 GB/s for fully connected NVLink
857        } else {
858            15_000 // 15 GB/s for fully connected PCIe
859        }
860    }
861}
862
863/// Enhanced device affinity with compile-time validation
864///
865/// This struct provides type-safe device affinity management with
866/// compile-time guarantees about device compatibility and locality.
867#[derive(Debug)]
868pub struct TypedDeviceAffinity<P: PhantomDevice> {
869    device_handle: DeviceHandle<P>,
870    preferred_numa_node: Option<usize>,
871    cpu_affinity: Option<Vec<usize>>,
872}
873
874impl<P: PhantomDevice> TypedDeviceAffinity<P> {
875    /// Create a new typed device affinity
876    pub fn new(device_handle: DeviceHandle<P>) -> Self {
877        Self {
878            device_handle,
879            preferred_numa_node: None,
880            cpu_affinity: None,
881        }
882    }
883
884    /// Set the preferred NUMA node for this device
885    pub fn with_numa_node(mut self, node: usize) -> Self {
886        self.preferred_numa_node = Some(node);
887        self
888    }
889
890    /// Set CPU affinity for operations on this device
891    pub fn with_cpu_affinity(mut self, cpus: Vec<usize>) -> Self {
892        self.cpu_affinity = Some(cpus);
893        self
894    }
895
896    /// Get the device handle
897    pub fn device(&self) -> &DeviceHandle<P> {
898        &self.device_handle
899    }
900
901    /// Get the preferred NUMA node
902    pub fn numa_node(&self) -> Option<usize> {
903        self.preferred_numa_node
904    }
905
906    /// Get the CPU affinity
907    pub fn cpu_affinity(&self) -> Option<&[usize]> {
908        self.cpu_affinity.as_deref()
909    }
910
911    /// Check if this device is local to a specific NUMA node (compile-time hint)
912    pub const fn is_cpu_device() -> bool {
913        matches!(P::DEVICE_TYPE, DeviceType::Cpu)
914    }
915
916    /// Get locality score (0-100, higher is better for NUMA)
917    pub fn locality_score(&self, target_numa: usize) -> u32 {
918        match self.preferred_numa_node {
919            Some(node) if node == target_numa => 100,
920            Some(_) => 30, // Remote NUMA node
921            None => 50,    // Unknown locality
922        }
923    }
924}
925
926/// Type-safe cross-device operation builder
927///
928/// This builder ensures that cross-device operations are only created
929/// between compatible devices with compile-time validation.
930#[derive(Debug)]
931pub struct CrossDeviceOp<PSrc: PhantomDevice, PDst: PhantomDevice> {
932    _phantom: PhantomData<(PSrc, PDst)>,
933}
934
935impl<PSrc: PhantomDevice, PDst: PhantomDevice> CrossDeviceOp<PSrc, PDst> {
936    /// Check if the operation is supported (compile-time)
937    pub const SUPPORTED: bool = TransferCompatible::<PSrc, PDst>::SUPPORTED;
938
939    /// Get the estimated transfer cost
940    pub fn transfer_cost() -> u32 {
941        TransferCompatible::<PSrc, PDst>::transfer_cost()
942    }
943
944    /// Get transfer strategy recommendation
945    pub fn transfer_strategy() -> &'static str {
946        match (PSrc::DEVICE_TYPE, PDst::DEVICE_TYPE) {
947            (DeviceType::Cpu, DeviceType::Cpu) => "memcpy",
948            (DeviceType::Cuda(_), DeviceType::Cuda(_)) => "peer-to-peer",
949            (DeviceType::Cpu, DeviceType::Cuda(_)) => "pinned-transfer",
950            (DeviceType::Cuda(_), DeviceType::Cpu) => "staged-readback",
951            (DeviceType::Metal(_), DeviceType::Metal(_)) => "unified-memory",
952            (DeviceType::Cpu, DeviceType::Metal(_)) => "shared-memory",
953            _ => "staged-transfer",
954        }
955    }
956
957    /// Check if zero-copy transfer is possible (compile-time hint)
958    pub const fn supports_zero_copy() -> bool {
959        matches!(
960            (PSrc::DEVICE_TYPE, PDst::DEVICE_TYPE),
961            (DeviceType::Metal(_), DeviceType::Metal(_))
962        )
963    }
964}
965
966/// Compile-time device validation utilities
967pub mod compile_time {
968    use super::*;
969
970    /// Assert that two phantom devices are the same type
971    pub fn assert_same_device<P1: PhantomDevice, P2: PhantomDevice>() {
972        if !SameDevice::<P1, P2>::is_satisfied() {
973            panic!("Device types must match");
974        }
975    }
976
977    /// Assert that a device is GPU
978    pub fn assert_gpu<P: PhantomDevice>() {
979        if !P::requires_gpu() {
980            panic!("Operation requires GPU device");
981        }
982    }
983
984    /// Assert that a device is CPU
985    pub fn assert_cpu<P: PhantomDevice>() {
986        if P::requires_gpu() {
987            panic!("Operation requires CPU device");
988        }
989    }
990
991    /// Assert that P2P is supported between devices
992    pub fn assert_p2p<P1, P2>()
993    where
994        P1: PhantomDevice + PeerToPeerOps<P2>,
995        P2: PhantomDevice,
996    {
997        if !P1::P2P_SUPPORTED {
998            panic!("P2P not supported between these device types");
999        }
1000    }
1001}
1002
1003#[cfg(test)]
1004mod advanced_tests {
1005    use super::*;
1006    use crate::device::implementations::CpuDevice;
1007
1008    #[test]
1009    fn test_device_group() {
1010        let cpu_device = Box::new(CpuDevice::new());
1011        let handle1 =
1012            DeviceHandle::<PhantomCpu>::new(cpu_device).expect("DeviceHandle::new should succeed");
1013
1014        let cpu_device2 = Box::new(CpuDevice::new());
1015        let handle2 =
1016            DeviceHandle::<PhantomCpu>::new(cpu_device2).expect("DeviceHandle::new should succeed");
1017
1018        let group = DeviceGroup::new([handle1, handle2]);
1019        assert_eq!(DeviceGroup::<PhantomCpu, 2>::device_count(), 2);
1020        assert!(group.get(0).is_some());
1021        assert!(group.get(1).is_some());
1022        assert!(group.get(2).is_none());
1023    }
1024
1025    #[test]
1026    fn test_p2p_cuda() {
1027        // Compile-time checks
1028        assert!(PhantomCuda::<0>::supports_p2p());
1029        assert!(<PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::P2P_SUPPORTED);
1030
1031        let bandwidth = <PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::p2p_bandwidth();
1032        assert!(bandwidth > 0);
1033
1034        let latency = <PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::p2p_latency();
1035        assert!(latency > 0);
1036    }
1037
1038    #[test]
1039    fn test_device_topology() {
1040        // Ring topology
1041        assert_eq!(RingTopology::<PhantomCuda<0>, 4>::DEVICE_COUNT, 4);
1042        assert!(RingTopology::<PhantomCuda<0>, 4>::SUPPORTS_ALLREDUCE);
1043        assert!(RingTopology::<PhantomCuda<0>, 4>::SUPPORTS_BROADCAST);
1044
1045        let bandwidth = RingTopology::<PhantomCuda<0>, 4>::allreduce_bandwidth();
1046        assert!(bandwidth > 0);
1047
1048        // Tree topology
1049        assert_eq!(TreeTopology::<PhantomCuda<0>, 8>::DEVICE_COUNT, 8);
1050        let tree_bandwidth = TreeTopology::<PhantomCuda<0>, 8>::allreduce_bandwidth();
1051        assert!(tree_bandwidth > 0);
1052
1053        // All-to-all topology
1054        assert_eq!(AllToAllTopology::<PhantomCuda<0>, 4>::DEVICE_COUNT, 4);
1055        let all2all_bandwidth = AllToAllTopology::<PhantomCuda<0>, 4>::allreduce_bandwidth();
1056        assert!(all2all_bandwidth >= tree_bandwidth); // All-to-all should be faster
1057    }
1058
1059    #[test]
1060    fn test_typed_device_affinity() {
1061        let cpu_device = Box::new(CpuDevice::new());
1062        let handle =
1063            DeviceHandle::<PhantomCpu>::new(cpu_device).expect("DeviceHandle::new should succeed");
1064
1065        let affinity = TypedDeviceAffinity::new(handle)
1066            .with_numa_node(0)
1067            .with_cpu_affinity(vec![0, 1, 2, 3]);
1068
1069        assert_eq!(affinity.numa_node(), Some(0));
1070        assert_eq!(affinity.cpu_affinity(), Some(&[0, 1, 2, 3][..]));
1071        assert_eq!(affinity.locality_score(0), 100); // Perfect locality
1072        assert_eq!(affinity.locality_score(1), 30); // Remote NUMA
1073    }
1074
1075    #[test]
1076    fn test_cross_device_op() {
1077        // CPU to CUDA transfer
1078        assert!(CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::SUPPORTED);
1079        assert_eq!(
1080            CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::transfer_cost(),
1081            100
1082        );
1083        assert_eq!(
1084            CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::transfer_strategy(),
1085            "pinned-transfer"
1086        );
1087
1088        // Metal unified memory
1089        assert!(CrossDeviceOp::<PhantomMetal<0>, PhantomMetal<0>>::supports_zero_copy());
1090
1091        // Same device
1092        assert_eq!(CrossDeviceOp::<PhantomCpu, PhantomCpu>::transfer_cost(), 0);
1093    }
1094
1095    #[test]
1096    fn test_compile_time_validation() {
1097        // These should compile without panicking
1098        assert_eq!(
1099            CrossDeviceOp::<PhantomCpu, PhantomCpu>::transfer_strategy(),
1100            "memcpy"
1101        );
1102
1103        // Device group supports P2P for CUDA
1104        assert!(DeviceGroup::<PhantomCuda<0>, 4>::supports_p2p());
1105        assert!(!DeviceGroup::<PhantomCpu, 4>::supports_p2p());
1106    }
1107}