trz_gateway_server/server/
tunnel.rs1use std::future::Ready;
2use std::future::ready;
3use std::io::ErrorKind;
4use std::sync::Arc;
5
6use axum::extract::Path;
7use axum::extract::WebSocketUpgrade;
8use axum::extract::ws::Message;
9use axum::extract::ws::WebSocket;
10use axum::http::Uri;
11use axum::response::IntoResponse;
12use futures::SinkExt;
13use futures::StreamExt;
14use hyper_util::rt::TokioIo;
15use nameth::NamedEnumValues as _;
16use nameth::NamedType as _;
17use nameth::nameth;
18use rustls::pki_types::DnsName;
19use rustls::pki_types::InvalidDnsNameError;
20use rustls::pki_types::ServerName;
21use tokio::io::AsyncRead;
22use tokio::io::AsyncWrite;
23use tokio_util::io::CopyToBytes;
24use tokio_util::io::SinkWriter;
25use tokio_util::io::StreamReader;
26use tonic::transport::Channel;
27use tracing::Instrument as _;
28use tracing::Span;
29use tracing::info;
30use tracing::info_span;
31use tracing::warn;
32use trz_gateway_common::id::ClientId;
33use trz_gateway_common::id::ClientName;
34
35use super::Server;
36
37impl Server {
38 pub async fn tunnel(
39 self: Arc<Self>,
40 client_id: Option<ClientId>,
41 Path(client_name): Path<ClientName>,
42 web_socket: WebSocketUpgrade,
43 ) -> impl IntoResponse {
44 let span = if let Some(client_id) = client_id {
45 info_span!("Tunnel", %client_name, %client_id)
46 } else {
47 info_span!("Tunnel", %client_name)
48 };
49 let _entered = span.clone().entered();
50 info!("Incoming tunnel");
51 web_socket.on_upgrade(move |web_socket| {
52 let _entered = span.clone().entered();
53 self.process_websocket(client_name, web_socket);
54 ready(())
55 })
56 }
57
58 fn process_websocket(self: Arc<Self>, client_name: ClientName, web_socket: WebSocket) {
59 let (sink, stream) = web_socket.split();
60
61 let reader = {
62 #[nameth]
63 #[derive(thiserror::Error, Debug)]
64 #[error("[{n}] {0}", n = Self::type_name())]
65 struct ReadError(axum::Error);
66
67 StreamReader::new(stream.map(|message| {
68 message.map(Message::into_data).map_err(|error| {
69 std::io::Error::new(ErrorKind::ConnectionAborted, ReadError(error))
70 })
71 }))
72 };
73
74 let writer = {
75 #[nameth]
76 #[derive(thiserror::Error, Debug)]
77 #[error("[{n}] {0}", n = Self::type_name())]
78 struct WriteError(axum::Error);
79
80 let sink = CopyToBytes::new(sink.with(|data| ready(Ok(Message::Binary(data)))))
81 .sink_map_err(|error| {
82 std::io::Error::new(ErrorKind::ConnectionAborted, WriteError(error))
83 });
84 SinkWriter::new(sink)
85 };
86
87 let stream = tokio::io::join(reader, writer);
88 tokio::spawn(
89 self.process_connection(client_name, stream)
90 .in_current_span(),
91 );
92 }
93
94 async fn process_connection(
95 self: Arc<Self>,
96 client_name: ClientName,
97 connection: impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
98 ) -> Result<(), TunnelError> {
99 let tls_stream = self
100 .tls_client
101 .connect(
102 ServerName::DnsName(DnsName::try_from(client_name.to_string())?),
103 connection,
104 )
105 .await
106 .map_err(TunnelError::TlsConnectError)?;
107
108 let endpoint = tonic::transport::Endpoint::new(format!(
112 "https://localhost/remote/tunnel/{client_name}"
113 ))
114 .map_err(|_| TunnelError::InvalidEndpoint)?;
115 let connector = make_single_use_connector(tls_stream)
116 .await
117 .map_err(TunnelError::SingleUseConnectorError)?;
118 let channel: Channel = endpoint
119 .connect_with_connector(tower::service_fn(connector))
120 .await
121 .map_err(TunnelError::GrpcConnectError)?;
122
123 self.connections.add(client_name, channel);
124 Ok(())
125 }
126}
127
128async fn make_single_use_connector<S: AsyncRead + AsyncWrite>(
129 stream: S,
130) -> std::io::Result<impl FnMut(Uri) -> Ready<std::io::Result<TokioIo<S>>>> {
131 let span = Span::current();
132 let mut single_use_connection = Some(TokioIo::new(stream));
133 let connector = move |_uri| {
134 span.in_scope(|| {
135 let Some(connection) = single_use_connection.take() else {
136 let error = std::io::Error::new(
137 ErrorKind::AddrInUse,
138 "The WebSocket was already used to create a channel",
139 );
140 warn!("{error}");
141 return ready(Err(error));
142 };
143 assert!(single_use_connection.is_none());
145 ready(Ok(connection))
146 })
147 };
148 Ok(connector)
149}
150
151#[nameth]
152#[derive(thiserror::Error, Debug)]
153pub enum TunnelError {
154 #[error("[{n}] Failed to create synthetic endpoint", n = self.name())]
155 InvalidEndpoint,
156
157 #[error("[{n}] {0}", n = self.name())]
158 InvalidDnsName(#[from] InvalidDnsNameError),
159
160 #[error("[{n}] {0}", n = self.name())]
161 TlsConnectError(std::io::Error),
162
163 #[error("[{n}] {0}", n = self.name())]
164 SingleUseConnectorError(std::io::Error),
165
166 #[error("[{n}] {0}", n = self.name())]
167 GrpcConnectError(tonic::transport::Error),
168}