Skip to main content

turn_client_rustls/
lib.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! #turn-client-rustls
12//!
13//! TLS TURN client using Rustls.
14//!
15//! An implementation of a TURN client suitable for TLS over TCP connections connections.
16//!
17//! ## Crypto providers
18//!
19//! `turn-client-rustls` does not enable any cryptographic providers on rustls.
20//! It is the user's responsibility (library or application) to enable and use
21//! the relevant cryptographic provider (ring, aws-lc-rs, RustCrypto, etc),
22//! that they wish to use.
23
24#![deny(missing_debug_implementations)]
25#![deny(missing_docs)]
26#![cfg_attr(docsrs, feature(doc_cfg))]
27#![deny(clippy::std_instead_of_core)]
28#![deny(clippy::std_instead_of_alloc)]
29#![no_std]
30
31extern crate alloc;
32
33#[cfg(any(feature = "std", test))]
34extern crate std;
35
36pub use rustls;
37
38use alloc::sync::Arc;
39use alloc::vec;
40use alloc::vec::Vec;
41use core::net::{IpAddr, SocketAddr};
42use core::time::Duration;
43use std::io::{Read, Write};
44
45use turn_client_proto::types::Instant;
46use turn_client_proto::types::TransportType;
47
48pub use turn_client_proto as proto;
49pub use turn_client_proto::api::*;
50
51use turn_client_proto::tcp::TurnClientTcp;
52
53use rustls::pki_types::ServerName;
54use rustls::{ClientConfig, ClientConnection};
55
56use tracing::{debug, trace, warn};
57
58/// A TURN client that communicates over TLS.
59#[derive(Debug)]
60pub struct TurnClientRustls {
61    protocol: TurnClientTcp,
62    tls_config: Arc<ClientConfig>,
63    server_name: ServerName<'static>,
64    pending_allocates: Vec<(u32, Socket5Tuple, SocketAddr)>,
65    sockets: Vec<Socket>,
66}
67
68#[derive(Debug)]
69struct Socket {
70    local_addr: SocketAddr,
71    remote_addr: SocketAddr,
72    tls: ClientConnection,
73    peer_closed: bool,
74    local_closed: bool,
75}
76
77impl TurnClientRustls {
78    /// Allocate an address on a TURN server to relay data to and from peers.
79    #[allow(clippy::too_many_arguments)]
80    pub fn allocate(
81        local_addr: SocketAddr,
82        remote_addr: SocketAddr,
83        config: TurnConfig,
84        server_name: ServerName<'static>,
85        tls_config: Arc<ClientConfig>,
86    ) -> Self {
87        Self {
88            protocol: TurnClientTcp::allocate(local_addr, remote_addr, config),
89            sockets: vec![Socket {
90                local_addr,
91                remote_addr,
92                tls: ClientConnection::new(tls_config.clone(), server_name.clone()).unwrap(),
93                local_closed: false,
94                peer_closed: false,
95            }],
96            tls_config,
97            server_name,
98            pending_allocates: vec![],
99        }
100    }
101
102    fn empty_transmit_queue(&mut self, now: Instant) {
103        while let Some(transmit) = self.protocol.poll_transmit(now) {
104            let Some(socket) = self.sockets.iter_mut().find(|socket| {
105                socket.local_addr == transmit.from && socket.remote_addr == transmit.to
106            }) else {
107                warn!(
108                    "no socket for transmit from {} to {}",
109                    transmit.from, transmit.to
110                );
111                continue;
112            };
113            socket.tls.writer().write_all(&transmit.data).unwrap();
114        }
115    }
116}
117
118impl TurnClientApi for TurnClientRustls {
119    fn transport(&self) -> TransportType {
120        self.protocol.transport()
121    }
122
123    fn local_addr(&self) -> SocketAddr {
124        self.protocol.local_addr()
125    }
126
127    fn remote_addr(&self) -> SocketAddr {
128        self.protocol.remote_addr()
129    }
130
131    fn poll(&mut self, now: Instant) -> TurnPollRet {
132        let mut is_handshaking = false;
133        let mut protocol_ret = TurnPollRet::Closed;
134        for (idx, socket) in self.sockets.iter_mut().enumerate() {
135            let io_state = match socket.tls.process_new_packets() {
136                Ok(io_state) => io_state,
137                Err(e) => {
138                    warn!("Error processing TLS: {e:?}");
139                    if socket.local_addr == self.protocol.local_addr()
140                        && socket.remote_addr == self.protocol.remote_addr()
141                    {
142                        self.protocol.protocol_error();
143                        return TurnPollRet::Closed;
144                    } else {
145                        // TODO: remove socket?
146                        continue;
147                    }
148                }
149            };
150            if io_state.peer_has_closed() {
151                socket.peer_closed = true;
152                if !socket.local_closed {
153                    socket.tls.send_close_notify();
154                    socket.local_closed = true;
155                    trace!("sending close notify");
156                    return TurnPollRet::WaitUntil(now);
157                }
158            }
159            let tls_write_bytes = io_state.tls_bytes_to_write();
160            if tls_write_bytes > 0 {
161                trace!("have {tls_write_bytes} bytes to write");
162                return TurnPollRet::WaitUntil(now);
163            }
164            if socket.peer_closed && socket.local_closed && !socket.tls.wants_write() {
165                let socket = self.sockets.remove(idx);
166                return TurnPollRet::TcpClose {
167                    local_addr: socket.local_addr,
168                    remote_addr: socket.remote_addr,
169                };
170            }
171            if socket.local_addr == self.protocol.local_addr()
172                && socket.remote_addr == self.protocol.remote_addr()
173            {
174                protocol_ret = self.protocol.poll(now);
175            }
176            is_handshaking |= socket.tls.is_handshaking();
177        }
178        match protocol_ret {
179            TurnPollRet::Closed => {
180                debug!("Closed");
181                return protocol_ret;
182            }
183            TurnPollRet::AllocateTcpSocket {
184                id,
185                socket,
186                peer_addr,
187            } => {
188                self.pending_allocates.push((id, socket, peer_addr));
189            }
190            _ => (),
191        }
192        if is_handshaking {
193            debug!("Currently handshaking, waiting for reply");
194            return TurnPollRet::WaitUntil(now + Duration::from_secs(60));
195        }
196        protocol_ret
197    }
198
199    fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
200        self.protocol.relayed_addresses()
201    }
202
203    fn permissions(
204        &self,
205        transport: TransportType,
206        relayed: SocketAddr,
207    ) -> impl Iterator<Item = IpAddr> + '_ {
208        self.protocol.permissions(transport, relayed)
209    }
210
211    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
212        let client_transport = self.transport();
213        for socket in self.sockets.iter_mut() {
214            if socket.tls.is_handshaking() {
215                if socket.tls.wants_write() {
216                    // TODO: avoid this allocation
217                    let mut out = vec![];
218                    match socket.tls.write_tls(&mut out) {
219                        Ok(_written) => {
220                            return Some(Transmit::new(
221                                Data::from(out.into_boxed_slice()),
222                                client_transport,
223                                socket.local_addr,
224                                socket.remote_addr,
225                            ))
226                        }
227                        Err(e) => {
228                            warn!("error during handshake: {e:?}");
229                            if socket.local_addr == self.protocol.local_addr()
230                                && socket.remote_addr == self.protocol.remote_addr()
231                            {
232                                self.protocol.protocol_error();
233                                return None;
234                            } else {
235                                // TODO: remove socket?
236                                continue;
237                            }
238                        }
239                    }
240                }
241                if socket.local_addr == self.protocol.local_addr()
242                    && socket.remote_addr == self.protocol.remote_addr()
243                {
244                    return None;
245                }
246            }
247        }
248        self.empty_transmit_queue(now);
249
250        for socket in self.sockets.iter_mut() {
251            if socket.tls.wants_write() {
252                // TODO: avoid this allocation
253                let mut out = vec![];
254                match socket.tls.write_tls(&mut out) {
255                    Ok(_written) => {
256                        return Some(Transmit::new(
257                            Data::from(out.into_boxed_slice()),
258                            client_transport,
259                            socket.local_addr,
260                            socket.remote_addr,
261                        ))
262                    }
263                    Err(e) => {
264                        warn!("error writing TLS: {e:?}");
265                        if socket.local_addr == self.protocol.local_addr()
266                            && socket.remote_addr == self.protocol.remote_addr()
267                        {
268                            self.protocol.protocol_error();
269                        } else {
270                            // TODO: remove socket?
271                            continue;
272                        }
273                    }
274                }
275            }
276        }
277        None
278    }
279
280    fn poll_event(&mut self) -> Option<TurnEvent> {
281        match self.protocol.poll_event()? {
282            TurnEvent::TcpConnected(peer_addr) => Some(TurnEvent::TcpConnected(peer_addr)),
283            TurnEvent::TcpConnectFailed(peer_addr) => Some(TurnEvent::TcpConnectFailed(peer_addr)),
284            event => Some(event),
285        }
286    }
287
288    fn delete(&mut self, now: Instant) -> Result<(), DeleteError> {
289        self.protocol.delete(now)?;
290
291        self.empty_transmit_queue(now);
292        Ok(())
293    }
294
295    fn create_permission(
296        &mut self,
297        transport: TransportType,
298        peer_addr: IpAddr,
299        now: Instant,
300    ) -> Result<(), CreatePermissionError> {
301        self.protocol.create_permission(transport, peer_addr, now)?;
302
303        self.empty_transmit_queue(now);
304
305        Ok(())
306    }
307
308    fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool {
309        self.protocol.have_permission(transport, to)
310    }
311
312    fn bind_channel(
313        &mut self,
314        transport: TransportType,
315        peer_addr: SocketAddr,
316        now: Instant,
317    ) -> Result<(), BindChannelError> {
318        self.protocol.bind_channel(transport, peer_addr, now)?;
319
320        self.empty_transmit_queue(now);
321
322        Ok(())
323    }
324
325    fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError> {
326        self.protocol.tcp_connect(peer_addr, now)?;
327
328        self.empty_transmit_queue(now);
329
330        Ok(())
331    }
332
333    fn allocated_tcp_socket(
334        &mut self,
335        id: u32,
336        five_tuple: Socket5Tuple,
337        peer_addr: SocketAddr,
338        local_addr: Option<SocketAddr>,
339        now: Instant,
340    ) -> Result<(), TcpAllocateError> {
341        self.protocol
342            .allocated_tcp_socket(id, five_tuple, peer_addr, local_addr, now)?;
343
344        if let Some(local_addr) = local_addr {
345            if let Some(idx) = self
346                .pending_allocates
347                .iter()
348                .position(|pending| pending.1 == five_tuple)
349            {
350                self.pending_allocates.swap_remove(idx);
351                self.sockets.push(Socket {
352                    local_addr,
353                    remote_addr: self.remote_addr(),
354                    tls: ClientConnection::new(self.tls_config.clone(), self.server_name.clone())
355                        .unwrap(),
356                    local_closed: false,
357                    peer_closed: false,
358                });
359            }
360        }
361
362        self.empty_transmit_queue(now);
363        Ok(())
364    }
365
366    fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant) {
367        let Some(socket) = self
368            .sockets
369            .iter_mut()
370            .find(|socket| socket.local_addr == local_addr && socket.remote_addr == remote_addr)
371        else {
372            warn!(
373                "Unknown socket local:{}, remote:{}",
374                local_addr, remote_addr
375            );
376            return;
377        };
378        self.protocol.tcp_closed(local_addr, remote_addr, now);
379        socket.tls.send_close_notify();
380        socket.local_closed = true;
381    }
382
383    fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
384        &mut self,
385        transport: TransportType,
386        to: SocketAddr,
387        data: T,
388        now: Instant,
389    ) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError> {
390        if let Some(transmit) = self.protocol.send_to(transport, to, data, now)? {
391            let client_transport = self.transport();
392            let transmit = transmit.build();
393            let Some(socket) = self.sockets.iter_mut().find(|socket| {
394                socket.local_addr == transmit.from
395                    && socket.remote_addr == transmit.to
396                    && !socket.local_closed
397            }) else {
398                warn!(
399                    "no socket for transmit from {} to {}",
400                    transmit.from, transmit.to
401                );
402                return Err(SendError::NoTcpSocket);
403            };
404            if let Err(e) = socket.tls.writer().write_all(&transmit.data) {
405                warn!("Error when writing plaintext: {e:?}");
406                if socket.local_addr == self.protocol.local_addr()
407                    && socket.remote_addr == self.protocol.remote_addr()
408                {
409                    self.protocol.protocol_error();
410                    return Err(SendError::NoAllocation);
411                } else {
412                    return Err(SendError::NoTcpSocket);
413                }
414            }
415
416            if socket.tls.wants_write() {
417                let mut out = vec![];
418                match socket.tls.write_tls(&mut out) {
419                    Ok(_n) => {
420                        return Ok(Some(TransmitBuild::new(
421                            DelayedMessageOrChannelSend::OwnedData(out),
422                            client_transport,
423                            socket.local_addr,
424                            socket.remote_addr,
425                        )))
426                    }
427                    Err(e) => {
428                        warn!("Error when writing TLS records: {e:?}");
429                        if socket.local_addr == self.protocol.local_addr()
430                            && socket.remote_addr == self.protocol.remote_addr()
431                        {
432                            self.protocol.protocol_error();
433                            return Err(SendError::NoAllocation);
434                        } else {
435                            return Err(SendError::NoTcpSocket);
436                        }
437                    }
438                }
439            }
440        }
441
442        Ok(None)
443    }
444
445    #[tracing::instrument(
446        name = "turn_rustls_recv",
447        skip(self, transmit, now),
448        fields(
449            from = ?transmit.from,
450            data_len = transmit.data.as_ref().len()
451        )
452    )]
453    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
454        &mut self,
455        transmit: Transmit<T>,
456        now: Instant,
457    ) -> TurnRecvRet<T> {
458        /* is this data for our client? */
459        if self.transport() != transmit.transport {
460            return TurnRecvRet::Ignored(transmit);
461        }
462        let Some(socket) = self
463            .sockets
464            .iter_mut()
465            .find(|socket| socket.local_addr == transmit.to && socket.remote_addr == transmit.from)
466        else {
467            trace!(
468                "received data not directed at us ({:?}) but for {:?}!",
469                self.local_addr(),
470                transmit.to
471            );
472            return TurnRecvRet::Ignored(transmit);
473        };
474        let mut data = std::io::Cursor::new(transmit.data.as_ref());
475
476        let io_state = match socket.tls.read_tls(&mut data) {
477            Ok(_written) => match socket.tls.process_new_packets() {
478                Ok(io_state) => io_state,
479                Err(e) => {
480                    self.protocol.protocol_error();
481                    warn!("Error processing TLS: {e:?}");
482                    return TurnRecvRet::Ignored(transmit);
483                }
484            },
485            Err(e) => {
486                warn!("Error receiving data: {e:?}");
487                self.protocol.protocol_error();
488                return TurnRecvRet::Ignored(transmit);
489            }
490        };
491        if io_state.plaintext_bytes_to_read() > 0 {
492            let mut out = vec![0; 2048];
493            let n = match socket.tls.reader().read(&mut out) {
494                Ok(n) => n,
495                Err(e) => {
496                    warn!("Error receiving data: {e:?}");
497                    self.protocol.protocol_error();
498                    return TurnRecvRet::Ignored(transmit);
499                }
500            };
501            out.resize(n, 0);
502            let transmit = Transmit::new(out, transmit.transport, transmit.from, transmit.to);
503
504            return match self.protocol.recv(transmit, now) {
505                TurnRecvRet::Ignored(_) => unreachable!(),
506                TurnRecvRet::PeerData(peer_data) => TurnRecvRet::PeerData(peer_data.into_owned()),
507                TurnRecvRet::Handled => TurnRecvRet::Handled,
508                TurnRecvRet::PeerIcmp {
509                    transport,
510                    peer,
511                    icmp_type,
512                    icmp_code,
513                    icmp_data,
514                } => TurnRecvRet::PeerIcmp {
515                    transport,
516                    peer,
517                    icmp_type,
518                    icmp_code,
519                    icmp_data,
520                },
521            };
522        }
523
524        TurnRecvRet::Handled
525    }
526
527    fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>> {
528        self.protocol.poll_recv(now)
529    }
530
531    fn protocol_error(&mut self) {
532        self.protocol.protocol_error()
533    }
534}