Skip to main content

turn_server_proto/
api.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! API for TURN servers.
12
13use alloc::string::String;
14use alloc::vec::Vec;
15use core::{net::SocketAddr, time::Duration};
16
17pub use stun_proto::agent::Transmit;
18use turn_types::prelude::DelayedTransmitBuild;
19use turn_types::stun::{attribute::ErrorCode, TransportType};
20use turn_types::transmit::{DelayedChannel, DelayedMessage, TransmitBuild};
21use turn_types::AddressFamily;
22use turn_types::Instant;
23
24/// API for TURN servers.
25pub trait TurnServerApi: Send + core::fmt::Debug {
26    /// Add a user credentials that would be accepted by this [`TurnServerApi`].
27    fn add_user(&mut self, username: String, password: String);
28    /// The address that the [`TurnServerApi`] is listening on for incoming client connections.
29    fn listen_address(&self) -> SocketAddr;
30    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
31    /// will need to be acquired by a client.
32    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration);
33    /// Provide received data to the [`TurnServerApi`].
34    ///
35    /// Any returned Transmit should be forwarded to the appropriate socket.
36    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
37        &mut self,
38        transmit: Transmit<T>,
39        now: Instant,
40    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>;
41    /// Provide a received ICMP packet to the [`TurnServerApi`].
42    ///
43    /// Any returned Transmit should be forwarded to the appropriate socket.
44    fn recv_icmp<T: AsRef<[u8]>>(
45        &mut self,
46        family: AddressFamily,
47        bytes: T,
48        now: Instant,
49    ) -> Option<Transmit<Vec<u8>>>;
50    /// Poll the [`TurnServerApi`] in order to make further progress.
51    ///
52    /// The returned value indicates what the caller should do.
53    fn poll(&mut self, now: Instant) -> TurnServerPollRet;
54    /// Poll for a new Transmit to send over a socket.
55    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>>;
56    /// Notify the [`TurnServerApi`] that a UDP socket has been allocated (or an error) in response to
57    /// [TurnServerPollRet::AllocateSocket].
58    #[allow(clippy::too_many_arguments)]
59    fn allocated_socket(
60        &mut self,
61        transport: TransportType,
62        listen_addr: SocketAddr,
63        client_addr: SocketAddr,
64        allocation_transport: TransportType,
65        family: AddressFamily,
66        socket_addr: Result<SocketAddr, SocketAllocateError>,
67        now: Instant,
68    );
69    /// Indicate that a TCP connection has been configured (or an error) for a client to
70    /// connect over TCP with a peer.
71    fn tcp_connected(
72        &mut self,
73        relayed_addr: SocketAddr,
74        peer_addr: SocketAddr,
75        listen_addr: SocketAddr,
76        client_addr: SocketAddr,
77        socket_addr: Result<SocketAddr, TcpConnectError>,
78        now: Instant,
79    );
80}
81
82/// Return value for [poll](TurnServerApi::poll).
83#[derive(Debug)]
84pub enum TurnServerPollRet {
85    /// Wait until the specified time before calling poll() again.
86    WaitUntil(Instant),
87    /// Allocate a listening socket for a client specified by the client's network 5-tuple.
88    AllocateSocket {
89        /// The transport of the client asking for an allocation.
90        transport: TransportType,
91        /// The TURN server address of the client asking for an allocation.
92        listen_addr: SocketAddr,
93        /// The client local address of the client asking for an allocation.
94        client_addr: SocketAddr,
95        /// The requested allocation transport.
96        allocation_transport: TransportType,
97        /// The address family of the request for an allocation.
98        family: AddressFamily,
99    },
100    /// Connect to a peer over TCP.
101    TcpConnect {
102        /// The relayed address to connect from.
103        relayed_addr: SocketAddr,
104        /// The peer to connect to.
105        peer_addr: SocketAddr,
106        /// The TURN server address of the client asking for an allocation.
107        listen_addr: SocketAddr,
108        /// The client's local address (TURN server remote) of the client asking for an allocation.
109        client_addr: SocketAddr,
110    },
111    /// Close a TCP connection between the TURN server and a peer/client.
112    ///
113    /// The connection can be in progress of being setup.
114    ///
115    /// Connections to both TURN clients and peers can be signalled through this variant.
116    TcpClose {
117        /// The socket address local to the TURN server.
118        local_addr: SocketAddr,
119        /// The socket address of the remote peer.
120        remote_addr: SocketAddr,
121    },
122    /// Stop listening on a socket.
123    SocketClose {
124        /// The transport of the socket.
125        transport: TransportType,
126        /// The listening socket address.
127        listen_addr: SocketAddr,
128    },
129}
130
131/// Errors that can be conveyed when allocating a socket for a client.
132#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
133pub enum SocketAllocateError {
134    /// The requested Address family is not supported.
135    #[error("The address family is not supported.")]
136    AddressFamilyNotSupported,
137    /// The server does not have the capacity to handle this request.
138    #[error("The server does not have the capacity to handle this request.")]
139    InsufficientCapacity,
140}
141
142impl SocketAllocateError {
143    /// Convert this error into an error code for the `ErrorCode` or `AddressErrorCode` attributes.
144    pub fn into_error_code(self) -> u16 {
145        match self {
146            Self::AddressFamilyNotSupported => ErrorCode::ADDRESS_FAMILY_NOT_SUPPORTED,
147            Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
148        }
149    }
150}
151
152/// Errors that can be conveyed when allocating a socket for a client.
153#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
154pub enum TcpConnectError {
155    /// The server does not have the capacity to handle this request.
156    #[error("The server does not have the capacity to handle this request.")]
157    InsufficientCapacity,
158    /// Connection is forbidden by local policy.
159    #[error("Connection is forbidden by local policy.")]
160    Forbidden,
161    /// Timed Out attempting to connect to the specified peer.
162    #[error("Timed out attempting to connect to the specifid peer.")]
163    TimedOut,
164    /// Faild for any other unspecified reason.
165    #[error("Failed for any other unspecified reason.")]
166    Failure,
167}
168
169impl TcpConnectError {
170    /// Convert this error into an error code for the `ErrorCode` or `AddressErrorCode` attributes.
171    pub fn into_error_code(self) -> u16 {
172        match self {
173            Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
174            Self::Forbidden => ErrorCode::FORBIDDEN,
175            Self::TimedOut | Self::Failure => ErrorCode::CONNECTION_TIMEOUT_OR_FAILURE,
176        }
177    }
178}
179
180/// Transmission data that needs to be constructed before transmit.
181#[derive(Debug)]
182pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + core::fmt::Debug> {
183    /// A STUN Message.
184    Message(DelayedMessage<T>),
185    /// A Turn Channel Data.
186    Channel(DelayedChannel<T>),
187    /// An already constructed piece of data.
188    Owned(Vec<u8>),
189    /// A subset of the incoming data.
190    Range(T, core::ops::Range<usize>),
191}
192
193impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
194    fn len(&self) -> usize {
195        match self {
196            Self::Message(msg) => msg.len(),
197            Self::Channel(channel) => channel.len(),
198            Self::Owned(v) => v.len(),
199            Self::Range(_data, range) => range.end - range.start,
200        }
201    }
202
203    fn build(self) -> Vec<u8> {
204        match self {
205            Self::Message(msg) => msg.build(),
206            Self::Channel(channel) => channel.build(),
207            Self::Owned(v) => v,
208            Self::Range(data, range) => data.as_ref()[range.start..range.end].to_vec(),
209        }
210    }
211    fn is_empty(&self) -> bool {
212        match self {
213            Self::Message(msg) => msg.is_empty(),
214            Self::Channel(channel) => channel.is_empty(),
215            Self::Owned(v) => v.is_empty(),
216            Self::Range(_data, range) => range.end == range.start,
217        }
218    }
219    fn write_into(self, data: &mut [u8]) -> usize {
220        match self {
221            Self::Message(msg) => msg.write_into(data),
222            Self::Channel(channel) => channel.write_into(data),
223            Self::Owned(v) => v.write_into(data),
224            Self::Range(src, range) => {
225                data.copy_from_slice(&src.as_ref()[range.start..range.end]);
226                range.end - range.start
227            }
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use alloc::vec;
235
236    use turn_types::attribute::Data as AData;
237    use turn_types::attribute::XorPeerAddress;
238    use turn_types::channel::ChannelData;
239    use turn_types::stun::message::Message;
240
241    use super::*;
242
243    fn generate_addresses() -> (SocketAddr, SocketAddr) {
244        (
245            "192.168.0.1:1000".parse().unwrap(),
246            "10.0.0.2:2000".parse().unwrap(),
247        )
248    }
249
250    #[test]
251    fn test_delayed_message() {
252        let (local_addr, remote_addr) = generate_addresses();
253        let data = [5; 5];
254        let peer_addr = "127.0.0.1:1".parse().unwrap();
255        let transmit = TransmitBuild::new(
256            DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
257            TransportType::Udp,
258            local_addr,
259            remote_addr,
260        );
261        assert!(!transmit.data.is_empty());
262        let len = transmit.data.len();
263        let out = transmit.build();
264        assert_eq!(len, out.data.len());
265        let msg = Message::from_bytes(&out.data).unwrap();
266        let addr = msg.attribute::<XorPeerAddress>().unwrap();
267        assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
268        let out_data = msg.attribute::<AData>().unwrap();
269        assert_eq!(out_data.data(), data.as_ref());
270        let transmit = TransmitBuild::new(
271            DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
272            TransportType::Udp,
273            local_addr,
274            remote_addr,
275        );
276        let mut out2 = vec![0; len];
277        transmit.write_into(&mut out2);
278        let msg = Message::from_bytes(&out2).unwrap();
279        let addr = msg.attribute::<XorPeerAddress>().unwrap();
280        assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
281        let out_data = msg.attribute::<AData>().unwrap();
282        assert_eq!(out_data.data(), data.as_ref());
283    }
284
285    #[test]
286    fn test_delayed_channel() {
287        let (local_addr, remote_addr) = generate_addresses();
288        let data = [5; 5];
289        let channel_id = 0x4567;
290        let transmit = TransmitBuild::new(
291            DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
292            TransportType::Udp,
293            local_addr,
294            remote_addr,
295        );
296        assert!(!transmit.data.is_empty());
297        let len = transmit.data.len();
298        let out = transmit.build();
299        assert_eq!(len, out.data.len());
300        let channel = ChannelData::parse(&out.data).unwrap();
301        assert_eq!(channel.id(), channel_id);
302        assert_eq!(channel.data(), data.as_ref());
303        let transmit = TransmitBuild::new(
304            DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
305            TransportType::Udp,
306            local_addr,
307            remote_addr,
308        );
309        let mut out2 = vec![0; len];
310        transmit.write_into(&mut out2);
311        assert_eq!(len, out2.len());
312        let channel = ChannelData::parse(&out2).unwrap();
313        assert_eq!(channel.id(), channel_id);
314        assert_eq!(channel.data(), data.as_ref());
315    }
316
317    #[test]
318    fn test_delayed_owned() {
319        let (local_addr, remote_addr) = generate_addresses();
320        let data = vec![7; 7];
321        let transmit = TransmitBuild::new(
322            DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
323            TransportType::Udp,
324            local_addr,
325            remote_addr,
326        );
327        assert!(!transmit.data.is_empty());
328        let len = transmit.data.len();
329        let out = transmit.build();
330        assert_eq!(len, out.data.len());
331        assert_eq!(data, out.data);
332        let transmit = TransmitBuild::new(
333            DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
334            TransportType::Udp,
335            local_addr,
336            remote_addr,
337        );
338        let mut out2 = vec![0; len];
339        transmit.write_into(&mut out2);
340        assert_eq!(len, out2.len());
341        assert_eq!(data, out2);
342    }
343
344    #[test]
345    fn test_delayed_range() {
346        let (local_addr, remote_addr) = generate_addresses();
347        let data = vec![7; 7];
348        let range = 2..6;
349        const LEN: usize = 4;
350        let transmit = TransmitBuild::new(
351            DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
352            TransportType::Udp,
353            local_addr,
354            remote_addr,
355        );
356        let len = transmit.data.len();
357        assert_eq!(len, LEN);
358        let out = transmit.build();
359        assert_eq!(len, out.data.len());
360        assert_eq!(data[range.start..range.end], out.data);
361        let transmit = TransmitBuild::new(
362            DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
363            TransportType::Udp,
364            local_addr,
365            remote_addr,
366        );
367        let mut out2 = vec![0; len];
368        transmit.write_into(&mut out2);
369        assert_eq!(len, out2.len());
370        assert_eq!(data[range.start..range.end], out2);
371    }
372}