systemprompt_api/services/server/
builder.rs1use anyhow::Result;
10use axum::Router;
11use axum::extract::DefaultBodyLimit;
12use systemprompt_runtime::AppContext;
13use systemprompt_traits::{StartupEvent, StartupEventExt, StartupEventSender};
14
15use super::routes::configure_routes;
16use crate::models::ServerConfig;
17use crate::services::middleware::{
18 AnalyticsMiddleware, CorsMiddleware, PublicContextMiddleware, SessionMiddleware,
19 inject_security_headers, inject_served_by, inject_trace_header, remove_trailing_slash,
20};
21
22pub use super::discovery::*;
23pub use super::health::handle_health;
24
25#[derive(Debug)]
26pub struct ApiServer {
27 router: Router,
28 _config: ServerConfig,
29 events: Option<StartupEventSender>,
30}
31
32impl ApiServer {
33 pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
34 Self::with_config(router, ServerConfig::default(), events)
35 }
36
37 pub const fn with_config(
38 router: Router,
39 config: ServerConfig,
40 events: Option<StartupEventSender>,
41 ) -> Self {
42 Self {
43 router,
44 _config: config,
45 events,
46 }
47 }
48
49 pub fn into_router(self) -> Router {
50 self.router
51 }
52
53 pub async fn serve<F>(self, addr: &str, shutdown: F) -> Result<()>
54 where
55 F: Future<Output = ()> + Send + 'static,
56 {
57 if let Some(ref tx) = self.events {
58 if tx
59 .unbounded_send(StartupEvent::ServerBinding {
60 address: addr.to_owned(),
61 })
62 .is_err()
63 {
64 tracing::debug!("Startup event receiver dropped");
65 }
66 }
67
68 let listener = self.create_listener(addr).await?;
69
70 if let Some(ref tx) = self.events {
71 tx.server_listening(addr, std::process::id());
72 }
73
74 axum::serve(
75 listener,
76 self.router
77 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
78 )
79 .with_graceful_shutdown(shutdown)
80 .await?;
81 Ok(())
82 }
83
84 async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
85 tokio::net::TcpListener::bind(addr)
86 .await
87 .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
88 }
89}
90
91pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
92 let rate_config = &ctx.config().rate_limits;
93
94 if rate_config.disabled {
95 if let Some(ref tx) = events {
96 tx.warning("Rate limiting disabled - development mode only");
97 }
98 }
99
100 let router = configure_routes(ctx, events.as_ref())?;
101 let router = apply_global_middleware(router, ctx)?;
102
103 Ok(ApiServer::new(router, events))
104}
105
106fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
107 let mut router = router;
108
109 router = router.layer(DefaultBodyLimit::max(2 * 1024 * 1024));
110
111 let analytics_middleware = AnalyticsMiddleware::new(ctx)?;
112 router = router.layer(axum::middleware::from_fn({
113 let middleware = analytics_middleware;
114 move |req, next| {
115 let middleware = middleware.clone();
116 async move { middleware.track_request(req, next).await }
117 }
118 }));
119
120 let global_context_middleware = PublicContextMiddleware::new();
121 router = router.layer(axum::middleware::from_fn({
122 let middleware = global_context_middleware;
123 move |req, next| async move { middleware.handle(req, next).await }
124 }));
125
126 let session_middleware = SessionMiddleware::new(ctx)?;
127 router = router.layer(axum::middleware::from_fn({
128 let middleware = session_middleware;
129 move |req, next| {
130 let middleware = middleware.clone();
131 async move { middleware.handle(req, next).await }
132 }
133 }));
134
135 let cors = CorsMiddleware::build_layer(ctx.config())?;
136 router = router.layer(cors);
137
138 router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
139
140 router = router.layer(axum::middleware::from_fn(inject_trace_header));
141
142 router = router.layer(axum::middleware::from_fn(inject_served_by));
143
144 if ctx.config().content_negotiation.enabled {
145 router = router.layer(axum::middleware::from_fn(
146 crate::services::middleware::content_negotiation_middleware,
147 ));
148 }
149
150 if ctx.config().security_headers.enabled {
151 let security_config = ctx.config().security_headers.clone();
152 router = router.layer(axum::middleware::from_fn(move |req, next| {
153 let config = security_config.clone();
154 inject_security_headers(config, req, next)
155 }));
156 }
157
158 Ok(router)
159}