1use alloc::string::String;
14use alloc::vec::Vec;
15use core::{net::SocketAddr, time::Duration};
16
17pub use stun_proto::agent::Transmit;
18use turn_types::prelude::DelayedTransmitBuild;
19use turn_types::stun::{attribute::ErrorCode, TransportType};
20use turn_types::transmit::{DelayedChannel, DelayedMessage, TransmitBuild};
21use turn_types::AddressFamily;
22use turn_types::Instant;
23
24pub trait TurnServerApi: Send + core::fmt::Debug {
26 fn add_user(&mut self, username: String, password: String);
28 fn listen_address(&self) -> SocketAddr;
30 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration);
33 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
37 &mut self,
38 transmit: Transmit<T>,
39 now: Instant,
40 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>;
41 fn recv_icmp<T: AsRef<[u8]>>(
45 &mut self,
46 family: AddressFamily,
47 bytes: T,
48 now: Instant,
49 ) -> Option<Transmit<Vec<u8>>>;
50 fn poll(&mut self, now: Instant) -> TurnServerPollRet;
54 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>>;
56 #[allow(clippy::too_many_arguments)]
59 fn allocated_socket(
60 &mut self,
61 transport: TransportType,
62 listen_addr: SocketAddr,
63 client_addr: SocketAddr,
64 allocation_transport: TransportType,
65 family: AddressFamily,
66 socket_addr: Result<SocketAddr, SocketAllocateError>,
67 now: Instant,
68 );
69 fn tcp_connected(
72 &mut self,
73 relayed_addr: SocketAddr,
74 peer_addr: SocketAddr,
75 listen_addr: SocketAddr,
76 client_addr: SocketAddr,
77 socket_addr: Result<SocketAddr, TcpConnectError>,
78 now: Instant,
79 );
80}
81
82#[derive(Debug)]
84pub enum TurnServerPollRet {
85 WaitUntil(Instant),
87 AllocateSocket {
89 transport: TransportType,
91 listen_addr: SocketAddr,
93 client_addr: SocketAddr,
95 allocation_transport: TransportType,
97 family: AddressFamily,
99 },
100 TcpConnect {
102 relayed_addr: SocketAddr,
104 peer_addr: SocketAddr,
106 listen_addr: SocketAddr,
108 client_addr: SocketAddr,
110 },
111 TcpClose {
117 local_addr: SocketAddr,
119 remote_addr: SocketAddr,
121 },
122 SocketClose {
124 transport: TransportType,
126 listen_addr: SocketAddr,
128 },
129}
130
131#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
133pub enum SocketAllocateError {
134 #[error("The address family is not supported.")]
136 AddressFamilyNotSupported,
137 #[error("The server does not have the capacity to handle this request.")]
139 InsufficientCapacity,
140}
141
142impl SocketAllocateError {
143 pub fn into_error_code(self) -> u16 {
145 match self {
146 Self::AddressFamilyNotSupported => ErrorCode::ADDRESS_FAMILY_NOT_SUPPORTED,
147 Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
148 }
149 }
150}
151
152#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
154pub enum TcpConnectError {
155 #[error("The server does not have the capacity to handle this request.")]
157 InsufficientCapacity,
158 #[error("Connection is forbidden by local policy.")]
160 Forbidden,
161 #[error("Timed out attempting to connect to the specifid peer.")]
163 TimedOut,
164 #[error("Failed for any other unspecified reason.")]
166 Failure,
167}
168
169impl TcpConnectError {
170 pub fn into_error_code(self) -> u16 {
172 match self {
173 Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
174 Self::Forbidden => ErrorCode::FORBIDDEN,
175 Self::TimedOut | Self::Failure => ErrorCode::CONNECTION_TIMEOUT_OR_FAILURE,
176 }
177 }
178}
179
180#[derive(Debug)]
182pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + core::fmt::Debug> {
183 Message(DelayedMessage<T>),
185 Channel(DelayedChannel<T>),
187 Owned(Vec<u8>),
189 Range(T, core::ops::Range<usize>),
191}
192
193impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
194 fn len(&self) -> usize {
195 match self {
196 Self::Message(msg) => msg.len(),
197 Self::Channel(channel) => channel.len(),
198 Self::Owned(v) => v.len(),
199 Self::Range(_data, range) => range.end - range.start,
200 }
201 }
202
203 fn build(self) -> Vec<u8> {
204 match self {
205 Self::Message(msg) => msg.build(),
206 Self::Channel(channel) => channel.build(),
207 Self::Owned(v) => v,
208 Self::Range(data, range) => data.as_ref()[range.start..range.end].to_vec(),
209 }
210 }
211 fn is_empty(&self) -> bool {
212 match self {
213 Self::Message(msg) => msg.is_empty(),
214 Self::Channel(channel) => channel.is_empty(),
215 Self::Owned(v) => v.is_empty(),
216 Self::Range(_data, range) => range.end == range.start,
217 }
218 }
219 fn write_into(self, data: &mut [u8]) -> usize {
220 match self {
221 Self::Message(msg) => msg.write_into(data),
222 Self::Channel(channel) => channel.write_into(data),
223 Self::Owned(v) => v.write_into(data),
224 Self::Range(src, range) => {
225 data.copy_from_slice(&src.as_ref()[range.start..range.end]);
226 range.end - range.start
227 }
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use alloc::vec;
235
236 use turn_types::attribute::Data as AData;
237 use turn_types::attribute::XorPeerAddress;
238 use turn_types::channel::ChannelData;
239 use turn_types::stun::message::Message;
240
241 use super::*;
242
243 fn generate_addresses() -> (SocketAddr, SocketAddr) {
244 (
245 "192.168.0.1:1000".parse().unwrap(),
246 "10.0.0.2:2000".parse().unwrap(),
247 )
248 }
249
250 #[test]
251 fn test_delayed_message() {
252 let (local_addr, remote_addr) = generate_addresses();
253 let data = [5; 5];
254 let peer_addr = "127.0.0.1:1".parse().unwrap();
255 let transmit = TransmitBuild::new(
256 DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
257 TransportType::Udp,
258 local_addr,
259 remote_addr,
260 );
261 assert!(!transmit.data.is_empty());
262 let len = transmit.data.len();
263 let out = transmit.build();
264 assert_eq!(len, out.data.len());
265 let msg = Message::from_bytes(&out.data).unwrap();
266 let addr = msg.attribute::<XorPeerAddress>().unwrap();
267 assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
268 let out_data = msg.attribute::<AData>().unwrap();
269 assert_eq!(out_data.data(), data.as_ref());
270 let transmit = TransmitBuild::new(
271 DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
272 TransportType::Udp,
273 local_addr,
274 remote_addr,
275 );
276 let mut out2 = vec![0; len];
277 transmit.write_into(&mut out2);
278 let msg = Message::from_bytes(&out2).unwrap();
279 let addr = msg.attribute::<XorPeerAddress>().unwrap();
280 assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
281 let out_data = msg.attribute::<AData>().unwrap();
282 assert_eq!(out_data.data(), data.as_ref());
283 }
284
285 #[test]
286 fn test_delayed_channel() {
287 let (local_addr, remote_addr) = generate_addresses();
288 let data = [5; 5];
289 let channel_id = 0x4567;
290 let transmit = TransmitBuild::new(
291 DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
292 TransportType::Udp,
293 local_addr,
294 remote_addr,
295 );
296 assert!(!transmit.data.is_empty());
297 let len = transmit.data.len();
298 let out = transmit.build();
299 assert_eq!(len, out.data.len());
300 let channel = ChannelData::parse(&out.data).unwrap();
301 assert_eq!(channel.id(), channel_id);
302 assert_eq!(channel.data(), data.as_ref());
303 let transmit = TransmitBuild::new(
304 DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
305 TransportType::Udp,
306 local_addr,
307 remote_addr,
308 );
309 let mut out2 = vec![0; len];
310 transmit.write_into(&mut out2);
311 assert_eq!(len, out2.len());
312 let channel = ChannelData::parse(&out2).unwrap();
313 assert_eq!(channel.id(), channel_id);
314 assert_eq!(channel.data(), data.as_ref());
315 }
316
317 #[test]
318 fn test_delayed_owned() {
319 let (local_addr, remote_addr) = generate_addresses();
320 let data = vec![7; 7];
321 let transmit = TransmitBuild::new(
322 DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
323 TransportType::Udp,
324 local_addr,
325 remote_addr,
326 );
327 assert!(!transmit.data.is_empty());
328 let len = transmit.data.len();
329 let out = transmit.build();
330 assert_eq!(len, out.data.len());
331 assert_eq!(data, out.data);
332 let transmit = TransmitBuild::new(
333 DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
334 TransportType::Udp,
335 local_addr,
336 remote_addr,
337 );
338 let mut out2 = vec![0; len];
339 transmit.write_into(&mut out2);
340 assert_eq!(len, out2.len());
341 assert_eq!(data, out2);
342 }
343
344 #[test]
345 fn test_delayed_range() {
346 let (local_addr, remote_addr) = generate_addresses();
347 let data = vec![7; 7];
348 let range = 2..6;
349 const LEN: usize = 4;
350 let transmit = TransmitBuild::new(
351 DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
352 TransportType::Udp,
353 local_addr,
354 remote_addr,
355 );
356 let len = transmit.data.len();
357 assert_eq!(len, LEN);
358 let out = transmit.build();
359 assert_eq!(len, out.data.len());
360 assert_eq!(data[range.start..range.end], out.data);
361 let transmit = TransmitBuild::new(
362 DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
363 TransportType::Udp,
364 local_addr,
365 remote_addr,
366 );
367 let mut out2 = vec![0; len];
368 transmit.write_into(&mut out2);
369 assert_eq!(len, out2.len());
370 assert_eq!(data[range.start..range.end], out2);
371 }
372}