1use super::*;
2use std::io;
3
4pub struct TcpWarpProto;
5
6impl Encoder for TcpWarpProto {
7 type Item = TcpWarpMessage;
8 type Error = io::Error;
9
10 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> {
11 match item {
12 TcpWarpMessage::AddPorts(ports) => {
13 dst.reserve(1 + 2 + ports.len() * 2);
14 dst.put_u8(1);
15 dst.put_u16(ports.len() as u16);
16 for port in ports {
17 dst.put_u16(port);
18 }
19 }
20 TcpWarpMessage::HostConnect {
21 connection_id,
22 host,
23 port,
24 } => {
25 let len = host.as_ref().map_or(0, |x| x.len());
26 dst.reserve(1 + 2 + 16 + 2 + len);
27 dst.put_u8(2);
28 dst.put_u16(len as u16);
29 dst.put_u128(connection_id.as_u128());
30 dst.put_u16(port);
31 if let Some(data) = host {
32 dst.put_slice(data.as_bytes());
33 }
34 }
35 TcpWarpMessage::BytesClient {
36 connection_id,
37 data,
38 } => {
39 dst.reserve(1 + 16 + 4 + data.len());
40 dst.put_u8(3);
41 dst.put_u128(connection_id.as_u128());
42 dst.put_u32(data.len() as u32);
43 dst.put_slice(&data);
44 }
45 TcpWarpMessage::BytesHost {
46 connection_id,
47 data,
48 } => {
49 dst.reserve(1 + 16 + 4 + data.len());
50 dst.put_u8(4);
51 dst.put_u128(connection_id.as_u128());
52 dst.put_u32(data.len() as u32);
53 dst.put_slice(&data);
54 }
55 TcpWarpMessage::Connected { connection_id } => {
56 dst.reserve(1 + 16);
57 dst.put_u8(5);
58 dst.put_u128(connection_id.as_u128());
59 }
60 TcpWarpMessage::DisconnectHost { connection_id } => {
61 dst.reserve(1 + 16);
62 dst.put_u8(6);
63 dst.put_u128(connection_id.as_u128());
64 }
65 TcpWarpMessage::DisconnectClient { connection_id } => {
66 dst.reserve(1 + 16);
67 dst.put_u8(7);
68 dst.put_u128(connection_id.as_u128());
69 }
70 TcpWarpMessage::ConnectFailure { connection_id } => {
71 dst.reserve(1 + 16);
72 dst.put_u8(8);
73 dst.put_u128(connection_id.as_u128());
74 }
75 other => {
76 error!("unknown message: {:?}", other);
77 }
78 }
79
80 Ok(())
81 }
82}
83
84impl Decoder for TcpWarpProto {
85 type Item = TcpWarpMessage;
86 type Error = io::Error;
87
88 fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<TcpWarpMessage>> {
89 Ok(match src.get(0) {
90 Some(1) if src.len() > 2 => {
91 let len = u16::from_be_bytes(src[1..3].try_into().unwrap());
92 if src.len() == 3 && len == 0 {
93 src.advance(3);
94 return Ok(Some(TcpWarpMessage::AddPorts(vec![])));
95 }
96 if len as usize * 2 + 3 <= src.len() {
97 src.advance(3);
98 let data = src.split_to(len as usize * 2);
99 let ports = data
100 .chunks_exact(2)
101 .map(|x| u16::from_be_bytes(x.try_into().unwrap()))
102 .collect();
103 Some(TcpWarpMessage::AddPorts(ports))
104 } else {
105 None
106 }
107 }
108 Some(2) if src.len() > (16 + 2) => {
109 let len = u16::from_be_bytes(src[1..3].try_into().unwrap()) as usize;
110 if 2 + 16 + 2 + len < src.len() {
111 src.advance(3);
112 let header = src.split_to(18);
113 let connection_id = Uuid::from_slice(&header[0..16]).unwrap();
114 let port = u16::from_be_bytes(header[16..18].try_into().unwrap());
115 let host = if len > 0 {
116 String::from_utf8(src.split_to(len).to_vec()).ok()
117 } else {
118 None
119 };
120 Some(TcpWarpMessage::HostConnect {
121 connection_id,
122 host,
123 port,
124 })
125 } else {
126 None
127 }
128 }
129 Some(3) if src.len() > (16 + 4 + 1) => {
130 let len = u32::from_be_bytes(src[17..21].try_into().unwrap()) as usize;
131 if len as usize + 16 + 4 < src.len() {
132 src.advance(1);
133 let header = src.split_to(20);
134 let connection_id = Uuid::from_slice(&header[0..16]).unwrap();
135 let data = src.split_to(len);
136 Some(TcpWarpMessage::BytesClient {
137 connection_id,
138 data,
139 })
140 } else {
141 None
142 }
143 }
144 Some(4) if src.len() > (16 + 4 + 1) => {
145 let len = u32::from_be_bytes(src[17..21].try_into().unwrap()) as usize;
146 if len as usize + 16 + 4 < src.len() {
147 src.advance(1);
148 let header = src.split_to(20);
149 let connection_id = Uuid::from_slice(&header[0..16]).unwrap();
150 let data = src.split_to(len);
151 Some(TcpWarpMessage::BytesHost {
152 connection_id,
153 data,
154 })
155 } else {
156 None
157 }
158 }
159 Some(5) if src.len() > 16 => {
160 src.advance(1);
161 let header = src.split_to(16);
162 let connection_id = Uuid::from_slice(&header).unwrap();
163 Some(TcpWarpMessage::Connected { connection_id })
164 }
165 Some(6) if src.len() > 16 => {
166 src.advance(1);
167 let header = src.split_to(16);
168 let connection_id = Uuid::from_slice(&header).unwrap();
169 Some(TcpWarpMessage::DisconnectHost { connection_id })
170 }
171 Some(7) if src.len() > 16 => {
172 src.advance(1);
173 let header = src.split_to(16);
174 let connection_id = Uuid::from_slice(&header).unwrap();
175 Some(TcpWarpMessage::DisconnectClient { connection_id })
176 }
177 Some(8) if src.len() > 16 => {
178 src.advance(1);
179 let header = src.split_to(16);
180 let connection_id = Uuid::from_slice(&header).unwrap();
181 Some(TcpWarpMessage::ConnectFailure { connection_id })
182 }
183 _ => {
184 debug!("looks like data is wrong [{}] {:?}", src.len(), src);
185 None
186 } })
188 }
189}
190
191#[derive(Debug)]
202pub enum TcpWarpMessage {
203 AddPorts(Vec<u16>),
204 Connected {
205 connection_id: Uuid,
206 },
207 BytesClient {
208 connection_id: Uuid,
209 data: BytesMut,
210 },
211 BytesServer {
212 data: BytesMut,
213 },
214 BytesHost {
215 connection_id: Uuid,
216 data: BytesMut,
217 },
218 Connect {
219 connection_id: Uuid,
220 connection: TcpWarpPortConnection,
221 sender: Sender<TcpWarpMessage>,
222 connected_sender: oneshot::Sender<Result<(), io::Error>>,
223 },
224 ConnectForward {
225 connection_id: Uuid,
226 sender: Sender<TcpWarpMessage>,
227 connected_sender: oneshot::Sender<Result<(), io::Error>>,
228 },
229 ConnectFailure {
230 connection_id: Uuid,
231 },
232 Disconnect,
233 Listener(AbortHandle),
234 HostConnect {
235 connection_id: Uuid,
236 host: Option<String>,
237 port: u16,
238 },
239 DisconnectHost {
240 connection_id: Uuid,
241 },
242 DisconnectClient {
243 connection_id: Uuid,
244 },
245}
246
247pub struct TcpWarpProtoClient {
248 pub connection_id: Uuid,
249}
250
251impl Encoder for TcpWarpProtoClient {
252 type Item = BytesMut;
253 type Error = io::Error;
254
255 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> {
256 dst.extend_from_slice(&item);
257 Ok(())
258 }
259}
260
261impl Decoder for TcpWarpProtoClient {
262 type Item = TcpWarpMessage;
263 type Error = io::Error;
264
265 fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<TcpWarpMessage>> {
266 if src.is_empty() {
267 return Ok(None);
268 }
269
270 Ok(Some(TcpWarpMessage::BytesClient {
271 connection_id: self.connection_id,
272 data: src.split(),
273 }))
274 }
275}
276
277pub struct TcpWarpProtoHost {
278 pub connection_id: Uuid,
279}
280
281impl Encoder for TcpWarpProtoHost {
282 type Item = BytesMut;
283 type Error = io::Error;
284
285 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> {
286 dst.extend_from_slice(&item);
287 Ok(())
288 }
289}
290
291impl Decoder for TcpWarpProtoHost {
292 type Item = TcpWarpMessage;
293 type Error = io::Error;
294
295 fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<TcpWarpMessage>> {
296 if src.is_empty() {
297 return Ok(None);
298 }
299
300 Ok(Some(TcpWarpMessage::BytesHost {
301 connection_id: self.connection_id,
302 data: src.split(),
303 }))
304 }
305}