Skip to main content

turn_server/server/provider/
mod.rs

1pub mod tcp;
2pub mod udp;
3
4use std::{net::SocketAddr, task::Poll, time::Duration};
5
6use anyhow::Result;
7use tokio::time::interval;
8
9use crate::{
10    Service,
11    config::Ssl,
12    server::{Switch, buffer::Buffer},
13    service::{Transport, session::Identifier},
14    statistics::{Statistics, Stats},
15};
16
17pub trait ProviderStream: Send + 'static {
18    fn read(&mut self) -> impl Future<Output = Result<Buffer>> + Send;
19    fn write(&mut self, buffer: &[u8]) -> impl Future<Output = Result<()>> + Send;
20    fn close(&mut self) -> impl Future<Output = ()> + Send;
21}
22
23#[allow(unused)]
24pub struct ServerOptions {
25    pub transport: Transport,
26    pub idle_timeout: u32,
27    pub listen: SocketAddr,
28    pub external: SocketAddr,
29    pub ssl: Option<Ssl>,
30    pub mtu: usize,
31}
32
33pub trait ProviderServer: Sized + Send {
34    type Stream: ProviderStream;
35
36    /// Bind the server to the specified address.
37    fn bind(options: &ServerOptions) -> impl Future<Output = Result<Self>> + Send;
38
39    /// Accept a new connection.
40    fn accept(&mut self) -> impl Future<Output = Result<Poll<(Self::Stream, SocketAddr)>>> + Send;
41
42    /// Get the local address of the listener.
43    fn local_addr(&self) -> Result<SocketAddr>;
44
45    /// Start the server.
46    fn start(
47        options: ServerOptions,
48        service: Service,
49        statistics: Statistics,
50        switch: Switch,
51    ) -> impl Future<Output = Result<()>> + Send {
52        let transport = options.transport;
53        let idle_timeout = options.idle_timeout as u64;
54
55        async move {
56            let mut listener = Self::bind(&options).await?;
57            let local_addr = listener.local_addr()?;
58
59            log::info!(
60                "server listening: listen={}, external={}, local addr={local_addr}, transport={transport:?}",
61                options.listen,
62                options.external,
63            );
64
65            while let Ok(poll) = listener.accept().await {
66                let Poll::Ready((mut socket, address)) = poll else {
67                    continue;
68                };
69
70                let id = Identifier {
71                    source: address,
72                    interface: local_addr,
73                    external: options.external,
74                    transport: options.transport,
75                };
76
77                let mut router = service.make_router(id);
78                let mut receiver = switch.get_receiver(id);
79                let reporter = statistics.get_reporter(transport);
80
81                let service = service.clone();
82                let switch = switch.clone();
83
84                tokio::spawn(async move {
85                    let mut interval = interval(Duration::from_secs(1));
86                    let mut read_delay = 0;
87
88                    loop {
89                        let mut response_buffer = Buffer::new();
90
91                        tokio::select! {
92                            Ok(buffer) = socket.read() => {
93                                read_delay = 0;
94
95                                if let Ok(Some(res)) = router.route(&buffer, &mut response_buffer).await
96                                {
97
98                                    if let Some(relay) = res.relay {
99                                        // The channel data needs to be aligned in multiples of 4 in
100                                        // tcp. If the channel data is forwarded to tcp, the alignment
101                                        // bit needs to be filled, because if the channel data comes
102                                        // from udp, it is not guaranteed to be aligned and needs to be
103                                        // checked.
104                                        if relay.transport == Transport::Tcp && res.method.is_none() {
105                                            let pad = response_buffer.len() % 4;
106                                            if pad > 0 {
107                                                response_buffer.extend_from_slice(&[0u8; 8][..(4 - pad)]);
108                                            }
109                                        }
110
111                                        switch.send(&relay, response_buffer);
112                                    } else {
113                                        if socket.write(&response_buffer).await.is_err() {
114                                            break;
115                                        }
116
117                                        reporter.send(
118                                            &id,
119                                            &[Stats::SendBytes(response_buffer.len()), Stats::SendPkts(1)],
120                                        );
121
122                                        if let Some(method) = res.method && method.is_error() {
123                                            reporter.send(&id, &[Stats::ErrorPkts(1)]);
124                                        }
125                                    }
126                                }
127                            }
128                            Some(bytes) = receiver.recv() => {
129                                if socket.write(&bytes).await.is_err() {
130                                    break;
131                                } else {
132                                    reporter.send(&id, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
133                                }
134                            }
135                            _ = interval.tick() => {
136                                read_delay += 1;
137
138                                if read_delay >= idle_timeout {
139                                    break;
140                                }
141                            }
142                            else => {
143                                break;
144                            }
145                        }
146                    }
147
148                    // close the socket
149                    socket.close().await;
150
151                    // When the socket connection is closed, the procedure to close the session is
152                    // process directly once, avoiding the connection being disconnected
153                    // directly without going through the closing
154                    // process.
155                    service.get_session_manager().refresh(&id, 0);
156
157                    switch.remove(&id);
158
159                    log::info!(
160                        "socket disconnect: addr={address:?}, interface={local_addr:?}, transport={transport:?}"
161                    );
162                });
163            }
164
165            log::error!("server shutdown: interface={local_addr:?}, transport={transport:?}");
166
167            Ok(())
168        }
169    }
170}