Skip to main content

torsh_backend/
device.rs

1//! Device abstraction and management
2
3use torsh_core::device::DeviceType as CoreDeviceType;
4
5#[cfg(not(feature = "std"))]
6use alloc::{string::String, vec::Vec};
7
8/// Device identifier and properties
9#[derive(Debug, Clone, PartialEq)]
10pub struct Device {
11    /// Unique device ID within the backend
12    pub id: usize,
13
14    /// Device type (CPU, CUDA, Metal, etc.)
15    pub device_type: CoreDeviceType,
16
17    /// Human-readable device name
18    pub name: String,
19
20    /// Device information and capabilities
21    pub info: DeviceInfo,
22}
23
24impl Device {
25    /// Create a new device
26    pub fn new(id: usize, device_type: CoreDeviceType, name: String, info: DeviceInfo) -> Self {
27        Self {
28            id,
29            device_type,
30            name,
31            info,
32        }
33    }
34
35    /// Builder pattern for creating devices
36    pub fn builder() -> DeviceBuilder {
37        DeviceBuilder::new()
38    }
39
40    /// Get the device ID
41    pub const fn id(&self) -> usize {
42        self.id
43    }
44
45    /// Get the device type
46    pub const fn device_type(&self) -> CoreDeviceType {
47        self.device_type
48    }
49
50    /// Get the device name
51    pub fn name(&self) -> &str {
52        &self.name
53    }
54
55    /// Get device information
56    pub fn info(&self) -> &DeviceInfo {
57        &self.info
58    }
59
60    /// Check if this device supports the given feature
61    pub fn supports_feature(&self, feature: DeviceFeature) -> bool {
62        self.info.features.contains(&feature)
63    }
64
65    /// Create a default CPU device
66    pub fn cpu() -> crate::BackendResult<Self> {
67        DeviceBuilder::new()
68            .with_device_type(CoreDeviceType::Cpu)
69            .with_name("CPU".to_string())
70            .with_vendor("Generic".to_string())
71            .with_compute_units(num_cpus::get())
72            .build()
73    }
74}
75
76/// Detailed device information
77#[derive(Debug, Clone, PartialEq)]
78pub struct DeviceInfo {
79    /// Vendor name (e.g., "NVIDIA", "AMD", "Apple")
80    pub vendor: String,
81
82    /// Driver version
83    pub driver_version: String,
84
85    /// Total memory in bytes
86    pub total_memory: usize,
87
88    /// Available memory in bytes
89    pub available_memory: usize,
90
91    /// Number of compute units (cores, SMs, etc.)
92    pub compute_units: usize,
93
94    /// Maximum work group size
95    pub max_work_group_size: usize,
96
97    /// Maximum work group dimensions
98    pub max_work_group_dimensions: Vec<usize>,
99
100    /// Clock frequency in MHz
101    pub clock_frequency_mhz: u32,
102
103    /// Memory bandwidth in GB/s
104    pub memory_bandwidth_gbps: f32,
105
106    /// Peak compute performance in GFLOPS
107    pub peak_gflops: f32,
108
109    /// Supported features
110    pub features: Vec<DeviceFeature>,
111
112    /// Additional vendor-specific properties
113    pub properties: Vec<(String, String)>,
114}
115
116impl Default for DeviceInfo {
117    fn default() -> Self {
118        Self {
119            vendor: "Unknown".to_string(),
120            driver_version: "Unknown".to_string(),
121            total_memory: 0,
122            available_memory: 0,
123            compute_units: 1,
124            max_work_group_size: 256,
125            max_work_group_dimensions: vec![256, 1, 1],
126            clock_frequency_mhz: 1000,
127            memory_bandwidth_gbps: 10.0,
128            peak_gflops: 100.0,
129            features: Vec::new(),
130            properties: Vec::new(),
131        }
132    }
133}
134
135/// Device features and capabilities
136#[derive(Debug, Clone, PartialEq, Eq, Hash)]
137pub enum DeviceFeature {
138    /// Supports double precision floating point
139    DoublePrecision,
140
141    /// Supports half precision floating point
142    HalfPrecision,
143
144    /// Supports unified memory between host and device
145    UnifiedMemory,
146
147    /// Supports atomic operations
148    AtomicOperations,
149
150    /// Supports sub-groups/warps
151    SubGroups,
152
153    /// Supports printf in kernels
154    Printf,
155
156    /// Supports profiling and debugging
157    Profiling,
158
159    /// Supports peer-to-peer memory access
160    PeerToPeer,
161
162    /// Supports concurrent kernel execution
163    ConcurrentExecution,
164
165    /// Supports asynchronous memory operations
166    AsyncMemory,
167
168    /// Supports texture/image operations
169    ImageSupport,
170
171    /// Supports fast math optimizations
172    FastMath,
173
174    // WebGPU-specific features
175    /// Supports timestamp queries for performance measurement
176    TimestampQuery,
177
178    /// Supports timestamp queries inside encoders
179    TimestampQueryInsideEncoders,
180
181    /// Supports pipeline statistics queries
182    PipelineStatistics,
183
184    /// Supports mappable primary buffers
185    MappableBuffers,
186
187    /// Supports buffer binding arrays
188    BufferArrays,
189
190    /// Supports storage resource binding arrays
191    StorageArrays,
192
193    /// Supports unsized binding arrays
194    UnsizedBindingArray,
195
196    /// Supports indirect first instance parameter
197    IndirectFirstInstance,
198
199    /// Supports 16-bit floating point in shaders
200    ShaderF16,
201
202    /// Supports 16-bit integers in shaders
203    ShaderI16,
204
205    /// Supports shader primitive index
206    ShaderPrimitiveIndex,
207
208    /// Supports early depth test in shaders
209    ShaderEarlyDepthTest,
210
211    /// Supports multi-draw indirect
212    MultiDrawIndirect,
213
214    /// Supports multi-draw indirect with count
215    MultiDrawIndirectCount,
216
217    /// Supports multisampled shading
218    Multisampling,
219
220    /// Supports texture clear operations
221    ClearTexture,
222
223    /// Supports SPIR-V shader passthrough
224    SpirvShaderPassthrough,
225
226    /// Custom vendor-specific feature
227    Custom(String),
228}
229
230/// Device builder for constructing devices with validation
231#[derive(Debug, Clone)]
232pub struct DeviceBuilder {
233    id: usize,
234    device_type: Option<CoreDeviceType>,
235    name: Option<String>,
236    info: DeviceInfo,
237}
238
239impl DeviceBuilder {
240    pub fn new() -> Self {
241        Self {
242            id: 0,
243            device_type: None,
244            name: None,
245            info: DeviceInfo::default(),
246        }
247    }
248
249    pub fn with_id(mut self, id: usize) -> Self {
250        self.id = id;
251        self
252    }
253
254    pub fn with_device_type(mut self, device_type: CoreDeviceType) -> Self {
255        self.device_type = Some(device_type);
256        self
257    }
258
259    pub fn with_name(mut self, name: String) -> Self {
260        self.name = Some(name);
261        self
262    }
263
264    pub fn with_vendor(mut self, vendor: String) -> Self {
265        self.info.vendor = vendor;
266        self
267    }
268
269    pub fn with_driver_version(mut self, version: String) -> Self {
270        self.info.driver_version = version;
271        self
272    }
273
274    pub fn with_memory(mut self, total: usize, available: usize) -> Self {
275        self.info.total_memory = total;
276        self.info.available_memory = available;
277        self
278    }
279
280    pub fn with_compute_units(mut self, units: usize) -> Self {
281        self.info.compute_units = units;
282        self
283    }
284
285    pub fn with_performance(mut self, gflops: f32, bandwidth_gbps: f32) -> Self {
286        self.info.peak_gflops = gflops;
287        self.info.memory_bandwidth_gbps = bandwidth_gbps;
288        self
289    }
290
291    pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
292        self.info.features.push(feature);
293        self
294    }
295
296    pub fn with_property(mut self, key: String, value: String) -> Self {
297        self.info.properties.push((key, value));
298        self
299    }
300
301    pub fn build(self) -> crate::BackendResult<Device> {
302        let device_type = self.device_type.ok_or_else(|| {
303            torsh_core::error::TorshError::BackendError("Device type is required".to_string())
304        })?;
305
306        let name = self.name.ok_or_else(|| {
307            torsh_core::error::TorshError::BackendError("Device name is required".to_string())
308        })?;
309
310        Ok(Device {
311            id: self.id,
312            device_type,
313            name,
314            info: self.info,
315        })
316    }
317}
318
319impl Default for DeviceBuilder {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325/// Device type enumeration
326#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
327pub enum DeviceType {
328    /// CPU device
329    Cpu,
330
331    /// NVIDIA CUDA GPU
332    Cuda,
333
334    /// Apple Metal GPU
335    Metal,
336
337    /// WebGPU device
338    WebGpu,
339
340    /// OpenCL device
341    OpenCl,
342
343    /// Vulkan Compute device
344    Vulkan,
345
346    /// Custom device type
347    Custom,
348}
349
350impl From<CoreDeviceType> for DeviceType {
351    fn from(core_type: CoreDeviceType) -> Self {
352        match core_type {
353            CoreDeviceType::Cpu => DeviceType::Cpu,
354            CoreDeviceType::Cuda(_) => DeviceType::Cuda,
355            CoreDeviceType::Metal(_) => DeviceType::Metal,
356            CoreDeviceType::Wgpu(_) => DeviceType::WebGpu,
357        }
358    }
359}
360
361impl From<DeviceType> for CoreDeviceType {
362    fn from(device_type: DeviceType) -> Self {
363        match device_type {
364            DeviceType::Cpu => CoreDeviceType::Cpu,
365            DeviceType::Cuda => CoreDeviceType::Cuda(0), // Default to device 0
366            DeviceType::Metal => CoreDeviceType::Metal(0), // Default to device 0
367            DeviceType::WebGpu => CoreDeviceType::Wgpu(0), // Default to device 0
368            DeviceType::OpenCl => CoreDeviceType::Cpu,   // Fallback
369            DeviceType::Vulkan => CoreDeviceType::Cpu,   // Fallback
370            DeviceType::Custom => CoreDeviceType::Cpu,   // Fallback
371        }
372    }
373}
374
375/// Device selection criteria
376#[derive(Default)]
377pub struct DeviceSelector {
378    /// Preferred device type
379    pub device_type: Option<DeviceType>,
380
381    /// Minimum memory requirement in bytes
382    pub min_memory: Option<usize>,
383
384    /// Minimum compute units
385    pub min_compute_units: Option<usize>,
386
387    /// Required features
388    pub required_features: Vec<DeviceFeature>,
389
390    /// Preferred vendor
391    pub preferred_vendor: Option<String>,
392
393    /// Custom selection function
394    #[allow(clippy::type_complexity)]
395    pub custom_filter: Option<Box<dyn Fn(&Device) -> bool + Send + Sync>>,
396}
397
398impl DeviceSelector {
399    /// Create a new device selector
400    pub fn new() -> Self {
401        Self::default()
402    }
403
404    /// Set preferred device type
405    pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
406        self.device_type = Some(device_type);
407        self
408    }
409
410    /// Set minimum memory requirement
411    pub fn with_min_memory(mut self, min_memory: usize) -> Self {
412        self.min_memory = Some(min_memory);
413        self
414    }
415
416    /// Set minimum compute units
417    pub fn with_min_compute_units(mut self, min_compute_units: usize) -> Self {
418        self.min_compute_units = Some(min_compute_units);
419        self
420    }
421
422    /// Add required feature
423    pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
424        self.required_features.push(feature);
425        self
426    }
427
428    /// Set preferred vendor
429    pub fn with_vendor(mut self, vendor: String) -> Self {
430        self.preferred_vendor = Some(vendor);
431        self
432    }
433
434    /// Check if a device matches this selector
435    pub fn matches(&self, device: &Device) -> bool {
436        // Check device type
437        if let Some(required_type) = &self.device_type {
438            if device.device_type != (*required_type).into() {
439                return false;
440            }
441        }
442
443        // Check memory
444        if let Some(min_memory) = self.min_memory {
445            if device.info.total_memory < min_memory {
446                return false;
447            }
448        }
449
450        // Check compute units
451        if let Some(min_compute_units) = self.min_compute_units {
452            if device.info.compute_units < min_compute_units {
453                return false;
454            }
455        }
456
457        // Check required features
458        for feature in &self.required_features {
459            if !device.supports_feature(feature.clone()) {
460                return false;
461            }
462        }
463
464        // Check vendor
465        if let Some(ref preferred_vendor) = self.preferred_vendor {
466            if device.info.vendor != *preferred_vendor {
467                return false;
468            }
469        }
470
471        // Apply custom filter
472        if let Some(ref filter) = self.custom_filter {
473            if !filter(device) {
474                return false;
475            }
476        }
477
478        true
479    }
480}
481
482/// Unified device management interface for all backends
483pub trait DeviceManager: Send + Sync {
484    /// Enumerate all available devices for this backend type
485    fn enumerate_devices(&self) -> crate::BackendResult<Vec<Device>>;
486
487    /// Get detailed device information by ID
488    fn get_device_info(&self, device_id: usize) -> crate::BackendResult<DeviceInfo>;
489
490    /// Check if a device supports specific features
491    fn check_device_features(
492        &self,
493        device_id: usize,
494        features: &[DeviceFeature],
495    ) -> crate::BackendResult<Vec<bool>>;
496
497    /// Get optimal device configuration for the backend
498    fn get_optimal_device_config(
499        &self,
500        device_id: usize,
501    ) -> crate::BackendResult<DeviceConfiguration>;
502
503    /// Validate device availability and readiness
504    fn validate_device(&self, device_id: usize) -> crate::BackendResult<bool>;
505
506    /// Get device performance characteristics
507    fn get_performance_info(&self, device_id: usize)
508        -> crate::BackendResult<DevicePerformanceInfo>;
509}
510
511/// Device configuration for optimal performance
512#[derive(Debug, Clone)]
513pub struct DeviceConfiguration {
514    /// Optimal memory allocation size
515    pub optimal_allocation_size: usize,
516
517    /// Recommended workgroup/thread block size
518    pub workgroup_size: (u32, u32, u32),
519
520    /// Memory alignment requirements
521    pub memory_alignment: usize,
522
523    /// Concurrent operation limits
524    pub max_concurrent_operations: u32,
525
526    /// Backend-specific configuration
527    pub backend_specific: std::collections::HashMap<String, crate::backend::CapabilityValue>,
528}
529
530/// Device performance characteristics
531#[derive(Debug, Clone)]
532pub struct DevicePerformanceInfo {
533    /// Memory bandwidth in GB/s
534    pub memory_bandwidth_gbps: f32,
535
536    /// Compute throughput in GFLOPS
537    pub compute_throughput_gflops: f32,
538
539    /// Memory latency in nanoseconds
540    pub memory_latency_ns: f32,
541
542    /// Cache hierarchy information
543    pub cache_hierarchy: Vec<CacheLevel>,
544
545    /// Thermal information (if available)
546    pub thermal_info: Option<ThermalInfo>,
547
548    /// Power consumption information (if available)
549    pub power_info: Option<PowerInfo>,
550}
551
552/// Cache level information
553#[derive(Debug, Clone)]
554pub struct CacheLevel {
555    pub level: u8,
556    pub size_bytes: usize,
557    pub line_size_bytes: usize,
558    pub associativity: Option<usize>,
559}
560
561/// Thermal monitoring information
562#[derive(Debug, Clone)]
563pub struct ThermalInfo {
564    pub current_temperature_celsius: f32,
565    pub max_temperature_celsius: f32,
566    pub thermal_throttling_active: bool,
567}
568
569/// Power consumption information
570#[derive(Debug, Clone)]
571pub struct PowerInfo {
572    pub current_power_watts: f32,
573    pub max_power_watts: f32,
574    pub power_limit_watts: f32,
575}
576
577/// Common device management utilities that can be shared across backends
578pub struct DeviceUtils;
579
580impl DeviceUtils {
581    /// Validate device configuration parameters
582    pub const fn validate_device_id(device_id: usize, max_devices: usize) -> bool {
583        device_id < max_devices
584    }
585
586    /// Calculate device score for selection algorithms
587    pub fn calculate_device_score(device: &Device, requirements: &DeviceRequirements) -> f32 {
588        let mut score = 0.0;
589
590        // Memory requirement scoring
591        if let Some(min_memory) = requirements.min_memory {
592            if device.info.total_memory >= min_memory {
593                score += 20.0;
594                // Bonus for having more memory than required
595                score += (device.info.total_memory as f32 / min_memory as f32 - 1.0) * 5.0;
596            } else {
597                return 0.0; // Disqualify if insufficient memory
598            }
599        }
600
601        // Compute units requirement scoring
602        if let Some(min_compute_units) = requirements.min_compute_units {
603            if device.info.compute_units >= min_compute_units {
604                score += 15.0;
605                score += (device.info.compute_units as f32 / min_compute_units as f32 - 1.0) * 3.0;
606            } else {
607                return 0.0;
608            }
609        }
610
611        // Features requirement scoring
612        for required_feature in &requirements.required_features {
613            if device.supports_feature(required_feature.clone()) {
614                score += 10.0;
615            } else {
616                return 0.0; // Disqualify if missing required feature
617            }
618        }
619
620        // Performance scores
621        score += device.info.peak_gflops / 1000.0; // Bonus for compute performance
622        score += device.info.memory_bandwidth_gbps / 100.0; // Bonus for memory bandwidth
623
624        // Backend preference
625        match DeviceType::from(device.device_type) {
626            DeviceType::Cuda => score += 15.0,  // Prefer CUDA
627            DeviceType::Metal => score += 10.0, // Then Metal
628            DeviceType::WebGpu => score += 5.0, // Then WebGPU
629            DeviceType::Cpu => score += 1.0,    // CPU as fallback
630            _ => score += 0.0,
631        }
632
633        score
634    }
635
636    /// Check if device meets minimum requirements
637    pub fn meets_requirements(device: &Device, requirements: &DeviceRequirements) -> bool {
638        // Check memory requirement
639        if let Some(min_memory) = requirements.min_memory {
640            if device.info.total_memory < min_memory {
641                return false;
642            }
643        }
644
645        // Check compute units requirement
646        if let Some(min_compute_units) = requirements.min_compute_units {
647            if device.info.compute_units < min_compute_units {
648                return false;
649            }
650        }
651
652        // Check required features
653        for required_feature in &requirements.required_features {
654            if !device.supports_feature(required_feature.clone()) {
655                return false;
656            }
657        }
658
659        // Check backend preference
660        if let Some(preferred_backend) = requirements.preferred_backend {
661            let device_backend = match DeviceType::from(device.device_type) {
662                DeviceType::Cpu => crate::backend::BackendType::Cpu,
663                DeviceType::Cuda => crate::backend::BackendType::Cuda,
664                DeviceType::Metal => crate::backend::BackendType::Metal,
665                DeviceType::WebGpu => crate::backend::BackendType::WebGpu,
666                _ => return false,
667            };
668            if device_backend != preferred_backend {
669                return false;
670            }
671        }
672
673        true
674    }
675
676    /// Get optimal workgroup/thread block size for device
677    pub fn get_optimal_workgroup_size(device: &Device, operation_type: &str) -> (u32, u32, u32) {
678        match DeviceType::from(device.device_type) {
679            DeviceType::Cuda => {
680                // CUDA optimal sizes
681                match operation_type {
682                    "matrix_mul" => (16, 16, 1),
683                    "element_wise" => (256, 1, 1),
684                    "reduction" => (512, 1, 1),
685                    _ => (32, 32, 1),
686                }
687            }
688            DeviceType::Metal => {
689                // Metal optimal sizes
690                match operation_type {
691                    "matrix_mul" => (16, 16, 1),
692                    "element_wise" => (256, 1, 1),
693                    "reduction" => (256, 1, 1),
694                    _ => (32, 32, 1),
695                }
696            }
697            DeviceType::WebGpu => {
698                // WebGPU optimal sizes
699                match operation_type {
700                    "matrix_mul" => (8, 8, 1),
701                    "element_wise" => (64, 1, 1),
702                    "reduction" => (64, 1, 1),
703                    _ => (8, 8, 1),
704                }
705            }
706            _ => {
707                // Default fallback
708                (1, 1, 1)
709            }
710        }
711    }
712}
713
714/// Common device discovery utilities
715pub struct DeviceDiscovery;
716
717impl DeviceDiscovery {
718    /// Discover all available devices across all backends
719    pub fn discover_all() -> crate::BackendResult<Vec<(crate::backend::BackendType, Vec<Device>)>> {
720        let mut all_devices = Vec::new();
721
722        // CPU devices (always available)
723        if let Ok(cpu_devices) = Self::discover_cpu_devices() {
724            all_devices.push((crate::backend::BackendType::Cpu, cpu_devices));
725        }
726
727        // CUDA devices
728        #[cfg(feature = "cuda")]
729        if let Ok(cuda_devices) = Self::discover_cuda_devices() {
730            if !cuda_devices.is_empty() {
731                all_devices.push((crate::backend::BackendType::Cuda, cuda_devices));
732            }
733        }
734
735        // Metal devices
736        #[cfg(all(feature = "metal", target_os = "macos"))]
737        if let Ok(metal_devices) = Self::discover_metal_devices() {
738            if !metal_devices.is_empty() {
739                all_devices.push((crate::backend::BackendType::Metal, metal_devices));
740            }
741        }
742
743        // WebGPU devices
744        #[cfg(feature = "webgpu")]
745        if let Ok(webgpu_devices) = Self::discover_webgpu_devices() {
746            if !webgpu_devices.is_empty() {
747                all_devices.push((crate::backend::BackendType::WebGpu, webgpu_devices));
748            }
749        }
750
751        Ok(all_devices)
752    }
753
754    /// Find the best device based on requirements
755    pub fn find_best_device(
756        requirements: &DeviceRequirements,
757    ) -> crate::BackendResult<(crate::backend::BackendType, Device)> {
758        let all_devices = Self::discover_all()?;
759
760        let mut best_device = None;
761        let mut best_score = 0.0;
762
763        for (backend_type, devices) in all_devices {
764            for device in devices {
765                let score = Self::score_device(&device, requirements);
766                if score > best_score {
767                    best_score = score;
768                    best_device = Some((backend_type, device));
769                }
770            }
771        }
772
773        best_device.ok_or_else(|| {
774            torsh_core::error::TorshError::BackendError(
775                "No suitable device found for requirements".to_string(),
776            )
777        })
778    }
779
780    /// Score a device based on requirements
781    fn score_device(device: &Device, requirements: &DeviceRequirements) -> f32 {
782        DeviceUtils::calculate_device_score(device, requirements)
783    }
784
785    /// Discover CPU devices
786    fn discover_cpu_devices() -> crate::BackendResult<Vec<Device>> {
787        let cpu_device = crate::cpu::CpuDevice::new(0, num_cpus::get())?;
788        Ok(vec![cpu_device.to_device()])
789    }
790
791    /// Discover CUDA devices
792    #[cfg(feature = "cuda")]
793    fn discover_cuda_devices() -> crate::BackendResult<Vec<Device>> {
794        // Implementation would query CUDA runtime for available devices
795        // For now, return empty vector
796        Ok(vec![])
797    }
798
799    /// Discover Metal devices
800    #[cfg(all(feature = "metal", target_os = "macos"))]
801    fn discover_metal_devices() -> crate::BackendResult<Vec<Device>> {
802        // Implementation would query Metal framework for available devices
803        // For now, return empty vector
804        Ok(vec![])
805    }
806
807    /// Discover WebGPU devices
808    #[cfg(feature = "webgpu")]
809    fn discover_webgpu_devices() -> crate::BackendResult<Vec<Device>> {
810        // Implementation would query WebGPU for available adapters
811        // For now, return empty vector
812        Ok(vec![])
813    }
814}
815
816/// Device requirements for selection
817#[derive(Debug, Clone, Default)]
818pub struct DeviceRequirements {
819    pub min_memory: Option<usize>,
820    pub min_compute_units: Option<usize>,
821    pub required_features: Vec<DeviceFeature>,
822    pub preferred_backend: Option<crate::backend::BackendType>,
823    pub max_power_consumption: Option<f32>,
824    pub max_temperature: Option<f32>,
825}
826
827impl DeviceRequirements {
828    pub fn new() -> Self {
829        Self::default()
830    }
831
832    pub fn with_min_memory(mut self, memory: usize) -> Self {
833        self.min_memory = Some(memory);
834        self
835    }
836
837    pub fn with_min_compute_units(mut self, units: usize) -> Self {
838        self.min_compute_units = Some(units);
839        self
840    }
841
842    pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
843        self.required_features.push(feature);
844        self
845    }
846
847    pub fn with_preferred_backend(mut self, backend: crate::backend::BackendType) -> Self {
848        self.preferred_backend = Some(backend);
849        self
850    }
851}
852
853// Manual implementations for Device to work around f32 fields in DeviceInfo
854impl Eq for Device {}
855
856impl std::hash::Hash for Device {
857    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
858        self.id.hash(state);
859        self.device_type.hash(state);
860        self.name.hash(state);
861        // Skip DeviceInfo fields that contain f32 since they don't implement Hash
862        self.info.vendor.hash(state);
863        self.info.driver_version.hash(state);
864        self.info.total_memory.hash(state);
865        self.info.available_memory.hash(state);
866        self.info.compute_units.hash(state);
867        self.info.max_work_group_size.hash(state);
868        self.info.max_work_group_dimensions.hash(state);
869        self.info.clock_frequency_mhz.hash(state);
870        // Skip memory_bandwidth_gbps and peak_gflops (f32 fields)
871        self.info.features.hash(state);
872        self.info.properties.hash(state);
873    }
874}
875
876#[cfg(test)]
877mod tests {
878    use super::*;
879
880    fn create_test_device_info() -> DeviceInfo {
881        DeviceInfo {
882            vendor: "Test Vendor".to_string(),
883            driver_version: "1.0.0".to_string(),
884            total_memory: 8 * 1024 * 1024 * 1024,     // 8GB
885            available_memory: 6 * 1024 * 1024 * 1024, // 6GB
886            compute_units: 32,
887            max_work_group_size: 1024,
888            max_work_group_dimensions: vec![1024, 1024, 64],
889            clock_frequency_mhz: 1500,
890            memory_bandwidth_gbps: 500.0,
891            peak_gflops: 10000.0,
892            features: vec![
893                DeviceFeature::DoublePrecision,
894                DeviceFeature::UnifiedMemory,
895                DeviceFeature::AtomicOperations,
896            ],
897            properties: vec![
898                ("compute_capability".to_string(), "7.5".to_string()),
899                ("warp_size".to_string(), "32".to_string()),
900            ],
901        }
902    }
903
904    #[test]
905    fn test_device_creation() {
906        let info = create_test_device_info();
907        let device = Device::new(
908            0,
909            CoreDeviceType::Cuda(0),
910            "Test GPU".to_string(),
911            info.clone(),
912        );
913
914        assert_eq!(device.id(), 0);
915        assert_eq!(device.name(), "Test GPU");
916        assert_eq!(device.device_type(), CoreDeviceType::Cuda(0));
917        assert_eq!(device.info().vendor, "Test Vendor");
918        assert_eq!(device.info().compute_units, 32);
919    }
920
921    #[test]
922    fn test_device_feature_support() {
923        let info = create_test_device_info();
924        let device = Device::new(1, CoreDeviceType::Cpu, "Test CPU".to_string(), info);
925
926        assert!(device.supports_feature(DeviceFeature::DoublePrecision));
927        assert!(device.supports_feature(DeviceFeature::UnifiedMemory));
928        assert!(device.supports_feature(DeviceFeature::AtomicOperations));
929        assert!(!device.supports_feature(DeviceFeature::HalfPrecision));
930        assert!(!device.supports_feature(DeviceFeature::SubGroups));
931    }
932
933    #[test]
934    fn test_device_info_default() {
935        let info = DeviceInfo::default();
936
937        assert_eq!(info.vendor, "Unknown");
938        assert_eq!(info.driver_version, "Unknown");
939        assert_eq!(info.total_memory, 0);
940        assert_eq!(info.available_memory, 0);
941        assert_eq!(info.compute_units, 1);
942        assert_eq!(info.max_work_group_size, 256);
943        assert_eq!(info.max_work_group_dimensions, vec![256, 1, 1]);
944        assert_eq!(info.clock_frequency_mhz, 1000);
945        assert_eq!(info.memory_bandwidth_gbps, 10.0);
946        assert_eq!(info.peak_gflops, 100.0);
947        assert!(info.features.is_empty());
948        assert!(info.properties.is_empty());
949    }
950
951    #[test]
952    fn test_device_type_conversion() {
953        assert_eq!(DeviceType::from(CoreDeviceType::Cpu), DeviceType::Cpu);
954        assert_eq!(DeviceType::from(CoreDeviceType::Cuda(0)), DeviceType::Cuda);
955        assert_eq!(
956            DeviceType::from(CoreDeviceType::Metal(0)),
957            DeviceType::Metal
958        );
959        assert_eq!(
960            DeviceType::from(CoreDeviceType::Wgpu(0)),
961            DeviceType::WebGpu
962        );
963
964        assert_eq!(CoreDeviceType::from(DeviceType::Cpu), CoreDeviceType::Cpu);
965        assert_eq!(
966            CoreDeviceType::from(DeviceType::Cuda),
967            CoreDeviceType::Cuda(0)
968        );
969        assert_eq!(
970            CoreDeviceType::from(DeviceType::Metal),
971            CoreDeviceType::Metal(0)
972        );
973        assert_eq!(
974            CoreDeviceType::from(DeviceType::WebGpu),
975            CoreDeviceType::Wgpu(0)
976        );
977
978        // Fallback conversions
979        assert_eq!(
980            CoreDeviceType::from(DeviceType::OpenCl),
981            CoreDeviceType::Cpu
982        );
983        assert_eq!(
984            CoreDeviceType::from(DeviceType::Vulkan),
985            CoreDeviceType::Cpu
986        );
987        assert_eq!(
988            CoreDeviceType::from(DeviceType::Custom),
989            CoreDeviceType::Cpu
990        );
991    }
992
993    #[test]
994    fn test_device_feature_variants() {
995        let features = [
996            DeviceFeature::DoublePrecision,
997            DeviceFeature::HalfPrecision,
998            DeviceFeature::UnifiedMemory,
999            DeviceFeature::AtomicOperations,
1000            DeviceFeature::SubGroups,
1001            DeviceFeature::Printf,
1002            DeviceFeature::Profiling,
1003            DeviceFeature::PeerToPeer,
1004            DeviceFeature::ConcurrentExecution,
1005            DeviceFeature::AsyncMemory,
1006            DeviceFeature::ImageSupport,
1007            DeviceFeature::FastMath,
1008            DeviceFeature::Custom("CustomFeature".to_string()),
1009        ];
1010
1011        // Ensure all features are distinct
1012        for (i, feature1) in features.iter().enumerate() {
1013            for (j, feature2) in features.iter().enumerate() {
1014                if i != j {
1015                    assert_ne!(feature1, feature2);
1016                }
1017            }
1018        }
1019    }
1020
1021    #[test]
1022    fn test_device_selector_creation() {
1023        let selector = DeviceSelector::new();
1024
1025        assert_eq!(selector.device_type, None);
1026        assert_eq!(selector.min_memory, None);
1027        assert_eq!(selector.min_compute_units, None);
1028        assert!(selector.required_features.is_empty());
1029        assert_eq!(selector.preferred_vendor, None);
1030        assert!(selector.custom_filter.is_none());
1031    }
1032
1033    #[test]
1034    fn test_device_selector_builder() {
1035        let selector = DeviceSelector::new()
1036            .with_device_type(DeviceType::Cuda)
1037            .with_min_memory(4 * 1024 * 1024 * 1024) // 4GB
1038            .with_min_compute_units(16)
1039            .with_feature(DeviceFeature::DoublePrecision)
1040            .with_feature(DeviceFeature::AtomicOperations)
1041            .with_vendor("NVIDIA".to_string());
1042
1043        assert_eq!(selector.device_type, Some(DeviceType::Cuda));
1044        assert_eq!(selector.min_memory, Some(4 * 1024 * 1024 * 1024));
1045        assert_eq!(selector.min_compute_units, Some(16));
1046        assert_eq!(selector.required_features.len(), 2);
1047        assert!(selector
1048            .required_features
1049            .contains(&DeviceFeature::DoublePrecision));
1050        assert!(selector
1051            .required_features
1052            .contains(&DeviceFeature::AtomicOperations));
1053        assert_eq!(selector.preferred_vendor, Some("NVIDIA".to_string()));
1054    }
1055
1056    #[test]
1057    fn test_device_selector_matching() {
1058        let mut info = create_test_device_info();
1059        info.vendor = "NVIDIA".to_string();
1060        info.total_memory = 8 * 1024 * 1024 * 1024; // 8GB
1061        info.compute_units = 32;
1062
1063        let device = Device::new(0, CoreDeviceType::Cuda(0), "RTX 4090".to_string(), info);
1064
1065        // Should match
1066        let selector1 = DeviceSelector::new()
1067            .with_device_type(DeviceType::Cuda)
1068            .with_min_memory(4 * 1024 * 1024 * 1024) // 4GB
1069            .with_min_compute_units(16)
1070            .with_feature(DeviceFeature::DoublePrecision)
1071            .with_vendor("NVIDIA".to_string());
1072
1073        assert!(selector1.matches(&device));
1074
1075        // Should not match - insufficient memory
1076        let selector2 = DeviceSelector::new().with_min_memory(16 * 1024 * 1024 * 1024); // 16GB
1077
1078        assert!(!selector2.matches(&device));
1079
1080        // Should not match - missing feature
1081        let selector3 = DeviceSelector::new().with_feature(DeviceFeature::HalfPrecision);
1082
1083        assert!(!selector3.matches(&device));
1084
1085        // Should not match - wrong vendor
1086        let selector4 = DeviceSelector::new().with_vendor("AMD".to_string());
1087
1088        assert!(!selector4.matches(&device));
1089    }
1090
1091    #[test]
1092    fn test_custom_device_feature() {
1093        let custom_feature1 = DeviceFeature::Custom("TensorCores".to_string());
1094        let custom_feature2 = DeviceFeature::Custom("TensorCores".to_string());
1095        let custom_feature3 = DeviceFeature::Custom("RTCores".to_string());
1096
1097        assert_eq!(custom_feature1, custom_feature2);
1098        assert_ne!(custom_feature1, custom_feature3);
1099
1100        let mut info = DeviceInfo::default();
1101        info.features.push(custom_feature1.clone());
1102
1103        let device = Device::new(0, CoreDeviceType::Cuda(0), "Custom GPU".to_string(), info);
1104        assert!(device.supports_feature(custom_feature1));
1105        assert!(!device.supports_feature(custom_feature3));
1106    }
1107}