1use torsh_core::device::DeviceType as CoreDeviceType;
4
5#[cfg(not(feature = "std"))]
6use alloc::{string::String, vec::Vec};
7
8#[derive(Debug, Clone, PartialEq)]
10pub struct Device {
11 pub id: usize,
13
14 pub device_type: CoreDeviceType,
16
17 pub name: String,
19
20 pub info: DeviceInfo,
22}
23
24impl Device {
25 pub fn new(id: usize, device_type: CoreDeviceType, name: String, info: DeviceInfo) -> Self {
27 Self {
28 id,
29 device_type,
30 name,
31 info,
32 }
33 }
34
35 pub fn builder() -> DeviceBuilder {
37 DeviceBuilder::new()
38 }
39
40 pub const fn id(&self) -> usize {
42 self.id
43 }
44
45 pub const fn device_type(&self) -> CoreDeviceType {
47 self.device_type
48 }
49
50 pub fn name(&self) -> &str {
52 &self.name
53 }
54
55 pub fn info(&self) -> &DeviceInfo {
57 &self.info
58 }
59
60 pub fn supports_feature(&self, feature: DeviceFeature) -> bool {
62 self.info.features.contains(&feature)
63 }
64
65 pub fn cpu() -> crate::BackendResult<Self> {
67 DeviceBuilder::new()
68 .with_device_type(CoreDeviceType::Cpu)
69 .with_name("CPU".to_string())
70 .with_vendor("Generic".to_string())
71 .with_compute_units(num_cpus::get())
72 .build()
73 }
74}
75
76#[derive(Debug, Clone, PartialEq)]
78pub struct DeviceInfo {
79 pub vendor: String,
81
82 pub driver_version: String,
84
85 pub total_memory: usize,
87
88 pub available_memory: usize,
90
91 pub compute_units: usize,
93
94 pub max_work_group_size: usize,
96
97 pub max_work_group_dimensions: Vec<usize>,
99
100 pub clock_frequency_mhz: u32,
102
103 pub memory_bandwidth_gbps: f32,
105
106 pub peak_gflops: f32,
108
109 pub features: Vec<DeviceFeature>,
111
112 pub properties: Vec<(String, String)>,
114}
115
116impl Default for DeviceInfo {
117 fn default() -> Self {
118 Self {
119 vendor: "Unknown".to_string(),
120 driver_version: "Unknown".to_string(),
121 total_memory: 0,
122 available_memory: 0,
123 compute_units: 1,
124 max_work_group_size: 256,
125 max_work_group_dimensions: vec![256, 1, 1],
126 clock_frequency_mhz: 1000,
127 memory_bandwidth_gbps: 10.0,
128 peak_gflops: 100.0,
129 features: Vec::new(),
130 properties: Vec::new(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, Hash)]
137pub enum DeviceFeature {
138 DoublePrecision,
140
141 HalfPrecision,
143
144 UnifiedMemory,
146
147 AtomicOperations,
149
150 SubGroups,
152
153 Printf,
155
156 Profiling,
158
159 PeerToPeer,
161
162 ConcurrentExecution,
164
165 AsyncMemory,
167
168 ImageSupport,
170
171 FastMath,
173
174 TimestampQuery,
177
178 TimestampQueryInsideEncoders,
180
181 PipelineStatistics,
183
184 MappableBuffers,
186
187 BufferArrays,
189
190 StorageArrays,
192
193 UnsizedBindingArray,
195
196 IndirectFirstInstance,
198
199 ShaderF16,
201
202 ShaderI16,
204
205 ShaderPrimitiveIndex,
207
208 ShaderEarlyDepthTest,
210
211 MultiDrawIndirect,
213
214 MultiDrawIndirectCount,
216
217 Multisampling,
219
220 ClearTexture,
222
223 SpirvShaderPassthrough,
225
226 Custom(String),
228}
229
230#[derive(Debug, Clone)]
232pub struct DeviceBuilder {
233 id: usize,
234 device_type: Option<CoreDeviceType>,
235 name: Option<String>,
236 info: DeviceInfo,
237}
238
239impl DeviceBuilder {
240 pub fn new() -> Self {
241 Self {
242 id: 0,
243 device_type: None,
244 name: None,
245 info: DeviceInfo::default(),
246 }
247 }
248
249 pub fn with_id(mut self, id: usize) -> Self {
250 self.id = id;
251 self
252 }
253
254 pub fn with_device_type(mut self, device_type: CoreDeviceType) -> Self {
255 self.device_type = Some(device_type);
256 self
257 }
258
259 pub fn with_name(mut self, name: String) -> Self {
260 self.name = Some(name);
261 self
262 }
263
264 pub fn with_vendor(mut self, vendor: String) -> Self {
265 self.info.vendor = vendor;
266 self
267 }
268
269 pub fn with_driver_version(mut self, version: String) -> Self {
270 self.info.driver_version = version;
271 self
272 }
273
274 pub fn with_memory(mut self, total: usize, available: usize) -> Self {
275 self.info.total_memory = total;
276 self.info.available_memory = available;
277 self
278 }
279
280 pub fn with_compute_units(mut self, units: usize) -> Self {
281 self.info.compute_units = units;
282 self
283 }
284
285 pub fn with_performance(mut self, gflops: f32, bandwidth_gbps: f32) -> Self {
286 self.info.peak_gflops = gflops;
287 self.info.memory_bandwidth_gbps = bandwidth_gbps;
288 self
289 }
290
291 pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
292 self.info.features.push(feature);
293 self
294 }
295
296 pub fn with_property(mut self, key: String, value: String) -> Self {
297 self.info.properties.push((key, value));
298 self
299 }
300
301 pub fn build(self) -> crate::BackendResult<Device> {
302 let device_type = self.device_type.ok_or_else(|| {
303 torsh_core::error::TorshError::BackendError("Device type is required".to_string())
304 })?;
305
306 let name = self.name.ok_or_else(|| {
307 torsh_core::error::TorshError::BackendError("Device name is required".to_string())
308 })?;
309
310 Ok(Device {
311 id: self.id,
312 device_type,
313 name,
314 info: self.info,
315 })
316 }
317}
318
319impl Default for DeviceBuilder {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
327pub enum DeviceType {
328 Cpu,
330
331 Cuda,
333
334 Metal,
336
337 WebGpu,
339
340 OpenCl,
342
343 Vulkan,
345
346 Custom,
348}
349
350impl From<CoreDeviceType> for DeviceType {
351 fn from(core_type: CoreDeviceType) -> Self {
352 match core_type {
353 CoreDeviceType::Cpu => DeviceType::Cpu,
354 CoreDeviceType::Cuda(_) => DeviceType::Cuda,
355 CoreDeviceType::Metal(_) => DeviceType::Metal,
356 CoreDeviceType::Wgpu(_) => DeviceType::WebGpu,
357 }
358 }
359}
360
361impl From<DeviceType> for CoreDeviceType {
362 fn from(device_type: DeviceType) -> Self {
363 match device_type {
364 DeviceType::Cpu => CoreDeviceType::Cpu,
365 DeviceType::Cuda => CoreDeviceType::Cuda(0), DeviceType::Metal => CoreDeviceType::Metal(0), DeviceType::WebGpu => CoreDeviceType::Wgpu(0), DeviceType::OpenCl => CoreDeviceType::Cpu, DeviceType::Vulkan => CoreDeviceType::Cpu, DeviceType::Custom => CoreDeviceType::Cpu, }
372 }
373}
374
375#[derive(Default)]
377pub struct DeviceSelector {
378 pub device_type: Option<DeviceType>,
380
381 pub min_memory: Option<usize>,
383
384 pub min_compute_units: Option<usize>,
386
387 pub required_features: Vec<DeviceFeature>,
389
390 pub preferred_vendor: Option<String>,
392
393 #[allow(clippy::type_complexity)]
395 pub custom_filter: Option<Box<dyn Fn(&Device) -> bool + Send + Sync>>,
396}
397
398impl DeviceSelector {
399 pub fn new() -> Self {
401 Self::default()
402 }
403
404 pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
406 self.device_type = Some(device_type);
407 self
408 }
409
410 pub fn with_min_memory(mut self, min_memory: usize) -> Self {
412 self.min_memory = Some(min_memory);
413 self
414 }
415
416 pub fn with_min_compute_units(mut self, min_compute_units: usize) -> Self {
418 self.min_compute_units = Some(min_compute_units);
419 self
420 }
421
422 pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
424 self.required_features.push(feature);
425 self
426 }
427
428 pub fn with_vendor(mut self, vendor: String) -> Self {
430 self.preferred_vendor = Some(vendor);
431 self
432 }
433
434 pub fn matches(&self, device: &Device) -> bool {
436 if let Some(required_type) = &self.device_type {
438 if device.device_type != (*required_type).into() {
439 return false;
440 }
441 }
442
443 if let Some(min_memory) = self.min_memory {
445 if device.info.total_memory < min_memory {
446 return false;
447 }
448 }
449
450 if let Some(min_compute_units) = self.min_compute_units {
452 if device.info.compute_units < min_compute_units {
453 return false;
454 }
455 }
456
457 for feature in &self.required_features {
459 if !device.supports_feature(feature.clone()) {
460 return false;
461 }
462 }
463
464 if let Some(ref preferred_vendor) = self.preferred_vendor {
466 if device.info.vendor != *preferred_vendor {
467 return false;
468 }
469 }
470
471 if let Some(ref filter) = self.custom_filter {
473 if !filter(device) {
474 return false;
475 }
476 }
477
478 true
479 }
480}
481
482pub trait DeviceManager: Send + Sync {
484 fn enumerate_devices(&self) -> crate::BackendResult<Vec<Device>>;
486
487 fn get_device_info(&self, device_id: usize) -> crate::BackendResult<DeviceInfo>;
489
490 fn check_device_features(
492 &self,
493 device_id: usize,
494 features: &[DeviceFeature],
495 ) -> crate::BackendResult<Vec<bool>>;
496
497 fn get_optimal_device_config(
499 &self,
500 device_id: usize,
501 ) -> crate::BackendResult<DeviceConfiguration>;
502
503 fn validate_device(&self, device_id: usize) -> crate::BackendResult<bool>;
505
506 fn get_performance_info(&self, device_id: usize)
508 -> crate::BackendResult<DevicePerformanceInfo>;
509}
510
511#[derive(Debug, Clone)]
513pub struct DeviceConfiguration {
514 pub optimal_allocation_size: usize,
516
517 pub workgroup_size: (u32, u32, u32),
519
520 pub memory_alignment: usize,
522
523 pub max_concurrent_operations: u32,
525
526 pub backend_specific: std::collections::HashMap<String, crate::backend::CapabilityValue>,
528}
529
530#[derive(Debug, Clone)]
532pub struct DevicePerformanceInfo {
533 pub memory_bandwidth_gbps: f32,
535
536 pub compute_throughput_gflops: f32,
538
539 pub memory_latency_ns: f32,
541
542 pub cache_hierarchy: Vec<CacheLevel>,
544
545 pub thermal_info: Option<ThermalInfo>,
547
548 pub power_info: Option<PowerInfo>,
550}
551
552#[derive(Debug, Clone)]
554pub struct CacheLevel {
555 pub level: u8,
556 pub size_bytes: usize,
557 pub line_size_bytes: usize,
558 pub associativity: Option<usize>,
559}
560
561#[derive(Debug, Clone)]
563pub struct ThermalInfo {
564 pub current_temperature_celsius: f32,
565 pub max_temperature_celsius: f32,
566 pub thermal_throttling_active: bool,
567}
568
569#[derive(Debug, Clone)]
571pub struct PowerInfo {
572 pub current_power_watts: f32,
573 pub max_power_watts: f32,
574 pub power_limit_watts: f32,
575}
576
577pub struct DeviceUtils;
579
580impl DeviceUtils {
581 pub const fn validate_device_id(device_id: usize, max_devices: usize) -> bool {
583 device_id < max_devices
584 }
585
586 pub fn calculate_device_score(device: &Device, requirements: &DeviceRequirements) -> f32 {
588 let mut score = 0.0;
589
590 if let Some(min_memory) = requirements.min_memory {
592 if device.info.total_memory >= min_memory {
593 score += 20.0;
594 score += (device.info.total_memory as f32 / min_memory as f32 - 1.0) * 5.0;
596 } else {
597 return 0.0; }
599 }
600
601 if let Some(min_compute_units) = requirements.min_compute_units {
603 if device.info.compute_units >= min_compute_units {
604 score += 15.0;
605 score += (device.info.compute_units as f32 / min_compute_units as f32 - 1.0) * 3.0;
606 } else {
607 return 0.0;
608 }
609 }
610
611 for required_feature in &requirements.required_features {
613 if device.supports_feature(required_feature.clone()) {
614 score += 10.0;
615 } else {
616 return 0.0; }
618 }
619
620 score += device.info.peak_gflops / 1000.0; score += device.info.memory_bandwidth_gbps / 100.0; match DeviceType::from(device.device_type) {
626 DeviceType::Cuda => score += 15.0, DeviceType::Metal => score += 10.0, DeviceType::WebGpu => score += 5.0, DeviceType::Cpu => score += 1.0, _ => score += 0.0,
631 }
632
633 score
634 }
635
636 pub fn meets_requirements(device: &Device, requirements: &DeviceRequirements) -> bool {
638 if let Some(min_memory) = requirements.min_memory {
640 if device.info.total_memory < min_memory {
641 return false;
642 }
643 }
644
645 if let Some(min_compute_units) = requirements.min_compute_units {
647 if device.info.compute_units < min_compute_units {
648 return false;
649 }
650 }
651
652 for required_feature in &requirements.required_features {
654 if !device.supports_feature(required_feature.clone()) {
655 return false;
656 }
657 }
658
659 if let Some(preferred_backend) = requirements.preferred_backend {
661 let device_backend = match DeviceType::from(device.device_type) {
662 DeviceType::Cpu => crate::backend::BackendType::Cpu,
663 DeviceType::Cuda => crate::backend::BackendType::Cuda,
664 DeviceType::Metal => crate::backend::BackendType::Metal,
665 DeviceType::WebGpu => crate::backend::BackendType::WebGpu,
666 _ => return false,
667 };
668 if device_backend != preferred_backend {
669 return false;
670 }
671 }
672
673 true
674 }
675
676 pub fn get_optimal_workgroup_size(device: &Device, operation_type: &str) -> (u32, u32, u32) {
678 match DeviceType::from(device.device_type) {
679 DeviceType::Cuda => {
680 match operation_type {
682 "matrix_mul" => (16, 16, 1),
683 "element_wise" => (256, 1, 1),
684 "reduction" => (512, 1, 1),
685 _ => (32, 32, 1),
686 }
687 }
688 DeviceType::Metal => {
689 match operation_type {
691 "matrix_mul" => (16, 16, 1),
692 "element_wise" => (256, 1, 1),
693 "reduction" => (256, 1, 1),
694 _ => (32, 32, 1),
695 }
696 }
697 DeviceType::WebGpu => {
698 match operation_type {
700 "matrix_mul" => (8, 8, 1),
701 "element_wise" => (64, 1, 1),
702 "reduction" => (64, 1, 1),
703 _ => (8, 8, 1),
704 }
705 }
706 _ => {
707 (1, 1, 1)
709 }
710 }
711 }
712}
713
714pub struct DeviceDiscovery;
716
717impl DeviceDiscovery {
718 pub fn discover_all() -> crate::BackendResult<Vec<(crate::backend::BackendType, Vec<Device>)>> {
720 let mut all_devices = Vec::new();
721
722 if let Ok(cpu_devices) = Self::discover_cpu_devices() {
724 all_devices.push((crate::backend::BackendType::Cpu, cpu_devices));
725 }
726
727 #[cfg(feature = "cuda")]
729 if let Ok(cuda_devices) = Self::discover_cuda_devices() {
730 if !cuda_devices.is_empty() {
731 all_devices.push((crate::backend::BackendType::Cuda, cuda_devices));
732 }
733 }
734
735 #[cfg(all(feature = "metal", target_os = "macos"))]
737 if let Ok(metal_devices) = Self::discover_metal_devices() {
738 if !metal_devices.is_empty() {
739 all_devices.push((crate::backend::BackendType::Metal, metal_devices));
740 }
741 }
742
743 #[cfg(feature = "webgpu")]
745 if let Ok(webgpu_devices) = Self::discover_webgpu_devices() {
746 if !webgpu_devices.is_empty() {
747 all_devices.push((crate::backend::BackendType::WebGpu, webgpu_devices));
748 }
749 }
750
751 Ok(all_devices)
752 }
753
754 pub fn find_best_device(
756 requirements: &DeviceRequirements,
757 ) -> crate::BackendResult<(crate::backend::BackendType, Device)> {
758 let all_devices = Self::discover_all()?;
759
760 let mut best_device = None;
761 let mut best_score = 0.0;
762
763 for (backend_type, devices) in all_devices {
764 for device in devices {
765 let score = Self::score_device(&device, requirements);
766 if score > best_score {
767 best_score = score;
768 best_device = Some((backend_type, device));
769 }
770 }
771 }
772
773 best_device.ok_or_else(|| {
774 torsh_core::error::TorshError::BackendError(
775 "No suitable device found for requirements".to_string(),
776 )
777 })
778 }
779
780 fn score_device(device: &Device, requirements: &DeviceRequirements) -> f32 {
782 DeviceUtils::calculate_device_score(device, requirements)
783 }
784
785 fn discover_cpu_devices() -> crate::BackendResult<Vec<Device>> {
787 let cpu_device = crate::cpu::CpuDevice::new(0, num_cpus::get())?;
788 Ok(vec![cpu_device.to_device()])
789 }
790
791 #[cfg(feature = "cuda")]
793 fn discover_cuda_devices() -> crate::BackendResult<Vec<Device>> {
794 Ok(vec![])
797 }
798
799 #[cfg(all(feature = "metal", target_os = "macos"))]
801 fn discover_metal_devices() -> crate::BackendResult<Vec<Device>> {
802 Ok(vec![])
805 }
806
807 #[cfg(feature = "webgpu")]
809 fn discover_webgpu_devices() -> crate::BackendResult<Vec<Device>> {
810 Ok(vec![])
813 }
814}
815
816#[derive(Debug, Clone, Default)]
818pub struct DeviceRequirements {
819 pub min_memory: Option<usize>,
820 pub min_compute_units: Option<usize>,
821 pub required_features: Vec<DeviceFeature>,
822 pub preferred_backend: Option<crate::backend::BackendType>,
823 pub max_power_consumption: Option<f32>,
824 pub max_temperature: Option<f32>,
825}
826
827impl DeviceRequirements {
828 pub fn new() -> Self {
829 Self::default()
830 }
831
832 pub fn with_min_memory(mut self, memory: usize) -> Self {
833 self.min_memory = Some(memory);
834 self
835 }
836
837 pub fn with_min_compute_units(mut self, units: usize) -> Self {
838 self.min_compute_units = Some(units);
839 self
840 }
841
842 pub fn with_feature(mut self, feature: DeviceFeature) -> Self {
843 self.required_features.push(feature);
844 self
845 }
846
847 pub fn with_preferred_backend(mut self, backend: crate::backend::BackendType) -> Self {
848 self.preferred_backend = Some(backend);
849 self
850 }
851}
852
853impl Eq for Device {}
855
856impl std::hash::Hash for Device {
857 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
858 self.id.hash(state);
859 self.device_type.hash(state);
860 self.name.hash(state);
861 self.info.vendor.hash(state);
863 self.info.driver_version.hash(state);
864 self.info.total_memory.hash(state);
865 self.info.available_memory.hash(state);
866 self.info.compute_units.hash(state);
867 self.info.max_work_group_size.hash(state);
868 self.info.max_work_group_dimensions.hash(state);
869 self.info.clock_frequency_mhz.hash(state);
870 self.info.features.hash(state);
872 self.info.properties.hash(state);
873 }
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879
880 fn create_test_device_info() -> DeviceInfo {
881 DeviceInfo {
882 vendor: "Test Vendor".to_string(),
883 driver_version: "1.0.0".to_string(),
884 total_memory: 8 * 1024 * 1024 * 1024, available_memory: 6 * 1024 * 1024 * 1024, compute_units: 32,
887 max_work_group_size: 1024,
888 max_work_group_dimensions: vec![1024, 1024, 64],
889 clock_frequency_mhz: 1500,
890 memory_bandwidth_gbps: 500.0,
891 peak_gflops: 10000.0,
892 features: vec![
893 DeviceFeature::DoublePrecision,
894 DeviceFeature::UnifiedMemory,
895 DeviceFeature::AtomicOperations,
896 ],
897 properties: vec![
898 ("compute_capability".to_string(), "7.5".to_string()),
899 ("warp_size".to_string(), "32".to_string()),
900 ],
901 }
902 }
903
904 #[test]
905 fn test_device_creation() {
906 let info = create_test_device_info();
907 let device = Device::new(
908 0,
909 CoreDeviceType::Cuda(0),
910 "Test GPU".to_string(),
911 info.clone(),
912 );
913
914 assert_eq!(device.id(), 0);
915 assert_eq!(device.name(), "Test GPU");
916 assert_eq!(device.device_type(), CoreDeviceType::Cuda(0));
917 assert_eq!(device.info().vendor, "Test Vendor");
918 assert_eq!(device.info().compute_units, 32);
919 }
920
921 #[test]
922 fn test_device_feature_support() {
923 let info = create_test_device_info();
924 let device = Device::new(1, CoreDeviceType::Cpu, "Test CPU".to_string(), info);
925
926 assert!(device.supports_feature(DeviceFeature::DoublePrecision));
927 assert!(device.supports_feature(DeviceFeature::UnifiedMemory));
928 assert!(device.supports_feature(DeviceFeature::AtomicOperations));
929 assert!(!device.supports_feature(DeviceFeature::HalfPrecision));
930 assert!(!device.supports_feature(DeviceFeature::SubGroups));
931 }
932
933 #[test]
934 fn test_device_info_default() {
935 let info = DeviceInfo::default();
936
937 assert_eq!(info.vendor, "Unknown");
938 assert_eq!(info.driver_version, "Unknown");
939 assert_eq!(info.total_memory, 0);
940 assert_eq!(info.available_memory, 0);
941 assert_eq!(info.compute_units, 1);
942 assert_eq!(info.max_work_group_size, 256);
943 assert_eq!(info.max_work_group_dimensions, vec![256, 1, 1]);
944 assert_eq!(info.clock_frequency_mhz, 1000);
945 assert_eq!(info.memory_bandwidth_gbps, 10.0);
946 assert_eq!(info.peak_gflops, 100.0);
947 assert!(info.features.is_empty());
948 assert!(info.properties.is_empty());
949 }
950
951 #[test]
952 fn test_device_type_conversion() {
953 assert_eq!(DeviceType::from(CoreDeviceType::Cpu), DeviceType::Cpu);
954 assert_eq!(DeviceType::from(CoreDeviceType::Cuda(0)), DeviceType::Cuda);
955 assert_eq!(
956 DeviceType::from(CoreDeviceType::Metal(0)),
957 DeviceType::Metal
958 );
959 assert_eq!(
960 DeviceType::from(CoreDeviceType::Wgpu(0)),
961 DeviceType::WebGpu
962 );
963
964 assert_eq!(CoreDeviceType::from(DeviceType::Cpu), CoreDeviceType::Cpu);
965 assert_eq!(
966 CoreDeviceType::from(DeviceType::Cuda),
967 CoreDeviceType::Cuda(0)
968 );
969 assert_eq!(
970 CoreDeviceType::from(DeviceType::Metal),
971 CoreDeviceType::Metal(0)
972 );
973 assert_eq!(
974 CoreDeviceType::from(DeviceType::WebGpu),
975 CoreDeviceType::Wgpu(0)
976 );
977
978 assert_eq!(
980 CoreDeviceType::from(DeviceType::OpenCl),
981 CoreDeviceType::Cpu
982 );
983 assert_eq!(
984 CoreDeviceType::from(DeviceType::Vulkan),
985 CoreDeviceType::Cpu
986 );
987 assert_eq!(
988 CoreDeviceType::from(DeviceType::Custom),
989 CoreDeviceType::Cpu
990 );
991 }
992
993 #[test]
994 fn test_device_feature_variants() {
995 let features = [
996 DeviceFeature::DoublePrecision,
997 DeviceFeature::HalfPrecision,
998 DeviceFeature::UnifiedMemory,
999 DeviceFeature::AtomicOperations,
1000 DeviceFeature::SubGroups,
1001 DeviceFeature::Printf,
1002 DeviceFeature::Profiling,
1003 DeviceFeature::PeerToPeer,
1004 DeviceFeature::ConcurrentExecution,
1005 DeviceFeature::AsyncMemory,
1006 DeviceFeature::ImageSupport,
1007 DeviceFeature::FastMath,
1008 DeviceFeature::Custom("CustomFeature".to_string()),
1009 ];
1010
1011 for (i, feature1) in features.iter().enumerate() {
1013 for (j, feature2) in features.iter().enumerate() {
1014 if i != j {
1015 assert_ne!(feature1, feature2);
1016 }
1017 }
1018 }
1019 }
1020
1021 #[test]
1022 fn test_device_selector_creation() {
1023 let selector = DeviceSelector::new();
1024
1025 assert_eq!(selector.device_type, None);
1026 assert_eq!(selector.min_memory, None);
1027 assert_eq!(selector.min_compute_units, None);
1028 assert!(selector.required_features.is_empty());
1029 assert_eq!(selector.preferred_vendor, None);
1030 assert!(selector.custom_filter.is_none());
1031 }
1032
1033 #[test]
1034 fn test_device_selector_builder() {
1035 let selector = DeviceSelector::new()
1036 .with_device_type(DeviceType::Cuda)
1037 .with_min_memory(4 * 1024 * 1024 * 1024) .with_min_compute_units(16)
1039 .with_feature(DeviceFeature::DoublePrecision)
1040 .with_feature(DeviceFeature::AtomicOperations)
1041 .with_vendor("NVIDIA".to_string());
1042
1043 assert_eq!(selector.device_type, Some(DeviceType::Cuda));
1044 assert_eq!(selector.min_memory, Some(4 * 1024 * 1024 * 1024));
1045 assert_eq!(selector.min_compute_units, Some(16));
1046 assert_eq!(selector.required_features.len(), 2);
1047 assert!(selector
1048 .required_features
1049 .contains(&DeviceFeature::DoublePrecision));
1050 assert!(selector
1051 .required_features
1052 .contains(&DeviceFeature::AtomicOperations));
1053 assert_eq!(selector.preferred_vendor, Some("NVIDIA".to_string()));
1054 }
1055
1056 #[test]
1057 fn test_device_selector_matching() {
1058 let mut info = create_test_device_info();
1059 info.vendor = "NVIDIA".to_string();
1060 info.total_memory = 8 * 1024 * 1024 * 1024; info.compute_units = 32;
1062
1063 let device = Device::new(0, CoreDeviceType::Cuda(0), "RTX 4090".to_string(), info);
1064
1065 let selector1 = DeviceSelector::new()
1067 .with_device_type(DeviceType::Cuda)
1068 .with_min_memory(4 * 1024 * 1024 * 1024) .with_min_compute_units(16)
1070 .with_feature(DeviceFeature::DoublePrecision)
1071 .with_vendor("NVIDIA".to_string());
1072
1073 assert!(selector1.matches(&device));
1074
1075 let selector2 = DeviceSelector::new().with_min_memory(16 * 1024 * 1024 * 1024); assert!(!selector2.matches(&device));
1079
1080 let selector3 = DeviceSelector::new().with_feature(DeviceFeature::HalfPrecision);
1082
1083 assert!(!selector3.matches(&device));
1084
1085 let selector4 = DeviceSelector::new().with_vendor("AMD".to_string());
1087
1088 assert!(!selector4.matches(&device));
1089 }
1090
1091 #[test]
1092 fn test_custom_device_feature() {
1093 let custom_feature1 = DeviceFeature::Custom("TensorCores".to_string());
1094 let custom_feature2 = DeviceFeature::Custom("TensorCores".to_string());
1095 let custom_feature3 = DeviceFeature::Custom("RTCores".to_string());
1096
1097 assert_eq!(custom_feature1, custom_feature2);
1098 assert_ne!(custom_feature1, custom_feature3);
1099
1100 let mut info = DeviceInfo::default();
1101 info.features.push(custom_feature1.clone());
1102
1103 let device = Device::new(0, CoreDeviceType::Cuda(0), "Custom GPU".to_string(), info);
1104 assert!(device.supports_feature(custom_feature1));
1105 assert!(!device.supports_feature(custom_feature3));
1106 }
1107}