softaes/
lib.rs

1//! Software implementation of the AES round function.
2
3#![no_std]
4
5use core::{
6    cmp,
7    mem::MaybeUninit,
8    ops,
9    sync::atomic::{self, Ordering},
10};
11
12const LUT: [u32; 256] = [
13    0xa56363c6, 0x847c7cf8, 0x997777ee, 0x8d7b7bf6, 0x0df2f2ff, 0xbd6b6bd6, 0xb16f6fde, 0x54c5c591,
14    0x50303060, 0x03010102, 0xa96767ce, 0x7d2b2b56, 0x19fefee7, 0x62d7d7b5, 0xe6abab4d, 0x9a7676ec,
15    0x45caca8f, 0x9d82821f, 0x40c9c989, 0x877d7dfa, 0x15fafaef, 0xeb5959b2, 0xc947478e, 0x0bf0f0fb,
16    0xecadad41, 0x67d4d4b3, 0xfda2a25f, 0xeaafaf45, 0xbf9c9c23, 0xf7a4a453, 0x967272e4, 0x5bc0c09b,
17    0xc2b7b775, 0x1cfdfde1, 0xae93933d, 0x6a26264c, 0x5a36366c, 0x413f3f7e, 0x02f7f7f5, 0x4fcccc83,
18    0x5c343468, 0xf4a5a551, 0x34e5e5d1, 0x08f1f1f9, 0x937171e2, 0x73d8d8ab, 0x53313162, 0x3f15152a,
19    0x0c040408, 0x52c7c795, 0x65232346, 0x5ec3c39d, 0x28181830, 0xa1969637, 0x0f05050a, 0xb59a9a2f,
20    0x0907070e, 0x36121224, 0x9b80801b, 0x3de2e2df, 0x26ebebcd, 0x6927274e, 0xcdb2b27f, 0x9f7575ea,
21    0x1b090912, 0x9e83831d, 0x742c2c58, 0x2e1a1a34, 0x2d1b1b36, 0xb26e6edc, 0xee5a5ab4, 0xfba0a05b,
22    0xf65252a4, 0x4d3b3b76, 0x61d6d6b7, 0xceb3b37d, 0x7b292952, 0x3ee3e3dd, 0x712f2f5e, 0x97848413,
23    0xf55353a6, 0x68d1d1b9, 0x00000000, 0x2cededc1, 0x60202040, 0x1ffcfce3, 0xc8b1b179, 0xed5b5bb6,
24    0xbe6a6ad4, 0x46cbcb8d, 0xd9bebe67, 0x4b393972, 0xde4a4a94, 0xd44c4c98, 0xe85858b0, 0x4acfcf85,
25    0x6bd0d0bb, 0x2aefefc5, 0xe5aaaa4f, 0x16fbfbed, 0xc5434386, 0xd74d4d9a, 0x55333366, 0x94858511,
26    0xcf45458a, 0x10f9f9e9, 0x06020204, 0x817f7ffe, 0xf05050a0, 0x443c3c78, 0xba9f9f25, 0xe3a8a84b,
27    0xf35151a2, 0xfea3a35d, 0xc0404080, 0x8a8f8f05, 0xad92923f, 0xbc9d9d21, 0x48383870, 0x04f5f5f1,
28    0xdfbcbc63, 0xc1b6b677, 0x75dadaaf, 0x63212142, 0x30101020, 0x1affffe5, 0x0ef3f3fd, 0x6dd2d2bf,
29    0x4ccdcd81, 0x140c0c18, 0x35131326, 0x2fececc3, 0xe15f5fbe, 0xa2979735, 0xcc444488, 0x3917172e,
30    0x57c4c493, 0xf2a7a755, 0x827e7efc, 0x473d3d7a, 0xac6464c8, 0xe75d5dba, 0x2b191932, 0x957373e6,
31    0xa06060c0, 0x98818119, 0xd14f4f9e, 0x7fdcdca3, 0x66222244, 0x7e2a2a54, 0xab90903b, 0x8388880b,
32    0xca46468c, 0x29eeeec7, 0xd3b8b86b, 0x3c141428, 0x79dedea7, 0xe25e5ebc, 0x1d0b0b16, 0x76dbdbad,
33    0x3be0e0db, 0x56323264, 0x4e3a3a74, 0x1e0a0a14, 0xdb494992, 0x0a06060c, 0x6c242448, 0xe45c5cb8,
34    0x5dc2c29f, 0x6ed3d3bd, 0xefacac43, 0xa66262c4, 0xa8919139, 0xa4959531, 0x37e4e4d3, 0x8b7979f2,
35    0x32e7e7d5, 0x43c8c88b, 0x5937376e, 0xb76d6dda, 0x8c8d8d01, 0x64d5d5b1, 0xd24e4e9c, 0xe0a9a949,
36    0xb46c6cd8, 0xfa5656ac, 0x07f4f4f3, 0x25eaeacf, 0xaf6565ca, 0x8e7a7af4, 0xe9aeae47, 0x18080810,
37    0xd5baba6f, 0x887878f0, 0x6f25254a, 0x722e2e5c, 0x241c1c38, 0xf1a6a657, 0xc7b4b473, 0x51c6c697,
38    0x23e8e8cb, 0x7cdddda1, 0x9c7474e8, 0x211f1f3e, 0xdd4b4b96, 0xdcbdbd61, 0x868b8b0d, 0x858a8a0f,
39    0x907070e0, 0x423e3e7c, 0xc4b5b571, 0xaa6666cc, 0xd8484890, 0x05030306, 0x01f6f6f7, 0x120e0e1c,
40    0xa36161c2, 0x5f35356a, 0xf95757ae, 0xd0b9b969, 0x91868617, 0x58c1c199, 0x271d1d3a, 0xb99e9e27,
41    0x38e1e1d9, 0x13f8f8eb, 0xb398982b, 0x33111122, 0xbb6969d2, 0x70d9d9a9, 0x898e8e07, 0xa7949433,
42    0xb69b9b2d, 0x221e1e3c, 0x92878715, 0x20e9e9c9, 0x49cece87, 0xff5555aa, 0x78282850, 0x7adfdfa5,
43    0x8f8c8c03, 0xf8a1a159, 0x80898909, 0x170d0d1a, 0xdabfbf65, 0x31e6e6d7, 0xc6424284, 0xb86868d0,
44    0xc3414182, 0xb0999929, 0x772d2d5a, 0x110f0f1e, 0xcbb0b07b, 0xfc5454a8, 0xd6bbbb6d, 0x3a16162c,
45];
46
47/// An AES block.
48#[repr(align(16))]
49#[derive(Copy, Clone, Debug, Default)]
50pub struct Block {
51    w0: u32,
52    w1: u32,
53    w2: u32,
54    w3: u32,
55}
56
57impl cmp::PartialEq for Block {
58    #[inline(never)]
59    fn eq(&self, other: &Block) -> bool {
60        let z = self ^ other;
61        z.w0 | z.w1 | z.w2 | z.w3 == 0
62    }
63}
64
65impl cmp::Eq for Block {}
66
67impl Block {
68    #[inline(always)]
69    pub fn from_bytes(input: &[u8; 16]) -> Block {
70        Block {
71            w0: u32::from_le_bytes([input[0], input[1], input[2], input[3]]),
72            w1: u32::from_le_bytes([input[4], input[5], input[6], input[7]]),
73            w2: u32::from_le_bytes([input[8], input[9], input[10], input[11]]),
74            w3: u32::from_le_bytes([input[12], input[13], input[14], input[15]]),
75        }
76    }
77
78    #[inline(always)]
79    pub fn from_slice(input: &[u8]) -> Block {
80        debug_assert!(input.len() == 16);
81        Block {
82            w0: u32::from_le_bytes([input[0], input[1], input[2], input[3]]),
83            w1: u32::from_le_bytes([input[4], input[5], input[6], input[7]]),
84            w2: u32::from_le_bytes([input[8], input[9], input[10], input[11]]),
85            w3: u32::from_le_bytes([input[12], input[13], input[14], input[15]]),
86        }
87    }
88
89    #[inline(always)]
90    pub fn from64x2(a: u64, b: u64) -> Block {
91        Block {
92            w0: b as u32,
93            w1: (b >> 32) as u32,
94            w2: a as u32,
95            w3: (a >> 32) as u32,
96        }
97    }
98
99    #[inline(always)]
100    pub fn to_bytes(&self) -> [u8; 16] {
101        let mut out: [u8; 16] = Default::default();
102        out[0..4].copy_from_slice(&self.w0.to_le_bytes());
103        out[4..8].copy_from_slice(&self.w1.to_le_bytes());
104        out[8..12].copy_from_slice(&self.w2.to_le_bytes());
105        out[12..16].copy_from_slice(&self.w3.to_le_bytes());
106        out
107    }
108
109    #[inline(always)]
110    pub fn xor(&self, other: &Block) -> Block {
111        Block {
112            w0: self.w0 ^ other.w0,
113            w1: self.w1 ^ other.w1,
114            w2: self.w2 ^ other.w2,
115            w3: self.w3 ^ other.w3,
116        }
117    }
118
119    #[inline(always)]
120    pub fn and(&self, other: &Block) -> Block {
121        Block {
122            w0: self.w0 & other.w0,
123            w1: self.w1 & other.w1,
124            w2: self.w2 & other.w2,
125            w3: self.w3 & other.w3,
126        }
127    }
128}
129
130impl ops::BitAnd for Block {
131    type Output = Block;
132
133    #[inline(always)]
134    fn bitand(self, rhs: Self) -> Self::Output {
135        self.and(&rhs)
136    }
137}
138
139impl ops::BitAnd for &Block {
140    type Output = Block;
141
142    #[inline(always)]
143    fn bitand(self, rhs: Self) -> Self::Output {
144        self.and(rhs)
145    }
146}
147
148impl ops::BitXor for Block {
149    type Output = Block;
150
151    #[inline(always)]
152    fn bitxor(self, rhs: Self) -> Self::Output {
153        self.xor(&rhs)
154    }
155}
156
157impl ops::BitXor for &Block {
158    type Output = Block;
159
160    #[inline(always)]
161    fn bitxor(self, rhs: Self) -> Self::Output {
162        self.xor(rhs)
163    }
164}
165
166pub struct SoftAes<const STRIDE: usize, const STRIDE_INV: usize>;
167
168#[repr(align(64))]
169struct AlignedT<const STRIDE_INV: usize>([[[MaybeUninit<u32>; STRIDE_INV]; 4]; 4]);
170
171#[repr(align(64))]
172struct AlignedOf([[MaybeUninit<u8>; 4]; 4]);
173
174impl<const STRIDE: usize, const STRIDE_INV: usize> SoftAes<STRIDE, STRIDE_INV> {
175    const STATIC_ASSERT_STRIDE: usize = (STRIDE_INV == 256 / STRIDE && STRIDE <= 256) as usize - 1;
176
177    fn _encrypt(ix0: [u8; 4], ix1: [u8; 4], ix2: [u8; 4], ix3: [u8; 4]) -> Block {
178        _ = Self::STATIC_ASSERT_STRIDE;
179
180        let mut t: AlignedT<STRIDE_INV> =
181            AlignedT([[[MaybeUninit::<u32>::uninit(); STRIDE_INV]; 4]; 4]);
182
183        let mut of: AlignedOf = AlignedOf([[MaybeUninit::<u8>::uninit(); 4]; 4]);
184
185        for j in 0..4 {
186            of.0[j][0].write((ix0[j] as usize % STRIDE) as _);
187            of.0[j][1].write((ix1[j] as usize % STRIDE) as _);
188            of.0[j][2].write((ix2[j] as usize % STRIDE) as _);
189            of.0[j][3].write((ix3[j] as usize % STRIDE) as _);
190        }
191        for i in 0usize..256 / STRIDE {
192            for j in 0usize..4 {
193                t.0[j][0][i]
194                    .write(LUT[(i * STRIDE) | unsafe { of.0[j][0].assume_init() } as usize]);
195                t.0[j][1][i]
196                    .write(LUT[(i * STRIDE) | unsafe { of.0[j][1].assume_init() } as usize]);
197                t.0[j][2][i]
198                    .write(LUT[(i * STRIDE) | unsafe { of.0[j][2].assume_init() } as usize]);
199                t.0[j][3][i]
200                    .write(LUT[(i * STRIDE) | unsafe { of.0[j][3].assume_init() } as usize]);
201            }
202        }
203
204        atomic::compiler_fence(Ordering::Acquire);
205
206        let mut w0 = unsafe { t.0[0][0][ix0[0] as usize / STRIDE].assume_init() };
207        w0 ^= unsafe { t.0[0][1][ix1[0] as usize / STRIDE].assume_init() }.rotate_left(8);
208        w0 ^= unsafe { t.0[0][2][ix1[0] as usize / STRIDE].assume_init() }.rotate_left(16);
209        w0 ^= unsafe { t.0[0][3][ix1[0] as usize / STRIDE].assume_init() }.rotate_left(24);
210
211        let mut w1 = unsafe { t.0[1][0][ix0[1] as usize / STRIDE].assume_init() };
212        w1 ^= unsafe { t.0[1][1][ix1[1] as usize / STRIDE].assume_init() }.rotate_left(8);
213        w1 ^= unsafe { t.0[1][2][ix1[1] as usize / STRIDE].assume_init() }.rotate_left(16);
214        w1 ^= unsafe { t.0[1][3][ix1[1] as usize / STRIDE].assume_init() }.rotate_left(24);
215
216        let mut w2 = unsafe { t.0[2][0][ix0[2] as usize / STRIDE].assume_init() };
217        w2 ^= unsafe { t.0[2][1][ix1[2] as usize / STRIDE].assume_init() }.rotate_left(8);
218        w2 ^= unsafe { t.0[2][2][ix1[2] as usize / STRIDE].assume_init() }.rotate_left(16);
219        w2 ^= unsafe { t.0[2][3][ix1[2] as usize / STRIDE].assume_init() }.rotate_left(24);
220
221        let mut w3 = unsafe { t.0[3][0][ix0[3] as usize / STRIDE].assume_init() };
222        w3 ^= unsafe { t.0[3][1][ix1[3] as usize / STRIDE].assume_init() }.rotate_left(8);
223        w3 ^= unsafe { t.0[3][2][ix1[3] as usize / STRIDE].assume_init() }.rotate_left(16);
224        w3 ^= unsafe { t.0[3][3][ix1[3] as usize / STRIDE].assume_init() }.rotate_left(24);
225
226        Block { w0, w1, w2, w3 }
227    }
228
229    /// AES forward round function.
230    /// `rk` is the round key.
231    #[inline]
232    pub fn block_encrypt(block: &Block, rk: &Block) -> Block {
233        let s0 = block.w0;
234        let s1 = block.w1;
235        let s2 = block.w2;
236        let s3 = block.w3;
237
238        let ix0: [u8; 4] = [s0 as _, s1 as _, s2 as _, s3 as _];
239        let ix1: [u8; 4] = [
240            (s1 >> 8) as _,
241            (s2 >> 8) as _,
242            (s3 >> 8) as _,
243            (s0 >> 8) as _,
244        ];
245        let ix2: [u8; 4] = [
246            (s2 >> 16) as _,
247            (s3 >> 16) as _,
248            (s0 >> 16) as _,
249            (s1 >> 16) as _,
250        ];
251        let ix3: [u8; 4] = [
252            (s3 >> 24) as _,
253            (s0 >> 24) as _,
254            (s1 >> 24) as _,
255            (s2 >> 24) as _,
256        ];
257        let mut out = Self::_encrypt(ix0, ix1, ix2, ix3);
258        out.w0 ^= rk.w0;
259        out.w1 ^= rk.w1;
260        out.w2 ^= rk.w2;
261        out.w3 ^= rk.w3;
262        out
263    }
264}
265
266/// Software AES implementation with a stride of 16 words (paranoid protection against side channels)
267pub type SoftAesSlow = SoftAes<16, { 256 / 16 }>;
268
269/// Software AES implementation with a stride of 64 words (practical protection against side channels)
270pub type SoftAesModerate = SoftAes<64, { 256 / 64 }>;
271
272/// Fast software AES implementation, but with minimal protection against side channels
273pub type SoftAesFast = SoftAes<256, { 256 / 256 }>;
274
275/// Fastest software AES implementation, but with no protection against side channels
276pub mod unprotected;
277
278#[test]
279fn test() {
280    let input_bytes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
281    let input = Block::from_bytes(&input_bytes);
282    let rk = Block::from_bytes(&[1u8; 16]);
283    let output = SoftAesFast::block_encrypt(&input, &rk);
284    let expected = Block::from_bytes(&[
285        107, 107, 93, 68, 45, 108, 50, 80, 177, 216, 92, 96, 38, 157, 32, 93,
286    ]);
287    assert_eq!(output, expected);
288}