turn_server_proto/
api.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! API for TURN servers.
10
11use alloc::string::String;
12use alloc::vec::Vec;
13use core::{net::SocketAddr, time::Duration};
14
15use stun_proto::agent::Transmit;
16use stun_proto::Instant;
17use turn_types::prelude::DelayedTransmitBuild;
18use turn_types::stun::{attribute::ErrorCode, TransportType};
19use turn_types::transmit::{DelayedChannel, DelayedMessage, TransmitBuild};
20use turn_types::AddressFamily;
21
22/// API for TURN servers.
23pub trait TurnServerApi: Send + core::fmt::Debug {
24    /// Add a user credentials that would be accepted by this [`TurnServerApi`].
25    fn add_user(&mut self, username: String, password: String);
26    /// The address that the [`TurnServerApi`] is listening on for incoming client connections.
27    fn listen_address(&self) -> SocketAddr;
28    /// Set the amount of time that a Nonce (used for authentication) will expire and a new Nonce
29    /// will need to be acquired by a client.
30    fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration);
31    /// Provide received data to the [`TurnServerApi`].
32    ///
33    /// Any returned Transmit should be forwarded to the appropriate socket.
34    fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
35        &mut self,
36        transmit: Transmit<T>,
37        now: Instant,
38    ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>;
39    /// Provide a received ICMP packet to the [`TurnServerApi`].
40    ///
41    /// Any returned Transmit should be forwarded to the appropriate socket.
42    fn recv_icmp<T: AsRef<[u8]>>(
43        &mut self,
44        family: AddressFamily,
45        bytes: T,
46        now: Instant,
47    ) -> Option<Transmit<Vec<u8>>>;
48    /// Poll the [`TurnServerApi`] in order to make further progress.
49    ///
50    /// The returned value indicates what the caller should do.
51    fn poll(&mut self, now: Instant) -> TurnServerPollRet;
52    /// Poll for a new Transmit to send over a socket.
53    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>>;
54    /// Notify the [`TurnServerApi`] that a UDP socket has been allocated (or an error) in response to
55    /// [TurnServerPollRet::AllocateSocketUdp].
56    fn allocated_udp_socket(
57        &mut self,
58        transport: TransportType,
59        local_addr: SocketAddr,
60        remote_addr: SocketAddr,
61        family: AddressFamily,
62        socket_addr: Result<SocketAddr, SocketAllocateError>,
63        now: Instant,
64    );
65}
66
67/// Return value for [poll](TurnServerApi::poll).
68#[derive(Debug)]
69pub enum TurnServerPollRet {
70    /// Wait until the specified time before calling poll() again.
71    WaitUntil(Instant),
72    /// Allocate a UDP socket for a client specified by the client's network 5-tuple.
73    AllocateSocketUdp {
74        /// The transport of the client asking for an allocation.
75        transport: TransportType,
76        /// The TURN server address of the client asking for an allocation.
77        local_addr: SocketAddr,
78        /// The client local address of the client asking for an allocation.
79        remote_addr: SocketAddr,
80        /// The address family of the request for an allocation.
81        family: AddressFamily,
82    },
83}
84
85/// Errors that can be conveyed when allocating a socket for a client.
86#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
87pub enum SocketAllocateError {
88    /// The requested Address family is not supported.
89    #[error("The address family is not supported.")]
90    AddressFamilyNotSupported,
91    /// The server does not have the capacity to handle this request.
92    #[error("The server does not have the capacity to handle this request.")]
93    InsufficientCapacity,
94}
95
96impl SocketAllocateError {
97    /// Convert this error into an error code for the `ErrorCode` or `AddressErrorCode` attributes.
98    pub fn into_error_code(self) -> u16 {
99        match self {
100            Self::AddressFamilyNotSupported => ErrorCode::ADDRESS_FAMILY_NOT_SUPPORTED,
101            Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
102        }
103    }
104}
105
106/// Transmission data that needs to be constructed before transmit.
107#[derive(Debug)]
108pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + core::fmt::Debug> {
109    /// A STUN Message.
110    Message(DelayedMessage<T>),
111    /// A Turn Channel Data.
112    Channel(DelayedChannel<T>),
113    /// An already constructed piece of data.
114    Owned(Vec<u8>),
115    /// A subset of the incoming data.
116    Range(T, core::ops::Range<usize>),
117}
118
119impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
120    fn len(&self) -> usize {
121        match self {
122            Self::Message(msg) => msg.len(),
123            Self::Channel(channel) => channel.len(),
124            Self::Owned(v) => v.len(),
125            Self::Range(_data, range) => range.end - range.start,
126        }
127    }
128
129    fn build(self) -> Vec<u8> {
130        match self {
131            Self::Message(msg) => msg.build(),
132            Self::Channel(channel) => channel.build(),
133            Self::Owned(v) => v,
134            Self::Range(data, range) => data.as_ref()[range.start..range.end].to_vec(),
135        }
136    }
137    fn is_empty(&self) -> bool {
138        match self {
139            Self::Message(msg) => msg.is_empty(),
140            Self::Channel(channel) => channel.is_empty(),
141            Self::Owned(v) => v.is_empty(),
142            Self::Range(_data, range) => range.end == range.start,
143        }
144    }
145    fn write_into(self, data: &mut [u8]) -> usize {
146        match self {
147            Self::Message(msg) => msg.write_into(data),
148            Self::Channel(channel) => channel.write_into(data),
149            Self::Owned(v) => v.write_into(data),
150            Self::Range(src, range) => {
151                data.copy_from_slice(&src.as_ref()[range.start..range.end]);
152                range.end - range.start
153            }
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use alloc::vec;
161
162    use turn_types::attribute::Data as AData;
163    use turn_types::attribute::XorPeerAddress;
164    use turn_types::channel::ChannelData;
165    use turn_types::stun::message::Message;
166
167    use super::*;
168
169    fn generate_addresses() -> (SocketAddr, SocketAddr) {
170        (
171            "192.168.0.1:1000".parse().unwrap(),
172            "10.0.0.2:2000".parse().unwrap(),
173        )
174    }
175
176    #[test]
177    fn test_delayed_message() {
178        let (local_addr, remote_addr) = generate_addresses();
179        let data = [5; 5];
180        let peer_addr = "127.0.0.1:1".parse().unwrap();
181        let transmit = TransmitBuild::new(
182            DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
183            TransportType::Udp,
184            local_addr,
185            remote_addr,
186        );
187        assert!(!transmit.data.is_empty());
188        let len = transmit.data.len();
189        let out = transmit.build();
190        assert_eq!(len, out.data.len());
191        let msg = Message::from_bytes(&out.data).unwrap();
192        let addr = msg.attribute::<XorPeerAddress>().unwrap();
193        assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
194        let out_data = msg.attribute::<AData>().unwrap();
195        assert_eq!(out_data.data(), data.as_ref());
196        let transmit = TransmitBuild::new(
197            DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
198            TransportType::Udp,
199            local_addr,
200            remote_addr,
201        );
202        let mut out2 = vec![0; len];
203        transmit.write_into(&mut out2);
204        let msg = Message::from_bytes(&out2).unwrap();
205        let addr = msg.attribute::<XorPeerAddress>().unwrap();
206        assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
207        let out_data = msg.attribute::<AData>().unwrap();
208        assert_eq!(out_data.data(), data.as_ref());
209    }
210
211    #[test]
212    fn test_delayed_channel() {
213        let (local_addr, remote_addr) = generate_addresses();
214        let data = [5; 5];
215        let channel_id = 0x4567;
216        let transmit = TransmitBuild::new(
217            DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
218            TransportType::Udp,
219            local_addr,
220            remote_addr,
221        );
222        assert!(!transmit.data.is_empty());
223        let len = transmit.data.len();
224        let out = transmit.build();
225        assert_eq!(len, out.data.len());
226        let channel = ChannelData::parse(&out.data).unwrap();
227        assert_eq!(channel.id(), channel_id);
228        assert_eq!(channel.data(), data.as_ref());
229        let transmit = TransmitBuild::new(
230            DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
231            TransportType::Udp,
232            local_addr,
233            remote_addr,
234        );
235        let mut out2 = vec![0; len];
236        transmit.write_into(&mut out2);
237        assert_eq!(len, out2.len());
238        let channel = ChannelData::parse(&out2).unwrap();
239        assert_eq!(channel.id(), channel_id);
240        assert_eq!(channel.data(), data.as_ref());
241    }
242
243    #[test]
244    fn test_delayed_owned() {
245        let (local_addr, remote_addr) = generate_addresses();
246        let data = vec![7; 7];
247        let transmit = TransmitBuild::new(
248            DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
249            TransportType::Udp,
250            local_addr,
251            remote_addr,
252        );
253        assert!(!transmit.data.is_empty());
254        let len = transmit.data.len();
255        let out = transmit.build();
256        assert_eq!(len, out.data.len());
257        assert_eq!(data, out.data);
258        let transmit = TransmitBuild::new(
259            DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
260            TransportType::Udp,
261            local_addr,
262            remote_addr,
263        );
264        let mut out2 = vec![0; len];
265        transmit.write_into(&mut out2);
266        assert_eq!(len, out2.len());
267        assert_eq!(data, out2);
268    }
269
270    #[test]
271    fn test_delayed_range() {
272        let (local_addr, remote_addr) = generate_addresses();
273        let data = vec![7; 7];
274        let range = 2..6;
275        const LEN: usize = 4;
276        let transmit = TransmitBuild::new(
277            DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
278            TransportType::Udp,
279            local_addr,
280            remote_addr,
281        );
282        let len = transmit.data.len();
283        assert_eq!(len, LEN);
284        let out = transmit.build();
285        assert_eq!(len, out.data.len());
286        assert_eq!(data[range.start..range.end], out.data);
287        let transmit = TransmitBuild::new(
288            DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
289            TransportType::Udp,
290            local_addr,
291            remote_addr,
292        );
293        let mut out2 = vec![0; len];
294        transmit.write_into(&mut out2);
295        assert_eq!(len, out2.len());
296        assert_eq!(data[range.start..range.end], out2);
297    }
298}