use std::fmt::{self, Debug, Formatter};
use std::net::{IpAddr, Ipv4Addr};
use bitflags::bitflags;
use bytes::{Buf, BufMut};
use log::warn;
use crate::protocol::{TimeSpan, TimeStamp};
use crate::{MsgNumber, SeqNumber, SocketID};
mod srt;
pub use self::srt::{CipherType, SrtControlPacket, SrtHandshake, SrtKeyMessage, SrtShakeFlags};
use super::PacketParseError;
#[derive(Clone, PartialEq, Eq)]
pub struct ControlPacket {
pub timestamp: TimeStamp,
pub dest_sockid: SocketID,
pub control_type: ControlTypes,
}
#[derive(Clone, PartialEq, Eq)]
#[allow(clippy::large_enum_variant)]
pub enum ControlTypes {
Handshake(HandshakeControlInfo),
KeepAlive,
Ack(AckControlInfo),
Nak(Vec<u32>),
Shutdown,
Ack2(i32),
DropRequest {
msg_to_drop: MsgNumber,
first: SeqNumber,
last: SeqNumber,
},
Srt(SrtControlPacket),
}
bitflags! {
struct ExtFlags: u16 {
const HS = 0b1;
const KM = 0b10;
const CONFIG = 0b100;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(clippy::large_enum_variant)]
pub enum HandshakeVSInfo {
V4(SocketType),
V5 {
crypto_size: u8,
ext_hs: Option<SrtControlPacket>,
ext_km: Option<SrtControlPacket>,
ext_config: Option<SrtControlPacket>,
},
}
#[derive(Clone, PartialEq, Eq)]
pub struct HandshakeControlInfo {
pub init_seq_num: SeqNumber,
pub max_packet_size: u32,
pub max_flow_size: u32,
pub shake_type: ShakeType,
pub socket_id: SocketID,
pub syn_cookie: i32,
pub peer_addr: IpAddr,
pub info: HandshakeVSInfo,
}
#[derive(Clone, PartialEq, Eq)]
pub struct AckControlInfo {
pub ack_seq_num: i32,
pub ack_number: SeqNumber,
pub rtt: Option<TimeSpan>,
pub rtt_variance: Option<TimeSpan>,
pub buffer_available: Option<i32>,
pub packet_recv_rate: Option<u32>,
pub est_link_cap: Option<i32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SocketType {
Stream = 1,
Datagram = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShakeType {
Induction = 1,
Waveahand = 0,
Conclusion = -1,
Agreement = -2,
}
impl HandshakeVSInfo {
fn type_flags(&self, shake_type: ShakeType) -> u32 {
match self {
HandshakeVSInfo::V4(ty) => *ty as u32,
HandshakeVSInfo::V5 {
crypto_size,
ext_hs,
ext_km,
ext_config,
} => {
if shake_type == ShakeType::Induction
&& (ext_hs.is_some() || ext_km.is_some() || ext_config.is_some())
{
panic!("Handshake is both induction and has SRT extensions, not valid");
}
let mut flags = ExtFlags::empty();
if ext_hs.is_some() {
flags |= ExtFlags::HS;
}
if ext_km.is_some() {
flags |= ExtFlags::KM;
}
if ext_config.is_some() {
flags |= ExtFlags::CONFIG;
}
(u32::from(*crypto_size) >> 3 << 16)
| if shake_type == ShakeType::Induction {
u32::from(SRT_MAGIC_CODE)
} else {
u32::from(flags.bits())
}
}
}
}
pub fn version(&self) -> u32 {
match self {
HandshakeVSInfo::V4(_) => 4,
HandshakeVSInfo::V5 { .. } => 5,
}
}
}
impl SocketType {
pub fn from_u16(num: u16) -> Result<SocketType, u16> {
match num {
1 => Ok(SocketType::Stream),
2 => Ok(SocketType::Datagram),
i => Err(i),
}
}
}
impl ControlPacket {
pub fn parse(buf: &mut impl Buf) -> Result<ControlPacket, PacketParseError> {
let control_type = buf.get_u16() << 1 >> 1;
let reserved = buf.get_u16();
let add_info = buf.get_i32();
let timestamp = TimeStamp::from_micros(buf.get_u32());
let dest_sockid = buf.get_u32();
Ok(ControlPacket {
timestamp,
dest_sockid: SocketID(dest_sockid),
control_type: ControlTypes::deserialize(control_type, reserved, add_info, buf)?,
})
}
pub fn serialize<T: BufMut>(&self, into: &mut T) {
into.put_u16(self.control_type.id_byte() | (0b1 << 15));
into.put_u16(self.control_type.reserved());
into.put_i32(self.control_type.additional_info());
into.put_u32(self.timestamp.as_micros());
into.put_u32(self.dest_sockid.0);
self.control_type.serialize(into);
}
}
impl Debug for ControlPacket {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(
f,
"{{{:?} ts={:.4}s dst={:?}}}",
self.control_type,
self.timestamp.as_secs_f64(),
self.dest_sockid,
)
}
}
const SRT_MAGIC_CODE: u16 = 0x4A17;
impl ControlTypes {
fn deserialize<T: Buf>(
packet_type: u16,
reserved: u16,
extra_info: i32,
mut buf: T,
) -> Result<ControlTypes, PacketParseError> {
match packet_type {
0x0 => {
if buf.remaining() < 8 * 4 + 16 {
return Err(PacketParseError::NotEnoughData);
}
let udt_version = buf.get_i32();
if udt_version != 4 && udt_version != 5 {
return Err(PacketParseError::BadUDTVersion(udt_version));
}
let crypto_size = buf.get_u16() << 3;
let type_ext_socket_type = buf.get_u16();
let init_seq_num = SeqNumber::new_truncate(buf.get_u32());
let max_packet_size = buf.get_u32();
let max_flow_size = buf.get_u32();
let shake_type = match ShakeType::from_i32(buf.get_i32()) {
Ok(ct) => ct,
Err(err_ct) => return Err(PacketParseError::BadConnectionType(err_ct)),
};
let socket_id = SocketID(buf.get_u32());
let syn_cookie = buf.get_i32();
let mut ip_buf: [u8; 16] = [0; 16];
buf.copy_to_slice(&mut ip_buf);
let peer_addr = if ip_buf[4..] == [0; 12][..] {
IpAddr::from(Ipv4Addr::new(ip_buf[3], ip_buf[2], ip_buf[1], ip_buf[0]))
} else {
IpAddr::from(ip_buf)
};
let info = match udt_version {
4 => HandshakeVSInfo::V4(match SocketType::from_u16(type_ext_socket_type) {
Ok(t) => t,
Err(e) => return Err(PacketParseError::BadSocketType(e)),
}),
5 => {
let crypto_size = match crypto_size {
0 | 16 | 24 | 32 => crypto_size as u8,
c => {
warn!(
"Unrecognized crypto key length: {}, disabling encryption. Should be 16, 24, or 32 bytes",
c
);
0
}
};
if shake_type == ShakeType::Induction {
if type_ext_socket_type != SRT_MAGIC_CODE {
warn!("HSv5 induction response did not have SRT_MAGIC_CODE, which is suspicious")
}
HandshakeVSInfo::V5 {
crypto_size,
ext_hs: None,
ext_km: None,
ext_config: None,
}
} else {
let extensions = match ExtFlags::from_bits(type_ext_socket_type) {
Some(i) => i,
None => {
warn!(
"Unnecessary bits in extensions flags: {:b}",
type_ext_socket_type
);
ExtFlags::from_bits_truncate(type_ext_socket_type)
}
};
let ext_hs = if extensions.contains(ExtFlags::HS) {
if buf.remaining() < 4 {
return Err(PacketParseError::NotEnoughData);
}
let pack_type = buf.get_u16();
let _pack_size = buf.get_u16();
match pack_type {
1 | 2 => Some(SrtControlPacket::parse(pack_type, &mut buf)?),
e => return Err(PacketParseError::BadSRTHsExtensionType(e)),
}
} else {
None
};
let ext_km = if extensions.contains(ExtFlags::KM) {
if buf.remaining() < 4 {
return Err(PacketParseError::NotEnoughData);
}
let pack_type = buf.get_u16();
let _pack_size = buf.get_u16();
match pack_type {
3 | 4 => Some(SrtControlPacket::parse(pack_type, &mut buf)?),
e => return Err(PacketParseError::BadSRTKmExtensionType(e)),
}
} else {
None
};
let ext_config = if extensions.contains(ExtFlags::CONFIG) {
if buf.remaining() < 4 {
return Err(PacketParseError::NotEnoughData);
}
let pack_type = buf.get_u16();
let _pack_size = buf.get_u16();
match pack_type {
5 | 6 => Some(SrtControlPacket::parse(pack_type, &mut buf)?),
e => {
return Err(PacketParseError::BadSRTConfigExtensionType(e))
}
}
} else {
None
};
HandshakeVSInfo::V5 {
crypto_size,
ext_hs,
ext_km,
ext_config,
}
}
}
_ => unreachable!(),
};
Ok(ControlTypes::Handshake(HandshakeControlInfo {
init_seq_num,
max_packet_size,
max_flow_size,
shake_type,
socket_id,
syn_cookie,
peer_addr,
info,
}))
}
0x1 => Ok(ControlTypes::KeepAlive),
0x2 => {
if buf.remaining() < 4 {
return Err(PacketParseError::NotEnoughData);
}
let ack_number = SeqNumber::new_truncate(buf.get_u32());
let opt_read_next_u32 = |buf: &mut T| {
if buf.remaining() >= 4 {
Some(buf.get_u32())
} else {
None
}
};
let opt_read_next_i32 = |buf: &mut T| {
if buf.remaining() >= 4 {
Some(buf.get_i32())
} else {
None
}
};
let rtt = opt_read_next_i32(&mut buf).map(TimeSpan::from_micros);
let rtt_variance = opt_read_next_i32(&mut buf).map(TimeSpan::from_micros);
let buffer_available = opt_read_next_i32(&mut buf);
let packet_recv_rate = opt_read_next_u32(&mut buf);
let est_link_cap = opt_read_next_i32(&mut buf);
Ok(ControlTypes::Ack(AckControlInfo {
ack_seq_num: extra_info,
ack_number,
rtt,
rtt_variance,
buffer_available,
packet_recv_rate,
est_link_cap,
}))
}
0x3 => {
let mut loss_info = Vec::new();
while buf.remaining() >= 4 {
loss_info.push(buf.get_u32());
}
Ok(ControlTypes::Nak(loss_info))
}
0x5 => Ok(ControlTypes::Shutdown),
0x6 => {
Ok(ControlTypes::Ack2(extra_info))
}
0x7 => {
if buf.remaining() < 2 * 4 {
return Err(PacketParseError::NotEnoughData);
}
Ok(ControlTypes::DropRequest {
msg_to_drop: MsgNumber::new_truncate(extra_info as u32),
first: SeqNumber::new_truncate(buf.get_u32()),
last: SeqNumber::new_truncate(buf.get_u32()),
})
}
0x7FFF => {
Ok(ControlTypes::Srt(SrtControlPacket::parse(
reserved, &mut buf,
)?))
}
x => Err(PacketParseError::BadControlType(x)),
}
}
fn id_byte(&self) -> u16 {
match *self {
ControlTypes::Handshake(_) => 0x0,
ControlTypes::KeepAlive => 0x1,
ControlTypes::Ack { .. } => 0x2,
ControlTypes::Nak(_) => 0x3,
ControlTypes::Shutdown => 0x5,
ControlTypes::Ack2(_) => 0x6,
ControlTypes::DropRequest { .. } => 0x7,
ControlTypes::Srt(_) => 0x7FFF,
}
}
fn additional_info(&self) -> i32 {
match self {
ControlTypes::DropRequest { msg_to_drop: a, .. } => a.as_raw() as i32,
ControlTypes::Ack2(a) | ControlTypes::Ack(AckControlInfo { ack_seq_num: a, .. }) => *a,
_ => 0,
}
}
fn reserved(&self) -> u16 {
match self {
ControlTypes::Srt(srt) => srt.type_id(),
_ => 0,
}
}
fn serialize<T: BufMut>(&self, into: &mut T) {
match self {
ControlTypes::Handshake(ref c) => {
into.put_u32(c.info.version());
into.put_u32(c.info.type_flags(c.shake_type));
into.put_u32(c.init_seq_num.as_raw());
into.put_u32(c.max_packet_size);
into.put_u32(c.max_flow_size);
into.put_i32(c.shake_type as i32);
into.put_u32(c.socket_id.0);
into.put_i32(c.syn_cookie);
match c.peer_addr {
IpAddr::V4(four) => {
let mut v = Vec::from(&four.octets()[..]);
v.reverse();
into.put(&v[..]);
into.put(&[0; 12][..]);
}
IpAddr::V6(six) => {
let mut v = Vec::from(&six.octets()[..]);
v.reverse();
into.put(&v[..]);
}
}
if let HandshakeVSInfo::V5 {
ref ext_hs,
ref ext_km,
ref ext_config,
..
} = c.info
{
for ext in [ext_hs, ext_km, ext_config]
.iter()
.filter_map(|&s| s.as_ref())
{
into.put_u16(ext.type_id());
into.put_u16(ext.size_words());
ext.serialize(into);
}
}
}
ControlTypes::Ack(AckControlInfo {
ack_number,
rtt,
rtt_variance,
buffer_available,
packet_recv_rate,
est_link_cap,
..
}) => {
into.put_u32(ack_number.as_raw());
into.put_i32(rtt.map(|t| t.as_micros()).unwrap_or(10_000));
into.put_i32(rtt_variance.map(|t| t.as_micros()).unwrap_or(50_000));
into.put_i32(buffer_available.unwrap_or(8175));
into.put_u32(packet_recv_rate.unwrap_or(10_000));
into.put_i32(est_link_cap.unwrap_or(1_000));
}
ControlTypes::Nak(ref n) => {
for &loss in n {
into.put_u32(loss);
}
}
ControlTypes::DropRequest { .. } => unimplemented!(),
ControlTypes::Ack2(_) | ControlTypes::Shutdown | ControlTypes::KeepAlive => {
into.put_u32(0x0);
}
ControlTypes::Srt(srt) => {
srt.serialize(into);
}
};
}
}
impl Debug for ControlTypes {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
match self {
ControlTypes::Handshake(hs) => write!(f, "{:?}", hs),
ControlTypes::KeepAlive => write!(f, "KeepAlive"),
ControlTypes::Ack(AckControlInfo {
ack_seq_num,
ack_number,
rtt,
rtt_variance,
buffer_available,
packet_recv_rate,
est_link_cap,
}) => {
write!(f, "Ack(asn={} an={}", ack_seq_num, ack_number,)?;
if let Some(rtt) = rtt {
write!(f, " rtt={}", rtt.as_micros())?;
}
if let Some(rttvar) = rtt_variance {
write!(f, " rttvar={}", rttvar.as_micros())?;
}
if let Some(buf) = buffer_available {
write!(f, " buf_av={}", buf)?;
}
if let Some(prr) = packet_recv_rate {
write!(f, " pack_rr={}", prr)?;
}
if let Some(link_cap) = est_link_cap {
write!(f, " link_cap={}", link_cap)?;
}
write!(f, ")")?;
Ok(())
}
ControlTypes::Nak(nak) => {
write!(f, "Nak({:?})", nak)
}
ControlTypes::Shutdown => write!(f, "Shutdown"),
ControlTypes::Ack2(ackno) => write!(f, "Ack2({})", ackno),
ControlTypes::DropRequest {
msg_to_drop,
first,
last,
} => write!(f, "DropReq(msg={} {}-{})", msg_to_drop, first, last),
ControlTypes::Srt(srt) => write!(f, "{:?}", srt),
}
}
}
impl Debug for HandshakeControlInfo {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "HS {:?} from({:?})", self.shake_type, self.socket_id)
}
}
impl ShakeType {
pub fn from_i32(num: i32) -> Result<ShakeType, i32> {
match num {
1 => Ok(ShakeType::Induction),
0 => Ok(ShakeType::Waveahand),
-1 => Ok(ShakeType::Conclusion),
-2 => Ok(ShakeType::Agreement),
i => Err(i),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{SeqNumber, SocketID, SrtVersion};
use std::io::Cursor;
use std::time::Duration;
#[test]
fn handshake_ser_des_test() {
let pack = ControlPacket {
timestamp: TimeStamp::from_micros(0),
dest_sockid: SocketID(0),
control_type: ControlTypes::Handshake(HandshakeControlInfo {
init_seq_num: SeqNumber::new_truncate(1_827_131),
max_packet_size: 1500,
max_flow_size: 25600,
shake_type: ShakeType::Conclusion,
socket_id: SocketID(1231),
syn_cookie: 0,
peer_addr: "127.0.0.1".parse().unwrap(),
info: HandshakeVSInfo::V5 {
crypto_size: 0,
ext_hs: Some(SrtControlPacket::HandshakeResponse(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::NAKREPORT | SrtShakeFlags::TSBPDSND,
peer_latency: Duration::from_millis(3000),
latency: Duration::from_millis(12345),
})),
ext_km: None,
ext_config: None,
},
}),
};
let mut buf = vec![];
pack.serialize(&mut buf);
let des = ControlPacket::parse(&mut Cursor::new(buf)).unwrap();
assert_eq!(pack, des);
}
#[test]
fn ack_ser_des_test() {
let pack = ControlPacket {
timestamp: TimeStamp::from_micros(113_703),
dest_sockid: SocketID(2_453_706_529),
control_type: ControlTypes::Ack(AckControlInfo {
ack_seq_num: 1,
ack_number: SeqNumber::new_truncate(282_049_186),
rtt: Some(TimeSpan::from_micros(10_002)),
rtt_variance: Some(TimeSpan::from_micros(1000)),
buffer_available: Some(1314),
packet_recv_rate: Some(0),
est_link_cap: Some(0),
}),
};
let mut buf = vec![];
pack.serialize(&mut buf);
let des = ControlPacket::parse(&mut Cursor::new(buf)).unwrap();
assert_eq!(pack, des);
}
#[test]
fn ack2_ser_des_test() {
let pack = ControlPacket {
timestamp: TimeStamp::from_micros(125_812),
dest_sockid: SocketID(8313),
control_type: ControlTypes::Ack2(831),
};
assert_eq!(pack.control_type.additional_info(), 831);
let mut buf = vec![];
pack.serialize(&mut buf);
assert_eq!((u32::from(buf[6]) << 8) + u32::from(buf[7]), 831);
let des = ControlPacket::parse(&mut Cursor::new(buf)).unwrap();
assert_eq!(pack, des);
}
#[test]
fn raw_srt_packet_test() {
let packet_data =
hex::decode("FFFF000000000000000189702BFFEFF2000103010000001E00000078").unwrap();
let packet = ControlPacket::parse(&mut Cursor::new(packet_data)).unwrap();
assert_eq!(
packet,
ControlPacket {
timestamp: TimeStamp::from_micros(100_720),
dest_sockid: SocketID(738_193_394),
control_type: ControlTypes::Srt(SrtControlPacket::Reject)
}
)
}
#[test]
fn raw_handshake_srt() {
let packet_data = hex::decode("8000000000000000000F9EC400000000000000050000000144BEA60D000005DC00002000FFFFFFFF3D6936B6E3E405DD0100007F00000000000000000000000000010003000103010000002F00780000").unwrap();
let packet = ControlPacket::parse(&mut Cursor::new(&packet_data[..])).unwrap();
assert_eq!(
packet,
ControlPacket {
timestamp: TimeStamp::from_micros(1_023_684),
dest_sockid: SocketID(0),
control_type: ControlTypes::Handshake(HandshakeControlInfo {
init_seq_num: SeqNumber(1_153_345_037),
max_packet_size: 1500,
max_flow_size: 8192,
shake_type: ShakeType::Conclusion,
socket_id: SocketID(1_030_305_462),
syn_cookie: -471_595_555,
peer_addr: "127.0.0.1".parse().unwrap(),
info: HandshakeVSInfo::V5 {
crypto_size: 0,
ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
version: SrtVersion::new(1, 3, 1),
flags: SrtShakeFlags::TSBPDSND
| SrtShakeFlags::TSBPDRCV
| SrtShakeFlags::HAICRYPT
| SrtShakeFlags::TLPKTDROP
| SrtShakeFlags::REXMITFLG,
peer_latency: Duration::from_millis(120),
latency: Duration::new(0, 0)
})),
ext_km: None,
ext_config: None
}
})
}
);
let mut buf = vec![];
packet.serialize(&mut buf);
assert_eq!(&buf[..], &packet_data[..]);
}
#[test]
fn raw_handshake_crypto() {
let packet_data = hex::decode("800000000000000000175E8A0000000000000005000000036FEFB8D8000005DC00002000FFFFFFFF35E790ED5D16CCEA0100007F00000000000000000000000000010003000103010000002F01F401F40003000E122029010000000002000200000004049D75B0AC924C6E4C9EC40FEB4FE973DB1D215D426C18A2871EBF77E2646D9BAB15DBD7689AEF60EC").unwrap();
let packet = ControlPacket::parse(&mut Cursor::new(&packet_data[..])).unwrap();
assert_eq!(
packet,
ControlPacket {
timestamp: TimeStamp::from_micros(1_531_530),
dest_sockid: SocketID(0),
control_type: ControlTypes::Handshake(HandshakeControlInfo {
init_seq_num: SeqNumber(1_877_981_400),
max_packet_size: 1_500,
max_flow_size: 8_192,
shake_type: ShakeType::Conclusion,
socket_id: SocketID(904_368_365),
syn_cookie: 1_561_775_338,
peer_addr: "127.0.0.1".parse().unwrap(),
info: HandshakeVSInfo::V5 {
crypto_size: 0,
ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
version: SrtVersion::new(1, 3, 1),
flags: SrtShakeFlags::TSBPDSND
| SrtShakeFlags::TSBPDRCV
| SrtShakeFlags::HAICRYPT
| SrtShakeFlags::TLPKTDROP
| SrtShakeFlags::REXMITFLG,
peer_latency: Duration::from_millis(500),
latency: Duration::from_millis(500)
})),
ext_km: Some(SrtControlPacket::KeyManagerRequest(SrtKeyMessage {
pt: 2,
sign: 8_233,
keki: 0,
cipher: CipherType::CTR,
auth: 0,
se: 2,
salt: hex::decode("9D75B0AC924C6E4C9EC40FEB4FE973DB").unwrap(),
even_key: Some(
hex::decode("1D215D426C18A2871EBF77E2646D9BAB").unwrap()
),
odd_key: None,
wrap_data: *b"\x15\xDB\xD7\x68\x9A\xEF\x60\xEC",
})),
ext_config: None
}
})
}
);
let mut buf = vec![];
packet.serialize(&mut buf);
assert_eq!(&buf[..], &packet_data[..])
}
#[test]
fn raw_handshake_crypto_pt2() {
let packet_data = hex::decode("8000000000000000000000000C110D94000000050000000374B7526E000005DC00002000FFFFFFFF18C1CED1F3819B720100007F00000000000000000000000000020003000103010000003F03E803E80004000E12202901000000000200020000000404D3B3D84BE1188A4EBDA4DA16EA65D522D82DE544E1BE06B6ED8128BF15AA4E18EC50EAA95546B101").unwrap();
let _packet = ControlPacket::parse(&mut Cursor::new(&packet_data[..])).unwrap();
}
#[test]
fn short_ack() {
let packet_data =
hex::decode("800200000000000e000246e5d96d5e1a389c24780000452900007bb000001fa9")
.unwrap();
let _cp = ControlPacket::parse(&mut Cursor::new(packet_data)).unwrap();
}
}