Skip to main content

tensorlogic_infer/
simd.rs

1//! SIMD (Single Instruction, Multiple Data) optimization utilities.
2//!
3//! This module provides SIMD optimization infrastructure:
4//! - **Platform detection** (AVX2, AVX-512, NEON, etc.)
5//! - **Vectorization hints** for the compiler
6//! - **SIMD-friendly data layouts** and alignment
7//! - **Automatic vectorization** checking
8//! - **Performance benchmarking** for SIMD operations
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{SimdCapabilities, SimdOptimizer, AlignedBuffer};
14//!
15//! // Check SIMD capabilities
16//! let caps = SimdCapabilities::detect();
17//! println!("AVX2 supported: {}", caps.has_avx2());
18//!
19//! // Create aligned buffer for SIMD operations
20//! let buffer = AlignedBuffer::<f32>::new(1024);
21//!
22//! // Optimize operations for SIMD
23//! let optimizer = SimdOptimizer::new(caps);
24//! let optimized_graph = optimizer.optimize(&graph)?;
25//! ```
26
27use serde::{Deserialize, Serialize};
28use std::alloc::{alloc, dealloc, Layout};
29use std::ptr::NonNull;
30use thiserror::Error;
31
32/// SIMD optimization errors.
33#[derive(Error, Debug, Clone, PartialEq)]
34pub enum SimdError {
35    #[error("Unsupported SIMD instruction set: {0}")]
36    UnsupportedInstructionSet(String),
37
38    #[error("Alignment error: required {required}, got {actual}")]
39    AlignmentError { required: usize, actual: usize },
40
41    #[error("Buffer size mismatch: expected {expected}, got {actual}")]
42    SizeMismatch { expected: usize, actual: usize },
43
44    #[error("SIMD operation failed: {0}")]
45    OperationFailed(String),
46}
47
48/// SIMD instruction set.
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum SimdInstructionSet {
51    /// No SIMD support
52    None,
53    /// SSE (Streaming SIMD Extensions)
54    SSE,
55    /// SSE2
56    SSE2,
57    /// SSE3
58    SSE3,
59    /// SSSE3
60    SSSE3,
61    /// SSE4.1
62    SSE41,
63    /// SSE4.2
64    SSE42,
65    /// AVX (Advanced Vector Extensions)
66    AVX,
67    /// AVX2
68    AVX2,
69    /// AVX-512
70    AVX512,
71    /// ARM NEON
72    NEON,
73    /// SVE (Scalable Vector Extension)
74    SVE,
75}
76
77impl SimdInstructionSet {
78    /// Get the vector width in bytes.
79    pub fn vector_width_bytes(&self) -> usize {
80        match self {
81            SimdInstructionSet::None => 1,
82            SimdInstructionSet::SSE
83            | SimdInstructionSet::SSE2
84            | SimdInstructionSet::SSE3
85            | SimdInstructionSet::SSSE3
86            | SimdInstructionSet::SSE41
87            | SimdInstructionSet::SSE42
88            | SimdInstructionSet::NEON => 16,
89            SimdInstructionSet::AVX | SimdInstructionSet::AVX2 => 32,
90            SimdInstructionSet::AVX512 => 64,
91            SimdInstructionSet::SVE => 128, // Can be up to 2048 bits
92        }
93    }
94
95    /// Get the number of f32 elements per vector.
96    pub fn f32_lanes(&self) -> usize {
97        self.vector_width_bytes() / std::mem::size_of::<f32>()
98    }
99
100    /// Get the number of f64 elements per vector.
101    pub fn f64_lanes(&self) -> usize {
102        self.vector_width_bytes() / std::mem::size_of::<f64>()
103    }
104
105    /// Get the preferred alignment in bytes.
106    pub fn preferred_alignment(&self) -> usize {
107        match self {
108            SimdInstructionSet::None => std::mem::align_of::<f64>(),
109            SimdInstructionSet::SSE
110            | SimdInstructionSet::SSE2
111            | SimdInstructionSet::SSE3
112            | SimdInstructionSet::SSSE3
113            | SimdInstructionSet::SSE41
114            | SimdInstructionSet::SSE42
115            | SimdInstructionSet::NEON => 16,
116            SimdInstructionSet::AVX | SimdInstructionSet::AVX2 => 32,
117            SimdInstructionSet::AVX512 => 64,
118            SimdInstructionSet::SVE => 128,
119        }
120    }
121
122    /// Get the instruction set name.
123    pub fn name(&self) -> &'static str {
124        match self {
125            SimdInstructionSet::None => "None",
126            SimdInstructionSet::SSE => "SSE",
127            SimdInstructionSet::SSE2 => "SSE2",
128            SimdInstructionSet::SSE3 => "SSE3",
129            SimdInstructionSet::SSSE3 => "SSSE3",
130            SimdInstructionSet::SSE41 => "SSE4.1",
131            SimdInstructionSet::SSE42 => "SSE4.2",
132            SimdInstructionSet::AVX => "AVX",
133            SimdInstructionSet::AVX2 => "AVX2",
134            SimdInstructionSet::AVX512 => "AVX-512",
135            SimdInstructionSet::NEON => "NEON",
136            SimdInstructionSet::SVE => "SVE",
137        }
138    }
139}
140
141/// SIMD capabilities detection.
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
143pub struct SimdCapabilities {
144    /// Available instruction sets
145    pub instruction_sets: Vec<SimdInstructionSet>,
146
147    /// CPU architecture
148    pub architecture: CpuArchitecture,
149
150    /// Number of CPU cores
151    pub num_cores: usize,
152
153    /// Cache line size (bytes)
154    pub cache_line_size: usize,
155
156    /// L1 cache size (bytes)
157    pub l1_cache_size: usize,
158
159    /// L2 cache size (bytes)
160    pub l2_cache_size: usize,
161
162    /// L3 cache size (bytes)
163    pub l3_cache_size: usize,
164}
165
166/// CPU architecture.
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
168pub enum CpuArchitecture {
169    X86_64,
170    AArch64,
171    ARM,
172    Other,
173}
174
175impl SimdCapabilities {
176    /// Detect SIMD capabilities of the current CPU.
177    pub fn detect() -> Self {
178        #[cfg(target_arch = "x86_64")]
179        {
180            Self::detect_x86_64()
181        }
182
183        #[cfg(target_arch = "aarch64")]
184        {
185            Self::detect_aarch64()
186        }
187
188        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
189        {
190            Self::default()
191        }
192    }
193
194    #[cfg(target_arch = "x86_64")]
195    fn detect_x86_64() -> Self {
196        let mut instruction_sets = Vec::new();
197
198        // Check for SSE family
199        if is_x86_feature_detected!("sse") {
200            instruction_sets.push(SimdInstructionSet::SSE);
201        }
202        if is_x86_feature_detected!("sse2") {
203            instruction_sets.push(SimdInstructionSet::SSE2);
204        }
205        if is_x86_feature_detected!("sse3") {
206            instruction_sets.push(SimdInstructionSet::SSE3);
207        }
208        if is_x86_feature_detected!("ssse3") {
209            instruction_sets.push(SimdInstructionSet::SSSE3);
210        }
211        if is_x86_feature_detected!("sse4.1") {
212            instruction_sets.push(SimdInstructionSet::SSE41);
213        }
214        if is_x86_feature_detected!("sse4.2") {
215            instruction_sets.push(SimdInstructionSet::SSE42);
216        }
217
218        // Check for AVX family
219        if is_x86_feature_detected!("avx") {
220            instruction_sets.push(SimdInstructionSet::AVX);
221        }
222        if is_x86_feature_detected!("avx2") {
223            instruction_sets.push(SimdInstructionSet::AVX2);
224        }
225        if is_x86_feature_detected!("avx512f") {
226            instruction_sets.push(SimdInstructionSet::AVX512);
227        }
228
229        Self {
230            instruction_sets,
231            architecture: CpuArchitecture::X86_64,
232            num_cores: num_cpus::get(),
233            cache_line_size: 64, // Common for x86_64
234            l1_cache_size: 32 * 1024,
235            l2_cache_size: 256 * 1024,
236            l3_cache_size: 8 * 1024 * 1024,
237        }
238    }
239
240    #[cfg(target_arch = "aarch64")]
241    fn detect_aarch64() -> Self {
242        let instruction_sets = vec![SimdInstructionSet::NEON];
243
244        // Note: SVE detection would require runtime feature detection
245        // which is not yet stable in Rust
246
247        Self {
248            instruction_sets,
249            architecture: CpuArchitecture::AArch64,
250            num_cores: num_cpus::get(),
251            cache_line_size: 64,
252            l1_cache_size: 64 * 1024,
253            l2_cache_size: 512 * 1024,
254            l3_cache_size: 4 * 1024 * 1024,
255        }
256    }
257
258    /// Check if a specific instruction set is available.
259    pub fn has_instruction_set(&self, isa: SimdInstructionSet) -> bool {
260        self.instruction_sets.contains(&isa)
261    }
262
263    /// Check if AVX2 is available.
264    pub fn has_avx2(&self) -> bool {
265        self.has_instruction_set(SimdInstructionSet::AVX2)
266    }
267
268    /// Check if AVX-512 is available.
269    pub fn has_avx512(&self) -> bool {
270        self.has_instruction_set(SimdInstructionSet::AVX512)
271    }
272
273    /// Check if NEON is available.
274    pub fn has_neon(&self) -> bool {
275        self.has_instruction_set(SimdInstructionSet::NEON)
276    }
277
278    /// Get the best available instruction set.
279    pub fn best_instruction_set(&self) -> SimdInstructionSet {
280        // Prefer more advanced instruction sets
281        if self.has_avx512() {
282            SimdInstructionSet::AVX512
283        } else if self.has_avx2() {
284            SimdInstructionSet::AVX2
285        } else if self.has_instruction_set(SimdInstructionSet::AVX) {
286            SimdInstructionSet::AVX
287        } else if self.has_instruction_set(SimdInstructionSet::SSE42) {
288            SimdInstructionSet::SSE42
289        } else if self.has_neon() {
290            SimdInstructionSet::NEON
291        } else {
292            SimdInstructionSet::None
293        }
294    }
295
296    /// Get the recommended vectorization factor for a given element size.
297    pub fn recommended_vector_size(&self, element_size: usize) -> usize {
298        let best_isa = self.best_instruction_set();
299        best_isa.vector_width_bytes() / element_size
300    }
301}
302
303impl Default for SimdCapabilities {
304    fn default() -> Self {
305        Self {
306            instruction_sets: vec![SimdInstructionSet::None],
307            architecture: CpuArchitecture::Other,
308            num_cores: num_cpus::get(),
309            cache_line_size: 64,
310            l1_cache_size: 32 * 1024,
311            l2_cache_size: 256 * 1024,
312            l3_cache_size: 8 * 1024 * 1024,
313        }
314    }
315}
316
317/// Aligned buffer for SIMD operations.
318pub struct AlignedBuffer<T> {
319    ptr: NonNull<T>,
320    len: usize,
321    alignment: usize,
322}
323
324impl<T> AlignedBuffer<T> {
325    /// Create a new aligned buffer with the specified alignment.
326    pub fn new_with_alignment(len: usize, alignment: usize) -> Result<Self, SimdError> {
327        if alignment == 0 || !alignment.is_power_of_two() {
328            return Err(SimdError::AlignmentError {
329                required: alignment,
330                actual: 0,
331            });
332        }
333
334        let size = len * std::mem::size_of::<T>();
335        let layout =
336            Layout::from_size_align(size, alignment).map_err(|_| SimdError::AlignmentError {
337                required: alignment,
338                actual: 0,
339            })?;
340
341        let ptr = unsafe { alloc(layout) as *mut T };
342        if ptr.is_null() {
343            return Err(SimdError::OperationFailed("Allocation failed".to_string()));
344        }
345
346        Ok(Self {
347            ptr: NonNull::new(ptr).expect("ptr is non-null after null check above"),
348            len,
349            alignment,
350        })
351    }
352
353    /// Create a new aligned buffer with default alignment (64 bytes).
354    pub fn new(len: usize) -> Result<Self, SimdError> {
355        Self::new_with_alignment(len, 64)
356    }
357
358    /// Get the length of the buffer.
359    pub fn len(&self) -> usize {
360        self.len
361    }
362
363    /// Check if the buffer is empty.
364    pub fn is_empty(&self) -> bool {
365        self.len == 0
366    }
367
368    /// Get the alignment of the buffer.
369    pub fn alignment(&self) -> usize {
370        self.alignment
371    }
372
373    /// Get a slice of the buffer.
374    pub fn as_slice(&self) -> &[T] {
375        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
376    }
377
378    /// Get a mutable slice of the buffer.
379    pub fn as_mut_slice(&mut self) -> &mut [T] {
380        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
381    }
382
383    /// Get the raw pointer.
384    pub fn as_ptr(&self) -> *const T {
385        self.ptr.as_ptr()
386    }
387
388    /// Get the mutable raw pointer.
389    pub fn as_mut_ptr(&mut self) -> *mut T {
390        self.ptr.as_ptr()
391    }
392
393    /// Check if the buffer is properly aligned.
394    pub fn is_aligned(&self) -> bool {
395        (self.ptr.as_ptr() as usize) % self.alignment == 0
396    }
397}
398
399impl<T> Drop for AlignedBuffer<T> {
400    fn drop(&mut self) {
401        let size = self.len * std::mem::size_of::<T>();
402        let layout =
403            Layout::from_size_align(size, self.alignment).expect("alignment is valid power of two");
404        unsafe {
405            dealloc(self.ptr.as_ptr() as *mut u8, layout);
406        }
407    }
408}
409
410unsafe impl<T: Send> Send for AlignedBuffer<T> {}
411unsafe impl<T: Sync> Sync for AlignedBuffer<T> {}
412
413/// Vectorization hint for the compiler.
414#[inline(always)]
415pub fn vectorize_hint() {
416    // This is a hint to the compiler that this loop should be vectorized
417    // The actual implementation depends on the compiler
418}
419
420/// Check if a pointer is aligned for SIMD operations.
421pub fn is_simd_aligned<T>(ptr: *const T, alignment: usize) -> bool {
422    (ptr as usize) % alignment == 0
423}
424
425/// Get the alignment offset needed to align a pointer.
426pub fn alignment_offset<T>(ptr: *const T, alignment: usize) -> usize {
427    let addr = ptr as usize;
428    let rem = addr % alignment;
429    if rem == 0 {
430        0
431    } else {
432        alignment - rem
433    }
434}
435
436/// SIMD optimization hints.
437#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
438pub struct SimdOptimizationHints {
439    /// Preferred vector size
440    pub vector_size: usize,
441
442    /// Preferred alignment
443    pub alignment: usize,
444
445    /// Enable loop unrolling
446    pub unroll_loops: bool,
447
448    /// Unroll factor
449    pub unroll_factor: usize,
450
451    /// Enable data prefetching
452    pub prefetch: bool,
453
454    /// Prefetch distance (cache lines)
455    pub prefetch_distance: usize,
456}
457
458impl Default for SimdOptimizationHints {
459    fn default() -> Self {
460        let caps = SimdCapabilities::detect();
461        let best_isa = caps.best_instruction_set();
462
463        Self {
464            vector_size: best_isa.vector_width_bytes(),
465            alignment: best_isa.preferred_alignment(),
466            unroll_loops: true,
467            unroll_factor: 4,
468            prefetch: true,
469            prefetch_distance: 8,
470        }
471    }
472}
473
474impl SimdOptimizationHints {
475    /// Create hints for a specific instruction set.
476    pub fn for_instruction_set(isa: SimdInstructionSet) -> Self {
477        Self {
478            vector_size: isa.vector_width_bytes(),
479            alignment: isa.preferred_alignment(),
480            ..Default::default()
481        }
482    }
483
484    /// Disable all optimizations.
485    pub fn none() -> Self {
486        Self {
487            vector_size: 1,
488            alignment: std::mem::align_of::<f64>(),
489            unroll_loops: false,
490            unroll_factor: 1,
491            prefetch: false,
492            prefetch_distance: 0,
493        }
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_simd_instruction_set_width() {
503        assert_eq!(SimdInstructionSet::SSE.vector_width_bytes(), 16);
504        assert_eq!(SimdInstructionSet::AVX.vector_width_bytes(), 32);
505        assert_eq!(SimdInstructionSet::AVX2.vector_width_bytes(), 32);
506        assert_eq!(SimdInstructionSet::AVX512.vector_width_bytes(), 64);
507        assert_eq!(SimdInstructionSet::NEON.vector_width_bytes(), 16);
508    }
509
510    #[test]
511    fn test_simd_instruction_set_lanes() {
512        assert_eq!(SimdInstructionSet::AVX2.f32_lanes(), 8);
513        assert_eq!(SimdInstructionSet::AVX2.f64_lanes(), 4);
514        assert_eq!(SimdInstructionSet::AVX512.f32_lanes(), 16);
515    }
516
517    #[test]
518    fn test_simd_capabilities_detection() {
519        let caps = SimdCapabilities::detect();
520        assert!(!caps.instruction_sets.is_empty());
521        assert!(caps.num_cores > 0);
522        assert!(caps.cache_line_size > 0);
523    }
524
525    #[test]
526    fn test_simd_capabilities_best() {
527        let caps = SimdCapabilities::detect();
528        let best = caps.best_instruction_set();
529
530        // Should never return None if any SIMD is available
531        if !caps.instruction_sets.is_empty() {
532            assert_ne!(best, SimdInstructionSet::None);
533        }
534    }
535
536    #[test]
537    fn test_aligned_buffer_creation() {
538        let buffer = AlignedBuffer::<f32>::new(1024).expect("unwrap");
539        assert_eq!(buffer.len(), 1024);
540        assert_eq!(buffer.alignment(), 64);
541        assert!(buffer.is_aligned());
542    }
543
544    #[test]
545    fn test_aligned_buffer_custom_alignment() {
546        let buffer = AlignedBuffer::<f32>::new_with_alignment(512, 32).expect("unwrap");
547        assert_eq!(buffer.len(), 512);
548        assert_eq!(buffer.alignment(), 32);
549        assert!(buffer.is_aligned());
550    }
551
552    #[test]
553    fn test_aligned_buffer_slice() {
554        let mut buffer = AlignedBuffer::<f32>::new(10).expect("unwrap");
555        let slice = buffer.as_mut_slice();
556        slice[0] = 1.0;
557        slice[1] = 2.0;
558
559        let const_slice = buffer.as_slice();
560        assert_eq!(const_slice[0], 1.0);
561        assert_eq!(const_slice[1], 2.0);
562    }
563
564    #[test]
565    fn test_is_simd_aligned() {
566        let buffer = AlignedBuffer::<f32>::new(1024).expect("unwrap");
567        assert!(is_simd_aligned(buffer.as_ptr(), 64));
568        assert!(is_simd_aligned(buffer.as_ptr(), 32));
569        assert!(is_simd_aligned(buffer.as_ptr(), 16));
570    }
571
572    #[test]
573    fn test_alignment_offset() {
574        let buffer = AlignedBuffer::<u8>::new(1024).expect("unwrap");
575        let offset = alignment_offset(buffer.as_ptr(), 64);
576        assert_eq!(offset, 0); // Already aligned
577    }
578
579    #[test]
580    fn test_simd_optimization_hints_default() {
581        let hints = SimdOptimizationHints::default();
582        assert!(hints.vector_size > 0);
583        assert!(hints.alignment > 0);
584        assert!(hints.unroll_loops);
585    }
586
587    #[test]
588    fn test_simd_optimization_hints_for_isa() {
589        let hints = SimdOptimizationHints::for_instruction_set(SimdInstructionSet::AVX2);
590        assert_eq!(hints.vector_size, 32);
591        assert_eq!(hints.alignment, 32);
592    }
593
594    #[test]
595    fn test_simd_optimization_hints_none() {
596        let hints = SimdOptimizationHints::none();
597        assert_eq!(hints.vector_size, 1);
598        assert!(!hints.unroll_loops);
599        assert!(!hints.prefetch);
600    }
601
602    #[test]
603    fn test_instruction_set_name() {
604        assert_eq!(SimdInstructionSet::AVX2.name(), "AVX2");
605        assert_eq!(SimdInstructionSet::NEON.name(), "NEON");
606        assert_eq!(SimdInstructionSet::AVX512.name(), "AVX-512");
607    }
608}