Skip to main content

turn_server/server/provider/
mod.rs

1pub mod tcp;
2pub mod udp;
3
4use std::{net::SocketAddr, task::Poll, time::Duration};
5
6use anyhow::Result;
7use tokio::time::interval;
8
9use crate::{
10    Service,
11    config::Ssl,
12    server::{
13        Exchanger,
14        memory_pool::{Buffer, MemoryPool},
15    },
16    service::{Transport, session::Identifier},
17    statistics::{Statistics, Stats},
18};
19
20pub trait ProviderStream: Send + 'static {
21    fn read(&mut self) -> impl Future<Output = Result<Buffer>> + Send;
22    fn write(&mut self, buffer: &[u8]) -> impl Future<Output = Result<()>> + Send;
23    fn close(&mut self) -> impl Future<Output = ()> + Send;
24}
25
26#[allow(unused)]
27pub struct ServerOptions {
28    pub transport: Transport,
29    pub idle_timeout: u32,
30    pub listen: SocketAddr,
31    pub external: SocketAddr,
32    pub ssl: Option<Ssl>,
33    pub mtu: usize,
34}
35
36pub trait ProviderServer: Sized + Send {
37    type Stream: ProviderStream;
38
39    /// Bind the server to the specified address.
40    fn bind(options: &ServerOptions) -> impl Future<Output = Result<Self>> + Send;
41
42    /// Accept a new connection.
43    fn accept(&mut self) -> impl Future<Output = Result<Poll<(Self::Stream, SocketAddr)>>> + Send;
44
45    /// Get the local address of the listener.
46    fn local_addr(&self) -> Result<SocketAddr>;
47
48    /// Start the server.
49    fn start(
50        options: ServerOptions,
51        service: Service,
52        statistics: Statistics,
53        exchanger: Exchanger,
54    ) -> impl Future<Output = Result<()>> + Send {
55        let transport = options.transport;
56        let idle_timeout = options.idle_timeout as u64;
57
58        async move {
59            let mut listener = Self::bind(&options).await?;
60            let local_addr = listener.local_addr()?;
61
62            log::info!(
63                "server listening: listen={}, external={}, local addr={local_addr}, transport={transport:?}",
64                options.listen,
65                options.external,
66            );
67
68            while let Ok(poll) = listener.accept().await {
69                let Poll::Ready((mut socket, address)) = poll else {
70                    continue;
71                };
72
73                let id = Identifier {
74                    source: address,
75                    interface: local_addr,
76                    external: options.external,
77                    transport: options.transport,
78                };
79
80                let mut router = service.make_router(id);
81                let mut receiver = exchanger.get_receiver(id);
82                let reporter = statistics.get_reporter(transport);
83
84                let service = service.clone();
85                let exchanger = exchanger.clone();
86
87                tokio::spawn(async move {
88                    let mut interval = interval(Duration::from_secs(1));
89                    let mut read_delay = 0;
90
91                    loop {
92                        let mut response_buffer = MemoryPool::acquire();
93
94                        tokio::select! {
95                            Ok(buffer) = socket.read() => {
96                                read_delay = 0;
97
98                                if let Ok(Some(res)) = router.route(&buffer, &mut response_buffer).await
99                                {
100
101                                    if let Some(relay) = res.relay {
102                                        // The channel data needs to be aligned in multiples of 4 in
103                                        // tcp. If the channel data is forwarded to tcp, the alignment
104                                        // bit needs to be filled, because if the channel data comes
105                                        // from udp, it is not guaranteed to be aligned and needs to be
106                                        // checked.
107                                        if relay.transport == Transport::Tcp && res.method.is_none() {
108                                            let pad = response_buffer.len() % 4;
109                                            if pad > 0 {
110                                                response_buffer.extend_from_slice(&[0u8; 8][..(4 - pad)]);
111                                            }
112                                        }
113
114                                        exchanger.send(&relay, response_buffer);
115                                    } else {
116                                        if socket.write(&response_buffer).await.is_err() {
117                                            break;
118                                        }
119
120                                        reporter.send(
121                                            &id,
122                                            &[Stats::SendBytes(response_buffer.len()), Stats::SendPkts(1)],
123                                        );
124
125                                        if let Some(method) = res.method && method.is_error() {
126                                            reporter.send(&id, &[Stats::ErrorPkts(1)]);
127                                        }
128                                    }
129                                }
130                            }
131                            Some(bytes) = receiver.recv() => {
132                                if socket.write(&bytes).await.is_err() {
133                                    break;
134                                } else {
135                                    reporter.send(&id, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
136                                }
137                            }
138                            _ = interval.tick() => {
139                                read_delay += 1;
140
141                                if read_delay >= idle_timeout {
142                                    break;
143                                }
144                            }
145                            else => {
146                                break;
147                            }
148                        }
149                    }
150
151                    // close the socket
152                    socket.close().await;
153
154                    // When the socket connection is closed, the procedure to close the session is
155                    // process directly once, avoiding the connection being disconnected
156                    // directly without going through the closing
157                    // process.
158                    service.get_session_manager().refresh(&id, 0);
159
160                    exchanger.remove(&id);
161
162                    log::info!(
163                        "socket disconnect: addr={address:?}, interface={local_addr:?}, transport={transport:?}"
164                    );
165                });
166            }
167
168            log::error!("server shutdown: interface={local_addr:?}, transport={transport:?}");
169
170            Ok(())
171        }
172    }
173}