use std::{convert::TryFrom, time::Duration};
use bitflags::bitflags;
use bytes::{Buf, BufMut};
use log::warn;
use crate::{PacketParseError, SrtVersion};
use core::fmt;
#[derive(Clone, Eq, PartialEq)]
pub enum SrtControlPacket {
Reject,
HandshakeRequest(SrtHandshake),
HandshakeResponse(SrtHandshake),
KeyManagerRequest(SrtKeyMessage),
KeyManagerResponse(SrtKeyMessage),
StreamId(String),
Smoother,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SrtKeyMessage {
pub pt: PacketType,
pub key_flags: KeyFlags,
pub keki: u32,
pub cipher: CipherType,
pub auth: Auth,
pub salt: Vec<u8>,
pub wrapped_keys: Vec<u8>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Auth {
None = 0,
}
impl TryFrom<u8> for Auth {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Auth::None),
e => Err(PacketParseError::BadAuth(e)),
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum StreamEncapsulation {
Udp = 1,
Srt = 2,
}
impl TryFrom<u8> for StreamEncapsulation {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
1 => StreamEncapsulation::Udp,
2 => StreamEncapsulation::Srt,
e => return Err(PacketParseError::BadStreamEncapsulation(e)),
})
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum PacketType {
MediaStream = 1,
KeyingMaterial = 2,
}
bitflags! {
pub struct KeyFlags : u8 {
const EVEN = 0b01;
const ODD = 0b10;
}
}
impl TryFrom<u8> for PacketType {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(PacketType::MediaStream),
2 => Ok(PacketType::KeyingMaterial),
err => Err(PacketParseError::BadKeyPacketType(err)),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum CipherType {
None = 0,
ECB = 1,
CTR = 2,
CBC = 3,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct SrtHandshake {
pub version: SrtVersion,
pub flags: SrtShakeFlags,
pub send_latency: Duration,
pub recv_latency: Duration,
}
bitflags! {
pub struct SrtShakeFlags: u32 {
const TSBPDSND = 0x1;
const TSBPDRCV = 0x2;
const HAICRYPT = 0x4;
const TLPKTDROP = 0x8;
const NAKREPORT = 0x10;
const REXMITFLG = 0x20;
const STREAM = 0x40;
const FILTERCAP = 0x80;
const SUPPORTED = Self::TSBPDSND.bits | Self::TSBPDRCV.bits | Self::HAICRYPT.bits | Self::REXMITFLG.bits;
}
}
impl SrtControlPacket {
pub fn parse<T: Buf>(
packet_type: u16,
buf: &mut T,
) -> Result<SrtControlPacket, PacketParseError> {
use self::SrtControlPacket::*;
match packet_type {
0 => Ok(Reject),
1 => Ok(HandshakeRequest(SrtHandshake::parse(buf)?)),
2 => Ok(HandshakeResponse(SrtHandshake::parse(buf)?)),
3 => Ok(KeyManagerRequest(SrtKeyMessage::parse(buf)?)),
4 => Ok(KeyManagerResponse(SrtKeyMessage::parse(buf)?)),
5 => {
if buf.remaining() % 4 != 0 {
return Err(PacketParseError::NotEnoughData);
}
let mut bytes = Vec::with_capacity(buf.remaining());
while buf.remaining() > 4 {
bytes.extend(&buf.get_u32_le().to_be_bytes());
}
match buf.get_u32_le().to_be_bytes() {
[a, 0, 0, 0] => bytes.push(a),
[a, b, 0, 0] => bytes.extend(&[a, b]),
[a, b, c, 0] => bytes.extend(&[a, b, c]),
_ => {}
}
match String::from_utf8(bytes) {
Ok(s) => Ok(StreamId(s)),
Err(e) => Err(PacketParseError::StreamTypeNotUTF8(e.utf8_error())),
}
}
_ => Err(PacketParseError::BadSRTConfigExtensionType(packet_type)),
}
}
pub fn type_id(&self) -> u16 {
use self::SrtControlPacket::*;
match self {
Reject => 0,
HandshakeRequest(_) => 1,
HandshakeResponse(_) => 2,
KeyManagerRequest(_) => 3,
KeyManagerResponse(_) => 4,
StreamId(_) => 5,
Smoother => 6,
}
}
pub fn serialize<T: BufMut>(&self, into: &mut T) {
use self::SrtControlPacket::*;
match self {
HandshakeRequest(s) | HandshakeResponse(s) => {
s.serialize(into);
}
KeyManagerRequest(k) | KeyManagerResponse(k) => {
k.serialize(into);
}
StreamId(sid) => {
let mut chunks = sid.as_bytes().chunks_exact(4);
while let Some(&[a, b, c, d]) = chunks.next() {
into.put(&[d, c, b, a][..]);
}
match *chunks.remainder() {
[a, b, c] => into.put(&[0, c, b, a][..]),
[a, b] => into.put(&[0, 0, b, a][..]),
[a] => into.put(&[0, 0, 0, a][..]),
_ => {}
}
}
_ => unimplemented!(),
}
}
pub fn size_words(&self) -> u16 {
use self::SrtControlPacket::*;
match self {
HandshakeRequest(_) | HandshakeResponse(_) => 3,
KeyManagerRequest(ref k) | KeyManagerResponse(ref k) => {
4 + k.salt.len() as u16 / 4 + k.wrapped_keys.len() as u16 / 4
}
StreamId(sid) => ((sid.len() + 3) / 4) as u16,
_ => unimplemented!(),
}
}
}
impl SrtHandshake {
pub fn parse<T: Buf>(buf: &mut T) -> Result<SrtHandshake, PacketParseError> {
if buf.remaining() < 12 {
return Err(PacketParseError::NotEnoughData);
}
let version = SrtVersion::parse(buf.get_u32());
let shake_flags = buf.get_u32();
let flags = match SrtShakeFlags::from_bits(shake_flags) {
Some(i) => i,
None => {
warn!("Unrecognized SRT flags: 0b{:b}", shake_flags);
SrtShakeFlags::from_bits_truncate(shake_flags)
}
};
let peer_latency = buf.get_u16();
let latency = buf.get_u16();
Ok(SrtHandshake {
version,
flags,
send_latency: Duration::from_millis(u64::from(peer_latency)),
recv_latency: Duration::from_millis(u64::from(latency)),
})
}
pub fn serialize<T: BufMut>(&self, into: &mut T) {
into.put_u32(self.version.to_u32());
into.put_u32(self.flags.bits());
into.put_u16(self.send_latency.as_millis() as u16);
into.put_u16(self.recv_latency.as_millis() as u16);
}
}
impl SrtKeyMessage {
const SIGN: u16 =
((b'H' - b'@') as u16) << 10 | ((b'A' - b'@') as u16) << 5 | (b'I' - b'@') as u16;
pub fn parse(buf: &mut impl Buf) -> Result<SrtKeyMessage, PacketParseError> {
if buf.remaining() < 4 * 4 {
return Err(PacketParseError::NotEnoughData);
}
let vers_pt = buf.get_u8();
if (vers_pt & 0b1000_0000) != 0 {
return Err(PacketParseError::BadSRTExtensionMessage);
}
let version = vers_pt >> 4;
if version != 1 {
return Err(PacketParseError::BadSRTExtensionMessage);
}
let pt = PacketType::try_from(vers_pt & 0b0000_1111)?;
let sign = buf.get_u16();
if sign != Self::SIGN {
return Err(PacketParseError::BadKeySign(sign));
}
let key_flags = KeyFlags::from_bits_truncate(buf.get_u8() & 0b0000_0011);
let keki = buf.get_u32();
let cipher = CipherType::try_from(buf.get_u8())?;
let auth = Auth::try_from(buf.get_u8())?;
let se = StreamEncapsulation::try_from(buf.get_u8())?;
if se != StreamEncapsulation::Srt {
return Err(PacketParseError::StreamEncapsulationNotSrt);
}
let _resv1 = buf.get_u8();
let _resv2 = buf.get_u16();
let salt_len = usize::from(buf.get_u8()) * 4;
let key_len = usize::from(buf.get_u8()) * 4;
match key_len {
16 | 24 | 32 => {}
e => return Err(PacketParseError::BadCryptoLength(e as u32)),
}
if buf.remaining() < salt_len + key_len * (key_flags.bits.count_ones() as usize) + 8 {
return Err(PacketParseError::NotEnoughData);
}
let mut salt = vec![];
for _ in 0..salt_len / 4 {
salt.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
}
let mut wrapped_keys = vec![];
for _ in 0..(key_len * key_flags.bits.count_ones() as usize + 8) / 4 {
wrapped_keys.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
}
Ok(SrtKeyMessage {
pt,
key_flags,
keki,
cipher,
auth,
salt,
wrapped_keys,
})
}
fn serialize<T: BufMut>(&self, into: &mut T) {
into.put_u8(1 << 4 | self.pt as u8);
into.put_u16(Self::SIGN);
into.put_u8(self.key_flags.bits);
into.put_u32(self.keki);
into.put_u8(self.cipher as u8);
into.put_u8(self.auth as u8);
into.put_u8(StreamEncapsulation::Srt as u8);
into.put_u8(0);
into.put_u16(0);
into.put_u8((self.salt.len() / 4) as u8);
let key_len = (self.wrapped_keys.len() - 8) / self.key_flags.bits.count_ones() as usize;
into.put_u8((key_len / 4) as u8);
into.put(&self.salt[..]);
for num in self.wrapped_keys[..].chunks(4) {
into.put_u32(u32::from_be_bytes([num[0], num[1], num[2], num[3]]));
}
}
}
impl fmt::Debug for SrtControlPacket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SrtControlPacket::Reject => write!(f, "reject"),
SrtControlPacket::HandshakeRequest(req) => write!(f, "hsreq={:?}", req),
SrtControlPacket::HandshakeResponse(resp) => write!(f, "hsresp={:?}", resp),
SrtControlPacket::KeyManagerRequest(req) => write!(f, "kmreq={:?}", req),
SrtControlPacket::KeyManagerResponse(resp) => write!(f, "kmresp={:?}", resp),
SrtControlPacket::StreamId(sid) => write!(f, "streamid={}", sid),
SrtControlPacket::Smoother => write!(f, "smoother"),
}
}
}
impl TryFrom<u8> for CipherType {
type Error = PacketParseError;
fn try_from(from: u8) -> Result<CipherType, PacketParseError> {
match from {
0 => Ok(CipherType::None),
1 => Ok(CipherType::ECB),
2 => Ok(CipherType::CTR),
3 => Ok(CipherType::CBC),
e => Err(PacketParseError::BadCipherKind(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::{SrtControlPacket, SrtHandshake, SrtShakeFlags};
use crate::packet::ControlTypes;
use crate::{protocol::TimeStamp, ControlPacket, Packet, SocketID, SrtVersion};
use std::io::Cursor;
use std::time::Duration;
#[test]
fn deser_ser_shake() {
let handshake = Packet::Control(ControlPacket {
timestamp: TimeStamp::from_micros(123_141),
dest_sockid: SocketID(123),
control_type: ControlTypes::Srt(SrtControlPacket::HandshakeRequest(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::empty(),
send_latency: Duration::from_millis(4000),
recv_latency: Duration::from_millis(3000),
})),
});
let mut buf = Vec::new();
handshake.serialize(&mut buf);
let deserialized = Packet::parse(&mut Cursor::new(buf), false).unwrap();
assert_eq!(handshake, deserialized);
}
#[test]
fn ser_deser_sid() {
let sid = Packet::Control(ControlPacket {
timestamp: TimeStamp::from_micros(123),
dest_sockid: SocketID(1234),
control_type: ControlTypes::Srt(SrtControlPacket::StreamId("Hellohelloheloo".into())),
});
let mut buf = Vec::new();
sid.serialize(&mut buf);
let deser = Packet::parse(&mut Cursor::new(buf), false).unwrap();
assert_eq!(sid, deser);
}
}