1use anyhow::Result;
11use axum::{
12 body::Body,
13 extract::{
14 ws::{Message, WebSocket, WebSocketUpgrade},
15 State,
16 },
17 http::{Request, StatusCode},
18 response::{IntoResponse, Response},
19 routing::get,
20 Router,
21};
22use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
23use bytes::Bytes;
24use futures_util::{SinkExt, StreamExt};
25use parking_lot::RwLock;
26use serde::{Deserialize, Serialize};
27use std::{collections::HashMap, sync::Arc};
28use tokio::sync::mpsc;
29use tokio_stream::wrappers::ReceiverStream;
30
31#[derive(Debug, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38enum ClientFrame {
39 Register { subdomain: String, token: String },
40 ResHead { id: String, status: u16, headers: HashMap<String, String> },
41 ResBody { id: String, data: String }, ResEnd { id: String },
43 ResErr { id: String, message: String },
44}
45
46#[derive(Debug, Serialize)]
48#[serde(tag = "type", rename_all = "snake_case")]
49enum RelayFrame<'a> {
50 Ack { subdomain: &'a str },
51 Deny { reason: &'a str },
52 Req {
53 id: &'a str,
54 method: &'a str,
55 path: &'a str,
56 headers: &'a HashMap<String, String>,
57 body: &'a str, },
59}
60
61#[derive(Clone)]
66pub struct RelayState {
67 tunnels: Arc<RwLock<HashMap<String, TunnelHandle>>>,
68 allowed_token: Arc<String>,
69 }
72
73#[derive(Clone)]
74struct TunnelHandle {
75 tx: mpsc::Sender<TunnelRequest>,
76}
77
78struct TunnelRequest {
79 id: String,
80 method: String,
81 path: String,
82 headers: HashMap<String, String>,
83 body: Bytes,
84 res_tx: mpsc::Sender<ResponseChunk>,
85}
86
87#[derive(Debug)]
88enum ResponseChunk {
89 Head { status: u16, headers: HashMap<String, String> },
90 Body (Bytes),
91 End,
92 Err (String),
93}
94
95pub async fn run_relay_server(port: u16, token: String) -> Result<()> {
100 let state = RelayState {
101 tunnels: Arc::new(RwLock::new(HashMap::new())),
102 allowed_token: Arc::new(token),
103 };
104
105 let app = Router::new()
106 .route("/tunnel", get(ws_handler))
107 .fallback(proxy_handler)
108 .with_state(state);
109
110 let addr = format!("0.0.0.0:{port}");
111 println!(" ◆ shunt relay listening on {addr}");
112 let listener = tokio::net::TcpListener::bind(&addr).await?;
113 axum::serve(listener, app).await?;
114 Ok(())
115}
116
117async fn ws_handler(
122 ws: WebSocketUpgrade,
123 State(state): State<RelayState>,
124) -> Response {
125 ws.on_upgrade(move |socket| handle_tunnel(socket, state))
126}
127
128async fn handle_tunnel(socket: WebSocket, state: RelayState) {
129 let (mut sink, mut stream) = socket.split();
130
131 let subdomain = loop {
133 match stream.next().await {
134 Some(Ok(Message::Text(text))) => {
135 match serde_json::from_str::<ClientFrame>(&text) {
136 Ok(ClientFrame::Register { subdomain, token }) => {
137 if token != *state.allowed_token {
138 let _ = sink.send(Message::Text(
139 serde_json::to_string(&RelayFrame::Deny { reason: "invalid token" }).unwrap()
140 )).await;
141 return;
142 }
143 let _ = sink.send(Message::Text(
144 serde_json::to_string(&RelayFrame::Ack { subdomain: &subdomain }).unwrap()
145 )).await;
146 break subdomain;
147 }
148 _ => { return; } }
150 }
151 _ => return,
152 }
153 };
154
155 let (tunnel_tx, mut tunnel_rx) = mpsc::channel::<TunnelRequest>(16);
157 state.tunnels.write().insert(subdomain.clone(), TunnelHandle { tx: tunnel_tx });
158 println!(" ◆ tunnel registered: {subdomain}");
159
160 let pending: Arc<RwLock<HashMap<String, mpsc::Sender<ResponseChunk>>>> =
163 Arc::new(RwLock::new(HashMap::new()));
164
165 let (ws_tx, mut ws_rx) = mpsc::channel::<Message>(64);
167
168 let ws_tx_clone = ws_tx.clone();
170 tokio::spawn(async move {
171 while let Some(msg) = ws_rx.recv().await {
172 if sink.send(msg).await.is_err() { break; }
173 }
174 });
175
176 let ws_tx2 = ws_tx_clone.clone();
178 let pending2 = pending.clone();
179 tokio::spawn(async move {
180 while let Some(req) = tunnel_rx.recv().await {
181 pending2.write().insert(req.id.clone(), req.res_tx);
182 let body_b64 = B64.encode(&req.body);
183 let frame = RelayFrame::Req {
184 id: &req.id,
185 method: &req.method,
186 path: &req.path,
187 headers: &req.headers,
188 body: &body_b64,
189 };
190 let text = serde_json::to_string(&frame).unwrap();
191 if ws_tx2.send(Message::Text(text)).await.is_err() { break; }
192 }
193 });
194
195 while let Some(Ok(msg)) = stream.next().await {
197 let text = match msg {
198 Message::Text(t) => t,
199 Message::Close(_) => break,
200 _ => continue,
201 };
202 let frame = match serde_json::from_str::<ClientFrame>(&text) {
203 Ok(f) => f,
204 Err(_) => continue,
205 };
206 match frame {
207 ClientFrame::ResHead { id, status, headers } => {
208 let tx = pending.read().get(&id).cloned();
210 if let Some(tx) = tx {
211 let _ = tx.send(ResponseChunk::Head { status, headers }).await;
212 }
213 }
214 ClientFrame::ResBody { id, data } => {
215 let tx = pending.read().get(&id).cloned();
216 if let Some(tx) = tx {
217 if let Ok(bytes) = B64.decode(&data) {
218 let _ = tx.send(ResponseChunk::Body(Bytes::from(bytes))).await;
219 }
220 }
221 }
222 ClientFrame::ResEnd { id } => {
223 let tx = pending.write().remove(&id);
224 if let Some(tx) = tx {
225 let _ = tx.send(ResponseChunk::End).await;
226 }
227 }
228 ClientFrame::ResErr { id, message } => {
229 let tx = pending.write().remove(&id);
230 if let Some(tx) = tx {
231 let _ = tx.send(ResponseChunk::Err(message)).await;
232 }
233 }
234 ClientFrame::Register { .. } => {} }
236 }
237
238 state.tunnels.write().remove(&subdomain);
240 println!(" · tunnel disconnected: {subdomain}");
241}
242
243async fn proxy_handler(
248 State(state): State<RelayState>,
249 req: Request<Body>,
250) -> Response {
251 let subdomain = match extract_subdomain(req.headers()) {
253 Some(s) => s,
254 None => return (StatusCode::BAD_REQUEST, "missing Host header").into_response(),
255 };
256
257 let handle = state.tunnels.read().get(&subdomain).cloned();
259 let handle = match handle {
260 Some(h) => h,
261 None => return (
262 StatusCode::BAD_GATEWAY,
263 format!("no tunnel connected for '{subdomain}'"),
264 ).into_response(),
265 };
266
267 let id = uuid::Uuid::new_v4().to_string();
269 let method = req.method().to_string();
270 let path = req.uri().path_and_query()
271 .map(|p| p.as_str().to_owned())
272 .unwrap_or_else(|| "/".to_owned());
273 let headers: HashMap<String, String> = req.headers().iter()
274 .filter_map(|(k, v)| {
275 let key = k.as_str().to_lowercase();
276 if matches!(key.as_str(), "host" | "connection" | "transfer-encoding" | "upgrade") {
278 return None;
279 }
280 v.to_str().ok().map(|v| (key, v.to_owned()))
281 })
282 .collect();
283 let body = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
284 Ok(b) => b,
285 Err(_) => return (StatusCode::BAD_REQUEST, "failed to read body").into_response(),
286 };
287
288 let (res_tx, res_rx) = mpsc::channel::<ResponseChunk>(32);
290 let tunnel_req = TunnelRequest { id, method, path, headers, body, res_tx };
291 if handle.tx.send(tunnel_req).await.is_err() {
292 return (StatusCode::BAD_GATEWAY, "tunnel send failed").into_response();
293 }
294
295 let mut rx = res_rx;
297 let (status, res_headers) = match rx.recv().await {
298 Some(ResponseChunk::Head { status, headers }) => (status, headers),
299 Some(ResponseChunk::Err(e)) => return (StatusCode::BAD_GATEWAY, e).into_response(),
300 _ => return (StatusCode::BAD_GATEWAY, "no response from tunnel").into_response(),
301 };
302
303 let stream = ReceiverStream::new(rx).filter_map(|chunk| async move {
305 match chunk {
306 ResponseChunk::Body(b) => Some(Ok::<_, std::convert::Infallible>(b)),
307 ResponseChunk::End | ResponseChunk::Head { .. } | ResponseChunk::Err(_) => None,
308 }
309 });
310
311 let mut builder = Response::builder()
312 .status(status);
313 for (k, v) in &res_headers {
314 builder = builder.header(k, v);
315 }
316 builder.body(Body::from_stream(stream)).unwrap_or_else(|_| {
317 (StatusCode::INTERNAL_SERVER_ERROR, "response build failed").into_response()
318 })
319}
320
321fn extract_subdomain(headers: &axum::http::HeaderMap) -> Option<String> {
326 let host = headers.get("host")?.to_str().ok()?;
327 let host = host.split(':').next()?;
330 let subdomain = host.split('.').next()?;
331 if subdomain.is_empty() { return None; }
332 Some(subdomain.to_owned())
333}