Skip to main content

systemprompt_api/services/server/
builder.rs

1use anyhow::Result;
2use axum::extract::DefaultBodyLimit;
3use axum::routing::get;
4use axum::{Json, Router};
5use serde_json::json;
6use systemprompt_database::DatabaseQuery;
7use systemprompt_models::api::SingleResponse;
8use systemprompt_models::modules::ApiPaths;
9use systemprompt_models::AppPaths;
10use systemprompt_runtime::AppContext;
11use systemprompt_traits::{StartupEvent, StartupEventExt, StartupEventSender};
12
13use super::routes::configure_routes;
14use crate::models::ServerConfig;
15use crate::services::middleware::{
16    inject_trace_header, remove_trailing_slash, AnalyticsMiddleware, ContextMiddleware,
17    CorsMiddleware, JwtContextExtractor, SessionMiddleware,
18};
19
20const HEALTH_CHECK_QUERY: DatabaseQuery = DatabaseQuery::new("SELECT 1");
21
22#[derive(Debug)]
23pub struct ApiServer {
24    router: Router,
25    _config: ServerConfig,
26    events: Option<StartupEventSender>,
27}
28
29impl ApiServer {
30    pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
31        Self::with_config(router, ServerConfig::default(), events)
32    }
33
34    pub const fn with_config(
35        router: Router,
36        config: ServerConfig,
37        events: Option<StartupEventSender>,
38    ) -> Self {
39        Self {
40            router,
41            _config: config,
42            events,
43        }
44    }
45
46    pub async fn serve(self, addr: &str) -> Result<()> {
47        if let Some(ref tx) = self.events {
48            if tx
49                .unbounded_send(StartupEvent::ServerBinding {
50                    address: addr.to_string(),
51                })
52                .is_err()
53            {
54                tracing::debug!("Startup event receiver dropped");
55            }
56        }
57
58        let listener = self.create_listener(addr).await?;
59
60        if let Some(ref tx) = self.events {
61            tx.server_listening(addr, std::process::id());
62        }
63
64        axum::serve(
65            listener,
66            self.router
67                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
68        )
69        .await?;
70        Ok(())
71    }
72
73    async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
74        tokio::net::TcpListener::bind(addr)
75            .await
76            .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
77    }
78}
79
80pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
81    let rate_config = &ctx.config().rate_limits;
82
83    if rate_config.disabled {
84        if let Some(ref tx) = events {
85            tx.warning("Rate limiting disabled - development mode only");
86        }
87    }
88
89    let router = configure_routes(ctx, events.as_ref())?;
90    let router = apply_global_middleware(router, ctx)?;
91
92    Ok(ApiServer::new(router, events))
93}
94
95fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
96    let mut router = router;
97
98    router = router.layer(DefaultBodyLimit::max(100 * 1024 * 1024));
99
100    let analytics_middleware = AnalyticsMiddleware::new(ctx);
101    router = router.layer(axum::middleware::from_fn({
102        let middleware = analytics_middleware;
103        move |req, next| {
104            let middleware = middleware.clone();
105            async move { middleware.track_request(req, next).await }
106        }
107    }));
108
109    let jwt_extractor = JwtContextExtractor::new(
110        systemprompt_models::SecretsBootstrap::jwt_secret()?,
111        ctx.db_pool(),
112    );
113    let global_context_middleware = ContextMiddleware::public(jwt_extractor);
114    router = router.layer(axum::middleware::from_fn({
115        let middleware = global_context_middleware;
116        move |req, next| {
117            let middleware = middleware.clone();
118            async move { middleware.handle(req, next).await }
119        }
120    }));
121
122    let session_middleware = SessionMiddleware::new(ctx)?;
123    router = router.layer(axum::middleware::from_fn({
124        let middleware = session_middleware;
125        move |req, next| {
126            let middleware = middleware.clone();
127            async move { middleware.handle(req, next).await }
128        }
129    }));
130
131    let cors = CorsMiddleware::build_layer(ctx.config())?;
132    router = router.layer(cors);
133
134    router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
135
136    router = router.layer(axum::middleware::from_fn(inject_trace_header));
137
138    if ctx.config().content_negotiation.enabled {
139        router = router.layer(axum::middleware::from_fn(
140            crate::services::middleware::content_negotiation_middleware,
141        ));
142    }
143
144    Ok(router)
145}
146
147pub async fn handle_root_discovery(
148    axum::extract::State(ctx): axum::extract::State<AppContext>,
149) -> impl axum::response::IntoResponse {
150    let base = &ctx.config().api_external_url;
151    let data = json!({
152        "name": format!("{} API", ctx.config().sitename),
153        "version": "1.0.0",
154        "description": "systemprompt.io OS API Gateway",
155        "endpoints": {
156            "health": format!("{}{}", base, ApiPaths::HEALTH),
157            "oauth": {
158                "href": format!("{}{}", base, ApiPaths::OAUTH_BASE),
159                "description": "OAuth2/OIDC authentication and WebAuthn",
160                "endpoints": {
161                    "authorize": format!("{}{}", base, ApiPaths::OAUTH_AUTHORIZE),
162                    "token": format!("{}{}", base, ApiPaths::OAUTH_TOKEN),
163                    "userinfo": format!("{}{}/userinfo", base, ApiPaths::OAUTH_BASE),
164                    "introspect": format!("{}{}/introspect", base, ApiPaths::OAUTH_BASE),
165                    "revoke": format!("{}{}/revoke", base, ApiPaths::OAUTH_BASE),
166                    "webauthn": format!("{}{}/webauthn", base, ApiPaths::OAUTH_BASE)
167                }
168            },
169            "core": {
170                "href": format!("{}{}", base, ApiPaths::CORE_BASE),
171                "description": "Core conversation, task, and artifact management",
172                "endpoints": {
173                    "contexts": format!("{}{}", base, ApiPaths::CORE_CONTEXTS),
174                    "tasks": format!("{}{}", base, ApiPaths::CORE_TASKS),
175                    "artifacts": format!("{}{}", base, ApiPaths::CORE_ARTIFACTS)
176                }
177            },
178            "agents": {
179                "href": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
180                "description": "A2A protocol agent registry and proxy",
181                "endpoints": {
182                    "registry": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
183                    "proxy": format!("{}{}{{agent_id}}", base, ApiPaths::AGENTS_BASE)
184                }
185            },
186            "mcp": {
187                "href": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
188                "description": "MCP server registry and lifecycle management",
189                "endpoints": {
190                    "registry": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
191                    "proxy": format!("{}{}{{server_name}}", base, ApiPaths::MCP_BASE)
192                }
193            },
194            "stream": {
195                "href": format!("{}{}", base, ApiPaths::STREAM_BASE),
196                "description": "Server-Sent Events (SSE) for real-time updates",
197                "endpoints": {
198                    "contexts": format!("{}{}", base, ApiPaths::STREAM_CONTEXTS)
199                }
200            }
201        },
202        "wellknown": {
203            "oauth": format!("{}{}", base, ApiPaths::WELLKNOWN_OAUTH_SERVER),
204            "agent": format!("{}{}", base, ApiPaths::WELLKNOWN_AGENT_CARD)
205        }
206    });
207
208    Json(SingleResponse::new(data))
209}
210
211#[cfg(target_os = "linux")]
212fn parse_proc_status_kb(content: &str, key: &str) -> Option<u64> {
213    content
214        .lines()
215        .find(|line| line.starts_with(key))
216        .and_then(|line| {
217            line.split_whitespace()
218                .nth(1)
219                .and_then(|v| v.parse::<u64>().ok())
220        })
221}
222
223#[cfg(target_os = "linux")]
224fn get_process_memory() -> Option<serde_json::Value> {
225    let content = std::fs::read_to_string("/proc/self/status").ok()?;
226
227    let rss_kb = parse_proc_status_kb(&content, "VmRSS:");
228    let virt_kb = parse_proc_status_kb(&content, "VmSize:");
229    let peak_kb = parse_proc_status_kb(&content, "VmPeak:");
230
231    Some(json!({
232        "rss_mb": rss_kb.map(|kb| kb / 1024),
233        "virtual_mb": virt_kb.map(|kb| kb / 1024),
234        "peak_mb": peak_kb.map(|kb| kb / 1024)
235    }))
236}
237
238#[cfg(not(target_os = "linux"))]
239fn get_process_memory() -> Option<serde_json::Value> {
240    None
241}
242
243pub async fn handle_health(
244    axum::extract::State(ctx): axum::extract::State<AppContext>,
245) -> impl axum::response::IntoResponse {
246    use axum::http::StatusCode;
247    use systemprompt_database::{DatabaseProvider, ServiceRepository};
248
249    let start = std::time::Instant::now();
250
251    let (db_status, db_latency_ms) = {
252        let db_start = std::time::Instant::now();
253        let status = match ctx.db_pool().fetch_optional(&HEALTH_CHECK_QUERY, &[]).await {
254            Ok(_) => "healthy",
255            Err(_) => "unhealthy",
256        };
257        (status, db_start.elapsed().as_millis())
258    };
259
260    let service_repo = ServiceRepository::new(ctx.db_pool().clone());
261
262    let (agent_count, agent_status) = match service_repo.count_running_services("agent").await {
263        Ok(count) if count > 0 => (count, "healthy"),
264        Ok(_) => (0, "none"),
265        Err(_) => (0, "error"),
266    };
267
268    let (mcp_count, mcp_status) = match service_repo.count_running_services("mcp").await {
269        Ok(count) if count > 0 => (count, "healthy"),
270        Ok(_) => (0, "none"),
271        Err(_) => (0, "error"),
272    };
273
274    let web_dir = AppPaths::get()
275        .map(|p| p.web().dist().to_path_buf())
276        .unwrap_or_else(|e| {
277            tracing::debug!(error = %e, "Failed to get web dist path, using default");
278            std::path::PathBuf::from("/var/www/html/dist")
279        });
280    let sitemap_exists = web_dir.join("sitemap.xml").exists();
281    let index_exists = web_dir.join("index.html").exists();
282
283    let db_healthy = db_status == "healthy";
284    let services_ok = agent_status != "error" && mcp_status != "error";
285    let content_ok = sitemap_exists && index_exists;
286
287    let (overall_status, http_status) = match (db_healthy, services_ok && content_ok) {
288        (false, _) => ("unhealthy", StatusCode::SERVICE_UNAVAILABLE),
289        (true, false) => ("degraded", StatusCode::OK),
290        (true, true) => ("healthy", StatusCode::OK),
291    };
292
293    let check_duration_ms = start.elapsed().as_millis();
294    let memory = get_process_memory();
295
296    let data = json!({
297        "status": overall_status,
298        "timestamp": chrono::Utc::now().to_rfc3339(),
299        "version": env!("CARGO_PKG_VERSION"),
300        "checks": {
301            "database": {
302                "status": db_status,
303                "latency_ms": db_latency_ms
304            },
305            "agents": {
306                "status": agent_status,
307                "count": agent_count
308            },
309            "mcp": {
310                "status": mcp_status,
311                "count": mcp_count
312            },
313            "static_content": {
314                "status": if content_ok { "healthy" } else { "degraded" },
315                "index_html": index_exists,
316                "sitemap_xml": sitemap_exists
317            }
318        },
319        "memory": memory,
320        "response_time_ms": check_duration_ms
321    });
322
323    (http_status, Json(data))
324}
325
326pub async fn handle_core_discovery(
327    axum::extract::State(ctx): axum::extract::State<AppContext>,
328) -> impl axum::response::IntoResponse {
329    let base = &ctx.config().api_external_url;
330    let data = json!({
331        "name": "Core Services",
332        "description": "Core conversation, task, and artifact management APIs",
333        "endpoints": {
334            "contexts": {
335                "href": format!("{}{}", base, ApiPaths::CORE_CONTEXTS),
336                "description": "Conversation context management",
337                "methods": ["GET", "POST", "DELETE"]
338            },
339            "tasks": {
340                "href": format!("{}{}", base, ApiPaths::CORE_TASKS),
341                "description": "Task management for agent operations",
342                "methods": ["GET", "POST", "PUT", "DELETE"]
343            },
344            "artifacts": {
345                "href": format!("{}{}", base, ApiPaths::CORE_ARTIFACTS),
346                "description": "Artifact storage and retrieval",
347                "methods": ["GET", "POST", "DELETE"]
348            },
349            "oauth": {
350                "href": format!("{}{}", base, ApiPaths::OAUTH_BASE),
351                "description": "OAuth2/OIDC authentication endpoints"
352            }
353        }
354    });
355    Json(SingleResponse::new(data))
356}
357
358pub async fn handle_agents_discovery(
359    axum::extract::State(ctx): axum::extract::State<AppContext>,
360) -> impl axum::response::IntoResponse {
361    let base = &ctx.config().api_external_url;
362    let data = json!({
363        "name": "Agent Services",
364        "description": "A2A protocol agent registry and proxy",
365        "endpoints": {
366            "registry": {
367                "href": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
368                "description": "List and discover available agents",
369                "methods": ["GET"]
370            },
371            "proxy": {
372                "href": format!("{}{}/<agent_id>/", base, ApiPaths::AGENTS_BASE),
373                "description": "Proxy requests to specific agents",
374                "methods": ["GET", "POST"]
375            }
376        }
377    });
378    Json(SingleResponse::new(data))
379}
380
381pub async fn handle_mcp_discovery(
382    axum::extract::State(ctx): axum::extract::State<AppContext>,
383) -> impl axum::response::IntoResponse {
384    let base = &ctx.config().api_external_url;
385    let data = json!({
386        "name": "MCP Services",
387        "description": "Model Context Protocol server registry and proxy",
388        "endpoints": {
389            "registry": {
390                "href": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
391                "description": "List and discover available MCP servers",
392                "methods": ["GET"]
393            },
394            "proxy": {
395                "href": format!("{}{}/<server_name>/mcp", base, ApiPaths::MCP_BASE),
396                "description": "Proxy requests to specific MCP servers",
397                "methods": ["GET", "POST"]
398            }
399        }
400    });
401    Json(SingleResponse::new(data))
402}
403
404pub fn discovery_router(ctx: &AppContext) -> Router {
405    Router::new()
406        .route(ApiPaths::DISCOVERY, get(handle_root_discovery))
407        .route(ApiPaths::HEALTH, get(handle_health))
408        .route("/health", get(handle_health))
409        .route(ApiPaths::CORE_BASE, get(handle_core_discovery))
410        .route(ApiPaths::AGENTS_BASE, get(handle_agents_discovery))
411        .route(ApiPaths::MCP_BASE, get(handle_mcp_discovery))
412        .with_state(ctx.clone())
413}