1use alloc::string::String;
12use alloc::sync::Arc;
13use alloc::vec;
14use alloc::vec::Vec;
15use core::net::SocketAddr;
16use core::time::Duration;
17use std::io::{Read, Write};
18use turn_types::prelude::DelayedTransmitBuild;
19use turn_types::transmit::TransmitBuild;
20use turn_types::AddressFamily;
21
22use rustls::{ServerConfig, ServerConnection};
23use stun_proto::agent::Transmit;
24use stun_proto::Instant;
25use tracing::{info, trace, warn};
26use turn_types::stun::TransportType;
27
28use crate::api::{
29 DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
30};
31use crate::server::TurnServer;
32
33#[derive(Debug)]
35pub struct RustlsTurnServer {
36 server: TurnServer,
37 config: Arc<ServerConfig>,
38 connections: Vec<(SocketAddr, ServerConnection)>,
39}
40
41impl RustlsTurnServer {
42 pub fn new(listen_addr: SocketAddr, realm: String, config: Arc<ServerConfig>) -> Self {
44 Self {
45 server: TurnServer::new(TransportType::Tcp, listen_addr, realm),
46 config,
47 connections: vec![],
48 }
49 }
50}
51
52impl TurnServerApi for RustlsTurnServer {
53 fn add_user(&mut self, username: String, password: String) {
55 self.server.add_user(username, password)
56 }
57
58 fn listen_address(&self) -> SocketAddr {
60 self.server.listen_address()
61 }
62
63 fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
66 self.server.set_nonce_expiry_duration(expiry_duration)
67 }
68
69 #[tracing::instrument(
73 name = "turn_server_rustls_recv",
74 skip(self, transmit, now),
75 fields(
76 from = ?transmit.from,
77 data_len = transmit.data.as_ref().len()
78 )
79 )]
80 fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
81 &mut self,
82 transmit: Transmit<T>,
83 now: Instant,
84 ) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
85 let listen_address = self.listen_address();
86 if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
87 trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
88 let (client_addr, conn) = match self
90 .connections
91 .iter_mut()
92 .find(|(client_addr, _conn)| *client_addr == transmit.from)
93 {
94 Some((client_addr, conn)) => (*client_addr, conn),
95 None => {
96 let len = self.connections.len();
97 self.connections.push((
98 transmit.from,
99 ServerConnection::new(self.config.clone()).unwrap(),
100 ));
101 info!("new connection from {}", transmit.from);
102 let ret = &mut self.connections[len];
103 (ret.0, &mut ret.1)
104 }
105 };
106 let mut input = std::io::Cursor::new(transmit.data.as_ref());
107 let io_state = match conn.read_tls(&mut input) {
108 Ok(_written) => match conn.process_new_packets() {
109 Ok(io_state) => io_state,
110 Err(e) => {
111 warn!("Error processing incoming TLS: {e:?}");
112 return None;
113 }
114 },
115 Err(e) => {
116 warn!("Error receiving data: {e:?}");
117 return None;
118 }
119 };
120 if io_state.plaintext_bytes_to_read() == 0 {
121 return None;
122 }
123 let mut vec = vec![0; 2048];
124 let n = match conn.reader().read(&mut vec) {
125 Ok(n) => n,
126 Err(e) => {
127 if e.kind() == std::io::ErrorKind::WouldBlock {
128 return None;
129 } else {
130 warn!("TLS error: {e:?}");
131 return None;
132 }
133 }
134 };
135 tracing::error!("io_state: {io_state:?}, n: {n}");
136 vec.resize(n, 0);
137 let transmit = self.server.recv(
138 Transmit::new(vec, transmit.transport, transmit.from, transmit.to),
139 now,
140 )?;
141 if transmit.transport == TransportType::Tcp
142 && transmit.from == listen_address
143 && transmit.to == client_addr
144 {
145 let plaintext = transmit.data.build();
146 conn.writer().write_all(&plaintext).unwrap();
147 let mut out = vec![];
148 conn.write_tls(&mut out).unwrap();
149 Some(TransmitBuild::new(
150 DelayedMessageOrChannelSend::Owned(out),
151 TransportType::Tcp,
152 listen_address,
153 client_addr,
154 ))
155 } else {
156 let transmit = transmit.build();
157 Some(TransmitBuild::new(
158 DelayedMessageOrChannelSend::Owned(transmit.data),
159 transmit.transport,
160 transmit.from,
161 transmit.to,
162 ))
163 }
164 } else if let Some(transmit) = self.server.recv(transmit, now) {
165 if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
167 let Some((client_addr, conn)) = self
168 .connections
169 .iter_mut()
170 .find(|(client_addr, _conn)| transmit.to == *client_addr)
171 else {
172 return Some(transmit);
173 };
174 let plaintext = transmit.data.build();
175 conn.writer().write_all(&plaintext).unwrap();
176 let mut out = vec![];
177 conn.write_tls(&mut out).unwrap();
178 Some(TransmitBuild::new(
179 DelayedMessageOrChannelSend::Owned(out),
180 TransportType::Tcp,
181 listen_address,
182 *client_addr,
183 ))
184 } else {
185 Some(transmit)
186 }
187 } else {
188 None
189 }
190 }
191
192 fn recv_icmp<T: AsRef<[u8]>>(
193 &mut self,
194 family: AddressFamily,
195 bytes: T,
196 now: Instant,
197 ) -> Option<Transmit<Vec<u8>>> {
198 let transmit = self.server.recv_icmp(family, bytes, now)?;
199 let listen_address = self.listen_address();
201 if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
202 let Some((client_addr, conn)) = self
203 .connections
204 .iter_mut()
205 .find(|(client_addr, _conn)| transmit.to == *client_addr)
206 else {
207 return Some(transmit);
208 };
209 conn.writer().write_all(&transmit.data).unwrap();
210 let mut out = vec![];
211 conn.write_tls(&mut out).unwrap();
212 Some(Transmit::new(
213 out,
214 TransportType::Tcp,
215 listen_address,
216 *client_addr,
217 ))
218 } else {
219 Some(transmit)
220 }
221 }
222
223 fn poll(&mut self, now: Instant) -> TurnServerPollRet {
227 let protocol_ret = self.server.poll(now);
228 let mut have_pending = false;
229 for (_client_addr, conn) in self.connections.iter_mut() {
230 let io_state = match conn.process_new_packets() {
231 Ok(io_state) => io_state,
232 Err(e) => {
233 warn!("Error processing TLS: {e:?}");
234 continue;
235 }
236 };
237 if io_state.tls_bytes_to_write() > 0 {
238 have_pending = true;
239 }
240 }
241 if have_pending {
242 return TurnServerPollRet::WaitUntil(now);
243 }
244 protocol_ret
245 }
246
247 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
249 let listen_address = self.listen_address();
250
251 while let Some(transmit) = self.server.poll_transmit(now) {
252 let Some((_client_addr, conn)) = self
253 .connections
254 .iter_mut()
255 .find(|(client_addr, _conn)| transmit.to == *client_addr)
256 else {
257 warn!("return transmit: {transmit:?}");
258 return Some(transmit);
259 };
260 conn.writer().write_all(&transmit.data).unwrap();
261 }
262
263 for (client_addr, conn) in self.connections.iter_mut() {
264 if !conn.wants_write() {
265 continue;
266 }
267 let mut vec = vec![];
268 let n = match conn.write_tls(&mut vec) {
269 Ok(n) => n,
270 Err(e) => {
271 warn!("error writing TLS: {e:?}");
272 continue;
273 }
274 };
275 vec.resize(n, 0);
276 warn!("return transmit: {vec:x?}");
277 return Some(Transmit::new(
278 vec,
279 TransportType::Tcp,
280 listen_address,
281 *client_addr,
282 ));
283 }
284 None
285 }
286
287 fn allocated_udp_socket(
290 &mut self,
291 transport: TransportType,
292 local_addr: SocketAddr,
293 remote_addr: SocketAddr,
294 family: AddressFamily,
295 socket_addr: Result<SocketAddr, SocketAllocateError>,
296 now: Instant,
297 ) {
298 self.server.allocated_udp_socket(
299 transport,
300 local_addr,
301 remote_addr,
302 family,
303 socket_addr,
304 now,
305 )
306 }
307}