sap_rs/
lib.rs

1/*
2 *  Copyright (C) 2024 Michael Bachmann
3 *
4 *  This program is free software: you can redistribute it and/or modify
5 *  it under the terms of the GNU Affero General Public License as published by
6 *  the Free Software Foundation, either version 3 of the License, or
7 *  (at your option) any later version.
8 *
9 *  This program is distributed in the hope that it will be useful,
10 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 *  GNU Affero General Public License for more details.
13 *
14 *  You should have received a copy of the GNU Affero General Public License
15 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
16 */
17
18use error::{Error, SapResult};
19use lazy_static::lazy_static;
20use murmur3::murmur3_32;
21use sdp::SessionDescription;
22use socket2::{Domain, Protocol, SockAddr, Socket, Type};
23use std::{
24    io::Cursor,
25    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
26    time::{Duration, SystemTime, UNIX_EPOCH},
27};
28use tokio::{net::UdpSocket, select, spawn, sync::mpsc, time::interval};
29
30pub mod error;
31
32const DEFAULT_PAYLOAD_TYPE: &str = "application/sdp";
33const DEFAULT_SAP_PORT: u16 = 9875;
34const DEFAULT_MULTICAST_ADDRESS: &str = "239.255.255.255";
35
36lazy_static! {
37    static ref HASH_SEED: u32 = SystemTime::now()
38        .duration_since(UNIX_EPOCH)
39        .expect("something is wrong with the system clock")
40        .as_secs() as u32;
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionAnnouncement {
45    pub deletion: bool,
46    pub encrypted: bool,
47    pub compressed: bool,
48    pub msg_id_hash: u16,
49    pub auth_data: Option<String>,
50    pub originating_source: IpAddr,
51    pub payload_type: Option<String>,
52    pub sdp: SessionDescription,
53}
54
55impl SessionAnnouncement {
56    pub fn new(sdp: SessionDescription) -> SapResult<Self> {
57        Ok(Self {
58            deletion: false,
59            encrypted: false,
60            compressed: false,
61            msg_id_hash: sdp_hash(&sdp),
62            auth_data: None,
63            originating_source: sdp.origin.unicast_address.parse()?,
64            payload_type: Some(DEFAULT_PAYLOAD_TYPE.to_owned()),
65            sdp,
66        })
67    }
68}
69
70pub struct Sap {
71    socket: UdpSocket,
72    multicast_addr: SocketAddr,
73    deletion_announcement: Option<SessionAnnouncement>,
74}
75
76impl Sap {
77    pub async fn new() -> SapResult<Self> {
78        let multicast_addr = SocketAddr::new(
79            IpAddr::V4(DEFAULT_MULTICAST_ADDRESS.parse()?),
80            DEFAULT_SAP_PORT,
81        );
82        let socket = create_socket().await?;
83
84        Ok(Sap {
85            socket,
86            multicast_addr,
87            deletion_announcement: None,
88        })
89    }
90
91    pub async fn discover_sessions(self) -> mpsc::Receiver<SapResult<SessionAnnouncement>> {
92        let mut buf = [0; 1024];
93
94        let (tx, rx) = mpsc::channel(10);
95
96        spawn(async move {
97            loop {
98                match self.socket.recv(&mut buf).await {
99                    Ok(len) => {
100                        let msg = decode_sap(&buf[..len]);
101                        if let Err(e) = tx.send(msg).await {
102                            log::error!("Error forwarding SAP message error: {e}");
103                            break;
104                        }
105                    }
106                    Err(e) => {
107                        if let Err(e) = tx.send(Err(Error::IoError(e))).await {
108                            log::error!("Error forwarding SAP message error: {e}");
109                        }
110                        break;
111                    }
112                }
113            }
114        });
115
116        rx
117    }
118
119    pub async fn announce_session(&mut self, announcement: SessionAnnouncement) -> SapResult<()> {
120        self.delete_session().await?;
121
122        let mut deletion_announcement = announcement.clone();
123        deletion_announcement.deletion = true;
124        self.deletion_announcement = Some(deletion_announcement);
125
126        let mut interval = interval(Duration::from_secs(5));
127
128        loop {
129            // TODO receive other announcements and update delay
130            // TODO send announcement in according intervals
131            //
132            select! {
133                _ = interval.tick() => self.send_announcement(&announcement).await?,
134            }
135        }
136    }
137
138    pub async fn delete_session(&mut self) -> SapResult<()> {
139        if let Some(deletion_announcement) = self.deletion_announcement.take() {
140            log::info!("Deleting active session.");
141            let msg = encode_sap(&deletion_announcement);
142            self.socket.send_to(&msg, &self.multicast_addr).await?;
143        } else {
144            log::debug!("No session active, nothing to delete.");
145        }
146
147        Ok(())
148    }
149
150    async fn send_announcement(&self, announcement: &SessionAnnouncement) -> SapResult<()> {
151        log::info!("Broadcasting session description.");
152        let msg = encode_sap(announcement);
153        self.socket.send_to(&msg, &self.multicast_addr).await?;
154        Ok(())
155    }
156}
157
158pub fn decode_sap(msg: &[u8]) -> SapResult<SessionAnnouncement> {
159    let mut min_length = 4;
160
161    if msg.len() < min_length {
162        return Err(Error::MalformedPacket(msg.to_owned()));
163    }
164
165    let header = msg[0];
166    let auth_len = msg[1];
167    let msg_id_hash = u16::from_be_bytes([msg[2], msg[3]]);
168
169    let ipv6 = (header & 0b00001000) >> 3 == 1;
170    let deletion = (header & 0b00000100) >> 2 == 1;
171    let encrypted = (header & 0b00000010) >> 1 == 1;
172    let compressed = header & 0b00000001 == 1;
173
174    // TODO implement decryption
175    if encrypted {
176        return Err(Error::NotImplemented("encryption"));
177    }
178    // TODO implement decompression
179    if compressed {
180        return Err(Error::NotImplemented("encryption"));
181    }
182
183    if ipv6 {
184        min_length += 16;
185    } else {
186        min_length += 4;
187    }
188
189    if msg.len() < min_length {
190        return Err(Error::MalformedPacket(msg.to_owned()));
191    }
192
193    let originating_source = if ipv6 {
194        let bits = u128::from_be_bytes([
195            msg[4], msg[5], msg[6], msg[7], msg[8], msg[9], msg[10], msg[11], msg[12], msg[13],
196            msg[14], msg[15], msg[16], msg[17], msg[18], msg[19],
197        ]);
198        IpAddr::V6(Ipv6Addr::from_bits(bits))
199    } else {
200        let bits = u32::from_be_bytes([msg[4], msg[5], msg[6], msg[7]]);
201        IpAddr::V4(Ipv4Addr::from_bits(bits))
202    };
203
204    let auth_data_start = min_length;
205
206    min_length += auth_len as usize;
207
208    if msg.len() <= min_length {
209        return Err(Error::MalformedPacket(msg.to_owned()));
210    }
211
212    let auth_data = if auth_len > 0 {
213        Some(String::from_utf8_lossy(&msg[auth_data_start..min_length]).to_string())
214    } else {
215        None
216    };
217
218    let payload = String::from_utf8_lossy(&msg[min_length..]).to_string();
219    let split: Vec<&str> = payload.split('\0').collect();
220
221    let payload_type = if split.len() >= 2 {
222        Some(split[0].to_owned())
223    } else {
224        None
225    };
226
227    let payload = if split.len() == 1 {
228        split[0]
229    } else {
230        &split[1..].join("\0")
231    };
232
233    let sdp = SessionDescription::unmarshal(&mut Cursor::new(payload))?;
234
235    Ok(SessionAnnouncement {
236        deletion,
237        encrypted,
238        compressed,
239        msg_id_hash,
240        auth_data,
241        originating_source,
242        payload_type,
243        sdp,
244    })
245}
246
247pub fn encode_sap(msg: &SessionAnnouncement) -> Vec<u8> {
248    let v = 1u8;
249    let (a, originating_source): (u8, &[u8]) = match msg.originating_source {
250        IpAddr::V4(addr) => (0u8, &addr.octets()),
251        IpAddr::V6(addr) => (1u8, &addr.octets()),
252    };
253    let r = 0u8;
254    let t = if msg.deletion { 1u8 } else { 0u8 };
255    let e = if msg.encrypted { 1u8 } else { 0u8 };
256    let c = if msg.compressed { 1u8 } else { 0u8 };
257    let header = v << 5 | a << 4 | r << 3 | t << 2 | e << 1 | c;
258    let auth_len = msg
259        .auth_data
260        .as_ref()
261        .map(|d| d.as_bytes().len())
262        .unwrap_or(0) as u8;
263    let msg_id_hash = msg.msg_id_hash.to_be_bytes();
264
265    let mut data = Vec::new();
266    data.push(header);
267    data.push(auth_len);
268    data.extend_from_slice(&msg_id_hash);
269    data.extend_from_slice(originating_source);
270    if let Some(auth_data) = &msg.auth_data {
271        data.extend_from_slice(auth_data.as_bytes());
272    }
273    if let Some(payload_type) = &msg.payload_type {
274        data.extend_from_slice(payload_type.as_bytes());
275        data.push(b'\0');
276    }
277    data.extend_from_slice(msg.sdp.marshal().as_bytes());
278
279    data
280}
281
282fn sdp_hash(sdp: &SessionDescription) -> u16 {
283    murmur3_32(&mut Cursor::new(sdp.marshal()), *HASH_SEED).unwrap_or(0) as u16
284}
285
286async fn create_socket() -> SapResult<UdpSocket> {
287    let multicast_addr: Ipv4Addr = DEFAULT_MULTICAST_ADDRESS.parse()?;
288    let local_ip = Ipv4Addr::UNSPECIFIED;
289    let local_addr = SocketAddr::new(IpAddr::V4(local_ip), DEFAULT_SAP_PORT);
290
291    let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
292    socket.set_reuse_address(true)?;
293    socket.set_reuse_port(true)?;
294    socket.set_nonblocking(true)?;
295    socket.bind(&SockAddr::from(local_addr))?;
296    socket.join_multicast_v4(&multicast_addr, &local_ip)?;
297
298    let tokio_socket = UdpSocket::from_std(socket.into())?;
299
300    Ok(tokio_socket)
301}
302
303#[cfg(test)]
304mod tests {
305
306    use super::*;
307
308    #[test]
309    fn sdp_gets_hashed_correctly() {
310        let sdp = SessionDescription::unmarshal(&mut Cursor::new(
311            "v=0
312o=- 123456 123458 IN IP4 10.0.1.2
313s=My sample flow
314i=4 channels: c1, c2, c3, c4
315t=0 0
316a=recvonly
317m=audio 5004 RTP/AVP 98
318c=IN IP4 239.69.11.44/32
319a=rtpmap:98 L24/48000/4
320a=ptime:1
321a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
322a=mediaclk:direct=0",
323        ))
324        .unwrap();
325        assert!(sdp_hash(&sdp) != 0);
326    }
327
328    #[test]
329    fn encode_decode_roundtrip_is_successful() {
330        let sdp = "v=0
331o=- 123456 123458 IN IP4 10.0.1.2
332s=My sample flow
333i=4 channels: c1, c2, c3, c4
334t=0 0
335a=recvonly
336m=audio 5004 RTP/AVP 98
337c=IN IP4 239.69.11.44/32
338a=rtpmap:98 L24/48000/4
339a=ptime:1
340a=ts-refclk:ptp=IEEE1588-2008:00-11-22-FF-FE-33-44-55:0
341a=mediaclk:direct=0
342";
343
344        let sa = SessionAnnouncement {
345            auth_data: None,
346            payload_type: None,
347            compressed: false,
348            deletion: true,
349            encrypted: false,
350            msg_id_hash: 1234,
351            originating_source: "127.0.0.1".parse().unwrap(),
352            sdp: SessionDescription::unmarshal(&mut Cursor::new(sdp)).unwrap(),
353        };
354
355        let sa_msg = encode_sap(&sa);
356
357        let decoded = decode_sap(&sa_msg).unwrap();
358
359        assert_eq!(sa.auth_data, decoded.auth_data);
360        assert_eq!(sa.compressed, decoded.compressed);
361        assert_eq!(sa.deletion, decoded.deletion);
362        assert_eq!(sa.encrypted, decoded.encrypted);
363        assert_eq!(sa.msg_id_hash, decoded.msg_id_hash);
364        assert_eq!(sa.originating_source, decoded.originating_source);
365        assert_eq!(sa.payload_type, decoded.payload_type);
366        assert_eq!(sa.sdp.marshal().replace('\r', ""), sdp);
367    }
368}