Skip to main content

soe_protocol/
rc4.rs

1//! The RC4 stream cipher, as used by the SOE protocol for optional data encryption.
2//!
3//! The cipher state is maintained for the entirety of a session, rather than being
4//! reset per block of data.
5
6/// The length of the RC4 key state buffer.
7pub const KEY_STATE_LENGTH: usize = 256;
8
9/// A reusable RC4 key state. The state is advanced as data is transformed, so a
10/// single [`Rc4KeyState`] represents one continuous cipher stream.
11#[derive(Clone)]
12pub struct Rc4KeyState {
13    state: [u8; KEY_STATE_LENGTH],
14    index1: usize,
15    index2: usize,
16}
17
18impl Rc4KeyState {
19    /// Creates a new key state by scheduling the given key bytes.
20    ///
21    /// # Panics
22    /// Panics if `key` is empty or longer than 256 bytes (the RC4 key-state length).
23    pub fn new(key: &[u8]) -> Self {
24        assert!(
25            !key.is_empty() && key.len() <= KEY_STATE_LENGTH,
26            "key length must be in 1..={KEY_STATE_LENGTH}"
27        );
28
29        let mut state = [0u8; KEY_STATE_LENGTH];
30        for (i, slot) in state.iter_mut().enumerate() {
31            *slot = i as u8;
32        }
33
34        let mut swap_index1: usize = 0;
35        let mut swap_index2: usize = 0;
36        for i in 0..KEY_STATE_LENGTH {
37            swap_index2 =
38                (swap_index2 + state[i] as usize + key[swap_index1] as usize) % KEY_STATE_LENGTH;
39            state.swap(i, swap_index2);
40            swap_index1 = (swap_index1 + 1) % key.len();
41        }
42
43        Self {
44            state,
45            index1: 0,
46            index2: 0,
47        }
48    }
49
50    /// Returns the current internal key state bytes (for inspection/testing).
51    pub fn key_state(&self) -> &[u8; KEY_STATE_LENGTH] {
52        &self.state
53    }
54
55    /// Returns the two transform indices.
56    pub fn indices(&self) -> (usize, usize) {
57        (self.index1, self.index2)
58    }
59
60    #[inline]
61    fn increment(&mut self) {
62        self.index1 = (self.index1 + 1) % KEY_STATE_LENGTH;
63        self.index2 = (self.index2 + self.state[self.index1] as usize) % KEY_STATE_LENGTH;
64        self.state.swap(self.index1, self.index2);
65    }
66
67    /// Transforms `input` into `output` using (and advancing) this key state.
68    ///
69    /// RC4 is symmetric, so the same operation both encrypts and decrypts.
70    ///
71    /// # Panics
72    /// Panics if `output` is shorter than `input`.
73    pub fn transform(&mut self, input: &[u8], output: &mut [u8]) {
74        assert!(
75            output.len() >= input.len(),
76            "output buffer must be at least as long as the input buffer"
77        );
78
79        for (i, &byte) in input.iter().enumerate() {
80            self.increment();
81            let xor_index = (self.state[self.index1] as usize + self.state[self.index2] as usize)
82                % KEY_STATE_LENGTH;
83            output[i] = byte ^ self.state[xor_index];
84        }
85    }
86
87    /// Transforms `buffer` in place.
88    pub fn transform_in_place(&mut self, buffer: &mut [u8]) {
89        for byte in buffer.iter_mut() {
90            self.increment();
91            let xor_index = (self.state[self.index1] as usize + self.state[self.index2] as usize)
92                % KEY_STATE_LENGTH;
93            *byte ^= self.state[xor_index];
94        }
95    }
96
97    /// Advances the key state by `amount` steps without transforming any data.
98    pub fn advance(&mut self, amount: usize) {
99        for _ in 0..amount {
100            self.increment();
101        }
102    }
103}
104
105impl std::fmt::Debug for Rc4KeyState {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.debug_struct("Rc4KeyState")
108            .field("index1", &self.index1)
109            .field("index2", &self.index2)
110            .finish_non_exhaustive()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    // Wikipedia RC4 test vectors, ported from Rc4CipherTests.cs
119    struct Vector {
120        key: &'static str,
121        plain: &'static str,
122        cipher: &'static [u8],
123    }
124
125    const VECTORS: &[Vector] = &[
126        Vector {
127            key: "Key",
128            plain: "Plaintext",
129            cipher: &[0xBB, 0xF3, 0x16, 0xE8, 0xD9, 0x40, 0xAF, 0x0A, 0xD3],
130        },
131        Vector {
132            key: "Wiki",
133            plain: "pedia",
134            cipher: &[0x10, 0x21, 0xBF, 0x04, 0x20],
135        },
136        Vector {
137            key: "Secret",
138            plain: "Attack at dawn",
139            cipher: &[
140                0x45, 0xA0, 0x1F, 0x64, 0x5F, 0xC3, 0x5B, 0x38, 0x35, 0x52, 0x54, 0x4B, 0x9B, 0xF5,
141            ],
142        },
143    ];
144
145    #[test]
146    fn test_encryption() {
147        for v in VECTORS {
148            let mut state = Rc4KeyState::new(v.key.as_bytes());
149            let mut out = vec![0u8; v.plain.len()];
150            state.transform(v.plain.as_bytes(), &mut out);
151            assert_eq!(out, v.cipher, "key={}", v.key);
152        }
153    }
154
155    #[test]
156    fn test_round_trip() {
157        for v in VECTORS {
158            let mut enc = Rc4KeyState::new(v.key.as_bytes());
159            let mut dec = Rc4KeyState::new(v.key.as_bytes());
160            let mut encrypted = vec![0u8; v.plain.len()];
161            let mut decrypted = vec![0u8; v.plain.len()];
162            enc.transform(v.plain.as_bytes(), &mut encrypted);
163            dec.transform(&encrypted, &mut decrypted);
164            assert_eq!(decrypted, v.plain.as_bytes());
165        }
166    }
167
168    #[test]
169    fn test_existing_key_state() {
170        // Transforming in two halves with one state must match transforming whole.
171        for v in VECTORS {
172            let half = v.cipher.len() / 2;
173            let mut state = Rc4KeyState::new(v.key.as_bytes());
174            let mut decrypted = vec![0u8; v.cipher.len()];
175            state.transform(&v.cipher[..half], &mut decrypted[..half]);
176            let mut tail = vec![0u8; v.cipher.len() - half];
177            state.transform(&v.cipher[half..], &mut tail);
178            decrypted[half..].copy_from_slice(&tail);
179            assert_eq!(decrypted, v.plain.as_bytes());
180        }
181    }
182
183    #[test]
184    fn test_advance() {
185        let key = VECTORS[0].key.as_bytes();
186        let mut values1 = [1u8, 2, 3];
187        let mut values2 = [1u8, 2, 3];
188
189        let mut state1 = Rc4KeyState::new(key);
190        let mut state2 = Rc4KeyState::new(key);
191
192        let copy = values1;
193        state1.transform(&copy, &mut values1);
194
195        state2.advance(2);
196        let tail = [values2[2]];
197        let mut out = [0u8];
198        state2.transform(&tail, &mut out);
199        values2[2] = out[0];
200
201        assert_eq!(values1[2], values2[2]);
202    }
203
204    // Ported from Rc4KeyStateTests.cs
205    #[test]
206    fn clone_creates_full_copy() {
207        let mut state = Rc4KeyState::new(&[0, 1, 2, 3, 4]);
208        state.advance(7);
209        let copied = state.clone();
210        assert_eq!(state.key_state(), copied.key_state());
211        assert_eq!(state.indices(), copied.indices());
212    }
213}