systemprompt_api/services/server/
builder.rs1use anyhow::Result;
2use axum::Router;
3use axum::extract::DefaultBodyLimit;
4use systemprompt_runtime::AppContext;
5use systemprompt_traits::{AppContext as _, StartupEvent, StartupEventExt, StartupEventSender};
6
7use super::routes::configure_routes;
8use crate::models::ServerConfig;
9use crate::services::middleware::{
10 AnalyticsMiddleware, ContextMiddleware, CorsMiddleware, JwtContextExtractor, SessionMiddleware,
11 inject_security_headers, inject_served_by, inject_trace_header, remove_trailing_slash,
12};
13
14pub use super::discovery::*;
15pub use super::health::handle_health;
16
17#[derive(Debug)]
18pub struct ApiServer {
19 router: Router,
20 _config: ServerConfig,
21 events: Option<StartupEventSender>,
22}
23
24impl ApiServer {
25 pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
26 Self::with_config(router, ServerConfig::default(), events)
27 }
28
29 pub const fn with_config(
30 router: Router,
31 config: ServerConfig,
32 events: Option<StartupEventSender>,
33 ) -> Self {
34 Self {
35 router,
36 _config: config,
37 events,
38 }
39 }
40
41 pub async fn serve(self, addr: &str) -> Result<()> {
42 if let Some(ref tx) = self.events {
43 if tx
44 .unbounded_send(StartupEvent::ServerBinding {
45 address: addr.to_string(),
46 })
47 .is_err()
48 {
49 tracing::debug!("Startup event receiver dropped");
50 }
51 }
52
53 let listener = self.create_listener(addr).await?;
54
55 if let Some(ref tx) = self.events {
56 tx.server_listening(addr, std::process::id());
57 }
58
59 axum::serve(
60 listener,
61 self.router
62 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
63 )
64 .await?;
65 Ok(())
66 }
67
68 async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
69 tokio::net::TcpListener::bind(addr)
70 .await
71 .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
72 }
73}
74
75pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
76 let rate_config = &ctx.config().rate_limits;
77
78 if rate_config.disabled {
79 if let Some(ref tx) = events {
80 tx.warning("Rate limiting disabled - development mode only");
81 }
82 }
83
84 let router = configure_routes(ctx, events.as_ref())?;
85 let router = apply_global_middleware(router, ctx)?;
86
87 Ok(ApiServer::new(router, events))
88}
89
90fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
91 let mut router = router;
92
93 router = router.layer(DefaultBodyLimit::max(2 * 1024 * 1024));
94
95 let analytics_middleware = AnalyticsMiddleware::new(ctx)?;
96 router = router.layer(axum::middleware::from_fn({
97 let middleware = analytics_middleware;
98 move |req, next| {
99 let middleware = middleware.clone();
100 async move { middleware.track_request(req, next).await }
101 }
102 }));
103
104 let analytics = ctx
105 .analytics_provider()
106 .ok_or_else(|| anyhow::anyhow!("AnalyticsProvider required for JWT middleware"))?;
107 let user_provider = ctx
108 .user_provider()
109 .ok_or_else(|| anyhow::anyhow!("UserProvider required for JWT middleware"))?;
110 let jwt_extractor = JwtContextExtractor::new(analytics, user_provider);
111 let global_context_middleware = ContextMiddleware::public(jwt_extractor);
112 router = router.layer(axum::middleware::from_fn({
113 let middleware = global_context_middleware;
114 move |req, next| {
115 let middleware = middleware.clone();
116 async move { middleware.handle(req, next).await }
117 }
118 }));
119
120 let session_middleware = SessionMiddleware::new(ctx)?;
121 router = router.layer(axum::middleware::from_fn({
122 let middleware = session_middleware;
123 move |req, next| {
124 let middleware = middleware.clone();
125 async move { middleware.handle(req, next).await }
126 }
127 }));
128
129 let cors = CorsMiddleware::build_layer(ctx.config())?;
130 router = router.layer(cors);
131
132 router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
133
134 router = router.layer(axum::middleware::from_fn(inject_trace_header));
135
136 router = router.layer(axum::middleware::from_fn(inject_served_by));
137
138 if ctx.config().content_negotiation.enabled {
139 router = router.layer(axum::middleware::from_fn(
140 crate::services::middleware::content_negotiation_middleware,
141 ));
142 }
143
144 if ctx.config().security_headers.enabled {
145 let security_config = ctx.config().security_headers.clone();
146 router = router.layer(axum::middleware::from_fn(move |req, next| {
147 let config = security_config.clone();
148 inject_security_headers(config, req, next)
149 }));
150 }
151
152 Ok(router)
153}