1use anyhow::Result;
6use bytes::Bytes;
7use http_body_util::{BodyExt, Full};
8use hyper::{Method, Request, Response, StatusCode, server::conn::http1, service::service_fn};
9use std::{net::SocketAddr, path::Path, sync::Arc};
10use synapse_rpc::{
11 ParsedMessage, RpcServer,
12 codec::{ContentType, decode_message},
13 create_health_response, create_rpc_response, parse_message,
14};
15use tokio::net::TcpListener;
16use tokio_rustls::TlsAcceptor;
17use tracing::{debug, error, info, warn};
18
19pub type HealthProvider = Arc<dyn Fn() -> synapse_proto::HealthResponse + Send + Sync>;
21
22pub struct HttpRpcServer {
24 rpc_server: Arc<RpcServer>,
25 bind_addr: SocketAddr,
26 tls_acceptor: Option<TlsAcceptor>,
27 health_provider: Option<HealthProvider>,
28}
29
30impl HttpRpcServer {
31 pub fn new(rpc_server: Arc<RpcServer>, bind_addr: SocketAddr) -> Self {
33 Self {
34 rpc_server,
35 bind_addr,
36 tls_acceptor: None,
37 health_provider: None,
38 }
39 }
40
41 pub fn with_acceptor(
43 rpc_server: Arc<RpcServer>,
44 bind_addr: SocketAddr,
45 tls_acceptor: TlsAcceptor,
46 ) -> Self {
47 Self {
48 rpc_server,
49 bind_addr,
50 tls_acceptor: Some(tls_acceptor),
51 health_provider: None,
52 }
53 }
54
55 pub fn with_mtls<P: AsRef<Path>>(
64 rpc_server: Arc<RpcServer>,
65 bind_addr: SocketAddr,
66 cert_path: P,
67 key_path: P,
68 ca_cert_path: P,
69 ) -> Result<Self> {
70 let tls_acceptor =
71 synapse_mtls::server::build_mtls_acceptor(&cert_path, &key_path, &ca_cert_path)?;
72
73 Ok(Self {
74 rpc_server,
75 bind_addr,
76 tls_acceptor: Some(tls_acceptor),
77 health_provider: None,
78 })
79 }
80
81 pub fn with_health_provider(mut self, provider: HealthProvider) -> Self {
83 self.health_provider = Some(provider);
84 self
85 }
86
87 pub async fn serve(self: Arc<Self>) -> Result<()> {
89 let listener = TcpListener::bind(&self.bind_addr).await?;
90
91 if self.tls_acceptor.is_some() {
92 info!("HTTPS (mTLS) RPC server listening on {}", self.bind_addr);
93 } else {
94 info!("HTTP RPC server listening on {}", self.bind_addr);
95 warn!("Running without TLS - use with_mtls() for production");
96 }
97
98 loop {
99 let (stream, peer_addr) = listener.accept().await?;
100 let self_clone = Arc::clone(&self);
101
102 tokio::spawn(async move {
103 if let Some(ref tls_acceptor) = self_clone.tls_acceptor {
104 match tls_acceptor.accept(stream).await {
106 Ok(tls_stream) => {
107 let io = hyper_util::rt::TokioIo::new(tls_stream);
108 let service = service_fn(move |req| {
109 let self_clone = Arc::clone(&self_clone);
110 async move { self_clone.handle_request(req).await }
111 });
112
113 if let Err(err) =
114 http1::Builder::new().serve_connection(io, service).await
115 {
116 error!("Error serving TLS connection from {}: {}", peer_addr, err);
117 }
118 }
119 Err(e) => {
120 warn!("TLS handshake failed from {}: {}", peer_addr, e);
121 }
122 }
123 } else {
124 let io = hyper_util::rt::TokioIo::new(stream);
126 let service = service_fn(move |req| {
127 let self_clone = Arc::clone(&self_clone);
128 async move { self_clone.handle_request(req).await }
129 });
130
131 if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
132 error!("Error serving connection from {}: {}", peer_addr, err);
133 }
134 }
135 });
136 }
137 }
138
139 fn detect_content_type(header: Option<&str>) -> ContentType {
141 match header {
142 Some(h) if h.contains("application/json") => ContentType::Json,
143 Some(h) if h.contains("application/protobuf") => ContentType::Protobuf,
144 Some(h) if h.contains("application/x-protobuf") => ContentType::Protobuf,
145 _ => ContentType::Json,
146 }
147 }
148
149 fn build_response(status: StatusCode, body: impl Into<Bytes>) -> Result<Response<Full<Bytes>>> {
151 Response::builder()
152 .status(status)
153 .body(Full::new(body.into()))
154 .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e))
155 }
156
157 fn build_response_with_headers(
159 status: StatusCode,
160 headers: &[(&str, &str)],
161 body: impl Into<Bytes>,
162 ) -> Result<Response<Full<Bytes>>> {
163 let mut builder = Response::builder().status(status);
164 for (name, value) in headers {
165 builder = builder.header(*name, *value);
166 }
167 builder
168 .body(Full::new(body.into()))
169 .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e))
170 }
171
172 async fn handle_request(
174 &self,
175 req: Request<hyper::body::Incoming>,
176 ) -> Result<Response<Full<Bytes>>> {
177 if req.method() != Method::POST {
179 return Self::build_response_with_headers(
180 StatusCode::METHOD_NOT_ALLOWED,
181 &[("Allow", "POST")],
182 "Method not allowed",
183 );
184 }
185
186 let content_type_header = req
188 .headers()
189 .get("content-type")
190 .and_then(|v| v.to_str().ok());
191
192 let content_type = Self::detect_content_type(content_type_header);
193
194 debug!("Received request: content_type={:?}", content_type_header);
195
196 let body_bytes = req
198 .into_body()
199 .collect()
200 .await
201 .map_err(|e| anyhow::anyhow!("Failed to read body: {}", e))?
202 .to_bytes();
203
204 let synapse_msg = match decode_message(&body_bytes, content_type) {
206 Ok(msg) => msg,
207 Err(e) => {
208 error!("Failed to decode request: {}", e);
209 return Self::build_response(
210 StatusCode::BAD_REQUEST,
211 format!("Invalid request: {}", e),
212 );
213 }
214 };
215
216 match parse_message(synapse_msg) {
218 ParsedMessage::RpcRequest {
219 request_id,
220 request,
221 } => {
222 self.handle_rpc_request(request_id, request, content_type)
223 .await
224 }
225 ParsedMessage::HealthPull { request_id } => {
226 self.handle_health_pull(request_id, content_type).await
227 }
228 other => {
229 warn!(
230 "Unexpected message type: {:?}",
231 std::mem::discriminant(&other)
232 );
233 Self::build_response(
234 StatusCode::BAD_REQUEST,
235 "Expected RPC_REQUEST or HEALTH_PULL message",
236 )
237 }
238 }
239 }
240
241 async fn handle_rpc_request(
243 &self,
244 request_id: Bytes,
245 rpc_request: synapse_rpc::RpcRequest,
246 content_type: ContentType,
247 ) -> Result<Response<Full<Bytes>>> {
248 debug!(
249 "Handling RPC: interface={}, method={}",
250 rpc_request.interface_id, rpc_request.method_id
251 );
252
253 let rpc_response = self.rpc_server.handle_request(rpc_request).await;
255
256 debug!("RPC response: status={}", rpc_response.status);
257
258 let response_body = match create_rpc_response(&request_id, rpc_response, content_type) {
260 Ok(body) => body,
261 Err(e) => {
262 error!("Failed to encode response: {}", e);
263 return Self::build_response(
264 StatusCode::INTERNAL_SERVER_ERROR,
265 "Failed to encode response",
266 );
267 }
268 };
269
270 Self::build_response_with_headers(
271 StatusCode::OK,
272 &[("Content-Type", content_type.mime_type())],
273 response_body,
274 )
275 }
276
277 async fn handle_health_pull(
279 &self,
280 request_id: Bytes,
281 content_type: ContentType,
282 ) -> Result<Response<Full<Bytes>>> {
283 debug!("Handling health pull request");
284
285 let health_response = self.get_health_response();
286
287 let response_body = match create_health_response(&request_id, health_response, content_type)
288 {
289 Ok(body) => body,
290 Err(e) => {
291 error!("Failed to encode health response: {}", e);
292 return Self::build_response(
293 StatusCode::INTERNAL_SERVER_ERROR,
294 "Failed to encode health response",
295 );
296 }
297 };
298
299 Self::build_response_with_headers(
300 StatusCode::OK,
301 &[("Content-Type", content_type.mime_type())],
302 response_body,
303 )
304 }
305
306 fn get_health_response(&self) -> synapse_proto::HealthResponse {
308 if let Some(ref provider) = self.health_provider {
309 provider()
310 } else {
311 synapse_proto::HealthResponse {
313 instance_id: Bytes::new(),
314 status: synapse_proto::HealthStatus::Healthy as i32,
315 version: env!("CARGO_PKG_VERSION").to_string(),
316 uptime_ms: 0,
317 message: String::new(),
318 }
319 }
320 }
321}