Skip to main content

xenia_wire/
replay_window.rs

1// Copyright (c) 2024-2026 Tristan Stoltz / Luminous Dynamics
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Replay-protection sliding window for AEAD-sealed streams.
5//!
6//! ## Why
7//!
8//! ChaCha20-Poly1305 + monotonic nonce prevents *encryption* reuse (the
9//! sender will never produce two ciphertexts with the same nonce under the
10//! same key), but it does NOT prevent a network attacker from capturing a
11//! sealed envelope and replaying it later — the receiver will accept it
12//! because AEAD verification still succeeds against the original key.
13//!
14//! For idempotent payloads (screen updates) replay is mostly cosmetic. For
15//! reverse-path input messages, replay is a real security hole: a
16//! captured `tap (504, 1122)` could be re-fired to re-execute the action.
17//!
18//! ## Design
19//!
20//! Sliding window over received sequence numbers, keyed by `(source_id,
21//! payload_type)`. Window size is [`WINDOW_BITS`] bits — the receiver tracks
22//! the highest sequence number seen so far plus a [`WINDOW_BITS`]-bit bitmap
23//! of the most recent sequences. A sequence is accepted iff:
24//!
25//! 1. It is strictly higher than the highest-seen-so-far (advance the
26//!    window), OR
27//! 2. It falls within the bitmap range AND the corresponding bit is unset
28//!    (mark the bit, accept the message).
29//!
30//! Sequences that are too old (more than [`WINDOW_BITS`] below the highest
31//! seen) are rejected outright. Duplicates within the bitmap range are
32//! rejected.
33//!
34//! ## Wraparound
35//!
36//! u64 sequence space is effectively unbounded — at 30 frames/sec this
37//! wraps in ~19 billion years. The implementation does not handle
38//! wraparound specifically because real session lifetime (governed by key
39//! rotation) is many orders of magnitude shorter.
40//!
41//! ## Multi-stream isolation
42//!
43//! Different `(source_id, payload_type)` tuples have independent windows.
44//! This is required because the forward-path frame stream and reverse-path
45//! input stream share a session key but maintain independent sequence
46//! counters via [`crate::Session::next_nonce`]. Replay protection is per
47//! tuple, not per session.
48//!
49//! ## Key epoch scoping (SPEC draft-02r1 §5.3)
50//!
51//! Windows are additionally scoped by a `key_epoch` byte that
52//! increments each time a new session key is installed. This matters
53//! because [`crate::Session::install_key`] resets the nonce counter
54//! to `0` on rekey — without per-epoch scoping, a counter-reset
55//! sender would produce low sequences that the receiver would
56//! reject against a still-high `highest` from the previous key.
57//!
58//! During the rekey grace period two per-epoch windows are live
59//! simultaneously for the same `(source_id, payload_type)` stream
60//! — one per key — and each envelope is routed to the window
61//! matching the key that verified its AEAD tag. When the previous
62//! key expires, [`ReplayWindow::drop_epoch`] removes that epoch's
63//! entries to bound memory.
64
65use std::collections::HashMap;
66
67/// Default replay window width in bits.
68///
69/// 64 bits is the standard IPsec/DTLS replay window width. See SPEC §5.1.
70/// Configurable per-session via [`crate::SessionBuilder::with_replay_window_bits`]
71/// (draft-02r2 / alpha.5+) up to [`MAX_WINDOW_BITS`].
72pub const DEFAULT_WINDOW_BITS: u32 = 64;
73
74/// Maximum supported replay window width in bits.
75///
76/// 1024 bits = 128 bytes of bitmap per stream. Suitable for high-jitter
77/// transports where ~64-packet reordering is realistic. The upper bound
78/// is chosen to keep per-stream memory bounded; callers with unusual
79/// requirements can bump this constant, but the default / SPEC-specified
80/// maximum is 1024.
81pub const MAX_WINDOW_BITS: u32 = 1024;
82
83/// Legacy alias for [`DEFAULT_WINDOW_BITS`]. Kept for backwards-
84/// compatible public API; new code should use `DEFAULT_WINDOW_BITS`.
85pub const WINDOW_BITS: u64 = DEFAULT_WINDOW_BITS as u64;
86
87/// Per-stream replay state: highest sequence seen + bitmap of the most
88/// recent `window_bits` sequences. The bitmap is stored as a vector of
89/// u64 words, length = `window_bits / 64`.
90#[derive(Debug, Clone)]
91struct StreamWindow {
92    /// Highest sequence number seen so far. The bitmap tracks
93    /// `[highest - window_bits + 1, highest]`. Bit position 0 = highest,
94    /// bit position `window_bits-1` = oldest.
95    highest: u64,
96    /// Bitmap of received sequences. `bitmap[w]` covers bits
97    /// `[64*w .. 64*(w+1))` in offset-from-highest space. Bit 0 of
98    /// `bitmap[0]` is always set once the window is initialized
99    /// (corresponds to `highest`).
100    bitmap: Vec<u64>,
101    /// Whether this window has seen any sequence yet.
102    initialized: bool,
103}
104
105impl StreamWindow {
106    fn new(bitmap_words: usize) -> Self {
107        Self {
108            highest: 0,
109            bitmap: vec![0u64; bitmap_words],
110            initialized: false,
111        }
112    }
113}
114
115/// Sliding-window replay protection for multiple independent streams.
116///
117/// Streams are keyed by `(source_id, payload_type, key_epoch)` — see
118/// module-level docs on why the key epoch matters across rekey. Use
119/// [`Self::accept`] to atomically check-and-mark a sequence as received.
120#[derive(Debug, Clone)]
121pub struct ReplayWindow {
122    streams: HashMap<(u64, u8, u8), StreamWindow>,
123    window_bits: u32,
124    bitmap_words: usize,
125}
126
127impl Default for ReplayWindow {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl ReplayWindow {
134    /// Create an empty replay window with the default 64-bit width.
135    pub fn new() -> Self {
136        Self::with_window_bits(DEFAULT_WINDOW_BITS)
137    }
138
139    /// Create an empty replay window with a caller-chosen width.
140    ///
141    /// `bits` MUST be a multiple of 64, at least 64, at most
142    /// [`MAX_WINDOW_BITS`] (1024). Panics otherwise.
143    pub fn with_window_bits(bits: u32) -> Self {
144        assert!(
145            (DEFAULT_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits)
146                && bits % DEFAULT_WINDOW_BITS == 0,
147            "replay window bits must be a multiple of 64 between 64 and 1024; got {bits}",
148        );
149        Self {
150            streams: HashMap::new(),
151            window_bits: bits,
152            bitmap_words: (bits / DEFAULT_WINDOW_BITS) as usize,
153        }
154    }
155
156    /// Current window width in bits.
157    pub fn window_bits(&self) -> u32 {
158        self.window_bits
159    }
160
161    /// Reset all tracked streams. Primarily useful for tests and for
162    /// session teardown. Rekey-driven cleanup is narrower — use
163    /// [`Self::drop_epoch`] to forget only the old key's windows while
164    /// preserving the current one.
165    pub fn clear(&mut self) {
166        self.streams.clear();
167    }
168
169    /// Atomically check whether `seq` is acceptable for the given
170    /// `(source_id, payload_type, key_epoch)` tuple, and mark it as
171    /// received if so.
172    ///
173    /// Returns `true` if the message should be processed (sequence is new
174    /// and within the window), `false` if it should be dropped (duplicate
175    /// or too old).
176    ///
177    /// `source_id` is the 6-byte random identifier from the AEAD nonce
178    /// interpreted as little-endian u64. `payload_type` is the nonce
179    /// byte 6. `key_epoch` is a receiver-local counter that advances on
180    /// every `install_key` call — the caller MUST pass the epoch of the
181    /// key that verified the AEAD tag, not (for example) the current
182    /// epoch if the previous key is what actually opened the envelope.
183    pub fn accept(&mut self, source_id: u64, payload_type: u8, key_epoch: u8, seq: u64) -> bool {
184        let window_bits_u64 = self.window_bits as u64;
185        let bitmap_words = self.bitmap_words;
186        let win = self
187            .streams
188            .entry((source_id, payload_type, key_epoch))
189            .or_insert_with(|| StreamWindow::new(bitmap_words));
190
191        if !win.initialized {
192            // First sequence for this stream: accept and initialize.
193            win.highest = seq;
194            win.bitmap.fill(0);
195            win.bitmap[0] = 1; // bit 0 (offset 0 = highest) = "seq seen"
196            win.initialized = true;
197            return true;
198        }
199
200        if seq > win.highest {
201            // New high sequence: shift the bitmap left by (seq - highest)
202            // bits. Bits shifted past the window edge are discarded.
203            let shift = seq - win.highest;
204            if shift >= window_bits_u64 {
205                // Jumped entirely past the old bitmap. Clear + seed.
206                win.bitmap.fill(0);
207                win.bitmap[0] = 1;
208            } else {
209                shift_bitmap_left(&mut win.bitmap, shift as u32);
210                // Seed bit 0 (the new highest) AFTER shifting.
211                win.bitmap[0] |= 1;
212            }
213            win.highest = seq;
214            true
215        } else {
216            // seq <= highest: check if it falls within the window and is
217            // unseen.
218            let offset = win.highest - seq;
219            if offset >= window_bits_u64 {
220                // Too old.
221                false
222            } else {
223                let word_idx = (offset / DEFAULT_WINDOW_BITS as u64) as usize;
224                let bit_idx = (offset % DEFAULT_WINDOW_BITS as u64) as u32;
225                let mask = 1u64 << bit_idx;
226                if win.bitmap[word_idx] & mask != 0 {
227                    false
228                } else {
229                    win.bitmap[word_idx] |= mask;
230                    true
231                }
232            }
233        }
234    }
235
236    /// Drop all stream state associated with a specific `key_epoch`.
237    /// Called by [`crate::Session::tick`] when the previous-key grace
238    /// period ends — at that point the old key's envelopes can no longer
239    /// verify anyway, so the old window is pure memory overhead and
240    /// should be reclaimed. Safe to call for an epoch that has no
241    /// entries (no-op).
242    pub fn drop_epoch(&mut self, key_epoch: u8) {
243        self.streams.retain(|(_, _, epoch), _| *epoch != key_epoch);
244    }
245
246    /// Number of distinct streams currently tracked. Mostly for tests and
247    /// observability; not part of the protection guarantee.
248    pub fn stream_count(&self) -> usize {
249        self.streams.len()
250    }
251}
252
253/// Shift a multi-word bitmap left by `shift` bits, filling low bits
254/// with zeros. `bitmap[0]` is the LOW word (covers bit offsets 0..64).
255/// `bitmap[N]` is higher. A left shift moves bits toward higher offsets
256/// — equivalent to `u64::<<` semantics extended across words.
257///
258/// Precondition: `shift < bitmap.len() * 64`. The caller (`accept`)
259/// handles the shift-past-end case by clearing the bitmap instead.
260///
261/// Runs in O(N) where N is the number of words. For the default
262/// 1-word (64-bit) case this degenerates to a single `u64 << shift`.
263#[inline]
264fn shift_bitmap_left(bitmap: &mut [u64], shift: u32) {
265    debug_assert!(
266        (shift as usize) < bitmap.len() * 64,
267        "shift {} out of range for {}-word bitmap",
268        shift,
269        bitmap.len()
270    );
271    if bitmap.is_empty() || shift == 0 {
272        return;
273    }
274    let word_shift = (shift / 64) as usize;
275    let bit_shift = shift % 64;
276    let len = bitmap.len();
277
278    if bit_shift == 0 {
279        // Pure word shift — move whole words, zero the low ones.
280        for i in (0..len).rev() {
281            bitmap[i] = if i >= word_shift {
282                bitmap[i - word_shift]
283            } else {
284                0
285            };
286        }
287        return;
288    }
289
290    // General case: each output word gets a contribution from the
291    // high part of one source word (<< bit_shift) OR'd with the
292    // low part of the next-lower source word (>> (64 - bit_shift)).
293    // Iterate from high to low so we don't clobber sources.
294    let inv_bit_shift = 64 - bit_shift;
295    for i in (0..len).rev() {
296        let hi_src = if i >= word_shift {
297            bitmap[i - word_shift] << bit_shift
298        } else {
299            0
300        };
301        let lo_src = if i > word_shift {
302            bitmap[i - word_shift - 1] >> inv_bit_shift
303        } else {
304            0
305        };
306        bitmap[i] = hi_src | lo_src;
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    const SRC: u64 = 0xDEAD_BEEF_CAFE_BABE;
315    const EPOCH: u8 = 0; // most single-epoch tests use epoch 0
316
317    fn accept_default(w: &mut ReplayWindow, pld: u8, seq: u64) -> bool {
318        w.accept(SRC, pld, EPOCH, seq)
319    }
320
321    #[test]
322    fn first_sequence_accepted() {
323        let mut w = ReplayWindow::new();
324        assert!(accept_default(&mut w, 0x10, 0));
325    }
326
327    #[test]
328    fn sequential_sequences_accepted() {
329        let mut w = ReplayWindow::new();
330        for seq in 0..100 {
331            assert!(accept_default(&mut w, 0x10, seq), "seq {seq} should accept");
332        }
333    }
334
335    #[test]
336    fn duplicate_at_highest_rejected() {
337        let mut w = ReplayWindow::new();
338        assert!(accept_default(&mut w, 0x10, 5));
339        assert!(
340            !accept_default(&mut w, 0x10, 5),
341            "duplicate at highest should reject"
342        );
343    }
344
345    #[test]
346    fn duplicate_within_window_rejected() {
347        let mut w = ReplayWindow::new();
348        for seq in 0..=5 {
349            assert!(accept_default(&mut w, 0x10, seq));
350        }
351        assert!(!accept_default(&mut w, 0x10, 2));
352        assert!(accept_default(&mut w, 0x10, 6));
353    }
354
355    #[test]
356    fn out_of_order_within_window_accepted() {
357        let mut w = ReplayWindow::new();
358        assert!(accept_default(&mut w, 0x10, 10));
359        assert!(accept_default(&mut w, 0x10, 7));
360        assert!(!accept_default(&mut w, 0x10, 7));
361        assert!(accept_default(&mut w, 0x10, 8));
362    }
363
364    #[test]
365    fn too_old_sequence_rejected() {
366        let mut w = ReplayWindow::new();
367        assert!(accept_default(&mut w, 0x10, 100));
368        assert!(!accept_default(&mut w, 0x10, 35));
369        assert!(!accept_default(&mut w, 0x10, 36));
370        assert!(accept_default(&mut w, 0x10, 37));
371    }
372
373    #[test]
374    fn future_arrival_shifts_window_correctly() {
375        let mut w = ReplayWindow::new();
376        for seq in 0..=5 {
377            assert!(accept_default(&mut w, 0x10, seq));
378        }
379        assert!(accept_default(&mut w, 0x10, 1000));
380        for seq in 0..=5 {
381            assert!(
382                !accept_default(&mut w, 0x10, seq),
383                "old seq {seq} after jump should reject"
384            );
385        }
386        assert!(accept_default(&mut w, 0x10, 999));
387        assert!(accept_default(&mut w, 0x10, 950));
388        assert!(!accept_default(&mut w, 0x10, 936));
389    }
390
391    #[test]
392    fn independent_streams_dont_interfere() {
393        let mut w = ReplayWindow::new();
394        assert!(accept_default(&mut w, 0x10, 5));
395        assert!(accept_default(&mut w, 0x11, 5));
396        assert!(!accept_default(&mut w, 0x10, 5));
397        assert!(!accept_default(&mut w, 0x11, 5));
398        assert_eq!(w.stream_count(), 2);
399    }
400
401    #[test]
402    fn different_source_ids_dont_interfere() {
403        let mut w = ReplayWindow::new();
404        assert!(w.accept(0xAAAA_AAAA_AAAA_AAAA, 0x10, EPOCH, 100));
405        assert!(w.accept(0xBBBB_BBBB_BBBB_BBBB, 0x10, EPOCH, 100));
406        assert!(!w.accept(0xAAAA_AAAA_AAAA_AAAA, 0x10, EPOCH, 100));
407        assert_eq!(w.stream_count(), 2);
408    }
409
410    #[test]
411    fn window_edge_exactly_window_bits_below_rejected() {
412        let mut w = ReplayWindow::new();
413        assert!(accept_default(&mut w, 0x10, 100));
414        assert!(!accept_default(&mut w, 0x10, 36));
415        assert!(accept_default(&mut w, 0x10, 37));
416    }
417
418    #[test]
419    fn clear_resets_all_streams() {
420        let mut w = ReplayWindow::new();
421        assert!(accept_default(&mut w, 0x10, 5));
422        assert!(accept_default(&mut w, 0x11, 7));
423        assert_eq!(w.stream_count(), 2);
424        w.clear();
425        assert_eq!(w.stream_count(), 0);
426        assert!(accept_default(&mut w, 0x10, 5));
427        assert!(accept_default(&mut w, 0x11, 7));
428    }
429
430    // ─── Per-key-epoch tests (SPEC §5.3) ───────────────────────────────
431
432    #[test]
433    fn independent_epochs_dont_interfere_even_with_same_stream() {
434        // This is the bug-fix regression test. Before the epoch split,
435        // the second accept for epoch=1 at seq=0 would be rejected
436        // because highest=1000 from epoch=0 was stored in a window
437        // keyed only by (source_id, pld_type).
438        let mut w = ReplayWindow::new();
439        for seq in 0..=1000 {
440            assert!(w.accept(SRC, 0x10, 0, seq));
441        }
442        // Rekey: new epoch starts fresh at seq=0.
443        assert!(
444            w.accept(SRC, 0x10, 1, 0),
445            "new-epoch seq=0 must be accepted despite old-epoch highest=1000"
446        );
447        assert!(w.accept(SRC, 0x10, 1, 1));
448        assert!(w.accept(SRC, 0x10, 1, 2));
449    }
450
451    #[test]
452    fn drop_epoch_removes_only_that_epoch() {
453        let mut w = ReplayWindow::new();
454        assert!(w.accept(SRC, 0x10, 0, 5));
455        assert!(w.accept(SRC, 0x10, 1, 5));
456        assert!(w.accept(SRC, 0x11, 0, 5));
457        assert_eq!(w.stream_count(), 3);
458
459        w.drop_epoch(0);
460        assert_eq!(w.stream_count(), 1); // only (SRC, 0x10, 1) left
461
462        // Re-accepting on the dropped epoch is fine — fresh state.
463        assert!(w.accept(SRC, 0x10, 0, 5));
464        // But the un-dropped epoch still sees its old state.
465        assert!(!w.accept(SRC, 0x10, 1, 5));
466    }
467
468    #[test]
469    fn drop_epoch_with_no_entries_is_noop() {
470        let mut w = ReplayWindow::new();
471        w.drop_epoch(42); // no-op, must not panic
472        assert_eq!(w.stream_count(), 0);
473    }
474}