1pub const KEY_STATE_LENGTH: usize = 256;
8
9#[derive(Clone)]
12pub struct Rc4KeyState {
13 state: [u8; KEY_STATE_LENGTH],
14 index1: usize,
15 index2: usize,
16}
17
18impl Rc4KeyState {
19 pub fn new(key: &[u8]) -> Self {
24 assert!(
25 !key.is_empty() && key.len() <= KEY_STATE_LENGTH,
26 "key length must be in 1..={KEY_STATE_LENGTH}"
27 );
28
29 let mut state = [0u8; KEY_STATE_LENGTH];
30 for (i, slot) in state.iter_mut().enumerate() {
31 *slot = i as u8;
32 }
33
34 let mut swap_index1: usize = 0;
35 let mut swap_index2: usize = 0;
36 for i in 0..KEY_STATE_LENGTH {
37 swap_index2 =
38 (swap_index2 + state[i] as usize + key[swap_index1] as usize) % KEY_STATE_LENGTH;
39 state.swap(i, swap_index2);
40 swap_index1 = (swap_index1 + 1) % key.len();
41 }
42
43 Self {
44 state,
45 index1: 0,
46 index2: 0,
47 }
48 }
49
50 pub fn key_state(&self) -> &[u8; KEY_STATE_LENGTH] {
52 &self.state
53 }
54
55 pub fn indices(&self) -> (usize, usize) {
57 (self.index1, self.index2)
58 }
59
60 #[inline]
61 fn increment(&mut self) {
62 self.index1 = (self.index1 + 1) % KEY_STATE_LENGTH;
63 self.index2 = (self.index2 + self.state[self.index1] as usize) % KEY_STATE_LENGTH;
64 self.state.swap(self.index1, self.index2);
65 }
66
67 pub fn transform(&mut self, input: &[u8], output: &mut [u8]) {
74 assert!(
75 output.len() >= input.len(),
76 "output buffer must be at least as long as the input buffer"
77 );
78
79 for (i, &byte) in input.iter().enumerate() {
80 self.increment();
81 let xor_index = (self.state[self.index1] as usize + self.state[self.index2] as usize)
82 % KEY_STATE_LENGTH;
83 output[i] = byte ^ self.state[xor_index];
84 }
85 }
86
87 pub fn transform_in_place(&mut self, buffer: &mut [u8]) {
89 for byte in buffer.iter_mut() {
90 self.increment();
91 let xor_index = (self.state[self.index1] as usize + self.state[self.index2] as usize)
92 % KEY_STATE_LENGTH;
93 *byte ^= self.state[xor_index];
94 }
95 }
96
97 pub fn advance(&mut self, amount: usize) {
99 for _ in 0..amount {
100 self.increment();
101 }
102 }
103}
104
105impl std::fmt::Debug for Rc4KeyState {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("Rc4KeyState")
108 .field("index1", &self.index1)
109 .field("index2", &self.index2)
110 .finish_non_exhaustive()
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 struct Vector {
120 key: &'static str,
121 plain: &'static str,
122 cipher: &'static [u8],
123 }
124
125 const VECTORS: &[Vector] = &[
126 Vector {
127 key: "Key",
128 plain: "Plaintext",
129 cipher: &[0xBB, 0xF3, 0x16, 0xE8, 0xD9, 0x40, 0xAF, 0x0A, 0xD3],
130 },
131 Vector {
132 key: "Wiki",
133 plain: "pedia",
134 cipher: &[0x10, 0x21, 0xBF, 0x04, 0x20],
135 },
136 Vector {
137 key: "Secret",
138 plain: "Attack at dawn",
139 cipher: &[
140 0x45, 0xA0, 0x1F, 0x64, 0x5F, 0xC3, 0x5B, 0x38, 0x35, 0x52, 0x54, 0x4B, 0x9B, 0xF5,
141 ],
142 },
143 ];
144
145 #[test]
146 fn test_encryption() {
147 for v in VECTORS {
148 let mut state = Rc4KeyState::new(v.key.as_bytes());
149 let mut out = vec![0u8; v.plain.len()];
150 state.transform(v.plain.as_bytes(), &mut out);
151 assert_eq!(out, v.cipher, "key={}", v.key);
152 }
153 }
154
155 #[test]
156 fn test_round_trip() {
157 for v in VECTORS {
158 let mut enc = Rc4KeyState::new(v.key.as_bytes());
159 let mut dec = Rc4KeyState::new(v.key.as_bytes());
160 let mut encrypted = vec![0u8; v.plain.len()];
161 let mut decrypted = vec![0u8; v.plain.len()];
162 enc.transform(v.plain.as_bytes(), &mut encrypted);
163 dec.transform(&encrypted, &mut decrypted);
164 assert_eq!(decrypted, v.plain.as_bytes());
165 }
166 }
167
168 #[test]
169 fn test_existing_key_state() {
170 for v in VECTORS {
172 let half = v.cipher.len() / 2;
173 let mut state = Rc4KeyState::new(v.key.as_bytes());
174 let mut decrypted = vec![0u8; v.cipher.len()];
175 state.transform(&v.cipher[..half], &mut decrypted[..half]);
176 let mut tail = vec![0u8; v.cipher.len() - half];
177 state.transform(&v.cipher[half..], &mut tail);
178 decrypted[half..].copy_from_slice(&tail);
179 assert_eq!(decrypted, v.plain.as_bytes());
180 }
181 }
182
183 #[test]
184 fn test_advance() {
185 let key = VECTORS[0].key.as_bytes();
186 let mut values1 = [1u8, 2, 3];
187 let mut values2 = [1u8, 2, 3];
188
189 let mut state1 = Rc4KeyState::new(key);
190 let mut state2 = Rc4KeyState::new(key);
191
192 let copy = values1;
193 state1.transform(©, &mut values1);
194
195 state2.advance(2);
196 let tail = [values2[2]];
197 let mut out = [0u8];
198 state2.transform(&tail, &mut out);
199 values2[2] = out[0];
200
201 assert_eq!(values1[2], values2[2]);
202 }
203
204 #[test]
206 fn clone_creates_full_copy() {
207 let mut state = Rc4KeyState::new(&[0, 1, 2, 3, 4]);
208 state.advance(7);
209 let copied = state.clone();
210 assert_eq!(state.key_state(), copied.key_state());
211 assert_eq!(state.indices(), copied.indices());
212 }
213}