1use std::ops::Range;
21
22use stun_proto::agent::Transmit;
23use stun_types::message::{Message, MessageHeader};
24use tracing::{debug, trace};
25
26use crate::channel::ChannelData;
27
28#[derive(Debug)]
30pub enum IncomingTcp<T: AsRef<[u8]> + std::fmt::Debug> {
31 CompleteMessage(Transmit<T>, Range<usize>),
35 CompleteChannel(Transmit<T>, Range<usize>),
39 StoredMessage(Vec<u8>, Transmit<T>),
41 StoredChannel(Vec<u8>, Transmit<T>),
43}
44
45impl<T: AsRef<[u8]> + std::fmt::Debug> IncomingTcp<T> {
46 pub fn data(&self) -> &[u8] {
48 match self {
49 Self::CompleteMessage(transmit, range) => {
50 &transmit.data.as_ref()[range.start..range.end]
51 }
52 Self::CompleteChannel(transmit, range) => {
53 &transmit.data.as_ref()[range.start..range.end]
54 }
55 Self::StoredMessage(data, _transmit) => data,
56 Self::StoredChannel(data, _transmit) => data,
57 }
58 }
59
60 pub fn message(&self) -> Option<Message<'_>> {
62 if !matches!(
63 self,
64 Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
65 ) {
66 return None;
67 }
68 Message::from_bytes(self.data()).ok()
69 }
70
71 pub fn channel(&self) -> Option<ChannelData<'_>> {
73 if !matches!(
74 self,
75 Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
76 ) {
77 return None;
78 }
79 ChannelData::parse(self.data()).ok()
80 }
81}
82
83impl<T: AsRef<[u8]> + std::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
84 fn as_ref(&self) -> &[u8] {
85 self.data()
86 }
87}
88
89#[derive(Debug)]
91pub enum StoredTcp {
92 Message(Vec<u8>),
94 Channel(Vec<u8>),
96}
97
98impl StoredTcp {
99 pub fn data(&self) -> &[u8] {
101 match self {
102 Self::Message(data) => data,
103 Self::Channel(data) => data,
104 }
105 }
106
107 fn into_incoming<T: AsRef<[u8]> + std::fmt::Debug>(
108 self,
109 transmit: Transmit<T>,
110 ) -> IncomingTcp<T> {
111 match self {
112 Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
113 Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
114 }
115 }
116}
117
118impl AsRef<[u8]> for StoredTcp {
119 fn as_ref(&self) -> &[u8] {
120 self.data()
121 }
122}
123
124#[derive(Debug, Default)]
126pub struct TurnTcpBuffer {
127 tcp_buffer: Vec<u8>,
128}
129
130impl TurnTcpBuffer {
131 pub fn new() -> Self {
133 Self { tcp_buffer: vec![] }
134 }
135
136 #[tracing::instrument(
141 level = "trace",
142 skip(self, transmit),
143 fields(
144 transmit.data_len = transmit.data.as_ref().len(),
145 from = ?transmit.from
146 )
147 )]
148 pub fn incoming_tcp<T: AsRef<[u8]> + std::fmt::Debug>(
149 &mut self,
150 transmit: Transmit<T>,
151 ) -> Option<IncomingTcp<T>> {
152 if self.tcp_buffer.is_empty() {
153 let data = transmit.data.as_ref();
154 trace!("Trying to parse incoming data as a complete message/channel");
155 let Ok(hdr) = MessageHeader::from_bytes(data) else {
156 let Ok(channel) = ChannelData::parse(data) else {
157 self.tcp_buffer.extend_from_slice(data);
158 return None;
159 };
160 let channel_len = 4 + channel.data().len();
161 debug!(
162 channel.id = channel.id(),
163 channel.len = channel_len - 4,
164 "Incoming data contains a channel",
165 );
166 if channel_len < data.len() {
167 self.tcp_buffer.extend_from_slice(&data[channel_len..]);
168 }
169 return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
170 };
171 let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
172 debug!(
173 msg.transaction = %hdr.transaction_id(),
174 msg.len = msg_len,
175 "Incoming data contains a message",
176 );
177 if data.len() < msg_len {
178 self.tcp_buffer.extend_from_slice(data);
179 return None;
180 }
181 if msg_len < data.len() {
182 self.tcp_buffer.extend_from_slice(&data[msg_len..]);
183 }
184 return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
185 }
186
187 self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
188 self.poll_recv().map(|recv| recv.into_incoming(transmit))
189 }
190
191 #[tracing::instrument(
193 level = "trace",
194 skip(self),
195 fields(
196 buffered_len = self.tcp_buffer.len(),
197 )
198 )]
199 pub fn poll_recv(&mut self) -> Option<StoredTcp> {
200 let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
201 let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
202 trace!(
203 buffered.len = self.tcp_buffer.len(),
204 "cannot parse stored data"
205 );
206 return None;
207 };
208 let channel_len = 4 + channel_data_len;
209 if self.tcp_buffer.len() < channel_len {
210 trace!(
211 buffered.len = self.tcp_buffer.len(),
212 required = channel_len,
213 "need more bytes to complete channel data"
214 );
215 return None;
216 }
217 let (data, remaining) = self.tcp_buffer.split_at(channel_len);
218 let data_binding = data.to_vec();
219 debug!(
220 channel.id = id,
221 channel.len = channel_data_len,
222 remaining = remaining.len(),
223 "buffered data contains a channel",
224 );
225 self.tcp_buffer = remaining.to_vec();
226 return Some(StoredTcp::Channel(data_binding));
227 };
228 let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
229 if self.tcp_buffer.len() < msg_len {
230 trace!(
231 buffered.len = self.tcp_buffer.len(),
232 required = msg_len,
233 "need more bytes to complete STUN message"
234 );
235 return None;
236 }
237 let (data, remaining) = self.tcp_buffer.split_at(msg_len);
238 let data_binding = data.to_vec();
239 debug!(
240 msg.transaction = %hdr.transaction_id(),
241 msg.len = msg_len,
242 remaining = remaining.len(),
243 "stored data contains a message",
244 );
245 self.tcp_buffer = remaining.to_vec();
246 Some(StoredTcp::Message(data_binding))
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use std::net::SocketAddr;
253
254 use stun_types::{
255 attribute::Software,
256 message::{Message, MessageWriteVec},
257 prelude::{MessageWrite, MessageWriteExt},
258 TransportType,
259 };
260 use tracing::info;
261
262 use crate::message::ALLOCATE;
263
264 use super::*;
265
266 fn generate_addresses() -> (SocketAddr, SocketAddr) {
267 (
268 "192.168.0.1:1000".parse().unwrap(),
269 "10.0.0.2:2000".parse().unwrap(),
270 )
271 }
272
273 fn generate_message() -> Vec<u8> {
274 let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
275 msg.add_attribute(&Software::new("turn-types").unwrap())
276 .unwrap();
277 msg.add_fingerprint().unwrap();
278 msg.finish()
279 }
280
281 fn generate_message_in_channel() -> Vec<u8> {
282 let msg = generate_message();
283 let channel = ChannelData::new(0x4000, &msg);
284 let mut out = vec![0; msg.len() + 4];
285 channel.write_into_unchecked(&mut out);
286 out
287 }
288
289 #[test]
290 fn test_incoming_tcp_complete_message() {
291 let _init = crate::tests::test_init_log();
292 let (local_addr, remote_addr) = generate_addresses();
293 let mut tcp = TurnTcpBuffer::new();
294 let msg = generate_message();
295 let ret = tcp
296 .incoming_tcp(Transmit::new(
297 msg.clone(),
298 TransportType::Tcp,
299 remote_addr,
300 local_addr,
301 ))
302 .unwrap();
303 assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
304 assert_eq!(ret.data(), &msg);
305 assert!(ret.message().is_some());
306 }
307
308 #[test]
309 fn test_incoming_tcp_complete_message_in_channel() {
310 let _init = crate::tests::test_init_log();
311 let (local_addr, remote_addr) = generate_addresses();
312 let mut tcp = TurnTcpBuffer::new();
313 let msg = generate_message_in_channel();
314 let ret = tcp
315 .incoming_tcp(Transmit::new(
316 msg.clone(),
317 TransportType::Tcp,
318 remote_addr,
319 local_addr,
320 ))
321 .unwrap();
322 assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
323 assert_eq!(ret.data(), &msg);
324 assert!(ret.channel().is_some());
325 }
326
327 #[test]
328 fn test_incoming_tcp_partial_message() {
329 let _init = crate::tests::test_init_log();
330 let (local_addr, remote_addr) = generate_addresses();
331 let mut tcp = TurnTcpBuffer::new();
332 let msg = generate_message();
333 info!("message: {msg:x?}");
334 for i in 1..msg.len() {
335 let ret = tcp.incoming_tcp(Transmit::new(
336 &msg[i - 1..i],
337 TransportType::Tcp,
338 remote_addr,
339 local_addr,
340 ));
341 assert!(ret.is_none());
342 }
343 let ret = tcp
344 .incoming_tcp(Transmit::new(
345 &msg[msg.len() - 1..],
346 TransportType::Tcp,
347 remote_addr,
348 local_addr,
349 ))
350 .unwrap();
351 assert_eq!(ret.data(), &msg);
352 assert!(ret.message().is_some());
353 let IncomingTcp::StoredMessage(produced, _) = ret else {
354 unreachable!();
355 };
356 assert_eq!(produced, msg);
357 }
358
359 #[test]
360 fn test_incoming_tcp_partial_channel() {
361 let _init = crate::tests::test_init_log();
362 let (local_addr, remote_addr) = generate_addresses();
363 let mut tcp = TurnTcpBuffer::new();
364 let channel = generate_message_in_channel();
365 info!("message: {channel:x?}");
366 for i in 1..channel.len() {
367 let ret = tcp.incoming_tcp(Transmit::new(
368 &channel[i - 1..i],
369 TransportType::Tcp,
370 remote_addr,
371 local_addr,
372 ));
373 assert!(ret.is_none());
374 }
375 let ret = tcp
376 .incoming_tcp(Transmit::new(
377 &channel[channel.len() - 1..],
378 TransportType::Tcp,
379 remote_addr,
380 local_addr,
381 ))
382 .unwrap();
383 assert_eq!(ret.data(), &channel);
384 assert!(ret.channel().is_some());
385 let IncomingTcp::StoredChannel(produced, _) = ret else {
386 unreachable!()
387 };
388 assert_eq!(produced, channel);
389 }
390
391 #[test]
392 fn test_incoming_tcp_message_and_channel() {
393 let _init = crate::tests::test_init_log();
394 let (local_addr, remote_addr) = generate_addresses();
395 let mut tcp = TurnTcpBuffer::new();
396 let msg = generate_message();
397 let channel = generate_message_in_channel();
398 let mut input = msg.clone();
399 input.extend_from_slice(&channel);
400 let ret = tcp
401 .incoming_tcp(Transmit::new(
402 input.clone(),
403 TransportType::Tcp,
404 remote_addr,
405 local_addr,
406 ))
407 .unwrap();
408 assert_eq!(ret.data(), &msg);
409 assert!(ret.message().is_some());
410 let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
411 unreachable!();
412 };
413 assert_eq!(msg_range, 0..msg.len());
414 assert_eq!(transmit.data, input);
415 let ret = tcp.poll_recv().unwrap();
416 assert_eq!(ret.data(), &channel);
417 let StoredTcp::Channel(produced) = ret else {
418 unreachable!()
419 };
420 assert_eq!(produced, channel);
421 }
422
423 #[test]
424 fn test_incoming_tcp_channel_and_message() {
425 let _init = crate::tests::test_init_log();
426 let (local_addr, remote_addr) = generate_addresses();
427 let mut tcp = TurnTcpBuffer::new();
428 let msg = generate_message();
429 let channel = generate_message_in_channel();
430 let mut input = channel.clone();
431 input.extend_from_slice(&msg);
432 let ret = tcp
433 .incoming_tcp(Transmit::new(
434 input.clone(),
435 TransportType::Tcp,
436 remote_addr,
437 local_addr,
438 ))
439 .unwrap();
440 assert_eq!(ret.data(), &channel);
441 assert!(ret.channel().is_some());
442 let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
443 unreachable!()
444 };
445 assert_eq!(channel_range, 0..channel.len());
446 assert_eq!(transmit.data, input);
447 let ret = tcp.poll_recv().unwrap();
448 assert_eq!(ret.data(), &msg);
449 let StoredTcp::Message(produced) = ret else {
450 unreachable!()
451 };
452 assert_eq!(produced, msg);
453 }
454}