Skip to main content

rustalign_simd/
sse_wrapper.rs

1//! SSE wrapper types and operations
2//!
3//! This module provides wrappers around SIMD intrinsics that match
4//! the C++ sse_wrap.h interface, with runtime dispatch for SSE2/SSE4.2/AVX2.
5
6// CPU feature detection imports - only available on x86_64
7
8/// SIMD register type wrapper
9///
10/// This abstracts over __m128i (SSE2) and __m256i (AVX2) types,
11/// matching the C++ SSERegI typedef.
12#[derive(Copy, Clone, Debug)]
13pub struct SseReg {
14    /// Internal storage - uses u8 array for portability
15    data: [u8; 32],
16}
17
18impl SseReg {
19    /// Create a new zero register
20    pub fn zero() -> Self {
21        Self { data: [0; 32] }
22    }
23
24    /// Set all 16-bit elements to the same value
25    pub fn set1_epi16(value: i16) -> Self {
26        let mut result = Self::zero();
27        let bytes = value.to_le_bytes();
28        for i in 0..16 {
29            result.data[i * 2] = bytes[0];
30            result.data[i * 2 + 1] = bytes[1];
31        }
32        result
33    }
34
35    /// Load aligned 128-bit data
36    #[cfg(target_arch = "x86_64")]
37    pub fn load_128(ptr: &[u8; 16]) -> Self {
38        use std::arch::x86_64::*;
39        unsafe {
40            let reg = _mm_loadu_si128(ptr.as_ptr() as *const __m128i);
41            let mut result = Self::zero();
42            std::ptr::copy_nonoverlapping(
43                &reg as *const _ as *const u8,
44                result.data.as_mut_ptr(),
45                16,
46            );
47            result
48        }
49    }
50
51    /// Load unaligned data
52    #[cfg(not(target_arch = "x86_64"))]
53    pub fn load_128(ptr: &[u8; 16]) -> Self {
54        let mut result = Self::zero();
55        result.data[..16].copy_from_slice(ptr);
56        result
57    }
58
59    /// Store aligned 128-bit data
60    #[cfg(target_arch = "x86_64")]
61    pub fn store_128(&self, ptr: &mut [u8; 16]) {
62        use std::arch::x86_64::*;
63        unsafe {
64            let reg = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
65            _mm_storeu_si128(ptr.as_mut_ptr() as *mut __m128i, reg);
66        }
67    }
68
69    /// Store data (non-x86 fallback)
70    #[cfg(not(target_arch = "x86_64"))]
71    pub fn store_128(&self, ptr: &mut [u8; 16]) {
72        ptr.copy_from_slice(&self.data[..16]);
73    }
74
75    /// Extract a 16-bit element by index
76    pub fn extract_epi16(&self, index: i32) -> i16 {
77        let idx = index as usize;
78        let byte_idx = idx * 2;
79        i16::from_le_bytes([self.data[byte_idx], self.data[byte_idx + 1]])
80    }
81
82    /// Compare 16-bit integers for equality
83    #[cfg(target_arch = "x86_64")]
84    pub fn cmpeq_epi16(&self, other: &Self) -> Self {
85        use std::arch::x86_64::*;
86        unsafe {
87            let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
88            let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
89            let result = _mm_cmpeq_epi16(a, b);
90            let mut out = Self::zero();
91            std::ptr::copy_nonoverlapping(
92                &result as *const _ as *const u8,
93                out.data.as_mut_ptr(),
94                16,
95            );
96            out
97        }
98    }
99
100    /// Compare 16-bit integers for equality (scalar fallback)
101    #[cfg(not(target_arch = "x86_64"))]
102    pub fn cmpeq_epi16(&self, other: &Self) -> Self {
103        let mut result = Self::zero();
104        for i in 0..8 {
105            let a = self.extract_epi16(i as i32);
106            let b = other.extract_epi16(i as i32);
107            let val = if a == b { 0xffff } else { 0 };
108            result.data[i * 2] = (val & 0xff) as u8;
109            result.data[i * 2 + 1] = ((val >> 8) & 0xff) as u8;
110        }
111        result
112    }
113
114    /// Add signed 16-bit integers with saturation
115    #[cfg(target_arch = "x86_64")]
116    pub fn adds_epi16(&self, other: &Self) -> Self {
117        use std::arch::x86_64::*;
118        unsafe {
119            let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
120            let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
121            let result = _mm_adds_epi16(a, b);
122            let mut out = Self::zero();
123            std::ptr::copy_nonoverlapping(
124                &result as *const _ as *const u8,
125                out.data.as_mut_ptr(),
126                16,
127            );
128            out
129        }
130    }
131
132    /// Add signed 16-bit integers with saturation (scalar fallback)
133    #[cfg(not(target_arch = "x86_64"))]
134    pub fn adds_epi16(&self, other: &Self) -> Self {
135        let mut result = Self::zero();
136        for i in 0..8 {
137            let a = self.extract_epi16(i as i32);
138            let b = other.extract_epi16(i as i32);
139            let val = a.saturating_add(b);
140            result.data[i * 2] = (val & 0xff) as u8;
141            result.data[i * 2 + 1] = ((val >> 8) & 0xff) as u8;
142        }
143        result
144    }
145
146    /// Maximum of signed 16-bit integers
147    #[cfg(target_arch = "x86_64")]
148    pub fn max_epi16(&self, other: &Self) -> Self {
149        use std::arch::x86_64::*;
150        unsafe {
151            let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
152            let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
153            let result = _mm_max_epi16(a, b);
154            let mut out = Self::zero();
155            std::ptr::copy_nonoverlapping(
156                &result as *const _ as *const u8,
157                out.data.as_mut_ptr(),
158                16,
159            );
160            out
161        }
162    }
163
164    /// Maximum of signed 16-bit integers (scalar fallback)
165    #[cfg(not(target_arch = "x86_64"))]
166    pub fn max_epi16(&self, other: &Self) -> Self {
167        let mut result = Self::zero();
168        for i in 0..8 {
169            let a = self.extract_epi16(i as i32);
170            let b = other.extract_epi16(i as i32);
171            let val = a.max(b);
172            result.data[i * 2] = (val & 0xff) as u8;
173            result.data[i * 2 + 1] = ((val >> 8) & 0xff) as u8;
174        }
175        result
176    }
177
178    /// Compute horizontal maximum of i16 elements
179    ///
180    /// This matches the C++ sse_max_score_i16 macro.
181    pub fn hmax_epi16(&self) -> i16 {
182        let mut result = self.extract_epi16(0);
183        for i in 1..8 {
184            result = result.max(self.extract_epi16(i));
185        }
186        result
187    }
188
189    /// Bitwise XOR
190    #[cfg(target_arch = "x86_64")]
191    pub fn xor(&self, other: &Self) -> Self {
192        use std::arch::x86_64::*;
193        unsafe {
194            let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
195            let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
196            let result = _mm_xor_si128(a, b);
197            let mut out = Self::zero();
198            std::ptr::copy_nonoverlapping(
199                &result as *const _ as *const u8,
200                out.data.as_mut_ptr(),
201                16,
202            );
203            out
204        }
205    }
206
207    /// Bitwise XOR (scalar fallback)
208    #[cfg(not(target_arch = "x86_64"))]
209    pub fn xor(&self, other: &Self) -> Self {
210        let mut result = Self::zero();
211        for i in 0..16 {
212            result.data[i] = self.data[i] ^ other.data[i];
213        }
214        result
215    }
216}
217
218impl Default for SseReg {
219    fn default() -> Self {
220        Self::zero()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_sse_reg_zero() {
230        let reg = SseReg::zero();
231        assert_eq!(reg.extract_epi16(0), 0);
232        assert_eq!(reg.extract_epi16(7), 0);
233    }
234
235    #[test]
236    fn test_sse_reg_set1() {
237        let reg = SseReg::set1_epi16(-1);
238        for i in 0..8 {
239            assert_eq!(reg.extract_epi16(i), -1);
240        }
241    }
242
243    #[test]
244    fn test_sse_reg_set1_42() {
245        let reg = SseReg::set1_epi16(42);
246        for i in 0..8 {
247            assert_eq!(reg.extract_epi16(i), 42);
248        }
249    }
250
251    #[test]
252    fn test_sse_cmpeq_epi16() {
253        let a = SseReg::set1_epi16(42);
254        let b = SseReg::set1_epi16(42);
255        let result = a.cmpeq_epi16(&b);
256        // All elements should be equal (0xffff)
257        for i in 0..8 {
258            assert_eq!(result.extract_epi16(i), -1);
259        }
260    }
261
262    #[test]
263    fn test_sse_cmpeq_epi16_not_equal() {
264        let a = SseReg::set1_epi16(42);
265        let b = SseReg::set1_epi16(43);
266        let result = a.cmpeq_epi16(&b);
267        // All elements should be not equal (0)
268        for i in 0..8 {
269            assert_eq!(result.extract_epi16(i), 0);
270        }
271    }
272
273    #[test]
274    fn test_sse_adds_epi16() {
275        let a = SseReg::set1_epi16(100);
276        let b = SseReg::set1_epi16(50);
277        let result = a.adds_epi16(&b);
278        for i in 0..8 {
279            assert_eq!(result.extract_epi16(i), 150);
280        }
281    }
282
283    #[test]
284    fn test_sse_adds_epi16_saturate() {
285        let a = SseReg::set1_epi16(30000);
286        let b = SseReg::set1_epi16(10000);
287        let result = a.adds_epi16(&b);
288        // Should saturate at i16::MAX = 32767
289        for i in 0..8 {
290            assert_eq!(result.extract_epi16(i), 32767);
291        }
292    }
293
294    #[test]
295    fn test_sse_max_epi16() {
296        let a = SseReg::set1_epi16(42);
297        let b = SseReg::set1_epi16(100);
298        let result = a.max_epi16(&b);
299        for i in 0..8 {
300            assert_eq!(result.extract_epi16(i), 100);
301        }
302    }
303
304    #[test]
305    fn test_sse_hmax_epi16() {
306        let mut data = [0u8; 32];
307        for i in 0..8 {
308            data[i * 2] = ((i * 10) & 0xff) as u8;
309            data[i * 2 + 1] = (((i * 10) >> 8) & 0xff) as u8;
310        }
311        let reg = SseReg { data };
312        assert_eq!(reg.hmax_epi16(), 70); // max of 0, 10, 20, ..., 70
313    }
314
315    #[test]
316    fn test_sse_xor() {
317        let a = SseReg::set1_epi16(-1);
318        let b = SseReg::set1_epi16(-1);
319        let result = a.xor(&b);
320        for i in 0..8 {
321            assert_eq!(result.extract_epi16(i), 0);
322        }
323    }
324
325    #[test]
326    fn test_sse_load_store() {
327        let input: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
328        let reg = SseReg::load_128(&input);
329        let mut output = [0u8; 16];
330        reg.store_128(&mut output);
331        assert_eq!(input, output);
332    }
333}