turn_types/
tcp.rs

1// Copyright (C) 2025 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Handle TURN over TCP.
10//!
11//! A TURN connection between a client and a server can have two types of data:
12//! - STUN [`Message`]s, and
13//! - [`ChannelData`]s
14//!
15//! Unlike a UDP connection which inherently contains a size for every message, TCP is a
16//! stream-based protocol and the size of a message must be infered from the contained data. This
17//! module performs the relevant buffering of incoming data over a TCP connection and produces
18//! [`Message`]s or [`ChannelData`] as they are completely received.
19
20use std::ops::Range;
21
22use stun_proto::agent::Transmit;
23use stun_types::message::{Message, MessageHeader};
24use tracing::{debug, trace};
25
26use crate::channel::ChannelData;
27
28/// Reply to [`TurnTcpBuffer::incoming_tcp()`]
29#[derive(Debug)]
30pub enum IncomingTcp<T: AsRef<[u8]> + std::fmt::Debug> {
31    /// Input data (with the provided range) contains a complete STUN Message.
32    ///
33    /// Any extra data after the range is stored for later processing.
34    CompleteMessage(Transmit<T>, Range<usize>),
35    /// Input data (with the provided range) contains a complete Channel data message.
36    ///
37    /// Any extra data after the range is stored for later processing.
38    CompleteChannel(Transmit<T>, Range<usize>),
39    /// A STUN message has been produced from the buffered data.
40    StoredMessage(Vec<u8>, Transmit<T>),
41    /// A Channel data message has been produced from the buffered data.
42    StoredChannel(Vec<u8>, Transmit<T>),
43}
44
45impl<T: AsRef<[u8]> + std::fmt::Debug> IncomingTcp<T> {
46    /// The byte slice for this incoming or stored data.
47    pub fn data(&self) -> &[u8] {
48        match self {
49            Self::CompleteMessage(transmit, range) => {
50                &transmit.data.as_ref()[range.start..range.end]
51            }
52            Self::CompleteChannel(transmit, range) => {
53                &transmit.data.as_ref()[range.start..range.end]
54            }
55            Self::StoredMessage(data, _transmit) => data,
56            Self::StoredChannel(data, _transmit) => data,
57        }
58    }
59
60    /// The [`Message`] contained in this incoming or stored data.
61    pub fn message(&self) -> Option<Message<'_>> {
62        if !matches!(
63            self,
64            Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
65        ) {
66            return None;
67        }
68        Message::from_bytes(self.data()).ok()
69    }
70
71    /// The [`ChannelData`] contained in this incoming or stored data.
72    pub fn channel(&self) -> Option<ChannelData<'_>> {
73        if !matches!(
74            self,
75            Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
76        ) {
77            return None;
78        }
79        ChannelData::parse(self.data()).ok()
80    }
81}
82
83impl<T: AsRef<[u8]> + std::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
84    fn as_ref(&self) -> &[u8] {
85        self.data()
86    }
87}
88
89/// A stored [`Message`] or [`ChannelData`]
90#[derive(Debug)]
91pub enum StoredTcp {
92    /// A STUN [`Message`] has been received.
93    Message(Vec<u8>),
94    /// A [`ChannelData`] has been received.
95    Channel(Vec<u8>),
96}
97
98impl StoredTcp {
99    /// The byte slice for this stored data.
100    pub fn data(&self) -> &[u8] {
101        match self {
102            Self::Message(data) => data,
103            Self::Channel(data) => data,
104        }
105    }
106
107    fn into_incoming<T: AsRef<[u8]> + std::fmt::Debug>(
108        self,
109        transmit: Transmit<T>,
110    ) -> IncomingTcp<T> {
111        match self {
112            Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
113            Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
114        }
115    }
116}
117
118impl AsRef<[u8]> for StoredTcp {
119    fn as_ref(&self) -> &[u8] {
120        self.data()
121    }
122}
123
124/// A TCP buffer for TURN messages.
125#[derive(Debug, Default)]
126pub struct TurnTcpBuffer {
127    tcp_buffer: Vec<u8>,
128}
129
130impl TurnTcpBuffer {
131    /// Construct a new [`TurnTcpBuffer`].
132    pub fn new() -> Self {
133        Self { tcp_buffer: vec![] }
134    }
135
136    /// Provide incoming TCP data to parse.
137    ///
138    /// A return value of `None` indicates that the more data is required to provide a complete
139    /// STUN [`Message`] or a [`ChannelData`].
140    #[tracing::instrument(
141        level = "trace",
142        skip(self, transmit),
143        fields(
144            transmit.data_len = transmit.data.as_ref().len(),
145            from = ?transmit.from
146        )
147    )]
148    pub fn incoming_tcp<T: AsRef<[u8]> + std::fmt::Debug>(
149        &mut self,
150        transmit: Transmit<T>,
151    ) -> Option<IncomingTcp<T>> {
152        if self.tcp_buffer.is_empty() {
153            let data = transmit.data.as_ref();
154            trace!("Trying to parse incoming data as a complete message/channel");
155            let Ok(hdr) = MessageHeader::from_bytes(data) else {
156                let Ok(channel) = ChannelData::parse(data) else {
157                    self.tcp_buffer.extend_from_slice(data);
158                    return None;
159                };
160                let channel_len = 4 + channel.data().len();
161                debug!(
162                    channel.id = channel.id(),
163                    channel.len = channel_len - 4,
164                    "Incoming data contains a channel",
165                );
166                if channel_len < data.len() {
167                    self.tcp_buffer.extend_from_slice(&data[channel_len..]);
168                }
169                return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
170            };
171            let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
172            debug!(
173                msg.transaction = %hdr.transaction_id(),
174                msg.len = msg_len,
175                "Incoming data contains a message",
176            );
177            if data.len() < msg_len {
178                self.tcp_buffer.extend_from_slice(data);
179                return None;
180            }
181            if msg_len < data.len() {
182                self.tcp_buffer.extend_from_slice(&data[msg_len..]);
183            }
184            return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
185        }
186
187        self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
188        self.poll_recv().map(|recv| recv.into_incoming(transmit))
189    }
190
191    /// Return the next complete message (if any).
192    #[tracing::instrument(
193        level = "trace",
194        skip(self),
195        fields(
196            buffered_len = self.tcp_buffer.len(),
197        )
198    )]
199    pub fn poll_recv(&mut self) -> Option<StoredTcp> {
200        let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
201            let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
202                trace!(
203                    buffered.len = self.tcp_buffer.len(),
204                    "cannot parse stored data"
205                );
206                return None;
207            };
208            let channel_len = 4 + channel_data_len;
209            if self.tcp_buffer.len() < channel_len {
210                trace!(
211                    buffered.len = self.tcp_buffer.len(),
212                    required = channel_len,
213                    "need more bytes to complete channel data"
214                );
215                return None;
216            }
217            let (data, remaining) = self.tcp_buffer.split_at(channel_len);
218            let data_binding = data.to_vec();
219            debug!(
220                channel.id = id,
221                channel.len = channel_data_len,
222                remaining = remaining.len(),
223                "buffered data contains a channel",
224            );
225            self.tcp_buffer = remaining.to_vec();
226            return Some(StoredTcp::Channel(data_binding));
227        };
228        let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
229        if self.tcp_buffer.len() < msg_len {
230            trace!(
231                buffered.len = self.tcp_buffer.len(),
232                required = msg_len,
233                "need more bytes to complete STUN message"
234            );
235            return None;
236        }
237        let (data, remaining) = self.tcp_buffer.split_at(msg_len);
238        let data_binding = data.to_vec();
239        debug!(
240            msg.transaction = %hdr.transaction_id(),
241            msg.len = msg_len,
242            remaining = remaining.len(),
243            "stored data contains a message",
244        );
245        self.tcp_buffer = remaining.to_vec();
246        Some(StoredTcp::Message(data_binding))
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use std::net::SocketAddr;
253
254    use stun_types::{
255        attribute::Software,
256        message::{Message, MessageWriteVec},
257        prelude::{MessageWrite, MessageWriteExt},
258        TransportType,
259    };
260    use tracing::info;
261
262    use crate::message::ALLOCATE;
263
264    use super::*;
265
266    fn generate_addresses() -> (SocketAddr, SocketAddr) {
267        (
268            "192.168.0.1:1000".parse().unwrap(),
269            "10.0.0.2:2000".parse().unwrap(),
270        )
271    }
272
273    fn generate_message() -> Vec<u8> {
274        let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
275        msg.add_attribute(&Software::new("turn-types").unwrap())
276            .unwrap();
277        msg.add_fingerprint().unwrap();
278        msg.finish()
279    }
280
281    fn generate_message_in_channel() -> Vec<u8> {
282        let msg = generate_message();
283        let channel = ChannelData::new(0x4000, &msg);
284        let mut out = vec![0; msg.len() + 4];
285        channel.write_into_unchecked(&mut out);
286        out
287    }
288
289    #[test]
290    fn test_incoming_tcp_complete_message() {
291        let _init = crate::tests::test_init_log();
292        let (local_addr, remote_addr) = generate_addresses();
293        let mut tcp = TurnTcpBuffer::new();
294        let msg = generate_message();
295        let ret = tcp
296            .incoming_tcp(Transmit::new(
297                msg.clone(),
298                TransportType::Tcp,
299                remote_addr,
300                local_addr,
301            ))
302            .unwrap();
303        assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
304        assert_eq!(ret.data(), &msg);
305        assert!(ret.message().is_some());
306    }
307
308    #[test]
309    fn test_incoming_tcp_complete_message_in_channel() {
310        let _init = crate::tests::test_init_log();
311        let (local_addr, remote_addr) = generate_addresses();
312        let mut tcp = TurnTcpBuffer::new();
313        let msg = generate_message_in_channel();
314        let ret = tcp
315            .incoming_tcp(Transmit::new(
316                msg.clone(),
317                TransportType::Tcp,
318                remote_addr,
319                local_addr,
320            ))
321            .unwrap();
322        assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
323        assert_eq!(ret.data(), &msg);
324        assert!(ret.channel().is_some());
325    }
326
327    #[test]
328    fn test_incoming_tcp_partial_message() {
329        let _init = crate::tests::test_init_log();
330        let (local_addr, remote_addr) = generate_addresses();
331        let mut tcp = TurnTcpBuffer::new();
332        let msg = generate_message();
333        info!("message: {msg:x?}");
334        for i in 1..msg.len() {
335            let ret = tcp.incoming_tcp(Transmit::new(
336                &msg[i - 1..i],
337                TransportType::Tcp,
338                remote_addr,
339                local_addr,
340            ));
341            assert!(ret.is_none());
342        }
343        let ret = tcp
344            .incoming_tcp(Transmit::new(
345                &msg[msg.len() - 1..],
346                TransportType::Tcp,
347                remote_addr,
348                local_addr,
349            ))
350            .unwrap();
351        assert_eq!(ret.data(), &msg);
352        assert!(ret.message().is_some());
353        let IncomingTcp::StoredMessage(produced, _) = ret else {
354            unreachable!();
355        };
356        assert_eq!(produced, msg);
357    }
358
359    #[test]
360    fn test_incoming_tcp_partial_channel() {
361        let _init = crate::tests::test_init_log();
362        let (local_addr, remote_addr) = generate_addresses();
363        let mut tcp = TurnTcpBuffer::new();
364        let channel = generate_message_in_channel();
365        info!("message: {channel:x?}");
366        for i in 1..channel.len() {
367            let ret = tcp.incoming_tcp(Transmit::new(
368                &channel[i - 1..i],
369                TransportType::Tcp,
370                remote_addr,
371                local_addr,
372            ));
373            assert!(ret.is_none());
374        }
375        let ret = tcp
376            .incoming_tcp(Transmit::new(
377                &channel[channel.len() - 1..],
378                TransportType::Tcp,
379                remote_addr,
380                local_addr,
381            ))
382            .unwrap();
383        assert_eq!(ret.data(), &channel);
384        assert!(ret.channel().is_some());
385        let IncomingTcp::StoredChannel(produced, _) = ret else {
386            unreachable!()
387        };
388        assert_eq!(produced, channel);
389    }
390
391    #[test]
392    fn test_incoming_tcp_message_and_channel() {
393        let _init = crate::tests::test_init_log();
394        let (local_addr, remote_addr) = generate_addresses();
395        let mut tcp = TurnTcpBuffer::new();
396        let msg = generate_message();
397        let channel = generate_message_in_channel();
398        let mut input = msg.clone();
399        input.extend_from_slice(&channel);
400        let ret = tcp
401            .incoming_tcp(Transmit::new(
402                input.clone(),
403                TransportType::Tcp,
404                remote_addr,
405                local_addr,
406            ))
407            .unwrap();
408        assert_eq!(ret.data(), &msg);
409        assert!(ret.message().is_some());
410        let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
411            unreachable!();
412        };
413        assert_eq!(msg_range, 0..msg.len());
414        assert_eq!(transmit.data, input);
415        let ret = tcp.poll_recv().unwrap();
416        assert_eq!(ret.data(), &channel);
417        let StoredTcp::Channel(produced) = ret else {
418            unreachable!()
419        };
420        assert_eq!(produced, channel);
421    }
422
423    #[test]
424    fn test_incoming_tcp_channel_and_message() {
425        let _init = crate::tests::test_init_log();
426        let (local_addr, remote_addr) = generate_addresses();
427        let mut tcp = TurnTcpBuffer::new();
428        let msg = generate_message();
429        let channel = generate_message_in_channel();
430        let mut input = channel.clone();
431        input.extend_from_slice(&msg);
432        let ret = tcp
433            .incoming_tcp(Transmit::new(
434                input.clone(),
435                TransportType::Tcp,
436                remote_addr,
437                local_addr,
438            ))
439            .unwrap();
440        assert_eq!(ret.data(), &channel);
441        assert!(ret.channel().is_some());
442        let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
443            unreachable!()
444        };
445        assert_eq!(channel_range, 0..channel.len());
446        assert_eq!(transmit.data, input);
447        let ret = tcp.poll_recv().unwrap();
448        assert_eq!(ret.data(), &msg);
449        let StoredTcp::Message(produced) = ret else {
450            unreachable!()
451        };
452        assert_eq!(produced, msg);
453    }
454}