Skip to main content

systemprompt_api/services/server/
builder.rs

1use anyhow::Result;
2use axum::extract::DefaultBodyLimit;
3use axum::routing::get;
4use axum::{Json, Router};
5use serde_json::json;
6use systemprompt_database::DatabaseQuery;
7use systemprompt_models::api::SingleResponse;
8use systemprompt_models::modules::ApiPaths;
9use systemprompt_models::AppPaths;
10use systemprompt_runtime::AppContext;
11use systemprompt_traits::{StartupEvent, StartupEventExt, StartupEventSender};
12
13use super::routes::configure_routes;
14use crate::models::ServerConfig;
15use crate::services::middleware::{
16    inject_security_headers, inject_trace_header, remove_trailing_slash, AnalyticsMiddleware,
17    ContextMiddleware, CorsMiddleware, JwtContextExtractor, SessionMiddleware,
18};
19
20const HEALTH_CHECK_QUERY: DatabaseQuery = DatabaseQuery::new("SELECT 1");
21
22const DB_SIZE_QUERY: DatabaseQuery = DatabaseQuery::new(
23    "SELECT pg_database_size(current_database()) as size_bytes, current_database() as db_name",
24);
25
26const TABLE_SIZES_QUERY: DatabaseQuery = DatabaseQuery::new(
27    "SELECT relname as table_name, pg_total_relation_size(relid) as total_bytes, n_live_tup as \
28     row_estimate FROM pg_stat_user_tables ORDER BY pg_total_relation_size(relid) DESC LIMIT 15",
29);
30
31const TABLE_COUNT_QUERY: DatabaseQuery =
32    DatabaseQuery::new("SELECT COUNT(*) as count FROM pg_stat_user_tables");
33
34const AUDIT_LOG_QUERY: DatabaseQuery = DatabaseQuery::new(
35    "SELECT COUNT(*) as row_count, pg_total_relation_size('audit_log') as size_bytes, \
36     MIN(created_at) as oldest, MAX(created_at) as newest FROM audit_log",
37);
38
39#[derive(Debug)]
40pub struct ApiServer {
41    router: Router,
42    _config: ServerConfig,
43    events: Option<StartupEventSender>,
44}
45
46impl ApiServer {
47    pub fn new(router: Router, events: Option<StartupEventSender>) -> Self {
48        Self::with_config(router, ServerConfig::default(), events)
49    }
50
51    pub const fn with_config(
52        router: Router,
53        config: ServerConfig,
54        events: Option<StartupEventSender>,
55    ) -> Self {
56        Self {
57            router,
58            _config: config,
59            events,
60        }
61    }
62
63    pub async fn serve(self, addr: &str) -> Result<()> {
64        if let Some(ref tx) = self.events {
65            if tx
66                .unbounded_send(StartupEvent::ServerBinding {
67                    address: addr.to_string(),
68                })
69                .is_err()
70            {
71                tracing::debug!("Startup event receiver dropped");
72            }
73        }
74
75        let listener = self.create_listener(addr).await?;
76
77        if let Some(ref tx) = self.events {
78            tx.server_listening(addr, std::process::id());
79        }
80
81        axum::serve(
82            listener,
83            self.router
84                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
85        )
86        .await?;
87        Ok(())
88    }
89
90    async fn create_listener(&self, addr: &str) -> Result<tokio::net::TcpListener> {
91        tokio::net::TcpListener::bind(addr)
92            .await
93            .map_err(|e| anyhow::anyhow!("Failed to bind to {addr}: {e}"))
94    }
95}
96
97pub fn setup_api_server(ctx: &AppContext, events: Option<StartupEventSender>) -> Result<ApiServer> {
98    let rate_config = &ctx.config().rate_limits;
99
100    if rate_config.disabled {
101        if let Some(ref tx) = events {
102            tx.warning("Rate limiting disabled - development mode only");
103        }
104    }
105
106    let router = configure_routes(ctx, events.as_ref())?;
107    let router = apply_global_middleware(router, ctx)?;
108
109    Ok(ApiServer::new(router, events))
110}
111
112fn apply_global_middleware(router: Router, ctx: &AppContext) -> Result<Router> {
113    let mut router = router;
114
115    router = router.layer(DefaultBodyLimit::max(100 * 1024 * 1024));
116
117    let analytics_middleware = AnalyticsMiddleware::new(ctx)?;
118    router = router.layer(axum::middleware::from_fn({
119        let middleware = analytics_middleware;
120        move |req, next| {
121            let middleware = middleware.clone();
122            async move { middleware.track_request(req, next).await }
123        }
124    }));
125
126    let jwt_extractor = JwtContextExtractor::new(
127        systemprompt_models::SecretsBootstrap::jwt_secret()?,
128        ctx.db_pool(),
129    );
130    let global_context_middleware = ContextMiddleware::public(jwt_extractor);
131    router = router.layer(axum::middleware::from_fn({
132        let middleware = global_context_middleware;
133        move |req, next| {
134            let middleware = middleware.clone();
135            async move { middleware.handle(req, next).await }
136        }
137    }));
138
139    let session_middleware = SessionMiddleware::new(ctx)?;
140    router = router.layer(axum::middleware::from_fn({
141        let middleware = session_middleware;
142        move |req, next| {
143            let middleware = middleware.clone();
144            async move { middleware.handle(req, next).await }
145        }
146    }));
147
148    let cors = CorsMiddleware::build_layer(ctx.config())?;
149    router = router.layer(cors);
150
151    router = router.layer(axum::middleware::from_fn(remove_trailing_slash));
152
153    router = router.layer(axum::middleware::from_fn(inject_trace_header));
154
155    if ctx.config().content_negotiation.enabled {
156        router = router.layer(axum::middleware::from_fn(
157            crate::services::middleware::content_negotiation_middleware,
158        ));
159    }
160
161    if ctx.config().security_headers.enabled {
162        let security_config = ctx.config().security_headers.clone();
163        router = router.layer(axum::middleware::from_fn(move |req, next| {
164            let config = security_config.clone();
165            inject_security_headers(config, req, next)
166        }));
167    }
168
169    Ok(router)
170}
171
172pub async fn handle_root_discovery(
173    axum::extract::State(ctx): axum::extract::State<AppContext>,
174) -> impl axum::response::IntoResponse {
175    let base = &ctx.config().api_external_url;
176    let data = json!({
177        "name": format!("{} API", ctx.config().sitename),
178        "version": "1.0.0",
179        "description": "systemprompt.io OS API Gateway",
180        "endpoints": {
181            "health": format!("{}{}", base, ApiPaths::HEALTH),
182            "oauth": {
183                "href": format!("{}{}", base, ApiPaths::OAUTH_BASE),
184                "description": "OAuth2/OIDC authentication and WebAuthn",
185                "endpoints": {
186                    "authorize": format!("{}{}", base, ApiPaths::OAUTH_AUTHORIZE),
187                    "token": format!("{}{}", base, ApiPaths::OAUTH_TOKEN),
188                    "userinfo": format!("{}{}/userinfo", base, ApiPaths::OAUTH_BASE),
189                    "introspect": format!("{}{}/introspect", base, ApiPaths::OAUTH_BASE),
190                    "revoke": format!("{}{}/revoke", base, ApiPaths::OAUTH_BASE),
191                    "webauthn": format!("{}{}/webauthn", base, ApiPaths::OAUTH_BASE)
192                }
193            },
194            "core": {
195                "href": format!("{}{}", base, ApiPaths::CORE_BASE),
196                "description": "Core conversation, task, and artifact management",
197                "endpoints": {
198                    "contexts": format!("{}{}", base, ApiPaths::CORE_CONTEXTS),
199                    "tasks": format!("{}{}", base, ApiPaths::CORE_TASKS),
200                    "artifacts": format!("{}{}", base, ApiPaths::CORE_ARTIFACTS)
201                }
202            },
203            "agents": {
204                "href": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
205                "description": "A2A protocol agent registry and proxy",
206                "endpoints": {
207                    "registry": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
208                    "proxy": format!("{}{}{{agent_id}}", base, ApiPaths::AGENTS_BASE)
209                }
210            },
211            "mcp": {
212                "href": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
213                "description": "MCP server registry and lifecycle management",
214                "endpoints": {
215                    "registry": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
216                    "proxy": format!("{}{}{{server_name}}", base, ApiPaths::MCP_BASE)
217                }
218            },
219            "stream": {
220                "href": format!("{}{}", base, ApiPaths::STREAM_BASE),
221                "description": "Server-Sent Events (SSE) for real-time updates",
222                "endpoints": {
223                    "contexts": format!("{}{}", base, ApiPaths::STREAM_CONTEXTS)
224                }
225            }
226        },
227        "wellknown": {
228            "oauth": format!("{}{}", base, ApiPaths::WELLKNOWN_OAUTH_SERVER),
229            "agent": format!("{}{}", base, ApiPaths::WELLKNOWN_AGENT_CARD)
230        }
231    });
232
233    Json(SingleResponse::new(data))
234}
235
236#[cfg(target_os = "linux")]
237fn parse_proc_status_kb(content: &str, key: &str) -> Option<u64> {
238    content
239        .lines()
240        .find(|line| line.starts_with(key))
241        .and_then(|line| {
242            line.split_whitespace()
243                .nth(1)
244                .and_then(|v| v.parse::<u64>().ok())
245        })
246}
247
248#[cfg(target_os = "linux")]
249fn get_process_memory() -> Option<serde_json::Value> {
250    let content = std::fs::read_to_string("/proc/self/status").ok()?;
251
252    let rss_kb = parse_proc_status_kb(&content, "VmRSS:");
253    let virt_kb = parse_proc_status_kb(&content, "VmSize:");
254    let peak_kb = parse_proc_status_kb(&content, "VmPeak:");
255
256    Some(json!({
257        "rss_mb": rss_kb.map(|kb| kb / 1024),
258        "virtual_mb": virt_kb.map(|kb| kb / 1024),
259        "peak_mb": peak_kb.map(|kb| kb / 1024)
260    }))
261}
262
263#[cfg(not(target_os = "linux"))]
264fn get_process_memory() -> Option<serde_json::Value> {
265    None
266}
267
268#[allow(clippy::cast_precision_loss)]
269fn human_bytes(bytes: i64) -> String {
270    const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
271    let mut size = bytes as f64;
272    let mut idx = 0;
273    while size >= 1024.0 && idx < UNITS.len() - 1 {
274        size /= 1024.0;
275        idx += 1;
276    }
277    format!("{size:.1} {}", UNITS[idx])
278}
279
280async fn get_disk_usage() -> Option<serde_json::Value> {
281    let output = tokio::process::Command::new("df")
282        .args(["-B1", "--output=size,used,avail", "."])
283        .output()
284        .await
285        .ok()?;
286
287    let stdout = String::from_utf8_lossy(&output.stdout);
288    let line = stdout.lines().nth(1)?;
289    let parts: Vec<&str> = line.split_whitespace().collect();
290    if parts.len() < 3 {
291        return None;
292    }
293
294    let total: u64 = parts[0].parse().ok()?;
295    let used: u64 = parts[1].parse().ok()?;
296    let available: u64 = parts[2].parse().ok()?;
297
298    #[allow(clippy::cast_precision_loss)]
299    let usage_pct = if total > 0 {
300        (used as f64 / total as f64) * 100.0
301    } else {
302        0.0
303    };
304
305    #[allow(clippy::cast_possible_wrap)]
306    Some(json!({
307        "total": human_bytes(total as i64),
308        "used": human_bytes(used as i64),
309        "available": human_bytes(available as i64),
310        "usage_percent": (usage_pct * 10.0).round() / 10.0
311    }))
312}
313
314async fn get_system_stats(
315    db: &dyn systemprompt_database::DatabaseProvider,
316) -> Option<serde_json::Value> {
317    let db_size_fut = db.fetch_one(&DB_SIZE_QUERY, &[]);
318    let table_sizes_fut = db.fetch_all(&TABLE_SIZES_QUERY, &[]);
319    let table_count_fut = db.fetch_one(&TABLE_COUNT_QUERY, &[]);
320    let audit_fut = db.fetch_optional(&AUDIT_LOG_QUERY, &[]);
321    let disk_fut = get_disk_usage();
322
323    let (db_size, table_sizes, table_count, audit, disk) = tokio::join!(
324        db_size_fut,
325        table_sizes_fut,
326        table_count_fut,
327        audit_fut,
328        disk_fut
329    );
330
331    let database =
332        if let (Ok(size_row), Ok(tables), Ok(count_row)) = (&db_size, &table_sizes, &table_count) {
333            let size_bytes = size_row
334                .get("size_bytes")
335                .and_then(serde_json::Value::as_i64)
336                .unwrap_or(0);
337            let db_name = size_row
338                .get("db_name")
339                .and_then(serde_json::Value::as_str)
340                .unwrap_or("unknown");
341            let tbl_count = count_row
342                .get("count")
343                .and_then(serde_json::Value::as_i64)
344                .unwrap_or(0);
345
346            let top_tables: Vec<serde_json::Value> = tables
347                .iter()
348                .map(|row| {
349                    let name = row
350                        .get("table_name")
351                        .and_then(serde_json::Value::as_str)
352                        .unwrap_or("?");
353                    let total = row
354                        .get("total_bytes")
355                        .and_then(serde_json::Value::as_i64)
356                        .unwrap_or(0);
357                    let rows = row
358                        .get("row_estimate")
359                        .and_then(serde_json::Value::as_i64)
360                        .unwrap_or(0);
361                    json!({
362                        "table_name": name,
363                        "total_size": human_bytes(total),
364                        "total_size_bytes": total,
365                        "row_estimate": rows
366                    })
367                })
368                .collect();
369
370            Some(json!({
371                "name": db_name,
372                "total_size": human_bytes(size_bytes),
373                "total_size_bytes": size_bytes,
374                "table_count": tbl_count,
375                "top_tables": top_tables
376            }))
377        } else {
378            None
379        };
380
381    let logs = audit.ok().flatten().map(|row| {
382        let row_count = row
383            .get("row_count")
384            .and_then(serde_json::Value::as_i64)
385            .unwrap_or(0);
386        let size_bytes = row
387            .get("size_bytes")
388            .and_then(serde_json::Value::as_i64)
389            .unwrap_or(0);
390        json!({
391            "audit_rows": row_count,
392            "audit_size": human_bytes(size_bytes),
393            "audit_size_bytes": size_bytes,
394            "oldest": row.get("oldest"),
395            "newest": row.get("newest")
396        })
397    });
398
399    Some(json!({
400        "database": database,
401        "disk": disk,
402        "logs": logs
403    }))
404}
405
406pub async fn handle_health(
407    axum::extract::State(ctx): axum::extract::State<AppContext>,
408) -> impl axum::response::IntoResponse {
409    use axum::http::StatusCode;
410    use systemprompt_database::{DatabaseProvider, ServiceRepository};
411
412    let start = std::time::Instant::now();
413
414    let (db_status, db_latency_ms) = {
415        let db_start = std::time::Instant::now();
416        let status = match ctx.db_pool().fetch_optional(&HEALTH_CHECK_QUERY, &[]).await {
417            Ok(_) => "healthy",
418            Err(_) => "unhealthy",
419        };
420        (status, db_start.elapsed().as_millis())
421    };
422
423    let (agent_count, agent_status, mcp_count, mcp_status) =
424        match ServiceRepository::new(ctx.db_pool()) {
425            Ok(service_repo) => {
426                let (ac, as_) = match service_repo.count_running_services("agent").await {
427                    Ok(count) if count > 0 => (count, "healthy"),
428                    Ok(_) => (0, "none"),
429                    Err(_) => (0, "error"),
430                };
431                let (mc, ms) = match service_repo.count_running_services("mcp").await {
432                    Ok(count) if count > 0 => (count, "healthy"),
433                    Ok(_) => (0, "none"),
434                    Err(_) => (0, "error"),
435                };
436                (ac, as_, mc, ms)
437            },
438            Err(_) => (0, "error", 0, "error"),
439        };
440
441    let web_dir = AppPaths::get()
442        .map(|p| p.web().dist().to_path_buf())
443        .unwrap_or_else(|e| {
444            tracing::debug!(error = %e, "Failed to get web dist path, using default");
445            std::path::PathBuf::from("/var/www/html/dist")
446        });
447    let sitemap_exists = web_dir.join("sitemap.xml").exists();
448    let index_exists = web_dir.join("index.html").exists();
449
450    let db_healthy = db_status == "healthy";
451    let services_ok = agent_status != "error" && mcp_status != "error";
452    let content_ok = sitemap_exists && index_exists;
453
454    let (overall_status, http_status) = match (db_healthy, services_ok && content_ok) {
455        (false, _) => ("unhealthy", StatusCode::SERVICE_UNAVAILABLE),
456        (true, false) => ("degraded", StatusCode::OK),
457        (true, true) => ("healthy", StatusCode::OK),
458    };
459
460    let system_stats = get_system_stats(ctx.db_pool().as_ref()).await;
461
462    let check_duration_ms = start.elapsed().as_millis();
463    let memory = get_process_memory();
464
465    let data = json!({
466        "status": overall_status,
467        "timestamp": chrono::Utc::now().to_rfc3339(),
468        "version": env!("CARGO_PKG_VERSION"),
469        "checks": {
470            "database": {
471                "status": db_status,
472                "latency_ms": db_latency_ms
473            },
474            "agents": {
475                "status": agent_status,
476                "count": agent_count
477            },
478            "mcp": {
479                "status": mcp_status,
480                "count": mcp_count
481            },
482            "static_content": {
483                "status": if content_ok { "healthy" } else { "degraded" },
484                "index_html": index_exists,
485                "sitemap_xml": sitemap_exists
486            }
487        },
488        "memory": memory,
489        "system": system_stats,
490        "response_time_ms": check_duration_ms
491    });
492
493    (http_status, Json(data))
494}
495
496pub async fn handle_core_discovery(
497    axum::extract::State(ctx): axum::extract::State<AppContext>,
498) -> impl axum::response::IntoResponse {
499    let base = &ctx.config().api_external_url;
500    let data = json!({
501        "name": "Core Services",
502        "description": "Core conversation, task, and artifact management APIs",
503        "endpoints": {
504            "contexts": {
505                "href": format!("{}{}", base, ApiPaths::CORE_CONTEXTS),
506                "description": "Conversation context management",
507                "methods": ["GET", "POST", "DELETE"]
508            },
509            "tasks": {
510                "href": format!("{}{}", base, ApiPaths::CORE_TASKS),
511                "description": "Task management for agent operations",
512                "methods": ["GET", "POST", "PUT", "DELETE"]
513            },
514            "artifacts": {
515                "href": format!("{}{}", base, ApiPaths::CORE_ARTIFACTS),
516                "description": "Artifact storage and retrieval",
517                "methods": ["GET", "POST", "DELETE"]
518            },
519            "oauth": {
520                "href": format!("{}{}", base, ApiPaths::OAUTH_BASE),
521                "description": "OAuth2/OIDC authentication endpoints"
522            }
523        }
524    });
525    Json(SingleResponse::new(data))
526}
527
528pub async fn handle_agents_discovery(
529    axum::extract::State(ctx): axum::extract::State<AppContext>,
530) -> impl axum::response::IntoResponse {
531    let base = &ctx.config().api_external_url;
532    let data = json!({
533        "name": "Agent Services",
534        "description": "A2A protocol agent registry and proxy",
535        "endpoints": {
536            "registry": {
537                "href": format!("{}{}", base, ApiPaths::AGENTS_REGISTRY),
538                "description": "List and discover available agents",
539                "methods": ["GET"]
540            },
541            "proxy": {
542                "href": format!("{}{}/<agent_id>/", base, ApiPaths::AGENTS_BASE),
543                "description": "Proxy requests to specific agents",
544                "methods": ["GET", "POST"]
545            }
546        }
547    });
548    Json(SingleResponse::new(data))
549}
550
551pub async fn handle_mcp_discovery(
552    axum::extract::State(ctx): axum::extract::State<AppContext>,
553) -> impl axum::response::IntoResponse {
554    let base = &ctx.config().api_external_url;
555    let data = json!({
556        "name": "MCP Services",
557        "description": "Model Context Protocol server registry and proxy",
558        "endpoints": {
559            "registry": {
560                "href": format!("{}{}", base, ApiPaths::MCP_REGISTRY),
561                "description": "List and discover available MCP servers",
562                "methods": ["GET"]
563            },
564            "proxy": {
565                "href": format!("{}{}/<server_name>/mcp", base, ApiPaths::MCP_BASE),
566                "description": "Proxy requests to specific MCP servers",
567                "methods": ["GET", "POST"]
568            }
569        }
570    });
571    Json(SingleResponse::new(data))
572}
573
574pub fn discovery_router(ctx: &AppContext) -> Router {
575    Router::new()
576        .route(ApiPaths::DISCOVERY, get(handle_root_discovery))
577        .route(ApiPaths::HEALTH, get(handle_health))
578        .route("/health", get(handle_health))
579        .route(ApiPaths::CORE_BASE, get(handle_core_discovery))
580        .route(ApiPaths::AGENTS_BASE, get(handle_agents_discovery))
581        .route(ApiPaths::MCP_BASE, get(handle_mcp_discovery))
582        .with_state(ctx.clone())
583}