Skip to main content

socle/bootstrap/
serve.rs

1//! `serve()` implementation — wires adapters, binds the listener, runs until
2//! shutdown, then drains.
3
4use std::future::Future;
5use std::net::SocketAddr;
6
7use tower_http::trace::TraceLayer;
8
9use crate::bootstrap::builder::{ServiceBootstrap, ShutdownHook};
10use crate::bootstrap::ctx::BootstrapCtx;
11use crate::error::{Error, Result};
12
13impl ServiceBootstrap {
14    /// Run the service. Initialises every enabled integration in dependency
15    /// order, binds the listener, serves until SIGINT/SIGTERM, then drains.
16    pub async fn serve(self, addr: impl Into<String>) -> Result<()> {
17        let addr: SocketAddr = addr
18            .into()
19            .parse()
20            .map_err(|e: std::net::AddrParseError| Error::Config(e.to_string()))?;
21        let listener = tokio::net::TcpListener::bind(addr)
22            .await
23            .map_err(|e| Error::Bind(e.to_string()))?;
24        tracing::info!(%addr, service = %self.service_name, "socle: listening");
25        self.serve_with_shutdown(listener, shutdown_signal()).await
26    }
27
28    /// Run the service using a pre-bound listener and a caller-supplied shutdown
29    /// future. Useful for integration tests where you need to bind on port 0
30    /// and control when the server stops.
31    ///
32    /// ```rust,no_run
33    /// use axum::{Router, routing::get};
34    /// use socle::{BootstrapCtx, ServiceBootstrap};
35    /// use tokio::net::TcpListener;
36    ///
37    /// # #[tokio::main] async fn main() {
38    /// # let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
39    /// ServiceBootstrap::new("my-service")
40    ///     .with_router(|_: &BootstrapCtx| Router::new().route("/", get(|| async { "ok" })))
41    ///     .serve_with_shutdown(listener, std::future::pending())
42    ///     .await
43    ///     .unwrap()
44    /// # }
45    /// ```
46    pub async fn serve_with_shutdown(
47        self,
48        listener: tokio::net::TcpListener,
49        shutdown: impl Future<Output = ()> + Send + 'static,
50    ) -> Result<()> {
51        // Destructure early so we can mutate shutdown_hooks without partial-move issues.
52        let service_name = self.service_name;
53        #[cfg_attr(not(feature = "telemetry"), allow(unused_mut))]
54        let mut shutdown_hooks = self.shutdown_hooks;
55        let shutdown_timeout = self.shutdown_timeout;
56        let extra_layers = self.extra_layers;
57        let rate_limit_provider = self.rate_limit_provider;
58        let auth_provider = self.auth_provider;
59        let audit_sink = self.audit_sink;
60        let audit_filter = self.audit_filter;
61        let cors = self.cors;
62        let router_builder = self.router_builder;
63        let version = self.version;
64        let health_path = self.health_path;
65        let body_limit_bytes = self.body_limit_bytes;
66        let readiness_checks = self.readiness_checks;
67
68        #[cfg(feature = "database")]
69        let database_url = self.database_url;
70        #[cfg(feature = "database")]
71        let db_pool = self.db_pool;
72        #[cfg(feature = "database")]
73        let migrator = self.migrator;
74
75        #[cfg(feature = "ratelimit")]
76        let rate_limit = self.rate_limit;
77        #[cfg(feature = "ratelimit")]
78        let ratelimit_extractor = self.ratelimit_extractor;
79
80        #[cfg(feature = "openapi")]
81        let openapi = self.openapi;
82        #[cfg(feature = "openapi")]
83        let openapi_spec_path = self.openapi_spec_path;
84        #[cfg(feature = "openapi")]
85        let openapi_ui_path = self.openapi_ui_path;
86
87        #[cfg(feature = "telemetry")]
88        let telemetry_enabled = self.telemetry;
89        #[cfg(feature = "telemetry")]
90        let telemetry_provider = self.telemetry_provider;
91        #[cfg(feature = "telemetry")]
92        let telemetry_init = self.telemetry_init;
93
94        // 1. Telemetry first — priority: provider > init_fn > builtin.
95        #[cfg(feature = "telemetry")]
96        if telemetry_enabled {
97            if let Some(provider) = telemetry_provider {
98                provider
99                    .init(&service_name)
100                    .map_err(|e| Error::Telemetry(e.to_string()))?;
101                // Register the provider's flush as the last drain hook so it
102                // runs after all user-registered hooks have completed.
103                let provider = std::sync::Arc::new(provider);
104                let hook: crate::bootstrap::builder::ShutdownHookFn =
105                    std::sync::Arc::new(move || {
106                        let p = provider.clone();
107                        Box::pin(async move { p.on_shutdown().await })
108                    });
109                shutdown_hooks.push(ShutdownHook {
110                    name: "telemetry-flush".into(),
111                    hook,
112                    timeout: std::time::Duration::from_secs(30),
113                });
114            } else {
115                match telemetry_init {
116                    Some(init_fn) => {
117                        init_fn(&service_name).map_err(|e| Error::Telemetry(e.to_string()))?
118                    }
119                    None => crate::adapters::observability::telemetry::init_basic_tracing(),
120                }
121            }
122        }
123
124        // 2. Database pool — prefer pre-built pool over URL construction.
125        #[cfg(feature = "database")]
126        let db: Option<sqlx::PgPool> = if let Some(pool) = db_pool {
127            if let Some(ref migrator) = migrator {
128                tracing::warn!(
129                    service = %service_name,
130                    "socle: running migrations in-process"
131                );
132                migrator
133                    .run(&pool)
134                    .await
135                    .map_err(|e| Error::Database(format!("migrate: {e}")))?;
136                tracing::info!("socle: migrations applied successfully");
137            }
138            Some(pool)
139        } else if let Some(ref url) = database_url {
140            let pool = sqlx::PgPool::connect(url)
141                .await
142                .map_err(|e| Error::Database(e.to_string()))?;
143
144            if let Some(ref migrator) = migrator {
145                tracing::warn!(
146                    service = %service_name,
147                    "socle: running migrations in-process"
148                );
149                migrator
150                    .run(&pool)
151                    .await
152                    .map_err(|e| Error::Database(format!("migrate: {e}")))?;
153                tracing::info!("socle: migrations applied successfully");
154            }
155
156            Some(pool)
157        } else if migrator.is_some() {
158            return Err(Error::Config(
159                "with_migrations(...) requires with_database(...) to be called first".into(),
160            ));
161        } else {
162            None
163        };
164
165        // 3. Build the user router via ctx.
166        let ctx = BootstrapCtx {
167            service_name: service_name.clone(),
168            #[cfg(feature = "database")]
169            db: db.clone(),
170            extensions: std::collections::HashMap::new(),
171        };
172
173        let user_router = router_builder
174            .ok_or_else(|| Error::Config("with_router(...) was never called".into()))?(
175            &ctx
176        );
177
178        // 4. Mount health endpoints.
179        let health_router = crate::adapters::health::build_health_router(
180            &health_path,
181            &service_name,
182            &version,
183            readiness_checks.clone(),
184        );
185        #[cfg_attr(not(feature = "openapi"), allow(unused_mut))]
186        let mut user_router = user_router.merge(health_router);
187
188        // OpenAPI spec + Swagger UI.
189        #[cfg(feature = "openapi")]
190        if let Some(mut api) = openapi.clone() {
191            api = crate::adapters::openapi::merge_health_paths(api, &health_path);
192            user_router = crate::adapters::openapi::mount_openapi(
193                user_router,
194                api,
195                &openapi_spec_path,
196                &openapi_ui_path,
197            );
198        }
199
200        let user_router = user_router.fallback(crate::adapters::health::not_found_fallback);
201
202        // 5. Apply layers.
203        let mut app = user_router;
204
205        // Rate limit — priority: provider > built-in memory backend.
206        if let Some(provider) = rate_limit_provider {
207            app = provider.apply(app);
208        } else {
209            #[cfg(feature = "ratelimit-memory")]
210            if let Some(cfg) = rate_limit {
211                use crate::adapters::security::rate_limit::RateLimitLayer;
212                app = app.layer(RateLimitLayer::new_memory(
213                    cfg.limit,
214                    cfg.window_secs,
215                    ratelimit_extractor,
216                ));
217            }
218        }
219
220        // Auth — applied after rate-limit (unauthenticated requests still
221        // counted) and before extra_layers.
222        if let Some(provider) = auth_provider {
223            app = provider.apply(app);
224        }
225
226        // Audit — applied after auth so principal context is available.
227        if let Some(sink) = audit_sink {
228            let layer = crate::audit::AuditLayer::new(sink);
229            let layer = match audit_filter {
230                Some(f) => layer.with_filter(f),
231                None => layer,
232            };
233            app = app.layer(layer);
234        }
235
236        // Extra layers registered via with_layer() — applied innermost first.
237        for layer_fn in extra_layers {
238            app = layer_fn(app);
239        }
240
241        // Enrich bare error responses.
242        app = app.layer(axum::middleware::from_fn(
243            crate::adapters::security::enrich_error::enrich_error_response,
244        ));
245
246        // Cross-cutting tower-http layers.
247        use tower_http::catch_panic::CatchPanicLayer;
248        use tower_http::compression::CompressionLayer;
249        use tower_http::limit::RequestBodyLimitLayer;
250        use tower_http::request_id::{PropagateRequestIdLayer, SetRequestIdLayer};
251
252        let request_id_header = axum::http::HeaderName::from_static("x-request-id");
253
254        let trace_layer =
255            TraceLayer::new_for_http().make_span_with(|req: &axum::http::Request<_>| {
256                let request_id = crate::request_id::extract_request_id(req);
257                tracing::info_span!(
258                    "request",
259                    method = %req.method(),
260                    uri = %req.uri(),
261                    "request.id" = request_id,
262                )
263            });
264
265        // CORS is opt-in. Omitting with_cors_config() means no CORS headers are
266        // sent, which is safe for APIs not accessed from browsers.
267        if let Some(cors) = cors {
268            app = app.layer(cors);
269        }
270
271        app = app
272            .layer(CompressionLayer::new())
273            .layer(RequestBodyLimitLayer::new(body_limit_bytes))
274            .layer(CatchPanicLayer::custom(crate::handler_error::panic_handler))
275            .layer(trace_layer)
276            .layer(PropagateRequestIdLayer::new(request_id_header.clone()))
277            .layer(crate::request_id::RequestIdTaskLocalLayer)
278            .layer(SetRequestIdLayer::new(
279                request_id_header,
280                crate::request_id::MakeRequestUuidV7,
281            ));
282
283        // 6. Serve with caller-supplied shutdown signal.
284        let make_service = app.into_make_service_with_connect_info::<std::net::SocketAddr>();
285        let server = axum::serve(listener, make_service).with_graceful_shutdown(shutdown);
286
287        server.await.map_err(|e| Error::Serve(e.to_string()))?;
288
289        run_shutdown_hooks(shutdown_hooks, shutdown_timeout).await;
290        tracing::info!(service = %service_name, "socle: shutdown complete");
291        Ok(())
292    }
293}
294
295async fn run_shutdown_hooks(hooks: Vec<ShutdownHook>, _default_timeout: std::time::Duration) {
296    for hook in hooks.into_iter().rev() {
297        tracing::info!(hook = %hook.name, "socle: running shutdown hook");
298        match tokio::time::timeout(hook.timeout, (hook.hook)()).await {
299            Ok(()) => tracing::info!(hook = %hook.name, "socle: shutdown hook completed"),
300            Err(_) => tracing::error!(
301                hook = %hook.name,
302                timeout_secs = hook.timeout.as_secs(),
303                "socle: shutdown hook timed out"
304            ),
305        }
306    }
307}
308
309pub(crate) async fn shutdown_signal() {
310    use tokio::signal;
311    let ctrl_c = async {
312        signal::ctrl_c().await.ok();
313    };
314    #[cfg(unix)]
315    let terminate = async {
316        if let Ok(mut sig) = signal::unix::signal(signal::unix::SignalKind::terminate()) {
317            sig.recv().await;
318        }
319    };
320    #[cfg(not(unix))]
321    let terminate = std::future::pending::<()>();
322    tokio::select! {
323        _ = ctrl_c => {},
324        _ = terminate => {},
325    }
326}