Skip to main content

systemprompt_api/routes/gateway/
models.rs

1use axum::Json;
2use axum::http::{HeaderMap, StatusCode};
3use serde::Serialize;
4use std::collections::BTreeMap;
5use systemprompt_config::ProfileBootstrap;
6use systemprompt_identifiers::headers::INFERENCE_PROTOCOL;
7use systemprompt_models::profile::{ApiSurface, ProviderRegistry};
8
9#[derive(Debug, Serialize)]
10pub struct RootResponse {
11    pub service: &'static str,
12    pub version: &'static str,
13    pub endpoints: Vec<&'static str>,
14}
15
16pub async fn root() -> Json<RootResponse> {
17    Json(RootResponse {
18        service: "systemprompt-gateway",
19        version: env!("CARGO_PKG_VERSION"),
20        endpoints: vec!["/v1/models", "/v1/messages"],
21    })
22}
23
24#[derive(Debug, Serialize)]
25pub struct ModelEntry {
26    #[serde(rename = "type")]
27    pub kind: &'static str,
28    pub id: String,
29    pub display_name: String,
30    pub created_at: String,
31}
32
33#[derive(Debug, Serialize)]
34pub struct ModelsResponse {
35    pub data: Vec<ModelEntry>,
36    pub has_more: bool,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub first_id: Option<String>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub last_id: Option<String>,
41}
42
43pub async fn list(headers: HeaderMap) -> Result<Json<ModelsResponse>, (StatusCode, String)> {
44    let profile = ProfileBootstrap::get().map_err(|e| {
45        (
46            StatusCode::SERVICE_UNAVAILABLE,
47            format!("Profile not ready: {e}"),
48        )
49    })?;
50
51    profile
52        .gateway
53        .as_ref()
54        .and_then(systemprompt_models::profile::GatewayState::resolved)
55        .filter(|g| g.enabled)
56        .ok_or_else(|| (StatusCode::NOT_FOUND, "Gateway not enabled".to_owned()))?;
57
58    let surfaces = surfaces_from_header(&headers)?;
59    let entries = model_entries(&profile.providers, &surfaces);
60    let first_id = entries.first().map(|e| e.id.clone());
61    let last_id = entries.last().map(|e| e.id.clone());
62
63    Ok(Json(ModelsResponse {
64        data: entries,
65        has_more: false,
66        first_id,
67        last_id,
68    }))
69}
70
71/// Resolve the `x-inference-protocol` selection header into API surfaces. An
72/// absent or empty header yields the full catalog (empty slice); an
73/// unrecognised tag, or `backend` (never a client surface), is a
74/// misconfiguration and fails with `400` rather than silently widening or
75/// leaking the advertised set.
76fn surfaces_from_header(headers: &HeaderMap) -> Result<Vec<ApiSurface>, (StatusCode, String)> {
77    let Some(raw) = headers
78        .get(INFERENCE_PROTOCOL)
79        .and_then(|v| v.to_str().ok())
80    else {
81        return Ok(Vec::new());
82    };
83    let mut surfaces = Vec::new();
84    for tag in raw.split(',').map(str::trim).filter(|t| !t.is_empty()) {
85        let surface = ApiSurface::from_tag(tag)
86            .filter(|s| *s != ApiSurface::Backend)
87            .ok_or_else(|| {
88                (
89                    StatusCode::BAD_REQUEST,
90                    format!("unknown {INFERENCE_PROTOCOL} value: {tag}"),
91                )
92            })?;
93        surfaces.push(surface);
94    }
95    Ok(surfaces)
96}
97
98pub fn model_entries(registry: &ProviderRegistry, surfaces: &[ApiSurface]) -> Vec<ModelEntry> {
99    let mut by_id: BTreeMap<String, ModelEntry> = BTreeMap::new();
100    for id in registry.advertised_model_ids(surfaces) {
101        by_id.insert(
102            id.clone(),
103            ModelEntry {
104                kind: "model",
105                display_name: humanize_model_id(&id),
106                id,
107                created_at: "1970-01-01T00:00:00Z".to_owned(),
108            },
109        );
110    }
111    by_id.into_values().collect()
112}
113
114fn humanize_model_id(id: &str) -> String {
115    id.split('-')
116        .map(|part| {
117            let mut chars = part.chars();
118            chars.next().map_or_else(String::new, |c| {
119                c.to_ascii_uppercase().to_string() + chars.as_str()
120            })
121        })
122        .collect::<Vec<_>>()
123        .join(" ")
124}