trz_gateway_server/server/
tunnel.rs

1use std::future::Ready;
2use std::future::ready;
3use std::io::ErrorKind;
4use std::sync::Arc;
5
6use axum::extract::Path;
7use axum::extract::WebSocketUpgrade;
8use axum::extract::ws::Message;
9use axum::extract::ws::WebSocket;
10use axum::http::Uri;
11use axum::response::IntoResponse;
12use futures::SinkExt;
13use futures::StreamExt;
14use hyper_util::rt::TokioIo;
15use nameth::NamedEnumValues as _;
16use nameth::NamedType as _;
17use nameth::nameth;
18use rustls::pki_types::DnsName;
19use rustls::pki_types::InvalidDnsNameError;
20use rustls::pki_types::ServerName;
21use tokio::io::AsyncRead;
22use tokio::io::AsyncWrite;
23use tokio_util::io::CopyToBytes;
24use tokio_util::io::SinkWriter;
25use tokio_util::io::StreamReader;
26use tonic::transport::Channel;
27use tracing::Instrument as _;
28use tracing::Span;
29use tracing::info;
30use tracing::info_span;
31use tracing::warn;
32use trz_gateway_common::id::ClientId;
33use trz_gateway_common::id::ClientName;
34
35use super::Server;
36
37impl Server {
38    pub async fn tunnel(
39        self: Arc<Self>,
40        client_id: Option<ClientId>,
41        Path(client_name): Path<ClientName>,
42        web_socket: WebSocketUpgrade,
43    ) -> impl IntoResponse {
44        let span = if let Some(client_id) = client_id {
45            info_span!("Tunnel", %client_name, %client_id)
46        } else {
47            info_span!("Tunnel", %client_name)
48        };
49        let _entered = span.clone().entered();
50        info!("Incoming tunnel");
51        web_socket.on_upgrade(move |web_socket| {
52            let _entered = span.clone().entered();
53            self.process_websocket(client_name, web_socket);
54            ready(())
55        })
56    }
57
58    fn process_websocket(self: Arc<Self>, client_name: ClientName, web_socket: WebSocket) {
59        let (sink, stream) = web_socket.split();
60
61        let reader = {
62            #[nameth]
63            #[derive(thiserror::Error, Debug)]
64            #[error("[{n}] {0}", n = Self::type_name())]
65            struct ReadError(axum::Error);
66
67            StreamReader::new(stream.map(|message| {
68                message.map(Message::into_data).map_err(|error| {
69                    std::io::Error::new(ErrorKind::ConnectionAborted, ReadError(error))
70                })
71            }))
72        };
73
74        let writer = {
75            #[nameth]
76            #[derive(thiserror::Error, Debug)]
77            #[error("[{n}] {0}", n = Self::type_name())]
78            struct WriteError(axum::Error);
79
80            let sink = CopyToBytes::new(sink.with(|data| ready(Ok(Message::Binary(data)))))
81                .sink_map_err(|error| {
82                    std::io::Error::new(ErrorKind::ConnectionAborted, WriteError(error))
83                });
84            SinkWriter::new(sink)
85        };
86
87        let stream = tokio::io::join(reader, writer);
88        tokio::spawn(
89            self.process_connection(client_name, stream)
90                .in_current_span(),
91        );
92    }
93
94    async fn process_connection(
95        self: Arc<Self>,
96        client_name: ClientName,
97        connection: impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
98    ) -> Result<(), TunnelError> {
99        let tls_stream = self
100            .tls_client
101            .connect(
102                ServerName::DnsName(DnsName::try_from(client_name.to_string())?),
103                connection,
104            )
105            .await
106            .map_err(TunnelError::TlsConnectError)?;
107
108        // The endpoint is irrelevant: gRPC isn't actually connecting to this endpoint.
109        // Instead we are manually providing the connection using 'connect_with_connector'.
110        // The connection used by gRPC is the bi-directional stream based on the WebSocket.
111        let endpoint = tonic::transport::Endpoint::new(format!(
112            "https://localhost/remote/tunnel/{client_name}"
113        ))
114        .map_err(|_| TunnelError::InvalidEndpoint)?;
115        let connector = make_single_use_connector(tls_stream)
116            .await
117            .map_err(TunnelError::SingleUseConnectorError)?;
118        let channel: Channel = endpoint
119            .connect_with_connector(tower::service_fn(connector))
120            .await
121            .map_err(TunnelError::GrpcConnectError)?;
122
123        self.connections.add(client_name, channel);
124        Ok(())
125    }
126}
127
128async fn make_single_use_connector<S: AsyncRead + AsyncWrite>(
129    stream: S,
130) -> std::io::Result<impl FnMut(Uri) -> Ready<std::io::Result<TokioIo<S>>>> {
131    let span = Span::current();
132    let mut single_use_connection = Some(TokioIo::new(stream));
133    let connector = move |_uri| {
134        span.in_scope(|| {
135            let Some(connection) = single_use_connection.take() else {
136                let error = std::io::Error::new(
137                    ErrorKind::AddrInUse,
138                    "The WebSocket was already used to create a channel",
139                );
140                warn!("{error}");
141                return ready(Err(error));
142            };
143            // `single_use_connection` has been consumed and is now empty.
144            assert!(single_use_connection.is_none());
145            ready(Ok(connection))
146        })
147    };
148    Ok(connector)
149}
150
151#[nameth]
152#[derive(thiserror::Error, Debug)]
153pub enum TunnelError {
154    #[error("[{n}] Failed to create synthetic endpoint", n = self.name())]
155    InvalidEndpoint,
156
157    #[error("[{n}] {0}", n = self.name())]
158    InvalidDnsName(#[from] InvalidDnsNameError),
159
160    #[error("[{n}] {0}", n = self.name())]
161    TlsConnectError(std::io::Error),
162
163    #[error("[{n}] {0}", n = self.name())]
164    SingleUseConnectorError(std::io::Error),
165
166    #[error("[{n}] {0}", n = self.name())]
167    GrpcConnectError(tonic::transport::Error),
168}