Skip to main content

systemprompt_api/services/server/
builder.rs

1//! API server construction and the global middleware stack.
2//!
3//! [`setup_api_server`] composes the route tree and applies the global layers
4//! (body limit, analytics, context, session, CORS, trailing-slash, trace and
5//! served-by headers, content negotiation, security headers) in the order they
6//! must run, producing an [`ApiServer`] that binds and serves with graceful
7//! shutdown.
8
9use anyhow::Result;
10use axum::Router;
11use axum::extract::DefaultBodyLimit;
12use systemprompt_runtime::AppContext;
13use systemprompt_traits::{StartupEvent, StartupEventExt, StartupEventSender};
14
15use super::routes::configure_routes;
16use crate::models::ServerConfig;
17use crate::services::middleware::{
18    AnalyticsMiddleware, CorsMiddleware, PublicContextMiddleware, SessionMiddleware,
19    inject_security_headers, inject_served_by, inject_trace_header, remove_trailing_slash,
20};
21
22pub use super::discovery::*;
23pub use super::health::handle_health;
24
25#[derive(Debug)]
26pub struct ApiServer {
27    router: Router,
28    _config: ServerConfig,
29    events: Option<StartupEventSender>,
30}
31
32impl ApiServer {
33    pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
34        Self::with_config(router, ServerConfig::default(), events)
35    }
36
37    pub const fn with_config(
38        router: Router,
39        config: ServerConfig,
40        events: Option<StartupEventSender>,
41    ) -> Self {
42        Self {
43            router,
44            _config: config,
45            events,
46        }
47    }
48
49    pub fn into_router(self) -> Router {
50        self.router
51    }
52
53    pub async fn serve<F>(self, addr: &str, shutdown: F) -> Result<()>
54    where
55        F: Future<Output = ()> + Send + 'static,
56    {
57        if let Some(ref tx) = self.events {
58            if tx
59                .unbounded_send(StartupEvent::ServerBinding {
60                    address: addr.to_owned(),
61                })
62                .is_err()
63            {
64                tracing::debug!("Startup event receiver dropped");
65            }
66        }
67
68        let listener = self.create_listener(addr).await?;
69
70        if let Some(ref tx) = self.events {
71            tx.server_listening(addr, std::process::id());
72        }
73
74        axum::serve(
75            listener,
76            self.router
77                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
78        )
79        .with_graceful_shutdown(shutdown)
80        .await?;
81        Ok(())
82    }
83
84    async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
85        tokio::net::TcpListener::bind(addr)
86            .await
87            .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
88    }
89}
90
91pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
92    let rate_config = &ctx.config().rate_limits;
93
94    if rate_config.disabled {
95        if let Some(ref tx) = events {
96            tx.warning("Rate limiting disabled - development mode only");
97        }
98    }
99
100    let router = configure_routes(ctx, events.as_ref())?;
101    let router = apply_global_middleware(router, ctx)?;
102
103    Ok(ApiServer::new(router, events))
104}
105
106fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
107    let mut router = router;
108
109    router = router.layer(DefaultBodyLimit::max(2 * 1024 * 1024));
110
111    let analytics_middleware = AnalyticsMiddleware::new(ctx)?;
112    router = router.layer(axum::middleware::from_fn({
113        let middleware = analytics_middleware;
114        move |req, next| {
115            let middleware = middleware.clone();
116            async move { middleware.track_request(req, next).await }
117        }
118    }));
119
120    let global_context_middleware = PublicContextMiddleware::new();
121    router = router.layer(axum::middleware::from_fn({
122        let middleware = global_context_middleware;
123        move |req, next| async move { middleware.handle(req, next).await }
124    }));
125
126    let session_middleware = SessionMiddleware::new(ctx)?;
127    router = router.layer(axum::middleware::from_fn({
128        let middleware = session_middleware;
129        move |req, next| {
130            let middleware = middleware.clone();
131            async move { middleware.handle(req, next).await }
132        }
133    }));
134
135    let cors = CorsMiddleware::build_layer(ctx.config())?;
136    router = router.layer(cors);
137
138    router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
139
140    router = router.layer(axum::middleware::from_fn(inject_trace_header));
141
142    router = router.layer(axum::middleware::from_fn(inject_served_by));
143
144    if ctx.config().content_negotiation.enabled {
145        router = router.layer(axum::middleware::from_fn(
146            crate::services::middleware::content_negotiation_middleware,
147        ));
148    }
149
150    if ctx.config().security_headers.enabled {
151        let security_config = ctx.config().security_headers.clone();
152        router = router.layer(axum::middleware::from_fn(move |req, next| {
153            let config = security_config.clone();
154            inject_security_headers(config, req, next)
155        }));
156    }
157
158    Ok(router)
159}