rustalign_simd/
sse_wrapper.rs1#[derive(Copy, Clone, Debug)]
13pub struct SseReg {
14 data: [u8; 32],
16}
17
18impl SseReg {
19 pub fn zero() -> Self {
21 Self { data: [0; 32] }
22 }
23
24 pub fn set1_epi16(value: i16) -> Self {
26 let mut result = Self::zero();
27 let bytes = value.to_le_bytes();
28 for i in 0..16 {
29 result.data[i * 2] = bytes[0];
30 result.data[i * 2 + 1] = bytes[1];
31 }
32 result
33 }
34
35 #[cfg(target_arch = "x86_64")]
37 pub fn load_128(ptr: &[u8; 16]) -> Self {
38 use std::arch::x86_64::*;
39 unsafe {
40 let reg = _mm_loadu_si128(ptr.as_ptr() as *const __m128i);
41 let mut result = Self::zero();
42 std::ptr::copy_nonoverlapping(
43 ® as *const _ as *const u8,
44 result.data.as_mut_ptr(),
45 16,
46 );
47 result
48 }
49 }
50
51 #[cfg(not(target_arch = "x86_64"))]
53 pub fn load_128(ptr: &[u8; 16]) -> Self {
54 let mut result = Self::zero();
55 result.data[..16].copy_from_slice(ptr);
56 result
57 }
58
59 #[cfg(target_arch = "x86_64")]
61 pub fn store_128(&self, ptr: &mut [u8; 16]) {
62 use std::arch::x86_64::*;
63 unsafe {
64 let reg = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
65 _mm_storeu_si128(ptr.as_mut_ptr() as *mut __m128i, reg);
66 }
67 }
68
69 #[cfg(not(target_arch = "x86_64"))]
71 pub fn store_128(&self, ptr: &mut [u8; 16]) {
72 ptr.copy_from_slice(&self.data[..16]);
73 }
74
75 pub fn extract_epi16(&self, index: i32) -> i16 {
77 let idx = index as usize;
78 let byte_idx = idx * 2;
79 i16::from_le_bytes([self.data[byte_idx], self.data[byte_idx + 1]])
80 }
81
82 #[cfg(target_arch = "x86_64")]
84 pub fn cmpeq_epi16(&self, other: &Self) -> Self {
85 use std::arch::x86_64::*;
86 unsafe {
87 let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
88 let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
89 let result = _mm_cmpeq_epi16(a, b);
90 let mut out = Self::zero();
91 std::ptr::copy_nonoverlapping(
92 &result as *const _ as *const u8,
93 out.data.as_mut_ptr(),
94 16,
95 );
96 out
97 }
98 }
99
100 #[cfg(not(target_arch = "x86_64"))]
102 pub fn cmpeq_epi16(&self, other: &Self) -> Self {
103 let mut result = Self::zero();
104 for i in 0..8 {
105 let a = self.extract_epi16(i as i32);
106 let b = other.extract_epi16(i as i32);
107 let val = if a == b { 0xffff } else { 0 };
108 result.data[i * 2] = (val & 0xff) as u8;
109 result.data[i * 2 + 1] = ((val >> 8) & 0xff) as u8;
110 }
111 result
112 }
113
114 #[cfg(target_arch = "x86_64")]
116 pub fn adds_epi16(&self, other: &Self) -> Self {
117 use std::arch::x86_64::*;
118 unsafe {
119 let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
120 let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
121 let result = _mm_adds_epi16(a, b);
122 let mut out = Self::zero();
123 std::ptr::copy_nonoverlapping(
124 &result as *const _ as *const u8,
125 out.data.as_mut_ptr(),
126 16,
127 );
128 out
129 }
130 }
131
132 #[cfg(not(target_arch = "x86_64"))]
134 pub fn adds_epi16(&self, other: &Self) -> Self {
135 let mut result = Self::zero();
136 for i in 0..8 {
137 let a = self.extract_epi16(i as i32);
138 let b = other.extract_epi16(i as i32);
139 let val = a.saturating_add(b);
140 result.data[i * 2] = (val & 0xff) as u8;
141 result.data[i * 2 + 1] = ((val >> 8) & 0xff) as u8;
142 }
143 result
144 }
145
146 #[cfg(target_arch = "x86_64")]
148 pub fn max_epi16(&self, other: &Self) -> Self {
149 use std::arch::x86_64::*;
150 unsafe {
151 let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
152 let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
153 let result = _mm_max_epi16(a, b);
154 let mut out = Self::zero();
155 std::ptr::copy_nonoverlapping(
156 &result as *const _ as *const u8,
157 out.data.as_mut_ptr(),
158 16,
159 );
160 out
161 }
162 }
163
164 #[cfg(not(target_arch = "x86_64"))]
166 pub fn max_epi16(&self, other: &Self) -> Self {
167 let mut result = Self::zero();
168 for i in 0..8 {
169 let a = self.extract_epi16(i as i32);
170 let b = other.extract_epi16(i as i32);
171 let val = a.max(b);
172 result.data[i * 2] = (val & 0xff) as u8;
173 result.data[i * 2 + 1] = ((val >> 8) & 0xff) as u8;
174 }
175 result
176 }
177
178 pub fn hmax_epi16(&self) -> i16 {
182 let mut result = self.extract_epi16(0);
183 for i in 1..8 {
184 result = result.max(self.extract_epi16(i));
185 }
186 result
187 }
188
189 #[cfg(target_arch = "x86_64")]
191 pub fn xor(&self, other: &Self) -> Self {
192 use std::arch::x86_64::*;
193 unsafe {
194 let a = _mm_loadu_si128(self.data.as_ptr() as *const __m128i);
195 let b = _mm_loadu_si128(other.data.as_ptr() as *const __m128i);
196 let result = _mm_xor_si128(a, b);
197 let mut out = Self::zero();
198 std::ptr::copy_nonoverlapping(
199 &result as *const _ as *const u8,
200 out.data.as_mut_ptr(),
201 16,
202 );
203 out
204 }
205 }
206
207 #[cfg(not(target_arch = "x86_64"))]
209 pub fn xor(&self, other: &Self) -> Self {
210 let mut result = Self::zero();
211 for i in 0..16 {
212 result.data[i] = self.data[i] ^ other.data[i];
213 }
214 result
215 }
216}
217
218impl Default for SseReg {
219 fn default() -> Self {
220 Self::zero()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_sse_reg_zero() {
230 let reg = SseReg::zero();
231 assert_eq!(reg.extract_epi16(0), 0);
232 assert_eq!(reg.extract_epi16(7), 0);
233 }
234
235 #[test]
236 fn test_sse_reg_set1() {
237 let reg = SseReg::set1_epi16(-1);
238 for i in 0..8 {
239 assert_eq!(reg.extract_epi16(i), -1);
240 }
241 }
242
243 #[test]
244 fn test_sse_reg_set1_42() {
245 let reg = SseReg::set1_epi16(42);
246 for i in 0..8 {
247 assert_eq!(reg.extract_epi16(i), 42);
248 }
249 }
250
251 #[test]
252 fn test_sse_cmpeq_epi16() {
253 let a = SseReg::set1_epi16(42);
254 let b = SseReg::set1_epi16(42);
255 let result = a.cmpeq_epi16(&b);
256 for i in 0..8 {
258 assert_eq!(result.extract_epi16(i), -1);
259 }
260 }
261
262 #[test]
263 fn test_sse_cmpeq_epi16_not_equal() {
264 let a = SseReg::set1_epi16(42);
265 let b = SseReg::set1_epi16(43);
266 let result = a.cmpeq_epi16(&b);
267 for i in 0..8 {
269 assert_eq!(result.extract_epi16(i), 0);
270 }
271 }
272
273 #[test]
274 fn test_sse_adds_epi16() {
275 let a = SseReg::set1_epi16(100);
276 let b = SseReg::set1_epi16(50);
277 let result = a.adds_epi16(&b);
278 for i in 0..8 {
279 assert_eq!(result.extract_epi16(i), 150);
280 }
281 }
282
283 #[test]
284 fn test_sse_adds_epi16_saturate() {
285 let a = SseReg::set1_epi16(30000);
286 let b = SseReg::set1_epi16(10000);
287 let result = a.adds_epi16(&b);
288 for i in 0..8 {
290 assert_eq!(result.extract_epi16(i), 32767);
291 }
292 }
293
294 #[test]
295 fn test_sse_max_epi16() {
296 let a = SseReg::set1_epi16(42);
297 let b = SseReg::set1_epi16(100);
298 let result = a.max_epi16(&b);
299 for i in 0..8 {
300 assert_eq!(result.extract_epi16(i), 100);
301 }
302 }
303
304 #[test]
305 fn test_sse_hmax_epi16() {
306 let mut data = [0u8; 32];
307 for i in 0..8 {
308 data[i * 2] = ((i * 10) & 0xff) as u8;
309 data[i * 2 + 1] = (((i * 10) >> 8) & 0xff) as u8;
310 }
311 let reg = SseReg { data };
312 assert_eq!(reg.hmax_epi16(), 70); }
314
315 #[test]
316 fn test_sse_xor() {
317 let a = SseReg::set1_epi16(-1);
318 let b = SseReg::set1_epi16(-1);
319 let result = a.xor(&b);
320 for i in 0..8 {
321 assert_eq!(result.extract_epi16(i), 0);
322 }
323 }
324
325 #[test]
326 fn test_sse_load_store() {
327 let input: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
328 let reg = SseReg::load_128(&input);
329 let mut output = [0u8; 16];
330 reg.store_128(&mut output);
331 assert_eq!(input, output);
332 }
333}