1use crate::error::{LinkError, Result};
2use crate::protocol::Message;
3use crate::serialization::{BinarySerializer, BinaryFormat};
4use bytes::Bytes;
5
6#[cfg(feature = "async")]
7use async_trait::async_trait;
8
9pub trait Transport {
10 fn send(&mut self, message: &Message) -> Result<()>;
11 fn receive(&mut self) -> Result<Option<Message>>;
12 fn close(&mut self) -> Result<()>;
13 fn is_connected(&self) -> bool;
14}
15
16#[cfg(feature = "async")]
17#[async_trait]
18pub trait AsyncTransport: Send + Sync {
19 async fn send(&mut self, message: &Message) -> Result<()>;
20 async fn receive(&mut self) -> Result<Option<Message>>;
21 async fn close(&mut self) -> Result<()>;
22 fn is_connected(&self) -> bool;
23}
24
25pub struct MemoryTransport {
26 serializer: BinarySerializer,
27 send_buffer: Vec<Bytes>,
28 receive_buffer: Vec<Bytes>,
29 connected: bool,
30}
31
32impl MemoryTransport {
33 pub fn new(format: BinaryFormat) -> Self {
34 Self {
35 serializer: BinarySerializer::new(format),
36 send_buffer: Vec::new(),
37 receive_buffer: Vec::new(),
38 connected: true,
39 }
40 }
41
42 pub fn create_pair(format: BinaryFormat) -> (Self, Self) {
43 let t1 = Self::new(format);
44 let t2 = Self::new(format);
45 (t1, t2)
46 }
47
48 pub fn connect_to(&mut self, other: &mut Self) {
49 std::mem::swap(&mut self.send_buffer, &mut other.receive_buffer);
50 std::mem::swap(&mut self.receive_buffer, &mut other.send_buffer);
51 }
52
53 pub fn get_send_buffer(&self) -> &[Bytes] {
54 &self.send_buffer
55 }
56
57 pub fn get_receive_buffer(&self) -> &[Bytes] {
58 &self.receive_buffer
59 }
60}
61
62impl Transport for MemoryTransport {
63 fn send(&mut self, message: &Message) -> Result<()> {
64 if !self.connected {
65 return Err(LinkError::ConnectionClosed);
66 }
67
68 let data = self.serializer.serialize_message(message)?;
69 self.send_buffer.push(data);
70 Ok(())
71 }
72
73 fn receive(&mut self) -> Result<Option<Message>> {
74 if !self.connected {
75 return Err(LinkError::ConnectionClosed);
76 }
77
78 if self.receive_buffer.is_empty() {
79 return Ok(None);
80 }
81
82 let data = self.receive_buffer.remove(0);
83 let message = self.serializer.deserialize_message(&data)?;
84 Ok(Some(message))
85 }
86
87 fn close(&mut self) -> Result<()> {
88 self.connected = false;
89 self.send_buffer.clear();
90 self.receive_buffer.clear();
91 Ok(())
92 }
93
94 fn is_connected(&self) -> bool {
95 self.connected
96 }
97}
98
99pub struct StdioTransport {
100 serializer: BinarySerializer,
101 connected: bool,
102}
103
104impl StdioTransport {
105 pub fn new(format: BinaryFormat) -> Self {
106 Self {
107 serializer: BinarySerializer::new(format),
108 connected: true,
109 }
110 }
111}
112
113impl Transport for StdioTransport {
114 fn send(&mut self, message: &Message) -> Result<()> {
115 if !self.connected {
116 return Err(LinkError::ConnectionClosed);
117 }
118
119 use std::io::Write;
120
121 let data = self.serializer.serialize_message(message)?;
122 let len = data.len() as u32;
123
124 let mut stdout = std::io::stdout();
125 stdout.write_all(&len.to_le_bytes())?;
126 stdout.write_all(&data)?;
127 stdout.flush()?;
128
129 Ok(())
130 }
131
132 fn receive(&mut self) -> Result<Option<Message>> {
133 if !self.connected {
134 return Err(LinkError::ConnectionClosed);
135 }
136
137 use std::io::Read;
138
139 let mut stdin = std::io::stdin();
140 let mut len_bytes = [0u8; 4];
141
142 match stdin.read_exact(&mut len_bytes) {
143 Ok(_) => {},
144 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
145 return Ok(None);
146 }
147 Err(e) => return Err(e.into()),
148 }
149
150 let len = u32::from_le_bytes(len_bytes) as usize;
151 let mut buffer = vec![0u8; len];
152
153 stdin.read_exact(&mut buffer)?;
154
155 let message = self.serializer.deserialize_message(&buffer)?;
156 Ok(Some(message))
157 }
158
159 fn close(&mut self) -> Result<()> {
160 self.connected = false;
161 Ok(())
162 }
163
164 fn is_connected(&self) -> bool {
165 self.connected
166 }
167}
168
169#[cfg(feature = "websocket")]
170pub mod websocket {
171 use super::*;
172 use tokio_tungstenite::{
173 WebSocketStream,
174 tungstenite::Message as WsMessage,
175 };
176 use tokio::net::TcpStream;
177 use futures_util::{SinkExt, StreamExt};
178
179 pub struct WebSocketTransport {
180 serializer: BinarySerializer,
181 stream: Option<WebSocketStream<TcpStream>>,
182 }
183
184 impl WebSocketTransport {
185 pub fn new(format: BinaryFormat, stream: WebSocketStream<TcpStream>) -> Self {
186 Self {
187 serializer: BinarySerializer::new(format),
188 stream: Some(stream),
189 }
190 }
191 }
192
193 #[async_trait]
194 impl AsyncTransport for WebSocketTransport {
195 async fn send(&mut self, message: &Message) -> Result<()> {
196 let stream = self.stream.as_mut()
197 .ok_or(LinkError::ConnectionClosed)?;
198
199 let data = self.serializer.serialize_message(message)?;
200 stream.send(WsMessage::Binary(data.to_vec())).await
201 .map_err(|e| LinkError::Transport(e.to_string()))?;
202
203 Ok(())
204 }
205
206 async fn receive(&mut self) -> Result<Option<Message>> {
207 let stream = self.stream.as_mut()
208 .ok_or(LinkError::ConnectionClosed)?;
209
210 match stream.next().await {
211 Some(Ok(WsMessage::Binary(data))) => {
212 let message = self.serializer.deserialize_message(&data)?;
213 Ok(Some(message))
214 }
215 Some(Ok(WsMessage::Close(_))) => {
216 self.stream = None;
217 Err(LinkError::ConnectionClosed)
218 }
219 Some(Ok(_)) => Ok(None),
220 Some(Err(e)) => Err(LinkError::Transport(e.to_string())),
221 None => {
222 self.stream = None;
223 Err(LinkError::ConnectionClosed)
224 }
225 }
226 }
227
228 async fn close(&mut self) -> Result<()> {
229 if let Some(mut stream) = self.stream.take() {
230 stream.close(None).await
231 .map_err(|e| LinkError::Transport(e.to_string()))?;
232 }
233 Ok(())
234 }
235
236 fn is_connected(&self) -> bool {
237 self.stream.is_some()
238 }
239 }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum TransportError {
244 NotConnected,
245 SendFailed,
246 ReceiveFailed,
247 CloseFailed,
248}
249
250impl std::fmt::Display for TransportError {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 match self {
253 TransportError::NotConnected => write!(f, "Not connected"),
254 TransportError::SendFailed => write!(f, "Send failed"),
255 TransportError::ReceiveFailed => write!(f, "Receive failed"),
256 TransportError::CloseFailed => write!(f, "Close failed"),
257 }
258 }
259}
260
261impl std::error::Error for TransportError {}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::protocol::MessageType;
267
268 #[test]
269 fn test_memory_transport() {
270 let mut transport1 = MemoryTransport::new(BinaryFormat::MessagePack);
271 let mut transport2 = MemoryTransport::new(BinaryFormat::MessagePack);
272
273 let message = Message::ping(1);
274 transport1.send(&message).unwrap();
275
276 transport1.connect_to(&mut transport2);
277
278 let received = transport2.receive().unwrap().unwrap();
279 assert_eq!(message.header.msg_type, received.header.msg_type);
280 }
281
282 #[test]
283 fn test_transport_close() {
284 let mut transport = MemoryTransport::new(BinaryFormat::Json);
285
286 assert!(transport.is_connected());
287
288 transport.close().unwrap();
289
290 assert!(!transport.is_connected());
291
292 let message = Message::ping(1);
293 assert!(transport.send(&message).is_err());
294 }
295}