1use anyhow::Result;
2use axum::extract::DefaultBodyLimit;
3use axum::routing::get;
4use axum::{Json, Router};
5use serde_json::json;
6use systemprompt_database::DatabaseQuery;
7use systemprompt_models::api::SingleResponse;
8use systemprompt_models::modules::ApiPaths;
9use systemprompt_models::AppPaths;
10use systemprompt_runtime::AppContext;
11use systemprompt_traits::{StartupEvent, StartupEventExt, StartupEventSender};
12
13use super::routes::configure_routes;
14use crate::models::ServerConfig;
15use crate::services::middleware::{
16 inject_security_headers, inject_trace_header, remove_trailing_slash, AnalyticsMiddleware,
17 ContextMiddleware, CorsMiddleware, JwtContextExtractor, SessionMiddleware,
18};
19
20const HEALTH_CHECK_QUERY: DatabaseQuery = DatabaseQuery::new("SELECT 1");
21
22const DB_SIZE_QUERY: DatabaseQuery = DatabaseQuery::new(
23 "SELECT pg_database_size(current_database()) as size_bytes, current_database() as db_name",
24);
25
26const TABLE_SIZES_QUERY: DatabaseQuery = DatabaseQuery::new(
27 "SELECT relname as table_name, pg_total_relation_size(relid) as total_bytes, n_live_tup as \
28 row_estimate FROM pg_stat_user_tables ORDER BY pg_total_relation_size(relid) DESC LIMIT 15",
29);
30
31const TABLE_COUNT_QUERY: DatabaseQuery =
32 DatabaseQuery::new("SELECT COUNT(*) as count FROM pg_stat_user_tables");
33
34const AUDIT_LOG_QUERY: DatabaseQuery = DatabaseQuery::new(
35 "SELECT COUNT(*) as row_count, pg_total_relation_size('audit_log') as size_bytes, \
36 MIN(created_at) as oldest, MAX(created_at) as newest FROM audit_log",
37);
38
39#[derive(Debug)]
40pub struct ApiServer {
41 router: Router,
42 _config: ServerConfig,
43 events: Option<StartupEventSender>,
44}
45
46impl ApiServer {
47 pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
48 Self::with_config(router, ServerConfig::default(), events)
49 }
50
51 pub const fn with_config(
52 router: Router,
53 config: ServerConfig,
54 events: Option<StartupEventSender>,
55 ) -> Self {
56 Self {
57 router,
58 _config: config,
59 events,
60 }
61 }
62
63 pub async fn serve(self, addr: &str) -> Result<()> {
64 if let Some(ref tx) = self.events {
65 if tx
66 .unbounded_send(StartupEvent::ServerBinding {
67 address: addr.to_string(),
68 })
69 .is_err()
70 {
71 tracing::debug!("Startup event receiver dropped");
72 }
73 }
74
75 let listener = self.create_listener(addr).await?;
76
77 if let Some(ref tx) = self.events {
78 tx.server_listening(addr, std::process::id());
79 }
80
81 axum::serve(
82 listener,
83 self.router
84 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
85 )
86 .await?;
87 Ok(())
88 }
89
90 async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
91 tokio::net::TcpListener::bind(addr)
92 .await
93 .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
94 }
95}
96
97pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
98 let rate_config = &ctx.config().rate_limits;
99
100 if rate_config.disabled {
101 if let Some(ref tx) = events {
102 tx.warning("Rate limiting disabled - development mode only");
103 }
104 }
105
106 let router = configure_routes(ctx, events.as_ref())?;
107 let router = apply_global_middleware(router, ctx)?;
108
109 Ok(ApiServer::new(router, events))
110}
111
112fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
113 let mut router = router;
114
115 router = router.layer(DefaultBodyLimit::max(100 * 1024 * 1024));
116
117 let analytics_middleware = AnalyticsMiddleware::new(ctx)?;
118 router = router.layer(axum::middleware::from_fn({
119 let middleware = analytics_middleware;
120 move |req, next| {
121 let middleware = middleware.clone();
122 async move { middleware.track_request(req, next).await }
123 }
124 }));
125
126 let jwt_extractor = JwtContextExtractor::new(
127 systemprompt_models::SecretsBootstrap::jwt_secret()?,
128 ctx.db_pool(),
129 );
130 let global_context_middleware = ContextMiddleware::public(jwt_extractor);
131 router = router.layer(axum::middleware::from_fn({
132 let middleware = global_context_middleware;
133 move |req, next| {
134 let middleware = middleware.clone();
135 async move { middleware.handle(req, next).await }
136 }
137 }));
138
139 let session_middleware = SessionMiddleware::new(ctx)?;
140 router = router.layer(axum::middleware::from_fn({
141 let middleware = session_middleware;
142 move |req, next| {
143 let middleware = middleware.clone();
144 async move { middleware.handle(req, next).await }
145 }
146 }));
147
148 let cors = CorsMiddleware::build_layer(ctx.config())?;
149 router = router.layer(cors);
150
151 router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
152
153 router = router.layer(axum::middleware::from_fn(inject_trace_header));
154
155 if ctx.config().content_negotiation.enabled {
156 router = router.layer(axum::middleware::from_fn(
157 crate::services::middleware::content_negotiation_middleware,
158 ));
159 }
160
161 if ctx.config().security_headers.enabled {
162 let security_config = ctx.config().security_headers.clone();
163 router = router.layer(axum::middleware::from_fn(move |req, next| {
164 let config = security_config.clone();
165 inject_security_headers(config, req, next)
166 }));
167 }
168
169 Ok(router)
170}
171
172pub async fn handle_root_discovery(
173 axum::extract::State(ctx): axum::extract::State<AppContext>,
174) -> impl axum::response::IntoResponse {
175 let base = &ctx.config().api_external_url;
176 let data = json!({
177 "name": format!("{} API", ctx.config().sitename),
178 "version": "1.0.0",
179 "description": "systemprompt.io OS API Gateway",
180 "endpoints": {
181 "health": format!("{}{}", base, ApiPaths::HEALTH),
182 "oauth": {
183 "href": format!("{}{}", base, ApiPaths::OAUTH_BASE),
184 "description": "OAuth2/OIDC authentication and WebAuthn",
185 "endpoints": {
186 "authorize": format!("{}{}", base, ApiPaths::OAUTH_AUTHORIZE),
187 "token": format!("{}{}", base, ApiPaths::OAUTH_TOKEN),
188 "userinfo": format!("{}{}/userinfo", base, ApiPaths::OAUTH_BASE),
189 "introspect": format!("{}{}/introspect", base, ApiPaths::OAUTH_BASE),
190 "revoke": format!("{}{}/revoke", base, ApiPaths::OAUTH_BASE),
191 "webauthn": format!("{}{}/webauthn", base, ApiPaths::OAUTH_BASE)
192 }
193 },
194 "core": {
195 "href": format!("{}{}", base, ApiPaths::CORE_BASE),
196 "description": "Core conversation, task, and artifact management",
197 "endpoints": {
198 "contexts": format!("{}{}", base, ApiPaths::CORE_CONTEXTS),
199 "tasks": format!("{}{}", base, ApiPaths::CORE_TASKS),
200 "artifacts": format!("{}{}", base, ApiPaths::CORE_ARTIFACTS)
201 }
202 },
203 "agents": {
204 "href": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
205 "description": "A2A protocol agent registry and proxy",
206 "endpoints": {
207 "registry": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
208 "proxy": format!("{}{}{{agent_id}}", base, ApiPaths::AGENTS_BASE)
209 }
210 },
211 "mcp": {
212 "href": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
213 "description": "MCP server registry and lifecycle management",
214 "endpoints": {
215 "registry": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
216 "proxy": format!("{}{}{{server_name}}", base, ApiPaths::MCP_BASE)
217 }
218 },
219 "stream": {
220 "href": format!("{}{}", base, ApiPaths::STREAM_BASE),
221 "description": "Server-Sent Events (SSE) for real-time updates",
222 "endpoints": {
223 "contexts": format!("{}{}", base, ApiPaths::STREAM_CONTEXTS)
224 }
225 }
226 },
227 "wellknown": {
228 "oauth": format!("{}{}", base, ApiPaths::WELLKNOWN_OAUTH_SERVER),
229 "agent": format!("{}{}", base, ApiPaths::WELLKNOWN_AGENT_CARD)
230 }
231 });
232
233 Json(SingleResponse::new(data))
234}
235
236#[cfg(target_os = "linux")]
237fn parse_proc_status_kb(content: &str, key: &str) -> Option<u64> {
238 content
239 .lines()
240 .find(|line| line.starts_with(key))
241 .and_then(|line| {
242 line.split_whitespace()
243 .nth(1)
244 .and_then(|v| v.parse::<u64>().ok())
245 })
246}
247
248#[cfg(target_os = "linux")]
249fn get_process_memory() -> Option<serde_json::Value> {
250 let content = std::fs::read_to_string("/proc/self/status").ok()?;
251
252 let rss_kb = parse_proc_status_kb(&content, "VmRSS:");
253 let virt_kb = parse_proc_status_kb(&content, "VmSize:");
254 let peak_kb = parse_proc_status_kb(&content, "VmPeak:");
255
256 Some(json!({
257 "rss_mb": rss_kb.map(|kb| kb / 1024),
258 "virtual_mb": virt_kb.map(|kb| kb / 1024),
259 "peak_mb": peak_kb.map(|kb| kb / 1024)
260 }))
261}
262
263#[cfg(not(target_os = "linux"))]
264fn get_process_memory() -> Option<serde_json::Value> {
265 None
266}
267
268#[allow(clippy::cast_precision_loss)]
269fn human_bytes(bytes: i64) -> String {
270 const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
271 let mut size = bytes as f64;
272 let mut idx = 0;
273 while size >= 1024.0 && idx < UNITS.len() - 1 {
274 size /= 1024.0;
275 idx += 1;
276 }
277 format!("{size:.1} {}", UNITS[idx])
278}
279
280async fn get_disk_usage() -> Option<serde_json::Value> {
281 let output = tokio::process::Command::new("df")
282 .args(["-B1", "--output=size,used,avail", "."])
283 .output()
284 .await
285 .ok()?;
286
287 let stdout = String::from_utf8_lossy(&output.stdout);
288 let line = stdout.lines().nth(1)?;
289 let parts: Vec<&str> = line.split_whitespace().collect();
290 if parts.len() < 3 {
291 return None;
292 }
293
294 let total: u64 = parts[0].parse().ok()?;
295 let used: u64 = parts[1].parse().ok()?;
296 let available: u64 = parts[2].parse().ok()?;
297
298 #[allow(clippy::cast_precision_loss)]
299 let usage_pct = if total > 0 {
300 (used as f64 / total as f64) * 100.0
301 } else {
302 0.0
303 };
304
305 #[allow(clippy::cast_possible_wrap)]
306 Some(json!({
307 "total": human_bytes(total as i64),
308 "used": human_bytes(used as i64),
309 "available": human_bytes(available as i64),
310 "usage_percent": (usage_pct * 10.0).round() / 10.0
311 }))
312}
313
314async fn get_system_stats(
315 db: &dyn systemprompt_database::DatabaseProvider,
316) -> Option<serde_json::Value> {
317 let db_size_fut = db.fetch_one(&DB_SIZE_QUERY, &[]);
318 let table_sizes_fut = db.fetch_all(&TABLE_SIZES_QUERY, &[]);
319 let table_count_fut = db.fetch_one(&TABLE_COUNT_QUERY, &[]);
320 let audit_fut = db.fetch_optional(&AUDIT_LOG_QUERY, &[]);
321 let disk_fut = get_disk_usage();
322
323 let (db_size, table_sizes, table_count, audit, disk) = tokio::join!(
324 db_size_fut,
325 table_sizes_fut,
326 table_count_fut,
327 audit_fut,
328 disk_fut
329 );
330
331 let database =
332 if let (Ok(size_row), Ok(tables), Ok(count_row)) = (&db_size, &table_sizes, &table_count) {
333 let size_bytes = size_row
334 .get("size_bytes")
335 .and_then(serde_json::Value::as_i64)
336 .unwrap_or(0);
337 let db_name = size_row
338 .get("db_name")
339 .and_then(serde_json::Value::as_str)
340 .unwrap_or("unknown");
341 let tbl_count = count_row
342 .get("count")
343 .and_then(serde_json::Value::as_i64)
344 .unwrap_or(0);
345
346 let top_tables: Vec<serde_json::Value> = tables
347 .iter()
348 .map(|row| {
349 let name = row
350 .get("table_name")
351 .and_then(serde_json::Value::as_str)
352 .unwrap_or("?");
353 let total = row
354 .get("total_bytes")
355 .and_then(serde_json::Value::as_i64)
356 .unwrap_or(0);
357 let rows = row
358 .get("row_estimate")
359 .and_then(serde_json::Value::as_i64)
360 .unwrap_or(0);
361 json!({
362 "table_name": name,
363 "total_size": human_bytes(total),
364 "total_size_bytes": total,
365 "row_estimate": rows
366 })
367 })
368 .collect();
369
370 Some(json!({
371 "name": db_name,
372 "total_size": human_bytes(size_bytes),
373 "total_size_bytes": size_bytes,
374 "table_count": tbl_count,
375 "top_tables": top_tables
376 }))
377 } else {
378 None
379 };
380
381 let logs = audit.ok().flatten().map(|row| {
382 let row_count = row
383 .get("row_count")
384 .and_then(serde_json::Value::as_i64)
385 .unwrap_or(0);
386 let size_bytes = row
387 .get("size_bytes")
388 .and_then(serde_json::Value::as_i64)
389 .unwrap_or(0);
390 json!({
391 "audit_rows": row_count,
392 "audit_size": human_bytes(size_bytes),
393 "audit_size_bytes": size_bytes,
394 "oldest": row.get("oldest"),
395 "newest": row.get("newest")
396 })
397 });
398
399 Some(json!({
400 "database": database,
401 "disk": disk,
402 "logs": logs
403 }))
404}
405
406pub async fn handle_health(
407 axum::extract::State(ctx): axum::extract::State<AppContext>,
408) -> impl axum::response::IntoResponse {
409 use axum::http::StatusCode;
410 use systemprompt_database::{DatabaseProvider, ServiceRepository};
411
412 let start = std::time::Instant::now();
413
414 let (db_status, db_latency_ms) = {
415 let db_start = std::time::Instant::now();
416 let status = match ctx.db_pool().fetch_optional(&HEALTH_CHECK_QUERY, &[]).await {
417 Ok(_) => "healthy",
418 Err(_) => "unhealthy",
419 };
420 (status, db_start.elapsed().as_millis())
421 };
422
423 let (agent_count, agent_status, mcp_count, mcp_status) =
424 match ServiceRepository::new(ctx.db_pool()) {
425 Ok(service_repo) => {
426 let (ac, as_) = match service_repo.count_running_services("agent").await {
427 Ok(count) if count > 0 => (count, "healthy"),
428 Ok(_) => (0, "none"),
429 Err(_) => (0, "error"),
430 };
431 let (mc, ms) = match service_repo.count_running_services("mcp").await {
432 Ok(count) if count > 0 => (count, "healthy"),
433 Ok(_) => (0, "none"),
434 Err(_) => (0, "error"),
435 };
436 (ac, as_, mc, ms)
437 },
438 Err(_) => (0, "error", 0, "error"),
439 };
440
441 let web_dir = AppPaths::get()
442 .map(|p| p.web().dist().to_path_buf())
443 .unwrap_or_else(|e| {
444 tracing::debug!(error = %e, "Failed to get web dist path, using default");
445 std::path::PathBuf::from("/var/www/html/dist")
446 });
447 let sitemap_exists = web_dir.join("sitemap.xml").exists();
448 let index_exists = web_dir.join("index.html").exists();
449
450 let db_healthy = db_status == "healthy";
451 let services_ok = agent_status != "error" && mcp_status != "error";
452 let content_ok = sitemap_exists && index_exists;
453
454 let (overall_status, http_status) = match (db_healthy, services_ok && content_ok) {
455 (false, _) => ("unhealthy", StatusCode::SERVICE_UNAVAILABLE),
456 (true, false) => ("degraded", StatusCode::OK),
457 (true, true) => ("healthy", StatusCode::OK),
458 };
459
460 let system_stats = get_system_stats(ctx.db_pool().as_ref()).await;
461
462 let check_duration_ms = start.elapsed().as_millis();
463 let memory = get_process_memory();
464
465 let data = json!({
466 "status": overall_status,
467 "timestamp": chrono::Utc::now().to_rfc3339(),
468 "version": env!("CARGO_PKG_VERSION"),
469 "checks": {
470 "database": {
471 "status": db_status,
472 "latency_ms": db_latency_ms
473 },
474 "agents": {
475 "status": agent_status,
476 "count": agent_count
477 },
478 "mcp": {
479 "status": mcp_status,
480 "count": mcp_count
481 },
482 "static_content": {
483 "status": if content_ok { "healthy" } else { "degraded" },
484 "index_html": index_exists,
485 "sitemap_xml": sitemap_exists
486 }
487 },
488 "memory": memory,
489 "system": system_stats,
490 "response_time_ms": check_duration_ms
491 });
492
493 (http_status, Json(data))
494}
495
496pub async fn handle_core_discovery(
497 axum::extract::State(ctx): axum::extract::State<AppContext>,
498) -> impl axum::response::IntoResponse {
499 let base = &ctx.config().api_external_url;
500 let data = json!({
501 "name": "Core Services",
502 "description": "Core conversation, task, and artifact management APIs",
503 "endpoints": {
504 "contexts": {
505 "href": format!("{}{}", base, ApiPaths::CORE_CONTEXTS),
506 "description": "Conversation context management",
507 "methods": ["GET", "POST", "DELETE"]
508 },
509 "tasks": {
510 "href": format!("{}{}", base, ApiPaths::CORE_TASKS),
511 "description": "Task management for agent operations",
512 "methods": ["GET", "POST", "PUT", "DELETE"]
513 },
514 "artifacts": {
515 "href": format!("{}{}", base, ApiPaths::CORE_ARTIFACTS),
516 "description": "Artifact storage and retrieval",
517 "methods": ["GET", "POST", "DELETE"]
518 },
519 "oauth": {
520 "href": format!("{}{}", base, ApiPaths::OAUTH_BASE),
521 "description": "OAuth2/OIDC authentication endpoints"
522 }
523 }
524 });
525 Json(SingleResponse::new(data))
526}
527
528pub async fn handle_agents_discovery(
529 axum::extract::State(ctx): axum::extract::State<AppContext>,
530) -> impl axum::response::IntoResponse {
531 let base = &ctx.config().api_external_url;
532 let data = json!({
533 "name": "Agent Services",
534 "description": "A2A protocol agent registry and proxy",
535 "endpoints": {
536 "registry": {
537 "href": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
538 "description": "List and discover available agents",
539 "methods": ["GET"]
540 },
541 "proxy": {
542 "href": format!("{}{}/<agent_id>/", base, ApiPaths::AGENTS_BASE),
543 "description": "Proxy requests to specific agents",
544 "methods": ["GET", "POST"]
545 }
546 }
547 });
548 Json(SingleResponse::new(data))
549}
550
551pub async fn handle_mcp_discovery(
552 axum::extract::State(ctx): axum::extract::State<AppContext>,
553) -> impl axum::response::IntoResponse {
554 let base = &ctx.config().api_external_url;
555 let data = json!({
556 "name": "MCP Services",
557 "description": "Model Context Protocol server registry and proxy",
558 "endpoints": {
559 "registry": {
560 "href": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
561 "description": "List and discover available MCP servers",
562 "methods": ["GET"]
563 },
564 "proxy": {
565 "href": format!("{}{}/<server_name>/mcp", base, ApiPaths::MCP_BASE),
566 "description": "Proxy requests to specific MCP servers",
567 "methods": ["GET", "POST"]
568 }
569 }
570 });
571 Json(SingleResponse::new(data))
572}
573
574pub fn discovery_router(ctx: &AppContext) -> Router {
575 Router::new()
576 .route(ApiPaths::DISCOVERY, get(handle_root_discovery))
577 .route(ApiPaths::HEALTH, get(handle_health))
578 .route("/health", get(handle_health))
579 .route(ApiPaths::CORE_BASE, get(handle_core_discovery))
580 .route(ApiPaths::AGENTS_BASE, get(handle_agents_discovery))
581 .route(ApiPaths::MCP_BASE, get(handle_mcp_discovery))
582 .with_state(ctx.clone())
583}