1use alloc::string::String;
12use alloc::vec::Vec;
13use core::{net::SocketAddr, time::Duration};
14
15use stun_proto::agent::Transmit;
16use stun_proto::Instant;
17use turn_types::prelude::DelayedTransmitBuild;
18use turn_types::stun::{attribute::ErrorCode, TransportType};
19use turn_types::transmit::{DelayedChannel, DelayedMessage, TransmitBuild};
20use turn_types::AddressFamily;
21
22pub trait TurnServerApi: Send + core::fmt::Debug {
24 fn add_user(&mut self, username: String, password: String);
26 fn listen_address(&self) -> SocketAddr;
28 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration);
31 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
35 &mut self,
36 transmit: Transmit<T>,
37 now: Instant,
38 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>;
39 fn recv_icmp<T: AsRef<[u8]>>(
43 &mut self,
44 family: AddressFamily,
45 bytes: T,
46 now: Instant,
47 ) -> Option<Transmit<Vec<u8>>>;
48 fn poll(&mut self, now: Instant) -> TurnServerPollRet;
52 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>>;
54 fn allocated_udp_socket(
57 &mut self,
58 transport: TransportType,
59 local_addr: SocketAddr,
60 remote_addr: SocketAddr,
61 family: AddressFamily,
62 socket_addr: Result<SocketAddr, SocketAllocateError>,
63 now: Instant,
64 );
65}
66
67#[derive(Debug)]
69pub enum TurnServerPollRet {
70 WaitUntil(Instant),
72 AllocateSocketUdp {
74 transport: TransportType,
76 local_addr: SocketAddr,
78 remote_addr: SocketAddr,
80 family: AddressFamily,
82 },
83}
84
85#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
87pub enum SocketAllocateError {
88 #[error("The address family is not supported.")]
90 AddressFamilyNotSupported,
91 #[error("The server does not have the capacity to handle this request.")]
93 InsufficientCapacity,
94}
95
96impl SocketAllocateError {
97 pub fn into_error_code(self) -> u16 {
99 match self {
100 Self::AddressFamilyNotSupported => ErrorCode::ADDRESS_FAMILY_NOT_SUPPORTED,
101 Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
102 }
103 }
104}
105
106#[derive(Debug)]
108pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + core::fmt::Debug> {
109 Message(DelayedMessage<T>),
111 Channel(DelayedChannel<T>),
113 Owned(Vec<u8>),
115 Range(T, core::ops::Range<usize>),
117}
118
119impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
120 fn len(&self) -> usize {
121 match self {
122 Self::Message(msg) => msg.len(),
123 Self::Channel(channel) => channel.len(),
124 Self::Owned(v) => v.len(),
125 Self::Range(_data, range) => range.end - range.start,
126 }
127 }
128
129 fn build(self) -> Vec<u8> {
130 match self {
131 Self::Message(msg) => msg.build(),
132 Self::Channel(channel) => channel.build(),
133 Self::Owned(v) => v,
134 Self::Range(data, range) => data.as_ref()[range.start..range.end].to_vec(),
135 }
136 }
137 fn is_empty(&self) -> bool {
138 match self {
139 Self::Message(msg) => msg.is_empty(),
140 Self::Channel(channel) => channel.is_empty(),
141 Self::Owned(v) => v.is_empty(),
142 Self::Range(_data, range) => range.end == range.start,
143 }
144 }
145 fn write_into(self, data: &mut [u8]) -> usize {
146 match self {
147 Self::Message(msg) => msg.write_into(data),
148 Self::Channel(channel) => channel.write_into(data),
149 Self::Owned(v) => v.write_into(data),
150 Self::Range(src, range) => {
151 data.copy_from_slice(&src.as_ref()[range.start..range.end]);
152 range.end - range.start
153 }
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use alloc::vec;
161
162 use turn_types::attribute::Data as AData;
163 use turn_types::attribute::XorPeerAddress;
164 use turn_types::channel::ChannelData;
165 use turn_types::stun::message::Message;
166
167 use super::*;
168
169 fn generate_addresses() -> (SocketAddr, SocketAddr) {
170 (
171 "192.168.0.1:1000".parse().unwrap(),
172 "10.0.0.2:2000".parse().unwrap(),
173 )
174 }
175
176 #[test]
177 fn test_delayed_message() {
178 let (local_addr, remote_addr) = generate_addresses();
179 let data = [5; 5];
180 let peer_addr = "127.0.0.1:1".parse().unwrap();
181 let transmit = TransmitBuild::new(
182 DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
183 TransportType::Udp,
184 local_addr,
185 remote_addr,
186 );
187 assert!(!transmit.data.is_empty());
188 let len = transmit.data.len();
189 let out = transmit.build();
190 assert_eq!(len, out.data.len());
191 let msg = Message::from_bytes(&out.data).unwrap();
192 let addr = msg.attribute::<XorPeerAddress>().unwrap();
193 assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
194 let out_data = msg.attribute::<AData>().unwrap();
195 assert_eq!(out_data.data(), data.as_ref());
196 let transmit = TransmitBuild::new(
197 DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
198 TransportType::Udp,
199 local_addr,
200 remote_addr,
201 );
202 let mut out2 = vec![0; len];
203 transmit.write_into(&mut out2);
204 let msg = Message::from_bytes(&out2).unwrap();
205 let addr = msg.attribute::<XorPeerAddress>().unwrap();
206 assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
207 let out_data = msg.attribute::<AData>().unwrap();
208 assert_eq!(out_data.data(), data.as_ref());
209 }
210
211 #[test]
212 fn test_delayed_channel() {
213 let (local_addr, remote_addr) = generate_addresses();
214 let data = [5; 5];
215 let channel_id = 0x4567;
216 let transmit = TransmitBuild::new(
217 DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
218 TransportType::Udp,
219 local_addr,
220 remote_addr,
221 );
222 assert!(!transmit.data.is_empty());
223 let len = transmit.data.len();
224 let out = transmit.build();
225 assert_eq!(len, out.data.len());
226 let channel = ChannelData::parse(&out.data).unwrap();
227 assert_eq!(channel.id(), channel_id);
228 assert_eq!(channel.data(), data.as_ref());
229 let transmit = TransmitBuild::new(
230 DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
231 TransportType::Udp,
232 local_addr,
233 remote_addr,
234 );
235 let mut out2 = vec![0; len];
236 transmit.write_into(&mut out2);
237 assert_eq!(len, out2.len());
238 let channel = ChannelData::parse(&out2).unwrap();
239 assert_eq!(channel.id(), channel_id);
240 assert_eq!(channel.data(), data.as_ref());
241 }
242
243 #[test]
244 fn test_delayed_owned() {
245 let (local_addr, remote_addr) = generate_addresses();
246 let data = vec![7; 7];
247 let transmit = TransmitBuild::new(
248 DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
249 TransportType::Udp,
250 local_addr,
251 remote_addr,
252 );
253 assert!(!transmit.data.is_empty());
254 let len = transmit.data.len();
255 let out = transmit.build();
256 assert_eq!(len, out.data.len());
257 assert_eq!(data, out.data);
258 let transmit = TransmitBuild::new(
259 DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
260 TransportType::Udp,
261 local_addr,
262 remote_addr,
263 );
264 let mut out2 = vec![0; len];
265 transmit.write_into(&mut out2);
266 assert_eq!(len, out2.len());
267 assert_eq!(data, out2);
268 }
269
270 #[test]
271 fn test_delayed_range() {
272 let (local_addr, remote_addr) = generate_addresses();
273 let data = vec![7; 7];
274 let range = 2..6;
275 const LEN: usize = 4;
276 let transmit = TransmitBuild::new(
277 DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
278 TransportType::Udp,
279 local_addr,
280 remote_addr,
281 );
282 let len = transmit.data.len();
283 assert_eq!(len, LEN);
284 let out = transmit.build();
285 assert_eq!(len, out.data.len());
286 assert_eq!(data[range.start..range.end], out.data);
287 let transmit = TransmitBuild::new(
288 DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
289 TransportType::Udp,
290 local_addr,
291 remote_addr,
292 );
293 let mut out2 = vec![0; len];
294 transmit.write_into(&mut out2);
295 assert_eq!(len, out2.len());
296 assert_eq!(data[range.start..range.end], out2);
297 }
298}