Skip to main content

turn_server_dimpl/
lib.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! # turn-server-dimpl
12//!
13//! A TURN server that can handle DTLS client connections using `dimpl`.
14//!
15//! `turn-server-dimpl` provides a sans-IO API for a TURN server communicating with many TURN clients.
16//!
17//! Relevant standards:
18//! - [RFC5766]: Traversal Using Relays around NAT (TURN).
19//! - [RFC6062]: Traversal Using Relays around NAT (TURN) Extensions for TCP Allocations
20//! - [RFC6156]: Traversal Using Relays around NAT (TURN) Extension for IPv6
21//! - [RFC8656]: Traversal Using Relays around NAT (TURN): Relay Extensions to Session
22//!   Traversal Utilities for NAT (STUN)
23//!
24//! [RFC5766]: https://datatracker.ietf.org/doc/html/rfc5766
25//! [RFC6062]: https://tools.ietf.org/html/rfc6062
26//! [RFC6156]: https://tools.ietf.org/html/rfc6156
27//! [RFC8656]: https://tools.ietf.org/html/rfc8656
28
29#![deny(missing_debug_implementations)]
30#![deny(missing_docs)]
31#![cfg_attr(docsrs, feature(doc_cfg))]
32#![deny(clippy::std_instead_of_core)]
33#![deny(clippy::std_instead_of_alloc)]
34#![no_std]
35
36extern crate alloc;
37
38#[cfg(any(feature = "std", test))]
39extern crate std;
40
41use alloc::collections::VecDeque;
42use alloc::string::String;
43use alloc::sync::Arc;
44use alloc::vec;
45use alloc::vec::Vec;
46use core::net::SocketAddr;
47use core::time::Duration;
48use turn_server_proto::types::prelude::DelayedTransmitBuild;
49use turn_server_proto::types::transmit::TransmitBuild;
50use turn_server_proto::types::AddressFamily;
51
52use turn_server_proto::api::Transmit;
53use turn_server_proto::types::Instant;
54use turn_server_proto::types::stun::TransportType;
55
56pub use turn_server_proto as proto;
57pub use turn_server_proto::api as api;
58
59use turn_server_proto::api::{
60    DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
61};
62use turn_server_proto::server::TurnServer;
63
64use tracing::{info, trace, warn};
65
66/// A TURN server that can handle TLS connections.
67#[derive(Debug)]
68pub struct DimplTurnServer {
69    server: TurnServer,
70    config: Arc<dimpl::Config>,
71    certificate: dimpl::DtlsCertificate,
72    clients: Vec<Client>,
73}
74
75#[derive(Debug)]
76struct Client {
77    client_addr: SocketAddr,
78    dtls: dimpl::Dtls,
79    base_instant: std::time::Instant,
80    base_now: Instant,
81    connected: bool,
82    pending_encrypted: VecDeque<Vec<u8>>,
83    pending_incoming_plaintext: VecDeque<Vec<u8>>,
84}
85
86impl Client {
87    fn poll(&mut self, now: Instant) -> Option<Instant> {
88        let _ = self.dtls.handle_timeout(
89            Instant::from_nanos((now - self.base_now).as_nanos() as i64).to_std(self.base_instant),
90        );
91        let mut out = [0; 2048];
92        let mut earliest_wait = None;
93        loop {
94            match self.dtls.poll_output(&mut out) {
95                dimpl::Output::Packet(p) => {
96                    self.pending_encrypted.push_back(p.to_vec());
97                    earliest_wait = Some(now);
98                }
99                dimpl::Output::Timeout(time) => {
100                    let wait = Instant::from_nanos((time - self.base_instant).as_nanos() as i64);
101                    if wait == now {
102                        let _ = self.dtls.handle_timeout(time);
103                        continue;
104                    }
105                    if earliest_wait.is_none_or(|earliest| earliest > wait) {
106                        earliest_wait = Some(wait);
107                    }
108                    break;
109                }
110                dimpl::Output::Connected => self.connected = true,
111                // TODO: validate certificate
112                dimpl::Output::PeerCert(_peer_cert) => (),
113                dimpl::Output::KeyingMaterial(_key, _srtp_profile) => (),
114                dimpl::Output::ApplicationData(app_data) => {
115                    self.pending_incoming_plaintext.push_back(app_data.to_vec());
116                }
117                _ => (),
118            }
119        }
120        earliest_wait
121    }
122
123    fn poll_plaintext(&mut self) -> Option<Vec<u8>> {
124        self.pending_incoming_plaintext.pop_front()
125    }
126
127    fn poll_encrypted(&mut self) -> Option<Vec<u8>> {
128        self.pending_encrypted.pop_front()
129    }
130}
131
132impl DimplTurnServer {
133    /// Construct a now Turn server that can handle TLS connections.
134    pub fn new(
135        transport: TransportType,
136        listen_addr: SocketAddr,
137        realm: String,
138        config: Arc<dimpl::Config>,
139        certificate: dimpl::DtlsCertificate,
140    ) -> Self {
141        Self {
142            server: TurnServer::new(transport, listen_addr, realm),
143            config,
144            certificate,
145            clients: vec![],
146        }
147    }
148}
149
150impl TurnServerApi for DimplTurnServer {
151    /// Add a user credentials that would be accepted by this [`TurnServer`].
152    fn add_user(&mut self, username: String, password: String) {
153        self.server.add_user(username, password)
154    }
155
156    /// The address that the [`TurnServer`] is listening on for incoming client connections.
157    fn listen_address(&self) -> SocketAddr {
158        self.server.listen_address()
159    }
160
161    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
162    /// will need to be acquired by a client.
163    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
164        self.server.set_nonce_expiry_duration(expiry_duration)
165    }
166
167    /// Provide received data to the [`TurnServer`].
168    ///
169    /// Any returned Transmit should be forwarded to the appropriate socket.
170    #[tracing::instrument(
171        name = "turn_server_dimpl_recv",
172        skip(self, transmit, now),
173        fields(
174            from = ?transmit.from,
175            data_len = transmit.data.as_ref().len()
176        )
177    )]
178    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
179        &mut self,
180        transmit: Transmit<T>,
181        now: Instant,
182    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
183        let listen_address = self.listen_address();
184        if transmit.to == listen_address {
185            trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
186            // incoming client
187            let client = match self
188                .clients
189                .iter_mut()
190                .find(|client| client.client_addr == transmit.from)
191            {
192                Some(client) => client,
193                None => {
194                    let len = self.clients.len();
195                    let base_instant = std::time::Instant::now();
196                    let mut dtls = dimpl::Dtls::new_auto(
197                        self.config.clone(),
198                        self.certificate.clone(),
199                        base_instant,
200                    );
201                    dtls.set_active(false);
202                    let mut client = Client {
203                        client_addr: transmit.from,
204                        dtls,
205                        base_instant,
206                        base_now: now,
207                        connected: false,
208                        pending_encrypted: VecDeque::default(),
209                        pending_incoming_plaintext: VecDeque::default(),
210                    };
211                    // start with poll to ensure that initial setup completes
212                    client.poll(now);
213                    self.clients.push(client);
214                    info!(
215                        "new connection from {} {}",
216                        transmit.transport, transmit.from
217                    );
218                    &mut self.clients[len]
219                }
220            };
221            match client.dtls.handle_packet(transmit.data.as_ref()) {
222                Ok(_) => (),
223                Err(e) => {
224                    warn!("error accepting TLS: {e}");
225                    return None;
226                }
227            };
228
229            client.poll(now);
230            while let Some(plaintext) = client.poll_plaintext() {
231                let Some(transmit) = self.server.recv(
232                    Transmit::new(plaintext, transmit.transport, transmit.from, transmit.to),
233                    now,
234                ) else {
235                    continue;
236                };
237
238                if transmit.from == listen_address && transmit.to == client.client_addr {
239                    client
240                        .dtls
241                        .send_application_data(&transmit.data.build())
242                        .unwrap();
243                    client.poll(now);
244                    let Some(data) = client.poll_encrypted() else {
245                        continue;
246                    };
247                    return Some(TransmitBuild::new(
248                        DelayedMessageOrChannelSend::Owned(data),
249                        transmit.transport,
250                        listen_address,
251                        client.client_addr,
252                    ));
253                } else {
254                    let transmit = transmit.build();
255                    return Some(TransmitBuild::new(
256                        DelayedMessageOrChannelSend::Owned(transmit.data),
257                        transmit.transport,
258                        transmit.from,
259                        transmit.to,
260                    ));
261                }
262            }
263            None
264        } else if let Some(transmit) = self.server.recv(transmit, now) {
265            // incoming allocated address
266            if transmit.from == listen_address {
267                let Some(client) = self
268                    .clients
269                    .iter_mut()
270                    .find(|client| transmit.to == client.client_addr)
271                else {
272                    return Some(transmit);
273                };
274
275                let _ = client.dtls.send_application_data(&transmit.data.build());
276                client.poll(now);
277                client.poll_encrypted().map(|encrypted| {
278                    TransmitBuild::new(
279                        DelayedMessageOrChannelSend::Owned(encrypted),
280                        transmit.transport,
281                        listen_address,
282                        client.client_addr,
283                    )
284                })
285            } else {
286                Some(transmit)
287            }
288        } else {
289            None
290        }
291    }
292
293    fn recv_icmp<T: AsRef<[u8]>>(
294        &mut self,
295        family: AddressFamily,
296        bytes: T,
297        now: Instant,
298    ) -> Option<Transmit<Vec<u8>>> {
299        let transmit = self.server.recv_icmp(family, bytes, now)?;
300        // incoming allocated address
301        let listen_address = self.listen_address();
302        if transmit.from == listen_address {
303            let Some(client) = self
304                .clients
305                .iter_mut()
306                .find(|client| transmit.to == client.client_addr)
307            else {
308                return Some(transmit);
309            };
310
311            client.dtls.send_application_data(&transmit.data).unwrap();
312            client.poll(now);
313            client.poll_encrypted().map(|encrypted| {
314                Transmit::new(
315                    encrypted,
316                    transmit.transport,
317                    listen_address,
318                    client.client_addr,
319                )
320            })
321        } else {
322            Some(transmit)
323        }
324    }
325
326    /// Poll the [`TurnServer`] in order to make further progress.
327    ///
328    /// The returned value indicates what the caller should do.
329    fn poll(&mut self, now: Instant) -> TurnServerPollRet {
330        let protocol_ret = self.server.poll(now);
331        let mut have_pending = false;
332        for client in self.clients.iter_mut() {
333            client.poll(now);
334            if !client.pending_encrypted.is_empty() {
335                have_pending = true;
336                continue;
337            }
338        }
339        if have_pending {
340            return TurnServerPollRet::WaitUntil(now);
341        }
342        protocol_ret
343    }
344
345    /// Poll for a new Transmit to send over a socket.
346    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
347        let listen_address = self.listen_address();
348
349        for client in self.clients.iter_mut() {
350            if let Some(data) = client.poll_encrypted() {
351                return Some(Transmit::new(
352                    data,
353                    TransportType::Udp,
354                    listen_address,
355                    client.client_addr,
356                ));
357            }
358        }
359
360        while let Some(transmit) = self.server.poll_transmit(now) {
361            let Some(client) = self
362                .clients
363                .iter_mut()
364                .find(|client| transmit.to == client.client_addr)
365            else {
366                warn!("return transmit: {transmit:?}");
367                return Some(transmit);
368            };
369            client.dtls.send_application_data(&transmit.data).unwrap();
370            client.poll(now);
371
372            if let Some(data) = client.poll_encrypted() {
373                return Some(Transmit::new(
374                    data,
375                    TransportType::Udp,
376                    listen_address,
377                    client.client_addr,
378                ));
379            }
380        }
381        None
382    }
383
384    /// Notify the [`TurnServer`] that a socket has been allocated (or an error) in response to
385    /// [TurnServerPollRet::AllocateSocket].
386    fn allocated_socket(
387        &mut self,
388        transport: TransportType,
389        local_addr: SocketAddr,
390        remote_addr: SocketAddr,
391        allocation_transport: TransportType,
392        family: AddressFamily,
393        socket_addr: Result<SocketAddr, SocketAllocateError>,
394        now: Instant,
395    ) {
396        self.server.allocated_socket(
397            transport,
398            local_addr,
399            remote_addr,
400            allocation_transport,
401            family,
402            socket_addr,
403            now,
404        )
405    }
406
407    fn tcp_connected(
408        &mut self,
409        relayed_addr: SocketAddr,
410        peer_addr: SocketAddr,
411        listen_addr: SocketAddr,
412        client_addr: SocketAddr,
413        socket_addr: Result<SocketAddr, api::TcpConnectError>,
414        now: Instant,
415    ) {
416        self.server.tcp_connected(
417            relayed_addr,
418            peer_addr,
419            listen_addr,
420            client_addr,
421            socket_addr,
422            now,
423        )
424    }
425}
426
427
428#[cfg(test)]
429mod tests {
430    use tracing::subscriber::DefaultGuard;
431    use tracing_subscriber::layer::SubscriberExt;
432    use tracing_subscriber::Layer;
433
434    use super::*;
435
436    fn test_init_log() -> DefaultGuard {
437        crate::proto::types::debug_init();
438        let level_filter = std::env::var("TURN_LOG")
439            .or(std::env::var("RUST_LOG"))
440            .ok()
441            .and_then(|var| var.parse::<tracing_subscriber::filter::Targets>().ok())
442            .unwrap_or(
443                tracing_subscriber::filter::Targets::new().with_default(tracing::Level::TRACE),
444            );
445        let registry = tracing_subscriber::registry().with(
446            tracing_subscriber::fmt::layer()
447                .with_file(true)
448                .with_line_number(true)
449                .with_level(true)
450                .with_target(false)
451                .with_test_writer()
452                .with_filter(level_filter),
453        );
454        tracing::subscriber::set_default(registry)
455    }
456
457    fn generate_cert() -> dimpl::DtlsCertificate {
458        dimpl::certificate::generate_self_signed_certificate().unwrap()
459    }
460
461    #[test]
462    fn constructor() {
463        let _log = test_init_log();
464        let config = Arc::new(dimpl::Config::builder().build().unwrap());
465        let listen_addr = "127.0.0.1:3478".parse().unwrap();
466        let realm = String::from("realm");
467        let cert = generate_cert();
468        let server = DimplTurnServer::new(TransportType::Udp, listen_addr, realm, config, cert);
469        assert_eq!(server.listen_address(), listen_addr);
470    }
471}