siphon_server/
http_plane.rs1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::Result;
8use bytes::Bytes;
9use http_body_util::{BodyExt, Full};
10use hyper::body::Incoming;
11use hyper::server::conn::http1;
12use hyper::service::service_fn;
13use hyper::{Request, Response, StatusCode};
14use hyper_util::rt::TokioIo;
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::net::TcpListener;
17use tokio::sync::oneshot;
18use tokio_rustls::TlsAcceptor;
19
20use siphon_protocol::ServerMessage;
21
22use crate::router::Router;
23use crate::state::ResponseRegistry;
24
25pub struct HttpPlane {
27 router: Arc<Router>,
28 base_domain: String,
29 stream_id_counter: AtomicU64,
30 response_registry: ResponseRegistry,
32 tls_acceptor: Option<TlsAcceptor>,
34}
35
36impl HttpPlane {
37 pub fn new(
38 router: Arc<Router>,
39 base_domain: String,
40 response_registry: ResponseRegistry,
41 tls_acceptor: Option<TlsAcceptor>,
42 ) -> Arc<Self> {
43 Arc::new(Self {
44 router,
45 base_domain,
46 stream_id_counter: AtomicU64::new(1),
47 response_registry,
48 tls_acceptor,
49 })
50 }
51
52 fn next_stream_id(&self) -> u64 {
53 self.stream_id_counter.fetch_add(1, Ordering::Relaxed)
54 }
55
56 async fn serve_connection<S>(self: Arc<Self>, stream: S, peer_addr: SocketAddr)
58 where
59 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60 {
61 let io = TokioIo::new(stream);
62
63 let service = service_fn(move |req| {
64 let this = self.clone();
65 async move { this.handle_request(req).await }
66 });
67
68 if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
69 tracing::debug!("HTTP connection error from {}: {}", peer_addr, e);
70 }
71 }
72
73 pub async fn run(self: Arc<Self>, addr: SocketAddr) -> Result<()> {
75 let listener = TcpListener::bind(addr).await?;
76
77 if self.tls_acceptor.is_some() {
78 tracing::info!("HTTPS plane listening on {}", addr);
79 } else {
80 tracing::info!("HTTP plane listening on {}", addr);
81 }
82
83 self.run_with_listener(listener).await
84 }
85
86 pub async fn run_with_listener(self: Arc<Self>, listener: TcpListener) -> Result<()> {
91 loop {
92 let (stream, peer_addr) = listener.accept().await?;
93 tracing::debug!("HTTP connection from {}", peer_addr);
94 let this = self.clone();
95
96 tokio::spawn(async move {
97 if let Some(ref acceptor) = this.tls_acceptor {
98 match acceptor.accept(stream).await {
100 Ok(tls_stream) => {
101 this.serve_connection(tls_stream, peer_addr).await;
102 }
103 Err(e) => {
104 tracing::warn!("TLS handshake failed from {}: {}", peer_addr, e);
105 }
106 }
107 } else {
108 this.serve_connection(stream, peer_addr).await;
110 }
111 });
112 }
113 }
114
115 async fn handle_request(
116 self: Arc<Self>,
117 req: Request<Incoming>,
118 ) -> Result<Response<Full<Bytes>>, Infallible> {
119 tracing::debug!(
120 "HTTP request: {} {} (Host: {:?})",
121 req.method(),
122 req.uri(),
123 req.headers().get("host")
124 );
125
126 let subdomain = match self.extract_subdomain(&req) {
128 Some(s) => s,
129 None => {
130 tracing::warn!("Request without valid subdomain");
131 return Ok(Response::builder()
132 .status(StatusCode::BAD_REQUEST)
133 .body(Full::new(Bytes::from("Invalid or missing subdomain")))
134 .unwrap());
135 }
136 };
137
138 tracing::debug!("Forwarding to tunnel: {}", subdomain);
139
140 let sender = match self.router.get_sender(&subdomain) {
142 Some(s) => s,
143 None => {
144 tracing::warn!("No tunnel for subdomain: {}", subdomain);
145 return Ok(Response::builder()
146 .status(StatusCode::NOT_FOUND)
147 .body(Full::new(Bytes::from(format!(
148 "Tunnel not found for: {}",
149 subdomain
150 ))))
151 .unwrap());
152 }
153 };
154
155 let stream_id = self.next_stream_id();
157
158 let method = req.method().to_string();
160 let uri = req.uri().to_string();
161
162 let headers: Vec<(String, String)> = req
163 .headers()
164 .iter()
165 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
166 .collect();
167
168 let body = match req.into_body().collect().await {
170 Ok(collected) => collected.to_bytes().to_vec(),
171 Err(e) => {
172 tracing::error!("Failed to read request body: {}", e);
173 return Ok(Response::builder()
174 .status(StatusCode::INTERNAL_SERVER_ERROR)
175 .body(Full::new(Bytes::from("Failed to read request body")))
176 .unwrap());
177 }
178 };
179
180 let (response_tx, response_rx) = oneshot::channel();
182
183 self.response_registry.insert(stream_id, response_tx);
185
186 let msg = ServerMessage::HttpRequest {
188 stream_id,
189 method,
190 uri,
191 headers,
192 body,
193 };
194
195 if let Err(e) = sender.send(msg).await {
196 tracing::error!("Failed to send request to tunnel: {}", e);
197 self.response_registry.remove(&stream_id);
199
200 return Ok(Response::builder()
201 .status(StatusCode::BAD_GATEWAY)
202 .body(Full::new(Bytes::from("Tunnel connection lost")))
203 .unwrap());
204 }
205
206 let timeout = Duration::from_secs(30);
208 match tokio::time::timeout(timeout, response_rx).await {
209 Ok(Ok(response_data)) => {
210 let mut builder = Response::builder().status(response_data.status);
212
213 for (name, value) in response_data.headers {
214 builder = builder.header(name, value);
215 }
216
217 Ok(builder
218 .body(Full::new(Bytes::from(response_data.body)))
219 .unwrap())
220 }
221 Ok(Err(_)) => {
222 tracing::error!("Tunnel disconnected while waiting for response");
224 Ok(Response::builder()
225 .status(StatusCode::BAD_GATEWAY)
226 .body(Full::new(Bytes::from("Tunnel disconnected")))
227 .unwrap())
228 }
229 Err(_) => {
230 tracing::error!("Timeout waiting for tunnel response");
232 self.response_registry.remove(&stream_id);
234
235 Ok(Response::builder()
236 .status(StatusCode::GATEWAY_TIMEOUT)
237 .body(Full::new(Bytes::from("Tunnel response timeout")))
238 .unwrap())
239 }
240 }
241 }
242
243 fn extract_subdomain(&self, req: &Request<Incoming>) -> Option<String> {
245 let host = req.headers().get("host")?.to_str().ok()?;
246
247 let host = host.split(':').next()?;
249
250 if !host.ends_with(&self.base_domain) {
252 return None;
253 }
254
255 let subdomain_part = host.strip_suffix(&format!(".{}", self.base_domain))?;
257
258 Some(subdomain_part.split('.').next()?.to_string())
260 }
261}