shadowsocks_service/net/
packet_window.rs

1// SPDX-License-Identifier: MIT
2//
3// Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
4
5//! Packet window
6//!
7//! https://github.com/WireGuard/wireguard-go/blob/master/replay/replay.go
8
9const BLOCK_BIT_LOG: u64 = 6; // 1<<6 == 64 bits
10const BLOCK_BITS: u64 = 1 << BLOCK_BIT_LOG; // must be power of 2
11const RING_BLOCKS: u64 = 1 << 7; // must be power of 2
12const WINDOW_SIZE: u64 = (RING_BLOCKS - 1) * BLOCK_BITS;
13const BLOCK_MASK: u64 = RING_BLOCKS - 1;
14const BIT_MASK: u64 = BLOCK_BITS - 1;
15
16/// Packet window for checking `packet_id` is in the sliding window
17#[derive(Debug, Clone)]
18pub struct PacketWindowFilter {
19    last_packet_id: u64,
20    packet_ring: [u64; RING_BLOCKS as usize],
21}
22
23impl Default for PacketWindowFilter {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl PacketWindowFilter {
30    /// Create an empty filter
31    pub fn new() -> Self {
32        Self {
33            last_packet_id: 0,
34            packet_ring: [0u64; RING_BLOCKS as usize],
35        }
36    }
37
38    /// Reset filter to the initial state
39    pub fn reset(&mut self) {
40        self.last_packet_id = 0;
41        self.packet_ring[0] = 0;
42    }
43
44    /// Check and remember the `packet_id`
45    ///
46    /// Overlimit `packet_id >= limit` are always rejected
47    pub fn validate_packet_id(&mut self, packet_id: u64, limit: u64) -> bool {
48        if packet_id >= limit {
49            return false;
50        }
51
52        let mut index_block = packet_id >> BLOCK_BIT_LOG;
53        if packet_id > self.last_packet_id {
54            // Move the window forward
55
56            let current = self.last_packet_id >> BLOCK_BIT_LOG;
57            let mut diff = index_block - current;
58            if diff > RING_BLOCKS {
59                // Clear the whole filter
60                diff = RING_BLOCKS;
61            }
62            for d in 1..=diff {
63                let i = current + d;
64                self.packet_ring[(i & BLOCK_MASK) as usize] = 0;
65            }
66            self.last_packet_id = packet_id;
67        } else if self.last_packet_id - packet_id > WINDOW_SIZE {
68            // Behind the current window
69            return false;
70        }
71
72        // Check and set bit
73        index_block &= BLOCK_MASK;
74        let index_bit = packet_id & BIT_MASK;
75        let old = self.packet_ring[index_block as usize];
76        let new = old | (1 << index_bit);
77        self.packet_ring[index_block as usize] = new;
78        old != new
79    }
80}
81
82#[cfg(test)]
83mod test {
84    use super::*;
85
86    use std::cell::RefCell;
87
88    #[test]
89    fn test_packet_window() {
90        const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1u64 << 13);
91
92        let filter = RefCell::new(PacketWindowFilter::new());
93
94        let test_number = RefCell::new(0);
95        #[allow(non_snake_case)]
96        let T = |n: u64, expected: bool| {
97            *(test_number.borrow_mut()) += 1;
98            if filter.borrow_mut().validate_packet_id(n, REJECT_AFTER_MESSAGES) != expected {
99                panic!("Test {} failed, {} {}", test_number.borrow(), n, expected);
100            }
101        };
102
103        const T_LIM: u64 = WINDOW_SIZE + 1;
104
105        T(0, true); // 1
106        T(1, true); // 2
107        T(1, false); // 3
108        T(9, true); // 4
109        T(8, true); // 5
110        T(7, true); // 6
111        T(7, false); // 7
112        T(T_LIM, true); // 8
113        T(T_LIM - 1, true); // 9
114        T(T_LIM - 1, false); // 10
115        T(T_LIM - 2, true); // 11
116        T(2, true); // 12
117        T(2, false); // 13
118        T(T_LIM + 16, true); // 14
119        T(3, false); // 15
120        T(T_LIM + 16, false); // 16
121        T(T_LIM * 4, true); // 17
122        T(T_LIM * 4 - (T_LIM - 1), true); // 18
123        T(10, false); // 19
124        T(T_LIM * 4 - T_LIM, false); // 20
125        T(T_LIM * 4 - (T_LIM + 1), false); // 21
126        T(T_LIM * 4 - (T_LIM - 2), true); // 22
127        T(T_LIM * 4 + 1 - T_LIM, false); // 23
128        T(0, false); // 24
129        T(REJECT_AFTER_MESSAGES, false); // 25
130        T(REJECT_AFTER_MESSAGES - 1, true); // 26
131        T(REJECT_AFTER_MESSAGES, false); // 27
132        T(REJECT_AFTER_MESSAGES - 1, false); // 28
133        T(REJECT_AFTER_MESSAGES - 2, true); // 29
134        T(REJECT_AFTER_MESSAGES + 1, false); // 30
135        T(REJECT_AFTER_MESSAGES + 2, false); // 31
136        T(REJECT_AFTER_MESSAGES - 2, false); // 32
137        T(REJECT_AFTER_MESSAGES - 3, true); // 33
138        T(0, false); // 34
139
140        println!("Bulk test 1");
141        filter.borrow_mut().reset();
142        *(test_number.borrow_mut()) = 0;
143        for i in 1..=WINDOW_SIZE {
144            T(i, true);
145        }
146        T(0, true);
147        T(0, false);
148
149        println!("Bulk test 2");
150        filter.borrow_mut().reset();
151        *(test_number.borrow_mut()) = 0;
152        for i in 2..=WINDOW_SIZE + 1 {
153            T(i, true);
154        }
155        T(1, true);
156        T(0, false);
157
158        println!("Bulk test 3");
159        filter.borrow_mut().reset();
160        *(test_number.borrow_mut()) = 0;
161        for i in (1..=WINDOW_SIZE + 1).rev() {
162            T(i, true);
163        }
164
165        println!("Bulk test 4");
166        filter.borrow_mut().reset();
167        *(test_number.borrow_mut()) = 0;
168        for i in (2..=WINDOW_SIZE + 2).rev() {
169            T(i, true);
170        }
171        T(0, false);
172
173        println!("Bulk test 5");
174        filter.borrow_mut().reset();
175        *(test_number.borrow_mut()) = 0;
176        for i in (1..=WINDOW_SIZE).rev() {
177            T(i, true);
178        }
179        T(WINDOW_SIZE + 1, true);
180        T(0, false);
181
182        println!("Bulk test 6");
183        filter.borrow_mut().reset();
184        *(test_number.borrow_mut()) = 0;
185        for i in (1..=WINDOW_SIZE).rev() {
186            T(i, true);
187        }
188        T(0, true);
189        T(WINDOW_SIZE + 1, true);
190    }
191}