1use super::*;
2
3pub struct TcpWarpServer {
4 listen_address: SocketAddr,
5 connect_address: IpAddr,
6}
7
8impl TcpWarpServer {
9 pub fn new(listen_address: SocketAddr, connect_address: IpAddr) -> Self {
10 Self {
11 listen_address,
12 connect_address,
13 }
14 }
15
16 pub async fn listen(&self) -> Result<(), Box<dyn Error>> {
17 let mut listener = TcpListener::bind(&self.listen_address).await?;
18 let mut incoming = listener.incoming();
19 let connect_address = self.connect_address;
20
21 while let Some(Ok(stream)) = incoming.next().await {
22 spawn(async move {
23 if let Err(e) = process(stream, connect_address).await {
24 println!("failed to process connection; error = {}", e);
25 }
26 });
27 }
28 Ok(())
29 }
30}
31
32async fn process(stream: TcpStream, connect_address: IpAddr) -> Result<(), Box<dyn Error>> {
33 let mut transport = Framed::new(stream, TcpWarpProto);
34
35 transport.send(TcpWarpMessage::AddPorts(vec![])).await?;
36
37 let (mut wtransport, mut rtransport) = transport.split();
38
39 let (sender, mut receiver) = channel(100);
40
41 let mut connections = HashMap::new();
42
43 let forward_task = async move {
44 debug!("in receiver task process");
45 while let Some(message) = receiver.next().await {
46 debug!("received in fw message: {:?}", message);
47 let message = match message {
48 TcpWarpMessage::ConnectForward {
49 connection_id,
50 sender,
51 connected_sender,
52 } => {
53 debug!("adding connection: {}", connection_id);
54 if let Err(err) = connected_sender.send(Ok(())) {
55 error!("connected sender errored: {:?}", err);
56 }
57 connections.insert(connection_id.clone(), sender.clone());
58 TcpWarpMessage::Connected { connection_id }
59 }
60 TcpWarpMessage::DisconnectClient { ref connection_id } => {
61 debug!(
62 "{} client connection disconnected, handle server disconnect",
63 connection_id
64 );
65 if let Some(mut sender) = connections.remove(connection_id) {
66 if let Err(err) = sender.send(message).await {
67 error!("cannot send to channel: {}", err);
68 }
69 } else {
70 error!("connection not found: {}", connection_id);
71 }
72 debug!("connections in pool: {}", connections.len());
73 continue;
74 }
75 TcpWarpMessage::BytesClient {
76 connection_id,
77 data,
78 } => {
79 if let Some(sender) = connections.get_mut(&connection_id) {
80 debug!(
81 "forward message to host port of connection: {}",
82 connection_id
83 );
84 if let Err(err) = sender.send(TcpWarpMessage::BytesServer { data }).await {
85 error!("cannot send to channel: {}", err);
86 };
87 } else {
88 error!("connection not found: {}", connection_id);
89 }
90 continue;
91 }
92 regular_message => regular_message,
93 };
94 debug!("sending message {:?} from server to tunnel client", message);
95 wtransport.send(message).await?
96 }
97
98 debug!("no more messages, closing forward to tunnel client task");
99 wtransport.close().await?;
100 receiver.close();
101
102 Ok::<(), io::Error>(())
103 };
104
105 let processing_task = async move {
106 while let Some(Ok(message)) = rtransport.next().await {
107 debug!("server received from tunnel client {:?}", message);
108 if let Err(err) =
109 process_client_to_host_message(message, sender.clone(), connect_address).await
110 {
111 error!("error in processing: {}", err);
112 }
113 }
114
115 debug!("processing task for client to host tunnel finished");
116
117 Ok::<(), io::Error>(())
118 };
119
120 let (_, _) = try_join!(forward_task, processing_task)?;
121
122 debug!("finished process of tunnel connection");
123
124 Ok(())
125}
126
127async fn process_client_to_host_message(
128 message: TcpWarpMessage,
129 mut client_sender: Sender<TcpWarpMessage>,
130 connect_address: IpAddr,
131) -> Result<(), io::Error> {
132 match message {
133 TcpWarpMessage::HostConnect {
134 connection_id,
135 host,
136 port,
137 } => {
138 let client_sender_ = client_sender.clone();
139 spawn(async move {
140 let connect_address = connect_address.to_string();
141 let socket_address = format!(
142 "{}:{}",
143 host.unwrap_or_else(|| connect_address.to_string()),
144 port
145 );
146 debug!("host connection to {}", socket_address);
147 if let Err(err) =
148 process_host_connection(client_sender_, connection_id, socket_address).await
149 {
150 error!(
151 "failed connection {} {}: {}",
152 connect_address, connection_id, err
153 );
154 }
155 });
156 }
157 TcpWarpMessage::DisconnectClient { .. } => {
158 if let Err(err) = client_sender.send(message).await {
159 error!(
160 "cannot send message DisconnectClient to forward channel: {}",
161 err
162 );
163 }
164 }
165 TcpWarpMessage::BytesClient { .. } => {
166 if let Err(err) = client_sender.send(message).await {
167 error!(
168 "cannot send message BytesClient to forward channel: {}",
169 err
170 );
171 }
172 }
173 other_message => warn!("unsupported message: {:?}", other_message),
174 }
175 Ok(())
176}
177
178async fn process_host_connection<S: ToSocketAddrs>(
179 mut client_sender: Sender<TcpWarpMessage>,
180 connection_id: Uuid,
181 socket_address: S,
182) -> Result<(), Box<dyn Error>> {
183 debug!("{} new connection", connection_id);
184
185 let stream = match TcpStream::connect(socket_address).await {
186 Ok(stream) => stream,
187 Err(err) => {
188 client_sender
189 .send(TcpWarpMessage::ConnectFailure { connection_id })
190 .await?;
191 return Err(err.into());
192 }
193 };
194
195 let (mut wtransport, mut rtransport) =
196 Framed::new(stream, TcpWarpProtoHost { connection_id }).split();
197
198 let (host_sender, mut host_receiver) = channel(100);
199
200 let forward_task = async move {
201 debug!("{} in receiver task process_host_connection", connection_id);
202
203 while let Some(message) = host_receiver.next().await {
204 debug!("{} just received a message: {:?}", connection_id, message);
205 match message {
206 TcpWarpMessage::DisconnectClient { .. } => break,
207 TcpWarpMessage::BytesServer { data } => wtransport.send(data).await?,
208 _ => (),
209 }
210 }
211
212 debug!(
213 "{} no more messages, closing process host forward task",
214 connection_id
215 );
216 wtransport.close().await?;
217 host_receiver.close();
218 debug!("{} closed write transport", connection_id);
219
220 Ok::<(), io::Error>(())
221 };
222
223 let (connected_sender, connected_receiver) = oneshot::channel();
224
225 client_sender
226 .send(TcpWarpMessage::ConnectForward {
227 connection_id,
228 sender: host_sender,
229 connected_sender,
230 })
231 .await?;
232
233 debug!("{} sended connect to client", connection_id);
234
235 let mut client_sender_ = client_sender.clone();
236
237 let processing_task = async move {
238 if let Err(err) = connected_receiver.await {
239 error!("{} connection error: {}", connection_id, err);
240 }
241 while let Some(Ok(message)) = rtransport.next().await {
242 if let Err(err) = client_sender_.send(message).await {
243 error!("{} {}", connection_id, err);
244 }
245 }
246
247 let message = TcpWarpMessage::DisconnectHost { connection_id };
248
249 debug!("{} sending disconnect host message", connection_id);
250
251 if let Err(err) = client_sender_.send(message).await {
252 error!("{} err: {}", connection_id, err);
253 }
254
255 debug!("{} host connection processing task done", connection_id);
256
257 Ok::<(), io::Error>(())
258 };
259
260 try_join!(forward_task, processing_task)?;
261
262 debug!("{} disconnect, processing task done", connection_id);
263
264 Ok(())
265}