turn_types/
tcp.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Handle TURN over TCP.
10//!
11//! A TURN connection between a client and a server can have two types of data:
12//! - STUN [`Message`]s, and
13//! - [`ChannelData`]s
14//!
15//! Unlike a UDP connection which inherently contains a size for every message, TCP is a
16//! stream-based protocol and the size of a message must be infered from the contained data. This
17//! module performs the relevant buffering of incoming data over a TCP connection and produces
18//! [`Message`]s or [`ChannelData`] as they are completely received.
19//!
20//! The buffering performed by [`TurnTcpBuffer`] is only applicable for the TCP connection between
21//! the TURN client and the TURN server when using a UDP allocation. Use of TURN-TCP ([RFC6062])
22//! requires the TURN client to connect to the TURN server using a TCP connection (optionally with
23//! TLS) and data on the separate data TCP connection is forwarded as-is. The control connection for
24//! TURN-TCP requires buffering of only STUN Messages without any framing and can also be performed
25//! by [`TurnTcpBuffer`] if any [`ChannelData`] messages received are considered fatal TURN protocol
26//! errors.
27//!
28//! [RFC6062]: https://tools.ietf.org/html/rfc6062
29
30use alloc::vec;
31use alloc::vec::Vec;
32use core::ops::Range;
33
34use stun_proto::agent::Transmit;
35use stun_types::message::{Message, MessageHeader};
36use tracing::{debug, trace};
37
38use crate::channel::ChannelData;
39
40/// Reply to [`TurnTcpBuffer::incoming_tcp()`].
41///
42/// The `Transmit<T>` in each value  is always the original value passed to
43/// [`TurnTcpBuffer::incoming_tcp()`].
44#[derive(Debug)]
45pub enum IncomingTcp<T: AsRef<[u8]> + core::fmt::Debug> {
46    /// Input data (with the provided range) contains a complete STUN Message.
47    ///
48    /// Any extra data after the range is stored for later processing.
49    CompleteMessage(Transmit<T>, Range<usize>),
50    /// Input data (with the provided range) contains a complete Channel data message.
51    ///
52    /// Any extra data after the range is stored for later processing.
53    CompleteChannel(Transmit<T>, Range<usize>),
54    /// A STUN message has been produced from the buffered data.
55    StoredMessage(Vec<u8>, Transmit<T>),
56    /// A Channel data message has been produced from the buffered data.
57    StoredChannel(Vec<u8>, Transmit<T>),
58}
59
60impl<T: AsRef<[u8]> + core::fmt::Debug> IncomingTcp<T> {
61    /// The byte slice for this incoming, or stored message, or channel.
62    pub fn data(&self) -> &[u8] {
63        match self {
64            Self::CompleteMessage(transmit, range) => {
65                &transmit.data.as_ref()[range.start..range.end]
66            }
67            Self::CompleteChannel(transmit, range) => {
68                &transmit.data.as_ref()[range.start..range.end]
69            }
70            Self::StoredMessage(data, _transmit) => data,
71            Self::StoredChannel(data, _transmit) => data,
72        }
73    }
74
75    /// The [`Message`] contained in this incoming or stored data.
76    pub fn message(&self) -> Option<Message<'_>> {
77        if !matches!(
78            self,
79            Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
80        ) {
81            return None;
82        }
83        Message::from_bytes(self.data()).ok()
84    }
85
86    /// The [`ChannelData`] contained in this incoming or stored data.
87    pub fn channel(&self) -> Option<ChannelData<'_>> {
88        if !matches!(
89            self,
90            Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
91        ) {
92            return None;
93        }
94        ChannelData::parse(self.data()).ok()
95    }
96}
97
98impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
99    fn as_ref(&self) -> &[u8] {
100        self.data()
101    }
102}
103
104/// A stored [`Message`] or [`ChannelData`]
105#[derive(Debug)]
106pub enum StoredTcp {
107    /// A STUN [`Message`] has been received.
108    Message(Vec<u8>),
109    /// A [`ChannelData`] has been received.
110    Channel(Vec<u8>),
111}
112
113impl StoredTcp {
114    /// The byte slice for this stored data.
115    pub fn data(&self) -> &[u8] {
116        match self {
117            Self::Message(data) => data,
118            Self::Channel(data) => data,
119        }
120    }
121
122    fn into_incoming<T: AsRef<[u8]> + core::fmt::Debug>(
123        self,
124        transmit: Transmit<T>,
125    ) -> IncomingTcp<T> {
126        match self {
127            Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
128            Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
129        }
130    }
131}
132
133impl AsRef<[u8]> for StoredTcp {
134    fn as_ref(&self) -> &[u8] {
135        self.data()
136    }
137}
138
139/// A TCP buffer for TURN messages.
140#[derive(Debug, Default)]
141pub struct TurnTcpBuffer {
142    tcp_buffer: Vec<u8>,
143}
144
145impl TurnTcpBuffer {
146    /// Construct a new [`TurnTcpBuffer`].
147    pub fn new() -> Self {
148        Self { tcp_buffer: vec![] }
149    }
150
151    /// Provide incoming TCP data to parse.
152    ///
153    /// A return value of `None` indicates that the more data is required to provide a complete
154    /// STUN [`Message`] or a [`ChannelData`].
155    #[tracing::instrument(
156        level = "trace",
157        skip(self, transmit),
158        fields(
159            transmit.data_len = transmit.data.as_ref().len(),
160            from = ?transmit.from
161        )
162    )]
163    pub fn incoming_tcp<T: AsRef<[u8]> + core::fmt::Debug>(
164        &mut self,
165        transmit: Transmit<T>,
166    ) -> Option<IncomingTcp<T>> {
167        if self.tcp_buffer.is_empty() {
168            let data = transmit.data.as_ref();
169            trace!("Trying to parse incoming data as a complete message/channel");
170            let Ok(hdr) = MessageHeader::from_bytes(data) else {
171                let Ok(channel) = ChannelData::parse(data) else {
172                    self.tcp_buffer.extend_from_slice(data);
173                    return None;
174                };
175                let channel_len = 4 + channel.data().len();
176                debug!(
177                    channel.id = channel.id(),
178                    channel.len = channel_len - 4,
179                    "Incoming data contains a channel",
180                );
181                if channel_len < data.len() {
182                    self.tcp_buffer.extend_from_slice(&data[channel_len..]);
183                }
184                return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
185            };
186            let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
187            debug!(
188                msg.transaction = %hdr.transaction_id(),
189                msg.len = msg_len,
190                "Incoming data contains a message",
191            );
192            if data.len() < msg_len {
193                self.tcp_buffer.extend_from_slice(data);
194                return None;
195            }
196            if msg_len < data.len() {
197                self.tcp_buffer.extend_from_slice(&data[msg_len..]);
198            }
199            return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
200        }
201
202        self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
203        self.poll_recv().map(|recv| recv.into_incoming(transmit))
204    }
205
206    /// Return the next complete message (if any).
207    #[tracing::instrument(
208        level = "trace",
209        skip(self),
210        fields(
211            buffered_len = self.tcp_buffer.len(),
212        )
213    )]
214    pub fn poll_recv(&mut self) -> Option<StoredTcp> {
215        let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
216            let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
217                trace!(
218                    buffered.len = self.tcp_buffer.len(),
219                    "cannot parse stored data"
220                );
221                return None;
222            };
223            let channel_len = 4 + channel_data_len;
224            if self.tcp_buffer.len() < channel_len {
225                trace!(
226                    buffered.len = self.tcp_buffer.len(),
227                    required = channel_len,
228                    "need more bytes to complete channel data"
229                );
230                return None;
231            }
232            let (data, remaining) = self.tcp_buffer.split_at(channel_len);
233            let data_binding = data.to_vec();
234            debug!(
235                channel.id = id,
236                channel.len = channel_data_len,
237                remaining = remaining.len(),
238                "buffered data contains a channel",
239            );
240            self.tcp_buffer = remaining.to_vec();
241            return Some(StoredTcp::Channel(data_binding));
242        };
243        let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
244        if self.tcp_buffer.len() < msg_len {
245            trace!(
246                buffered.len = self.tcp_buffer.len(),
247                required = msg_len,
248                "need more bytes to complete STUN message"
249            );
250            return None;
251        }
252        let (data, remaining) = self.tcp_buffer.split_at(msg_len);
253        let data_binding = data.to_vec();
254        debug!(
255            msg.transaction = %hdr.transaction_id(),
256            msg.len = msg_len,
257            remaining = remaining.len(),
258            "stored data contains a message",
259        );
260        self.tcp_buffer = remaining.to_vec();
261        Some(StoredTcp::Message(data_binding))
262    }
263
264    /// Returns the underlying buffer.
265    pub fn into_inner(self) -> Vec<u8> {
266        self.tcp_buffer
267    }
268
269    /// The number of bytes contained in this buffer.
270    pub fn len(&self) -> usize {
271        self.tcp_buffer.len()
272    }
273
274    /// Whether the buffer currently contains 0 bytes of data.
275    pub fn is_empty(&self) -> bool {
276        self.tcp_buffer.is_empty()
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use core::net::SocketAddr;
283
284    use stun_types::{
285        attribute::Software,
286        message::{Message, MessageWriteVec},
287        prelude::{MessageWrite, MessageWriteExt},
288        TransportType,
289    };
290    use tracing::info;
291
292    use crate::message::ALLOCATE;
293
294    use super::*;
295
296    fn generate_addresses() -> (SocketAddr, SocketAddr) {
297        (
298            "192.168.0.1:1000".parse().unwrap(),
299            "10.0.0.2:2000".parse().unwrap(),
300        )
301    }
302
303    fn generate_message() -> Vec<u8> {
304        let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
305        msg.add_attribute(&Software::new("turn-types").unwrap())
306            .unwrap();
307        msg.add_fingerprint().unwrap();
308        msg.finish()
309    }
310
311    fn generate_message_in_channel() -> Vec<u8> {
312        let msg = generate_message();
313        let channel = ChannelData::new(0x4000, &msg);
314        let mut out = vec![0; msg.len() + 4];
315        channel.write_into_unchecked(&mut out);
316        out
317    }
318
319    #[test]
320    fn test_incoming_tcp_complete_message() {
321        let _init = crate::tests::test_init_log();
322        let (local_addr, remote_addr) = generate_addresses();
323        let mut tcp = TurnTcpBuffer::new();
324        let msg = generate_message();
325        let ret = tcp
326            .incoming_tcp(Transmit::new(
327                msg.clone(),
328                TransportType::Tcp,
329                remote_addr,
330                local_addr,
331            ))
332            .unwrap();
333        assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
334        assert_eq!(ret.data(), &msg);
335        assert!(ret.message().is_some());
336        assert!(tcp.into_inner().is_empty());
337    }
338
339    #[test]
340    fn test_incoming_tcp_complete_message_in_channel() {
341        let _init = crate::tests::test_init_log();
342        let (local_addr, remote_addr) = generate_addresses();
343        let mut tcp = TurnTcpBuffer::new();
344        let msg = generate_message_in_channel();
345        let ret = tcp
346            .incoming_tcp(Transmit::new(
347                msg.clone(),
348                TransportType::Tcp,
349                remote_addr,
350                local_addr,
351            ))
352            .unwrap();
353        assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
354        assert_eq!(ret.data(), &msg);
355        assert!(ret.channel().is_some());
356        assert!(tcp.into_inner().is_empty());
357    }
358
359    #[test]
360    fn test_incoming_tcp_partial_message() {
361        let _init = crate::tests::test_init_log();
362        let (local_addr, remote_addr) = generate_addresses();
363        let mut tcp = TurnTcpBuffer::new();
364        let msg = generate_message();
365        info!("message: {msg:x?}");
366        for i in 1..msg.len() {
367            let ret = tcp.incoming_tcp(Transmit::new(
368                &msg[i - 1..i],
369                TransportType::Tcp,
370                remote_addr,
371                local_addr,
372            ));
373            assert!(ret.is_none());
374
375            let data = tcp.into_inner();
376            assert_eq!(&data, &msg[..i]);
377            tcp = TurnTcpBuffer::new();
378            let ret = tcp.incoming_tcp(Transmit::new(
379                &data,
380                TransportType::Tcp,
381                remote_addr,
382                local_addr,
383            ));
384            assert!(ret.is_none());
385        }
386        let ret = tcp
387            .incoming_tcp(Transmit::new(
388                &msg[msg.len() - 1..],
389                TransportType::Tcp,
390                remote_addr,
391                local_addr,
392            ))
393            .unwrap();
394        assert_eq!(ret.data(), &msg);
395        assert!(ret.message().is_some());
396        let IncomingTcp::StoredMessage(produced, _) = ret else {
397            unreachable!();
398        };
399        assert_eq!(produced, msg);
400        assert!(tcp.into_inner().is_empty());
401    }
402
403    #[test]
404    fn test_incoming_tcp_partial_channel() {
405        let _init = crate::tests::test_init_log();
406        let (local_addr, remote_addr) = generate_addresses();
407        let mut tcp = TurnTcpBuffer::new();
408        let channel = generate_message_in_channel();
409        info!("message: {channel:x?}");
410        for i in 1..channel.len() {
411            let ret = tcp.incoming_tcp(Transmit::new(
412                &channel[i - 1..i],
413                TransportType::Tcp,
414                remote_addr,
415                local_addr,
416            ));
417            assert!(ret.is_none());
418
419            let data = tcp.into_inner();
420            assert_eq!(&data, &channel[..i]);
421            tcp = TurnTcpBuffer::new();
422            let ret = tcp.incoming_tcp(Transmit::new(
423                &data,
424                TransportType::Tcp,
425                remote_addr,
426                local_addr,
427            ));
428            assert!(ret.is_none());
429        }
430        let ret = tcp
431            .incoming_tcp(Transmit::new(
432                &channel[channel.len() - 1..],
433                TransportType::Tcp,
434                remote_addr,
435                local_addr,
436            ))
437            .unwrap();
438        assert_eq!(ret.data(), &channel);
439        assert!(ret.channel().is_some());
440        let IncomingTcp::StoredChannel(produced, _) = ret else {
441            unreachable!()
442        };
443        assert_eq!(produced, channel);
444        assert!(tcp.into_inner().is_empty());
445    }
446
447    #[test]
448    fn test_incoming_tcp_message_and_channel() {
449        let _init = crate::tests::test_init_log();
450        let (local_addr, remote_addr) = generate_addresses();
451        let mut tcp = TurnTcpBuffer::new();
452        let msg = generate_message();
453        let channel = generate_message_in_channel();
454        let mut input = msg.clone();
455        input.extend_from_slice(&channel);
456        let ret = tcp
457            .incoming_tcp(Transmit::new(
458                input.clone(),
459                TransportType::Tcp,
460                remote_addr,
461                local_addr,
462            ))
463            .unwrap();
464        assert_eq!(ret.data(), &msg);
465        assert!(ret.message().is_some());
466        let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
467            unreachable!();
468        };
469        assert_eq!(msg_range, 0..msg.len());
470        assert_eq!(transmit.data, input);
471        let ret = tcp.poll_recv().unwrap();
472        assert_eq!(ret.data(), &channel);
473        let StoredTcp::Channel(produced) = ret else {
474            unreachable!()
475        };
476        assert_eq!(produced, channel);
477    }
478
479    #[test]
480    fn test_incoming_tcp_channel_and_message() {
481        let _init = crate::tests::test_init_log();
482        let (local_addr, remote_addr) = generate_addresses();
483        let mut tcp = TurnTcpBuffer::new();
484        let msg = generate_message();
485        let channel = generate_message_in_channel();
486        let mut input = channel.clone();
487        input.extend_from_slice(&msg);
488        let ret = tcp
489            .incoming_tcp(Transmit::new(
490                input.clone(),
491                TransportType::Tcp,
492                remote_addr,
493                local_addr,
494            ))
495            .unwrap();
496        assert_eq!(ret.data(), &channel);
497        assert!(ret.channel().is_some());
498        let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
499            unreachable!()
500        };
501        assert_eq!(channel_range, 0..channel.len());
502        assert_eq!(transmit.data, input);
503        let ret = tcp.poll_recv().unwrap();
504        assert_eq!(ret.data(), &msg);
505        let StoredTcp::Message(produced) = ret else {
506            unreachable!()
507        };
508        assert_eq!(produced, msg);
509    }
510}