1use alloc::vec;
31use alloc::vec::Vec;
32use core::ops::Range;
33
34use stun_proto::agent::Transmit;
35use stun_types::message::{Message, MessageHeader};
36use tracing::{debug, trace};
37
38use crate::channel::ChannelData;
39
40#[derive(Debug)]
45pub enum IncomingTcp<T: AsRef<[u8]> + core::fmt::Debug> {
46 CompleteMessage(Transmit<T>, Range<usize>),
50 CompleteChannel(Transmit<T>, Range<usize>),
54 StoredMessage(Vec<u8>, Transmit<T>),
56 StoredChannel(Vec<u8>, Transmit<T>),
58}
59
60impl<T: AsRef<[u8]> + core::fmt::Debug> IncomingTcp<T> {
61 pub fn data(&self) -> &[u8] {
63 match self {
64 Self::CompleteMessage(transmit, range) => {
65 &transmit.data.as_ref()[range.start..range.end]
66 }
67 Self::CompleteChannel(transmit, range) => {
68 &transmit.data.as_ref()[range.start..range.end]
69 }
70 Self::StoredMessage(data, _transmit) => data,
71 Self::StoredChannel(data, _transmit) => data,
72 }
73 }
74
75 pub fn message(&self) -> Option<Message<'_>> {
77 if !matches!(
78 self,
79 Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
80 ) {
81 return None;
82 }
83 Message::from_bytes(self.data()).ok()
84 }
85
86 pub fn channel(&self) -> Option<ChannelData<'_>> {
88 if !matches!(
89 self,
90 Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
91 ) {
92 return None;
93 }
94 ChannelData::parse(self.data()).ok()
95 }
96}
97
98impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
99 fn as_ref(&self) -> &[u8] {
100 self.data()
101 }
102}
103
104#[derive(Debug)]
106pub enum StoredTcp {
107 Message(Vec<u8>),
109 Channel(Vec<u8>),
111}
112
113impl StoredTcp {
114 pub fn data(&self) -> &[u8] {
116 match self {
117 Self::Message(data) => data,
118 Self::Channel(data) => data,
119 }
120 }
121
122 fn into_incoming<T: AsRef<[u8]> + core::fmt::Debug>(
123 self,
124 transmit: Transmit<T>,
125 ) -> IncomingTcp<T> {
126 match self {
127 Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
128 Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
129 }
130 }
131}
132
133impl AsRef<[u8]> for StoredTcp {
134 fn as_ref(&self) -> &[u8] {
135 self.data()
136 }
137}
138
139#[derive(Debug, Default)]
141pub struct TurnTcpBuffer {
142 tcp_buffer: Vec<u8>,
143}
144
145impl TurnTcpBuffer {
146 pub fn new() -> Self {
148 Self { tcp_buffer: vec![] }
149 }
150
151 #[tracing::instrument(
156 level = "trace",
157 skip(self, transmit),
158 fields(
159 transmit.data_len = transmit.data.as_ref().len(),
160 from = ?transmit.from
161 )
162 )]
163 pub fn incoming_tcp<T: AsRef<[u8]> + core::fmt::Debug>(
164 &mut self,
165 transmit: Transmit<T>,
166 ) -> Option<IncomingTcp<T>> {
167 if self.tcp_buffer.is_empty() {
168 let data = transmit.data.as_ref();
169 trace!("Trying to parse incoming data as a complete message/channel");
170 let Ok(hdr) = MessageHeader::from_bytes(data) else {
171 let Ok(channel) = ChannelData::parse(data) else {
172 self.tcp_buffer.extend_from_slice(data);
173 return None;
174 };
175 let channel_len = 4 + channel.data().len();
176 debug!(
177 channel.id = channel.id(),
178 channel.len = channel_len - 4,
179 "Incoming data contains a channel",
180 );
181 if channel_len < data.len() {
182 self.tcp_buffer.extend_from_slice(&data[channel_len..]);
183 }
184 return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
185 };
186 let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
187 debug!(
188 msg.transaction = %hdr.transaction_id(),
189 msg.len = msg_len,
190 "Incoming data contains a message",
191 );
192 if data.len() < msg_len {
193 self.tcp_buffer.extend_from_slice(data);
194 return None;
195 }
196 if msg_len < data.len() {
197 self.tcp_buffer.extend_from_slice(&data[msg_len..]);
198 }
199 return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
200 }
201
202 self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
203 self.poll_recv().map(|recv| recv.into_incoming(transmit))
204 }
205
206 #[tracing::instrument(
208 level = "trace",
209 skip(self),
210 fields(
211 buffered_len = self.tcp_buffer.len(),
212 )
213 )]
214 pub fn poll_recv(&mut self) -> Option<StoredTcp> {
215 let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
216 let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
217 trace!(
218 buffered.len = self.tcp_buffer.len(),
219 "cannot parse stored data"
220 );
221 return None;
222 };
223 let channel_len = 4 + channel_data_len;
224 if self.tcp_buffer.len() < channel_len {
225 trace!(
226 buffered.len = self.tcp_buffer.len(),
227 required = channel_len,
228 "need more bytes to complete channel data"
229 );
230 return None;
231 }
232 let (data, remaining) = self.tcp_buffer.split_at(channel_len);
233 let data_binding = data.to_vec();
234 debug!(
235 channel.id = id,
236 channel.len = channel_data_len,
237 remaining = remaining.len(),
238 "buffered data contains a channel",
239 );
240 self.tcp_buffer = remaining.to_vec();
241 return Some(StoredTcp::Channel(data_binding));
242 };
243 let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
244 if self.tcp_buffer.len() < msg_len {
245 trace!(
246 buffered.len = self.tcp_buffer.len(),
247 required = msg_len,
248 "need more bytes to complete STUN message"
249 );
250 return None;
251 }
252 let (data, remaining) = self.tcp_buffer.split_at(msg_len);
253 let data_binding = data.to_vec();
254 debug!(
255 msg.transaction = %hdr.transaction_id(),
256 msg.len = msg_len,
257 remaining = remaining.len(),
258 "stored data contains a message",
259 );
260 self.tcp_buffer = remaining.to_vec();
261 Some(StoredTcp::Message(data_binding))
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use core::net::SocketAddr;
268
269 use stun_types::{
270 attribute::Software,
271 message::{Message, MessageWriteVec},
272 prelude::{MessageWrite, MessageWriteExt},
273 TransportType,
274 };
275 use tracing::info;
276
277 use crate::message::ALLOCATE;
278
279 use super::*;
280
281 fn generate_addresses() -> (SocketAddr, SocketAddr) {
282 (
283 "192.168.0.1:1000".parse().unwrap(),
284 "10.0.0.2:2000".parse().unwrap(),
285 )
286 }
287
288 fn generate_message() -> Vec<u8> {
289 let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
290 msg.add_attribute(&Software::new("turn-types").unwrap())
291 .unwrap();
292 msg.add_fingerprint().unwrap();
293 msg.finish()
294 }
295
296 fn generate_message_in_channel() -> Vec<u8> {
297 let msg = generate_message();
298 let channel = ChannelData::new(0x4000, &msg);
299 let mut out = vec![0; msg.len() + 4];
300 channel.write_into_unchecked(&mut out);
301 out
302 }
303
304 #[test]
305 fn test_incoming_tcp_complete_message() {
306 let _init = crate::tests::test_init_log();
307 let (local_addr, remote_addr) = generate_addresses();
308 let mut tcp = TurnTcpBuffer::new();
309 let msg = generate_message();
310 let ret = tcp
311 .incoming_tcp(Transmit::new(
312 msg.clone(),
313 TransportType::Tcp,
314 remote_addr,
315 local_addr,
316 ))
317 .unwrap();
318 assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
319 assert_eq!(ret.data(), &msg);
320 assert!(ret.message().is_some());
321 }
322
323 #[test]
324 fn test_incoming_tcp_complete_message_in_channel() {
325 let _init = crate::tests::test_init_log();
326 let (local_addr, remote_addr) = generate_addresses();
327 let mut tcp = TurnTcpBuffer::new();
328 let msg = generate_message_in_channel();
329 let ret = tcp
330 .incoming_tcp(Transmit::new(
331 msg.clone(),
332 TransportType::Tcp,
333 remote_addr,
334 local_addr,
335 ))
336 .unwrap();
337 assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
338 assert_eq!(ret.data(), &msg);
339 assert!(ret.channel().is_some());
340 }
341
342 #[test]
343 fn test_incoming_tcp_partial_message() {
344 let _init = crate::tests::test_init_log();
345 let (local_addr, remote_addr) = generate_addresses();
346 let mut tcp = TurnTcpBuffer::new();
347 let msg = generate_message();
348 info!("message: {msg:x?}");
349 for i in 1..msg.len() {
350 let ret = tcp.incoming_tcp(Transmit::new(
351 &msg[i - 1..i],
352 TransportType::Tcp,
353 remote_addr,
354 local_addr,
355 ));
356 assert!(ret.is_none());
357 }
358 let ret = tcp
359 .incoming_tcp(Transmit::new(
360 &msg[msg.len() - 1..],
361 TransportType::Tcp,
362 remote_addr,
363 local_addr,
364 ))
365 .unwrap();
366 assert_eq!(ret.data(), &msg);
367 assert!(ret.message().is_some());
368 let IncomingTcp::StoredMessage(produced, _) = ret else {
369 unreachable!();
370 };
371 assert_eq!(produced, msg);
372 }
373
374 #[test]
375 fn test_incoming_tcp_partial_channel() {
376 let _init = crate::tests::test_init_log();
377 let (local_addr, remote_addr) = generate_addresses();
378 let mut tcp = TurnTcpBuffer::new();
379 let channel = generate_message_in_channel();
380 info!("message: {channel:x?}");
381 for i in 1..channel.len() {
382 let ret = tcp.incoming_tcp(Transmit::new(
383 &channel[i - 1..i],
384 TransportType::Tcp,
385 remote_addr,
386 local_addr,
387 ));
388 assert!(ret.is_none());
389 }
390 let ret = tcp
391 .incoming_tcp(Transmit::new(
392 &channel[channel.len() - 1..],
393 TransportType::Tcp,
394 remote_addr,
395 local_addr,
396 ))
397 .unwrap();
398 assert_eq!(ret.data(), &channel);
399 assert!(ret.channel().is_some());
400 let IncomingTcp::StoredChannel(produced, _) = ret else {
401 unreachable!()
402 };
403 assert_eq!(produced, channel);
404 }
405
406 #[test]
407 fn test_incoming_tcp_message_and_channel() {
408 let _init = crate::tests::test_init_log();
409 let (local_addr, remote_addr) = generate_addresses();
410 let mut tcp = TurnTcpBuffer::new();
411 let msg = generate_message();
412 let channel = generate_message_in_channel();
413 let mut input = msg.clone();
414 input.extend_from_slice(&channel);
415 let ret = tcp
416 .incoming_tcp(Transmit::new(
417 input.clone(),
418 TransportType::Tcp,
419 remote_addr,
420 local_addr,
421 ))
422 .unwrap();
423 assert_eq!(ret.data(), &msg);
424 assert!(ret.message().is_some());
425 let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
426 unreachable!();
427 };
428 assert_eq!(msg_range, 0..msg.len());
429 assert_eq!(transmit.data, input);
430 let ret = tcp.poll_recv().unwrap();
431 assert_eq!(ret.data(), &channel);
432 let StoredTcp::Channel(produced) = ret else {
433 unreachable!()
434 };
435 assert_eq!(produced, channel);
436 }
437
438 #[test]
439 fn test_incoming_tcp_channel_and_message() {
440 let _init = crate::tests::test_init_log();
441 let (local_addr, remote_addr) = generate_addresses();
442 let mut tcp = TurnTcpBuffer::new();
443 let msg = generate_message();
444 let channel = generate_message_in_channel();
445 let mut input = channel.clone();
446 input.extend_from_slice(&msg);
447 let ret = tcp
448 .incoming_tcp(Transmit::new(
449 input.clone(),
450 TransportType::Tcp,
451 remote_addr,
452 local_addr,
453 ))
454 .unwrap();
455 assert_eq!(ret.data(), &channel);
456 assert!(ret.channel().is_some());
457 let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
458 unreachable!()
459 };
460 assert_eq!(channel_range, 0..channel.len());
461 assert_eq!(transmit.data, input);
462 let ret = tcp.poll_recv().unwrap();
463 assert_eq!(ret.data(), &msg);
464 let StoredTcp::Message(produced) = ret else {
465 unreachable!()
466 };
467 assert_eq!(produced, msg);
468 }
469}