shape_wire/transport/
tcp.rs1use super::{Connection, Transport, TransportError};
7use std::io::{Read, Write};
8use std::net::TcpStream;
9use std::time::Duration;
10
11pub const MAX_PAYLOAD_SIZE: usize = 64 * 1024 * 1024;
13
14pub struct TcpTransport {
16 pub connect_timeout: Duration,
18 pub read_timeout: Option<Duration>,
20}
21
22impl Default for TcpTransport {
23 fn default() -> Self {
24 Self {
25 connect_timeout: Duration::from_secs(10),
26 read_timeout: Some(Duration::from_secs(30)),
27 }
28 }
29}
30
31impl Transport for TcpTransport {
32 fn send(&self, destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
33 if payload.len() > MAX_PAYLOAD_SIZE {
34 return Err(TransportError::PayloadTooLarge {
35 size: payload.len(),
36 max: MAX_PAYLOAD_SIZE,
37 });
38 }
39
40 let mut stream = TcpStream::connect_timeout(
41 &destination
42 .parse()
43 .map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?,
44 self.connect_timeout,
45 )
46 .map_err(|e| TransportError::ConnectionFailed(format!("{}: {}", destination, e)))?;
47
48 stream.set_read_timeout(self.read_timeout).ok();
49 stream.set_write_timeout(Some(self.connect_timeout)).ok();
50
51 write_length_prefixed(&mut stream, payload)?;
52 read_length_prefixed(&mut stream)
53 }
54
55 fn connect(&self, destination: &str) -> Result<Box<dyn Connection>, TransportError> {
56 let stream = TcpStream::connect_timeout(
57 &destination
58 .parse()
59 .map_err(|e| TransportError::ConnectionFailed(format!("{}", e)))?,
60 self.connect_timeout,
61 )
62 .map_err(|e| TransportError::ConnectionFailed(format!("{}: {}", destination, e)))?;
63
64 Ok(Box::new(TcpConnection {
65 stream,
66 max_payload: MAX_PAYLOAD_SIZE,
67 }))
68 }
69}
70
71pub struct TcpConnection {
73 stream: TcpStream,
74 max_payload: usize,
75}
76
77impl TcpConnection {
78 pub fn from_stream(stream: TcpStream) -> Self {
80 Self {
81 stream,
82 max_payload: MAX_PAYLOAD_SIZE,
83 }
84 }
85}
86
87impl Connection for TcpConnection {
88 fn send(&mut self, payload: &[u8]) -> Result<(), TransportError> {
89 if payload.len() > self.max_payload {
90 return Err(TransportError::PayloadTooLarge {
91 size: payload.len(),
92 max: self.max_payload,
93 });
94 }
95 write_length_prefixed(&mut self.stream, payload)
96 }
97
98 fn recv(&mut self, timeout: Option<Duration>) -> Result<Vec<u8>, TransportError> {
99 self.stream
100 .set_read_timeout(timeout)
101 .map_err(TransportError::Io)?;
102 read_length_prefixed(&mut self.stream)
103 }
104
105 fn close(&mut self) -> Result<(), TransportError> {
106 self.stream
107 .shutdown(std::net::Shutdown::Both)
108 .map_err(TransportError::Io)
109 }
110}
111
112pub fn write_length_prefixed(stream: &mut TcpStream, data: &[u8]) -> Result<(), TransportError> {
118 let framed = super::framing::encode_framed(data);
119 let len = framed.len() as u32;
120 stream
121 .write_all(&len.to_be_bytes())
122 .map_err(|e| TransportError::SendFailed(format!("write frame length: {}", e)))?;
123 stream
124 .write_all(&framed)
125 .map_err(|e| TransportError::SendFailed(format!("write frame payload: {}", e)))?;
126 stream
127 .flush()
128 .map_err(|e| TransportError::SendFailed(format!("flush: {}", e)))?;
129 Ok(())
130}
131
132pub fn read_length_prefixed(stream: &mut TcpStream) -> Result<Vec<u8>, TransportError> {
134 let mut len_buf = [0u8; 4];
135 stream
136 .read_exact(&mut len_buf)
137 .map_err(|e| TransportError::ReceiveFailed(format!("read frame length: {}", e)))?;
138 let len = u32::from_be_bytes(len_buf) as usize;
139
140 if len > MAX_PAYLOAD_SIZE {
141 return Err(TransportError::PayloadTooLarge {
142 size: len,
143 max: MAX_PAYLOAD_SIZE,
144 });
145 }
146
147 let mut buf = vec![0u8; len];
148 stream
149 .read_exact(&mut buf)
150 .map_err(|e| TransportError::ReceiveFailed(format!("read frame payload: {}", e)))?;
151 super::framing::decode_framed(&buf)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::net::TcpListener;
158
159 #[test]
160 fn test_length_prefixed_roundtrip() {
161 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
162 let addr = listener.local_addr().unwrap();
163
164 let payload = b"hello transport";
165
166 let server = std::thread::spawn(move || {
167 let (mut conn, _) = listener.accept().unwrap();
168 conn.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
169 let data = read_length_prefixed(&mut conn).unwrap();
170 write_length_prefixed(&mut conn, &data).unwrap();
171 });
172
173 let mut stream = TcpStream::connect(addr).unwrap();
174 stream
175 .set_read_timeout(Some(Duration::from_secs(5)))
176 .unwrap();
177
178 write_length_prefixed(&mut stream, payload).unwrap();
179 let response = read_length_prefixed(&mut stream).unwrap();
180
181 assert_eq!(&response, payload);
182 server.join().unwrap();
183 }
184
185 #[test]
186 fn test_tcp_transport_one_shot() {
187 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
188 let addr = listener.local_addr().unwrap();
189
190 let server = std::thread::spawn(move || {
191 let (mut conn, _) = listener.accept().unwrap();
192 conn.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
193 let data = read_length_prefixed(&mut conn).unwrap();
194 let mut response = b"reply:".to_vec();
195 response.extend_from_slice(&data);
196 write_length_prefixed(&mut conn, &response).unwrap();
197 });
198
199 let transport = TcpTransport::default();
200 let result = transport.send(&addr.to_string(), b"ping").unwrap();
201 assert_eq!(&result, b"reply:ping");
202 server.join().unwrap();
203 }
204
205 #[test]
206 fn test_tcp_connection_send_recv() {
207 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
208 let addr = listener.local_addr().unwrap();
209
210 let server = std::thread::spawn(move || {
211 let (mut conn, _) = listener.accept().unwrap();
212 conn.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
213 let data = read_length_prefixed(&mut conn).unwrap();
214 write_length_prefixed(&mut conn, &data).unwrap();
215 });
216
217 let transport = TcpTransport::default();
218 let mut conn = transport.connect(&addr.to_string()).unwrap();
219
220 conn.send(b"test data").unwrap();
221 let response = conn.recv(Some(Duration::from_secs(5))).unwrap();
222 assert_eq!(&response, b"test data");
223
224 conn.close().unwrap();
225 server.join().unwrap();
226 }
227
228 #[test]
229 fn test_payload_too_large() {
230 let transport = TcpTransport::default();
231 let huge = vec![0u8; MAX_PAYLOAD_SIZE + 1];
232 let result = transport.send("127.0.0.1:1", &huge);
233 assert!(matches!(
234 result,
235 Err(TransportError::PayloadTooLarge { .. })
236 ));
237 }
238
239 #[test]
240 fn test_connection_refused() {
241 let transport = TcpTransport {
242 connect_timeout: Duration::from_millis(100),
243 ..Default::default()
244 };
245 let result = transport.connect("127.0.0.1:1");
246 assert!(result.is_err());
247 }
248}