1#![deny(missing_debug_implementations)]
19#![deny(missing_docs)]
20#![cfg_attr(docsrs, feature(doc_cfg))]
21#![deny(clippy::std_instead_of_core)]
22#![deny(clippy::std_instead_of_alloc)]
23#![no_std]
24
25extern crate alloc;
26
27pub use openssl;
28
29#[cfg(any(feature = "std", test))]
30extern crate std;
31
32pub use turn_client_proto::api;
33
34use std::io::{Read, Write};
35
36use alloc::collections::VecDeque;
37use alloc::vec;
38use alloc::vec::Vec;
39
40use core::net::{IpAddr, SocketAddr};
41use core::time::Duration;
42
43use turn_client_proto::types::Instant;
44use turn_client_proto::types::TransportType;
45
46use tracing::{info, trace, warn};
47
48use turn_client_proto::api::*;
49use turn_client_proto::tcp::TurnClientTcp;
50use turn_client_proto::udp::TurnClientUdp;
51
52use openssl::ssl::{
53 HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslContext,
54 SslStream,
55};
56
57turn_client_proto::impl_client!(TcpOrUdp, (Udp, TurnClientUdp), (Tcp, TurnClientTcp));
58
59#[derive(Debug)]
61pub struct TurnClientOpensslTls {
62 protocol: TcpOrUdp,
63 ssl_context: SslContext,
64 sockets: Vec<Socket>,
65}
66
67#[derive(Debug)]
68struct Socket {
69 local_addr: SocketAddr,
70 remote_addr: SocketAddr,
71 handshake: HandshakeState,
72 pending_write: VecDeque<Data<'static>>,
73 shutdown: ShutdownState,
74}
75
76#[derive(Debug)]
77enum HandshakeState {
78 Init(Ssl, OsslBio),
79 Handshaking(MidHandshakeSslStream<OsslBio>),
80 Done(SslStream<OsslBio>),
81 Nothing,
82}
83
84impl HandshakeState {
85 fn complete(&mut self) -> Result<&mut SslStream<OsslBio>, std::io::Error> {
86 if let Self::Done(s) = self {
87 return Ok(s);
88 }
89 let taken = core::mem::replace(self, Self::Nothing);
90
91 let ret = match taken {
92 Self::Init(ssl, bio) => ssl.connect(bio),
93 Self::Handshaking(mid) => mid.handshake(),
94 Self::Done(_) | Self::Nothing => unreachable!(),
95 };
96
97 match ret {
98 Ok(s) => {
99 info!(
100 "SSL handshake completed with version {} cipher: {:?}",
101 s.ssl().version_str(),
102 s.ssl().current_cipher()
103 );
104 *self = Self::Done(s);
105 Ok(self.complete()?)
106 }
107 Err(HandshakeError::WouldBlock(mid)) => {
108 *self = Self::Handshaking(mid);
109 Err(std::io::Error::new(
110 std::io::ErrorKind::WouldBlock,
111 "Would Block",
112 ))
113 }
114 Err(HandshakeError::SetupFailure(e)) => {
115 warn!("Error during ssl setup: {e}");
116 Err(std::io::Error::new(
117 std::io::ErrorKind::ConnectionRefused,
118 e,
119 ))
120 }
121 Err(HandshakeError::Failure(mid)) => {
122 warn!("Failure during ssl setup: {}", mid.error());
123 *self = Self::Handshaking(mid);
124 Err(std::io::Error::new(
125 std::io::ErrorKind::ConnectionRefused,
126 "Failure to setup SSL parameters",
127 ))
128 }
129 }
130 }
131 fn inner_mut(&mut self) -> &mut OsslBio {
132 match self {
133 Self::Init(_ssl, stream) => stream,
134 Self::Handshaking(mid) => mid.get_mut(),
135 Self::Done(stream) => stream.get_mut(),
136 Self::Nothing => unreachable!(),
137 }
138 }
139}
140
141#[derive(Debug, Default)]
142struct OsslBio {
143 incoming: Vec<u8>,
144 outgoing: VecDeque<Vec<u8>>,
145}
146
147impl OsslBio {
148 fn push_incoming(&mut self, buf: &[u8]) {
149 self.incoming.extend_from_slice(buf)
150 }
151
152 fn pop_outgoing(&mut self) -> Option<Vec<u8>> {
153 self.outgoing.pop_front()
154 }
155}
156
157impl std::io::Write for OsslBio {
158 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
159 self.outgoing.push_back(buf.to_vec());
160 Ok(buf.len())
161 }
162
163 fn flush(&mut self) -> std::io::Result<()> {
164 Ok(())
165 }
166}
167
168impl std::io::Read for OsslBio {
169 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
170 let len = self.incoming.len();
171 let max = buf.len().min(len);
172
173 if len == 0 {
174 return Err(std::io::Error::new(
175 std::io::ErrorKind::WouldBlock,
176 "Would Block",
177 ));
178 }
179
180 buf[..max].copy_from_slice(&self.incoming[..max]);
181 if max == len {
182 self.incoming.truncate(0);
183 } else {
184 self.incoming.drain(..max);
185 }
186
187 Ok(max)
188 }
189}
190
191impl TurnClientOpensslTls {
192 pub fn allocate(
194 transport: TransportType,
195 local_addr: SocketAddr,
196 remote_addr: SocketAddr,
197 config: TurnConfig,
198 ssl_context: SslContext,
199 ) -> Self {
200 let ssl = Ssl::new(&ssl_context).expect("Cannot create ssl structure");
201
202 Self {
203 protocol: match transport {
204 TransportType::Udp => {
205 if config.allocation_transport() != TransportType::Udp {
206 panic!("Cannot create a TCP allocation with a UDP connection to the TURN server")
207 }
208 TcpOrUdp::Udp(TurnClientUdp::allocate(local_addr, remote_addr, config))
209 }
210 TransportType::Tcp => {
211 TcpOrUdp::Tcp(TurnClientTcp::allocate(local_addr, remote_addr, config))
212 }
213 },
214 ssl_context,
215 sockets: vec![Socket {
216 local_addr,
217 remote_addr,
218 handshake: HandshakeState::Init(ssl, OsslBio::default()),
219 pending_write: VecDeque::default(),
220 shutdown: ShutdownState::empty(),
221 }],
222 }
223 }
224
225 fn empty_transmit_queue(&mut self, now: Instant) {
226 while let Some(transmit) = self.protocol.poll_transmit(now) {
227 let Some(socket) = self.sockets.iter_mut().find(|socket| {
228 socket.local_addr == transmit.from && socket.remote_addr == transmit.to
229 }) else {
230 warn!(
231 "no socket for transmit from {} to {}",
232 transmit.from, transmit.to
233 );
234 continue;
235 };
236 match socket.handshake.complete() {
237 Ok(stream) => {
238 for data in socket.pending_write.drain(..) {
239 warn!("write early data, {} bytes", data.len());
240 stream.write_all(&data).unwrap()
241 }
242 stream.write_all(&transmit.data).unwrap()
243 }
244 Err(e) => {
245 if e.kind() == std::io::ErrorKind::WouldBlock {
246 warn!("early data ({} bytes), storing", transmit.data.len());
247 socket.pending_write.push_back(transmit.data);
248 } else {
249 warn!("Failure to send data: {e:?}");
250 continue;
251 }
252 }
253 }
254 }
255 }
256}
257
258impl TurnClientApi for TurnClientOpensslTls {
259 fn transport(&self) -> TransportType {
260 self.protocol.transport()
261 }
262
263 fn local_addr(&self) -> SocketAddr {
264 self.protocol.local_addr()
265 }
266
267 fn remote_addr(&self) -> SocketAddr {
268 self.protocol.remote_addr()
269 }
270
271 fn poll(&mut self, now: Instant) -> TurnPollRet {
272 let mut is_handshaking = false;
273 let mut have_outgoing = false;
274 for (idx, socket) in self.sockets.iter_mut().enumerate() {
275 let stream = match socket.handshake.complete() {
276 Ok(stream) => stream,
277 Err(e) => {
278 if e.kind() == std::io::ErrorKind::WouldBlock {
279 is_handshaking = true;
280 continue;
281 } else {
282 warn!("Openssl produced error: {e:?}");
283 return TurnPollRet::Closed;
284 }
285 }
286 };
287 socket.shutdown = stream.get_shutdown();
288 if !socket.handshake.inner_mut().outgoing.is_empty() {
289 have_outgoing = true;
290 continue;
291 }
292 if socket
293 .shutdown
294 .contains(ShutdownState::SENT | ShutdownState::RECEIVED)
295 {
296 let socket = self.sockets.swap_remove(idx);
297 if self.transport() == TransportType::Tcp {
298 return TurnPollRet::TcpClose {
299 local_addr: socket.local_addr,
300 remote_addr: socket.remote_addr,
301 };
302 } else {
303 have_outgoing = true;
304 break;
305 }
306 }
307 }
308 if have_outgoing {
309 return TurnPollRet::WaitUntil(now);
310 }
311 if is_handshaking {
312 return TurnPollRet::WaitUntil(now + Duration::from_millis(200));
314 }
315 let protocol_ret = self.protocol.poll(now);
316 if let TurnPollRet::TcpClose {
317 local_addr,
318 remote_addr,
319 } = protocol_ret
320 {
321 if let Some((idx, socket)) =
322 self.sockets.iter_mut().enumerate().find(|(_idx, socket)| {
323 socket.local_addr == local_addr && socket.remote_addr == remote_addr
324 })
325 {
326 if let Ok(stream) = socket.handshake.complete() {
327 let _ = stream.shutdown();
328 socket.shutdown = stream.get_shutdown();
329 } else {
330 self.sockets.swap_remove(idx);
331 }
332 return TurnPollRet::WaitUntil(now);
333 }
334 }
335 protocol_ret
336 }
337
338 fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_ {
339 self.protocol.relayed_addresses()
340 }
341
342 fn permissions(
343 &self,
344 transport: TransportType,
345 relayed: SocketAddr,
346 ) -> impl Iterator<Item = IpAddr> + '_ {
347 self.protocol.permissions(transport, relayed)
348 }
349
350 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>> {
351 let client_transport = self.transport();
352 for socket in self.sockets.iter_mut() {
353 if let Some(outgoing) = socket.handshake.inner_mut().pop_outgoing() {
354 return Some(Transmit::new(
355 outgoing.into_boxed_slice().into(),
356 client_transport,
357 socket.local_addr,
358 socket.remote_addr,
359 ));
360 }
361
362 let stream = match socket.handshake.complete() {
363 Ok(stream) => stream,
364 Err(e) => {
365 warn!("handshake error: {e:?}");
366 if let Some(outgoing) = socket.handshake.inner_mut().pop_outgoing() {
367 return Some(Transmit::new(
368 outgoing.into_boxed_slice().into(),
369 client_transport,
370 socket.local_addr,
371 socket.remote_addr,
372 ));
373 } else {
374 return None;
375 }
376 }
377 };
378 for data in socket.pending_write.drain(..) {
379 warn!("write early data, {} bytes", data.len());
380 stream.write_all(&data).unwrap()
381 }
382 }
383 self.empty_transmit_queue(now);
384 for socket in self.sockets.iter_mut() {
385 if let Some(outgoing) = socket.handshake.inner_mut().pop_outgoing() {
386 return Some(Transmit::new(
387 outgoing.into_boxed_slice().into(),
388 client_transport,
389 socket.local_addr,
390 socket.remote_addr,
391 ));
392 }
393 }
394 None
395 }
396
397 fn poll_event(&mut self) -> Option<TurnEvent> {
398 self.protocol.poll_event()
399 }
400
401 fn delete(&mut self, now: Instant) -> Result<(), DeleteError> {
402 self.protocol.delete(now)?;
403 self.empty_transmit_queue(now);
404 Ok(())
405 }
406
407 fn create_permission(
408 &mut self,
409 transport: TransportType,
410 peer_addr: IpAddr,
411 now: Instant,
412 ) -> Result<(), CreatePermissionError> {
413 self.protocol.create_permission(transport, peer_addr, now)?;
414 self.empty_transmit_queue(now);
415 Ok(())
416 }
417
418 fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool {
419 self.protocol.have_permission(transport, to)
420 }
421
422 fn bind_channel(
423 &mut self,
424 transport: TransportType,
425 peer_addr: SocketAddr,
426 now: Instant,
427 ) -> Result<(), BindChannelError> {
428 self.protocol.bind_channel(transport, peer_addr, now)?;
429 self.empty_transmit_queue(now);
430 Ok(())
431 }
432
433 fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError> {
434 self.protocol.tcp_connect(peer_addr, now)?;
435
436 self.empty_transmit_queue(now);
437
438 Ok(())
439 }
440
441 fn allocated_tcp_socket(
442 &mut self,
443 id: u32,
444 five_tuple: Socket5Tuple,
445 peer_addr: SocketAddr,
446 local_addr: Option<SocketAddr>,
447 now: Instant,
448 ) -> Result<(), TcpAllocateError> {
449 self.protocol
450 .allocated_tcp_socket(id, five_tuple, peer_addr, local_addr, now)?;
451
452 if let Some(local_addr) = local_addr {
453 self.sockets.push(Socket {
454 local_addr,
455 remote_addr: self.remote_addr(),
456 handshake: HandshakeState::Init(
457 Ssl::new(&self.ssl_context).expect("Failed to create SSL"),
458 OsslBio::default(),
459 ),
460 pending_write: VecDeque::default(),
461 shutdown: ShutdownState::empty(),
462 });
463 }
464
465 self.empty_transmit_queue(now);
466
467 Ok(())
468 }
469
470 fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant) {
471 let Some(socket) = self
472 .sockets
473 .iter_mut()
474 .find(|socket| socket.local_addr == local_addr && socket.remote_addr == remote_addr)
475 else {
476 warn!(
477 "Unknown socket local:{}, remote:{}",
478 local_addr, remote_addr
479 );
480 return;
481 };
482 self.protocol.tcp_closed(local_addr, remote_addr, now);
483 if let Ok(stream) = socket.handshake.complete() {
484 socket.shutdown |= match stream.shutdown() {
485 Ok(ShutdownResult::Sent) => ShutdownState::SENT,
486 Ok(ShutdownResult::Received) => ShutdownState::RECEIVED,
487 Err(e) => {
488 warn!("Failed to close TLS connection: {e:?}");
489 return;
490 }
491 }
492 }
493 }
494
495 fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
496 &mut self,
497 transport: TransportType,
498 to: SocketAddr,
499 data: T,
500 now: Instant,
501 ) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError> {
502 let client_transport = self.transport();
503 if let Some(transmit) = self.protocol.send_to(transport, to, data, now)? {
504 let Some(socket) = self.sockets.iter_mut().find(|socket| {
505 socket.local_addr == transmit.from
506 && socket.remote_addr == transmit.to
507 && !socket.shutdown.contains(ShutdownState::SENT)
508 }) else {
509 warn!(
510 "no socket for transmit from {} to {}",
511 transmit.from, transmit.to
512 );
513 return Err(SendError::NoTcpSocket);
514 };
515 let stream = socket.handshake.complete().expect("No TLS connection yet");
516 let transmit = transmit.build();
517 for data in socket.pending_write.drain(..) {
518 stream.write_all(&data).unwrap()
519 }
520 if let Err(e) = stream.write_all(&transmit.data) {
521 self.protocol.protocol_error();
522 warn!("Error when writing plaintext: {e:?}");
523 return Err(SendError::NoAllocation);
524 }
525
526 if let Some(outgoing) = stream.get_mut().pop_outgoing() {
527 return Ok(Some(TransmitBuild::new(
528 DelayedMessageOrChannelSend::OwnedData(outgoing),
529 client_transport,
530 socket.local_addr,
531 socket.remote_addr,
532 )));
533 }
534 }
535
536 Ok(None)
537 }
538
539 #[tracing::instrument(
540 name = "turn_openssl_recv",
541 skip(self, transmit, now),
542 fields(
543 transport = %transmit.transport,
544 from = ?transmit.from,
545 data_len = transmit.data.as_ref().len()
546 )
547 )]
548 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
549 &mut self,
550 transmit: Transmit<T>,
551 now: Instant,
552 ) -> TurnRecvRet<T> {
553 if self.transport() != transmit.transport {
555 return TurnRecvRet::Ignored(transmit);
556 }
557 let Some(socket) = self
558 .sockets
559 .iter_mut()
560 .find(|socket| socket.local_addr == transmit.to && socket.remote_addr == transmit.from)
561 else {
562 trace!(
563 "received data not directed at us ({} {:?}) but for {} {:?}!",
564 self.transport(),
565 self.local_addr(),
566 transmit.transport,
567 transmit.to,
568 );
569 return TurnRecvRet::Ignored(transmit);
570 };
571
572 socket
573 .handshake
574 .inner_mut()
575 .push_incoming(transmit.data.as_ref());
576
577 let stream = match socket.handshake.complete() {
578 Ok(stream) => stream,
579 Err(e) => {
580 if e.kind() == std::io::ErrorKind::WouldBlock {
581 return TurnRecvRet::Handled;
582 }
583 return TurnRecvRet::Ignored(transmit);
584 }
585 };
586
587 let mut out = vec![0; 2048];
588 let len = match stream.read(&mut out) {
589 Ok(len) => len,
590 Err(e) => {
591 if e.kind() != std::io::ErrorKind::WouldBlock {
592 self.protocol.protocol_error();
593 tracing::warn!("Error: {e}");
594 }
595 return TurnRecvRet::Ignored(transmit);
596 }
597 };
598 out.resize(len, 0);
599
600 let transmit = Transmit::new(out, transmit.transport, transmit.from, transmit.to);
601
602 match self.protocol.recv(transmit, now) {
603 TurnRecvRet::Ignored(_) => unreachable!(),
604 TurnRecvRet::PeerData(peer_data) => TurnRecvRet::PeerData(peer_data.into_owned()),
605 TurnRecvRet::Handled => TurnRecvRet::Handled,
606 TurnRecvRet::PeerIcmp {
607 transport,
608 peer,
609 icmp_type,
610 icmp_code,
611 icmp_data,
612 } => TurnRecvRet::PeerIcmp {
613 transport,
614 peer,
615 icmp_type,
616 icmp_code,
617 icmp_data,
618 },
619 }
620 }
621
622 fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>> {
623 self.protocol.poll_recv(now)
624 }
625
626 fn protocol_error(&mut self) {
627 self.protocol.protocol_error()
628 }
629}