webrtc_srtp/session/
mod.rs

1#[cfg(test)]
2mod session_rtcp_test;
3#[cfg(test)]
4mod session_rtp_test;
5
6use std::collections::{HashMap, HashSet};
7use std::marker::{Send, Sync};
8use std::sync::Arc;
9
10use bytes::Bytes;
11use tokio::sync::{mpsc, Mutex};
12use util::conn::Conn;
13use util::marshal::*;
14
15use crate::config::*;
16use crate::context::*;
17use crate::error::{Error, Result};
18use crate::option::*;
19use crate::stream::*;
20
21const DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW: usize = 64;
22const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
23
24/// Session implements io.ReadWriteCloser and provides a bi-directional SRTP session
25/// SRTP itself does not have a design like this, but it is common in most applications
26/// for local/remote to each have their own keying material. This provides those patterns
27/// instead of making everyone re-implement
28pub struct Session {
29    local_context: Arc<Mutex<Context>>,
30    streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
31    #[allow(clippy::type_complexity)]
32    new_stream_rx: Arc<Mutex<mpsc::Receiver<(Arc<Stream>, Option<rtp::header::Header>)>>>,
33    close_stream_tx: mpsc::Sender<u32>,
34    close_session_tx: mpsc::Sender<()>,
35    pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
36    is_rtp: bool,
37}
38
39impl Session {
40    pub async fn new(
41        conn: Arc<dyn Conn + Send + Sync>,
42        config: Config,
43        is_rtp: bool,
44    ) -> Result<Self> {
45        let local_context = Context::new(
46            &config.keys.local_master_key,
47            &config.keys.local_master_salt,
48            config.profile,
49            config.local_rtp_options,
50            config.local_rtcp_options,
51        )?;
52
53        let mut remote_context = Context::new(
54            &config.keys.remote_master_key,
55            &config.keys.remote_master_salt,
56            config.profile,
57            if config.remote_rtp_options.is_none() {
58                Some(srtp_replay_protection(
59                    DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW,
60                ))
61            } else {
62                config.remote_rtp_options
63            },
64            if config.remote_rtcp_options.is_none() {
65                Some(srtcp_replay_protection(
66                    DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW,
67                ))
68            } else {
69                config.remote_rtcp_options
70            },
71        )?;
72
73        let streams_map = Arc::new(Mutex::new(HashMap::new()));
74        let (mut new_stream_tx, new_stream_rx) = mpsc::channel(8);
75        let (close_stream_tx, mut close_stream_rx) = mpsc::channel(8);
76        let (close_session_tx, mut close_session_rx) = mpsc::channel(8);
77        let udp_tx = Arc::clone(&conn);
78        let udp_rx = Arc::clone(&conn);
79        let cloned_streams_map = Arc::clone(&streams_map);
80        let cloned_close_stream_tx = close_stream_tx.clone();
81
82        tokio::spawn(async move {
83            let mut buf = vec![0u8; 8192];
84
85            loop {
86                let incoming_stream = Session::incoming(
87                    &udp_rx,
88                    &mut buf,
89                    &cloned_streams_map,
90                    &cloned_close_stream_tx,
91                    &mut new_stream_tx,
92                    &mut remote_context,
93                    is_rtp,
94                );
95                let close_stream = close_stream_rx.recv();
96                let close_session = close_session_rx.recv();
97
98                tokio::select! {
99                    result = incoming_stream => match result{
100                        Ok(()) => {},
101                        Err(err) => log::info!("{}", err),
102                    },
103                    opt = close_stream => if let Some(ssrc) = opt {
104                        Session::close_stream(&cloned_streams_map, ssrc).await
105                    },
106                    _ = close_session => break
107                }
108            }
109        });
110
111        Ok(Session {
112            local_context: Arc::new(Mutex::new(local_context)),
113            streams_map,
114            new_stream_rx: Arc::new(Mutex::new(new_stream_rx)),
115            close_stream_tx,
116            close_session_tx,
117            udp_tx,
118            is_rtp,
119        })
120    }
121
122    async fn close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32) {
123        let mut streams = streams_map.lock().await;
124        streams.remove(&ssrc);
125    }
126
127    async fn incoming(
128        udp_rx: &Arc<dyn Conn + Send + Sync>,
129        buf: &mut [u8],
130        streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
131        close_stream_tx: &mpsc::Sender<u32>,
132        new_stream_tx: &mut mpsc::Sender<(Arc<Stream>, Option<rtp::header::Header>)>,
133        remote_context: &mut Context,
134        is_rtp: bool,
135    ) -> Result<()> {
136        let n = udp_rx.recv(buf).await?;
137        if n == 0 {
138            return Err(Error::SessionEof);
139        }
140
141        let decrypted = if is_rtp {
142            remote_context.decrypt_rtp(&buf[0..n])?
143        } else {
144            remote_context.decrypt_rtcp(&buf[0..n])?
145        };
146
147        let mut buf = &decrypted[..];
148        let (ssrcs, header) = if is_rtp {
149            let header = rtp::header::Header::unmarshal(&mut buf)?;
150            (vec![header.ssrc], Some(header))
151        } else {
152            let pkts = rtcp::packet::unmarshal(&mut buf)?;
153            (destination_ssrc(&pkts), None)
154        };
155
156        for ssrc in ssrcs {
157            let (stream, is_new) =
158                Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
159                    .await;
160
161            if is_new {
162                log::trace!(
163                    "srtp session got new {} stream {}",
164                    if is_rtp { "rtp" } else { "rtcp" },
165                    ssrc
166                );
167                new_stream_tx
168                    .send((Arc::clone(&stream), header.clone()))
169                    .await?;
170            }
171
172            match stream.buffer.write(&decrypted).await {
173                Ok(_) => {}
174                Err(err) => {
175                    // Silently drop data when the buffer is full.
176                    if util::Error::ErrBufferFull != err {
177                        return Err(err.into());
178                    }
179                }
180            }
181        }
182
183        Ok(())
184    }
185
186    async fn get_or_create_stream(
187        streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
188        close_stream_tx: mpsc::Sender<u32>,
189        is_rtp: bool,
190        ssrc: u32,
191    ) -> (Arc<Stream>, bool) {
192        let mut streams = streams_map.lock().await;
193
194        if let Some(stream) = streams.get(&ssrc) {
195            (Arc::clone(stream), false)
196        } else {
197            let stream = Arc::new(Stream::new(ssrc, close_stream_tx, is_rtp));
198            streams.insert(ssrc, Arc::clone(&stream));
199            (stream, true)
200        }
201    }
202
203    /// open on the given SSRC to create a stream, it can be used
204    /// if you want a certain SSRC, but don't want to wait for Accept
205    pub async fn open(&self, ssrc: u32) -> Arc<Stream> {
206        let (stream, _) = Session::get_or_create_stream(
207            &self.streams_map,
208            self.close_stream_tx.clone(),
209            self.is_rtp,
210            ssrc,
211        )
212        .await;
213
214        stream
215    }
216
217    /// accept returns a stream to handle RTCP for a single SSRC
218    pub async fn accept(&self) -> Result<(Arc<Stream>, Option<rtp::header::Header>)> {
219        let mut new_stream_rx = self.new_stream_rx.lock().await;
220
221        new_stream_rx
222            .recv()
223            .await
224            .ok_or(Error::SessionSrtpAlreadyClosed)
225    }
226
227    pub async fn close(&self) -> Result<()> {
228        self.close_session_tx.send(()).await?;
229
230        Ok(())
231    }
232
233    pub async fn write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize> {
234        if self.is_rtp != is_rtp {
235            return Err(Error::SessionRtpRtcpTypeMismatch);
236        }
237
238        let encrypted = {
239            let mut local_context = self.local_context.lock().await;
240
241            if is_rtp {
242                local_context.encrypt_rtp(buf)?
243            } else {
244                local_context.encrypt_rtcp(buf)?
245            }
246        };
247
248        Ok(self.udp_tx.send(&encrypted).await?)
249    }
250
251    pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize> {
252        let raw = pkt.marshal()?;
253        self.write(&raw, true).await
254    }
255
256    pub async fn write_rtcp(
257        &self,
258        pkt: &(dyn rtcp::packet::Packet + Send + Sync),
259    ) -> Result<usize> {
260        let raw = pkt.marshal()?;
261        self.write(&raw, false).await
262    }
263}
264
265/// create a list of Destination SSRCs
266/// that's a superset of all Destinations in the slice.
267fn destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32> {
268    let mut ssrc_set = HashSet::new();
269    for p in pkts {
270        for ssrc in p.destination_ssrc() {
271            ssrc_set.insert(ssrc);
272        }
273    }
274    ssrc_set.into_iter().collect()
275}