Skip to main content

systemprompt_api/services/server/
builder.rs

1use anyhow::Result;
2use axum::Router;
3use axum::extract::DefaultBodyLimit;
4use systemprompt_runtime::AppContext;
5use systemprompt_traits::{AppContext as _, StartupEvent, StartupEventExt, StartupEventSender};
6
7use super::routes::configure_routes;
8use crate::models::ServerConfig;
9use crate::services::middleware::{
10    AnalyticsMiddleware, ContextMiddleware, CorsMiddleware, JwtContextExtractor, SessionMiddleware,
11    inject_security_headers, inject_served_by, inject_trace_header, remove_trailing_slash,
12};
13
14pub use super::discovery::*;
15pub use super::health::handle_health;
16
17#[derive(Debug)]
18pub struct ApiServer {
19    router: Router,
20    _config: ServerConfig,
21    events: Option<StartupEventSender>,
22}
23
24impl ApiServer {
25    pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
26        Self::with_config(router, ServerConfig::default(), events)
27    }
28
29    pub const fn with_config(
30        router: Router,
31        config: ServerConfig,
32        events: Option<StartupEventSender>,
33    ) -> Self {
34        Self {
35            router,
36            _config: config,
37            events,
38        }
39    }
40
41    pub fn into_router(self) -> Router {
42        self.router
43    }
44
45    pub async fn serve(self, addr: &str) -> Result<()> {
46        if let Some(ref tx) = self.events {
47            if tx
48                .unbounded_send(StartupEvent::ServerBinding {
49                    address: addr.to_owned(),
50                })
51                .is_err()
52            {
53                tracing::debug!("Startup event receiver dropped");
54            }
55        }
56
57        let listener = self.create_listener(addr).await?;
58
59        if let Some(ref tx) = self.events {
60            tx.server_listening(addr, std::process::id());
61        }
62
63        axum::serve(
64            listener,
65            self.router
66                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
67        )
68        .await?;
69        Ok(())
70    }
71
72    async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
73        tokio::net::TcpListener::bind(addr)
74            .await
75            .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
76    }
77}
78
79pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
80    let rate_config = &ctx.config().rate_limits;
81
82    if rate_config.disabled {
83        if let Some(ref tx) = events {
84            tx.warning("Rate limiting disabled - development mode only");
85        }
86    }
87
88    let router = configure_routes(ctx, events.as_ref())?;
89    let router = apply_global_middleware(router, ctx)?;
90
91    Ok(ApiServer::new(router, events))
92}
93
94fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
95    let mut router = router;
96
97    router = router.layer(DefaultBodyLimit::max(2 * 1024 * 1024));
98
99    let analytics_middleware = AnalyticsMiddleware::new(ctx)?;
100    router = router.layer(axum::middleware::from_fn({
101        let middleware = analytics_middleware;
102        move |req, next| {
103            let middleware = middleware.clone();
104            async move { middleware.track_request(req, next).await }
105        }
106    }));
107
108    let analytics = ctx
109        .analytics_provider()
110        .ok_or_else(|| anyhow::anyhow!("AnalyticsProvider required for JWT middleware"))?;
111    let user_provider = ctx
112        .user_provider()
113        .ok_or_else(|| anyhow::anyhow!("UserProvider required for JWT middleware"))?;
114    let jwt_extractor = JwtContextExtractor::new(analytics, user_provider);
115    let global_context_middleware = ContextMiddleware::public(jwt_extractor);
116    router = router.layer(axum::middleware::from_fn({
117        let middleware = global_context_middleware;
118        move |req, next| {
119            let middleware = middleware.clone();
120            async move { middleware.handle(req, next).await }
121        }
122    }));
123
124    let session_middleware = SessionMiddleware::new(ctx)?;
125    router = router.layer(axum::middleware::from_fn({
126        let middleware = session_middleware;
127        move |req, next| {
128            let middleware = middleware.clone();
129            async move { middleware.handle(req, next).await }
130        }
131    }));
132
133    let cors = CorsMiddleware::build_layer(ctx.config())?;
134    router = router.layer(cors);
135
136    router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
137
138    router = router.layer(axum::middleware::from_fn(inject_trace_header));
139
140    router = router.layer(axum::middleware::from_fn(inject_served_by));
141
142    if ctx.config().content_negotiation.enabled {
143        router = router.layer(axum::middleware::from_fn(
144            crate::services::middleware::content_negotiation_middleware,
145        ));
146    }
147
148    if ctx.config().security_headers.enabled {
149        let security_config = ctx.config().security_headers.clone();
150        router = router.layer(axum::middleware::from_fn(move |req, next| {
151            let config = security_config.clone();
152            inject_security_headers(config, req, next)
153        }));
154    }
155
156    Ok(router)
157}