1mod builder;
2mod session;
3mod state;
4
5use std::{io, sync::Arc};
6
7use futures::{channel::mpsc, prelude::*};
8use srt_protocol::settings::ConnInitSettings;
9use tokio::{net::UdpSocket, sync::oneshot, task::JoinHandle};
10
11use crate::net::bind_socket;
12
13use super::{net::PacketSocket, options::*, watch};
14
15pub use builder::SrtListenerBuilder;
16pub use session::ConnectionRequest;
17pub use srt_protocol::statistics::ListenerStatistics;
18
19#[derive(Debug)]
20pub struct SrtListener {
21 settings: ConnInitSettings,
22 statistics_receiver: watch::Receiver<ListenerStatistics>,
23 close_req: Option<oneshot::Sender<()>>,
24 task: JoinHandle<()>,
25}
26
27#[derive(Debug)]
28pub struct SrtIncoming {
29 request_receiver: mpsc::Receiver<ConnectionRequest>,
30}
31
32impl SrtListener {
33 pub fn builder() -> SrtListenerBuilder {
34 SrtListenerBuilder::default()
35 }
36
37 pub async fn bind(options: Valid<ListenerOptions>) -> Result<(Self, SrtIncoming), io::Error> {
38 let socket = bind_socket(&options.socket).await?;
39 Self::bind_with_socket(options, socket).await
40 }
41
42 pub async fn bind_with_socket(
43 options: Valid<ListenerOptions>,
44 socket: UdpSocket,
45 ) -> Result<(Self, SrtIncoming), io::Error> {
46 use state::SrtListenerState;
47 let socket_options = options.into_value().socket;
48 let local_address = socket.local_addr()?;
49 let socket = PacketSocket::from_socket(Arc::new(socket), 1024 * 1024);
50 let settings = ConnInitSettings::from(socket_options);
51 let (close_req, close_resp) = oneshot::channel();
52 let (request_sender, request_receiver) = mpsc::channel(100);
53 let (statistics_sender, statistics_receiver) = watch::channel();
54 let state = SrtListenerState::new(
55 socket,
56 local_address,
57 settings.clone(),
58 request_sender,
59 statistics_sender,
60 close_resp,
61 );
62 let task = tokio::spawn(async move {
63 state.run_loop().await;
64 });
65 Ok((
66 Self {
67 settings,
68 statistics_receiver,
69 close_req: Some(close_req),
70 task,
71 },
72 SrtIncoming { request_receiver },
73 ))
74 }
75
76 pub fn settings(&self) -> &ConnInitSettings {
77 &self.settings
78 }
79
80 pub fn statistics(&mut self) -> &mut (impl Stream<Item = ListenerStatistics> + Clone) {
81 &mut self.statistics_receiver
82 }
83
84 pub async fn close(&mut self) {
85 let _ = self.close_req.take().unwrap().send(());
86 (&mut self.task).await.unwrap();
87 }
88}
89
90impl SrtIncoming {
91 pub fn incoming(&mut self) -> &mut impl Stream<Item = ConnectionRequest> {
92 &mut self.request_receiver
93 }
94}
95
96impl Drop for SrtListener {
97 fn drop(&mut self) {
98 }
100}
101
102#[cfg(test)]
103mod tests {
104 use std::time::{Duration, Instant};
105
106 use anyhow::Result;
107 use bytes::Bytes;
108 use futures::{channel::oneshot, future::join_all, prelude::*};
109 use log::{debug, info};
110
111 use crate::{access::*, SrtSocket};
112
113 use super::*;
114
115 #[tokio::test]
116 async fn accept_reject() -> Result<()> {
117 #[derive(Debug)]
118 enum Select {
119 Connection(Option<ConnectionRequest>),
120 Statistics(Option<ListenerStatistics>),
121 Finished,
122 }
123
124 let _ = pretty_env_logger::try_init();
125
126 let (finished_send, finished_recv) = oneshot::channel();
127
128 let listener = tokio::spawn(async {
129 let (mut server, mut incoming) =
130 SrtListener::builder().bind("127.0.0.1:4001").await.unwrap();
131 let mut statistics = server.statistics().clone().fuse();
132
133 let mut incoming = incoming.incoming().fuse();
134 let mut fused_finish = finished_recv.fuse();
135 loop {
136 let selection = futures::select!(
137 request = incoming.next() => Select::Connection(request),
138 stats = statistics.next() => Select::Statistics(stats),
139 _ = fused_finish => Select::Finished,
140 );
141 match selection {
142 Select::Connection(Some(request)) => {
143 let stream_id = request.stream_id().unwrap();
144 if stream_id.eq(&"reject".parse().unwrap()) {
145 request.reject(RejectReason::User(42)).await.unwrap();
146 } else {
147 let mut sender = request.accept(None).await.unwrap();
148 let mut stream = stream::iter(
149 Some(Ok((Instant::now(), Bytes::from("hello")))).into_iter(),
150 );
151 tokio::spawn(async move {
152 sender.send_all(&mut stream).await.unwrap();
153 sender.close().await.unwrap();
154 info!("Sent");
155 });
156 }
157 }
158 Select::Statistics(Some(stats)) => debug!("{:?}", stats),
159 _ => {
160 break;
161 }
162 }
163 }
164 });
165
166 let mut join_handles = vec![];
168 for i in 0..10 {
169 join_handles.push(tokio::spawn(async move {
170 info!("Calling: {}", i);
171 let address = "127.0.0.1:4001";
172 if i % 2 > 0 {
173 let result = SrtSocket::builder().call(address, Some("reject")).await;
174 assert!(result.is_err());
175 debug!("Rejected: {}", i);
176 } else {
177 let stream_id = format!("{i}").to_string();
178 let mut receiver = SrtSocket::builder()
179 .call(address, Some(&stream_id))
180 .await
181 .unwrap();
182 info!("Accepted: {}", i);
183 let first = receiver.next().await;
184 assert_eq!(first.unwrap().unwrap().1, "hello");
185 let second = receiver.next().await;
186 assert!(second.is_none());
187 info!("Received: {}", i);
188 }
189 }));
190 }
191
192 join_all(join_handles).await;
194 info!("all finished");
195 finished_send.send(()).unwrap();
196 listener.await?;
197 Ok(())
198 }
199
200 #[tokio::test]
201 async fn accept_reject_encryption() -> Result<()> {
202 #[derive(Debug)]
203 enum Select {
204 Connection(Option<ConnectionRequest>),
205 Statistics(Option<ListenerStatistics>),
206 Finished,
207 }
208
209 let _ = pretty_env_logger::try_init();
210
211 let (finished_send, finished_recv) = oneshot::channel();
212
213 let listener = tokio::spawn(async {
214 let (mut server, mut incoming) = SrtListener::builder()
215 .encryption(0, "super secret passcode")
216 .bind("127.0.0.1:4002")
217 .await
218 .unwrap();
219 let mut statistics = server.statistics().clone().fuse();
220
221 let mut incoming = incoming.incoming().fuse();
222 let mut fused_finish = finished_recv.fuse();
223 loop {
224 let selection = futures::select!(
225 request = incoming.next() => Select::Connection(request),
226 stats = statistics.next() => Select::Statistics(stats),
227 _ = fused_finish => Select::Finished,
228 );
229 match selection {
230 Select::Connection(Some(request)) => {
231 let stream_id = request.stream_id().expect("stream_id");
232 if stream_id.eq(&"reject".parse().unwrap()) {
233 request
234 .reject(RejectReason::User(42))
235 .await
236 .expect("reject");
237 } else {
238 let mut sender = request.accept(None).await.expect("accept");
239 let mut stream = stream::iter(
240 Some(Ok((Instant::now(), Bytes::from("hello")))).into_iter(),
241 );
242 tokio::spawn(async move {
243 sender.send_all(&mut stream).await.expect("send_all");
244 sender.close().await.expect("close");
245 info!("Sent");
246 });
247 }
248 }
249 Select::Statistics(Some(stats)) => debug!("{:?}", stats),
250 _ => {
251 break;
252 }
253 }
254 }
255 });
256
257 let mut join_handles = vec![];
259 for i in 0..10 {
260 join_handles.push(tokio::spawn(async move {
261 info!("Calling: {}", i);
262 let address = "127.0.0.1:4002";
263 if i % 2 == 0 {
264 let result = SrtSocket::builder().call(address, Some("reject")).await;
265 assert!(result.is_err());
266 info!("Rejected: {}", i);
267 } else {
268 let stream_id = format!("{i}").to_string();
269 let mut receiver = SrtSocket::builder()
270 .encryption(0, "super secret passcode")
271 .call(address, Some(&stream_id))
272 .await
273 .expect("call");
274 info!("Accepted: {}", i);
275 let first = receiver.next().await;
276 assert_eq!(first.expect("next error").expect("next no data").1, "hello");
277 let second = receiver.next().await;
278 assert!(second.is_none());
279 info!("Received: {}", i);
280 }
281 }));
282 }
283
284 join_all(join_handles).await;
286 info!("all finished");
287 finished_send.send(()).unwrap();
288 listener.await?;
289 Ok(())
290 }
291
292 #[tokio::test]
293 async fn multiplex_timeout() {
294 use bytes::Bytes;
295 use futures::{stream, SinkExt, StreamExt};
296 use log::info;
297 use tokio::time::sleep;
298
299 use srt_protocol::options::*;
300
301 async fn run_listener() -> Result<(), io::Error> {
302 let port = 4444;
303 let (_binding, mut incoming) = SrtListener::builder()
304 .with(Sender {
305 drop_delay: Duration::from_secs(20),
306 peer_latency: Duration::from_secs(1),
307 buffer_size: ByteCount(8192 * 100),
308 ..Default::default()
309 })
310 .bind("127.0.0.1:4444")
311 .await
312 .unwrap();
313
314 info!("SRT Multiplex Server is listening on port: {}", port);
315 while let Some(request) = incoming.incoming().next().await {
316 let mut srt_socket = request.accept(None).await.unwrap();
317
318 tokio::spawn(async move {
319 let client_desc = format!(
320 "(ip_port: {}, sockid: {})",
321 srt_socket.settings().remote,
322 srt_socket.settings().remote_sockid.0
323 );
324
325 info!("New client connected: {}", client_desc);
326
327 let longer_than_peer_timeout = Duration::from_secs(7);
328 let start = Instant::now();
329 let mut stream = stream::unfold(0, |count| async move {
330 let res = Ok((Instant::now(), Bytes::copy_from_slice(&[0; 1316])));
331 sleep(Duration::from_millis(5)).await;
332 if start.elapsed() > longer_than_peer_timeout {
333 return None;
334 }
335 Some((res, count))
336 })
337 .boxed();
338
339 if let Err(e) = srt_socket.send_all(&mut stream).await {
340 info!("Send to client: {} error: {:?}", client_desc, e);
341 }
342 info!("Client {} disconnected", client_desc);
343
344 start.elapsed().as_secs() as i32
345 });
346 }
347 Ok(())
348 }
349
350 async fn run_receiver(id: u32) -> Result<i32, io::Error> {
351 let mut srt_socket = SrtSocket::builder()
352 .with(Receiver {
353 buffer_size: ByteCount(8192 * 100),
354 latency: Duration::from_secs(1),
355 ..Default::default()
356 })
357 .call("127.0.0.1:4444", None)
358 .await
359 .unwrap();
360
361 info!("Client {} connection opened", id);
362
363 let mut count = 1;
364 let start = Instant::now();
365 while let Some((_instant, _bytes)) = srt_socket.try_next().await? {
366 if count % 200 == 0 {
367 info!("{} received {:?} packets", id, count);
368 }
369 count += 1;
370 }
371 info!("Client {} received {:?} packets", id, count);
372 info!("Client {} connection closed", id);
373
374 Ok(start.elapsed().as_secs() as i32)
375 }
376
377 let _listener_handle = tokio::spawn(run_listener());
378 let join_handles = [
379 tokio::spawn(run_receiver(1)),
380 tokio::spawn(run_receiver(2)),
381 tokio::spawn(run_receiver(3)),
382 ];
383 let min_elapsed_seconds = join_all(join_handles)
384 .await
385 .into_iter()
386 .map(|r| r.unwrap().unwrap_or_default())
387 .min()
388 .unwrap_or_default();
389
390 assert!(min_elapsed_seconds > 5);
392 }
393}