1#![deny(missing_debug_implementations)]
30#![deny(missing_docs)]
31#![cfg_attr(docsrs, feature(doc_cfg))]
32#![deny(clippy::std_instead_of_core)]
33#![deny(clippy::std_instead_of_alloc)]
34#![no_std]
35
36extern crate alloc;
37
38#[cfg(any(feature = "std", test))]
39extern crate std;
40
41use alloc::string::String;
42use alloc::sync::Arc;
43use alloc::vec;
44use alloc::vec::Vec;
45use core::net::SocketAddr;
46use core::time::Duration;
47use std::io::{Read, Write};
48
49use turn_server_proto::types::prelude::DelayedTransmitBuild;
50use turn_server_proto::types::transmit::TransmitBuild;
51use turn_server_proto::types::AddressFamily;
52
53use turn_server_proto::api::Transmit;
54use turn_server_proto::server::TurnServer;
55use turn_server_proto::types::stun::TransportType;
56use turn_server_proto::types::Instant;
57
58pub use turn_server_proto as proto;
59pub use turn_server_proto::api;
60
61use turn_server_proto::api::{
62 DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
63};
64
65use tracing::{info, trace, warn};
66
67use rustls::{ServerConfig, ServerConnection};
68
69#[derive(Debug)]
71pub struct RustlsTurnServer {
72 server: TurnServer,
73 config: Arc<ServerConfig>,
74 clients: Vec<Client>,
75}
76
77#[derive(Debug)]
78struct Client {
79 client_addr: SocketAddr,
80 tls: ServerConnection,
81 local_closed: bool,
82 peer_closed: bool,
83}
84
85impl RustlsTurnServer {
86 pub fn new(listen_addr: SocketAddr, realm: String, config: Arc<ServerConfig>) -> Self {
88 Self {
89 server: TurnServer::new(TransportType::Tcp, listen_addr, realm),
90 config,
91 clients: vec![],
92 }
93 }
94}
95
96impl TurnServerApi for RustlsTurnServer {
97 fn add_user(&mut self, username: String, password: String) {
99 self.server.add_user(username, password)
100 }
101
102 fn listen_address(&self) -> SocketAddr {
104 self.server.listen_address()
105 }
106
107 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
110 self.server.set_nonce_expiry_duration(expiry_duration)
111 }
112
113 #[tracing::instrument(
117 name = "turn_server_rustls_recv",
118 skip(self, transmit, now),
119 fields(
120 from = ?transmit.from,
121 data_len = transmit.data.as_ref().len()
122 )
123 )]
124 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
125 &mut self,
126 transmit: Transmit<T>,
127 now: Instant,
128 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
129 let listen_address = self.listen_address();
130 if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
131 trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
132 let client = match self
134 .clients
135 .iter_mut()
136 .find(|client| client.client_addr == transmit.from)
137 {
138 Some(client) => client,
139 None => {
140 if transmit.data.as_ref().is_empty() {
141 return None;
142 }
143 let len = self.clients.len();
144 self.clients.push(Client {
145 client_addr: transmit.from,
146 tls: ServerConnection::new(self.config.clone()).unwrap(),
147 local_closed: false,
148 peer_closed: false,
149 });
150 info!("new connection from {}", transmit.from);
151 &mut self.clients[len]
152 }
153 };
154 let mut input = std::io::Cursor::new(transmit.data.as_ref());
155 let io_state = match client.tls.read_tls(&mut input) {
156 Ok(_written) => match client.tls.process_new_packets() {
157 Ok(io_state) => io_state,
158 Err(e) => {
159 warn!("Error processing incoming TLS: {e:?}");
160 return None;
161 }
162 },
163 Err(e) => {
164 warn!("Error receiving data: {e:?}");
165 return None;
166 }
167 };
168 if io_state.peer_has_closed() {
169 client.peer_closed = true;
170 if !client.local_closed {
171 client.tls.send_close_notify();
172 client.local_closed = true;
173 let mut out = vec![];
174 client.tls.write_tls(&mut out).unwrap();
175 let client_addr = client.client_addr;
176 info!("client {client_addr} TLS closed");
177 return Some(TransmitBuild::new(
178 DelayedMessageOrChannelSend::Owned(out),
179 TransportType::Tcp,
180 listen_address,
181 client_addr,
182 ));
183 } else {
184 return None;
185 }
186 }
187 if io_state.plaintext_bytes_to_read() == 0 {
188 return None;
189 }
190 let mut vec = vec![0; 2048];
191 let n = match client.tls.reader().read(&mut vec) {
192 Ok(n) => n,
193 Err(e) => {
194 if e.kind() == std::io::ErrorKind::WouldBlock {
195 return None;
196 } else {
197 warn!("TLS error: {e:?}");
198 return None;
199 }
200 }
201 };
202 trace!("io_state: {io_state:?}, n: {n}");
203 vec.resize(n, 0);
204 let transmit = self.server.recv(
205 Transmit::new(vec, transmit.transport, transmit.from, transmit.to),
206 now,
207 )?;
208 if transmit.transport == TransportType::Tcp
209 && transmit.from == listen_address
210 && transmit.to == client.client_addr
211 {
212 let plaintext = transmit.data.build();
213 client.tls.writer().write_all(&plaintext).unwrap();
214 let mut out = vec![];
215 client.tls.write_tls(&mut out).unwrap();
216 Some(TransmitBuild::new(
217 DelayedMessageOrChannelSend::Owned(out),
218 TransportType::Tcp,
219 listen_address,
220 client.client_addr,
221 ))
222 } else {
223 let transmit = transmit.build();
224 Some(TransmitBuild::new(
225 DelayedMessageOrChannelSend::Owned(transmit.data),
226 transmit.transport,
227 transmit.from,
228 transmit.to,
229 ))
230 }
231 } else if let Some(transmit) = self.server.recv(transmit, now) {
232 if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
234 let Some(client) = self
235 .clients
236 .iter_mut()
237 .find(|client| transmit.to == client.client_addr)
238 else {
239 return Some(transmit);
240 };
241 let plaintext = transmit.data.build();
242 client.tls.writer().write_all(&plaintext).unwrap();
243 let mut out = vec![];
244 client.tls.write_tls(&mut out).unwrap();
245 Some(TransmitBuild::new(
246 DelayedMessageOrChannelSend::Owned(out),
247 TransportType::Tcp,
248 listen_address,
249 client.client_addr,
250 ))
251 } else {
252 Some(transmit)
253 }
254 } else {
255 None
256 }
257 }
258
259 fn recv_icmp<T: AsRef<[u8]>>(
260 &mut self,
261 family: AddressFamily,
262 bytes: T,
263 now: Instant,
264 ) -> Option<Transmit<Vec<u8>>> {
265 let transmit = self.server.recv_icmp(family, bytes, now)?;
266 let listen_address = self.listen_address();
268 if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
269 let Some(client) = self
270 .clients
271 .iter_mut()
272 .find(|client| transmit.to == client.client_addr)
273 else {
274 return Some(transmit);
275 };
276 client.tls.writer().write_all(&transmit.data).unwrap();
277 let mut out = vec![];
278 client.tls.write_tls(&mut out).unwrap();
279 Some(Transmit::new(
280 out,
281 TransportType::Tcp,
282 listen_address,
283 client.client_addr,
284 ))
285 } else {
286 Some(transmit)
287 }
288 }
289
290 fn poll(&mut self, now: Instant) -> TurnServerPollRet {
294 let protocol_ret = self.server.poll(now);
295 let mut have_pending = false;
296 for (idx, client) in self.clients.iter_mut().enumerate() {
297 trace!("client: {client:?}");
298 let io_state = match client.tls.process_new_packets() {
299 Ok(io_state) => io_state,
300 Err(e) => {
301 warn!("Error processing TLS: {e:?}");
302 continue;
303 }
304 };
305 trace!("{io_state:?}");
306 if io_state.tls_bytes_to_write() > 0 {
307 have_pending = true;
308 continue;
309 } else if !client.peer_closed && io_state.peer_has_closed() {
310 client.peer_closed = true;
311 if !client.local_closed {
312 client.tls.send_close_notify();
313 client.local_closed = true;
314 have_pending = true;
315 continue;
316 }
317 }
318 if client.local_closed && client.peer_closed && !client.tls.wants_write() {
319 let client = self.clients.remove(idx);
320 return TurnServerPollRet::TcpClose {
321 local_addr: self.server.listen_address(),
322 remote_addr: client.client_addr,
323 };
324 }
325 }
326 if let TurnServerPollRet::TcpClose {
327 local_addr,
328 remote_addr,
329 } = protocol_ret
330 {
331 let Some(client) = self
332 .clients
333 .iter_mut()
334 .find(|client| client.client_addr == remote_addr)
335 else {
336 return TurnServerPollRet::TcpClose {
337 local_addr,
338 remote_addr,
339 };
340 };
341 client.tls.send_close_notify();
342 client.local_closed = true;
343 return TurnServerPollRet::WaitUntil(now);
344 }
345 if have_pending {
346 return TurnServerPollRet::WaitUntil(now);
347 }
348 protocol_ret
349 }
350
351 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
353 let listen_address = self.listen_address();
354
355 while let Some(transmit) = self.server.poll_transmit(now) {
356 if let Some(client) = self
357 .clients
358 .iter_mut()
359 .find(|client| transmit.to == client.client_addr)
360 {
361 if transmit.data.is_empty() {
362 if !client.local_closed {
363 warn!("client {} closed", client.client_addr);
364 client.tls.send_close_notify();
365 client.local_closed = true;
366 }
367 } else {
368 client.tls.writer().write_all(&transmit.data).unwrap();
369 }
370 } else {
371 warn!("return transmit: {transmit:?}");
372 return Some(transmit);
373 };
374 }
375
376 for client in self.clients.iter_mut() {
377 trace!("client: {client:?}");
378 let client_addr = client.client_addr;
379 if !client.tls.wants_write() {
380 continue;
381 }
382 let mut vec = vec![];
383 let n = match client.tls.write_tls(&mut vec) {
384 Ok(n) => n,
385 Err(e) => {
386 warn!("error writing TLS: {e:?}");
387 continue;
388 }
389 };
390 vec.resize(n, 0);
391 warn!("return transmit: {vec:x?}");
392 return Some(Transmit::new(
393 vec,
394 TransportType::Tcp,
395 listen_address,
396 client_addr,
397 ));
398 }
399 None
400 }
401
402 fn allocated_socket(
405 &mut self,
406 transport: TransportType,
407 local_addr: SocketAddr,
408 remote_addr: SocketAddr,
409 allocation_transport: TransportType,
410 family: AddressFamily,
411 socket_addr: Result<SocketAddr, SocketAllocateError>,
412 now: Instant,
413 ) {
414 self.server.allocated_socket(
415 transport,
416 local_addr,
417 remote_addr,
418 allocation_transport,
419 family,
420 socket_addr,
421 now,
422 )
423 }
424
425 fn tcp_connected(
426 &mut self,
427 relayed_addr: SocketAddr,
428 peer_addr: SocketAddr,
429 listen_addr: SocketAddr,
430 client_addr: SocketAddr,
431 socket_addr: Result<SocketAddr, crate::api::TcpConnectError>,
432 now: Instant,
433 ) {
434 self.server.tcp_connected(
435 relayed_addr,
436 peer_addr,
437 listen_addr,
438 client_addr,
439 socket_addr,
440 now,
441 )
442 }
443}