Skip to main content

shape_wire/transport/
tcp.rs

1//! TCP transport implementation with length-prefixed framing.
2//!
3//! Uses a simple 4-byte big-endian length header followed by the payload.
4//! Maximum payload size is 64 MB to prevent accidental memory exhaustion.
5
6use super::{Connection, Transport, TransportError};
7use std::io::{Read, Write};
8use std::net::TcpStream;
9use std::time::Duration;
10
11/// Maximum payload size: 64 MB.
12pub const MAX_PAYLOAD_SIZE: usize = 64 * 1024 * 1024;
13
14/// TCP-based transport using length-prefixed framing.
15pub struct TcpTransport {
16    /// Timeout for establishing new connections.
17    pub connect_timeout: Duration,
18    /// Read timeout applied to connections created by [`Transport::send`].
19    pub read_timeout: Option<Duration>,
20}
21
22impl Default for TcpTransport {
23    fn default() -> Self {
24        Self {
25            connect_timeout: Duration::from_secs(10),
26            read_timeout: Some(Duration::from_secs(30)),
27        }
28    }
29}
30
31impl Transport for TcpTransport {
32    fn send(&self, destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
33        if payload.len() > MAX_PAYLOAD_SIZE {
34            return Err(TransportError::PayloadTooLarge {
35                size: payload.len(),
36                max: MAX_PAYLOAD_SIZE,
37            });
38        }
39
40        let mut stream = TcpStream::connect_timeout(
41            &destination
42                .parse()
43                .map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?,
44            self.connect_timeout,
45        )
46        .map_err(|e| TransportError::ConnectionFailed(format!("{}: {}", destination, e)))?;
47
48        stream.set_read_timeout(self.read_timeout).ok();
49        stream.set_write_timeout(Some(self.connect_timeout)).ok();
50
51        write_length_prefixed(&mut stream, payload)?;
52        read_length_prefixed(&mut stream)
53    }
54
55    fn connect(&self, destination: &str) -> Result<Box<dyn Connection>, TransportError> {
56        let stream = TcpStream::connect_timeout(
57            &destination
58                .parse()
59                .map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?,
60            self.connect_timeout,
61        )
62        .map_err(|e| TransportError::ConnectionFailed(format!("{}: {}", destination, e)))?;
63
64        Ok(Box::new(TcpConnection {
65            stream,
66            max_payload: MAX_PAYLOAD_SIZE,
67        }))
68    }
69}
70
71/// A persistent TCP connection with length-prefixed framing.
72pub struct TcpConnection {
73    stream: TcpStream,
74    max_payload: usize,
75}
76
77impl TcpConnection {
78    /// Wrap an already-connected `TcpStream` into a `TcpConnection`.
79    pub fn from_stream(stream: TcpStream) -> Self {
80        Self {
81            stream,
82            max_payload: MAX_PAYLOAD_SIZE,
83        }
84    }
85}
86
87impl Connection for TcpConnection {
88    fn send(&mut self, payload: &[u8]) -> Result<(), TransportError> {
89        if payload.len() > self.max_payload {
90            return Err(TransportError::PayloadTooLarge {
91                size: payload.len(),
92                max: self.max_payload,
93            });
94        }
95        write_length_prefixed(&mut self.stream, payload)
96    }
97
98    fn recv(&mut self, timeout: Option<Duration>) -> Result<Vec<u8>, TransportError> {
99        self.stream
100            .set_read_timeout(timeout)
101            .map_err(TransportError::Io)?;
102        read_length_prefixed(&mut self.stream)
103    }
104
105    fn close(&mut self) -> Result<(), TransportError> {
106        self.stream
107            .shutdown(std::net::Shutdown::Both)
108            .map_err(TransportError::Io)
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Length-prefixed framing helpers
114// ---------------------------------------------------------------------------
115
116/// Write a length-prefixed frame: compress, then write 4-byte BE length + framed payload.
117pub fn write_length_prefixed(stream: &mut TcpStream, data: &[u8]) -> Result<(), TransportError> {
118    let framed = super::framing::encode_framed(data);
119    let len = framed.len() as u32;
120    stream
121        .write_all(&len.to_be_bytes())
122        .map_err(|e| TransportError::SendFailed(format!("write frame length: {}", e)))?;
123    stream
124        .write_all(&framed)
125        .map_err(|e| TransportError::SendFailed(format!("write frame payload: {}", e)))?;
126    stream
127        .flush()
128        .map_err(|e| TransportError::SendFailed(format!("flush: {}", e)))?;
129    Ok(())
130}
131
132/// Read a length-prefixed frame: read raw bytes, then decompress.
133pub fn read_length_prefixed(stream: &mut TcpStream) -> Result<Vec<u8>, TransportError> {
134    let mut len_buf = [0u8; 4];
135    stream
136        .read_exact(&mut len_buf)
137        .map_err(|e| TransportError::ReceiveFailed(format!("read frame length: {}", e)))?;
138    let len = u32::from_be_bytes(len_buf) as usize;
139
140    if len > MAX_PAYLOAD_SIZE {
141        return Err(TransportError::PayloadTooLarge {
142            size: len,
143            max: MAX_PAYLOAD_SIZE,
144        });
145    }
146
147    let mut buf = vec![0u8; len];
148    stream
149        .read_exact(&mut buf)
150        .map_err(|e| TransportError::ReceiveFailed(format!("read frame payload: {}", e)))?;
151    super::framing::decode_framed(&buf)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use std::net::TcpListener;
158
159    #[test]
160    fn test_length_prefixed_roundtrip() {
161        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
162        let addr = listener.local_addr().unwrap();
163
164        let payload = b"hello transport";
165
166        let server = std::thread::spawn(move || {
167            let (mut conn, _) = listener.accept().unwrap();
168            conn.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
169            let data = read_length_prefixed(&mut conn).unwrap();
170            write_length_prefixed(&mut conn, &data).unwrap();
171        });
172
173        let mut stream = TcpStream::connect(addr).unwrap();
174        stream
175            .set_read_timeout(Some(Duration::from_secs(5)))
176            .unwrap();
177
178        write_length_prefixed(&mut stream, payload).unwrap();
179        let response = read_length_prefixed(&mut stream).unwrap();
180
181        assert_eq!(&response, payload);
182        server.join().unwrap();
183    }
184
185    #[test]
186    fn test_tcp_transport_one_shot() {
187        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
188        let addr = listener.local_addr().unwrap();
189
190        let server = std::thread::spawn(move || {
191            let (mut conn, _) = listener.accept().unwrap();
192            conn.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
193            let data = read_length_prefixed(&mut conn).unwrap();
194            let mut response = b"reply:".to_vec();
195            response.extend_from_slice(&data);
196            write_length_prefixed(&mut conn, &response).unwrap();
197        });
198
199        let transport = TcpTransport::default();
200        let result = transport.send(&addr.to_string(), b"ping").unwrap();
201        assert_eq!(&result, b"reply:ping");
202        server.join().unwrap();
203    }
204
205    #[test]
206    fn test_tcp_connection_send_recv() {
207        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
208        let addr = listener.local_addr().unwrap();
209
210        let server = std::thread::spawn(move || {
211            let (mut conn, _) = listener.accept().unwrap();
212            conn.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
213            let data = read_length_prefixed(&mut conn).unwrap();
214            write_length_prefixed(&mut conn, &data).unwrap();
215        });
216
217        let transport = TcpTransport::default();
218        let mut conn = transport.connect(&addr.to_string()).unwrap();
219
220        conn.send(b"test data").unwrap();
221        let response = conn.recv(Some(Duration::from_secs(5))).unwrap();
222        assert_eq!(&response, b"test data");
223
224        conn.close().unwrap();
225        server.join().unwrap();
226    }
227
228    #[test]
229    fn test_payload_too_large() {
230        let transport = TcpTransport::default();
231        let huge = vec![0u8; MAX_PAYLOAD_SIZE + 1];
232        let result = transport.send("127.0.0.1:1", &huge);
233        assert!(matches!(
234            result,
235            Err(TransportError::PayloadTooLarge { .. })
236        ));
237    }
238
239    #[test]
240    fn test_connection_refused() {
241        let transport = TcpTransport {
242            connect_timeout: Duration::from_millis(100),
243            ..Default::default()
244        };
245        let result = transport.connect("127.0.0.1:1");
246        assert!(result.is_err());
247    }
248}