Skip to main content

rustdom_x/
m128.rs

1use std::fmt;
2
3// ---- AES S-boxes ----
4
5const SBOX: [u8; 256] = [
6    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
7    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
8    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
9    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
10    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
11    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
12    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
13    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
14    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
15    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
16    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
17    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
18    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
19    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
20    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
21    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
22];
23
24const INV_SBOX: [u8; 256] = [
25    0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
26    0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
27    0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
28    0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
29    0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
30    0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
31    0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
32    0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
33    0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
34    0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
35    0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
36    0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
37    0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
38    0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
39    0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
40    0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
41];
42
43// ---- Precomputed T-tables (SubBytes + ShiftRows + MixColumns in one lookup) ----
44//
45// ENC_TABLE.(0..3): encryption, indexed by the byte in each row position.
46// DEC_TABLE.(0..3): decryption, indexed by the byte in each row position.
47// Each entry encodes a full 4-byte column contribution as a u32 (little-endian).
48// Generated at compile time — zero runtime cost.
49
50const fn xtime(a: u8) -> u8 {
51    if a & 0x80 != 0 {
52        (a << 1) ^ 0x1b
53    } else {
54        a << 1
55    }
56}
57
58const fn make_enc_tables() -> ([u32; 256], [u32; 256], [u32; 256], [u32; 256]) {
59    let mut t0 = [0u32; 256];
60    let mut t1 = [0u32; 256];
61    let mut t2 = [0u32; 256];
62    let mut t3 = [0u32; 256];
63    let mut i = 0usize;
64    while i < 256 {
65        let s = SBOX[i];
66        let x2 = xtime(s);
67        let x3 = x2 ^ s;
68        // MixColumns column vector for byte in row 0: [2s, s, s, 3s]
69        // Rotated for rows 1-3.
70        t0[i] = (x2 as u32) | ((s as u32) << 8) | ((s as u32) << 16) | ((x3 as u32) << 24);
71        t1[i] = (x3 as u32) | ((x2 as u32) << 8) | ((s as u32) << 16) | ((s as u32) << 24);
72        t2[i] = (s as u32) | ((x3 as u32) << 8) | ((x2 as u32) << 16) | ((s as u32) << 24);
73        t3[i] = (s as u32) | ((s as u32) << 8) | ((x3 as u32) << 16) | ((x2 as u32) << 24);
74        i += 1;
75    }
76    (t0, t1, t2, t3)
77}
78
79const fn make_dec_tables() -> ([u32; 256], [u32; 256], [u32; 256], [u32; 256]) {
80    let mut t0 = [0u32; 256];
81    let mut t1 = [0u32; 256];
82    let mut t2 = [0u32; 256];
83    let mut t3 = [0u32; 256];
84    let mut i = 0usize;
85    while i < 256 {
86        let s = INV_SBOX[i];
87        let x2 = xtime(s);
88        let x4 = xtime(x2);
89        let x8 = xtime(x4);
90        let x9 = x8 ^ s;
91        let xb = x8 ^ x2 ^ s;
92        let xd = x8 ^ x4 ^ s;
93        let xe = x8 ^ x4 ^ x2;
94        // InvMixColumns column vector for byte in row 0: [0e, 09, 0d, 0b]
95        // Rotated for rows 1-3.
96        t0[i] = (xe as u32) | ((x9 as u32) << 8) | ((xd as u32) << 16) | ((xb as u32) << 24);
97        t1[i] = (xb as u32) | ((xe as u32) << 8) | ((x9 as u32) << 16) | ((xd as u32) << 24);
98        t2[i] = (xd as u32) | ((xb as u32) << 8) | ((xe as u32) << 16) | ((x9 as u32) << 24);
99        t3[i] = (x9 as u32) | ((xd as u32) << 8) | ((xb as u32) << 16) | ((xe as u32) << 24);
100        i += 1;
101    }
102    (t0, t1, t2, t3)
103}
104
105static ENC_TABLE: ([u32; 256], [u32; 256], [u32; 256], [u32; 256]) = make_enc_tables();
106static DEC_TABLE: ([u32; 256], [u32; 256], [u32; 256], [u32; 256]) = make_dec_tables();
107
108// ---- AES round functions ----
109
110/// One AES encryption round: SubBytes + ShiftRows + MixColumns + AddRoundKey.
111/// Matches _mm_aesenc_si128 semantics exactly.
112fn aesenc(state: u128, key: u128) -> u128 {
113    let b = state.to_le_bytes();
114
115    // Column-major layout: b[c*4 + r] = byte at (col=c, row=r).
116    // ShiftRows rotates row r left by r, so output col c row r reads
117    // from input col (c + r) % 4, row r.
118    // T-tables fold SubBytes + MixColumns into a single lookup per byte.
119    let mut out = [0u32; 4];
120    for c in 0..4usize {
121        let b0 = b[((c) % 4) * 4] as usize; // row 0
122        let b1 = b[((c + 1) % 4) * 4 + 1] as usize; // row 1
123        let b2 = b[((c + 2) % 4) * 4 + 2] as usize; // row 2
124        let b3 = b[((c + 3) % 4) * 4 + 3] as usize; // row 3
125        out[c] = ENC_TABLE.0[b0] ^ ENC_TABLE.1[b1] ^ ENC_TABLE.2[b2] ^ ENC_TABLE.3[b3];
126    }
127
128    let mut result = [0u8; 16];
129    for c in 0..4 {
130        result[c * 4..c * 4 + 4].copy_from_slice(&out[c].to_le_bytes());
131    }
132    u128::from_le_bytes(result) ^ key
133}
134
135/// One AES decryption round: InvShiftRows + InvSubBytes + InvMixColumns + AddRoundKey.
136/// Matches _mm_aesdec_si128 semantics exactly.
137fn aesdec(state: u128, key: u128) -> u128 {
138    let b = state.to_le_bytes();
139
140    // InvShiftRows rotates row r right by r, so output col c row r reads
141    // from input col (c + 4 - r) % 4, row r.
142    let mut out = [0u32; 4];
143    for c in 0..4usize {
144        let b0 = b[((c) % 4) * 4] as usize; // row 0, no shift
145        let b1 = b[((c + 3) % 4) * 4 + 1] as usize; // row 1, right 1
146        let b2 = b[((c + 2) % 4) * 4 + 2] as usize; // row 2, right 2
147        let b3 = b[((c + 1) % 4) * 4 + 3] as usize; // row 3, right 3
148        out[c] = DEC_TABLE.0[b0] ^ DEC_TABLE.1[b1] ^ DEC_TABLE.2[b2] ^ DEC_TABLE.3[b3];
149    }
150
151    let mut result = [0u8; 16];
152    for c in 0..4 {
153        result[c * 4..c * 4 + 4].copy_from_slice(&out[c].to_le_bytes());
154    }
155    u128::from_le_bytes(result) ^ key
156}
157
158// ---- m128i ----
159#[cfg(target_arch = "x86_64")]
160use std::arch::x86_64::*;
161
162#[cfg(target_arch = "aarch64")]
163use std::arch::aarch64::*;
164
165#[allow(nonstandard_style)]
166#[derive(Copy, Clone, PartialEq, Eq)]
167pub struct m128i(pub u128);
168
169impl m128i {
170    pub fn aesenc(&self, key: m128i) -> m128i {
171        // 1. Try x86 Hardware (Intel/AMD)
172        #[cfg(target_arch = "x86_64")]
173        {
174            if is_x86_feature_detected!("aes") {
175                return unsafe { self.hw_aesenc_x86(key) };
176            }
177        }
178
179        // 2. Try ARM Hardware (Android/iOS)
180        #[cfg(target_arch = "aarch64")]
181        {
182            if std::arch::is_aarch64_feature_detected!("aes") {
183                return unsafe { self.hw_aesenc_aarch64(key) };
184            }
185        }
186
187        // 3. Fallback to Software (Old/Cheap devices)
188        self.sw_aesenc(key)
189    }
190
191    pub fn aesdec(&self, key: m128i) -> m128i {
192        #[cfg(target_arch = "x86_64")]
193        {
194            if is_x86_feature_detected!("aes") {
195                return unsafe { self.hw_aesdec_x86(key) };
196            }
197        }
198
199        #[cfg(target_arch = "aarch64")]
200        {
201            if std::arch::is_aarch64_feature_detected!("aes") {
202                return unsafe { self.hw_aesdec_aarch64(key) };
203            }
204        }
205
206        self.sw_aesdec(key)
207    }
208
209    // --- x86 Hardware Implementation ---
210    #[cfg(target_arch = "x86_64")]
211    #[target_feature(enable = "aes")]
212    unsafe fn hw_aesenc_x86(&self, key: m128i) -> m128i {
213        unsafe {
214            let s = _mm_loadu_si128(&self.0 as *const u128 as *const __m128i);
215            let k = _mm_loadu_si128(&key.0 as *const u128 as *const __m128i);
216            let r = _mm_aesenc_si128(s, k);
217            let mut out = 0u128;
218            _mm_storeu_si128(&mut out as *mut u128 as *mut __m128i, r);
219            m128i(out)
220        }
221    }
222
223    #[cfg(target_arch = "x86_64")]
224    #[target_feature(enable = "aes")]
225    unsafe fn hw_aesdec_x86(&self, key: m128i) -> m128i {
226        unsafe {
227            let s = _mm_loadu_si128(&self.0 as *const u128 as *const __m128i);
228            let k = _mm_loadu_si128(&key.0 as *const u128 as *const __m128i);
229            let r = _mm_aesdec_si128(s, k);
230            let mut out = 0u128;
231            _mm_storeu_si128(&mut out as *mut u128 as *mut __m128i, r);
232            m128i(out)
233        }
234    }
235
236    // --- ARM Hardware Implementation (iOS / Android) ---
237    #[cfg(target_arch = "aarch64")]
238    #[target_feature(enable = "aes")]
239    unsafe fn hw_aesenc_aarch64(&self, key: m128i) -> m128i {
240        let s = vreinterpretq_u8_u128(self.0);
241        let k = vreinterpretq_u8_u128(key.0);
242        let zero = vdupq_n_u8(0);
243
244        // RandomX expects x86 aesenc behavior:
245        // vaeseq (SubBytes + ShiftRows) -> vaesmcq (MixColumns) -> XOR key
246        let mut out = vaeseq_u8(s, zero);
247        out = vaesmcq_u8(out);
248        let res = veorq_u8(out, k);
249
250        m128i(vgetq_lane_u128(vreinterpretq_u128_u8(res), 0))
251    }
252
253    #[cfg(target_arch = "aarch64")]
254    #[target_feature(enable = "aes")]
255    unsafe fn hw_aesdec_aarch64(&self, key: m128i) -> m128i {
256        let s = vreinterpretq_u8_u128(self.0);
257        let k = vreinterpretq_u8_u128(key.0);
258        let zero = vdupq_n_u8(0);
259
260        // vaesdq (InvSubBytes + InvShiftRows) -> vaesimcq (InvMixColumns) -> XOR key
261        let mut out = vaesdq_u8(s, zero);
262        out = vaesimcq_u8(out);
263        let res = veorq_u8(out, k);
264
265        m128i(vgetq_lane_u128(vreinterpretq_u128_u8(res), 0))
266    }
267
268    // --- Software Fallback (Using your T-Tables) ---
269    fn sw_aesenc(&self, key: m128i) -> m128i {
270        // This is your original software 'aesenc' function
271        m128i(super::m128::aesenc(self.0, key.0))
272    }
273
274    fn sw_aesdec(&self, key: m128i) -> m128i {
275        // This is your original software 'aesdec' function
276        m128i(super::m128::aesdec(self.0, key.0))
277    }
278
279    pub fn zero() -> m128i {
280        m128i(0)
281    }
282
283    pub fn from_u8(bytes: &[u8]) -> m128i {
284        debug_assert_eq!(bytes.len(), 16);
285        m128i(u128::from_le_bytes(bytes.try_into().unwrap()))
286    }
287
288    pub fn from_i32(i3: i32, i2: i32, i1: i32, i0: i32) -> m128i {
289        let v = ((i3 as u128) << 96)
290            | (((i2 as u32) as u128) << 64)
291            | (((i1 as u32) as u128) << 32)
292            | ((i0 as u32) as u128);
293        m128i(v)
294    }
295
296    pub fn from_u64(u1: u64, u0: u64) -> m128i {
297        m128i(((u1 as u128) << 64) | (u0 as u128))
298    }
299
300    pub fn as_i64(&self) -> (i64, i64) {
301        let lo = self.0 as i64;
302        let hi = (self.0 >> 64) as i64;
303        (hi, lo)
304    }
305
306    /// Converts the two lower i32 lanes to f64, matching _mm_cvtepi32_pd semantics.
307    pub fn lower_to_m128d(&self) -> m128d {
308        let i0 = self.0 as i32 as f64;
309        let i1 = (self.0 >> 32) as i32 as f64;
310        m128d::from_f64(i1, i0)
311    }
312
313    pub fn as_m128d(&self) -> m128d {
314        let lo = self.0 as u64;
315        let hi = (self.0 >> 64) as u64;
316        m128d::from_u64(hi, lo)
317    }
318}
319
320fn format_m128i(m: &m128i, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321    let (hi, lo) = m.as_i64();
322    f.write_fmt(format_args!("({:x},{:x})", hi, lo))
323}
324
325impl fmt::LowerHex for m128i {
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        format_m128i(self, f)
328    }
329}
330
331impl fmt::Debug for m128i {
332    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333        format_m128i(self, f)
334    }
335}
336
337// ---- m128d ----
338
339#[allow(nonstandard_style)]
340#[derive(Copy, Clone)]
341pub struct m128d(pub u128);
342
343impl m128d {
344    pub fn zero() -> m128d {
345        m128d::from_f64(0.0, 0.0)
346    }
347
348    pub fn from_u64(h: u64, l: u64) -> m128d {
349        m128d::from_f64(f64::from_bits(h), f64::from_bits(l))
350    }
351
352    pub fn from_f64(h: f64, l: f64) -> m128d {
353        let lo = l.to_bits() as u128;
354        let hi = (h.to_bits() as u128) << 64;
355        m128d(hi | lo)
356    }
357
358    pub fn as_f64(&self) -> (f64, f64) {
359        let lo = f64::from_bits(self.0 as u64);
360        let hi = f64::from_bits((self.0 >> 64) as u64);
361        (hi, lo)
362    }
363
364    pub fn as_u64(&self) -> (u64, u64) {
365        let (h, l) = self.as_f64();
366        (h.to_bits(), l.to_bits())
367    }
368
369    /// Matches _mm_shuffle_pd(self, other, 1):
370    /// low lane = high of self, high lane = low of other.
371    pub fn shuffle_1(&self, other: &m128d) -> m128d {
372        let (self_hi, _) = self.as_f64();
373        let (_, other_lo) = other.as_f64();
374        m128d::from_f64(other_lo, self_hi)
375    }
376
377    pub fn sqrt(&self) -> m128d {
378        let (h, l) = self.as_f64();
379        m128d::from_f64(h.sqrt(), l.sqrt())
380    }
381}
382
383impl PartialEq for m128d {
384    fn eq(&self, other: &Self) -> bool {
385        self.0 == other.0
386    }
387}
388
389impl Eq for m128d {}
390
391impl std::ops::Add for m128d {
392    type Output = Self;
393    fn add(self, other: Self) -> Self {
394        let (h1, l1) = self.as_f64();
395        let (h2, l2) = other.as_f64();
396        m128d::from_f64(h1 + h2, l1 + l2)
397    }
398}
399
400impl std::ops::Sub for m128d {
401    type Output = Self;
402    fn sub(self, other: Self) -> Self {
403        let (h1, l1) = self.as_f64();
404        let (h2, l2) = other.as_f64();
405        m128d::from_f64(h1 - h2, l1 - l2)
406    }
407}
408
409impl std::ops::Mul for m128d {
410    type Output = Self;
411    fn mul(self, rhs: Self) -> Self {
412        let (h1, l1) = self.as_f64();
413        let (h2, l2) = rhs.as_f64();
414        m128d::from_f64(h1 * h2, l1 * l2)
415    }
416}
417
418impl std::ops::Div for m128d {
419    type Output = Self;
420    fn div(self, rhs: Self) -> Self {
421        let (h1, l1) = self.as_f64();
422        let (h2, l2) = rhs.as_f64();
423        m128d::from_f64(h1 / h2, l1 / l2)
424    }
425}
426
427impl std::ops::BitXor for m128d {
428    type Output = Self;
429    fn bitxor(self, rhs: Self) -> Self {
430        m128d(self.0 ^ rhs.0)
431    }
432}
433
434impl std::ops::BitAnd for m128d {
435    type Output = Self;
436    fn bitand(self, rhs: Self) -> Self {
437        m128d(self.0 & rhs.0)
438    }
439}
440
441impl std::ops::BitOr for m128d {
442    type Output = Self;
443    fn bitor(self, rhs: Self) -> Self {
444        m128d(self.0 | rhs.0)
445    }
446}
447
448fn format_m128d(m: &m128d, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449    let (hi, lo) = m.as_f64();
450    f.write_fmt(format_args!("({},{})", lo, hi))
451}
452
453impl fmt::LowerHex for m128d {
454    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455        let (hi, lo) = self.as_f64();
456        f.write_fmt(format_args!("({:x},{:x})", hi.to_bits(), lo.to_bits()))
457    }
458}
459
460impl fmt::Debug for m128d {
461    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462        format_m128d(self, f)
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[cfg(all(
471        any(target_arch = "x86", target_arch = "x86_64"),
472        target_feature = "sse2",
473        target_feature = "aes",
474    ))]
475    mod hw {
476        use std::arch::x86_64::{__m128i, _mm_aesdec_si128, _mm_aesenc_si128, _mm_set_epi64x};
477
478        pub unsafe fn hw_aesenc(state: u128, key: u128) -> u128 {
479            let s = u128_to_m128i(state);
480            let k = u128_to_m128i(key);
481            let r = _mm_aesenc_si128(s, k);
482            m128i_to_u128(r)
483        }
484
485        pub unsafe fn hw_aesdec(state: u128, key: u128) -> u128 {
486            let s = u128_to_m128i(state);
487            let k = u128_to_m128i(key);
488            let r = _mm_aesdec_si128(s, k);
489            m128i_to_u128(r)
490        }
491
492        unsafe fn u128_to_m128i(v: u128) -> __m128i {
493            let lo = v as i64;
494            let hi = (v >> 64) as i64;
495            _mm_set_epi64x(hi, lo)
496        }
497
498        unsafe fn m128i_to_u128(v: __m128i) -> u128 {
499            use std::arch::x86_64::_mm_extract_epi64;
500            let lo = _mm_extract_epi64(v, 0) as u64 as u128;
501            let hi = (_mm_extract_epi64(v, 1) as u64 as u128) << 64;
502            hi | lo
503        }
504    }
505
506    #[cfg(all(
507        any(target_arch = "x86", target_arch = "x86_64"),
508        target_feature = "sse2",
509        target_feature = "aes",
510    ))]
511    const TEST_VECTORS: &[(u128, u128)] = &[
512        (
513            0x0000_0000_0000_0000_0000_0000_0000_0000,
514            0x0000_0000_0000_0000_0000_0000_0000_0000,
515        ),
516        (
517            0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff,
518            0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff,
519        ),
520        (
521            0x3243f6a8885a308d313198a2e0370734,
522            0x2b7e151628aed2a6abf7158809cf4f3c,
523        ),
524        (
525            0x0000_0000_0000_0000_0000_0000_0000_0000,
526            0xdead_beef_cafe_babe_1234_5678_9abc_def0,
527        ),
528        (
529            0xdead_beef_cafe_babe_1234_5678_9abc_def0,
530            0x0000_0000_0000_0000_0000_0000_0000_0000,
531        ),
532        (
533            0xaaaa_aaaa_aaaa_aaaa_aaaa_aaaa_aaaa_aaaa,
534            0x5555_5555_5555_5555_5555_5555_5555_5555,
535        ),
536        (
537            0x0f0e_0d0c_0b0a_0908_0706_0504_0302_0100,
538            0x1f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
539        ),
540        (
541            0x6bc1bee22e409f96e93d7e117393172a,
542            0xae2d8a571e03ac9c9eb76fac45af8e51,
543        ),
544        (
545            0x0000_0000_0000_0000_0000_0000_0000_0001,
546            0x0000_0000_0000_0000_0000_0000_0000_0000,
547        ),
548        (
549            0x0000_0000_0000_0000_0000_0000_0000_0000,
550            0x8000_0000_0000_0000_0000_0000_0000_0000,
551        ),
552    ];
553
554    #[test]
555    #[cfg(all(
556        any(target_arch = "x86", target_arch = "x86_64"),
557        target_feature = "sse2",
558        target_feature = "aes",
559    ))]
560    fn aesenc_matches_hardware() {
561        let mut all_pass = true;
562        for (i, &(state, key)) in TEST_VECTORS.iter().enumerate() {
563            let sw = aesenc(state, key);
564            let hw = unsafe { hw::hw_aesenc(state, key) };
565            if sw != hw {
566                eprintln!(
567                    "[aesenc] vector {i} FAIL\n  state={state:032x}\n  key  ={key:032x}\n  sw   ={sw:032x}\n  hw   ={hw:032x}\n  diff ={:032x}",
568                    sw ^ hw
569                );
570                all_pass = false;
571            } else {
572                eprintln!("[aesenc] vector {i} ok -> {sw:032x}");
573            }
574        }
575        assert!(all_pass, "aesenc mismatches (see stderr)");
576    }
577
578    #[test]
579    #[cfg(all(
580        any(target_arch = "x86", target_arch = "x86_64"),
581        target_feature = "sse2",
582        target_feature = "aes",
583    ))]
584    fn aesdec_matches_hardware() {
585        let mut all_pass = true;
586        for (i, &(state, key)) in TEST_VECTORS.iter().enumerate() {
587            let sw = aesdec(state, key);
588            let hw = unsafe { hw::hw_aesdec(state, key) };
589            if sw != hw {
590                eprintln!(
591                    "[aesdec] vector {i} FAIL\n  state={state:032x}\n  key  ={key:032x}\n  sw   ={sw:032x}\n  hw   ={hw:032x}\n  diff ={:032x}",
592                    sw ^ hw
593                );
594                all_pass = false;
595            } else {
596                eprintln!("[aesdec] vector {i} ok -> {sw:032x}");
597            }
598        }
599        assert!(all_pass, "aesdec mismatches (see stderr)");
600    }
601
602    #[test]
603    #[cfg(all(
604        any(target_arch = "x86", target_arch = "x86_64"),
605        target_feature = "sse2",
606        target_feature = "aes",
607    ))]
608    fn enc_dec_consistency() {
609        for (i, &(state, key)) in TEST_VECTORS.iter().enumerate() {
610            let sw_enc = aesenc(state, key);
611            let sw_dec = aesdec(sw_enc, key);
612            let hw_enc = unsafe { hw::hw_aesenc(state, key) };
613            let hw_dec = unsafe { hw::hw_aesdec(hw_enc, key) };
614            assert_eq!(sw_dec, hw_dec, "vector {i} enc_dec mismatch");
615        }
616    }
617
618    #[test]
619    fn aesenc_perf_smoke() {
620        let start = std::time::Instant::now();
621        let mut state = 0x3243f6a8885a308d313198a2e0370734u128;
622        let key = 0x2b7e151628aed2a6abf7158809cf4f3cu128;
623        for _ in 0..1_000_000 {
624            state = aesenc(state, key);
625        }
626        let elapsed = start.elapsed();
627        eprintln!(
628            "1M aesenc: {:?}  ({} ns/op)",
629            elapsed,
630            elapsed.as_nanos() / 1_000_000
631        );
632        let _ = state;
633    }
634}