Skip to main content

tensorlogic_infer/
simd.rs

1//! SIMD (Single Instruction, Multiple Data) optimization utilities.
2//!
3//! This module provides SIMD optimization infrastructure:
4//! - **Platform detection** (AVX2, AVX-512, NEON, etc.)
5//! - **Vectorization hints** for the compiler
6//! - **SIMD-friendly data layouts** and alignment
7//! - **Automatic vectorization** checking
8//! - **Performance benchmarking** for SIMD operations
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{SimdCapabilities, SimdOptimizer, AlignedBuffer};
14//!
15//! // Check SIMD capabilities
16//! let caps = SimdCapabilities::detect();
17//! println!("AVX2 supported: {}", caps.has_avx2());
18//!
19//! // Create aligned buffer for SIMD operations
20//! let buffer = AlignedBuffer::<f32>::new(1024);
21//!
22//! // Optimize operations for SIMD
23//! let optimizer = SimdOptimizer::new(caps);
24//! let optimized_graph = optimizer.optimize(&graph)?;
25//! ```
26
27use serde::{Deserialize, Serialize};
28use std::alloc::{alloc, dealloc, Layout};
29use std::ptr::NonNull;
30use thiserror::Error;
31
32/// SIMD optimization errors.
33#[derive(Error, Debug, Clone, PartialEq)]
34pub enum SimdError {
35    #[error("Unsupported SIMD instruction set: {0}")]
36    UnsupportedInstructionSet(String),
37
38    #[error("Alignment error: required {required}, got {actual}")]
39    AlignmentError { required: usize, actual: usize },
40
41    #[error("Buffer size mismatch: expected {expected}, got {actual}")]
42    SizeMismatch { expected: usize, actual: usize },
43
44    #[error("SIMD operation failed: {0}")]
45    OperationFailed(String),
46}
47
48/// SIMD instruction set.
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum SimdInstructionSet {
51    /// No SIMD support
52    None,
53    /// SSE (Streaming SIMD Extensions)
54    SSE,
55    /// SSE2
56    SSE2,
57    /// SSE3
58    SSE3,
59    /// SSSE3
60    SSSE3,
61    /// SSE4.1
62    SSE41,
63    /// SSE4.2
64    SSE42,
65    /// AVX (Advanced Vector Extensions)
66    AVX,
67    /// AVX2
68    AVX2,
69    /// AVX-512
70    AVX512,
71    /// ARM NEON
72    NEON,
73    /// SVE (Scalable Vector Extension)
74    SVE,
75}
76
77impl SimdInstructionSet {
78    /// Get the vector width in bytes.
79    pub fn vector_width_bytes(&self) -> usize {
80        match self {
81            SimdInstructionSet::None => 1,
82            SimdInstructionSet::SSE
83            | SimdInstructionSet::SSE2
84            | SimdInstructionSet::SSE3
85            | SimdInstructionSet::SSSE3
86            | SimdInstructionSet::SSE41
87            | SimdInstructionSet::SSE42
88            | SimdInstructionSet::NEON => 16,
89            SimdInstructionSet::AVX | SimdInstructionSet::AVX2 => 32,
90            SimdInstructionSet::AVX512 => 64,
91            SimdInstructionSet::SVE => 128, // Can be up to 2048 bits
92        }
93    }
94
95    /// Get the number of f32 elements per vector.
96    pub fn f32_lanes(&self) -> usize {
97        self.vector_width_bytes() / std::mem::size_of::<f32>()
98    }
99
100    /// Get the number of f64 elements per vector.
101    pub fn f64_lanes(&self) -> usize {
102        self.vector_width_bytes() / std::mem::size_of::<f64>()
103    }
104
105    /// Get the preferred alignment in bytes.
106    pub fn preferred_alignment(&self) -> usize {
107        match self {
108            SimdInstructionSet::None => std::mem::align_of::<f64>(),
109            SimdInstructionSet::SSE
110            | SimdInstructionSet::SSE2
111            | SimdInstructionSet::SSE3
112            | SimdInstructionSet::SSSE3
113            | SimdInstructionSet::SSE41
114            | SimdInstructionSet::SSE42
115            | SimdInstructionSet::NEON => 16,
116            SimdInstructionSet::AVX | SimdInstructionSet::AVX2 => 32,
117            SimdInstructionSet::AVX512 => 64,
118            SimdInstructionSet::SVE => 128,
119        }
120    }
121
122    /// Get the instruction set name.
123    pub fn name(&self) -> &'static str {
124        match self {
125            SimdInstructionSet::None => "None",
126            SimdInstructionSet::SSE => "SSE",
127            SimdInstructionSet::SSE2 => "SSE2",
128            SimdInstructionSet::SSE3 => "SSE3",
129            SimdInstructionSet::SSSE3 => "SSSE3",
130            SimdInstructionSet::SSE41 => "SSE4.1",
131            SimdInstructionSet::SSE42 => "SSE4.2",
132            SimdInstructionSet::AVX => "AVX",
133            SimdInstructionSet::AVX2 => "AVX2",
134            SimdInstructionSet::AVX512 => "AVX-512",
135            SimdInstructionSet::NEON => "NEON",
136            SimdInstructionSet::SVE => "SVE",
137        }
138    }
139}
140
141/// SIMD capabilities detection.
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
143pub struct SimdCapabilities {
144    /// Available instruction sets
145    pub instruction_sets: Vec<SimdInstructionSet>,
146
147    /// CPU architecture
148    pub architecture: CpuArchitecture,
149
150    /// Number of CPU cores
151    pub num_cores: usize,
152
153    /// Cache line size (bytes)
154    pub cache_line_size: usize,
155
156    /// L1 cache size (bytes)
157    pub l1_cache_size: usize,
158
159    /// L2 cache size (bytes)
160    pub l2_cache_size: usize,
161
162    /// L3 cache size (bytes)
163    pub l3_cache_size: usize,
164}
165
166/// CPU architecture.
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
168pub enum CpuArchitecture {
169    X86_64,
170    AArch64,
171    ARM,
172    Other,
173}
174
175impl SimdCapabilities {
176    /// Detect SIMD capabilities of the current CPU.
177    pub fn detect() -> Self {
178        #[cfg(target_arch = "x86_64")]
179        {
180            Self::detect_x86_64()
181        }
182
183        #[cfg(target_arch = "aarch64")]
184        {
185            Self::detect_aarch64()
186        }
187
188        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
189        {
190            Self::default()
191        }
192    }
193
194    #[cfg(target_arch = "x86_64")]
195    fn detect_x86_64() -> Self {
196        let mut instruction_sets = Vec::new();
197
198        // Check for SSE family
199        if is_x86_feature_detected!("sse") {
200            instruction_sets.push(SimdInstructionSet::SSE);
201        }
202        if is_x86_feature_detected!("sse2") {
203            instruction_sets.push(SimdInstructionSet::SSE2);
204        }
205        if is_x86_feature_detected!("sse3") {
206            instruction_sets.push(SimdInstructionSet::SSE3);
207        }
208        if is_x86_feature_detected!("ssse3") {
209            instruction_sets.push(SimdInstructionSet::SSSE3);
210        }
211        if is_x86_feature_detected!("sse4.1") {
212            instruction_sets.push(SimdInstructionSet::SSE41);
213        }
214        if is_x86_feature_detected!("sse4.2") {
215            instruction_sets.push(SimdInstructionSet::SSE42);
216        }
217
218        // Check for AVX family
219        if is_x86_feature_detected!("avx") {
220            instruction_sets.push(SimdInstructionSet::AVX);
221        }
222        if is_x86_feature_detected!("avx2") {
223            instruction_sets.push(SimdInstructionSet::AVX2);
224        }
225        if is_x86_feature_detected!("avx512f") {
226            instruction_sets.push(SimdInstructionSet::AVX512);
227        }
228
229        Self {
230            instruction_sets,
231            architecture: CpuArchitecture::X86_64,
232            num_cores: num_cpus::get(),
233            cache_line_size: 64, // Common for x86_64
234            l1_cache_size: 32 * 1024,
235            l2_cache_size: 256 * 1024,
236            l3_cache_size: 8 * 1024 * 1024,
237        }
238    }
239
240    #[cfg(target_arch = "aarch64")]
241    fn detect_aarch64() -> Self {
242        let instruction_sets = vec![SimdInstructionSet::NEON];
243
244        // Note: SVE detection would require runtime feature detection
245        // which is not yet stable in Rust
246
247        Self {
248            instruction_sets,
249            architecture: CpuArchitecture::AArch64,
250            num_cores: num_cpus::get(),
251            cache_line_size: 64,
252            l1_cache_size: 64 * 1024,
253            l2_cache_size: 512 * 1024,
254            l3_cache_size: 4 * 1024 * 1024,
255        }
256    }
257
258    /// Check if a specific instruction set is available.
259    pub fn has_instruction_set(&self, isa: SimdInstructionSet) -> bool {
260        self.instruction_sets.contains(&isa)
261    }
262
263    /// Check if AVX2 is available.
264    pub fn has_avx2(&self) -> bool {
265        self.has_instruction_set(SimdInstructionSet::AVX2)
266    }
267
268    /// Check if AVX-512 is available.
269    pub fn has_avx512(&self) -> bool {
270        self.has_instruction_set(SimdInstructionSet::AVX512)
271    }
272
273    /// Check if NEON is available.
274    pub fn has_neon(&self) -> bool {
275        self.has_instruction_set(SimdInstructionSet::NEON)
276    }
277
278    /// Get the best available instruction set.
279    pub fn best_instruction_set(&self) -> SimdInstructionSet {
280        // Prefer more advanced instruction sets
281        if self.has_avx512() {
282            SimdInstructionSet::AVX512
283        } else if self.has_avx2() {
284            SimdInstructionSet::AVX2
285        } else if self.has_instruction_set(SimdInstructionSet::AVX) {
286            SimdInstructionSet::AVX
287        } else if self.has_instruction_set(SimdInstructionSet::SSE42) {
288            SimdInstructionSet::SSE42
289        } else if self.has_neon() {
290            SimdInstructionSet::NEON
291        } else {
292            SimdInstructionSet::None
293        }
294    }
295
296    /// Get the recommended vectorization factor for a given element size.
297    pub fn recommended_vector_size(&self, element_size: usize) -> usize {
298        let best_isa = self.best_instruction_set();
299        best_isa.vector_width_bytes() / element_size
300    }
301}
302
303impl Default for SimdCapabilities {
304    fn default() -> Self {
305        Self {
306            instruction_sets: vec![SimdInstructionSet::None],
307            architecture: CpuArchitecture::Other,
308            num_cores: num_cpus::get(),
309            cache_line_size: 64,
310            l1_cache_size: 32 * 1024,
311            l2_cache_size: 256 * 1024,
312            l3_cache_size: 8 * 1024 * 1024,
313        }
314    }
315}
316
317/// Aligned buffer for SIMD operations.
318pub struct AlignedBuffer<T> {
319    ptr: NonNull<T>,
320    len: usize,
321    alignment: usize,
322}
323
324impl<T> AlignedBuffer<T> {
325    /// Create a new aligned buffer with the specified alignment.
326    pub fn new_with_alignment(len: usize, alignment: usize) -> Result<Self, SimdError> {
327        if alignment == 0 || !alignment.is_power_of_two() {
328            return Err(SimdError::AlignmentError {
329                required: alignment,
330                actual: 0,
331            });
332        }
333
334        let size = len * std::mem::size_of::<T>();
335        let layout =
336            Layout::from_size_align(size, alignment).map_err(|_| SimdError::AlignmentError {
337                required: alignment,
338                actual: 0,
339            })?;
340
341        let ptr = unsafe { alloc(layout) as *mut T };
342        if ptr.is_null() {
343            return Err(SimdError::OperationFailed("Allocation failed".to_string()));
344        }
345
346        Ok(Self {
347            ptr: NonNull::new(ptr).unwrap(),
348            len,
349            alignment,
350        })
351    }
352
353    /// Create a new aligned buffer with default alignment (64 bytes).
354    pub fn new(len: usize) -> Result<Self, SimdError> {
355        Self::new_with_alignment(len, 64)
356    }
357
358    /// Get the length of the buffer.
359    pub fn len(&self) -> usize {
360        self.len
361    }
362
363    /// Check if the buffer is empty.
364    pub fn is_empty(&self) -> bool {
365        self.len == 0
366    }
367
368    /// Get the alignment of the buffer.
369    pub fn alignment(&self) -> usize {
370        self.alignment
371    }
372
373    /// Get a slice of the buffer.
374    pub fn as_slice(&self) -> &[T] {
375        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
376    }
377
378    /// Get a mutable slice of the buffer.
379    pub fn as_mut_slice(&mut self) -> &mut [T] {
380        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
381    }
382
383    /// Get the raw pointer.
384    pub fn as_ptr(&self) -> *const T {
385        self.ptr.as_ptr()
386    }
387
388    /// Get the mutable raw pointer.
389    pub fn as_mut_ptr(&mut self) -> *mut T {
390        self.ptr.as_ptr()
391    }
392
393    /// Check if the buffer is properly aligned.
394    pub fn is_aligned(&self) -> bool {
395        (self.ptr.as_ptr() as usize) % self.alignment == 0
396    }
397}
398
399impl<T> Drop for AlignedBuffer<T> {
400    fn drop(&mut self) {
401        let size = self.len * std::mem::size_of::<T>();
402        let layout = Layout::from_size_align(size, self.alignment).unwrap();
403        unsafe {
404            dealloc(self.ptr.as_ptr() as *mut u8, layout);
405        }
406    }
407}
408
409unsafe impl<T: Send> Send for AlignedBuffer<T> {}
410unsafe impl<T: Sync> Sync for AlignedBuffer<T> {}
411
412/// Vectorization hint for the compiler.
413#[inline(always)]
414pub fn vectorize_hint() {
415    // This is a hint to the compiler that this loop should be vectorized
416    // The actual implementation depends on the compiler
417}
418
419/// Check if a pointer is aligned for SIMD operations.
420pub fn is_simd_aligned<T>(ptr: *const T, alignment: usize) -> bool {
421    (ptr as usize) % alignment == 0
422}
423
424/// Get the alignment offset needed to align a pointer.
425pub fn alignment_offset<T>(ptr: *const T, alignment: usize) -> usize {
426    let addr = ptr as usize;
427    let rem = addr % alignment;
428    if rem == 0 {
429        0
430    } else {
431        alignment - rem
432    }
433}
434
435/// SIMD optimization hints.
436#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
437pub struct SimdOptimizationHints {
438    /// Preferred vector size
439    pub vector_size: usize,
440
441    /// Preferred alignment
442    pub alignment: usize,
443
444    /// Enable loop unrolling
445    pub unroll_loops: bool,
446
447    /// Unroll factor
448    pub unroll_factor: usize,
449
450    /// Enable data prefetching
451    pub prefetch: bool,
452
453    /// Prefetch distance (cache lines)
454    pub prefetch_distance: usize,
455}
456
457impl Default for SimdOptimizationHints {
458    fn default() -> Self {
459        let caps = SimdCapabilities::detect();
460        let best_isa = caps.best_instruction_set();
461
462        Self {
463            vector_size: best_isa.vector_width_bytes(),
464            alignment: best_isa.preferred_alignment(),
465            unroll_loops: true,
466            unroll_factor: 4,
467            prefetch: true,
468            prefetch_distance: 8,
469        }
470    }
471}
472
473impl SimdOptimizationHints {
474    /// Create hints for a specific instruction set.
475    pub fn for_instruction_set(isa: SimdInstructionSet) -> Self {
476        Self {
477            vector_size: isa.vector_width_bytes(),
478            alignment: isa.preferred_alignment(),
479            ..Default::default()
480        }
481    }
482
483    /// Disable all optimizations.
484    pub fn none() -> Self {
485        Self {
486            vector_size: 1,
487            alignment: std::mem::align_of::<f64>(),
488            unroll_loops: false,
489            unroll_factor: 1,
490            prefetch: false,
491            prefetch_distance: 0,
492        }
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_simd_instruction_set_width() {
502        assert_eq!(SimdInstructionSet::SSE.vector_width_bytes(), 16);
503        assert_eq!(SimdInstructionSet::AVX.vector_width_bytes(), 32);
504        assert_eq!(SimdInstructionSet::AVX2.vector_width_bytes(), 32);
505        assert_eq!(SimdInstructionSet::AVX512.vector_width_bytes(), 64);
506        assert_eq!(SimdInstructionSet::NEON.vector_width_bytes(), 16);
507    }
508
509    #[test]
510    fn test_simd_instruction_set_lanes() {
511        assert_eq!(SimdInstructionSet::AVX2.f32_lanes(), 8);
512        assert_eq!(SimdInstructionSet::AVX2.f64_lanes(), 4);
513        assert_eq!(SimdInstructionSet::AVX512.f32_lanes(), 16);
514    }
515
516    #[test]
517    fn test_simd_capabilities_detection() {
518        let caps = SimdCapabilities::detect();
519        assert!(!caps.instruction_sets.is_empty());
520        assert!(caps.num_cores > 0);
521        assert!(caps.cache_line_size > 0);
522    }
523
524    #[test]
525    fn test_simd_capabilities_best() {
526        let caps = SimdCapabilities::detect();
527        let best = caps.best_instruction_set();
528
529        // Should never return None if any SIMD is available
530        if !caps.instruction_sets.is_empty() {
531            assert_ne!(best, SimdInstructionSet::None);
532        }
533    }
534
535    #[test]
536    fn test_aligned_buffer_creation() {
537        let buffer = AlignedBuffer::<f32>::new(1024).unwrap();
538        assert_eq!(buffer.len(), 1024);
539        assert_eq!(buffer.alignment(), 64);
540        assert!(buffer.is_aligned());
541    }
542
543    #[test]
544    fn test_aligned_buffer_custom_alignment() {
545        let buffer = AlignedBuffer::<f32>::new_with_alignment(512, 32).unwrap();
546        assert_eq!(buffer.len(), 512);
547        assert_eq!(buffer.alignment(), 32);
548        assert!(buffer.is_aligned());
549    }
550
551    #[test]
552    fn test_aligned_buffer_slice() {
553        let mut buffer = AlignedBuffer::<f32>::new(10).unwrap();
554        let slice = buffer.as_mut_slice();
555        slice[0] = 1.0;
556        slice[1] = 2.0;
557
558        let const_slice = buffer.as_slice();
559        assert_eq!(const_slice[0], 1.0);
560        assert_eq!(const_slice[1], 2.0);
561    }
562
563    #[test]
564    fn test_is_simd_aligned() {
565        let buffer = AlignedBuffer::<f32>::new(1024).unwrap();
566        assert!(is_simd_aligned(buffer.as_ptr(), 64));
567        assert!(is_simd_aligned(buffer.as_ptr(), 32));
568        assert!(is_simd_aligned(buffer.as_ptr(), 16));
569    }
570
571    #[test]
572    fn test_alignment_offset() {
573        let buffer = AlignedBuffer::<u8>::new(1024).unwrap();
574        let offset = alignment_offset(buffer.as_ptr(), 64);
575        assert_eq!(offset, 0); // Already aligned
576    }
577
578    #[test]
579    fn test_simd_optimization_hints_default() {
580        let hints = SimdOptimizationHints::default();
581        assert!(hints.vector_size > 0);
582        assert!(hints.alignment > 0);
583        assert!(hints.unroll_loops);
584    }
585
586    #[test]
587    fn test_simd_optimization_hints_for_isa() {
588        let hints = SimdOptimizationHints::for_instruction_set(SimdInstructionSet::AVX2);
589        assert_eq!(hints.vector_size, 32);
590        assert_eq!(hints.alignment, 32);
591    }
592
593    #[test]
594    fn test_simd_optimization_hints_none() {
595        let hints = SimdOptimizationHints::none();
596        assert_eq!(hints.vector_size, 1);
597        assert!(!hints.unroll_loops);
598        assert!(!hints.prefetch);
599    }
600
601    #[test]
602    fn test_instruction_set_name() {
603        assert_eq!(SimdInstructionSet::AVX2.name(), "AVX2");
604        assert_eq!(SimdInstructionSet::NEON.name(), "NEON");
605        assert_eq!(SimdInstructionSet::AVX512.name(), "AVX-512");
606    }
607}