systemprompt_api/routes/gateway/
models.rs1use axum::Json;
2use axum::http::{HeaderMap, StatusCode};
3use serde::Serialize;
4use std::collections::BTreeMap;
5use systemprompt_config::ProfileBootstrap;
6use systemprompt_identifiers::headers::INFERENCE_PROTOCOL;
7use systemprompt_models::profile::{ProviderRegistry, WireProtocol};
8
9#[derive(Debug, Serialize)]
10pub struct RootResponse {
11 pub service: &'static str,
12 pub version: &'static str,
13 pub endpoints: Vec<&'static str>,
14}
15
16pub async fn root() -> Json<RootResponse> {
17 Json(RootResponse {
18 service: "systemprompt-gateway",
19 version: env!("CARGO_PKG_VERSION"),
20 endpoints: vec!["/v1/models", "/v1/messages"],
21 })
22}
23
24#[derive(Debug, Serialize)]
25pub struct ModelEntry {
26 #[serde(rename = "type")]
27 pub kind: &'static str,
28 pub id: String,
29 pub display_name: String,
30 pub created_at: String,
31}
32
33#[derive(Debug, Serialize)]
34pub struct ModelsResponse {
35 pub data: Vec<ModelEntry>,
36 pub has_more: bool,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub first_id: Option<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub last_id: Option<String>,
41}
42
43pub async fn list(headers: HeaderMap) -> Result<Json<ModelsResponse>, (StatusCode, String)> {
44 let profile = ProfileBootstrap::get().map_err(|e| {
45 (
46 StatusCode::SERVICE_UNAVAILABLE,
47 format!("Profile not ready: {e}"),
48 )
49 })?;
50
51 profile
52 .gateway
53 .as_ref()
54 .and_then(systemprompt_models::profile::GatewayState::resolved)
55 .filter(|g| g.enabled)
56 .ok_or_else(|| (StatusCode::NOT_FOUND, "Gateway not enabled".to_owned()))?;
57
58 let protocols = protocols_from_header(&headers)?;
59 let entries = model_entries(&profile.providers, &protocols);
60 let first_id = entries.first().map(|e| e.id.clone());
61 let last_id = entries.last().map(|e| e.id.clone());
62
63 Ok(Json(ModelsResponse {
64 data: entries,
65 has_more: false,
66 first_id,
67 last_id,
68 }))
69}
70
71fn protocols_from_header(headers: &HeaderMap) -> Result<Vec<WireProtocol>, (StatusCode, String)> {
76 let Some(raw) = headers
77 .get(INFERENCE_PROTOCOL)
78 .and_then(|v| v.to_str().ok())
79 else {
80 return Ok(Vec::new());
81 };
82 let mut protocols = Vec::new();
83 for tag in raw.split(',').map(str::trim).filter(|t| !t.is_empty()) {
84 let protocol = WireProtocol::from_tag(tag).ok_or_else(|| {
85 (
86 StatusCode::BAD_REQUEST,
87 format!("unknown {INFERENCE_PROTOCOL} value: {tag}"),
88 )
89 })?;
90 protocols.push(protocol);
91 }
92 Ok(protocols)
93}
94
95pub fn model_entries(registry: &ProviderRegistry, protocols: &[WireProtocol]) -> Vec<ModelEntry> {
96 let mut by_id: BTreeMap<String, ModelEntry> = BTreeMap::new();
97 for id in registry.advertised_model_ids(protocols) {
98 by_id.insert(
99 id.clone(),
100 ModelEntry {
101 kind: "model",
102 display_name: humanize_model_id(&id),
103 id,
104 created_at: "1970-01-01T00:00:00Z".to_owned(),
105 },
106 );
107 }
108 by_id.into_values().collect()
109}
110
111fn humanize_model_id(id: &str) -> String {
112 id.split('-')
113 .map(|part| {
114 let mut chars = part.chars();
115 chars.next().map_or_else(String::new, |c| {
116 c.to_ascii_uppercase().to_string() + chars.as_str()
117 })
118 })
119 .collect::<Vec<_>>()
120 .join(" ")
121}