1use alloc::vec;
31use alloc::vec::Vec;
32use core::ops::Range;
33
34use stun_proto::agent::Transmit;
35use stun_types::message::{Message, MessageHeader};
36use tracing::{debug, trace};
37
38use crate::channel::ChannelData;
39
40#[derive(Debug)]
45pub enum IncomingTcp<T: AsRef<[u8]> + core::fmt::Debug> {
46 CompleteMessage(Transmit<T>, Range<usize>),
50 CompleteChannel(Transmit<T>, Range<usize>),
54 StoredMessage(Vec<u8>, Transmit<T>),
56 StoredChannel(Vec<u8>, Transmit<T>),
58}
59
60impl<T: AsRef<[u8]> + core::fmt::Debug> IncomingTcp<T> {
61 pub fn data(&self) -> &[u8] {
63 match self {
64 Self::CompleteMessage(transmit, range) => {
65 &transmit.data.as_ref()[range.start..range.end]
66 }
67 Self::CompleteChannel(transmit, range) => {
68 &transmit.data.as_ref()[range.start..range.end]
69 }
70 Self::StoredMessage(data, _transmit) => data,
71 Self::StoredChannel(data, _transmit) => data,
72 }
73 }
74
75 pub fn message(&self) -> Option<Message<'_>> {
77 if !matches!(
78 self,
79 Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
80 ) {
81 return None;
82 }
83 Message::from_bytes(self.data()).ok()
84 }
85
86 pub fn channel(&self) -> Option<ChannelData<'_>> {
88 if !matches!(
89 self,
90 Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
91 ) {
92 return None;
93 }
94 ChannelData::parse(self.data()).ok()
95 }
96}
97
98impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
99 fn as_ref(&self) -> &[u8] {
100 self.data()
101 }
102}
103
104#[derive(Debug)]
106pub enum StoredTcp {
107 Message(Vec<u8>),
109 Channel(Vec<u8>),
111}
112
113impl StoredTcp {
114 pub fn data(&self) -> &[u8] {
116 match self {
117 Self::Message(data) => data,
118 Self::Channel(data) => data,
119 }
120 }
121
122 fn into_incoming<T: AsRef<[u8]> + core::fmt::Debug>(
123 self,
124 transmit: Transmit<T>,
125 ) -> IncomingTcp<T> {
126 match self {
127 Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
128 Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
129 }
130 }
131}
132
133impl AsRef<[u8]> for StoredTcp {
134 fn as_ref(&self) -> &[u8] {
135 self.data()
136 }
137}
138
139#[derive(Debug, Default)]
141pub struct TurnTcpBuffer {
142 tcp_buffer: Vec<u8>,
143}
144
145impl TurnTcpBuffer {
146 pub fn new() -> Self {
148 Self { tcp_buffer: vec![] }
149 }
150
151 #[tracing::instrument(
156 level = "trace",
157 skip(self, transmit),
158 fields(
159 transmit.data_len = transmit.data.as_ref().len(),
160 from = ?transmit.from
161 )
162 )]
163 pub fn incoming_tcp<T: AsRef<[u8]> + core::fmt::Debug>(
164 &mut self,
165 transmit: Transmit<T>,
166 ) -> Option<IncomingTcp<T>> {
167 if self.tcp_buffer.is_empty() {
168 let data = transmit.data.as_ref();
169 trace!("Trying to parse incoming data as a complete message/channel");
170 let Ok(hdr) = MessageHeader::from_bytes(data) else {
171 let Ok(channel) = ChannelData::parse(data) else {
172 self.tcp_buffer.extend_from_slice(data);
173 return None;
174 };
175 let channel_len = 4 + channel.data().len();
176 debug!(
177 channel.id = channel.id(),
178 channel.len = channel_len - 4,
179 "Incoming data contains a channel",
180 );
181 if channel_len < data.len() {
182 self.tcp_buffer.extend_from_slice(&data[channel_len..]);
183 }
184 return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
185 };
186 let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
187 debug!(
188 msg.transaction = %hdr.transaction_id(),
189 msg.len = msg_len,
190 "Incoming data contains a message",
191 );
192 if data.len() < msg_len {
193 self.tcp_buffer.extend_from_slice(data);
194 return None;
195 }
196 if msg_len < data.len() {
197 self.tcp_buffer.extend_from_slice(&data[msg_len..]);
198 }
199 return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
200 }
201
202 self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
203 self.poll_recv().map(|recv| recv.into_incoming(transmit))
204 }
205
206 #[tracing::instrument(
208 level = "trace",
209 skip(self),
210 fields(
211 buffered_len = self.tcp_buffer.len(),
212 )
213 )]
214 pub fn poll_recv(&mut self) -> Option<StoredTcp> {
215 let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
216 let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
217 trace!(
218 buffered.len = self.tcp_buffer.len(),
219 "cannot parse stored data"
220 );
221 return None;
222 };
223 let channel_len = 4 + channel_data_len;
224 if self.tcp_buffer.len() < channel_len {
225 trace!(
226 buffered.len = self.tcp_buffer.len(),
227 required = channel_len,
228 "need more bytes to complete channel data"
229 );
230 return None;
231 }
232 let (data, remaining) = self.tcp_buffer.split_at(channel_len);
233 let data_binding = data.to_vec();
234 debug!(
235 channel.id = id,
236 channel.len = channel_data_len,
237 remaining = remaining.len(),
238 "buffered data contains a channel",
239 );
240 self.tcp_buffer = remaining.to_vec();
241 return Some(StoredTcp::Channel(data_binding));
242 };
243 let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
244 if self.tcp_buffer.len() < msg_len {
245 trace!(
246 buffered.len = self.tcp_buffer.len(),
247 required = msg_len,
248 "need more bytes to complete STUN message"
249 );
250 return None;
251 }
252 let (data, remaining) = self.tcp_buffer.split_at(msg_len);
253 let data_binding = data.to_vec();
254 debug!(
255 msg.transaction = %hdr.transaction_id(),
256 msg.len = msg_len,
257 remaining = remaining.len(),
258 "stored data contains a message",
259 );
260 self.tcp_buffer = remaining.to_vec();
261 Some(StoredTcp::Message(data_binding))
262 }
263
264 pub fn into_inner(self) -> Vec<u8> {
266 self.tcp_buffer
267 }
268
269 pub fn len(&self) -> usize {
271 self.tcp_buffer.len()
272 }
273
274 pub fn is_empty(&self) -> bool {
276 self.tcp_buffer.is_empty()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use core::net::SocketAddr;
283
284 use stun_types::{
285 attribute::Software,
286 message::{Message, MessageWriteVec},
287 prelude::{MessageWrite, MessageWriteExt},
288 TransportType,
289 };
290 use tracing::info;
291
292 use crate::message::ALLOCATE;
293
294 use super::*;
295
296 fn generate_addresses() -> (SocketAddr, SocketAddr) {
297 (
298 "192.168.0.1:1000".parse().unwrap(),
299 "10.0.0.2:2000".parse().unwrap(),
300 )
301 }
302
303 fn generate_message() -> Vec<u8> {
304 let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
305 msg.add_attribute(&Software::new("turn-types").unwrap())
306 .unwrap();
307 msg.add_fingerprint().unwrap();
308 msg.finish()
309 }
310
311 fn generate_message_in_channel() -> Vec<u8> {
312 let msg = generate_message();
313 let channel = ChannelData::new(0x4000, &msg);
314 let mut out = vec![0; msg.len() + 4];
315 channel.write_into_unchecked(&mut out);
316 out
317 }
318
319 #[test]
320 fn test_incoming_tcp_complete_message() {
321 let _init = crate::tests::test_init_log();
322 let (local_addr, remote_addr) = generate_addresses();
323 let mut tcp = TurnTcpBuffer::new();
324 let msg = generate_message();
325 let ret = tcp
326 .incoming_tcp(Transmit::new(
327 msg.clone(),
328 TransportType::Tcp,
329 remote_addr,
330 local_addr,
331 ))
332 .unwrap();
333 assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
334 assert_eq!(ret.data(), &msg);
335 assert!(ret.message().is_some());
336 assert!(tcp.into_inner().is_empty());
337 }
338
339 #[test]
340 fn test_incoming_tcp_complete_message_in_channel() {
341 let _init = crate::tests::test_init_log();
342 let (local_addr, remote_addr) = generate_addresses();
343 let mut tcp = TurnTcpBuffer::new();
344 let msg = generate_message_in_channel();
345 let ret = tcp
346 .incoming_tcp(Transmit::new(
347 msg.clone(),
348 TransportType::Tcp,
349 remote_addr,
350 local_addr,
351 ))
352 .unwrap();
353 assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
354 assert_eq!(ret.data(), &msg);
355 assert!(ret.channel().is_some());
356 assert!(tcp.into_inner().is_empty());
357 }
358
359 #[test]
360 fn test_incoming_tcp_partial_message() {
361 let _init = crate::tests::test_init_log();
362 let (local_addr, remote_addr) = generate_addresses();
363 let mut tcp = TurnTcpBuffer::new();
364 let msg = generate_message();
365 info!("message: {msg:x?}");
366 for i in 1..msg.len() {
367 let ret = tcp.incoming_tcp(Transmit::new(
368 &msg[i - 1..i],
369 TransportType::Tcp,
370 remote_addr,
371 local_addr,
372 ));
373 assert!(ret.is_none());
374
375 let data = tcp.into_inner();
376 assert_eq!(&data, &msg[..i]);
377 tcp = TurnTcpBuffer::new();
378 let ret = tcp.incoming_tcp(Transmit::new(
379 &data,
380 TransportType::Tcp,
381 remote_addr,
382 local_addr,
383 ));
384 assert!(ret.is_none());
385 }
386 let ret = tcp
387 .incoming_tcp(Transmit::new(
388 &msg[msg.len() - 1..],
389 TransportType::Tcp,
390 remote_addr,
391 local_addr,
392 ))
393 .unwrap();
394 assert_eq!(ret.data(), &msg);
395 assert!(ret.message().is_some());
396 let IncomingTcp::StoredMessage(produced, _) = ret else {
397 unreachable!();
398 };
399 assert_eq!(produced, msg);
400 assert!(tcp.into_inner().is_empty());
401 }
402
403 #[test]
404 fn test_incoming_tcp_partial_channel() {
405 let _init = crate::tests::test_init_log();
406 let (local_addr, remote_addr) = generate_addresses();
407 let mut tcp = TurnTcpBuffer::new();
408 let channel = generate_message_in_channel();
409 info!("message: {channel:x?}");
410 for i in 1..channel.len() {
411 let ret = tcp.incoming_tcp(Transmit::new(
412 &channel[i - 1..i],
413 TransportType::Tcp,
414 remote_addr,
415 local_addr,
416 ));
417 assert!(ret.is_none());
418
419 let data = tcp.into_inner();
420 assert_eq!(&data, &channel[..i]);
421 tcp = TurnTcpBuffer::new();
422 let ret = tcp.incoming_tcp(Transmit::new(
423 &data,
424 TransportType::Tcp,
425 remote_addr,
426 local_addr,
427 ));
428 assert!(ret.is_none());
429 }
430 let ret = tcp
431 .incoming_tcp(Transmit::new(
432 &channel[channel.len() - 1..],
433 TransportType::Tcp,
434 remote_addr,
435 local_addr,
436 ))
437 .unwrap();
438 assert_eq!(ret.data(), &channel);
439 assert!(ret.channel().is_some());
440 let IncomingTcp::StoredChannel(produced, _) = ret else {
441 unreachable!()
442 };
443 assert_eq!(produced, channel);
444 assert!(tcp.into_inner().is_empty());
445 }
446
447 #[test]
448 fn test_incoming_tcp_message_and_channel() {
449 let _init = crate::tests::test_init_log();
450 let (local_addr, remote_addr) = generate_addresses();
451 let mut tcp = TurnTcpBuffer::new();
452 let msg = generate_message();
453 let channel = generate_message_in_channel();
454 let mut input = msg.clone();
455 input.extend_from_slice(&channel);
456 let ret = tcp
457 .incoming_tcp(Transmit::new(
458 input.clone(),
459 TransportType::Tcp,
460 remote_addr,
461 local_addr,
462 ))
463 .unwrap();
464 assert_eq!(ret.data(), &msg);
465 assert!(ret.message().is_some());
466 let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
467 unreachable!();
468 };
469 assert_eq!(msg_range, 0..msg.len());
470 assert_eq!(transmit.data, input);
471 let ret = tcp.poll_recv().unwrap();
472 assert_eq!(ret.data(), &channel);
473 let StoredTcp::Channel(produced) = ret else {
474 unreachable!()
475 };
476 assert_eq!(produced, channel);
477 }
478
479 #[test]
480 fn test_incoming_tcp_channel_and_message() {
481 let _init = crate::tests::test_init_log();
482 let (local_addr, remote_addr) = generate_addresses();
483 let mut tcp = TurnTcpBuffer::new();
484 let msg = generate_message();
485 let channel = generate_message_in_channel();
486 let mut input = channel.clone();
487 input.extend_from_slice(&msg);
488 let ret = tcp
489 .incoming_tcp(Transmit::new(
490 input.clone(),
491 TransportType::Tcp,
492 remote_addr,
493 local_addr,
494 ))
495 .unwrap();
496 assert_eq!(ret.data(), &channel);
497 assert!(ret.channel().is_some());
498 let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
499 unreachable!()
500 };
501 assert_eq!(channel_range, 0..channel.len());
502 assert_eq!(transmit.data, input);
503 let ret = tcp.poll_recv().unwrap();
504 assert_eq!(ret.data(), &msg);
505 let StoredTcp::Message(produced) = ret else {
506 unreachable!()
507 };
508 assert_eq!(produced, msg);
509 }
510}