1use std::{
2 fmt::{self, Display, Formatter},
3 {collections::BTreeMap, convert::TryFrom, time::Duration},
4};
5
6use bitflags::bitflags;
7use bytes::{Buf, BufMut};
8use log::warn;
9
10use crate::{options::SrtVersion, packet::PacketParseError};
11
12#[derive(Clone, Eq, PartialEq)]
15pub enum SrtControlPacket {
16 Reject,
19
20 HandshakeRequest(SrtHandshake),
23
24 HandshakeResponse(SrtHandshake),
27
28 KeyRefreshRequest(KeyingMaterialMessage),
31
32 KeyRefreshResponse(KeyingMaterialMessage),
35
36 StreamId(String),
39
40 Congestion(String),
43
44 Filter(FilterSpec),
49
50 Group {
52 ty: GroupType,
53 flags: GroupFlags,
54 weight: u16,
55 },
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct FilterSpec(pub BTreeMap<String, String>);
60
61#[derive(Copy, Clone, Eq, PartialEq, Debug)]
62pub enum GroupType {
63 Undefined,
64 Broadcast,
65 MainBackup,
66 Balancing,
67 Multicast,
68 Unrecognized(u8),
69}
70
71bitflags! {
72 #[derive(Clone, Copy, Eq, PartialEq, Debug)]
73 pub struct GroupFlags: u8 {
74 const MSG_SYNC = 1 << 6;
75 }
76}
77
78#[derive(Clone, Eq, PartialEq)]
104pub struct KeyingMaterialMessage {
105 pub pt: PacketType, pub key_flags: KeyFlags,
107 pub keki: u32,
108 pub cipher: CipherType,
109 pub auth: Auth,
110 pub salt: Vec<u8>,
111 pub wrapped_keys: Vec<u8>,
112}
113
114impl From<GroupType> for u8 {
115 fn from(from: GroupType) -> u8 {
116 match from {
117 GroupType::Undefined => 0,
118 GroupType::Broadcast => 1,
119 GroupType::MainBackup => 2,
120 GroupType::Balancing => 3,
121 GroupType::Multicast => 4,
122 GroupType::Unrecognized(u) => u,
123 }
124 }
125}
126
127impl From<u8> for GroupType {
128 fn from(from: u8) -> GroupType {
129 match from {
130 0 => GroupType::Undefined,
131 1 => GroupType::Broadcast,
132 2 => GroupType::MainBackup,
133 3 => GroupType::Balancing,
134 4 => GroupType::Multicast,
135 u => GroupType::Unrecognized(u),
136 }
137 }
138}
139
140impl fmt::Debug for KeyingMaterialMessage {
141 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
142 f.debug_struct("KeyingMaterialMessage")
143 .field("pt", &self.pt)
144 .field("key_flags", &self.key_flags)
145 .field("keki", &self.keki)
146 .field("cipher", &self.cipher)
147 .field("auth", &self.auth)
148 .finish()
149 }
150}
151
152#[derive(Debug, Copy, Clone, Eq, PartialEq)]
153pub enum Auth {
154 None = 0,
155}
156
157impl TryFrom<u8> for Auth {
158 type Error = PacketParseError;
159 fn try_from(value: u8) -> Result<Self, Self::Error> {
160 match value {
161 0 => Ok(Auth::None),
162 e => Err(PacketParseError::BadAuth(e)),
163 }
164 }
165}
166
167#[derive(Debug, Copy, Clone, Eq, PartialEq)]
168pub enum StreamEncapsulation {
169 Udp = 1,
170 Srt = 2,
171}
172
173impl TryFrom<u8> for StreamEncapsulation {
174 type Error = PacketParseError;
175 fn try_from(value: u8) -> Result<Self, Self::Error> {
176 Ok(match value {
177 1 => StreamEncapsulation::Udp,
178 2 => StreamEncapsulation::Srt,
179 e => return Err(PacketParseError::BadStreamEncapsulation(e)),
180 })
181 }
182}
183
184#[derive(Debug, Copy, Clone, Eq, PartialEq)]
185pub enum PacketType {
188 MediaStream = 1, KeyingMaterial = 2, }
191
192bitflags! {
193 #[derive(Clone, Copy, Eq, PartialEq, Debug)]
194 pub struct KeyFlags : u8 {
195 const EVEN = 0b01;
196 const ODD = 0b10;
197 }
198}
199
200impl TryFrom<u8> for PacketType {
201 type Error = PacketParseError;
202 fn try_from(value: u8) -> Result<Self, Self::Error> {
203 match value {
204 1 => Ok(PacketType::MediaStream),
205 2 => Ok(PacketType::KeyingMaterial),
206 err => Err(PacketParseError::BadKeyPacketType(err)),
207 }
208 }
209}
210
211#[derive(Debug, Clone, Copy, Eq, PartialEq)]
213pub enum CipherType {
214 None = 0,
215 Ecb = 1,
216 Ctr = 2,
217 Cbc = 3,
218}
219
220#[derive(Debug, Clone, Copy, Eq, PartialEq)]
222pub struct SrtHandshake {
223 pub version: SrtVersion,
226
227 pub flags: SrtShakeFlags,
229
230 pub send_latency: Duration,
234
235 pub recv_latency: Duration,
239}
240
241bitflags! {
242 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
243 pub struct SrtShakeFlags: u32 {
244 const TSBPDSND = 0x1;
246
247 const TSBPDRCV = 0x2;
249
250 const HAICRYPT = 0x4;
253
254 const TLPKTDROP = 0x8;
256
257 const NAKREPORT = 0x10;
259
260 const REXMITFLG = 0x20;
262
263 const STREAM = 0x40;
265
266 const PACKET_FILTER = 0x80;
268
269 const SUPPORTED = Self::TSBPDSND.bits() | Self::TSBPDRCV.bits() | Self::HAICRYPT.bits() | Self::REXMITFLG.bits();
271 }
272}
273
274fn le_bytes_to_string(le_bytes: &mut impl Buf) -> Result<String, PacketParseError> {
275 if le_bytes.remaining() % 4 != 0 {
276 return Err(PacketParseError::NotEnoughData);
277 }
278
279 let mut str_bytes = Vec::with_capacity(le_bytes.remaining());
280
281 while le_bytes.remaining() > 4 {
282 str_bytes.extend(le_bytes.get_u32_le().to_be_bytes());
283 }
284
285 match le_bytes.get_u32_le().to_be_bytes() {
287 [a, 0, 0, 0] => str_bytes.push(a),
288 [a, b, 0, 0] => str_bytes.extend([a, b]),
289 [a, b, c, 0] => str_bytes.extend([a, b, c]),
290 [a, b, c, d] => str_bytes.extend([a, b, c, d]),
291 }
292
293 String::from_utf8(str_bytes).map_err(|e| PacketParseError::StreamTypeNotUtf8(e.utf8_error()))
294}
295
296fn string_to_le_bytes(str: &str, into: &mut impl BufMut) {
297 let mut chunks = str.as_bytes().chunks_exact(4);
298
299 while let Some(&[a, b, c, d]) = chunks.next() {
300 into.put(&[d, c, b, a][..]);
301 }
302
303 match *chunks.remainder() {
305 [a, b, c] => into.put(&[0, c, b, a][..]),
306 [a, b] => into.put(&[0, 0, b, a][..]),
307 [a] => into.put(&[0, 0, 0, a][..]),
308 [] => {} _ => unreachable!(),
310 }
311}
312
313impl Display for FilterSpec {
314 fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
315 for (i, (k, v)) in self.0.iter().enumerate() {
316 write!(f, "{k}:{v}")?;
317 if i != self.0.len() - 1 {
318 write!(f, ",")?;
319 }
320 }
321 Ok(())
322 }
323}
324
325impl SrtControlPacket {
326 pub fn parse<T: Buf>(
327 packet_type: u16,
328 buf: &mut T,
329 ) -> Result<SrtControlPacket, PacketParseError> {
330 use self::SrtControlPacket::*;
331
332 match packet_type {
333 0 => Ok(Reject),
334 1 => Ok(HandshakeRequest(SrtHandshake::parse(buf)?)),
335 2 => Ok(HandshakeResponse(SrtHandshake::parse(buf)?)),
336 3 => Ok(KeyRefreshRequest(KeyingMaterialMessage::parse(buf)?)),
337 4 => Ok(KeyRefreshResponse(KeyingMaterialMessage::parse(buf)?)),
338 5 => {
339 le_bytes_to_string(buf).map(StreamId)
342 }
343 6 => le_bytes_to_string(buf).map(Congestion),
344 7 => {
346 let filter_str = le_bytes_to_string(buf)?;
347 Ok(Filter(FilterSpec(
348 filter_str
349 .split(',')
350 .map(|kv| {
351 let mut colon_split_iter = kv.split(':');
352 let k = colon_split_iter
353 .next()
354 .ok_or_else(|| PacketParseError::BadFilter(filter_str.clone()))?;
355 let v = colon_split_iter
356 .next()
357 .ok_or_else(|| PacketParseError::BadFilter(filter_str.clone()))?;
358 if colon_split_iter.next().is_some() {
360 return Err(PacketParseError::BadFilter(filter_str.clone()));
361 }
362 Ok((k.to_string(), v.to_string()))
363 })
364 .collect::<Result<_, _>>()?,
365 )))
366 }
367 8 => {
368 let ty = buf.get_u8().into();
369 let flags = GroupFlags::from_bits_truncate(buf.get_u8());
370 let weight = buf.get_u16_le();
371 Ok(Group { ty, flags, weight })
372 }
373 _ => Err(PacketParseError::UnsupportedSrtExtensionType(packet_type)),
374 }
375 }
376
377 pub fn type_id(&self) -> u16 {
379 use self::SrtControlPacket::*;
380
381 match self {
382 Reject => 0,
383 HandshakeRequest(_) => 1,
384 HandshakeResponse(_) => 2,
385 KeyRefreshRequest(_) => 3,
386 KeyRefreshResponse(_) => 4,
387 StreamId(_) => 5,
388 Congestion(_) => 6,
389 Filter(_) => 7,
390 Group { .. } => 8,
391 }
392 }
393 pub fn serialize<T: BufMut>(&self, into: &mut T) {
394 use self::SrtControlPacket::*;
395
396 match self {
397 HandshakeRequest(s) | HandshakeResponse(s) => {
398 s.serialize(into);
399 }
400 KeyRefreshRequest(k) | KeyRefreshResponse(k) => {
401 k.serialize(into);
402 }
403 Filter(filter) => {
404 string_to_le_bytes(&format!("{filter}"), into);
405 }
406 Group { ty, flags, weight } => {
407 into.put_u8((*ty).into());
408 into.put_u8(flags.bits());
409 into.put_u16_le(*weight);
410 }
411 Reject => {}
412 StreamId(str) | Congestion(str) => {
413 string_to_le_bytes(str, into);
416 }
417 }
418 }
419 pub fn size_words(&self) -> u16 {
421 use self::SrtControlPacket::*;
422
423 match self {
424 HandshakeRequest(_) | HandshakeResponse(_) => 3,
426 KeyRefreshRequest(ref k) | KeyRefreshResponse(ref k) => {
428 4 + k.salt.len() as u16 / 4 + k.wrapped_keys.len() as u16 / 4
429 }
430 Congestion(str) | StreamId(str) => ((str.len() + 3) / 4) as u16, Group { .. } => 1,
433 Filter(filter) => ((format!("{filter}").len() + 3) / 4) as u16, _ => unimplemented!("{:?}", self),
435 }
436 }
437}
438
439impl SrtHandshake {
440 pub fn parse<T: Buf>(buf: &mut T) -> Result<SrtHandshake, PacketParseError> {
441 if buf.remaining() < 12 {
442 return Err(PacketParseError::NotEnoughData);
443 }
444
445 let version = SrtVersion::parse(buf.get_u32());
446
447 let shake_flags = buf.get_u32();
448 let flags = match SrtShakeFlags::from_bits(shake_flags) {
449 Some(i) => i,
450 None => {
451 warn!("Unrecognized SRT flags: 0b{:b}", shake_flags);
452 SrtShakeFlags::from_bits_truncate(shake_flags)
453 }
454 };
455 let peer_latency = buf.get_u16();
456 let latency = buf.get_u16();
457
458 Ok(SrtHandshake {
459 version,
460 flags,
461 send_latency: Duration::from_millis(u64::from(peer_latency)),
462 recv_latency: Duration::from_millis(u64::from(latency)),
463 })
464 }
465
466 pub fn serialize<T: BufMut>(&self, into: &mut T) {
467 into.put_u32(self.version.to_u32());
468 into.put_u32(self.flags.bits());
469 into.put_u16(self.send_latency.as_millis() as u16); into.put_u16(self.recv_latency.as_millis() as u16); }
475}
476
477impl KeyingMaterialMessage {
478 const SIGN: u16 =
481 ((b'H' - b'@') as u16) << 10 | ((b'A' - b'@') as u16) << 5 | (b'I' - b'@') as u16;
482
483 pub fn parse(buf: &mut impl Buf) -> Result<KeyingMaterialMessage, PacketParseError> {
484 if buf.remaining() < 4 * 4 {
493 return Err(PacketParseError::NotEnoughData);
494 }
495
496 let vers_pt = buf.get_u8();
497
498 if (vers_pt & 0b1000_0000) != 0 {
500 return Err(PacketParseError::BadSrtExtensionMessage);
501 }
502
503 let version = vers_pt >> 4;
505
506 if version != 1 {
507 return Err(PacketParseError::BadSrtExtensionMessage);
508 }
509
510 let pt = PacketType::try_from(vers_pt & 0b0000_1111)?;
512
513 let sign = buf.get_u16();
515
516 if sign != Self::SIGN {
517 return Err(PacketParseError::BadKeySign(sign));
518 }
519
520 let key_flags = KeyFlags::from_bits_truncate(buf.get_u8() & 0b0000_0011);
522
523 let keki = buf.get_u32();
525
526 let cipher = CipherType::try_from(buf.get_u8())?;
533 let auth = Auth::try_from(buf.get_u8())?;
534 let se = StreamEncapsulation::try_from(buf.get_u8())?;
535 if se != StreamEncapsulation::Srt {
536 return Err(PacketParseError::StreamEncapsulationNotSrt);
537 }
538
539 let _resv1 = buf.get_u8();
540
541 let _resv2 = buf.get_u16();
548 let salt_len = usize::from(buf.get_u8()) * 4;
549 let key_len = usize::from(buf.get_u8()) * 4;
550
551 match key_len {
553 16 | 24 | 32 => {}
555 e => return Err(PacketParseError::BadCryptoLength(e as u32)),
557 }
558
559 if buf.remaining() < salt_len + key_len * (key_flags.bits().count_ones() as usize) + 8 {
563 return Err(PacketParseError::NotEnoughData);
564 }
565
566 let mut salt = vec![];
572 for _ in 0..salt_len / 4 {
573 salt.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
574 }
575
576 let mut wrapped_keys = vec![];
578
579 for _ in 0..(key_len * key_flags.bits().count_ones() as usize + 8) / 4 {
580 wrapped_keys.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
581 }
582
583 Ok(KeyingMaterialMessage {
584 pt,
585 key_flags,
586 keki,
587 cipher,
588 auth,
589 salt,
590 wrapped_keys,
591 })
592 }
593
594 fn serialize<T: BufMut>(&self, into: &mut T) {
595 into.put_u8(1 << 4 | self.pt as u8);
603
604 into.put_u16(Self::SIGN);
605
606 into.put_u8(self.key_flags.bits());
608
609 into.put_u32(self.keki);
611
612 into.put_u8(self.cipher as u8);
618 into.put_u8(self.auth as u8);
619 into.put_u8(StreamEncapsulation::Srt as u8);
620 into.put_u8(0); into.put_u16(0); into.put_u8((self.salt.len() / 4) as u8);
629
630 let key_len = (self.wrapped_keys.len() - 8) / self.key_flags.bits().count_ones() as usize;
632 into.put_u8((key_len / 4) as u8);
633
634 into.put(&self.salt[..]);
636
637 for num in self.wrapped_keys[..].chunks(4) {
642 into.put_u32(u32::from_be_bytes([num[0], num[1], num[2], num[3]]));
643 }
644 }
645}
646
647impl fmt::Debug for SrtControlPacket {
648 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
649 match self {
650 SrtControlPacket::Reject => write!(f, "reject"),
651 SrtControlPacket::HandshakeRequest(req) => write!(f, "hsreq={req:?}"),
652 SrtControlPacket::HandshakeResponse(resp) => write!(f, "hsresp={resp:?}"),
653 SrtControlPacket::KeyRefreshRequest(req) => write!(f, "kmreq={req:?}"),
654 SrtControlPacket::KeyRefreshResponse(resp) => write!(f, "kmresp={resp:?}"),
655 SrtControlPacket::StreamId(sid) => write!(f, "streamid={sid}"),
656 SrtControlPacket::Congestion(ctype) => write!(f, "congestion={ctype}"),
657 SrtControlPacket::Filter(filter) => write!(f, "filter={filter:?}"),
658 SrtControlPacket::Group { ty, flags, weight } => {
659 write!(f, "group=({ty:?}, {flags:?}, {weight:?})")
660 }
661 }
662 }
663}
664
665impl TryFrom<u8> for CipherType {
666 type Error = PacketParseError;
667 fn try_from(from: u8) -> Result<CipherType, PacketParseError> {
668 match from {
669 0 => Ok(CipherType::None),
670 1 => Ok(CipherType::Ecb),
671 2 => Ok(CipherType::Ctr),
672 3 => Ok(CipherType::Cbc),
673 e => Err(PacketParseError::BadCipherKind(e)),
674 }
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use super::{KeyingMaterialMessage, SrtControlPacket, SrtHandshake, SrtShakeFlags};
681
682 use crate::{options::*, packet::*};
683
684 use std::{io::Cursor, time::Duration};
685
686 #[test]
687 fn deser_ser_shake() {
688 let handshake = Packet::Control(ControlPacket {
689 timestamp: TimeStamp::from_micros(123_141),
690 dest_sockid: SocketId(123),
691 control_type: ControlTypes::Srt(SrtControlPacket::HandshakeRequest(SrtHandshake {
692 version: SrtVersion::CURRENT,
693 flags: SrtShakeFlags::empty(),
694 send_latency: Duration::from_millis(4000),
695 recv_latency: Duration::from_millis(3000),
696 })),
697 });
698
699 let mut buf = Vec::new();
700 handshake.serialize(&mut buf);
701
702 let deserialized = Packet::parse(&mut Cursor::new(buf), false).unwrap();
703
704 assert_eq!(handshake, deserialized);
705 }
706
707 #[test]
708 fn ser_deser_sid() {
709 let sid = Packet::Control(ControlPacket {
710 timestamp: TimeStamp::from_micros(123),
711 dest_sockid: SocketId(1234),
712 control_type: ControlTypes::Srt(SrtControlPacket::StreamId("Hellohelloheloo".into())),
713 });
714
715 let mut buf = Vec::new();
716 sid.serialize(&mut buf);
717
718 let deser = Packet::parse(&mut Cursor::new(buf), false).unwrap();
719
720 assert_eq!(sid, deser);
721 }
722
723 #[test]
724 fn srt_key_message_debug() {
725 let salt = b"\x00\x00\x00\x00\x00\x00\x00\x00\x85\x2c\x3c\xcd\x02\x65\x1a\x22";
726 let wrapped = b"U\x06\xe9\xfd\xdfd\xf1'nr\xf4\xe9f\x81#(\xb7\xb5D\x19{\x9b\xcdx";
727
728 let km = KeyingMaterialMessage {
729 pt: PacketType::KeyingMaterial,
730 key_flags: KeyFlags::EVEN,
731 keki: 0,
732 cipher: CipherType::Ctr,
733 auth: Auth::None,
734 salt: salt[..].into(),
735 wrapped_keys: wrapped[..].into(),
736 };
737
738 assert_eq!(format!("{km:?}"), "KeyingMaterialMessage { pt: KeyingMaterial, key_flags: KeyFlags(EVEN), keki: 0, cipher: Ctr, auth: None }")
739 }
740}