turn_server_proto/
openssl.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! A TURN server that can handle TLS client connections.
10
11use alloc::collections::VecDeque;
12use alloc::string::String;
13use alloc::vec;
14use alloc::vec::Vec;
15use core::net::SocketAddr;
16use core::time::Duration;
17use std::io::{Read, Write};
18use turn_types::prelude::DelayedTransmitBuild;
19use turn_types::transmit::TransmitBuild;
20use turn_types::AddressFamily;
21
22use openssl::ssl::{HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslStream};
23use stun_proto::agent::Transmit;
24use stun_proto::Instant;
25use tracing::{info, trace, warn};
26use turn_types::stun::TransportType;
27
28use crate::api::{
29    DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
30};
31use crate::server::TurnServer;
32
33/// A TURN server that can handle TLS connections.
34#[derive(Debug)]
35pub struct OpensslTurnServer {
36    server: TurnServer,
37    ssl_context: SslContext,
38    connections: Vec<(SocketAddr, HandshakeState)>,
39}
40
41#[derive(Debug)]
42enum HandshakeState {
43    Init(Ssl, OsslBio),
44    Handshaking(MidHandshakeSslStream<OsslBio>),
45    Done(SslStream<OsslBio>),
46    Nothing,
47}
48
49impl HandshakeState {
50    fn complete(&mut self) -> Result<&mut SslStream<OsslBio>, std::io::Error> {
51        if let Self::Done(s) = self {
52            return Ok(s);
53        }
54        let taken = core::mem::replace(self, Self::Nothing);
55
56        let ret = match taken {
57            Self::Init(ssl, bio) => ssl.accept(bio),
58            Self::Handshaking(mid) => mid.handshake(),
59            Self::Done(_) | Self::Nothing => unreachable!(),
60        };
61
62        match ret {
63            Ok(s) => {
64                info!(
65                    "SSL handshake completed with version {} cipher: {:?}",
66                    s.ssl().version_str(),
67                    s.ssl().current_cipher()
68                );
69                *self = Self::Done(s);
70                Ok(self.complete()?)
71            }
72            Err(HandshakeError::WouldBlock(mid)) => {
73                *self = Self::Handshaking(mid);
74                Err(std::io::Error::new(
75                    std::io::ErrorKind::WouldBlock,
76                    "Would Block",
77                ))
78            }
79            Err(HandshakeError::SetupFailure(e)) => {
80                warn!("Error during ssl setup: {e}");
81                Err(std::io::Error::new(
82                    std::io::ErrorKind::ConnectionRefused,
83                    e,
84                ))
85            }
86            Err(HandshakeError::Failure(mid)) => {
87                warn!("Failure during ssl setup: {}", mid.error());
88                *self = Self::Handshaking(mid);
89                Err(std::io::Error::new(
90                    std::io::ErrorKind::WouldBlock,
91                    "Would Block",
92                ))
93            }
94        }
95    }
96    fn inner_mut(&mut self) -> &mut OsslBio {
97        match self {
98            Self::Init(_ssl, stream) => stream,
99            Self::Handshaking(mid) => mid.get_mut(),
100            Self::Done(stream) => stream.get_mut(),
101            Self::Nothing => unreachable!(),
102        }
103    }
104}
105
106#[derive(Debug, Default)]
107struct OsslBio {
108    incoming: Vec<u8>,
109    outgoing: VecDeque<Vec<u8>>,
110}
111
112impl OsslBio {
113    fn push_incoming(&mut self, buf: &[u8]) {
114        self.incoming.extend_from_slice(buf)
115    }
116
117    fn pop_outgoing(&mut self) -> Option<Vec<u8>> {
118        self.outgoing.pop_front()
119    }
120}
121
122impl std::io::Write for OsslBio {
123    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
124        self.outgoing.push_back(buf.to_vec());
125        Ok(buf.len())
126    }
127
128    fn flush(&mut self) -> std::io::Result<()> {
129        Ok(())
130    }
131}
132
133impl std::io::Read for OsslBio {
134    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
135        let len = self.incoming.len();
136        let max = buf.len().min(len);
137
138        if len == 0 {
139            return Err(std::io::Error::new(
140                std::io::ErrorKind::WouldBlock,
141                "Would Block",
142            ));
143        }
144
145        buf[..max].copy_from_slice(&self.incoming[..max]);
146        if max == len {
147            self.incoming.truncate(0);
148        } else {
149            self.incoming.drain(..max);
150        }
151
152        Ok(max)
153    }
154}
155
156impl OpensslTurnServer {
157    /// Construct a now Turn server that can handle TLS connections.
158    pub fn new(
159        transport: TransportType,
160        listen_addr: SocketAddr,
161        realm: String,
162        ssl_context: SslContext,
163    ) -> Self {
164        Self {
165            server: TurnServer::new(transport, listen_addr, realm),
166            ssl_context,
167            connections: vec![],
168        }
169    }
170}
171
172impl TurnServerApi for OpensslTurnServer {
173    /// Add a user credentials that would be accepted by this [`TurnServer`].
174    fn add_user(&mut self, username: String, password: String) {
175        self.server.add_user(username, password)
176    }
177
178    /// The address that the [`TurnServer`] is listening on for incoming client connections.
179    fn listen_address(&self) -> SocketAddr {
180        self.server.listen_address()
181    }
182
183    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
184    /// will need to be acquired by a client.
185    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
186        self.server.set_nonce_expiry_duration(expiry_duration)
187    }
188
189    /// Provide received data to the [`TurnServer`].
190    ///
191    /// Any returned Transmit should be forwarded to the appropriate socket.
192    #[tracing::instrument(
193        name = "turn_server_openssl_recv",
194        skip(self, transmit, now),
195        fields(
196            from = ?transmit.from,
197            data_len = transmit.data.as_ref().len()
198        )
199    )]
200    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
201        &mut self,
202        transmit: Transmit<T>,
203        now: Instant,
204    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
205        let listen_address = self.listen_address();
206        if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
207            trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
208            // incoming client
209            let (client_addr, conn) = match self
210                .connections
211                .iter_mut()
212                .find(|(client_addr, _conn)| *client_addr == transmit.from)
213            {
214                Some((client_addr, conn)) => (*client_addr, conn),
215                None => {
216                    let len = self.connections.len();
217                    let ssl = Ssl::new(&self.ssl_context).expect("Cannot create ssl structure");
218                    self.connections
219                        .push((transmit.from, HandshakeState::Init(ssl, OsslBio::default())));
220                    info!("new connection from {}", transmit.from);
221                    let ret = &mut self.connections[len];
222                    (ret.0, &mut ret.1)
223                }
224            };
225            conn.inner_mut().push_incoming(transmit.data.as_ref());
226            let stream = match conn.complete() {
227                Ok(s) => s,
228                Err(e) => {
229                    if e.kind() != std::io::ErrorKind::WouldBlock {
230                        warn!("error accepting TLS: {e}");
231                    }
232                    return None;
233                }
234            };
235
236            let mut plaintext = vec![0; 2048];
237            let len = match stream.read(&mut plaintext) {
238                Ok(len) => len,
239                Err(e) => {
240                    if e.kind() != std::io::ErrorKind::WouldBlock {
241                        tracing::warn!("Error: {e}");
242                    }
243                    return None;
244                }
245            };
246            plaintext.resize(len, 0);
247
248            let transmit = self.server.recv(
249                Transmit::new(plaintext, transmit.transport, transmit.from, transmit.to),
250                now,
251            )?;
252
253            if transmit.transport == TransportType::Tcp
254                && transmit.from == listen_address
255                && transmit.to == client_addr
256            {
257                let plaintext = transmit.data.build();
258                stream.write_all(&plaintext).unwrap();
259                stream.get_mut().pop_outgoing().map(|data| {
260                    TransmitBuild::new(
261                        DelayedMessageOrChannelSend::Owned(data),
262                        TransportType::Tcp,
263                        listen_address,
264                        client_addr,
265                    )
266                })
267            } else {
268                let transmit = transmit.build();
269                Some(TransmitBuild::new(
270                    DelayedMessageOrChannelSend::Owned(transmit.data),
271                    transmit.transport,
272                    transmit.from,
273                    transmit.to,
274                ))
275            }
276        } else if let Some(transmit) = self.server.recv(transmit, now) {
277            // incoming allocated address
278            if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
279                let Some((client_addr, conn)) = self
280                    .connections
281                    .iter_mut()
282                    .find(|(client_addr, _conn)| transmit.to == *client_addr)
283                else {
284                    return Some(transmit);
285                };
286
287                let plaintext = transmit.data.build();
288                let stream = match conn.complete() {
289                    Ok(s) => s,
290                    Err(e) => {
291                        if e.kind() != std::io::ErrorKind::WouldBlock {
292                            warn!("error accepting TLS: {e}");
293                        }
294                        return None;
295                    }
296                };
297                stream.write_all(&plaintext).unwrap();
298                stream.get_mut().pop_outgoing().map(|data| {
299                    TransmitBuild::new(
300                        DelayedMessageOrChannelSend::Owned(data),
301                        TransportType::Tcp,
302                        listen_address,
303                        *client_addr,
304                    )
305                })
306            } else {
307                Some(transmit)
308            }
309        } else {
310            None
311        }
312    }
313
314    fn recv_icmp<T: AsRef<[u8]>>(
315        &mut self,
316        family: AddressFamily,
317        bytes: T,
318        now: Instant,
319    ) -> Option<Transmit<Vec<u8>>> {
320        let transmit = self.server.recv_icmp(family, bytes, now)?;
321        // incoming allocated address
322        let listen_address = self.listen_address();
323        if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
324            let Some((client_addr, conn)) = self
325                .connections
326                .iter_mut()
327                .find(|(client_addr, _conn)| transmit.to == *client_addr)
328            else {
329                return Some(transmit);
330            };
331            let stream = match conn.complete() {
332                Ok(s) => s,
333                Err(e) => {
334                    if e.kind() != std::io::ErrorKind::WouldBlock {
335                        warn!("error accepting TLS: {e}");
336                    }
337                    return None;
338                }
339            };
340            stream.write_all(&transmit.data).unwrap();
341            stream
342                .get_mut()
343                .pop_outgoing()
344                .map(|data| Transmit::new(data, TransportType::Tcp, listen_address, *client_addr))
345        } else {
346            Some(transmit)
347        }
348    }
349
350    /// Poll the [`TurnServer`] in order to make further progress.
351    ///
352    /// The returned value indicates what the caller should do.
353    fn poll(&mut self, now: Instant) -> TurnServerPollRet {
354        let protocol_ret = self.server.poll(now);
355        let mut have_pending = false;
356        for (_client_addr, conn) in self.connections.iter_mut() {
357            let stream = match conn.complete() {
358                Ok(s) => s,
359                Err(_) => continue,
360            };
361            if !stream.get_mut().outgoing.is_empty() {
362                have_pending = true;
363            }
364        }
365        if have_pending {
366            return TurnServerPollRet::WaitUntil(now);
367        }
368        protocol_ret
369    }
370
371    /// Poll for a new Transmit to send over a socket.
372    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
373        let listen_address = self.listen_address();
374
375        for (client_addr, conn) in self.connections.iter_mut() {
376            if let Some(data) = conn.inner_mut().pop_outgoing() {
377                return Some(Transmit::new(
378                    data,
379                    TransportType::Tcp,
380                    listen_address,
381                    *client_addr,
382                ));
383            }
384        }
385
386        while let Some(transmit) = self.server.poll_transmit(now) {
387            let Some((client_addr, conn)) = self
388                .connections
389                .iter_mut()
390                .find(|(client_addr, _conn)| transmit.to == *client_addr)
391            else {
392                warn!("return transmit: {transmit:?}");
393                return Some(transmit);
394            };
395            let stream = match conn.complete() {
396                Ok(s) => s,
397                // FIXME: how to deal with early data
398                Err(_) => continue,
399            };
400            stream.write_all(&transmit.data).unwrap();
401
402            if let Some(data) = conn.inner_mut().pop_outgoing() {
403                return Some(Transmit::new(
404                    data,
405                    TransportType::Tcp,
406                    listen_address,
407                    *client_addr,
408                ));
409            }
410        }
411        None
412    }
413
414    /// Notify the [`TurnServer`] that a UDP socket has been allocated (or an error) in response to
415    /// [TurnServerPollRet::AllocateSocketUdp].
416    fn allocated_udp_socket(
417        &mut self,
418        transport: TransportType,
419        local_addr: SocketAddr,
420        remote_addr: SocketAddr,
421        family: AddressFamily,
422        socket_addr: Result<SocketAddr, SocketAllocateError>,
423        now: Instant,
424    ) {
425        self.server.allocated_udp_socket(
426            transport,
427            local_addr,
428            remote_addr,
429            family,
430            socket_addr,
431            now,
432        )
433    }
434}