Skip to main content

sochdb_vector/simd/
dispatch.rs

1//! Unified SIMD Dispatch Architecture
2//!
3//! This module provides compile-time and runtime CPU feature detection
4//! and dispatch for SIMD operations.
5//!
6//! # Architecture
7//!
8//! The dispatch system supports two modes:
9//! 1. **Compile-Time Dispatch** (`#[cfg]` attributes): Used when target is known
10//! 2. **Runtime Dispatch** (CPUID/feature detection): For portable binaries
11//!
12//! # Usage
13//!
14//! ```rust,ignore
15//! use sochdb_vector::simd::dispatch::{cpu_features, simd_level, SimdLevel};
16//!
17//! let features = cpu_features();
18//! if features.has_avx2 {
19//!     println!("AVX2 is available!");
20//! }
21//!
22//! match simd_level() {
23//!     SimdLevel::Avx512 => println!("Using AVX-512"),
24//!     SimdLevel::Avx2 => println!("Using AVX2"),
25//!     SimdLevel::Neon => println!("Using NEON"),
26//!     _ => println!("Using scalar fallback"),
27//! }
28//! ```
29
30use std::sync::OnceLock;
31
32/// CPU feature flags detected at runtime.
33#[derive(Debug, Clone, Copy, Default)]
34pub struct CpuFeatures {
35    /// SSE 4.1 support (x86)
36    pub has_sse4_1: bool,
37    /// AVX2 support (x86)
38    pub has_avx2: bool,
39    /// AVX-512F support (x86)
40    pub has_avx512f: bool,
41    /// AVX-512BW support (x86)
42    pub has_avx512bw: bool,
43    /// AVX-512 VNNI support (x86, for int8 acceleration)
44    pub has_vnni: bool,
45    /// NEON support (ARM, mandatory on aarch64)
46    pub has_neon: bool,
47    /// SVE support (ARM v8.2+)
48    pub has_sve: bool,
49    /// Dot product instruction support (ARM v8.2+)
50    pub has_dotprod: bool,
51}
52
53impl CpuFeatures {
54    /// Detect CPU features at runtime.
55    pub fn detect() -> Self {
56        #[cfg(target_arch = "x86_64")]
57        {
58            Self::detect_x86()
59        }
60
61        #[cfg(target_arch = "aarch64")]
62        {
63            Self::detect_arm()
64        }
65
66        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67        {
68            Self::default()
69        }
70    }
71
72    #[cfg(target_arch = "x86_64")]
73    fn detect_x86() -> Self {
74        Self {
75            has_sse4_1: is_x86_feature_detected!("sse4.1"),
76            has_avx2: is_x86_feature_detected!("avx2"),
77            has_avx512f: is_x86_feature_detected!("avx512f"),
78            has_avx512bw: is_x86_feature_detected!("avx512bw"),
79            has_vnni: is_x86_feature_detected!("avx512vnni"),
80            has_neon: false,
81            has_sve: false,
82            has_dotprod: false,
83        }
84    }
85
86    #[cfg(target_arch = "aarch64")]
87    fn detect_arm() -> Self {
88        // NEON is mandatory on aarch64
89        Self {
90            has_sse4_1: false,
91            has_avx2: false,
92            has_avx512f: false,
93            has_avx512bw: false,
94            has_vnni: false,
95            has_neon: true,
96            // SVE and dotprod detection would require reading system registers
97            // For now, we rely on compile-time detection
98            has_sve: cfg!(target_feature = "sve"),
99            has_dotprod: cfg!(target_feature = "dotprod"),
100        }
101    }
102
103    /// Get the best SIMD level available.
104    pub fn best_level(&self) -> SimdLevel {
105        if self.has_avx512f && self.has_avx512bw {
106            SimdLevel::Avx512
107        } else if self.has_avx2 {
108            SimdLevel::Avx2
109        } else if self.has_neon {
110            SimdLevel::Neon
111        } else if self.has_sse4_1 {
112            SimdLevel::Sse4
113        } else {
114            SimdLevel::Scalar
115        }
116    }
117
118    /// Check if any SIMD acceleration is available.
119    pub fn has_simd(&self) -> bool {
120        self.has_avx2 || self.has_neon || self.has_sse4_1
121    }
122}
123
124/// SIMD capability level.
125#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
126#[repr(u8)]
127pub enum SimdLevel {
128    /// No SIMD, scalar operations only
129    Scalar = 0,
130    /// SSE 4.1 (128-bit, x86)
131    Sse4 = 1,
132    /// NEON (128-bit, ARM)
133    Neon = 2,
134    /// AVX2 (256-bit, x86)
135    Avx2 = 3,
136    /// AVX-512 (512-bit, x86)
137    Avx512 = 4,
138}
139
140impl SimdLevel {
141    /// Elements per SIMD register for u8 operations.
142    pub const fn u8_width(self) -> usize {
143        match self {
144            SimdLevel::Scalar => 1,
145            SimdLevel::Sse4 => 16,
146            SimdLevel::Neon => 16,
147            SimdLevel::Avx2 => 32,
148            SimdLevel::Avx512 => 64,
149        }
150    }
151
152    /// Elements per SIMD register for u64 operations.
153    pub const fn u64_width(self) -> usize {
154        match self {
155            SimdLevel::Scalar => 1,
156            SimdLevel::Sse4 => 2,
157            SimdLevel::Neon => 2,
158            SimdLevel::Avx2 => 4,
159            SimdLevel::Avx512 => 8,
160        }
161    }
162
163    /// Elements per SIMD register for f32 operations.
164    pub const fn f32_width(self) -> usize {
165        match self {
166            SimdLevel::Scalar => 1,
167            SimdLevel::Sse4 => 4,
168            SimdLevel::Neon => 4,
169            SimdLevel::Avx2 => 8,
170            SimdLevel::Avx512 => 16,
171        }
172    }
173
174    /// Register width in bits.
175    pub const fn width_bits(self) -> usize {
176        match self {
177            SimdLevel::Scalar => 64,
178            SimdLevel::Sse4 => 128,
179            SimdLevel::Neon => 128,
180            SimdLevel::Avx2 => 256,
181            SimdLevel::Avx512 => 512,
182        }
183    }
184
185    /// Theoretical speedup factor over scalar for byte operations.
186    pub const fn speedup_factor(self) -> usize {
187        self.u8_width()
188    }
189
190    /// Human-readable name.
191    pub const fn name(self) -> &'static str {
192        match self {
193            SimdLevel::Scalar => "Scalar",
194            SimdLevel::Sse4 => "SSE4.1",
195            SimdLevel::Neon => "NEON",
196            SimdLevel::Avx2 => "AVX2",
197            SimdLevel::Avx512 => "AVX-512",
198        }
199    }
200}
201
202impl std::fmt::Display for SimdLevel {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        write!(f, "{}", self.name())
205    }
206}
207
208/// Global CPU features, detected once at first use.
209static CPU_FEATURES: OnceLock<CpuFeatures> = OnceLock::new();
210
211/// Get detected CPU features (cached).
212#[inline]
213pub fn cpu_features() -> &'static CpuFeatures {
214    CPU_FEATURES.get_or_init(CpuFeatures::detect)
215}
216
217/// Get best available SIMD level.
218#[inline]
219pub fn simd_level() -> SimdLevel {
220    cpu_features().best_level()
221}
222
223/// Check if SIMD acceleration is available.
224#[inline]
225pub fn simd_available() -> bool {
226    cpu_features().has_simd()
227}
228
229/// Get a human-readable description of SIMD capabilities.
230pub fn dispatch_info() -> String {
231    let features = cpu_features();
232    let level = features.best_level();
233
234    let mut info = format!(
235        "SIMD Level: {} ({}-bit)\n",
236        level.name(),
237        level.width_bits()
238    );
239
240    #[cfg(target_arch = "x86_64")]
241    {
242        info.push_str(&format!("  SSE4.1: {}\n", features.has_sse4_1));
243        info.push_str(&format!("  AVX2: {}\n", features.has_avx2));
244        info.push_str(&format!("  AVX-512F: {}\n", features.has_avx512f));
245        info.push_str(&format!("  AVX-512BW: {}\n", features.has_avx512bw));
246        info.push_str(&format!("  VNNI: {}\n", features.has_vnni));
247    }
248
249    #[cfg(target_arch = "aarch64")]
250    {
251        info.push_str(&format!("  NEON: {}\n", features.has_neon));
252        info.push_str(&format!("  SVE: {}\n", features.has_sve));
253        info.push_str(&format!("  DOTPROD: {}\n", features.has_dotprod));
254    }
255
256    info
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_cpu_detection() {
265        let features = cpu_features();
266        let level = features.best_level();
267
268        println!("Detected SIMD level: {:?}", level);
269        println!("Features: {:?}", features);
270
271        // At minimum, detection should work without panicking
272        assert!(level >= SimdLevel::Scalar);
273    }
274
275    #[test]
276    fn test_simd_widths() {
277        assert_eq!(SimdLevel::Scalar.u8_width(), 1);
278        assert_eq!(SimdLevel::Avx2.u8_width(), 32);
279        assert_eq!(SimdLevel::Neon.u8_width(), 16);
280        assert_eq!(SimdLevel::Avx512.u8_width(), 64);
281    }
282
283    #[test]
284    fn test_dispatch_info() {
285        let info = dispatch_info();
286        println!("{}", info);
287        assert!(!info.is_empty());
288    }
289}