turn_server_proto/
rustls.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! A TURN server that can handle TLS client connections.
10
11use alloc::string::String;
12use alloc::sync::Arc;
13use alloc::vec;
14use alloc::vec::Vec;
15use core::net::SocketAddr;
16use core::time::Duration;
17use std::io::{Read, Write};
18use turn_types::prelude::DelayedTransmitBuild;
19use turn_types::transmit::TransmitBuild;
20use turn_types::AddressFamily;
21
22use rustls::{ServerConfig, ServerConnection};
23use stun_proto::agent::Transmit;
24use stun_proto::Instant;
25use tracing::{info, trace, warn};
26use turn_types::stun::TransportType;
27
28use crate::api::{
29    DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
30};
31use crate::server::TurnServer;
32
33/// A TURN server that can handle TLS connections.
34#[derive(Debug)]
35pub struct RustlsTurnServer {
36    server: TurnServer,
37    config: Arc<ServerConfig>,
38    connections: Vec<(SocketAddr, ServerConnection)>,
39}
40
41impl RustlsTurnServer {
42    /// Construct a now Turn server that can handle TLS connections.
43    pub fn new(listen_addr: SocketAddr, realm: String, config: Arc<ServerConfig>) -> Self {
44        Self {
45            server: TurnServer::new(TransportType::Tcp, listen_addr, realm),
46            config,
47            connections: vec![],
48        }
49    }
50}
51
52impl TurnServerApi for RustlsTurnServer {
53    /// Add a user credentials that would be accepted by this [`TurnServer`].
54    fn add_user(&mut self, username: String, password: String) {
55        self.server.add_user(username, password)
56    }
57
58    /// The address that the [`TurnServer`] is listening on for incoming client connections.
59    fn listen_address(&self) -> SocketAddr {
60        self.server.listen_address()
61    }
62
63    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
64    /// will need to be acquired by a client.
65    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
66        self.server.set_nonce_expiry_duration(expiry_duration)
67    }
68
69    /// Provide received data to the [`TurnServer`].
70    ///
71    /// Any returned Transmit should be forwarded to the appropriate socket.
72    #[tracing::instrument(
73        name = "turn_server_rustls_recv",
74        skip(self, transmit, now),
75        fields(
76            from = ?transmit.from,
77            data_len = transmit.data.as_ref().len()
78        )
79    )]
80    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
81        &mut self,
82        transmit: Transmit<T>,
83        now: Instant,
84    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
85        let listen_address = self.listen_address();
86        if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
87            trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
88            // incoming client
89            let (client_addr, conn) = match self
90                .connections
91                .iter_mut()
92                .find(|(client_addr, _conn)| *client_addr == transmit.from)
93            {
94                Some((client_addr, conn)) => (*client_addr, conn),
95                None => {
96                    let len = self.connections.len();
97                    self.connections.push((
98                        transmit.from,
99                        ServerConnection::new(self.config.clone()).unwrap(),
100                    ));
101                    info!("new connection from {}", transmit.from);
102                    let ret = &mut self.connections[len];
103                    (ret.0, &mut ret.1)
104                }
105            };
106            let mut input = std::io::Cursor::new(transmit.data.as_ref());
107            let io_state = match conn.read_tls(&mut input) {
108                Ok(_written) => match conn.process_new_packets() {
109                    Ok(io_state) => io_state,
110                    Err(e) => {
111                        warn!("Error processing incoming TLS: {e:?}");
112                        return None;
113                    }
114                },
115                Err(e) => {
116                    warn!("Error receiving data: {e:?}");
117                    return None;
118                }
119            };
120            if io_state.plaintext_bytes_to_read() == 0 {
121                return None;
122            }
123            let mut vec = vec![0; 2048];
124            let n = match conn.reader().read(&mut vec) {
125                Ok(n) => n,
126                Err(e) => {
127                    if e.kind() == std::io::ErrorKind::WouldBlock {
128                        return None;
129                    } else {
130                        warn!("TLS error: {e:?}");
131                        return None;
132                    }
133                }
134            };
135            tracing::error!("io_state: {io_state:?}, n: {n}");
136            vec.resize(n, 0);
137            let transmit = self.server.recv(
138                Transmit::new(vec, transmit.transport, transmit.from, transmit.to),
139                now,
140            )?;
141            if transmit.transport == TransportType::Tcp
142                && transmit.from == listen_address
143                && transmit.to == client_addr
144            {
145                let plaintext = transmit.data.build();
146                conn.writer().write_all(&plaintext).unwrap();
147                let mut out = vec![];
148                conn.write_tls(&mut out).unwrap();
149                Some(TransmitBuild::new(
150                    DelayedMessageOrChannelSend::Owned(out),
151                    TransportType::Tcp,
152                    listen_address,
153                    client_addr,
154                ))
155            } else {
156                let transmit = transmit.build();
157                Some(TransmitBuild::new(
158                    DelayedMessageOrChannelSend::Owned(transmit.data),
159                    transmit.transport,
160                    transmit.from,
161                    transmit.to,
162                ))
163            }
164        } else if let Some(transmit) = self.server.recv(transmit, now) {
165            // incoming allocated address
166            if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
167                let Some((client_addr, conn)) = self
168                    .connections
169                    .iter_mut()
170                    .find(|(client_addr, _conn)| transmit.to == *client_addr)
171                else {
172                    return Some(transmit);
173                };
174                let plaintext = transmit.data.build();
175                conn.writer().write_all(&plaintext).unwrap();
176                let mut out = vec![];
177                conn.write_tls(&mut out).unwrap();
178                Some(TransmitBuild::new(
179                    DelayedMessageOrChannelSend::Owned(out),
180                    TransportType::Tcp,
181                    listen_address,
182                    *client_addr,
183                ))
184            } else {
185                Some(transmit)
186            }
187        } else {
188            None
189        }
190    }
191
192    fn recv_icmp<T: AsRef<[u8]>>(
193        &mut self,
194        family: AddressFamily,
195        bytes: T,
196        now: Instant,
197    ) -> Option<Transmit<Vec<u8>>> {
198        let transmit = self.server.recv_icmp(family, bytes, now)?;
199        // incoming allocated address
200        let listen_address = self.listen_address();
201        if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
202            let Some((client_addr, conn)) = self
203                .connections
204                .iter_mut()
205                .find(|(client_addr, _conn)| transmit.to == *client_addr)
206            else {
207                return Some(transmit);
208            };
209            conn.writer().write_all(&transmit.data).unwrap();
210            let mut out = vec![];
211            conn.write_tls(&mut out).unwrap();
212            Some(Transmit::new(
213                out,
214                TransportType::Tcp,
215                listen_address,
216                *client_addr,
217            ))
218        } else {
219            Some(transmit)
220        }
221    }
222
223    /// Poll the [`TurnServer`] in order to make further progress.
224    ///
225    /// The returned value indicates what the caller should do.
226    fn poll(&mut self, now: Instant) -> TurnServerPollRet {
227        let protocol_ret = self.server.poll(now);
228        let mut have_pending = false;
229        for (_client_addr, conn) in self.connections.iter_mut() {
230            let io_state = match conn.process_new_packets() {
231                Ok(io_state) => io_state,
232                Err(e) => {
233                    warn!("Error processing TLS: {e:?}");
234                    continue;
235                }
236            };
237            if io_state.tls_bytes_to_write() > 0 {
238                have_pending = true;
239            }
240        }
241        if have_pending {
242            return TurnServerPollRet::WaitUntil(now);
243        }
244        protocol_ret
245    }
246
247    /// Poll for a new Transmit to send over a socket.
248    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
249        let listen_address = self.listen_address();
250
251        while let Some(transmit) = self.server.poll_transmit(now) {
252            let Some((_client_addr, conn)) = self
253                .connections
254                .iter_mut()
255                .find(|(client_addr, _conn)| transmit.to == *client_addr)
256            else {
257                warn!("return transmit: {transmit:?}");
258                return Some(transmit);
259            };
260            conn.writer().write_all(&transmit.data).unwrap();
261        }
262
263        for (client_addr, conn) in self.connections.iter_mut() {
264            if !conn.wants_write() {
265                continue;
266            }
267            let mut vec = vec![];
268            let n = match conn.write_tls(&mut vec) {
269                Ok(n) => n,
270                Err(e) => {
271                    warn!("error writing TLS: {e:?}");
272                    continue;
273                }
274            };
275            vec.resize(n, 0);
276            warn!("return transmit: {vec:x?}");
277            return Some(Transmit::new(
278                vec,
279                TransportType::Tcp,
280                listen_address,
281                *client_addr,
282            ));
283        }
284        None
285    }
286
287    /// Notify the [`TurnServer`] that a UDP socket has been allocated (or an error) in response to
288    /// [TurnServerPollRet::AllocateSocketUdp].
289    fn allocated_udp_socket(
290        &mut self,
291        transport: TransportType,
292        local_addr: SocketAddr,
293        remote_addr: SocketAddr,
294        family: AddressFamily,
295        socket_addr: Result<SocketAddr, SocketAllocateError>,
296        now: Instant,
297    ) {
298        self.server.allocated_udp_socket(
299            transport,
300            local_addr,
301            remote_addr,
302            family,
303            socket_addr,
304            now,
305        )
306    }
307}