1use error::{Error, SapResult};
19use lazy_static::lazy_static;
20use murmur3::murmur3_32;
21use sdp::SessionDescription;
22use socket2::{Domain, Protocol, SockAddr, Socket, Type};
23use std::{
24 io::Cursor,
25 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
26 time::{Duration, SystemTime, UNIX_EPOCH},
27};
28use tokio::{net::UdpSocket, select, spawn, sync::mpsc, time::interval};
29
30pub mod error;
31
32const DEFAULT_PAYLOAD_TYPE: &str = "application/sdp";
33const DEFAULT_SAP_PORT: u16 = 9875;
34const DEFAULT_MULTICAST_ADDRESS: &str = "239.255.255.255";
35
36lazy_static! {
37 static ref HASH_SEED: u32 = SystemTime::now()
38 .duration_since(UNIX_EPOCH)
39 .expect("something is wrong with the system clock")
40 .as_secs() as u32;
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionAnnouncement {
45 pub deletion: bool,
46 pub encrypted: bool,
47 pub compressed: bool,
48 pub msg_id_hash: u16,
49 pub auth_data: Option<String>,
50 pub originating_source: IpAddr,
51 pub payload_type: Option<String>,
52 pub sdp: SessionDescription,
53}
54
55impl SessionAnnouncement {
56 pub fn new(sdp: SessionDescription) -> SapResult<Self> {
57 Ok(Self {
58 deletion: false,
59 encrypted: false,
60 compressed: false,
61 msg_id_hash: sdp_hash(&sdp),
62 auth_data: None,
63 originating_source: sdp.origin.unicast_address.parse()?,
64 payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
65 sdp,
66 })
67 }
68}
69
70pub struct Sap {
71 socket: UdpSocket,
72 multicast_addr: SocketAddr,
73 deletion_announcement: Option<SessionAnnouncement>,
74}
75
76impl Sap {
77 pub async fn new() -> SapResult<Self> {
78 let multicast_addr = SocketAddr::new(
79 IpAddr::V4(DEFAULT_MULTICAST_ADDRESS.parse()?),
80 DEFAULT_SAP_PORT,
81 );
82 let socket = create_socket().await?;
83
84 Ok(Sap {
85 socket,
86 multicast_addr,
87 deletion_announcement: None,
88 })
89 }
90
91 pub async fn discover_sessions(self) -> mpsc::Receiver<SapResult<SessionAnnouncement>> {
92 let mut buf = [0; 1024];
93
94 let (tx, rx) = mpsc::channel(10);
95
96 spawn(async move {
97 loop {
98 match self.socket.recv(&mut buf).await {
99 Ok(len) => {
100 let msg = decode_sap(&buf[..len]);
101 if let Err(e) = tx.send(msg).await {
102 log::error!("Error forwarding SAP message error: {e}");
103 break;
104 }
105 }
106 Err(e) => {
107 if let Err(e) = tx.send(Err(Error::IoError(e))).await {
108 log::error!("Error forwarding SAP message error: {e}");
109 }
110 break;
111 }
112 }
113 }
114 });
115
116 rx
117 }
118
119 pub async fn announce_session(&mut self, announcement: SessionAnnouncement) -> SapResult<()> {
120 self.delete_session().await?;
121
122 let mut deletion_announcement = announcement.clone();
123 deletion_announcement.deletion = true;
124 self.deletion_announcement = Some(deletion_announcement);
125
126 let mut interval = interval(Duration::from_secs(5));
127
128 loop {
129 select! {
133 _ = interval.tick() => self.send_announcement(&announcement).await?,
134 }
135 }
136 }
137
138 pub async fn delete_session(&mut self) -> SapResult<()> {
139 if let Some(deletion_announcement) = self.deletion_announcement.take() {
140 log::info!("Deleting active session.");
141 let msg = encode_sap(&deletion_announcement);
142 self.socket.send_to(&msg, &self.multicast_addr).await?;
143 } else {
144 log::debug!("No session active, nothing to delete.");
145 }
146
147 Ok(())
148 }
149
150 async fn send_announcement(&self, announcement: &SessionAnnouncement) -> SapResult<()> {
151 log::info!("Broadcasting session description.");
152 let msg = encode_sap(announcement);
153 self.socket.send_to(&msg, &self.multicast_addr).await?;
154 Ok(())
155 }
156}
157
158pub fn decode_sap(msg: &[u8]) -> SapResult<SessionAnnouncement> {
159 let mut min_length = 4;
160
161 if msg.len() < min_length {
162 return Err(Error::MalformedPacket(msg.to_owned()));
163 }
164
165 let header = msg[0];
166 let auth_len = msg[1];
167 let msg_id_hash = u16::from_be_bytes([msg[2], msg[3]]);
168
169 let ipv6 = (header & 0b00001000) >> 3 == 1;
170 let deletion = (header & 0b00000100) >> 2 == 1;
171 let encrypted = (header & 0b00000010) >> 1 == 1;
172 let compressed = header & 0b00000001 == 1;
173
174 if encrypted {
176 return Err(Error::NotImplemented("encryption"));
177 }
178 if compressed {
180 return Err(Error::NotImplemented("encryption"));
181 }
182
183 if ipv6 {
184 min_length += 16;
185 } else {
186 min_length += 4;
187 }
188
189 if msg.len() < min_length {
190 return Err(Error::MalformedPacket(msg.to_owned()));
191 }
192
193 let originating_source = if ipv6 {
194 let bits = u128::from_be_bytes([
195 msg[4], msg[5], msg[6], msg[7], msg[8], msg[9], msg[10], msg[11], msg[12], msg[13],
196 msg[14], msg[15], msg[16], msg[17], msg[18], msg[19],
197 ]);
198 IpAddr::V6(Ipv6Addr::from_bits(bits))
199 } else {
200 let bits = u32::from_be_bytes([msg[4], msg[5], msg[6], msg[7]]);
201 IpAddr::V4(Ipv4Addr::from_bits(bits))
202 };
203
204 let auth_data_start = min_length;
205
206 min_length += auth_len as usize;
207
208 if msg.len() <= min_length {
209 return Err(Error::MalformedPacket(msg.to_owned()));
210 }
211
212 let auth_data = if auth_len > 0 {
213 Some(String::from_utf8_lossy(&msg[auth_data_start..min_length]).to_string())
214 } else {
215 None
216 };
217
218 let payload = String::from_utf8_lossy(&msg[min_length..]).to_string();
219 let split: Vec<&str> = payload.split('\0').collect();
220
221 let payload_type = if split.len() >= 2 {
222 Some(split[0].to_owned())
223 } else {
224 None
225 };
226
227 let payload = if split.len() == 1 {
228 split[0]
229 } else {
230 &split[1..].join("\0")
231 };
232
233 let sdp = SessionDescription::unmarshal(&mut Cursor::new(payload))?;
234
235 Ok(SessionAnnouncement {
236 deletion,
237 encrypted,
238 compressed,
239 msg_id_hash,
240 auth_data,
241 originating_source,
242 payload_type,
243 sdp,
244 })
245}
246
247pub fn encode_sap(msg: &SessionAnnouncement) -> Vec<u8> {
248 let v = 1u8;
249 let (a, originating_source): (u8, &[u8]) = match msg.originating_source {
250 IpAddr::V4(addr) => (0u8, &addr.octets()),
251 IpAddr::V6(addr) => (1u8, &addr.octets()),
252 };
253 let r = 0u8;
254 let t = if msg.deletion { 1u8 } else { 0u8 };
255 let e = if msg.encrypted { 1u8 } else { 0u8 };
256 let c = if msg.compressed { 1u8 } else { 0u8 };
257 let header = v << 5 | a << 4 | r << 3 | t << 2 | e << 1 | c;
258 let auth_len = msg
259 .auth_data
260 .as_ref()
261 .map(|d| d.as_bytes().len())
262 .unwrap_or(0) as u8;
263 let msg_id_hash = msg.msg_id_hash.to_be_bytes();
264
265 let mut data = Vec::new();
266 data.push(header);
267 data.push(auth_len);
268 data.extend_from_slice(&msg_id_hash);
269 data.extend_from_slice(originating_source);
270 if let Some(auth_data) = &msg.auth_data {
271 data.extend_from_slice(auth_data.as_bytes());
272 }
273 if let Some(payload_type) = &msg.payload_type {
274 data.extend_from_slice(payload_type.as_bytes());
275 data.push(b'\0');
276 }
277 data.extend_from_slice(msg.sdp.marshal().as_bytes());
278
279 data
280}
281
282fn sdp_hash(sdp: &SessionDescription) -> u16 {
283 murmur3_32(&mut Cursor::new(sdp.marshal()), *HASH_SEED).unwrap_or(0) as u16
284}
285
286async fn create_socket() -> SapResult<UdpSocket> {
287 let multicast_addr: Ipv4Addr = DEFAULT_MULTICAST_ADDRESS.parse()?;
288 let local_ip = Ipv4Addr::UNSPECIFIED;
289 let local_addr = SocketAddr::new(IpAddr::V4(local_ip), DEFAULT_SAP_PORT);
290
291 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
292 socket.set_reuse_address(true)?;
293 socket.set_reuse_port(true)?;
294 socket.set_nonblocking(true)?;
295 socket.bind(&SockAddr::from(local_addr))?;
296 socket.join_multicast_v4(&multicast_addr, &local_ip)?;
297
298 let tokio_socket = UdpSocket::from_std(socket.into())?;
299
300 Ok(tokio_socket)
301}
302
303#[cfg(test)]
304mod tests {
305
306 use super::*;
307
308 #[test]
309 fn sdp_gets_hashed_correctly() {
310 let sdp = SessionDescription::unmarshal(&mut Cursor::new(
311 "v=0
312o=- 123456 123458 IN IP4 10.0.1.2
313s=My sample flow
314i=4 channels: c1, c2, c3, c4
315t=0 0
316a=recvonly
317m=audio 5004 RTP/AVP 98
318c=IN IP4 239.69.11.44/32
319a=rtpmap:98 L24/48000/4
320a=ptime:1
321a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
322a=mediaclk:direct=0",
323 ))
324 .unwrap();
325 assert!(sdp_hash(&sdp) != 0);
326 }
327
328 #[test]
329 fn encode_decode_roundtrip_is_successful() {
330 let sdp = "v=0
331o=- 123456 123458 IN IP4 10.0.1.2
332s=My sample flow
333i=4 channels: c1, c2, c3, c4
334t=0 0
335a=recvonly
336m=audio 5004 RTP/AVP 98
337c=IN IP4 239.69.11.44/32
338a=rtpmap:98 L24/48000/4
339a=ptime:1
340a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
341a=mediaclk:direct=0
342";
343
344 let sa = SessionAnnouncement {
345 auth_data: None,
346 payload_type: None,
347 compressed: false,
348 deletion: true,
349 encrypted: false,
350 msg_id_hash: 1234,
351 originating_source: "127.0.0.1".parse().unwrap(),
352 sdp: SessionDescription::unmarshal(&mut Cursor::new(sdp)).unwrap(),
353 };
354
355 let sa_msg = encode_sap(&sa);
356
357 let decoded = decode_sap(&sa_msg).unwrap();
358
359 assert_eq!(sa.auth_data, decoded.auth_data);
360 assert_eq!(sa.compressed, decoded.compressed);
361 assert_eq!(sa.deletion, decoded.deletion);
362 assert_eq!(sa.encrypted, decoded.encrypted);
363 assert_eq!(sa.msg_id_hash, decoded.msg_id_hash);
364 assert_eq!(sa.originating_source, decoded.originating_source);
365 assert_eq!(sa.payload_type, decoded.payload_type);
366 assert_eq!(sa.sdp.marshal().replace('\r', ""), sdp);
367 }
368}