srt_tokio/listener/
mod.rs

1mod builder;
2mod session;
3mod state;
4
5use std::{io, sync::Arc};
6
7use futures::{channel::mpsc, prelude::*};
8use srt_protocol::settings::ConnInitSettings;
9use tokio::{net::UdpSocket, sync::oneshot, task::JoinHandle};
10
11use crate::net::bind_socket;
12
13use super::{net::PacketSocket, options::*, watch};
14
15pub use builder::SrtListenerBuilder;
16pub use session::ConnectionRequest;
17pub use srt_protocol::statistics::ListenerStatistics;
18
19#[derive(Debug)]
20pub struct SrtListener {
21    settings: ConnInitSettings,
22    statistics_receiver: watch::Receiver<ListenerStatistics>,
23    close_req: Option<oneshot::Sender<()>>,
24    task: JoinHandle<()>,
25}
26
27#[derive(Debug)]
28pub struct SrtIncoming {
29    request_receiver: mpsc::Receiver<ConnectionRequest>,
30}
31
32impl SrtListener {
33    pub fn builder() -> SrtListenerBuilder {
34        SrtListenerBuilder::default()
35    }
36
37    pub async fn bind(options: Valid<ListenerOptions>) -> Result<(Self, SrtIncoming), io::Error> {
38        let socket = bind_socket(&options.socket).await?;
39        Self::bind_with_socket(options, socket).await
40    }
41
42    pub async fn bind_with_socket(
43        options: Valid<ListenerOptions>,
44        socket: UdpSocket,
45    ) -> Result<(Self, SrtIncoming), io::Error> {
46        use state::SrtListenerState;
47        let socket_options = options.into_value().socket;
48        let local_address = socket.local_addr()?;
49        let socket = PacketSocket::from_socket(Arc::new(socket), 1024 * 1024);
50        let settings = ConnInitSettings::from(socket_options);
51        let (close_req, close_resp) = oneshot::channel();
52        let (request_sender, request_receiver) = mpsc::channel(100);
53        let (statistics_sender, statistics_receiver) = watch::channel();
54        let state = SrtListenerState::new(
55            socket,
56            local_address,
57            settings.clone(),
58            request_sender,
59            statistics_sender,
60            close_resp,
61        );
62        let task = tokio::spawn(async move {
63            state.run_loop().await;
64        });
65        Ok((
66            Self {
67                settings,
68                statistics_receiver,
69                close_req: Some(close_req),
70                task,
71            },
72            SrtIncoming { request_receiver },
73        ))
74    }
75
76    pub fn settings(&self) -> &ConnInitSettings {
77        &self.settings
78    }
79
80    pub fn statistics(&mut self) -> &mut (impl Stream<Item = ListenerStatistics> + Clone) {
81        &mut self.statistics_receiver
82    }
83
84    pub async fn close(&mut self) {
85        let _ = self.close_req.take().unwrap().send(());
86        (&mut self.task).await.unwrap();
87    }
88}
89
90impl SrtIncoming {
91    pub fn incoming(&mut self) -> &mut impl Stream<Item = ConnectionRequest> {
92        &mut self.request_receiver
93    }
94}
95
96impl Drop for SrtListener {
97    fn drop(&mut self) {
98        // TODO: we probably need to use a std::sync primitive to block until closed
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use std::time::{Duration, Instant};
105
106    use anyhow::Result;
107    use bytes::Bytes;
108    use futures::{channel::oneshot, future::join_all, prelude::*};
109    use log::{debug, info};
110
111    use crate::{access::*, SrtSocket};
112
113    use super::*;
114
115    #[tokio::test]
116    async fn accept_reject() -> Result<()> {
117        #[derive(Debug)]
118        enum Select {
119            Connection(Option<ConnectionRequest>),
120            Statistics(Option<ListenerStatistics>),
121            Finished,
122        }
123
124        let _ = pretty_env_logger::try_init();
125
126        let (finished_send, finished_recv) = oneshot::channel();
127
128        let listener = tokio::spawn(async {
129            let (mut server, mut incoming) =
130                SrtListener::builder().bind("127.0.0.1:4001").await.unwrap();
131            let mut statistics = server.statistics().clone().fuse();
132
133            let mut incoming = incoming.incoming().fuse();
134            let mut fused_finish = finished_recv.fuse();
135            loop {
136                let selection = futures::select!(
137                    request = incoming.next() => Select::Connection(request),
138                    stats = statistics.next() => Select::Statistics(stats),
139                    _ = fused_finish => Select::Finished,
140                );
141                match selection {
142                    Select::Connection(Some(request)) => {
143                        let stream_id = request.stream_id().unwrap();
144                        if stream_id.eq(&"reject".parse().unwrap()) {
145                            request.reject(RejectReason::User(42)).await.unwrap();
146                        } else {
147                            let mut sender = request.accept(None).await.unwrap();
148                            let mut stream = stream::iter(
149                                Some(Ok((Instant::now(), Bytes::from("hello")))).into_iter(),
150                            );
151                            tokio::spawn(async move {
152                                sender.send_all(&mut stream).await.unwrap();
153                                sender.close().await.unwrap();
154                                info!("Sent");
155                            });
156                        }
157                    }
158                    Select::Statistics(Some(stats)) => debug!("{:?}", stats),
159                    _ => {
160                        break;
161                    }
162                }
163            }
164        });
165
166        // connect 10 clients to it
167        let mut join_handles = vec![];
168        for i in 0..10 {
169            join_handles.push(tokio::spawn(async move {
170                info!("Calling: {}", i);
171                let address = "127.0.0.1:4001";
172                if i % 2 > 0 {
173                    let result = SrtSocket::builder().call(address, Some("reject")).await;
174                    assert!(result.is_err());
175                    debug!("Rejected: {}", i);
176                } else {
177                    let stream_id = format!("{i}").to_string();
178                    let mut receiver = SrtSocket::builder()
179                        .call(address, Some(&stream_id))
180                        .await
181                        .unwrap();
182                    info!("Accepted: {}", i);
183                    let first = receiver.next().await;
184                    assert_eq!(first.unwrap().unwrap().1, "hello");
185                    let second = receiver.next().await;
186                    assert!(second.is_none());
187                    info!("Received: {}", i);
188                }
189            }));
190        }
191
192        // close the multiplex server when all is done
193        join_all(join_handles).await;
194        info!("all finished");
195        finished_send.send(()).unwrap();
196        listener.await?;
197        Ok(())
198    }
199
200    #[tokio::test]
201    async fn accept_reject_encryption() -> Result<()> {
202        #[derive(Debug)]
203        enum Select {
204            Connection(Option<ConnectionRequest>),
205            Statistics(Option<ListenerStatistics>),
206            Finished,
207        }
208
209        let _ = pretty_env_logger::try_init();
210
211        let (finished_send, finished_recv) = oneshot::channel();
212
213        let listener = tokio::spawn(async {
214            let (mut server, mut incoming) = SrtListener::builder()
215                .encryption(0, "super secret passcode")
216                .bind("127.0.0.1:4002")
217                .await
218                .unwrap();
219            let mut statistics = server.statistics().clone().fuse();
220
221            let mut incoming = incoming.incoming().fuse();
222            let mut fused_finish = finished_recv.fuse();
223            loop {
224                let selection = futures::select!(
225                    request = incoming.next() => Select::Connection(request),
226                    stats = statistics.next() => Select::Statistics(stats),
227                    _ = fused_finish => Select::Finished,
228                );
229                match selection {
230                    Select::Connection(Some(request)) => {
231                        let stream_id = request.stream_id().expect("stream_id");
232                        if stream_id.eq(&"reject".parse().unwrap()) {
233                            request
234                                .reject(RejectReason::User(42))
235                                .await
236                                .expect("reject");
237                        } else {
238                            let mut sender = request.accept(None).await.expect("accept");
239                            let mut stream = stream::iter(
240                                Some(Ok((Instant::now(), Bytes::from("hello")))).into_iter(),
241                            );
242                            tokio::spawn(async move {
243                                sender.send_all(&mut stream).await.expect("send_all");
244                                sender.close().await.expect("close");
245                                info!("Sent");
246                            });
247                        }
248                    }
249                    Select::Statistics(Some(stats)) => debug!("{:?}", stats),
250                    _ => {
251                        break;
252                    }
253                }
254            }
255        });
256
257        // connect 10 clients to it
258        let mut join_handles = vec![];
259        for i in 0..10 {
260            join_handles.push(tokio::spawn(async move {
261                info!("Calling: {}", i);
262                let address = "127.0.0.1:4002";
263                if i % 2 == 0 {
264                    let result = SrtSocket::builder().call(address, Some("reject")).await;
265                    assert!(result.is_err());
266                    info!("Rejected: {}", i);
267                } else {
268                    let stream_id = format!("{i}").to_string();
269                    let mut receiver = SrtSocket::builder()
270                        .encryption(0, "super secret passcode")
271                        .call(address, Some(&stream_id))
272                        .await
273                        .expect("call");
274                    info!("Accepted: {}", i);
275                    let first = receiver.next().await;
276                    assert_eq!(first.expect("next error").expect("next no data").1, "hello");
277                    let second = receiver.next().await;
278                    assert!(second.is_none());
279                    info!("Received: {}", i);
280                }
281            }));
282        }
283
284        // close the multiplex server when all is done
285        join_all(join_handles).await;
286        info!("all finished");
287        finished_send.send(()).unwrap();
288        listener.await?;
289        Ok(())
290    }
291
292    #[tokio::test]
293    async fn multiplex_timeout() {
294        use bytes::Bytes;
295        use futures::{stream, SinkExt, StreamExt};
296        use log::info;
297        use tokio::time::sleep;
298
299        use srt_protocol::options::*;
300
301        async fn run_listener() -> Result<(), io::Error> {
302            let port = 4444;
303            let (_binding, mut incoming) = SrtListener::builder()
304                .with(Sender {
305                    drop_delay: Duration::from_secs(20),
306                    peer_latency: Duration::from_secs(1),
307                    buffer_size: ByteCount(8192 * 100),
308                    ..Default::default()
309                })
310                .bind("127.0.0.1:4444")
311                .await
312                .unwrap();
313
314            info!("SRT Multiplex Server is listening on port: {}", port);
315            while let Some(request) = incoming.incoming().next().await {
316                let mut srt_socket = request.accept(None).await.unwrap();
317
318                tokio::spawn(async move {
319                    let client_desc = format!(
320                        "(ip_port: {}, sockid: {})",
321                        srt_socket.settings().remote,
322                        srt_socket.settings().remote_sockid.0
323                    );
324
325                    info!("New client connected: {}", client_desc);
326
327                    let longer_than_peer_timeout = Duration::from_secs(7);
328                    let start = Instant::now();
329                    let mut stream = stream::unfold(0, |count| async move {
330                        let res = Ok((Instant::now(), Bytes::copy_from_slice(&[0; 1316])));
331                        sleep(Duration::from_millis(5)).await;
332                        if start.elapsed() > longer_than_peer_timeout {
333                            return None;
334                        }
335                        Some((res, count))
336                    })
337                    .boxed();
338
339                    if let Err(e) = srt_socket.send_all(&mut stream).await {
340                        info!("Send to client: {} error: {:?}", client_desc, e);
341                    }
342                    info!("Client {} disconnected", client_desc);
343
344                    start.elapsed().as_secs() as i32
345                });
346            }
347            Ok(())
348        }
349
350        async fn run_receiver(id: u32) -> Result<i32, io::Error> {
351            let mut srt_socket = SrtSocket::builder()
352                .with(Receiver {
353                    buffer_size: ByteCount(8192 * 100),
354                    latency: Duration::from_secs(1),
355                    ..Default::default()
356                })
357                .call("127.0.0.1:4444", None)
358                .await
359                .unwrap();
360
361            info!("Client {} connection opened", id);
362
363            let mut count = 1;
364            let start = Instant::now();
365            while let Some((_instant, _bytes)) = srt_socket.try_next().await? {
366                if count % 200 == 0 {
367                    info!("{} received {:?} packets", id, count);
368                }
369                count += 1;
370            }
371            info!("Client {} received {:?} packets", id, count);
372            info!("Client {} connection closed", id);
373
374            Ok(start.elapsed().as_secs() as i32)
375        }
376
377        let _listener_handle = tokio::spawn(run_listener());
378        let join_handles = [
379            tokio::spawn(run_receiver(1)),
380            tokio::spawn(run_receiver(2)),
381            tokio::spawn(run_receiver(3)),
382        ];
383        let min_elapsed_seconds = join_all(join_handles)
384            .await
385            .into_iter()
386            .map(|r| r.unwrap().unwrap_or_default())
387            .min()
388            .unwrap_or_default();
389
390        // clients should have still received data well past the default peer timout
391        assert!(min_elapsed_seconds > 5);
392    }
393}