1use error::{Error, SapResult};
19use lazy_static::lazy_static;
20use murmur3::murmur3_32;
21use sdp::SessionDescription;
22use socket2::{Domain, Protocol, SockAddr, Socket, Type};
23use std::{
24 collections::HashMap,
25 io::Cursor,
26 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
27 time::{Duration, SystemTime, UNIX_EPOCH},
28};
29use tokio::{
30 net::UdpSocket,
31 select, spawn,
32 sync::{mpsc, oneshot},
33 time::interval,
34};
35
36pub mod error;
37
38const DEFAULT_PAYLOAD_TYPE: &str = "application/sdp";
39const DEFAULT_SAP_PORT: u16 = 9875;
40const DEFAULT_MULTICAST_ADDRESS: &str = "239.255.255.255";
41
42lazy_static! {
43 static ref HASH_SEED: u32 = SystemTime::now()
44 .duration_since(UNIX_EPOCH)
45 .expect("something is wrong with the system clock")
46 .as_secs() as u32;
47}
48
49#[derive(Debug, Clone)]
50pub struct SessionAnnouncement {
51 pub deletion: bool,
52 pub encrypted: bool,
53 pub compressed: bool,
54 pub msg_id_hash: u16,
55 pub auth_data: Option<String>,
56 pub originating_source: IpAddr,
57 pub payload_type: Option<String>,
58 pub sdp: SessionDescription,
59}
60
61impl SessionAnnouncement {
62 pub fn new(sdp: SessionDescription) -> SapResult<Self> {
63 Ok(Self {
64 deletion: false,
65 encrypted: false,
66 compressed: false,
67 msg_id_hash: sdp_hash(&sdp),
68 auth_data: None,
69 originating_source: sdp.origin.unicast_address.parse()?,
70 payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
71 sdp,
72 })
73 }
74
75 pub fn deletion(sdp: SessionDescription) -> SapResult<Self> {
76 Ok(Self {
77 deletion: true,
78 encrypted: false,
79 compressed: false,
80 msg_id_hash: sdp_hash(&sdp),
81 auth_data: None,
82 originating_source: sdp.origin.unicast_address.parse()?,
83 payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
84 sdp,
85 })
86 }
87}
88
89pub struct SapActor {
90 socket: UdpSocket,
91 multicast_addr: SocketAddr,
92 active_sessions: HashMap<u16, SessionAnnouncement>,
93 foreign_sessions: HashMap<u16, SessionAnnouncement>,
94 deletion_announcements: HashMap<u16, SessionAnnouncement>,
95 event_tx: mpsc::Sender<Event>,
96 msg_rx: mpsc::Receiver<Message>,
97}
98
99pub enum Event {
100 SessionFound(SessionAnnouncement),
101 SessionLost(SessionAnnouncement),
102}
103
104enum Message {
105 AnnounceSession(Box<SessionAnnouncement>, oneshot::Sender<SapResult<()>>),
106 DeleteSession(u16, oneshot::Sender<SapResult<()>>),
107}
108
109impl SapActor {
110 async fn run(mut self) {
111 let mut buf = [0; 1024];
112
113 loop {
114 select! {
115 Some(msg) = self.msg_rx.recv() => {
116 match msg {
117 Message::AnnounceSession(sa, tx) => {
118 tx.send(self.announce_session(*sa).await).ok();
119 },
120 Message::DeleteSession(hash, tx) => {
121 tx.send(self.delete_session(hash).await).ok();
122 },
123 }
124 },
125 Ok(len) = async {
126 log::debug!("receiving SAP broadcast message …");
127 let recv = self.socket.recv(&mut buf).await;
128 log::debug!("broadcast message received");
129 recv
130 } => self.forward_announcement(&buf[0..len]).await,
131 else => break,
132 }
133 }
134 }
135
136 async fn forward_announcement(&self, buf: &[u8]) {
137 log::debug!("forwarding SAP message");
138 match decode_sap(buf) {
139 Ok(sap) => {
140 let event = if sap.deletion {
141 Event::SessionLost(sap)
142 } else {
143 Event::SessionFound(sap)
144 };
145 if let Err(e) = self.event_tx.send(event).await {
146 log::error!("Error forwarding SAP message error: {e}");
147 } else {
148 log::debug!("SAP message forwarded");
149 }
150 }
151 Err(e) => {
152 log::error!("error decoding SAP message: {e}");
153 }
154 }
155 }
156
157 async fn announce_session(&mut self, announcement: SessionAnnouncement) -> SapResult<()> {
158 self.delete_session(announcement.msg_id_hash).await?;
159
160 let mut deletion_announcement = announcement.clone();
161 deletion_announcement.deletion = true;
162 self.deletion_announcements
163 .insert(deletion_announcement.msg_id_hash, deletion_announcement);
164
165 let mut interval = interval(Duration::from_secs(5));
166
167 loop {
168 select! {
172 _ = interval.tick() => self.send_announcement(&announcement).await?,
173 }
174 }
175 }
176
177 async fn delete_session(&mut self, hash: u16) -> SapResult<()> {
178 if let Some(deletion_announcement) = self.deletion_announcements.remove(&hash) {
179 log::info!("Deleting active session {hash}.");
180 let msg = encode_sap(&deletion_announcement);
181 self.socket.send_to(&msg, &self.multicast_addr).await?;
182 } else {
183 log::debug!("No session active, nothing to delete.");
184 }
185
186 Ok(())
187 }
188
189 async fn send_announcement(&self, announcement: &SessionAnnouncement) -> SapResult<()> {
190 log::info!("Broadcasting session description.");
191 let msg = encode_sap(announcement);
192 self.socket.send_to(&msg, &self.multicast_addr).await?;
193 Ok(())
194 }
195}
196
197#[derive(Clone)]
198pub struct Sap {
199 msg_tx: mpsc::Sender<Message>,
200}
201
202impl Sap {
203 pub async fn new() -> SapResult<(Self, mpsc::Receiver<Event>)> {
204 let multicast_addr = SocketAddr::new(
205 IpAddr::V4(DEFAULT_MULTICAST_ADDRESS.parse()?),
206 DEFAULT_SAP_PORT,
207 );
208 let socket = create_socket().await?;
209
210 let active_sessions = HashMap::new();
211 let foreign_sessions = HashMap::new();
212 let deletion_announcements = HashMap::new();
213
214 let (event_tx, event_rx) = mpsc::channel(1);
215 let (msg_tx, msg_rx) = mpsc::channel(100);
216
217 let actor = SapActor {
218 socket,
219 multicast_addr,
220 active_sessions,
221 foreign_sessions,
222 deletion_announcements,
223 event_tx,
224 msg_rx,
225 };
226
227 spawn(actor.run());
228
229 Ok((Sap { msg_tx }, event_rx))
230 }
231
232 pub async fn announce_session(&self, sd: SessionDescription) -> SapResult<()> {
233 let sa = SessionAnnouncement::new(sd)?;
234 let (tx, rx) = oneshot::channel();
235 self.msg_tx
236 .send(Message::AnnounceSession(Box::new(sa), tx))
237 .await?;
238 rx.await?
239 }
240
241 pub async fn delete_session(&self, hash: u16) -> SapResult<()> {
242 let (tx, rx) = oneshot::channel();
243 self.msg_tx.send(Message::DeleteSession(hash, tx)).await?;
244 rx.await?
245 }
246}
247
248pub fn decode_sap(msg: &[u8]) -> SapResult<SessionAnnouncement> {
249 let mut min_length = 4;
250
251 if msg.len() < min_length {
252 return Err(Error::MalformedPacket(msg.to_owned()));
253 }
254
255 let header = msg[0];
256 let auth_len = msg[1];
257 let msg_id_hash = u16::from_be_bytes([msg[2], msg[3]]);
258
259 let ipv6 = (header & 0b00001000) >> 3 == 1;
260 let deletion = (header & 0b00000100) >> 2 == 1;
261 let encrypted = (header & 0b00000010) >> 1 == 1;
262 let compressed = header & 0b00000001 == 1;
263
264 if encrypted {
266 return Err(Error::NotImplemented("encryption"));
267 }
268 if compressed {
270 return Err(Error::NotImplemented("encryption"));
271 }
272
273 if ipv6 {
274 min_length += 16;
275 } else {
276 min_length += 4;
277 }
278
279 if msg.len() < min_length {
280 return Err(Error::MalformedPacket(msg.to_owned()));
281 }
282
283 let originating_source = if ipv6 {
284 let bits = u128::from_be_bytes([
285 msg[4], msg[5], msg[6], msg[7], msg[8], msg[9], msg[10], msg[11], msg[12], msg[13],
286 msg[14], msg[15], msg[16], msg[17], msg[18], msg[19],
287 ]);
288 IpAddr::V6(Ipv6Addr::from_bits(bits))
289 } else {
290 let bits = u32::from_be_bytes([msg[4], msg[5], msg[6], msg[7]]);
291 IpAddr::V4(Ipv4Addr::from_bits(bits))
292 };
293
294 let auth_data_start = min_length;
295
296 min_length += auth_len as usize;
297
298 if msg.len() <= min_length {
299 return Err(Error::MalformedPacket(msg.to_owned()));
300 }
301
302 let auth_data = if auth_len > 0 {
303 Some(String::from_utf8_lossy(&msg[auth_data_start..min_length]).to_string())
304 } else {
305 None
306 };
307
308 let payload = String::from_utf8_lossy(&msg[min_length..]).to_string();
309 let split: Vec<&str> = payload.split('\0').collect();
310
311 let payload_type = if split.len() >= 2 {
312 Some(split[0].to_owned())
313 } else {
314 None
315 };
316
317 let payload = if split.len() == 1 {
318 split[0]
319 } else {
320 &split[1..].join("\0")
321 };
322
323 let sdp = SessionDescription::unmarshal(&mut Cursor::new(payload))?;
324
325 Ok(SessionAnnouncement {
326 deletion,
327 encrypted,
328 compressed,
329 msg_id_hash,
330 auth_data,
331 originating_source,
332 payload_type,
333 sdp,
334 })
335}
336
337pub fn encode_sap(msg: &SessionAnnouncement) -> Vec<u8> {
338 let v = 1u8;
339 let (a, originating_source): (u8, &[u8]) = match msg.originating_source {
340 IpAddr::V4(addr) => (0u8, &addr.octets()),
341 IpAddr::V6(addr) => (1u8, &addr.octets()),
342 };
343 let r = 0u8;
344 let t = if msg.deletion { 1u8 } else { 0u8 };
345 let e = if msg.encrypted { 1u8 } else { 0u8 };
346 let c = if msg.compressed { 1u8 } else { 0u8 };
347 let header = v << 5 | a << 4 | r << 3 | t << 2 | e << 1 | c;
348 let auth_len = msg.auth_data.as_ref().map(|d| d.len()).unwrap_or(0) as u8;
349 let msg_id_hash = msg.msg_id_hash.to_be_bytes();
350
351 let mut data = Vec::new();
352 data.push(header);
353 data.push(auth_len);
354 data.extend_from_slice(&msg_id_hash);
355 data.extend_from_slice(originating_source);
356 if let Some(auth_data) = &msg.auth_data {
357 data.extend_from_slice(auth_data.as_bytes());
358 }
359 if let Some(payload_type) = &msg.payload_type {
360 data.extend_from_slice(payload_type.as_bytes());
361 data.push(b'\0');
362 }
363 log::info!("marshalling sdp ...");
364 data.extend_from_slice(msg.sdp.marshal().as_bytes());
365 log::info!("marshalling sdp done.");
366
367 data
368}
369
370fn sdp_hash(sdp: &SessionDescription) -> u16 {
371 log::info!("computing message hash ...");
372 let res = murmur3_32(&mut Cursor::new(sdp.marshal()), *HASH_SEED).unwrap_or(0) as u16;
373 log::info!("computing message hash done");
374 res
375}
376
377async fn create_socket() -> SapResult<UdpSocket> {
378 let multicast_addr: Ipv4Addr = DEFAULT_MULTICAST_ADDRESS.parse()?;
379 let local_ip = Ipv4Addr::UNSPECIFIED;
380 let local_addr = SocketAddr::new(IpAddr::V4(local_ip), DEFAULT_SAP_PORT);
381
382 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
383 socket.set_reuse_address(true)?;
384 socket.set_nonblocking(true)?;
385 socket.bind(&SockAddr::from(local_addr))?;
386 socket.join_multicast_v4(&multicast_addr, &local_ip)?;
387
388 let socket = UdpSocket::from_std(socket.into())?;
389
390 Ok(socket)
391}
392
393#[cfg(test)]
394mod tests {
395
396 use super::*;
397
398 #[test]
399 fn sdp_gets_hashed_correctly() {
400 let sdp = SessionDescription::unmarshal(&mut Cursor::new(
401 "v=0
402o=- 123456 123458 IN IP4 10.0.1.2
403s=My sample flow
404i=4 channels: c1, c2, c3, c4
405t=0 0
406a=recvonly
407m=audio 5004 RTP/AVP 98
408c=IN IP4 239.69.11.44/32
409a=rtpmap:98 L24/48000/4
410a=ptime:1
411a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
412a=mediaclk:direct=0",
413 ))
414 .unwrap();
415 assert!(sdp_hash(&sdp) != 0);
416 }
417
418 #[test]
419 fn encode_decode_roundtrip_is_successful() {
420 let sdp = "v=0
421o=- 123456 123458 IN IP4 10.0.1.2
422s=My sample flow
423i=4 channels: c1, c2, c3, c4
424t=0 0
425a=recvonly
426m=audio 5004 RTP/AVP 98
427c=IN IP4 239.69.11.44/32
428a=rtpmap:98 L24/48000/4
429a=ptime:1
430a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
431a=mediaclk:direct=0
432";
433
434 let sa = SessionAnnouncement {
435 auth_data: None,
436 payload_type: None,
437 compressed: false,
438 deletion: true,
439 encrypted: false,
440 msg_id_hash: 1234,
441 originating_source: "127.0.0.1".parse().unwrap(),
442 sdp: SessionDescription::unmarshal(&mut Cursor::new(sdp)).unwrap(),
443 };
444
445 let sa_msg = encode_sap(&sa);
446
447 let decoded = decode_sap(&sa_msg).unwrap();
448
449 assert_eq!(sa.auth_data, decoded.auth_data);
450 assert_eq!(sa.compressed, decoded.compressed);
451 assert_eq!(sa.deletion, decoded.deletion);
452 assert_eq!(sa.encrypted, decoded.encrypted);
453 assert_eq!(sa.msg_id_hash, decoded.msg_id_hash);
454 assert_eq!(sa.originating_source, decoded.originating_source);
455 assert_eq!(sa.payload_type, decoded.payload_type);
456 assert_eq!(sa.sdp.marshal().replace('\r', ""), sdp);
457 }
458}