srt_protocol/packet/control/
srt.rs

1use std::{
2    fmt::{self, Display, Formatter},
3    {collections::BTreeMap, convert::TryFrom, time::Duration},
4};
5
6use bitflags::bitflags;
7use bytes::{Buf, BufMut};
8use log::warn;
9
10use crate::{options::SrtVersion, packet::PacketParseError};
11
12/// The SRT-specific control packets
13/// These are `Packet::Custom` types
14#[derive(Clone, Eq, PartialEq)]
15pub enum SrtControlPacket {
16    /// SRT handshake reject
17    /// ID = 0
18    Reject,
19
20    /// SRT handshake request
21    /// ID = 1
22    HandshakeRequest(SrtHandshake),
23
24    /// SRT handshake response
25    /// ID = 2
26    HandshakeResponse(SrtHandshake),
27
28    /// Key manager request
29    /// ID = 3
30    KeyRefreshRequest(KeyingMaterialMessage),
31
32    /// Key manager response
33    /// ID = 4
34    KeyRefreshResponse(KeyingMaterialMessage),
35
36    /// Stream identifier
37    /// ID = 5
38    StreamId(String),
39
40    /// Congestion control type. Often "live" or "file"
41    /// ID = 6
42    Congestion(String),
43
44    /// ID = 7
45    /// Filter seems to be a string of
46    /// comma-separted key-value pairs like:
47    /// a:b,c:d
48    Filter(FilterSpec),
49
50    // ID = 8
51    Group {
52        ty: GroupType,
53        flags: GroupFlags,
54        weight: u16,
55    },
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct FilterSpec(pub BTreeMap<String, String>);
60
61#[derive(Copy, Clone, Eq, PartialEq, Debug)]
62pub enum GroupType {
63    Undefined,
64    Broadcast,
65    MainBackup,
66    Balancing,
67    Multicast,
68    Unrecognized(u8),
69}
70
71bitflags! {
72    #[derive(Clone, Copy, Eq, PartialEq, Debug)]
73    pub struct GroupFlags: u8 {
74        const MSG_SYNC = 1 << 6;
75    }
76}
77
78/// from https://github.com/Haivision/srt/blob/2ef4ef003c2006df1458de6d47fbe3d2338edf69/haicrypt/hcrypt_msg.h#L76-L96
79/// or https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-00#section-3.2.2
80///
81/// HaiCrypt KMmsg (Keying Material Message):
82///
83/// ```ignore,
84///        0                   1                   2                   3
85///        0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
86///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
87/// +0x00 |0|Vers |   PT  |             Sign              |    resv   |KF |
88///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
89/// +0x04 |                              KEKI                             |
90///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
91/// +0x08 |    Cipher     |      Auth     |      SE       |     Resv1     |
92///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
93/// +0x0C |             Resv2             |     Slen/4    |     Klen/4    |
94///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
95/// +0x10 |                              Salt                             |
96///       |                              ...                              |
97///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
98///       |                              Wrap                             |
99///       |                              ...                              |
100///       +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
101/// ```
102///
103#[derive(Clone, Eq, PartialEq)]
104pub struct KeyingMaterialMessage {
105    pub pt: PacketType, // TODO: i think this is always KeyingMaterial....
106    pub key_flags: KeyFlags,
107    pub keki: u32,
108    pub cipher: CipherType,
109    pub auth: Auth,
110    pub salt: Vec<u8>,
111    pub wrapped_keys: Vec<u8>,
112}
113
114impl From<GroupType> for u8 {
115    fn from(from: GroupType) -> u8 {
116        match from {
117            GroupType::Undefined => 0,
118            GroupType::Broadcast => 1,
119            GroupType::MainBackup => 2,
120            GroupType::Balancing => 3,
121            GroupType::Multicast => 4,
122            GroupType::Unrecognized(u) => u,
123        }
124    }
125}
126
127impl From<u8> for GroupType {
128    fn from(from: u8) -> GroupType {
129        match from {
130            0 => GroupType::Undefined,
131            1 => GroupType::Broadcast,
132            2 => GroupType::MainBackup,
133            3 => GroupType::Balancing,
134            4 => GroupType::Multicast,
135            u => GroupType::Unrecognized(u),
136        }
137    }
138}
139
140impl fmt::Debug for KeyingMaterialMessage {
141    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
142        f.debug_struct("KeyingMaterialMessage")
143            .field("pt", &self.pt)
144            .field("key_flags", &self.key_flags)
145            .field("keki", &self.keki)
146            .field("cipher", &self.cipher)
147            .field("auth", &self.auth)
148            .finish()
149    }
150}
151
152#[derive(Debug, Copy, Clone, Eq, PartialEq)]
153pub enum Auth {
154    None = 0,
155}
156
157impl TryFrom<u8> for Auth {
158    type Error = PacketParseError;
159    fn try_from(value: u8) -> Result<Self, Self::Error> {
160        match value {
161            0 => Ok(Auth::None),
162            e => Err(PacketParseError::BadAuth(e)),
163        }
164    }
165}
166
167#[derive(Debug, Copy, Clone, Eq, PartialEq)]
168pub enum StreamEncapsulation {
169    Udp = 1,
170    Srt = 2,
171}
172
173impl TryFrom<u8> for StreamEncapsulation {
174    type Error = PacketParseError;
175    fn try_from(value: u8) -> Result<Self, Self::Error> {
176        Ok(match value {
177            1 => StreamEncapsulation::Udp,
178            2 => StreamEncapsulation::Srt,
179            e => return Err(PacketParseError::BadStreamEncapsulation(e)),
180        })
181    }
182}
183
184#[derive(Debug, Copy, Clone, Eq, PartialEq)]
185// see htcryp_msg.h:43...
186// 7: Reserved to discriminate MPEG-TS packet (0x47=sync byte).
187pub enum PacketType {
188    MediaStream = 1,    // Media Stream Message (MSmsg)
189    KeyingMaterial = 2, // Keying Material Message (KMmsg)
190}
191
192bitflags! {
193    #[derive(Clone, Copy, Eq, PartialEq, Debug)]
194    pub struct KeyFlags : u8 {
195        const EVEN = 0b01;
196        const ODD = 0b10;
197    }
198}
199
200impl TryFrom<u8> for PacketType {
201    type Error = PacketParseError;
202    fn try_from(value: u8) -> Result<Self, Self::Error> {
203        match value {
204            1 => Ok(PacketType::MediaStream),
205            2 => Ok(PacketType::KeyingMaterial),
206            err => Err(PacketParseError::BadKeyPacketType(err)),
207        }
208    }
209}
210
211/// from https://github.com/Haivision/srt/blob/2ef4ef003c2006df1458de6d47fbe3d2338edf69/haicrypt/hcrypt_msg.h#L121-L124
212#[derive(Debug, Clone, Copy, Eq, PartialEq)]
213pub enum CipherType {
214    None = 0,
215    Ecb = 1,
216    Ctr = 2,
217    Cbc = 3,
218}
219
220/// The SRT handshake object
221#[derive(Debug, Clone, Copy, Eq, PartialEq)]
222pub struct SrtHandshake {
223    /// The SRT version
224    /// Serialized just as the u32 that SrtVersion serialized to
225    pub version: SrtVersion,
226
227    /// SRT connection init flags
228    pub flags: SrtShakeFlags,
229
230    /// The peer's TSBPD latency (latency to send at)
231    /// This is serialized as the upper 16 bits of the third 32-bit word
232    /// source: https://github.com/Haivision/srt/blob/4f7f2beb2e1e306111b9b11402049a90cb6d3787/srtcore/core.cpp#L1341-L1353
233    pub send_latency: Duration,
234
235    /// The TSBPD latency (latency to recv at)
236    /// This is serialized as the lower 16 bits of the third 32-bit word
237    /// see csrtcc.cpp:132 in the reference implementation
238    pub recv_latency: Duration,
239}
240
241bitflags! {
242    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
243    pub struct SrtShakeFlags: u32 {
244        /// Timestamp-based Packet delivery real-time data sender
245        const TSBPDSND = 0x1;
246
247        /// Timestamp-based Packet delivery real-time data receiver
248        const TSBPDRCV = 0x2;
249
250        /// HaiCrypt AES-128/192/256-CTR
251        /// also represents if it supports the encryption flags in the data packet
252        const HAICRYPT = 0x4;
253
254        /// Drop real-time data packets too late to be processed in time
255        const TLPKTDROP = 0x8;
256
257        /// Periodic NAK report
258        const NAKREPORT = 0x10;
259
260        /// One bit in payload packet msgno is "retransmitted" flag
261        const REXMITFLG = 0x20;
262
263        /// This entity supports stream ID packets
264        const STREAM = 0x40;
265
266        /// Again not sure... TODO:
267        const PACKET_FILTER = 0x80;
268
269        // currently implemented flags
270        const SUPPORTED = Self::TSBPDSND.bits() | Self::TSBPDRCV.bits() | Self::HAICRYPT.bits() | Self::REXMITFLG.bits();
271    }
272}
273
274fn le_bytes_to_string(le_bytes: &mut impl Buf) -> Result<String, PacketParseError> {
275    if le_bytes.remaining() % 4 != 0 {
276        return Err(PacketParseError::NotEnoughData);
277    }
278
279    let mut str_bytes = Vec::with_capacity(le_bytes.remaining());
280
281    while le_bytes.remaining() > 4 {
282        str_bytes.extend(le_bytes.get_u32_le().to_be_bytes());
283    }
284
285    // make sure to skip padding bytes if any for the last word
286    match le_bytes.get_u32_le().to_be_bytes() {
287        [a, 0, 0, 0] => str_bytes.push(a),
288        [a, b, 0, 0] => str_bytes.extend([a, b]),
289        [a, b, c, 0] => str_bytes.extend([a, b, c]),
290        [a, b, c, d] => str_bytes.extend([a, b, c, d]),
291    }
292
293    String::from_utf8(str_bytes).map_err(|e| PacketParseError::StreamTypeNotUtf8(e.utf8_error()))
294}
295
296fn string_to_le_bytes(str: &str, into: &mut impl BufMut) {
297    let mut chunks = str.as_bytes().chunks_exact(4);
298
299    while let Some(&[a, b, c, d]) = chunks.next() {
300        into.put(&[d, c, b, a][..]);
301    }
302
303    // add padding bytes for the final word if needed
304    match *chunks.remainder() {
305        [a, b, c] => into.put(&[0, c, b, a][..]),
306        [a, b] => into.put(&[0, 0, b, a][..]),
307        [a] => into.put(&[0, 0, 0, a][..]),
308        [] => {} // exact multiple of 4
309        _ => unreachable!(),
310    }
311}
312
313impl Display for FilterSpec {
314    fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
315        for (i, (k, v)) in self.0.iter().enumerate() {
316            write!(f, "{k}:{v}")?;
317            if i != self.0.len() - 1 {
318                write!(f, ",")?;
319            }
320        }
321        Ok(())
322    }
323}
324
325impl SrtControlPacket {
326    pub fn parse<T: Buf>(
327        packet_type: u16,
328        buf: &mut T,
329    ) -> Result<SrtControlPacket, PacketParseError> {
330        use self::SrtControlPacket::*;
331
332        match packet_type {
333            0 => Ok(Reject),
334            1 => Ok(HandshakeRequest(SrtHandshake::parse(buf)?)),
335            2 => Ok(HandshakeResponse(SrtHandshake::parse(buf)?)),
336            3 => Ok(KeyRefreshRequest(KeyingMaterialMessage::parse(buf)?)),
337            4 => Ok(KeyRefreshResponse(KeyingMaterialMessage::parse(buf)?)),
338            5 => {
339                // the stream id string is stored as 32-bit little endian words
340                // https://tools.ietf.org/html/draft-sharabayko-mops-srt-01#section-3.2.1.3
341                le_bytes_to_string(buf).map(StreamId)
342            }
343            6 => le_bytes_to_string(buf).map(Congestion),
344            // Filter
345            7 => {
346                let filter_str = le_bytes_to_string(buf)?;
347                Ok(Filter(FilterSpec(
348                    filter_str
349                        .split(',')
350                        .map(|kv| {
351                            let mut colon_split_iter = kv.split(':');
352                            let k = colon_split_iter
353                                .next()
354                                .ok_or_else(|| PacketParseError::BadFilter(filter_str.clone()))?;
355                            let v = colon_split_iter
356                                .next()
357                                .ok_or_else(|| PacketParseError::BadFilter(filter_str.clone()))?;
358                            // only one colon
359                            if colon_split_iter.next().is_some() {
360                                return Err(PacketParseError::BadFilter(filter_str.clone()));
361                            }
362                            Ok((k.to_string(), v.to_string()))
363                        })
364                        .collect::<Result<_, _>>()?,
365                )))
366            }
367            8 => {
368                let ty = buf.get_u8().into();
369                let flags = GroupFlags::from_bits_truncate(buf.get_u8());
370                let weight = buf.get_u16_le();
371                Ok(Group { ty, flags, weight })
372            }
373            _ => Err(PacketParseError::UnsupportedSrtExtensionType(packet_type)),
374        }
375    }
376
377    /// Get the value to fill the reserved area with
378    pub fn type_id(&self) -> u16 {
379        use self::SrtControlPacket::*;
380
381        match self {
382            Reject => 0,
383            HandshakeRequest(_) => 1,
384            HandshakeResponse(_) => 2,
385            KeyRefreshRequest(_) => 3,
386            KeyRefreshResponse(_) => 4,
387            StreamId(_) => 5,
388            Congestion(_) => 6,
389            Filter(_) => 7,
390            Group { .. } => 8,
391        }
392    }
393    pub fn serialize<T: BufMut>(&self, into: &mut T) {
394        use self::SrtControlPacket::*;
395
396        match self {
397            HandshakeRequest(s) | HandshakeResponse(s) => {
398                s.serialize(into);
399            }
400            KeyRefreshRequest(k) | KeyRefreshResponse(k) => {
401                k.serialize(into);
402            }
403            Filter(filter) => {
404                string_to_le_bytes(&format!("{filter}"), into);
405            }
406            Group { ty, flags, weight } => {
407                into.put_u8((*ty).into());
408                into.put_u8(flags.bits());
409                into.put_u16_le(*weight);
410            }
411            Reject => {}
412            StreamId(str) | Congestion(str) => {
413                // the stream id string and congestion string is stored as 32-bit little endian words
414                // https://tools.ietf.org/html/draft-sharabayko-mops-srt-01#section-3.2.1.3
415                string_to_le_bytes(str, into);
416            }
417        }
418    }
419    // size in 32-bit words
420    pub fn size_words(&self) -> u16 {
421        use self::SrtControlPacket::*;
422
423        match self {
424            // 3 32-bit words, version, flags, latency
425            HandshakeRequest(_) | HandshakeResponse(_) => 3,
426            // 4 32-bit words + salt + key + wrap [2]
427            KeyRefreshRequest(ref k) | KeyRefreshResponse(ref k) => {
428                4 + k.salt.len() as u16 / 4 + k.wrapped_keys.len() as u16 / 4
429            }
430            Congestion(str) | StreamId(str) => ((str.len() + 3) / 4) as u16, // round up to nearest multiple of 4
431            // 1 32-bit word packed with type, flags, and weight
432            Group { .. } => 1,
433            Filter(filter) => ((format!("{filter}").len() + 3) / 4) as u16, // TODO: not optimial performace, but probably okay
434            _ => unimplemented!("{:?}", self),
435        }
436    }
437}
438
439impl SrtHandshake {
440    pub fn parse<T: Buf>(buf: &mut T) -> Result<SrtHandshake, PacketParseError> {
441        if buf.remaining() < 12 {
442            return Err(PacketParseError::NotEnoughData);
443        }
444
445        let version = SrtVersion::parse(buf.get_u32());
446
447        let shake_flags = buf.get_u32();
448        let flags = match SrtShakeFlags::from_bits(shake_flags) {
449            Some(i) => i,
450            None => {
451                warn!("Unrecognized SRT flags: 0b{:b}", shake_flags);
452                SrtShakeFlags::from_bits_truncate(shake_flags)
453            }
454        };
455        let peer_latency = buf.get_u16();
456        let latency = buf.get_u16();
457
458        Ok(SrtHandshake {
459            version,
460            flags,
461            send_latency: Duration::from_millis(u64::from(peer_latency)),
462            recv_latency: Duration::from_millis(u64::from(latency)),
463        })
464    }
465
466    pub fn serialize<T: BufMut>(&self, into: &mut T) {
467        into.put_u32(self.version.to_u32());
468        into.put_u32(self.flags.bits());
469        // upper 16 bits are peer latency
470        into.put_u16(self.send_latency.as_millis() as u16); // TODO: handle overflow
471
472        // lower 16 is latency
473        into.put_u16(self.recv_latency.as_millis() as u16); // TODO: handle overflow
474    }
475}
476
477impl KeyingMaterialMessage {
478    // from hcrypt_msg.h:39
479    // also const traits aren't a thing yet, so u16::from can't be used
480    const SIGN: u16 =
481        ((b'H' - b'@') as u16) << 10 | ((b'A' - b'@') as u16) << 5 | (b'I' - b'@') as u16;
482
483    pub fn parse(buf: &mut impl Buf) -> Result<KeyingMaterialMessage, PacketParseError> {
484        // first 32-bit word:
485        //
486        //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
487        // +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
488        // |0|Vers |   PT  |             Sign              |    resv   |KF |
489
490        // make sure there is enough data left in the buffer to at least get to the key flags and length, which tells us how long the packet will be
491        // that's 4x32bit words
492        if buf.remaining() < 4 * 4 {
493            return Err(PacketParseError::NotEnoughData);
494        }
495
496        let vers_pt = buf.get_u8();
497
498        // make sure the first bit is zero
499        if (vers_pt & 0b1000_0000) != 0 {
500            return Err(PacketParseError::BadSrtExtensionMessage);
501        }
502
503        // upper 4 bits are version
504        let version = vers_pt >> 4;
505
506        if version != 1 {
507            return Err(PacketParseError::BadSrtExtensionMessage);
508        }
509
510        // lower 4 bits are pt
511        let pt = PacketType::try_from(vers_pt & 0b0000_1111)?;
512
513        // next 16 bis are sign
514        let sign = buf.get_u16();
515
516        if sign != Self::SIGN {
517            return Err(PacketParseError::BadKeySign(sign));
518        }
519
520        // next 6 bits is reserved, then two bits of KF
521        let key_flags = KeyFlags::from_bits_truncate(buf.get_u8() & 0b0000_0011);
522
523        // second 32-bit word: keki
524        let keki = buf.get_u32();
525
526        // third 32-bit word:
527        //
528        //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
529        // +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
530        // |    Cipher     |      Auth     |      SE       |     Resv1     |
531
532        let cipher = CipherType::try_from(buf.get_u8())?;
533        let auth = Auth::try_from(buf.get_u8())?;
534        let se = StreamEncapsulation::try_from(buf.get_u8())?;
535        if se != StreamEncapsulation::Srt {
536            return Err(PacketParseError::StreamEncapsulationNotSrt);
537        }
538
539        let _resv1 = buf.get_u8();
540
541        // fourth 32-bit word:
542        //
543        //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
544        // +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
545        // |             Resv2             |     Slen/4    |     Klen/4    |
546
547        let _resv2 = buf.get_u16();
548        let salt_len = usize::from(buf.get_u8()) * 4;
549        let key_len = usize::from(buf.get_u8()) * 4;
550
551        // acceptable key lengths are 16, 24, and 32
552        match key_len {
553            // OK
554            16 | 24 | 32 => {}
555            // not
556            e => return Err(PacketParseError::BadCryptoLength(e as u32)),
557        }
558
559        // get the size of the packet to make sure that there is enough space
560
561        // salt + keys (there's a 1 for each in key flags, it's already been anded with 0b11 so max is 2), wrap data is 8 long
562        if buf.remaining() < salt_len + key_len * (key_flags.bits().count_ones() as usize) + 8 {
563            return Err(PacketParseError::NotEnoughData);
564        }
565
566        // the reference implmentation converts the whole thing to network order (bit endian) (in 32-bit words)
567        // so we need to make sure to do the same. Source:
568        // https://github.com/Haivision/srt/blob/2ef4ef003c2006df1458de6d47fbe3d2338edf69/srtcore/crypto.cpp#L115
569
570        // after this, is the salt
571        let mut salt = vec![];
572        for _ in 0..salt_len / 4 {
573            salt.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
574        }
575
576        // then key[s]
577        let mut wrapped_keys = vec![];
578
579        for _ in 0..(key_len * key_flags.bits().count_ones() as usize + 8) / 4 {
580            wrapped_keys.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
581        }
582
583        Ok(KeyingMaterialMessage {
584            pt,
585            key_flags,
586            keki,
587            cipher,
588            auth,
589            salt,
590            wrapped_keys,
591        })
592    }
593
594    fn serialize<T: BufMut>(&self, into: &mut T) {
595        // first 32-bit word:
596        //
597        //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
598        // +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
599        // |0|Vers |   PT  |             Sign              |    resv   |KF |
600
601        // version is 1
602        into.put_u8(1 << 4 | self.pt as u8);
603
604        into.put_u16(Self::SIGN);
605
606        // rightmost bit of KF is even, other is odd
607        into.put_u8(self.key_flags.bits());
608
609        // second 32-bit word: keki
610        into.put_u32(self.keki);
611
612        // third 32-bit word:
613        //
614        //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
615        // +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
616        // |    Cipher     |      Auth     |      SE       |     Resv1     |
617        into.put_u8(self.cipher as u8);
618        into.put_u8(self.auth as u8);
619        into.put_u8(StreamEncapsulation::Srt as u8);
620        into.put_u8(0); // resv1
621
622        // fourth 32-bit word:
623        //
624        //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
625        // +-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-|-+-+-+-+-+-+-+-+
626        // |             Resv2             |     Slen/4    |     Klen/4    |
627        into.put_u16(0); // resv2
628        into.put_u8((self.salt.len() / 4) as u8);
629
630        // this unwrap is okay because we already panic above if both are None
631        let key_len = (self.wrapped_keys.len() - 8) / self.key_flags.bits().count_ones() as usize;
632        into.put_u8((key_len / 4) as u8);
633
634        // put the salt then key[s]
635        into.put(&self.salt[..]);
636
637        // the reference implmentation converts the whole thing to network order (big endian) (in 32-bit words)
638        // so we need to make sure to do the same. Source:
639        // https://github.com/Haivision/srt/blob/2ef4ef003c2006df1458de6d47fbe3d2338edf69/srtcore/crypto.cpp#L115
640
641        for num in self.wrapped_keys[..].chunks(4) {
642            into.put_u32(u32::from_be_bytes([num[0], num[1], num[2], num[3]]));
643        }
644    }
645}
646
647impl fmt::Debug for SrtControlPacket {
648    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
649        match self {
650            SrtControlPacket::Reject => write!(f, "reject"),
651            SrtControlPacket::HandshakeRequest(req) => write!(f, "hsreq={req:?}"),
652            SrtControlPacket::HandshakeResponse(resp) => write!(f, "hsresp={resp:?}"),
653            SrtControlPacket::KeyRefreshRequest(req) => write!(f, "kmreq={req:?}"),
654            SrtControlPacket::KeyRefreshResponse(resp) => write!(f, "kmresp={resp:?}"),
655            SrtControlPacket::StreamId(sid) => write!(f, "streamid={sid}"),
656            SrtControlPacket::Congestion(ctype) => write!(f, "congestion={ctype}"),
657            SrtControlPacket::Filter(filter) => write!(f, "filter={filter:?}"),
658            SrtControlPacket::Group { ty, flags, weight } => {
659                write!(f, "group=({ty:?}, {flags:?}, {weight:?})")
660            }
661        }
662    }
663}
664
665impl TryFrom<u8> for CipherType {
666    type Error = PacketParseError;
667    fn try_from(from: u8) -> Result<CipherType, PacketParseError> {
668        match from {
669            0 => Ok(CipherType::None),
670            1 => Ok(CipherType::Ecb),
671            2 => Ok(CipherType::Ctr),
672            3 => Ok(CipherType::Cbc),
673            e => Err(PacketParseError::BadCipherKind(e)),
674        }
675    }
676}
677
678#[cfg(test)]
679mod tests {
680    use super::{KeyingMaterialMessage, SrtControlPacket, SrtHandshake, SrtShakeFlags};
681
682    use crate::{options::*, packet::*};
683
684    use std::{io::Cursor, time::Duration};
685
686    #[test]
687    fn deser_ser_shake() {
688        let handshake = Packet::Control(ControlPacket {
689            timestamp: TimeStamp::from_micros(123_141),
690            dest_sockid: SocketId(123),
691            control_type: ControlTypes::Srt(SrtControlPacket::HandshakeRequest(SrtHandshake {
692                version: SrtVersion::CURRENT,
693                flags: SrtShakeFlags::empty(),
694                send_latency: Duration::from_millis(4000),
695                recv_latency: Duration::from_millis(3000),
696            })),
697        });
698
699        let mut buf = Vec::new();
700        handshake.serialize(&mut buf);
701
702        let deserialized = Packet::parse(&mut Cursor::new(buf), false).unwrap();
703
704        assert_eq!(handshake, deserialized);
705    }
706
707    #[test]
708    fn ser_deser_sid() {
709        let sid = Packet::Control(ControlPacket {
710            timestamp: TimeStamp::from_micros(123),
711            dest_sockid: SocketId(1234),
712            control_type: ControlTypes::Srt(SrtControlPacket::StreamId("Hellohelloheloo".into())),
713        });
714
715        let mut buf = Vec::new();
716        sid.serialize(&mut buf);
717
718        let deser = Packet::parse(&mut Cursor::new(buf), false).unwrap();
719
720        assert_eq!(sid, deser);
721    }
722
723    #[test]
724    fn srt_key_message_debug() {
725        let salt = b"\x00\x00\x00\x00\x00\x00\x00\x00\x85\x2c\x3c\xcd\x02\x65\x1a\x22";
726        let wrapped = b"U\x06\xe9\xfd\xdfd\xf1'nr\xf4\xe9f\x81#(\xb7\xb5D\x19{\x9b\xcdx";
727
728        let km = KeyingMaterialMessage {
729            pt: PacketType::KeyingMaterial,
730            key_flags: KeyFlags::EVEN,
731            keki: 0,
732            cipher: CipherType::Ctr,
733            auth: Auth::None,
734            salt: salt[..].into(),
735            wrapped_keys: wrapped[..].into(),
736        };
737
738        assert_eq!(format!("{km:?}"), "KeyingMaterialMessage { pt: KeyingMaterial, key_flags: KeyFlags(EVEN), keki: 0, cipher: Ctr, auth: None }")
739    }
740}