1use alloc::collections::VecDeque;
12use alloc::string::String;
13use alloc::vec;
14use alloc::vec::Vec;
15use core::net::SocketAddr;
16use core::time::Duration;
17use std::io::{Read, Write};
18use turn_types::prelude::DelayedTransmitBuild;
19use turn_types::transmit::TransmitBuild;
20use turn_types::AddressFamily;
21
22use openssl::ssl::{HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslStream};
23use stun_proto::agent::Transmit;
24use stun_proto::Instant;
25use tracing::{info, trace, warn};
26use turn_types::stun::TransportType;
27
28use crate::api::{
29 DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
30};
31use crate::server::TurnServer;
32
33#[derive(Debug)]
35pub struct OpensslTurnServer {
36 server: TurnServer,
37 ssl_context: SslContext,
38 connections: Vec<(SocketAddr, HandshakeState)>,
39}
40
41#[derive(Debug)]
42enum HandshakeState {
43 Init(Ssl, OsslBio),
44 Handshaking(MidHandshakeSslStream<OsslBio>),
45 Done(SslStream<OsslBio>),
46 Nothing,
47}
48
49impl HandshakeState {
50 fn complete(&mut self) -> Result<&mut SslStream<OsslBio>, std::io::Error> {
51 if let Self::Done(s) = self {
52 return Ok(s);
53 }
54 let taken = core::mem::replace(self, Self::Nothing);
55
56 let ret = match taken {
57 Self::Init(ssl, bio) => ssl.accept(bio),
58 Self::Handshaking(mid) => mid.handshake(),
59 Self::Done(_) | Self::Nothing => unreachable!(),
60 };
61
62 match ret {
63 Ok(s) => {
64 info!(
65 "SSL handshake completed with version {} cipher: {:?}",
66 s.ssl().version_str(),
67 s.ssl().current_cipher()
68 );
69 *self = Self::Done(s);
70 Ok(self.complete()?)
71 }
72 Err(HandshakeError::WouldBlock(mid)) => {
73 *self = Self::Handshaking(mid);
74 Err(std::io::Error::new(
75 std::io::ErrorKind::WouldBlock,
76 "Would Block",
77 ))
78 }
79 Err(HandshakeError::SetupFailure(e)) => {
80 warn!("Error during ssl setup: {e}");
81 Err(std::io::Error::new(
82 std::io::ErrorKind::ConnectionRefused,
83 e,
84 ))
85 }
86 Err(HandshakeError::Failure(mid)) => {
87 warn!("Failure during ssl setup: {}", mid.error());
88 *self = Self::Handshaking(mid);
89 Err(std::io::Error::new(
90 std::io::ErrorKind::WouldBlock,
91 "Would Block",
92 ))
93 }
94 }
95 }
96 fn inner_mut(&mut self) -> &mut OsslBio {
97 match self {
98 Self::Init(_ssl, stream) => stream,
99 Self::Handshaking(mid) => mid.get_mut(),
100 Self::Done(stream) => stream.get_mut(),
101 Self::Nothing => unreachable!(),
102 }
103 }
104}
105
106#[derive(Debug, Default)]
107struct OsslBio {
108 incoming: Vec<u8>,
109 outgoing: VecDeque<Vec<u8>>,
110}
111
112impl OsslBio {
113 fn push_incoming(&mut self, buf: &[u8]) {
114 self.incoming.extend_from_slice(buf)
115 }
116
117 fn pop_outgoing(&mut self) -> Option<Vec<u8>> {
118 self.outgoing.pop_front()
119 }
120}
121
122impl std::io::Write for OsslBio {
123 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
124 self.outgoing.push_back(buf.to_vec());
125 Ok(buf.len())
126 }
127
128 fn flush(&mut self) -> std::io::Result<()> {
129 Ok(())
130 }
131}
132
133impl std::io::Read for OsslBio {
134 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
135 let len = self.incoming.len();
136 let max = buf.len().min(len);
137
138 if len == 0 {
139 return Err(std::io::Error::new(
140 std::io::ErrorKind::WouldBlock,
141 "Would Block",
142 ));
143 }
144
145 buf[..max].copy_from_slice(&self.incoming[..max]);
146 if max == len {
147 self.incoming.truncate(0);
148 } else {
149 self.incoming.drain(..max);
150 }
151
152 Ok(max)
153 }
154}
155
156impl OpensslTurnServer {
157 pub fn new(
159 transport: TransportType,
160 listen_addr: SocketAddr,
161 realm: String,
162 ssl_context: SslContext,
163 ) -> Self {
164 Self {
165 server: TurnServer::new(transport, listen_addr, realm),
166 ssl_context,
167 connections: vec![],
168 }
169 }
170}
171
172impl TurnServerApi for OpensslTurnServer {
173 fn add_user(&mut self, username: String, password: String) {
175 self.server.add_user(username, password)
176 }
177
178 fn listen_address(&self) -> SocketAddr {
180 self.server.listen_address()
181 }
182
183 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
186 self.server.set_nonce_expiry_duration(expiry_duration)
187 }
188
189 #[tracing::instrument(
193 name = "turn_server_openssl_recv",
194 skip(self, transmit, now),
195 fields(
196 from = ?transmit.from,
197 data_len = transmit.data.as_ref().len()
198 )
199 )]
200 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
201 &mut self,
202 transmit: Transmit<T>,
203 now: Instant,
204 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
205 let listen_address = self.listen_address();
206 if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
207 trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
208 let (client_addr, conn) = match self
210 .connections
211 .iter_mut()
212 .find(|(client_addr, _conn)| *client_addr == transmit.from)
213 {
214 Some((client_addr, conn)) => (*client_addr, conn),
215 None => {
216 let len = self.connections.len();
217 let ssl = Ssl::new(&self.ssl_context).expect("Cannot create ssl structure");
218 self.connections
219 .push((transmit.from, HandshakeState::Init(ssl, OsslBio::default())));
220 info!("new connection from {}", transmit.from);
221 let ret = &mut self.connections[len];
222 (ret.0, &mut ret.1)
223 }
224 };
225 conn.inner_mut().push_incoming(transmit.data.as_ref());
226 let stream = match conn.complete() {
227 Ok(s) => s,
228 Err(e) => {
229 if e.kind() != std::io::ErrorKind::WouldBlock {
230 warn!("error accepting TLS: {e}");
231 }
232 return None;
233 }
234 };
235
236 let mut plaintext = vec![0; 2048];
237 let len = match stream.read(&mut plaintext) {
238 Ok(len) => len,
239 Err(e) => {
240 if e.kind() != std::io::ErrorKind::WouldBlock {
241 tracing::warn!("Error: {e}");
242 }
243 return None;
244 }
245 };
246 plaintext.resize(len, 0);
247
248 let transmit = self.server.recv(
249 Transmit::new(plaintext, transmit.transport, transmit.from, transmit.to),
250 now,
251 )?;
252
253 if transmit.transport == TransportType::Tcp
254 && transmit.from == listen_address
255 && transmit.to == client_addr
256 {
257 let plaintext = transmit.data.build();
258 stream.write_all(&plaintext).unwrap();
259 stream.get_mut().pop_outgoing().map(|data| {
260 TransmitBuild::new(
261 DelayedMessageOrChannelSend::Owned(data),
262 TransportType::Tcp,
263 listen_address,
264 client_addr,
265 )
266 })
267 } else {
268 let transmit = transmit.build();
269 Some(TransmitBuild::new(
270 DelayedMessageOrChannelSend::Owned(transmit.data),
271 transmit.transport,
272 transmit.from,
273 transmit.to,
274 ))
275 }
276 } else if let Some(transmit) = self.server.recv(transmit, now) {
277 if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
279 let Some((client_addr, conn)) = self
280 .connections
281 .iter_mut()
282 .find(|(client_addr, _conn)| transmit.to == *client_addr)
283 else {
284 return Some(transmit);
285 };
286
287 let plaintext = transmit.data.build();
288 let stream = match conn.complete() {
289 Ok(s) => s,
290 Err(e) => {
291 if e.kind() != std::io::ErrorKind::WouldBlock {
292 warn!("error accepting TLS: {e}");
293 }
294 return None;
295 }
296 };
297 stream.write_all(&plaintext).unwrap();
298 stream.get_mut().pop_outgoing().map(|data| {
299 TransmitBuild::new(
300 DelayedMessageOrChannelSend::Owned(data),
301 TransportType::Tcp,
302 listen_address,
303 *client_addr,
304 )
305 })
306 } else {
307 Some(transmit)
308 }
309 } else {
310 None
311 }
312 }
313
314 fn recv_icmp<T: AsRef<[u8]>>(
315 &mut self,
316 family: AddressFamily,
317 bytes: T,
318 now: Instant,
319 ) -> Option<Transmit<Vec<u8>>> {
320 let transmit = self.server.recv_icmp(family, bytes, now)?;
321 let listen_address = self.listen_address();
323 if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
324 let Some((client_addr, conn)) = self
325 .connections
326 .iter_mut()
327 .find(|(client_addr, _conn)| transmit.to == *client_addr)
328 else {
329 return Some(transmit);
330 };
331 let stream = match conn.complete() {
332 Ok(s) => s,
333 Err(e) => {
334 if e.kind() != std::io::ErrorKind::WouldBlock {
335 warn!("error accepting TLS: {e}");
336 }
337 return None;
338 }
339 };
340 stream.write_all(&transmit.data).unwrap();
341 stream
342 .get_mut()
343 .pop_outgoing()
344 .map(|data| Transmit::new(data, TransportType::Tcp, listen_address, *client_addr))
345 } else {
346 Some(transmit)
347 }
348 }
349
350 fn poll(&mut self, now: Instant) -> TurnServerPollRet {
354 let protocol_ret = self.server.poll(now);
355 let mut have_pending = false;
356 for (_client_addr, conn) in self.connections.iter_mut() {
357 let stream = match conn.complete() {
358 Ok(s) => s,
359 Err(_) => continue,
360 };
361 if !stream.get_mut().outgoing.is_empty() {
362 have_pending = true;
363 }
364 }
365 if have_pending {
366 return TurnServerPollRet::WaitUntil(now);
367 }
368 protocol_ret
369 }
370
371 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
373 let listen_address = self.listen_address();
374
375 for (client_addr, conn) in self.connections.iter_mut() {
376 if let Some(data) = conn.inner_mut().pop_outgoing() {
377 return Some(Transmit::new(
378 data,
379 TransportType::Tcp,
380 listen_address,
381 *client_addr,
382 ));
383 }
384 }
385
386 while let Some(transmit) = self.server.poll_transmit(now) {
387 let Some((client_addr, conn)) = self
388 .connections
389 .iter_mut()
390 .find(|(client_addr, _conn)| transmit.to == *client_addr)
391 else {
392 warn!("return transmit: {transmit:?}");
393 return Some(transmit);
394 };
395 let stream = match conn.complete() {
396 Ok(s) => s,
397 Err(_) => continue,
399 };
400 stream.write_all(&transmit.data).unwrap();
401
402 if let Some(data) = conn.inner_mut().pop_outgoing() {
403 return Some(Transmit::new(
404 data,
405 TransportType::Tcp,
406 listen_address,
407 *client_addr,
408 ));
409 }
410 }
411 None
412 }
413
414 fn allocated_udp_socket(
417 &mut self,
418 transport: TransportType,
419 local_addr: SocketAddr,
420 remote_addr: SocketAddr,
421 family: AddressFamily,
422 socket_addr: Result<SocketAddr, SocketAllocateError>,
423 now: Instant,
424 ) {
425 self.server.allocated_udp_socket(
426 transport,
427 local_addr,
428 remote_addr,
429 family,
430 socket_addr,
431 now,
432 )
433 }
434}