Skip to main content

runway_middleware/
lib.rs

1use axum::{
2    Json, Router,
3    extract::Request,
4    middleware::{self, Next},
5    response::{IntoResponse, Response},
6    routing::get,
7};
8use serde_json::json;
9use tower::ServiceBuilder;
10use tower_http::{
11    compression::CompressionLayer,
12    cors::{Any, CorsLayer},
13    request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
14    trace::TraceLayer,
15};
16use tracing::Span;
17use uuid::Uuid;
18
19/// Runtime configuration for the middleware stack.
20///
21/// Populated by the binary at startup and passed into `stack`. The
22/// middleware crate itself never reads from the process environment.
23#[derive(Debug, Clone, Default)]
24pub struct MiddlewareConfig {
25    /// Comma-separated list of allowed CORS origins. Empty = allow any
26    /// (the local-development default; production binaries must populate
27    /// this from `ALLOWED_ORIGINS`).
28    pub allowed_origins: String,
29}
30
31/// Attach the full middleware stack to any Axum router.
32///
33/// Order matters: request-id is outermost, compression is innermost.
34pub fn stack<S>(router: Router<S>, cfg: &MiddlewareConfig) -> Router<S>
35where
36    S: Clone + Send + Sync + 'static,
37{
38    router.route("/health", get(health)).layer(
39        ServiceBuilder::new()
40            .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
41            .layer(PropagateRequestIdLayer::x_request_id())
42            .layer(
43                TraceLayer::new_for_http()
44                    .make_span_with(|req: &Request<_>| {
45                        let request_id = req
46                            .headers()
47                            .get("x-request-id")
48                            .and_then(|v| v.to_str().ok())
49                            .unwrap_or("-");
50                        tracing::info_span!(
51                            "request",
52                            method = %req.method(),
53                            uri    = %req.uri(),
54                            request_id,
55                        )
56                    })
57                    .on_response(
58                        |resp: &Response<_>, latency: std::time::Duration, _span: &Span| {
59                            tracing::info!(
60                                status = resp.status().as_u16(),
61                                latency_ms = latency.as_millis(),
62                                "response"
63                            );
64                        },
65                    ),
66            )
67            .layer(CompressionLayer::new())
68            .layer(
69                CorsLayer::new()
70                    .allow_methods(Any)
71                    .allow_headers(Any)
72                    .allow_origin(cors_origin(&cfg.allowed_origins)),
73            )
74            .layer(middleware::from_fn(error_formatter)),
75    )
76}
77
78/// Serve the router on the given port with graceful SIGTERM shutdown.
79///
80/// Call `.with_state()` on your router before passing it here. The
81/// binary is responsible for resolving the port (Cloud Run sets `PORT`
82/// in the env; local dev defaults to 8080) — this crate no longer reads
83/// the environment.
84pub async fn serve(app: Router, port: u16) {
85    let addr = format!("0.0.0.0:{port}");
86    let listener = tokio::net::TcpListener::bind(&addr)
87        .await
88        .expect("bind failed");
89    tracing::info!("listening on {addr}");
90
91    axum::serve(listener, app)
92        .with_graceful_shutdown(shutdown_signal())
93        .await
94        .expect("server error");
95}
96
97async fn health() -> impl IntoResponse {
98    Json(json!({ "status": "ok" }))
99}
100
101/// Catch unhandled errors and return a clean JSON body (no stack trace to client).
102async fn error_formatter(req: Request, next: Next) -> Response {
103    let resp = next.run(req).await;
104    if resp.status().is_server_error() {
105        let status = resp.status();
106        return (
107            status,
108            Json(json!({
109                "error": status.canonical_reason().unwrap_or("internal error"),
110                "request_id": Uuid::new_v4().to_string(),
111            })),
112        )
113            .into_response();
114    }
115    resp
116}
117
118fn cors_origin(allowed_origins: &str) -> tower_http::cors::AllowOrigin {
119    if allowed_origins.is_empty() {
120        return tower_http::cors::AllowOrigin::any();
121    }
122    let parsed: Vec<_> = allowed_origins
123        .split(',')
124        .filter_map(|o| o.trim().parse::<axum::http::HeaderValue>().ok())
125        .collect();
126    tower_http::cors::AllowOrigin::list(parsed)
127}
128
129async fn shutdown_signal() {
130    let ctrl_c = async {
131        tokio::signal::ctrl_c()
132            .await
133            .expect("ctrl-c handler failed");
134    };
135    #[cfg(unix)]
136    let terminate = async {
137        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
138            .expect("SIGTERM handler failed")
139            .recv()
140            .await;
141    };
142    #[cfg(not(unix))]
143    let terminate = std::future::pending::<()>();
144
145    tokio::select! {
146        _ = ctrl_c   => tracing::info!("ctrl-c received, shutting down"),
147        _ = terminate => tracing::info!("SIGTERM received, shutting down"),
148    }
149}