1#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
4compile_error!("signinum-core only supports x86_64 and aarch64 targets");
5
6use core::sync::atomic::{AtomicU8, Ordering};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum BackendKind {
11 Cpu,
13 Metal,
15 Cuda,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
21pub enum BackendRequest {
22 #[default]
24 Auto,
25 Cpu,
27 Metal,
29 Cuda,
31}
32
33impl BackendRequest {
34 pub const ACCELERATED: Self = Self::Auto;
37 pub const CPU_ONLY: Self = Self::Cpu;
39 pub const STRICT_METAL: Self = Self::Metal;
41 pub const STRICT_CUDA: Self = Self::Cuda;
43}
44
45#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
47pub struct CpuFeatures {
48 pub avx2: bool,
50 pub sse41: bool,
52 pub neon: bool,
54}
55
56impl CpuFeatures {
57 pub fn detect() -> Self {
59 static DETECTED: AtomicU8 = AtomicU8::new(0);
60
61 let cached = DETECTED.load(Ordering::Acquire);
62 if cached != 0 {
63 return Self::from_cache_byte(cached);
64 }
65
66 let detected = Self::detect_uncached();
67 let encoded = detected.to_cache_byte();
68 let _ = DETECTED.compare_exchange(0, encoded, Ordering::AcqRel, Ordering::Acquire);
69 Self::from_cache_byte(DETECTED.load(Ordering::Acquire))
70 }
71
72 fn detect_uncached() -> Self {
73 #[cfg(target_arch = "x86_64")]
74 {
75 Self {
76 avx2: detect_x86_avx2(),
77 sse41: detect_x86_sse41(),
78 neon: false,
79 }
80 }
81
82 #[cfg(target_arch = "aarch64")]
83 {
84 Self {
85 avx2: false,
86 sse41: false,
87 neon: true,
88 }
89 }
90 }
91
92 const fn to_cache_byte(self) -> u8 {
93 let mut encoded = 1_u8;
94 if self.avx2 {
95 encoded |= 1 << 1;
96 }
97 if self.sse41 {
98 encoded |= 1 << 2;
99 }
100 if self.neon {
101 encoded |= 1 << 3;
102 }
103 encoded
104 }
105
106 const fn from_cache_byte(encoded: u8) -> Self {
107 let bits = encoded.saturating_sub(1);
108 Self {
109 avx2: (bits & (1 << 1)) != 0,
110 sse41: (bits & (1 << 2)) != 0,
111 neon: (bits & (1 << 3)) != 0,
112 }
113 }
114}
115
116#[cfg(target_arch = "x86_64")]
117fn detect_x86_sse41() -> bool {
118 let features = core::arch::x86_64::__cpuid(1);
119 (features.ecx & (1 << 19)) != 0
120}
121
122#[cfg(target_arch = "x86_64")]
123fn detect_x86_avx2() -> bool {
124 let leaf1 = core::arch::x86_64::__cpuid(1);
125 let osxsave = (leaf1.ecx & (1 << 27)) != 0;
126 let avx = (leaf1.ecx & (1 << 28)) != 0;
127 if !(osxsave && avx) {
128 return false;
129 }
130
131 let xcr0 = unsafe { core::arch::x86_64::_xgetbv(0) };
133 let xmm_enabled = (xcr0 & 0b10) != 0;
134 let ymm_enabled = (xcr0 & 0b100) != 0;
135 if !(xmm_enabled && ymm_enabled) {
136 return false;
137 }
138
139 let max_leaf = core::arch::x86_64::__cpuid(0).eax;
140 if max_leaf < 7 {
141 return false;
142 }
143
144 let leaf7 = core::arch::x86_64::__cpuid_count(7, 0);
145 (leaf7.ebx & (1 << 5)) != 0
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub struct BackendCapabilities {
151 pub cpu: CpuFeatures,
153 pub metal: bool,
155 pub cuda: bool,
157}
158
159impl BackendCapabilities {
160 #[must_use]
166 pub fn compile_time_defaults() -> Self {
167 Self {
168 cpu: CpuFeatures::detect(),
169 metal: cfg!(target_os = "macos"),
170 cuda: false,
171 }
172 }
173
174 #[must_use]
176 pub const fn supports(self, request: BackendRequest) -> bool {
177 match request {
178 BackendRequest::Auto | BackendRequest::Cpu => true,
179 BackendRequest::Metal => self.metal,
180 BackendRequest::Cuda => self.cuda,
181 }
182 }
183
184 #[must_use]
190 pub fn resolve(self, request: BackendRequest) -> Option<BackendKind> {
191 match request {
192 BackendRequest::Auto | BackendRequest::Cpu => Some(BackendKind::Cpu),
193 BackendRequest::Metal if self.metal => Some(BackendKind::Metal),
194 BackendRequest::Cuda if self.cuda => Some(BackendKind::Cuda),
195 BackendRequest::Metal | BackendRequest::Cuda => None,
196 }
197 }
198
199 #[must_use]
202 pub const fn first_available_accelerator(self) -> Option<BackendKind> {
203 if self.metal {
204 Some(BackendKind::Metal)
205 } else if self.cuda {
206 Some(BackendKind::Cuda)
207 } else {
208 None
209 }
210 }
211}