1use std::time::Duration;
26
27use axum::{
28 extract::{
29 ws::{Message, WebSocket},
30 State, WebSocketUpgrade,
31 },
32 response::Response,
33 routing::get,
34 Router,
35};
36use futures_util::{SinkExt, StreamExt};
37use tokio::net::TcpListener;
38use tokio_tungstenite::{connect_async, tungstenite::Message as TgMessage};
39use tracing::{info, warn};
40
41#[derive(Clone)]
42pub struct WsState {
43 pub upstream_ws_url: String,
44 pub connect_timeout: Duration,
48}
49
50pub async fn run_ws(
53 port: u16,
54 upstream_ws_url: String,
55 connect_timeout: Duration,
56) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
57 let state = WsState {
58 upstream_ws_url,
59 connect_timeout,
60 };
61 let app = Router::new().route("/", get(ws_upgrade)).with_state(state);
62 let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
63 let listener = TcpListener::bind(&addr).await?;
64 info!("tidepool WS proxy listening on ws://{addr}");
65 axum::serve(listener, app).await?;
66 Ok(())
67}
68
69async fn ws_upgrade(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
70 ws.on_upgrade(move |socket| proxy_connection(socket, state))
71}
72
73async fn proxy_connection(client_socket: WebSocket, state: WsState) {
74 let upstream_url = state.upstream_ws_url.clone();
75
76 let upstream =
78 match tokio::time::timeout(state.connect_timeout, connect_async(&upstream_url)).await {
79 Ok(Ok((ws_stream, _resp))) => ws_stream,
80 Ok(Err(e)) => {
81 warn!(err = %e, upstream = %upstream_url, "upstream WS connect failed");
82 let _ = close_client(client_socket).await;
83 return;
84 }
85 Err(_) => {
86 warn!(upstream = %upstream_url, "upstream WS connect timed out");
87 let _ = close_client(client_socket).await;
88 return;
89 }
90 };
91
92 let (mut client_sink, mut client_stream) = client_socket.split();
93 let (mut upstream_sink, mut upstream_stream) = upstream.split();
94
95 let pump_a = async move {
97 while let Some(Ok(msg)) = client_stream.next().await {
98 let Some(out) = axum_to_tungstenite(msg) else {
99 continue;
100 };
101 let was_close = matches!(out, TgMessage::Close(_));
102 if upstream_sink.send(out).await.is_err() || was_close {
103 break;
104 }
105 }
106 };
107
108 let pump_b = async move {
110 while let Some(Ok(msg)) = upstream_stream.next().await {
111 let Some(out) = tungstenite_to_axum(msg) else {
112 continue;
113 };
114 let was_close = matches!(out, Message::Close(_));
115 if client_sink.send(out).await.is_err() || was_close {
116 break;
117 }
118 }
119 };
120
121 tokio::select! {
123 () = pump_a => {}
124 () = pump_b => {}
125 }
126}
127
128async fn close_client(mut socket: WebSocket) -> Result<(), axum::Error> {
129 socket.send(Message::Close(None)).await
130}
131
132#[allow(clippy::implicit_clone)]
140fn axum_to_tungstenite(msg: Message) -> Option<TgMessage> {
141 match msg {
142 Message::Text(s) => Some(TgMessage::Text(s.to_string())),
143 Message::Binary(b) => Some(TgMessage::Binary(b.to_vec())),
144 Message::Close(Some(c)) => Some(TgMessage::Close(Some(
145 tokio_tungstenite::tungstenite::protocol::CloseFrame {
146 code: c.code.into(),
147 reason: c.reason.to_string().into(),
148 },
149 ))),
150 Message::Close(None) => Some(TgMessage::Close(None)),
151 Message::Ping(_) | Message::Pong(_) => None,
156 }
157}
158
159#[allow(clippy::implicit_clone)]
160fn tungstenite_to_axum(msg: TgMessage) -> Option<Message> {
161 match msg {
162 TgMessage::Text(s) => Some(Message::Text(s.to_string().into())),
163 TgMessage::Binary(b) => Some(Message::Binary(b.to_vec().into())),
164 TgMessage::Close(Some(c)) => Some(Message::Close(Some(axum::extract::ws::CloseFrame {
165 code: c.code.into(),
166 reason: c.reason.to_string().into(),
167 }))),
168 TgMessage::Close(None) => Some(Message::Close(None)),
169 TgMessage::Ping(_) | TgMessage::Pong(_) | TgMessage::Frame(_) => None,
170 }
171}