Skip to main content

vck_common/
xts.rs

1// SPDX-FileCopyrightText: 2026 JC-Lab <joseph@jc-lab.net>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5//! Shared AES-256-XTS volume sector cipher used by both the kernel driver and
6//! the UEFI loader, so their on-disk crypto agrees by construction.
7//!
8//! **Tweak convention (authoritative):** the XTS tweak for a sector is its
9//! **data-region-relative** sector number, i.e. `rel = absolute_lba - offset_sector`,
10//! where `rel == 0` is the first encryptable sector. This matches the
11//! `EncryptedOffset` semantics (also data-region relative). Callers MUST map
12//! absolute LBAs to `rel` before invoking these methods, and MUST NOT call them
13//! for sectors inside header/footer metadata regions (those pass through in
14//! plaintext).
15//!
16//! Keys are two independent 256-bit halves (`key1` = data key, `key2` = tweak
17//! key), giving AES-256-XTS.
18//!
19//! # Performance
20//!
21//! All sectors are processed through an 8-block parallel XTS path that keeps 8
22//! independent AES operations in flight simultaneously. On x86-64 with AES-NI
23//! (detected at runtime by the `aes` crate) this saturates the throughput
24//! pipeline (~1 cycle per 16-byte block) instead of being latency-bound
25//! (~7 cycles per block for AES-256). Sectors are always a multiple of 16 bytes
26//! (512, 4096, …) so ciphertext stealing never applies.
27//!
28//! # Kernel stack safety
29//!
30//! The driver runs this crypto on a constrained kernel stack (a system-thread
31//! stack of ~24 KiB, and an IOCTL callout stack of 32 KiB). The crate is built
32//! for the driver WITHOUT `-C target-feature=+aes`, so the `aes` crate's
33//! fully-unrolled AES-NI `encrypt8`/`decrypt8` stay behind a runtime-dispatch
34//! call boundary instead of being inlined into (and ballooning the frames of)
35//! the deep storage/metadata call chain. The per-sector entry points are
36//! additionally marked `#[inline(never)]` so their AES frames can never combine
37//! with a caller's frame. (A prior build with `+aes` inlined the unrolled AES
38//! into the IOCTL path and double-faulted on stack overflow.)
39
40use aes::cipher::{BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
41use aes::{Aes256, Block};
42
43use crate::{VckError, VckResult};
44
45/// Number of AES-XTS blocks processed in one parallel batch.
46/// Matches the AES-NI backend's `ParBlocks = 8`, filling the 7-cycle pipeline.
47const BATCH: usize = 8;
48
49/// A full-volume sector cipher.
50///
51/// The default JVCK suite uses [`XtsVolumeCipher`] (AES-256-XTS), but a vendor
52/// can supply a different full-volume-encryption algorithm by implementing this
53/// trait and selecting it from the volume metadata (see the driver's cipher
54/// factory). All sector numbers are **data-region relative** (`rel = lba -
55/// offset_sector`); see the module docs for the tweak convention.
56pub trait VolumeCipher: Send + Sync {
57    /// Encrypt one sector in place.
58    fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]);
59    /// Decrypt one sector in place.
60    fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]);
61    /// Encrypt a contiguous buffer of `sector_size`-byte sectors.
62    fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64);
63    /// Decrypt a contiguous buffer (inverse of [`encrypt_area`](Self::encrypt_area)).
64    fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64);
65}
66
67pub struct XtsVolumeCipher {
68    /// Data cipher for the AES-XTS payload blocks.
69    cipher_1: Aes256,
70    /// Tweak cipher (initial tweak = AES_K2(sector_number)).
71    cipher_2: Aes256,
72}
73
74impl VolumeCipher for XtsVolumeCipher {
75    fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
76        XtsVolumeCipher::encrypt_sector(self, rel_sector, sector)
77    }
78    fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
79        XtsVolumeCipher::decrypt_sector(self, rel_sector, sector)
80    }
81    fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
82        XtsVolumeCipher::encrypt_area(self, buf, sector_size, first_rel_sector)
83    }
84    fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
85        XtsVolumeCipher::decrypt_area(self, buf, sector_size, first_rel_sector)
86    }
87}
88
89/// GF(2^128) multiplication by the primitive element alpha in the XTS field
90/// (little-endian byte order, primitive polynomial x^128 + x^7 + x^2 + x + 1).
91#[inline(always)]
92fn gf128_mul(t: Block) -> Block {
93    let lo = u64::from_le_bytes(t[..8].try_into().unwrap());
94    let hi = u64::from_le_bytes(t[8..].try_into().unwrap());
95    let carry = if hi >> 63 != 0 { 0x87u64 } else { 0u64 };
96    let mut out = Block::default();
97    out[..8].copy_from_slice(&((lo << 1) ^ carry).to_le_bytes());
98    out[8..].copy_from_slice(&((hi << 1) | (lo >> 63)).to_le_bytes());
99    out
100}
101
102impl XtsVolumeCipher {
103    pub fn new(key1: &[u8; 32], key2: &[u8; 32]) -> VckResult<Self> {
104        let cipher_1 =
105            Aes256::new_from_slice(key1).map_err(|_| VckError::CryptoFailed("invalid XTS key1"))?;
106        let cipher_2 =
107            Aes256::new_from_slice(key2).map_err(|_| VckError::CryptoFailed("invalid XTS key2"))?;
108        Ok(Self { cipher_1, cipher_2 })
109    }
110
111    /// Encrypt one sector in place. `rel_sector` is data-region relative.
112    pub fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
113        self.encrypt_sector_inner(rel_sector, sector);
114    }
115
116    /// Decrypt one sector in place. `rel_sector` is data-region relative.
117    pub fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
118        self.decrypt_sector_inner(rel_sector, sector);
119    }
120
121    /// Encrypt a contiguous buffer of `sector_size`-byte sectors starting at
122    /// data-region-relative sector `first_rel_sector`.
123    pub fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
124        for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
125            self.encrypt_sector_inner(first_rel_sector + si as u64, sector);
126        }
127    }
128
129    /// Decrypt a contiguous buffer (inverse of [`encrypt_area`]).
130    pub fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
131        for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
132            self.decrypt_sector_inner(first_rel_sector + si as u64, sector);
133        }
134    }
135
136    /// `#[inline(never)]` bounds this function's (AES-heavy) stack frame so it
137    /// cannot merge with a deep caller's frame on the kernel stack.
138    #[inline(never)]
139    fn encrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
140        // T_0 = AES_K2(sector_number as 128-bit little-endian)
141        let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
142        self.cipher_2.encrypt_block(&mut tw);
143
144        let n = sector.len() / 16;
145        let mut off = 0;
146
147        // 8-block parallel path: all 8 AES operations are independent so the
148        // CPU can keep the AES-NI units fully pipelined.
149        while off + BATCH <= n {
150            let mut ts = [Block::default(); BATCH];
151            ts[0] = tw;
152            for i in 1..BATCH {
153                ts[i] = gf128_mul(ts[i - 1]);
154            }
155            tw = gf128_mul(ts[BATCH - 1]);
156
157            let mut batch = [Block::default(); BATCH];
158            for i in 0..BATCH {
159                let src = &sector[(off + i) * 16..(off + i + 1) * 16];
160                for j in 0..16 {
161                    batch[i][j] = src[j] ^ ts[i][j];
162                }
163            }
164            self.cipher_1.encrypt_blocks(&mut batch);
165            for i in 0..BATCH {
166                let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
167                for j in 0..16 {
168                    dst[j] = batch[i][j] ^ ts[i][j];
169                }
170            }
171            off += BATCH;
172        }
173
174        // Scalar tail for sectors whose block count is not a multiple of BATCH.
175        while off < n {
176            let block = &mut sector[off * 16..(off + 1) * 16];
177            for j in 0..16 {
178                block[j] ^= tw[j];
179            }
180            let mut ga: Block = Block::try_from(&block[..]).unwrap();
181            self.cipher_1.encrypt_block(&mut ga);
182            block.copy_from_slice(&ga);
183            for j in 0..16 {
184                block[j] ^= tw[j];
185            }
186            tw = gf128_mul(tw);
187            off += 1;
188        }
189    }
190
191    #[inline(never)]
192    fn decrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
193        // Tweak is always encrypted with K2 (even during decryption).
194        let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
195        self.cipher_2.encrypt_block(&mut tw);
196
197        let n = sector.len() / 16;
198        let mut off = 0;
199
200        while off + BATCH <= n {
201            let mut ts = [Block::default(); BATCH];
202            ts[0] = tw;
203            for i in 1..BATCH {
204                ts[i] = gf128_mul(ts[i - 1]);
205            }
206            tw = gf128_mul(ts[BATCH - 1]);
207
208            let mut batch = [Block::default(); BATCH];
209            for i in 0..BATCH {
210                let src = &sector[(off + i) * 16..(off + i + 1) * 16];
211                for j in 0..16 {
212                    batch[i][j] = src[j] ^ ts[i][j];
213                }
214            }
215            self.cipher_1.decrypt_blocks(&mut batch);
216            for i in 0..BATCH {
217                let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
218                for j in 0..16 {
219                    dst[j] = batch[i][j] ^ ts[i][j];
220                }
221            }
222            off += BATCH;
223        }
224
225        while off < n {
226            let block = &mut sector[off * 16..(off + 1) * 16];
227            for j in 0..16 {
228                block[j] ^= tw[j];
229            }
230            let mut ga: Block = Block::try_from(&block[..]).unwrap();
231            self.cipher_1.decrypt_block(&mut ga);
232            block.copy_from_slice(&ga);
233            for j in 0..16 {
234                block[j] ^= tw[j];
235            }
236            tw = gf128_mul(tw);
237            off += 1;
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use alloc::vec::Vec;
246    // Reference implementation for cross-checking standards compliance.
247    use xts_mode::{get_tweak_default, Xts128};
248
249    const KEY1: [u8; 32] = [0x11; 32];
250    const KEY2: [u8; 32] = [0x22; 32];
251
252    /// Build the `xts-mode` reference cipher for the same key pair.
253    fn reference() -> Xts128<Aes256> {
254        let c1 = Aes256::new_from_slice(&KEY1).unwrap();
255        let c2 = Aes256::new_from_slice(&KEY2).unwrap();
256        Xts128::new(c1, c2)
257    }
258
259    #[test]
260    fn sector_roundtrip() {
261        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
262        let plain: Vec<u8> = (0..512).map(|i| i as u8).collect();
263        let mut buf = plain.clone();
264        c.encrypt_sector(42, &mut buf);
265        assert_ne!(buf, plain, "ciphertext must differ from plaintext");
266        c.decrypt_sector(42, &mut buf);
267        assert_eq!(buf, plain);
268    }
269
270    #[test]
271    fn tweak_depends_on_sector() {
272        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
273        let plain = [0xABu8; 512];
274        let mut a = plain;
275        let mut b = plain;
276        c.encrypt_sector(0, &mut a);
277        c.encrypt_sector(1, &mut b);
278        assert_ne!(a, b, "same plaintext at different sectors must differ");
279    }
280
281    /// Our parallel path must produce byte-identical ciphertext to the standard
282    /// `xts-mode` implementation (data-region-relative sector as the tweak).
283    #[test]
284    fn matches_xts_mode_reference() {
285        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
286        let xts = reference();
287        let sector_size = 512usize;
288        let first = 7u64;
289        let plain: Vec<u8> = (0..sector_size * 3).map(|i| (i * 7) as u8).collect();
290
291        let mut ours = plain.clone();
292        c.encrypt_area(&mut ours, sector_size, first);
293
294        let mut refer = plain.clone();
295        for s in 0..3u64 {
296            let start = s as usize * sector_size;
297            xts.encrypt_sector(
298                &mut refer[start..start + sector_size],
299                get_tweak_default((first + s) as u128),
300            );
301        }
302        assert_eq!(ours, refer, "parallel XTS must match xts-mode reference");
303
304        c.decrypt_area(&mut ours, sector_size, first);
305        assert_eq!(ours, plain);
306    }
307
308    /// Round-trip with a sector size whose block count is not a multiple of
309    /// BATCH (64 bytes = 4 blocks < 8) exercises the scalar tail.
310    #[test]
311    fn small_sector_roundtrip() {
312        let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
313        let sector_size = 64usize;
314        let plain: Vec<u8> = (0..sector_size * 5).map(|i| i as u8).collect();
315        let mut buf = plain.clone();
316        c.encrypt_area(&mut buf, sector_size, 0);
317        assert_ne!(buf, plain);
318        c.decrypt_area(&mut buf, sector_size, 0);
319        assert_eq!(buf, plain);
320    }
321}