srt_protocol/protocol/encryption/
mod.rs1pub mod key;
2pub mod stream;
3mod wrap;
4
5use std::fmt::Debug;
6
7use bytes::BytesMut;
8
9use crate::{packet::*, settings::*};
10
11use stream::KeyMaterialError;
12
13#[derive(Debug, Eq, PartialEq)]
14pub enum DecryptionError {
15 UnexpectedUnencryptedPacket(DataPacket),
17 UnexpectedEncryptedPacket(DataPacket),
18 EncryptionFailure,
19 DecryptionFailure,
20}
21
22#[derive(Debug)]
23pub struct Decryption(Option<(StreamEncryptionKeys, KeySettings)>);
24
25impl Decryption {
26 pub fn new(settings: Option<CipherSettings>) -> Self {
27 Self(settings.map(|settings| (settings.stream_keys, settings.key_settings)))
28 }
29
30 pub fn decrypt(&self, packet: DataPacket) -> Result<(usize, DataPacket), DecryptionError> {
31 use DecryptionError::*;
32 let mut packet = packet;
33 match (packet.encryption, &self.0) {
34 (DataEncryption::None, None) => Ok((0, packet)),
35 (DataEncryption::None, Some(_)) => Err(UnexpectedUnencryptedPacket(packet)),
36 (DataEncryption::Even | DataEncryption::Odd, None) => {
37 Err(UnexpectedEncryptedPacket(packet))
38 }
39 (selected_sek, Some((stream_keys, _))) => {
40 let mut data = BytesMut::with_capacity(packet.payload.len());
42 data.extend_from_slice(&packet.payload[..]);
43 let bytes = stream_keys
44 .decrypt(selected_sek, packet.seq_number, &mut data)
45 .ok_or(DecryptionFailure)?;
46 packet.encryption = DataEncryption::None;
47 packet.payload = data.freeze();
48 Ok((bytes, packet))
49 }
50 }
51 }
52
53 pub fn refresh_key_material(
54 &mut self,
55 keying_material: KeyingMaterialMessage,
56 ) -> Result<Option<KeyingMaterialMessage>, KeyMaterialError> {
57 let (stream_keys, key_settings) = self.0.as_mut().ok_or(KeyMaterialError::NoKeys)?;
58 *stream_keys = StreamEncryptionKeys::unwrap_from(key_settings, &keying_material)?;
59 Ok(Some(keying_material))
60 }
61}
62
63#[derive(Debug)]
64pub struct Encryption(Option<EncryptionState>);
65
66#[derive(Debug)]
67struct EncryptionState {
68 key_settings: KeySettings,
69 key_refresh: KeyMaterialRefreshSettings,
70 stream_keys: StreamEncryptionKeys,
71 active_sek: DataEncryption,
72 packets_until_pre_announcement: usize,
73 packets_until_transmit: usize,
74 packets_until_key_switch: usize,
75 last_key_material: Option<KeyingMaterialMessage>,
76}
77
78impl EncryptionState {
79 fn try_encrypt_packet(&mut self, mut packet: DataPacket) -> Option<(usize, DataPacket)> {
80 let mut data = BytesMut::with_capacity(packet.payload.len());
82 data.extend_from_slice(&packet.payload[..]);
83 let bytes = self
84 .stream_keys
85 .encrypt(self.active_sek, packet.seq_number, &mut data)?;
86 packet.encryption = self.active_sek;
87 packet.payload = data.freeze();
88 Some((bytes, packet))
89 }
90
91 fn try_schedule_pre_announcment(&mut self) {
92 if self.packets_until_pre_announcement == 0 {
93 self.packets_until_pre_announcement = self.key_refresh.period();
94 self.packets_until_transmit = 0;
95
96 if self.last_key_material.is_none() {
97 self.last_key_material = self
98 .stream_keys
99 .commission_next_key(self.active_sek, &self.key_settings);
100 }
101 }
102 }
103
104 fn try_send_key_material(&mut self) -> Option<KeyingMaterialMessage> {
105 let km = self.last_key_material.as_ref()?;
106 if self.packets_until_transmit == 0 {
107 self.packets_until_transmit =
108 std::cmp::min(self.key_refresh.pre_announcement_period(), 1_000);
109 Some(km.clone())
110 } else {
111 self.packets_until_transmit -= 1;
112 None
113 }
114 }
115
116 fn try_switch_stream_keys(&mut self) {
117 use DataEncryption::*;
118 if self.packets_until_key_switch == 0 {
119 self.packets_until_key_switch = self.key_refresh.period();
120 if self.last_key_material.is_none() {
121 self.active_sek = match self.active_sek {
122 Even => Odd,
123 Odd => Even,
124 None => None,
125 };
126 }
127 }
128 }
129}
130
131impl Encryption {
132 pub fn new(settings: Option<CipherSettings>) -> Self {
133 Self(settings.map(|settings| EncryptionState {
134 key_settings: settings.key_settings,
135 key_refresh: settings.key_refresh.clone(),
136 stream_keys: settings.stream_keys,
137 active_sek: DataEncryption::Even,
138
139 packets_until_pre_announcement: settings.key_refresh.period()
140 - settings.key_refresh.pre_announcement_period(),
141 packets_until_transmit: 0,
142 packets_until_key_switch: settings.key_refresh.period(),
143 last_key_material: None,
144 }))
145 }
146
147 pub fn encrypt(
148 &mut self,
149 packet: DataPacket,
150 ) -> Option<(usize, DataPacket, Option<KeyingMaterialMessage>)> {
151 match &mut self.0 {
152 Some(this) => {
153 let (bytes, packet) = this.try_encrypt_packet(packet)?;
154
155 this.try_schedule_pre_announcment();
156 this.try_switch_stream_keys();
157 let km = this.try_send_key_material();
158
159 this.packets_until_pre_announcement -= 1;
160 this.packets_until_key_switch -= 1;
161
162 Some((bytes, packet, km))
163 }
164 None => Some((0, packet, None)),
165 }
166 }
167
168 pub fn handle_key_refresh_response(
169 &mut self,
170 keying_material: KeyingMaterialMessage,
171 ) -> Result<(), KeyMaterialError> {
172 use KeyMaterialError::*;
173 if let Some(settings) = self.0.as_mut() {
174 let expected_key_material = settings.last_key_material.as_ref().ok_or(NoKeys)?;
175 if keying_material == *expected_key_material {
176 settings.packets_until_transmit = 0;
177 settings.last_key_material = None;
178 } else {
179 return Err(InvalidRefreshResponse(keying_material));
180 }
181 }
182 Ok(())
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 fn key_settings() -> KeySettings {
191 KeySettings {
192 key_size: KeySize::AES192,
193 passphrase: "1234567890".into(),
194 }
195 }
196
197 fn new_settings() -> CipherSettings {
198 CipherSettings::new_random(&key_settings(), &Default::default())
199 }
200
201 fn data_packet(encryption: DataEncryption, payload: &str) -> DataPacket {
202 DataPacket {
203 seq_number: SeqNumber(3),
204 message_loc: PacketLocation::ONLY,
205 in_order_delivery: false,
206 encryption,
207 retransmitted: false,
208 message_number: MsgNumber(1),
209 timestamp: TimeStamp::MIN,
210 dest_sockid: SocketId(0),
211 payload: bytes::Bytes::copy_from_slice(payload.as_bytes()),
212 }
213 }
214
215 #[test]
216 fn round_trip() {
217 let settings = new_settings();
218 let original_packet = data_packet(DataEncryption::None, "test round_trip");
219
220 let mut encryption = Encryption::new(Some(settings.clone()));
221 let (bytes, encrypted_packet, key_material) =
222 encryption.encrypt(original_packet.clone()).unwrap();
223 assert_eq!(bytes, original_packet.payload.len());
224 assert_ne!(encrypted_packet, original_packet);
225 assert_eq!(key_material, None);
226
227 let decryption = Decryption::new(Some(settings));
228 let (bytes, decrypted_packet) = decryption.decrypt(encrypted_packet).unwrap();
229 assert_eq!(bytes, original_packet.payload.len());
230 assert_eq!(decrypted_packet, original_packet);
231 }
232
233 #[test]
234 fn decryption_falure() {
235 use DecryptionError::*;
236 let with_keys = |with_keys| {
237 if with_keys {
238 Decryption::new(Some(new_settings()))
239 } else {
240 Decryption::new(None)
241 }
242 };
243
244 let new_packet = |encryption| data_packet(encryption, "test decryption_falureR");
245
246 let packet = new_packet(DataEncryption::None);
247 assert_eq!(
248 with_keys(true).decrypt(packet.clone()),
249 Err(UnexpectedUnencryptedPacket(packet))
250 );
251
252 let packet = new_packet(DataEncryption::Even);
253 assert_eq!(
254 with_keys(false).decrypt(packet.clone()),
255 Err(UnexpectedEncryptedPacket(packet))
256 );
257
258 let packet = new_packet(DataEncryption::Odd);
259 assert_eq!(
260 with_keys(false).decrypt(packet.clone()),
261 Err(UnexpectedEncryptedPacket(packet))
262 );
263
264 let packet = new_packet(DataEncryption::None);
265 assert_eq!(with_keys(false).decrypt(packet.clone()), Ok((0, packet)));
266 }
267
268 #[test]
269 fn refresh_key_material() {
270 let settings = CipherSettings {
271 key_refresh: KeyMaterialRefreshSettings::new(3_000, 1_000).unwrap(),
272 ..new_settings()
273 };
274 let mut encryption = Encryption::new(Some(settings.clone()));
275 let mut decryption = Decryption::new(Some(settings.clone()));
276 let original_packet = data_packet(DataEncryption::None, "test refresh_key_material");
277
278 let count = settings.key_refresh.period() - settings.key_refresh.pre_announcement_period();
279 for i in 0..count {
280 let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
281 assert_eq!(km, None);
282 assert_eq!(packet.encryption, DataEncryption::Even, "{i:?}");
283 }
284
285 let (_, first_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
286 assert_ne!(km, None);
287 assert_eq!(first_packet.encryption, DataEncryption::Even);
288
289 let key_material = km.unwrap();
290 let response = decryption.refresh_key_material(key_material.clone());
291 assert_eq!(response, Ok(Some(key_material.clone())));
292
293 assert_eq!(encryption.handle_key_refresh_response(key_material), Ok(()));
294
295 for i in 0..settings.key_refresh.pre_announcement_period() {
296 let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
297 assert_eq!(km, None, "{i:?}");
298 assert_eq!(packet.encryption, DataEncryption::Even);
299 }
300
301 let (_, second_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
302 assert_eq!(km, None);
303 assert_eq!(second_packet.encryption, DataEncryption::Odd);
304
305 let (bytes, decrypted_packet) = decryption.decrypt(first_packet).unwrap();
306 assert_eq!(bytes, original_packet.payload.len());
307 assert_eq!(decrypted_packet, original_packet);
308
309 let (bytes, decrypted_packet) = decryption.decrypt(second_packet).unwrap();
310 assert_eq!(bytes, original_packet.payload.len());
311 assert_eq!(decrypted_packet, original_packet);
312
313 let count = settings.key_refresh.period() - settings.key_refresh.pre_announcement_period();
314 for _ in 1..count - 1 {
315 let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
316 assert_eq!(km, None);
317 assert_eq!(packet.encryption, DataEncryption::Odd);
318 }
319
320 let (_, third_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
321 assert_ne!(km, None);
322 assert_eq!(third_packet.encryption, DataEncryption::Odd);
323
324 let key_material = km.unwrap();
325 let response = decryption.refresh_key_material(key_material.clone());
326 assert_eq!(response, Ok(Some(key_material)));
327
328 let (bytes, decrypted_packet) = decryption.decrypt(third_packet).unwrap();
329 assert_eq!(bytes, original_packet.payload.len());
330 assert_eq!(decrypted_packet, original_packet);
331 }
332
333 #[test]
334 fn retry_refresh_key_material() {
335 let settings = CipherSettings {
336 key_refresh: KeyMaterialRefreshSettings::new(44_000, 20_000).unwrap(),
337 ..new_settings()
338 };
339 let mut encryption = Encryption::new(Some(settings.clone()));
340 let original_packet = data_packet(DataEncryption::None, "test refresh_key_material");
341
342 let mut km_resp = None;
343 let count = (0..settings.key_refresh.period() - 10_000)
344 .filter_map(|_| {
346 let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
347 if let Some(km) = &km {
348 km_resp = Some(km.clone());
349 }
350 km.map(|k| (packet.encryption, k))
351 })
352 .count();
353
354 assert_eq!(count, 10);
355
356 encryption
357 .handle_key_refresh_response(km_resp.unwrap())
358 .unwrap();
359
360 let count = (0..10_000
361 + (settings.key_refresh.period() - settings.key_refresh.pre_announcement_period()))
362 .filter_map(|_| {
363 let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
364 km.map(|k| (packet.encryption, k))
365 })
366 .count();
367
368 assert_eq!(count, 0);
370 }
371}