1#![deny(missing_debug_implementations)]
25#![deny(missing_docs)]
26#![cfg_attr(docsrs, feature(doc_cfg))]
27#![deny(clippy::std_instead_of_core)]
28#![deny(clippy::std_instead_of_alloc)]
29#![no_std]
30
31extern crate alloc;
32
33#[cfg(any(feature = "std", test))]
34extern crate std;
35
36pub use rustls;
37
38use alloc::sync::Arc;
39use alloc::vec;
40use alloc::vec::Vec;
41use core::net::{IpAddr, SocketAddr};
42use core::time::Duration;
43use std::io::{Read, Write};
44
45use turn_client_proto::types::Instant;
46use turn_client_proto::types::TransportType;
47
48pub use turn_client_proto as proto;
49pub use turn_client_proto::api::*;
50
51use turn_client_proto::tcp::TurnClientTcp;
52
53use rustls::pki_types::ServerName;
54use rustls::{ClientConfig, ClientConnection};
55
56use tracing::{debug, trace, warn};
57
58#[derive(Debug)]
60pub struct TurnClientRustls {
61 protocol: TurnClientTcp,
62 tls_config: Arc<ClientConfig>,
63 server_name: ServerName<'static>,
64 pending_allocates: Vec<(u32, Socket5Tuple, SocketAddr)>,
65 sockets: Vec<Socket>,
66}
67
68#[derive(Debug)]
69struct Socket {
70 local_addr: SocketAddr,
71 remote_addr: SocketAddr,
72 tls: ClientConnection,
73 peer_closed: bool,
74 local_closed: bool,
75}
76
77impl TurnClientRustls {
78 #[allow(clippy::too_many_arguments)]
80 pub fn allocate(
81 local_addr: SocketAddr,
82 remote_addr: SocketAddr,
83 config: TurnConfig,
84 server_name: ServerName<'static>,
85 tls_config: Arc<ClientConfig>,
86 ) -> Self {
87 Self {
88 protocol: TurnClientTcp::allocate(local_addr, remote_addr, config),
89 sockets: vec![Socket {
90 local_addr,
91 remote_addr,
92 tls: ClientConnection::new(tls_config.clone(), server_name.clone()).unwrap(),
93 local_closed: false,
94 peer_closed: false,
95 }],
96 tls_config,
97 server_name,
98 pending_allocates: vec![],
99 }
100 }
101
102 fn empty_transmit_queue(&mut self, now: Instant) {
103 while let Some(transmit) = self.protocol.poll_transmit(now) {
104 let Some(socket) = self.sockets.iter_mut().find(|socket| {
105 socket.local_addr == transmit.from && socket.remote_addr == transmit.to
106 }) else {
107 warn!(
108 "no socket for transmit from {} to {}",
109 transmit.from, transmit.to
110 );
111 continue;
112 };
113 socket.tls.writer().write_all(&transmit.data).unwrap();
114 }
115 }
116}
117
118impl TurnClientApi for TurnClientRustls {
119 fn transport(&self) -> TransportType {
120 self.protocol.transport()
121 }
122
123 fn local_addr(&self) -> SocketAddr {
124 self.protocol.local_addr()
125 }
126
127 fn remote_addr(&self) -> SocketAddr {
128 self.protocol.remote_addr()
129 }
130
131 fn poll(&mut self, now: Instant) -> TurnPollRet {
132 let mut is_handshaking = false;
133 let mut protocol_ret = TurnPollRet::Closed;
134 for (idx, socket) in self.sockets.iter_mut().enumerate() {
135 let io_state = match socket.tls.process_new_packets() {
136 Ok(io_state) => io_state,
137 Err(e) => {
138 warn!("Error processing TLS: {e:?}");
139 if socket.local_addr == self.protocol.local_addr()
140 && socket.remote_addr == self.protocol.remote_addr()
141 {
142 self.protocol.protocol_error();
143 return TurnPollRet::Closed;
144 } else {
145 continue;
147 }
148 }
149 };
150 if io_state.peer_has_closed() {
151 socket.peer_closed = true;
152 if !socket.local_closed {
153 socket.tls.send_close_notify();
154 socket.local_closed = true;
155 trace!("sending close notify");
156 return TurnPollRet::WaitUntil(now);
157 }
158 }
159 let tls_write_bytes = io_state.tls_bytes_to_write();
160 if tls_write_bytes > 0 {
161 trace!("have {tls_write_bytes} bytes to write");
162 return TurnPollRet::WaitUntil(now);
163 }
164 if socket.peer_closed && socket.local_closed && !socket.tls.wants_write() {
165 let socket = self.sockets.remove(idx);
166 return TurnPollRet::TcpClose {
167 local_addr: socket.local_addr,
168 remote_addr: socket.remote_addr,
169 };
170 }
171 if socket.local_addr == self.protocol.local_addr()
172 && socket.remote_addr == self.protocol.remote_addr()
173 {
174 protocol_ret = self.protocol.poll(now);
175 }
176 is_handshaking |= socket.tls.is_handshaking();
177 }
178 match protocol_ret {
179 TurnPollRet::Closed => {
180 debug!("Closed");
181 return protocol_ret;
182 }
183 TurnPollRet::AllocateTcpSocket {
184 id,
185 socket,
186 peer_addr,
187 } => {
188 self.pending_allocates.push((id, socket, peer_addr));
189 }
190 _ => (),
191 }
192 if is_handshaking {
193 debug!("Currently handshaking, waiting for reply");
194 return TurnPollRet::WaitUntil(now + Duration::from_secs(60));
195 }
196 protocol_ret
197 }
198
199 fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
200 self.protocol.relayed_addresses()
201 }
202
203 fn permissions(
204 &self,
205 transport: TransportType,
206 relayed: SocketAddr,
207 ) -> impl Iterator<Item = IpAddr> + '_ {
208 self.protocol.permissions(transport, relayed)
209 }
210
211 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
212 let client_transport = self.transport();
213 for socket in self.sockets.iter_mut() {
214 if socket.tls.is_handshaking() {
215 if socket.tls.wants_write() {
216 let mut out = vec![];
218 match socket.tls.write_tls(&mut out) {
219 Ok(_written) => {
220 return Some(Transmit::new(
221 Data::from(out.into_boxed_slice()),
222 client_transport,
223 socket.local_addr,
224 socket.remote_addr,
225 ))
226 }
227 Err(e) => {
228 warn!("error during handshake: {e:?}");
229 if socket.local_addr == self.protocol.local_addr()
230 && socket.remote_addr == self.protocol.remote_addr()
231 {
232 self.protocol.protocol_error();
233 return None;
234 } else {
235 continue;
237 }
238 }
239 }
240 }
241 if socket.local_addr == self.protocol.local_addr()
242 && socket.remote_addr == self.protocol.remote_addr()
243 {
244 return None;
245 }
246 }
247 }
248 self.empty_transmit_queue(now);
249
250 for socket in self.sockets.iter_mut() {
251 if socket.tls.wants_write() {
252 let mut out = vec![];
254 match socket.tls.write_tls(&mut out) {
255 Ok(_written) => {
256 return Some(Transmit::new(
257 Data::from(out.into_boxed_slice()),
258 client_transport,
259 socket.local_addr,
260 socket.remote_addr,
261 ))
262 }
263 Err(e) => {
264 warn!("error writing TLS: {e:?}");
265 if socket.local_addr == self.protocol.local_addr()
266 && socket.remote_addr == self.protocol.remote_addr()
267 {
268 self.protocol.protocol_error();
269 } else {
270 continue;
272 }
273 }
274 }
275 }
276 }
277 None
278 }
279
280 fn poll_event(&mut self) -> Option<TurnEvent> {
281 match self.protocol.poll_event()? {
282 TurnEvent::TcpConnected(peer_addr) => Some(TurnEvent::TcpConnected(peer_addr)),
283 TurnEvent::TcpConnectFailed(peer_addr) => Some(TurnEvent::TcpConnectFailed(peer_addr)),
284 event => Some(event),
285 }
286 }
287
288 fn delete(&mut self, now: Instant) -> Result<(), DeleteError> {
289 self.protocol.delete(now)?;
290
291 self.empty_transmit_queue(now);
292 Ok(())
293 }
294
295 fn create_permission(
296 &mut self,
297 transport: TransportType,
298 peer_addr: IpAddr,
299 now: Instant,
300 ) -> Result<(), CreatePermissionError> {
301 self.protocol.create_permission(transport, peer_addr, now)?;
302
303 self.empty_transmit_queue(now);
304
305 Ok(())
306 }
307
308 fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool {
309 self.protocol.have_permission(transport, to)
310 }
311
312 fn bind_channel(
313 &mut self,
314 transport: TransportType,
315 peer_addr: SocketAddr,
316 now: Instant,
317 ) -> Result<(), BindChannelError> {
318 self.protocol.bind_channel(transport, peer_addr, now)?;
319
320 self.empty_transmit_queue(now);
321
322 Ok(())
323 }
324
325 fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError> {
326 self.protocol.tcp_connect(peer_addr, now)?;
327
328 self.empty_transmit_queue(now);
329
330 Ok(())
331 }
332
333 fn allocated_tcp_socket(
334 &mut self,
335 id: u32,
336 five_tuple: Socket5Tuple,
337 peer_addr: SocketAddr,
338 local_addr: Option<SocketAddr>,
339 now: Instant,
340 ) -> Result<(), TcpAllocateError> {
341 self.protocol
342 .allocated_tcp_socket(id, five_tuple, peer_addr, local_addr, now)?;
343
344 if let Some(local_addr) = local_addr {
345 if let Some(idx) = self
346 .pending_allocates
347 .iter()
348 .position(|pending| pending.1 == five_tuple)
349 {
350 self.pending_allocates.swap_remove(idx);
351 self.sockets.push(Socket {
352 local_addr,
353 remote_addr: self.remote_addr(),
354 tls: ClientConnection::new(self.tls_config.clone(), self.server_name.clone())
355 .unwrap(),
356 local_closed: false,
357 peer_closed: false,
358 });
359 }
360 }
361
362 self.empty_transmit_queue(now);
363 Ok(())
364 }
365
366 fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant) {
367 let Some(socket) = self
368 .sockets
369 .iter_mut()
370 .find(|socket| socket.local_addr == local_addr && socket.remote_addr == remote_addr)
371 else {
372 warn!(
373 "Unknown socket local:{}, remote:{}",
374 local_addr, remote_addr
375 );
376 return;
377 };
378 self.protocol.tcp_closed(local_addr, remote_addr, now);
379 socket.tls.send_close_notify();
380 socket.local_closed = true;
381 }
382
383 fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
384 &mut self,
385 transport: TransportType,
386 to: SocketAddr,
387 data: T,
388 now: Instant,
389 ) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError> {
390 if let Some(transmit) = self.protocol.send_to(transport, to, data, now)? {
391 let client_transport = self.transport();
392 let transmit = transmit.build();
393 let Some(socket) = self.sockets.iter_mut().find(|socket| {
394 socket.local_addr == transmit.from
395 && socket.remote_addr == transmit.to
396 && !socket.local_closed
397 }) else {
398 warn!(
399 "no socket for transmit from {} to {}",
400 transmit.from, transmit.to
401 );
402 return Err(SendError::NoTcpSocket);
403 };
404 if let Err(e) = socket.tls.writer().write_all(&transmit.data) {
405 warn!("Error when writing plaintext: {e:?}");
406 if socket.local_addr == self.protocol.local_addr()
407 && socket.remote_addr == self.protocol.remote_addr()
408 {
409 self.protocol.protocol_error();
410 return Err(SendError::NoAllocation);
411 } else {
412 return Err(SendError::NoTcpSocket);
413 }
414 }
415
416 if socket.tls.wants_write() {
417 let mut out = vec![];
418 match socket.tls.write_tls(&mut out) {
419 Ok(_n) => {
420 return Ok(Some(TransmitBuild::new(
421 DelayedMessageOrChannelSend::OwnedData(out),
422 client_transport,
423 socket.local_addr,
424 socket.remote_addr,
425 )))
426 }
427 Err(e) => {
428 warn!("Error when writing TLS records: {e:?}");
429 if socket.local_addr == self.protocol.local_addr()
430 && socket.remote_addr == self.protocol.remote_addr()
431 {
432 self.protocol.protocol_error();
433 return Err(SendError::NoAllocation);
434 } else {
435 return Err(SendError::NoTcpSocket);
436 }
437 }
438 }
439 }
440 }
441
442 Ok(None)
443 }
444
445 #[tracing::instrument(
446 name = "turn_rustls_recv",
447 skip(self, transmit, now),
448 fields(
449 from = ?transmit.from,
450 data_len = transmit.data.as_ref().len()
451 )
452 )]
453 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
454 &mut self,
455 transmit: Transmit<T>,
456 now: Instant,
457 ) -> TurnRecvRet<T> {
458 if self.transport() != transmit.transport {
460 return TurnRecvRet::Ignored(transmit);
461 }
462 let Some(socket) = self
463 .sockets
464 .iter_mut()
465 .find(|socket| socket.local_addr == transmit.to && socket.remote_addr == transmit.from)
466 else {
467 trace!(
468 "received data not directed at us ({:?}) but for {:?}!",
469 self.local_addr(),
470 transmit.to
471 );
472 return TurnRecvRet::Ignored(transmit);
473 };
474 let mut data = std::io::Cursor::new(transmit.data.as_ref());
475
476 let io_state = match socket.tls.read_tls(&mut data) {
477 Ok(_written) => match socket.tls.process_new_packets() {
478 Ok(io_state) => io_state,
479 Err(e) => {
480 self.protocol.protocol_error();
481 warn!("Error processing TLS: {e:?}");
482 return TurnRecvRet::Ignored(transmit);
483 }
484 },
485 Err(e) => {
486 warn!("Error receiving data: {e:?}");
487 self.protocol.protocol_error();
488 return TurnRecvRet::Ignored(transmit);
489 }
490 };
491 if io_state.plaintext_bytes_to_read() > 0 {
492 let mut out = vec![0; 2048];
493 let n = match socket.tls.reader().read(&mut out) {
494 Ok(n) => n,
495 Err(e) => {
496 warn!("Error receiving data: {e:?}");
497 self.protocol.protocol_error();
498 return TurnRecvRet::Ignored(transmit);
499 }
500 };
501 out.resize(n, 0);
502 let transmit = Transmit::new(out, transmit.transport, transmit.from, transmit.to);
503
504 return match self.protocol.recv(transmit, now) {
505 TurnRecvRet::Ignored(_) => unreachable!(),
506 TurnRecvRet::PeerData(peer_data) => TurnRecvRet::PeerData(peer_data.into_owned()),
507 TurnRecvRet::Handled => TurnRecvRet::Handled,
508 TurnRecvRet::PeerIcmp {
509 transport,
510 peer,
511 icmp_type,
512 icmp_code,
513 icmp_data,
514 } => TurnRecvRet::PeerIcmp {
515 transport,
516 peer,
517 icmp_type,
518 icmp_code,
519 icmp_data,
520 },
521 };
522 }
523
524 TurnRecvRet::Handled
525 }
526
527 fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>> {
528 self.protocol.poll_recv(now)
529 }
530
531 fn protocol_error(&mut self) {
532 self.protocol.protocol_error()
533 }
534}