Skip to main content

synapse_sdk/
http_server.rs

1//! HTTP/HTTPS server for service RPC endpoints
2//!
3//! Supports both plain HTTP (for development) and mTLS (for production).
4
5use anyhow::Result;
6use bytes::Bytes;
7use http_body_util::{BodyExt, Full};
8use hyper::{Method, Request, Response, StatusCode, server::conn::http1, service::service_fn};
9use std::{net::SocketAddr, path::Path, sync::Arc};
10use synapse_rpc::{
11    ParsedMessage, RpcServer,
12    codec::{ContentType, decode_message},
13    create_health_response, create_rpc_response, parse_message,
14};
15use tokio::net::TcpListener;
16use tokio_rustls::TlsAcceptor;
17use tracing::{debug, error, info, warn};
18
19/// Health check provider for the HTTP server
20pub type HealthProvider = Arc<dyn Fn() -> synapse_proto::HealthResponse + Send + Sync>;
21
22/// HTTP/HTTPS server for handling RPC requests
23pub struct HttpRpcServer {
24    rpc_server: Arc<RpcServer>,
25    bind_addr: SocketAddr,
26    tls_acceptor: Option<TlsAcceptor>,
27    health_provider: Option<HealthProvider>,
28}
29
30impl HttpRpcServer {
31    /// Create a new plain HTTP RPC server (no TLS)
32    pub fn new(rpc_server: Arc<RpcServer>, bind_addr: SocketAddr) -> Self {
33        Self {
34            rpc_server,
35            bind_addr,
36            tls_acceptor: None,
37            health_provider: None,
38        }
39    }
40
41    /// Create a new mTLS RPC server with a pre-built TLS acceptor
42    pub fn with_acceptor(
43        rpc_server: Arc<RpcServer>,
44        bind_addr: SocketAddr,
45        tls_acceptor: TlsAcceptor,
46    ) -> Self {
47        Self {
48            rpc_server,
49            bind_addr,
50            tls_acceptor: Some(tls_acceptor),
51            health_provider: None,
52        }
53    }
54
55    /// Create a new mTLS RPC server from file paths
56    ///
57    /// # Arguments
58    /// * `rpc_server` - The RPC server to handle requests
59    /// * `bind_addr` - Address to bind to
60    /// * `cert_path` - Path to server certificate PEM file
61    /// * `key_path` - Path to server private key PEM file
62    /// * `ca_cert_path` - Path to CA certificate for client verification
63    pub fn with_mtls<P: AsRef<Path>>(
64        rpc_server: Arc<RpcServer>,
65        bind_addr: SocketAddr,
66        cert_path: P,
67        key_path: P,
68        ca_cert_path: P,
69    ) -> Result<Self> {
70        let tls_acceptor =
71            synapse_mtls::server::build_mtls_acceptor(&cert_path, &key_path, &ca_cert_path)?;
72
73        Ok(Self {
74            rpc_server,
75            bind_addr,
76            tls_acceptor: Some(tls_acceptor),
77            health_provider: None,
78        })
79    }
80
81    /// Set a health provider function that returns HealthResponse
82    pub fn with_health_provider(mut self, provider: HealthProvider) -> Self {
83        self.health_provider = Some(provider);
84        self
85    }
86
87    /// Start the server (HTTP or HTTPS depending on configuration)
88    pub async fn serve(self: Arc<Self>) -> Result<()> {
89        let listener = TcpListener::bind(&self.bind_addr).await?;
90
91        if self.tls_acceptor.is_some() {
92            info!("HTTPS (mTLS) RPC server listening on {}", self.bind_addr);
93        } else {
94            info!("HTTP RPC server listening on {}", self.bind_addr);
95            warn!("Running without TLS - use with_mtls() for production");
96        }
97
98        loop {
99            let (stream, peer_addr) = listener.accept().await?;
100            let self_clone = Arc::clone(&self);
101
102            tokio::spawn(async move {
103                if let Some(ref tls_acceptor) = self_clone.tls_acceptor {
104                    // mTLS connection
105                    match tls_acceptor.accept(stream).await {
106                        Ok(tls_stream) => {
107                            let io = hyper_util::rt::TokioIo::new(tls_stream);
108                            let service = service_fn(move |req| {
109                                let self_clone = Arc::clone(&self_clone);
110                                async move { self_clone.handle_request(req).await }
111                            });
112
113                            if let Err(err) =
114                                http1::Builder::new().serve_connection(io, service).await
115                            {
116                                error!("Error serving TLS connection from {}: {}", peer_addr, err);
117                            }
118                        }
119                        Err(e) => {
120                            warn!("TLS handshake failed from {}: {}", peer_addr, e);
121                        }
122                    }
123                } else {
124                    // Plain HTTP connection
125                    let io = hyper_util::rt::TokioIo::new(stream);
126                    let service = service_fn(move |req| {
127                        let self_clone = Arc::clone(&self_clone);
128                        async move { self_clone.handle_request(req).await }
129                    });
130
131                    if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
132                        error!("Error serving connection from {}: {}", peer_addr, err);
133                    }
134                }
135            });
136        }
137    }
138
139    /// Detect content type from header
140    fn detect_content_type(header: Option<&str>) -> ContentType {
141        match header {
142            Some(h) if h.contains("application/json") => ContentType::Json,
143            Some(h) if h.contains("application/protobuf") => ContentType::Protobuf,
144            Some(h) if h.contains("application/x-protobuf") => ContentType::Protobuf,
145            _ => ContentType::Json,
146        }
147    }
148
149    /// Build a simple response, logging on the (unlikely) builder failure
150    fn build_response(status: StatusCode, body: impl Into<Bytes>) -> Result<Response<Full<Bytes>>> {
151        Response::builder()
152            .status(status)
153            .body(Full::new(body.into()))
154            .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e))
155    }
156
157    /// Build a response with headers
158    fn build_response_with_headers(
159        status: StatusCode,
160        headers: &[(&str, &str)],
161        body: impl Into<Bytes>,
162    ) -> Result<Response<Full<Bytes>>> {
163        let mut builder = Response::builder().status(status);
164        for (name, value) in headers {
165            builder = builder.header(*name, *value);
166        }
167        builder
168            .body(Full::new(body.into()))
169            .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e))
170    }
171
172    /// Handle an HTTP request (RPC or health pull)
173    async fn handle_request(
174        &self,
175        req: Request<hyper::body::Incoming>,
176    ) -> Result<Response<Full<Bytes>>> {
177        // Only accept POST
178        if req.method() != Method::POST {
179            return Self::build_response_with_headers(
180                StatusCode::METHOD_NOT_ALLOWED,
181                &[("Allow", "POST")],
182                "Method not allowed",
183            );
184        }
185
186        // Get content type
187        let content_type_header = req
188            .headers()
189            .get("content-type")
190            .and_then(|v| v.to_str().ok());
191
192        let content_type = Self::detect_content_type(content_type_header);
193
194        debug!("Received request: content_type={:?}", content_type_header);
195
196        // Read body
197        let body_bytes = req
198            .into_body()
199            .collect()
200            .await
201            .map_err(|e| anyhow::anyhow!("Failed to read body: {}", e))?
202            .to_bytes();
203
204        // Decode the SynapseMessage
205        let synapse_msg = match decode_message(&body_bytes, content_type) {
206            Ok(msg) => msg,
207            Err(e) => {
208                error!("Failed to decode request: {}", e);
209                return Self::build_response(
210                    StatusCode::BAD_REQUEST,
211                    format!("Invalid request: {}", e),
212                );
213            }
214        };
215
216        // Route based on message type
217        match parse_message(synapse_msg) {
218            ParsedMessage::RpcRequest {
219                request_id,
220                request,
221            } => {
222                self.handle_rpc_request(request_id, request, content_type)
223                    .await
224            }
225            ParsedMessage::HealthPull { request_id } => {
226                self.handle_health_pull(request_id, content_type).await
227            }
228            other => {
229                warn!(
230                    "Unexpected message type: {:?}",
231                    std::mem::discriminant(&other)
232                );
233                Self::build_response(
234                    StatusCode::BAD_REQUEST,
235                    "Expected RPC_REQUEST or HEALTH_PULL message",
236                )
237            }
238        }
239    }
240
241    /// Handle an RPC request
242    async fn handle_rpc_request(
243        &self,
244        request_id: Bytes,
245        rpc_request: synapse_rpc::RpcRequest,
246        content_type: ContentType,
247    ) -> Result<Response<Full<Bytes>>> {
248        debug!(
249            "Handling RPC: interface={}, method={}",
250            rpc_request.interface_id, rpc_request.method_id
251        );
252
253        // Handle through RPC server
254        let rpc_response = self.rpc_server.handle_request(rpc_request).await;
255
256        debug!("RPC response: status={}", rpc_response.status);
257
258        // Encode response
259        let response_body = match create_rpc_response(&request_id, rpc_response, content_type) {
260            Ok(body) => body,
261            Err(e) => {
262                error!("Failed to encode response: {}", e);
263                return Self::build_response(
264                    StatusCode::INTERNAL_SERVER_ERROR,
265                    "Failed to encode response",
266                );
267            }
268        };
269
270        Self::build_response_with_headers(
271            StatusCode::OK,
272            &[("Content-Type", content_type.mime_type())],
273            response_body,
274        )
275    }
276
277    /// Handle a health pull request (Synapse protocol)
278    async fn handle_health_pull(
279        &self,
280        request_id: Bytes,
281        content_type: ContentType,
282    ) -> Result<Response<Full<Bytes>>> {
283        debug!("Handling health pull request");
284
285        let health_response = self.get_health_response();
286
287        let response_body = match create_health_response(&request_id, health_response, content_type)
288        {
289            Ok(body) => body,
290            Err(e) => {
291                error!("Failed to encode health response: {}", e);
292                return Self::build_response(
293                    StatusCode::INTERNAL_SERVER_ERROR,
294                    "Failed to encode health response",
295                );
296            }
297        };
298
299        Self::build_response_with_headers(
300            StatusCode::OK,
301            &[("Content-Type", content_type.mime_type())],
302            response_body,
303        )
304    }
305
306    /// Get health response from provider or return default
307    fn get_health_response(&self) -> synapse_proto::HealthResponse {
308        if let Some(ref provider) = self.health_provider {
309            provider()
310        } else {
311            // Default healthy response
312            synapse_proto::HealthResponse {
313                instance_id: Bytes::new(),
314                status: synapse_proto::HealthStatus::Healthy as i32,
315                version: env!("CARGO_PKG_VERSION").to_string(),
316                uptime_ms: 0,
317                message: String::new(),
318            }
319        }
320    }
321}