siphon_server/
http_plane.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::Result;
8use bytes::Bytes;
9use http_body_util::{BodyExt, Full};
10use hyper::body::Incoming;
11use hyper::server::conn::http1;
12use hyper::service::service_fn;
13use hyper::{Request, Response, StatusCode};
14use hyper_util::rt::TokioIo;
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::net::TcpListener;
17use tokio::sync::oneshot;
18use tokio_rustls::TlsAcceptor;
19
20use siphon_protocol::ServerMessage;
21
22use crate::router::Router;
23use crate::state::ResponseRegistry;
24
25/// HTTP data plane that receives traffic from Cloudflare
26pub struct HttpPlane {
27    router: Arc<Router>,
28    base_domain: String,
29    stream_id_counter: AtomicU64,
30    /// Shared registry for pending responses
31    response_registry: ResponseRegistry,
32    /// Optional TLS acceptor for HTTPS mode
33    tls_acceptor: Option<TlsAcceptor>,
34}
35
36impl HttpPlane {
37    pub fn new(
38        router: Arc<Router>,
39        base_domain: String,
40        response_registry: ResponseRegistry,
41        tls_acceptor: Option<TlsAcceptor>,
42    ) -> Arc<Self> {
43        Arc::new(Self {
44            router,
45            base_domain,
46            stream_id_counter: AtomicU64::new(1),
47            response_registry,
48            tls_acceptor,
49        })
50    }
51
52    fn next_stream_id(&self) -> u64 {
53        self.stream_id_counter.fetch_add(1, Ordering::Relaxed)
54    }
55
56    /// Serve an HTTP connection on any AsyncRead + AsyncWrite stream
57    async fn serve_connection<S>(self: Arc<Self>, stream: S, peer_addr: SocketAddr)
58    where
59        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60    {
61        let io = TokioIo::new(stream);
62
63        let service = service_fn(move |req| {
64            let this = self.clone();
65            async move { this.handle_request(req).await }
66        });
67
68        if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
69            tracing::debug!("HTTP connection error from {}: {}", peer_addr, e);
70        }
71    }
72
73    /// Start listening for HTTP/HTTPS traffic from Cloudflare
74    pub async fn run(self: Arc<Self>, addr: SocketAddr) -> Result<()> {
75        let listener = TcpListener::bind(addr).await?;
76
77        if self.tls_acceptor.is_some() {
78            tracing::info!("HTTPS plane listening on {}", addr);
79        } else {
80            tracing::info!("HTTP plane listening on {}", addr);
81        }
82
83        self.run_with_listener(listener).await
84    }
85
86    /// Start accepting connections from a pre-bound listener
87    ///
88    /// This is useful for testing where the caller wants to bind to an
89    /// ephemeral port and get the actual address before starting the server.
90    pub async fn run_with_listener(self: Arc<Self>, listener: TcpListener) -> Result<()> {
91        loop {
92            let (stream, peer_addr) = listener.accept().await?;
93            tracing::debug!("HTTP connection from {}", peer_addr);
94            let this = self.clone();
95
96            tokio::spawn(async move {
97                if let Some(ref acceptor) = this.tls_acceptor {
98                    // TLS mode
99                    match acceptor.accept(stream).await {
100                        Ok(tls_stream) => {
101                            this.serve_connection(tls_stream, peer_addr).await;
102                        }
103                        Err(e) => {
104                            tracing::warn!("TLS handshake failed from {}: {}", peer_addr, e);
105                        }
106                    }
107                } else {
108                    // Plain HTTP mode
109                    this.serve_connection(stream, peer_addr).await;
110                }
111            });
112        }
113    }
114
115    async fn handle_request(
116        self: Arc<Self>,
117        req: Request<Incoming>,
118    ) -> Result<Response<Full<Bytes>>, Infallible> {
119        tracing::debug!(
120            "HTTP request: {} {} (Host: {:?})",
121            req.method(),
122            req.uri(),
123            req.headers().get("host")
124        );
125
126        // Extract subdomain from Host header
127        let subdomain = match self.extract_subdomain(&req) {
128            Some(s) => s,
129            None => {
130                tracing::warn!("Request without valid subdomain");
131                return Ok(Response::builder()
132                    .status(StatusCode::BAD_REQUEST)
133                    .body(Full::new(Bytes::from("Invalid or missing subdomain")))
134                    .unwrap());
135            }
136        };
137
138        tracing::debug!("Forwarding to tunnel: {}", subdomain);
139
140        // Find the tunnel for this subdomain
141        let sender = match self.router.get_sender(&subdomain) {
142            Some(s) => s,
143            None => {
144                tracing::warn!("No tunnel for subdomain: {}", subdomain);
145                return Ok(Response::builder()
146                    .status(StatusCode::NOT_FOUND)
147                    .body(Full::new(Bytes::from(format!(
148                        "Tunnel not found for: {}",
149                        subdomain
150                    ))))
151                    .unwrap());
152            }
153        };
154
155        // Generate stream ID
156        let stream_id = self.next_stream_id();
157
158        // Convert request to protocol message
159        let method = req.method().to_string();
160        let uri = req.uri().to_string();
161
162        let headers: Vec<(String, String)> = req
163            .headers()
164            .iter()
165            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
166            .collect();
167
168        // Collect body
169        let body = match req.into_body().collect().await {
170            Ok(collected) => collected.to_bytes().to_vec(),
171            Err(e) => {
172                tracing::error!("Failed to read request body: {}", e);
173                return Ok(Response::builder()
174                    .status(StatusCode::INTERNAL_SERVER_ERROR)
175                    .body(Full::new(Bytes::from("Failed to read request body")))
176                    .unwrap());
177            }
178        };
179
180        // Create response channel
181        let (response_tx, response_rx) = oneshot::channel();
182
183        // Register pending response in shared registry
184        self.response_registry.insert(stream_id, response_tx);
185
186        // Send request to tunnel
187        let msg = ServerMessage::HttpRequest {
188            stream_id,
189            method,
190            uri,
191            headers,
192            body,
193        };
194
195        if let Err(e) = sender.send(msg).await {
196            tracing::error!("Failed to send request to tunnel: {}", e);
197            // Clean up pending response
198            self.response_registry.remove(&stream_id);
199
200            return Ok(Response::builder()
201                .status(StatusCode::BAD_GATEWAY)
202                .body(Full::new(Bytes::from("Tunnel connection lost")))
203                .unwrap());
204        }
205
206        // Wait for response with timeout
207        let timeout = Duration::from_secs(30);
208        match tokio::time::timeout(timeout, response_rx).await {
209            Ok(Ok(response_data)) => {
210                // Build HTTP response
211                let mut builder = Response::builder().status(response_data.status);
212
213                for (name, value) in response_data.headers {
214                    builder = builder.header(name, value);
215                }
216
217                Ok(builder
218                    .body(Full::new(Bytes::from(response_data.body)))
219                    .unwrap())
220            }
221            Ok(Err(_)) => {
222                // Channel closed (tunnel disconnected)
223                tracing::error!("Tunnel disconnected while waiting for response");
224                Ok(Response::builder()
225                    .status(StatusCode::BAD_GATEWAY)
226                    .body(Full::new(Bytes::from("Tunnel disconnected")))
227                    .unwrap())
228            }
229            Err(_) => {
230                // Timeout
231                tracing::error!("Timeout waiting for tunnel response");
232                // Clean up pending response
233                self.response_registry.remove(&stream_id);
234
235                Ok(Response::builder()
236                    .status(StatusCode::GATEWAY_TIMEOUT)
237                    .body(Full::new(Bytes::from("Tunnel response timeout")))
238                    .unwrap())
239            }
240        }
241    }
242
243    /// Extract subdomain from Host header
244    fn extract_subdomain(&self, req: &Request<Incoming>) -> Option<String> {
245        let host = req.headers().get("host")?.to_str().ok()?;
246
247        // Remove port if present
248        let host = host.split(':').next()?;
249
250        // Check if it ends with our base domain
251        if !host.ends_with(&self.base_domain) {
252            return None;
253        }
254
255        // Extract subdomain
256        let subdomain_part = host.strip_suffix(&format!(".{}", self.base_domain))?;
257
258        // Return only the first part (in case of multi-level subdomain)
259        Some(subdomain_part.split('.').next()?.to_string())
260    }
261}