turn_types/
tcp.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Handle TURN over TCP.
10//!
11//! A TURN connection between a client and a server can have two types of data:
12//! - STUN [`Message`]s, and
13//! - [`ChannelData`]s
14//!
15//! Unlike a UDP connection which inherently contains a size for every message, TCP is a
16//! stream-based protocol and the size of a message must be infered from the contained data. This
17//! module performs the relevant buffering of incoming data over a TCP connection and produces
18//! [`Message`]s or [`ChannelData`] as they are completely received.
19//!
20//! The buffering performed by [`TurnTcpBuffer`] is only applicable for the TCP connection between
21//! the TURN client and the TURN server when using a UDP allocation. Use of TURN-TCP ([RFC6062])
22//! requires the TURN client to connect to the TURN server using a TCP connection (optionally with
23//! TLS) and data on the separate data TCP connection is forwarded as-is. The control connection for
24//! TURN-TCP requires buffering of only STUN Messages without any framing and can also be performed
25//! by [`TurnTcpBuffer`] if any [`ChannelData`] messages received are considered fatal TURN protocol
26//! errors.
27//!
28//! [RFC6062]: https://tools.ietf.org/html/rfc6062
29
30use alloc::vec;
31use alloc::vec::Vec;
32use core::ops::Range;
33
34use stun_proto::agent::Transmit;
35use stun_types::message::{Message, MessageHeader};
36use tracing::{debug, trace};
37
38use crate::channel::ChannelData;
39
40/// Reply to [`TurnTcpBuffer::incoming_tcp()`].
41///
42/// The `Transmit<T>` in each value  is always the original value passed to
43/// [`TurnTcpBuffer::incoming_tcp()`].
44#[derive(Debug)]
45pub enum IncomingTcp<T: AsRef<[u8]> + core::fmt::Debug> {
46    /// Input data (with the provided range) contains a complete STUN Message.
47    ///
48    /// Any extra data after the range is stored for later processing.
49    CompleteMessage(Transmit<T>, Range<usize>),
50    /// Input data (with the provided range) contains a complete Channel data message.
51    ///
52    /// Any extra data after the range is stored for later processing.
53    CompleteChannel(Transmit<T>, Range<usize>),
54    /// A STUN message has been produced from the buffered data.
55    StoredMessage(Vec<u8>, Transmit<T>),
56    /// A Channel data message has been produced from the buffered data.
57    StoredChannel(Vec<u8>, Transmit<T>),
58}
59
60impl<T: AsRef<[u8]> + core::fmt::Debug> IncomingTcp<T> {
61    /// The byte slice for this incoming, or stored message, or channel.
62    pub fn data(&self) -> &[u8] {
63        match self {
64            Self::CompleteMessage(transmit, range) => {
65                &transmit.data.as_ref()[range.start..range.end]
66            }
67            Self::CompleteChannel(transmit, range) => {
68                &transmit.data.as_ref()[range.start..range.end]
69            }
70            Self::StoredMessage(data, _transmit) => data,
71            Self::StoredChannel(data, _transmit) => data,
72        }
73    }
74
75    /// The [`Message`] contained in this incoming or stored data.
76    pub fn message(&self) -> Option<Message<'_>> {
77        if !matches!(
78            self,
79            Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
80        ) {
81            return None;
82        }
83        Message::from_bytes(self.data()).ok()
84    }
85
86    /// The [`ChannelData`] contained in this incoming or stored data.
87    pub fn channel(&self) -> Option<ChannelData<'_>> {
88        if !matches!(
89            self,
90            Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
91        ) {
92            return None;
93        }
94        ChannelData::parse(self.data()).ok()
95    }
96}
97
98impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
99    fn as_ref(&self) -> &[u8] {
100        self.data()
101    }
102}
103
104/// A stored [`Message`] or [`ChannelData`]
105#[derive(Debug)]
106pub enum StoredTcp {
107    /// A STUN [`Message`] has been received.
108    Message(Vec<u8>),
109    /// A [`ChannelData`] has been received.
110    Channel(Vec<u8>),
111}
112
113impl StoredTcp {
114    /// The byte slice for this stored data.
115    pub fn data(&self) -> &[u8] {
116        match self {
117            Self::Message(data) => data,
118            Self::Channel(data) => data,
119        }
120    }
121
122    fn into_incoming<T: AsRef<[u8]> + core::fmt::Debug>(
123        self,
124        transmit: Transmit<T>,
125    ) -> IncomingTcp<T> {
126        match self {
127            Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
128            Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
129        }
130    }
131}
132
133impl AsRef<[u8]> for StoredTcp {
134    fn as_ref(&self) -> &[u8] {
135        self.data()
136    }
137}
138
139/// A TCP buffer for TURN messages.
140#[derive(Debug, Default)]
141pub struct TurnTcpBuffer {
142    tcp_buffer: Vec<u8>,
143}
144
145impl TurnTcpBuffer {
146    /// Construct a new [`TurnTcpBuffer`].
147    pub fn new() -> Self {
148        Self { tcp_buffer: vec![] }
149    }
150
151    /// Provide incoming TCP data to parse.
152    ///
153    /// A return value of `None` indicates that the more data is required to provide a complete
154    /// STUN [`Message`] or a [`ChannelData`].
155    #[tracing::instrument(
156        level = "trace",
157        skip(self, transmit),
158        fields(
159            transmit.data_len = transmit.data.as_ref().len(),
160            from = ?transmit.from
161        )
162    )]
163    pub fn incoming_tcp<T: AsRef<[u8]> + core::fmt::Debug>(
164        &mut self,
165        transmit: Transmit<T>,
166    ) -> Option<IncomingTcp<T>> {
167        if self.tcp_buffer.is_empty() {
168            let data = transmit.data.as_ref();
169            trace!("Trying to parse incoming data as a complete message/channel");
170            let Ok(hdr) = MessageHeader::from_bytes(data) else {
171                let Ok(channel) = ChannelData::parse(data) else {
172                    self.tcp_buffer.extend_from_slice(data);
173                    return None;
174                };
175                let channel_len = 4 + channel.data().len();
176                debug!(
177                    channel.id = channel.id(),
178                    channel.len = channel_len - 4,
179                    "Incoming data contains a channel",
180                );
181                if channel_len < data.len() {
182                    self.tcp_buffer.extend_from_slice(&data[channel_len..]);
183                }
184                return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
185            };
186            let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
187            debug!(
188                msg.transaction = %hdr.transaction_id(),
189                msg.len = msg_len,
190                "Incoming data contains a message",
191            );
192            if data.len() < msg_len {
193                self.tcp_buffer.extend_from_slice(data);
194                return None;
195            }
196            if msg_len < data.len() {
197                self.tcp_buffer.extend_from_slice(&data[msg_len..]);
198            }
199            return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
200        }
201
202        self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
203        self.poll_recv().map(|recv| recv.into_incoming(transmit))
204    }
205
206    /// Return the next complete message (if any).
207    #[tracing::instrument(
208        level = "trace",
209        skip(self),
210        fields(
211            buffered_len = self.tcp_buffer.len(),
212        )
213    )]
214    pub fn poll_recv(&mut self) -> Option<StoredTcp> {
215        let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
216            let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
217                trace!(
218                    buffered.len = self.tcp_buffer.len(),
219                    "cannot parse stored data"
220                );
221                return None;
222            };
223            let channel_len = 4 + channel_data_len;
224            if self.tcp_buffer.len() < channel_len {
225                trace!(
226                    buffered.len = self.tcp_buffer.len(),
227                    required = channel_len,
228                    "need more bytes to complete channel data"
229                );
230                return None;
231            }
232            let (data, remaining) = self.tcp_buffer.split_at(channel_len);
233            let data_binding = data.to_vec();
234            debug!(
235                channel.id = id,
236                channel.len = channel_data_len,
237                remaining = remaining.len(),
238                "buffered data contains a channel",
239            );
240            self.tcp_buffer = remaining.to_vec();
241            return Some(StoredTcp::Channel(data_binding));
242        };
243        let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
244        if self.tcp_buffer.len() < msg_len {
245            trace!(
246                buffered.len = self.tcp_buffer.len(),
247                required = msg_len,
248                "need more bytes to complete STUN message"
249            );
250            return None;
251        }
252        let (data, remaining) = self.tcp_buffer.split_at(msg_len);
253        let data_binding = data.to_vec();
254        debug!(
255            msg.transaction = %hdr.transaction_id(),
256            msg.len = msg_len,
257            remaining = remaining.len(),
258            "stored data contains a message",
259        );
260        self.tcp_buffer = remaining.to_vec();
261        Some(StoredTcp::Message(data_binding))
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use core::net::SocketAddr;
268
269    use stun_types::{
270        attribute::Software,
271        message::{Message, MessageWriteVec},
272        prelude::{MessageWrite, MessageWriteExt},
273        TransportType,
274    };
275    use tracing::info;
276
277    use crate::message::ALLOCATE;
278
279    use super::*;
280
281    fn generate_addresses() -> (SocketAddr, SocketAddr) {
282        (
283            "192.168.0.1:1000".parse().unwrap(),
284            "10.0.0.2:2000".parse().unwrap(),
285        )
286    }
287
288    fn generate_message() -> Vec<u8> {
289        let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
290        msg.add_attribute(&Software::new("turn-types").unwrap())
291            .unwrap();
292        msg.add_fingerprint().unwrap();
293        msg.finish()
294    }
295
296    fn generate_message_in_channel() -> Vec<u8> {
297        let msg = generate_message();
298        let channel = ChannelData::new(0x4000, &msg);
299        let mut out = vec![0; msg.len() + 4];
300        channel.write_into_unchecked(&mut out);
301        out
302    }
303
304    #[test]
305    fn test_incoming_tcp_complete_message() {
306        let _init = crate::tests::test_init_log();
307        let (local_addr, remote_addr) = generate_addresses();
308        let mut tcp = TurnTcpBuffer::new();
309        let msg = generate_message();
310        let ret = tcp
311            .incoming_tcp(Transmit::new(
312                msg.clone(),
313                TransportType::Tcp,
314                remote_addr,
315                local_addr,
316            ))
317            .unwrap();
318        assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
319        assert_eq!(ret.data(), &msg);
320        assert!(ret.message().is_some());
321    }
322
323    #[test]
324    fn test_incoming_tcp_complete_message_in_channel() {
325        let _init = crate::tests::test_init_log();
326        let (local_addr, remote_addr) = generate_addresses();
327        let mut tcp = TurnTcpBuffer::new();
328        let msg = generate_message_in_channel();
329        let ret = tcp
330            .incoming_tcp(Transmit::new(
331                msg.clone(),
332                TransportType::Tcp,
333                remote_addr,
334                local_addr,
335            ))
336            .unwrap();
337        assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
338        assert_eq!(ret.data(), &msg);
339        assert!(ret.channel().is_some());
340    }
341
342    #[test]
343    fn test_incoming_tcp_partial_message() {
344        let _init = crate::tests::test_init_log();
345        let (local_addr, remote_addr) = generate_addresses();
346        let mut tcp = TurnTcpBuffer::new();
347        let msg = generate_message();
348        info!("message: {msg:x?}");
349        for i in 1..msg.len() {
350            let ret = tcp.incoming_tcp(Transmit::new(
351                &msg[i - 1..i],
352                TransportType::Tcp,
353                remote_addr,
354                local_addr,
355            ));
356            assert!(ret.is_none());
357        }
358        let ret = tcp
359            .incoming_tcp(Transmit::new(
360                &msg[msg.len() - 1..],
361                TransportType::Tcp,
362                remote_addr,
363                local_addr,
364            ))
365            .unwrap();
366        assert_eq!(ret.data(), &msg);
367        assert!(ret.message().is_some());
368        let IncomingTcp::StoredMessage(produced, _) = ret else {
369            unreachable!();
370        };
371        assert_eq!(produced, msg);
372    }
373
374    #[test]
375    fn test_incoming_tcp_partial_channel() {
376        let _init = crate::tests::test_init_log();
377        let (local_addr, remote_addr) = generate_addresses();
378        let mut tcp = TurnTcpBuffer::new();
379        let channel = generate_message_in_channel();
380        info!("message: {channel:x?}");
381        for i in 1..channel.len() {
382            let ret = tcp.incoming_tcp(Transmit::new(
383                &channel[i - 1..i],
384                TransportType::Tcp,
385                remote_addr,
386                local_addr,
387            ));
388            assert!(ret.is_none());
389        }
390        let ret = tcp
391            .incoming_tcp(Transmit::new(
392                &channel[channel.len() - 1..],
393                TransportType::Tcp,
394                remote_addr,
395                local_addr,
396            ))
397            .unwrap();
398        assert_eq!(ret.data(), &channel);
399        assert!(ret.channel().is_some());
400        let IncomingTcp::StoredChannel(produced, _) = ret else {
401            unreachable!()
402        };
403        assert_eq!(produced, channel);
404    }
405
406    #[test]
407    fn test_incoming_tcp_message_and_channel() {
408        let _init = crate::tests::test_init_log();
409        let (local_addr, remote_addr) = generate_addresses();
410        let mut tcp = TurnTcpBuffer::new();
411        let msg = generate_message();
412        let channel = generate_message_in_channel();
413        let mut input = msg.clone();
414        input.extend_from_slice(&channel);
415        let ret = tcp
416            .incoming_tcp(Transmit::new(
417                input.clone(),
418                TransportType::Tcp,
419                remote_addr,
420                local_addr,
421            ))
422            .unwrap();
423        assert_eq!(ret.data(), &msg);
424        assert!(ret.message().is_some());
425        let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
426            unreachable!();
427        };
428        assert_eq!(msg_range, 0..msg.len());
429        assert_eq!(transmit.data, input);
430        let ret = tcp.poll_recv().unwrap();
431        assert_eq!(ret.data(), &channel);
432        let StoredTcp::Channel(produced) = ret else {
433            unreachable!()
434        };
435        assert_eq!(produced, channel);
436    }
437
438    #[test]
439    fn test_incoming_tcp_channel_and_message() {
440        let _init = crate::tests::test_init_log();
441        let (local_addr, remote_addr) = generate_addresses();
442        let mut tcp = TurnTcpBuffer::new();
443        let msg = generate_message();
444        let channel = generate_message_in_channel();
445        let mut input = channel.clone();
446        input.extend_from_slice(&msg);
447        let ret = tcp
448            .incoming_tcp(Transmit::new(
449                input.clone(),
450                TransportType::Tcp,
451                remote_addr,
452                local_addr,
453            ))
454            .unwrap();
455        assert_eq!(ret.data(), &channel);
456        assert!(ret.channel().is_some());
457        let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
458            unreachable!()
459        };
460        assert_eq!(channel_range, 0..channel.len());
461        assert_eq!(transmit.data, input);
462        let ret = tcp.poll_recv().unwrap();
463        assert_eq!(ret.data(), &msg);
464        let StoredTcp::Message(produced) = ret else {
465            unreachable!()
466        };
467        assert_eq!(produced, msg);
468    }
469}