tcpwarp/
proto.rs

1use super::*;
2use std::io;
3
4pub struct TcpWarpProto;
5
6impl Encoder for TcpWarpProto {
7    type Item = TcpWarpMessage;
8    type Error = io::Error;
9
10    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> {
11        match item {
12            TcpWarpMessage::AddPorts(ports) => {
13                dst.reserve(1 + 2 + ports.len() * 2);
14                dst.put_u8(1);
15                dst.put_u16(ports.len() as u16);
16                for port in ports {
17                    dst.put_u16(port);
18                }
19            }
20            TcpWarpMessage::HostConnect {
21                connection_id,
22                host,
23                port,
24            } => {
25                let len = host.as_ref().map_or(0, |x| x.len());
26                dst.reserve(1 + 2 + 16 + 2 + len);
27                dst.put_u8(2);
28                dst.put_u16(len as u16);
29                dst.put_u128(connection_id.as_u128());
30                dst.put_u16(port);
31                if let Some(data) = host {
32                    dst.put_slice(data.as_bytes());
33                }
34            }
35            TcpWarpMessage::BytesClient {
36                connection_id,
37                data,
38            } => {
39                dst.reserve(1 + 16 + 4 + data.len());
40                dst.put_u8(3);
41                dst.put_u128(connection_id.as_u128());
42                dst.put_u32(data.len() as u32);
43                dst.put_slice(&data);
44            }
45            TcpWarpMessage::BytesHost {
46                connection_id,
47                data,
48            } => {
49                dst.reserve(1 + 16 + 4 + data.len());
50                dst.put_u8(4);
51                dst.put_u128(connection_id.as_u128());
52                dst.put_u32(data.len() as u32);
53                dst.put_slice(&data);
54            }
55            TcpWarpMessage::Connected { connection_id } => {
56                dst.reserve(1 + 16);
57                dst.put_u8(5);
58                dst.put_u128(connection_id.as_u128());
59            }
60            TcpWarpMessage::DisconnectHost { connection_id } => {
61                dst.reserve(1 + 16);
62                dst.put_u8(6);
63                dst.put_u128(connection_id.as_u128());
64            }
65            TcpWarpMessage::DisconnectClient { connection_id } => {
66                dst.reserve(1 + 16);
67                dst.put_u8(7);
68                dst.put_u128(connection_id.as_u128());
69            }
70            TcpWarpMessage::ConnectFailure { connection_id } => {
71                dst.reserve(1 + 16);
72                dst.put_u8(8);
73                dst.put_u128(connection_id.as_u128());
74            }
75            other => {
76                error!("unknown message: {:?}", other);
77            }
78        }
79
80        Ok(())
81    }
82}
83
84impl Decoder for TcpWarpProto {
85    type Item = TcpWarpMessage;
86    type Error = io::Error;
87
88    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<TcpWarpMessage>> {
89        Ok(match src.get(0) {
90            Some(1) if src.len() > 2 => {
91                let len = u16::from_be_bytes(src[1..3].try_into().unwrap());
92                if src.len() == 3 && len == 0 {
93                    src.advance(3);
94                    return Ok(Some(TcpWarpMessage::AddPorts(vec![])));
95                }
96                if len as usize * 2 + 3 <= src.len() {
97                    src.advance(3);
98                    let data = src.split_to(len as usize * 2);
99                    let ports = data
100                        .chunks_exact(2)
101                        .map(|x| u16::from_be_bytes(x.try_into().unwrap()))
102                        .collect();
103                    Some(TcpWarpMessage::AddPorts(ports))
104                } else {
105                    None
106                }
107            }
108            Some(2) if src.len() > (16 + 2) => {
109                let len = u16::from_be_bytes(src[1..3].try_into().unwrap()) as usize;
110                if 2 + 16 + 2 + len < src.len() {
111                    src.advance(3);
112                    let header = src.split_to(18);
113                    let connection_id = Uuid::from_slice(&header[0..16]).unwrap();
114                    let port = u16::from_be_bytes(header[16..18].try_into().unwrap());
115                    let host = if len > 0 {
116                        String::from_utf8(src.split_to(len).to_vec()).ok()
117                    } else {
118                        None
119                    };
120                    Some(TcpWarpMessage::HostConnect {
121                        connection_id,
122                        host,
123                        port,
124                    })
125                } else {
126                    None
127                }
128            }
129            Some(3) if src.len() > (16 + 4 + 1) => {
130                let len = u32::from_be_bytes(src[17..21].try_into().unwrap()) as usize;
131                if len as usize + 16 + 4 < src.len() {
132                    src.advance(1);
133                    let header = src.split_to(20);
134                    let connection_id = Uuid::from_slice(&header[0..16]).unwrap();
135                    let data = src.split_to(len);
136                    Some(TcpWarpMessage::BytesClient {
137                        connection_id,
138                        data,
139                    })
140                } else {
141                    None
142                }
143            }
144            Some(4) if src.len() > (16 + 4 + 1) => {
145                let len = u32::from_be_bytes(src[17..21].try_into().unwrap()) as usize;
146                if len as usize + 16 + 4 < src.len() {
147                    src.advance(1);
148                    let header = src.split_to(20);
149                    let connection_id = Uuid::from_slice(&header[0..16]).unwrap();
150                    let data = src.split_to(len);
151                    Some(TcpWarpMessage::BytesHost {
152                        connection_id,
153                        data,
154                    })
155                } else {
156                    None
157                }
158            }
159            Some(5) if src.len() > 16 => {
160                src.advance(1);
161                let header = src.split_to(16);
162                let connection_id = Uuid::from_slice(&header).unwrap();
163                Some(TcpWarpMessage::Connected { connection_id })
164            }
165            Some(6) if src.len() > 16 => {
166                src.advance(1);
167                let header = src.split_to(16);
168                let connection_id = Uuid::from_slice(&header).unwrap();
169                Some(TcpWarpMessage::DisconnectHost { connection_id })
170            }
171            Some(7) if src.len() > 16 => {
172                src.advance(1);
173                let header = src.split_to(16);
174                let connection_id = Uuid::from_slice(&header).unwrap();
175                Some(TcpWarpMessage::DisconnectClient { connection_id })
176            }
177            Some(8) if src.len() > 16 => {
178                src.advance(1);
179                let header = src.split_to(16);
180                let connection_id = Uuid::from_slice(&header).unwrap();
181                Some(TcpWarpMessage::ConnectFailure { connection_id })
182            }
183            _ => {
184                debug!("looks like data is wrong [{}] {:?}", src.len(), src);
185                None
186            } // _ => None,
187        })
188    }
189}
190
191/// Command types.
192///
193/// Serialization scheme:
194/// - 1 - add ports u16 len * u16
195/// - 2 - host connect u16=(len + 2) u128 u16 len * u8
196/// - 3 - bytes client u128 u32 len * u8
197/// - 4 - bytes host u128 u32 len * u8
198/// - 5 - connected u128
199/// - 6 - disconnect host u128
200/// - 7 - disconnect client u128
201#[derive(Debug)]
202pub enum TcpWarpMessage {
203    AddPorts(Vec<u16>),
204    Connected {
205        connection_id: Uuid,
206    },
207    BytesClient {
208        connection_id: Uuid,
209        data: BytesMut,
210    },
211    BytesServer {
212        data: BytesMut,
213    },
214    BytesHost {
215        connection_id: Uuid,
216        data: BytesMut,
217    },
218    Connect {
219        connection_id: Uuid,
220        connection: TcpWarpPortConnection,
221        sender: Sender<TcpWarpMessage>,
222        connected_sender: oneshot::Sender<Result<(), io::Error>>,
223    },
224    ConnectForward {
225        connection_id: Uuid,
226        sender: Sender<TcpWarpMessage>,
227        connected_sender: oneshot::Sender<Result<(), io::Error>>,
228    },
229    ConnectFailure {
230        connection_id: Uuid,
231    },
232    Disconnect,
233    Listener(AbortHandle),
234    HostConnect {
235        connection_id: Uuid,
236        host: Option<String>,
237        port: u16,
238    },
239    DisconnectHost {
240        connection_id: Uuid,
241    },
242    DisconnectClient {
243        connection_id: Uuid,
244    },
245}
246
247pub struct TcpWarpProtoClient {
248    pub connection_id: Uuid,
249}
250
251impl Encoder for TcpWarpProtoClient {
252    type Item = BytesMut;
253    type Error = io::Error;
254
255    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> {
256        dst.extend_from_slice(&item);
257        Ok(())
258    }
259}
260
261impl Decoder for TcpWarpProtoClient {
262    type Item = TcpWarpMessage;
263    type Error = io::Error;
264
265    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<TcpWarpMessage>> {
266        if src.is_empty() {
267            return Ok(None);
268        }
269
270        Ok(Some(TcpWarpMessage::BytesClient {
271            connection_id: self.connection_id,
272            data: src.split(),
273        }))
274    }
275}
276
277pub struct TcpWarpProtoHost {
278    pub connection_id: Uuid,
279}
280
281impl Encoder for TcpWarpProtoHost {
282    type Item = BytesMut;
283    type Error = io::Error;
284
285    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> {
286        dst.extend_from_slice(&item);
287        Ok(())
288    }
289}
290
291impl Decoder for TcpWarpProtoHost {
292    type Item = TcpWarpMessage;
293    type Error = io::Error;
294
295    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<TcpWarpMessage>> {
296        if src.is_empty() {
297            return Ok(None);
298        }
299
300        Ok(Some(TcpWarpMessage::BytesHost {
301            connection_id: self.connection_id,
302            data: src.split(),
303        }))
304    }
305}