srt_protocol/protocol/encryption/
mod.rs

1pub mod key;
2pub mod stream;
3mod wrap;
4
5use std::fmt::Debug;
6
7use bytes::BytesMut;
8
9use crate::{packet::*, settings::*};
10
11use stream::KeyMaterialError;
12
13#[derive(Debug, Eq, PartialEq)]
14pub enum DecryptionError {
15    // "Tried to decrypt but key was none"
16    UnexpectedUnencryptedPacket(DataPacket),
17    UnexpectedEncryptedPacket(DataPacket),
18    EncryptionFailure,
19    DecryptionFailure,
20}
21
22#[derive(Debug)]
23pub struct Decryption(Option<(StreamEncryptionKeys, KeySettings)>);
24
25impl Decryption {
26    pub fn new(settings: Option<CipherSettings>) -> Self {
27        Self(settings.map(|settings| (settings.stream_keys, settings.key_settings)))
28    }
29
30    pub fn decrypt(&self, packet: DataPacket) -> Result<(usize, DataPacket), DecryptionError> {
31        use DecryptionError::*;
32        let mut packet = packet;
33        match (packet.encryption, &self.0) {
34            (DataEncryption::None, None) => Ok((0, packet)),
35            (DataEncryption::None, Some(_)) => Err(UnexpectedUnencryptedPacket(packet)),
36            (DataEncryption::Even | DataEncryption::Odd, None) => {
37                Err(UnexpectedEncryptedPacket(packet))
38            }
39            (selected_sek, Some((stream_keys, _))) => {
40                // this requires an extra copy here...maybe DataPacket should have a BytesMut in it instead...
41                let mut data = BytesMut::with_capacity(packet.payload.len());
42                data.extend_from_slice(&packet.payload[..]);
43                let bytes = stream_keys
44                    .decrypt(selected_sek, packet.seq_number, &mut data)
45                    .ok_or(DecryptionFailure)?;
46                packet.encryption = DataEncryption::None;
47                packet.payload = data.freeze();
48                Ok((bytes, packet))
49            }
50        }
51    }
52
53    pub fn refresh_key_material(
54        &mut self,
55        keying_material: KeyingMaterialMessage,
56    ) -> Result<Option<KeyingMaterialMessage>, KeyMaterialError> {
57        let (stream_keys, key_settings) = self.0.as_mut().ok_or(KeyMaterialError::NoKeys)?;
58        *stream_keys = StreamEncryptionKeys::unwrap_from(key_settings, &keying_material)?;
59        Ok(Some(keying_material))
60    }
61}
62
63#[derive(Debug)]
64pub struct Encryption(Option<EncryptionState>);
65
66#[derive(Debug)]
67struct EncryptionState {
68    key_settings: KeySettings,
69    key_refresh: KeyMaterialRefreshSettings,
70    stream_keys: StreamEncryptionKeys,
71    active_sek: DataEncryption,
72    packets_until_pre_announcement: usize,
73    packets_until_transmit: usize,
74    packets_until_key_switch: usize,
75    last_key_material: Option<KeyingMaterialMessage>,
76}
77
78impl EncryptionState {
79    fn try_encrypt_packet(&mut self, mut packet: DataPacket) -> Option<(usize, DataPacket)> {
80        // this requires an extra copy here...maybe DataPacket should have a BytesMut in it instead...
81        let mut data = BytesMut::with_capacity(packet.payload.len());
82        data.extend_from_slice(&packet.payload[..]);
83        let bytes = self
84            .stream_keys
85            .encrypt(self.active_sek, packet.seq_number, &mut data)?;
86        packet.encryption = self.active_sek;
87        packet.payload = data.freeze();
88        Some((bytes, packet))
89    }
90
91    fn try_schedule_pre_announcment(&mut self) {
92        if self.packets_until_pre_announcement == 0 {
93            self.packets_until_pre_announcement = self.key_refresh.period();
94            self.packets_until_transmit = 0;
95
96            if self.last_key_material.is_none() {
97                self.last_key_material = self
98                    .stream_keys
99                    .commission_next_key(self.active_sek, &self.key_settings);
100            }
101        }
102    }
103
104    fn try_send_key_material(&mut self) -> Option<KeyingMaterialMessage> {
105        let km = self.last_key_material.as_ref()?;
106        if self.packets_until_transmit == 0 {
107            self.packets_until_transmit =
108                std::cmp::min(self.key_refresh.pre_announcement_period(), 1_000);
109            Some(km.clone())
110        } else {
111            self.packets_until_transmit -= 1;
112            None
113        }
114    }
115
116    fn try_switch_stream_keys(&mut self) {
117        use DataEncryption::*;
118        if self.packets_until_key_switch == 0 {
119            self.packets_until_key_switch = self.key_refresh.period();
120            if self.last_key_material.is_none() {
121                self.active_sek = match self.active_sek {
122                    Even => Odd,
123                    Odd => Even,
124                    None => None,
125                };
126            }
127        }
128    }
129}
130
131impl Encryption {
132    pub fn new(settings: Option<CipherSettings>) -> Self {
133        Self(settings.map(|settings| EncryptionState {
134            key_settings: settings.key_settings,
135            key_refresh: settings.key_refresh.clone(),
136            stream_keys: settings.stream_keys,
137            active_sek: DataEncryption::Even,
138
139            packets_until_pre_announcement: settings.key_refresh.period()
140                - settings.key_refresh.pre_announcement_period(),
141            packets_until_transmit: 0,
142            packets_until_key_switch: settings.key_refresh.period(),
143            last_key_material: None,
144        }))
145    }
146
147    pub fn encrypt(
148        &mut self,
149        packet: DataPacket,
150    ) -> Option<(usize, DataPacket, Option<KeyingMaterialMessage>)> {
151        match &mut self.0 {
152            Some(this) => {
153                let (bytes, packet) = this.try_encrypt_packet(packet)?;
154
155                this.try_schedule_pre_announcment();
156                this.try_switch_stream_keys();
157                let km = this.try_send_key_material();
158
159                this.packets_until_pre_announcement -= 1;
160                this.packets_until_key_switch -= 1;
161
162                Some((bytes, packet, km))
163            }
164            None => Some((0, packet, None)),
165        }
166    }
167
168    pub fn handle_key_refresh_response(
169        &mut self,
170        keying_material: KeyingMaterialMessage,
171    ) -> Result<(), KeyMaterialError> {
172        use KeyMaterialError::*;
173        if let Some(settings) = self.0.as_mut() {
174            let expected_key_material = settings.last_key_material.as_ref().ok_or(NoKeys)?;
175            if keying_material == *expected_key_material {
176                settings.packets_until_transmit = 0;
177                settings.last_key_material = None;
178            } else {
179                return Err(InvalidRefreshResponse(keying_material));
180            }
181        }
182        Ok(())
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    fn key_settings() -> KeySettings {
191        KeySettings {
192            key_size: KeySize::AES192,
193            passphrase: "1234567890".into(),
194        }
195    }
196
197    fn new_settings() -> CipherSettings {
198        CipherSettings::new_random(&key_settings(), &Default::default())
199    }
200
201    fn data_packet(encryption: DataEncryption, payload: &str) -> DataPacket {
202        DataPacket {
203            seq_number: SeqNumber(3),
204            message_loc: PacketLocation::ONLY,
205            in_order_delivery: false,
206            encryption,
207            retransmitted: false,
208            message_number: MsgNumber(1),
209            timestamp: TimeStamp::MIN,
210            dest_sockid: SocketId(0),
211            payload: bytes::Bytes::copy_from_slice(payload.as_bytes()),
212        }
213    }
214
215    #[test]
216    fn round_trip() {
217        let settings = new_settings();
218        let original_packet = data_packet(DataEncryption::None, "test round_trip");
219
220        let mut encryption = Encryption::new(Some(settings.clone()));
221        let (bytes, encrypted_packet, key_material) =
222            encryption.encrypt(original_packet.clone()).unwrap();
223        assert_eq!(bytes, original_packet.payload.len());
224        assert_ne!(encrypted_packet, original_packet);
225        assert_eq!(key_material, None);
226
227        let decryption = Decryption::new(Some(settings));
228        let (bytes, decrypted_packet) = decryption.decrypt(encrypted_packet).unwrap();
229        assert_eq!(bytes, original_packet.payload.len());
230        assert_eq!(decrypted_packet, original_packet);
231    }
232
233    #[test]
234    fn decryption_falure() {
235        use DecryptionError::*;
236        let with_keys = |with_keys| {
237            if with_keys {
238                Decryption::new(Some(new_settings()))
239            } else {
240                Decryption::new(None)
241            }
242        };
243
244        let new_packet = |encryption| data_packet(encryption, "test decryption_falureR");
245
246        let packet = new_packet(DataEncryption::None);
247        assert_eq!(
248            with_keys(true).decrypt(packet.clone()),
249            Err(UnexpectedUnencryptedPacket(packet))
250        );
251
252        let packet = new_packet(DataEncryption::Even);
253        assert_eq!(
254            with_keys(false).decrypt(packet.clone()),
255            Err(UnexpectedEncryptedPacket(packet))
256        );
257
258        let packet = new_packet(DataEncryption::Odd);
259        assert_eq!(
260            with_keys(false).decrypt(packet.clone()),
261            Err(UnexpectedEncryptedPacket(packet))
262        );
263
264        let packet = new_packet(DataEncryption::None);
265        assert_eq!(with_keys(false).decrypt(packet.clone()), Ok((0, packet)));
266    }
267
268    #[test]
269    fn refresh_key_material() {
270        let settings = CipherSettings {
271            key_refresh: KeyMaterialRefreshSettings::new(3_000, 1_000).unwrap(),
272            ..new_settings()
273        };
274        let mut encryption = Encryption::new(Some(settings.clone()));
275        let mut decryption = Decryption::new(Some(settings.clone()));
276        let original_packet = data_packet(DataEncryption::None, "test refresh_key_material");
277
278        let count = settings.key_refresh.period() - settings.key_refresh.pre_announcement_period();
279        for i in 0..count {
280            let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
281            assert_eq!(km, None);
282            assert_eq!(packet.encryption, DataEncryption::Even, "{i:?}");
283        }
284
285        let (_, first_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
286        assert_ne!(km, None);
287        assert_eq!(first_packet.encryption, DataEncryption::Even);
288
289        let key_material = km.unwrap();
290        let response = decryption.refresh_key_material(key_material.clone());
291        assert_eq!(response, Ok(Some(key_material.clone())));
292
293        assert_eq!(encryption.handle_key_refresh_response(key_material), Ok(()));
294
295        for i in 0..settings.key_refresh.pre_announcement_period() {
296            let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
297            assert_eq!(km, None, "{i:?}");
298            assert_eq!(packet.encryption, DataEncryption::Even);
299        }
300
301        let (_, second_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
302        assert_eq!(km, None);
303        assert_eq!(second_packet.encryption, DataEncryption::Odd);
304
305        let (bytes, decrypted_packet) = decryption.decrypt(first_packet).unwrap();
306        assert_eq!(bytes, original_packet.payload.len());
307        assert_eq!(decrypted_packet, original_packet);
308
309        let (bytes, decrypted_packet) = decryption.decrypt(second_packet).unwrap();
310        assert_eq!(bytes, original_packet.payload.len());
311        assert_eq!(decrypted_packet, original_packet);
312
313        let count = settings.key_refresh.period() - settings.key_refresh.pre_announcement_period();
314        for _ in 1..count - 1 {
315            let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
316            assert_eq!(km, None);
317            assert_eq!(packet.encryption, DataEncryption::Odd);
318        }
319
320        let (_, third_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
321        assert_ne!(km, None);
322        assert_eq!(third_packet.encryption, DataEncryption::Odd);
323
324        let key_material = km.unwrap();
325        let response = decryption.refresh_key_material(key_material.clone());
326        assert_eq!(response, Ok(Some(key_material)));
327
328        let (bytes, decrypted_packet) = decryption.decrypt(third_packet).unwrap();
329        assert_eq!(bytes, original_packet.payload.len());
330        assert_eq!(decrypted_packet, original_packet);
331    }
332
333    #[test]
334    fn retry_refresh_key_material() {
335        let settings = CipherSettings {
336            key_refresh: KeyMaterialRefreshSettings::new(44_000, 20_000).unwrap(),
337            ..new_settings()
338        };
339        let mut encryption = Encryption::new(Some(settings.clone()));
340        let original_packet = data_packet(DataEncryption::None, "test refresh_key_material");
341
342        let mut km_resp = None;
343        let count = (0..settings.key_refresh.period() - 10_000)
344            // let count = (0..settings.key_refresh.period())
345            .filter_map(|_| {
346                let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
347                if let Some(km) = &km {
348                    km_resp = Some(km.clone());
349                }
350                km.map(|k| (packet.encryption, k))
351            })
352            .count();
353
354        assert_eq!(count, 10);
355
356        encryption
357            .handle_key_refresh_response(km_resp.unwrap())
358            .unwrap();
359
360        let count = (0..10_000
361            + (settings.key_refresh.period() - settings.key_refresh.pre_announcement_period()))
362            .filter_map(|_| {
363                let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
364                km.map(|k| (packet.encryption, k))
365            })
366            .count();
367
368        // none received after the response
369        assert_eq!(count, 0);
370    }
371}