Skip to main content

systemprompt_models/profile/
gateway.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::path::PathBuf;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum GatewayProfileError {
10    #[error("Failed to read gateway catalog {path}: {source}")]
11    CatalogRead {
12        path: PathBuf,
13        #[source]
14        source: std::io::Error,
15    },
16
17    #[error("Failed to parse gateway catalog {path}: {source}")]
18    CatalogParse {
19        path: PathBuf,
20        #[source]
21        source: serde_yaml::Error,
22    },
23
24    #[error("Invalid gateway catalog {path}: {source}")]
25    CatalogInvalid {
26        path: PathBuf,
27        #[source]
28        source: Box<Self>,
29    },
30
31    #[error("gateway catalog model has empty id")]
32    ModelEmptyId,
33
34    #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
35    UnknownProvider { model: String, provider: String },
36
37    #[error("gateway catalog provider has empty name")]
38    ProviderEmptyName,
39
40    #[error("gateway catalog provider '{name}' has empty endpoint")]
41    ProviderEmptyEndpoint { name: String },
42}
43
44pub type GatewayResult<T> = Result<T, GatewayProfileError>;
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct GatewayConfig {
48    #[serde(default)]
49    pub enabled: bool,
50    #[serde(default)]
51    pub routes: Vec<GatewayRoute>,
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub catalog_path: Option<PathBuf>,
54    #[serde(default, skip)]
55    pub catalog: Option<GatewayCatalog>,
56    #[serde(default = "default_auth_scheme")]
57    pub auth_scheme: String,
58    #[serde(default = "default_inference_path_prefix")]
59    pub inference_path_prefix: String,
60}
61
62impl Default for GatewayConfig {
63    fn default() -> Self {
64        Self {
65            enabled: false,
66            routes: Vec::new(),
67            catalog_path: None,
68            catalog: None,
69            auth_scheme: default_auth_scheme(),
70            inference_path_prefix: default_inference_path_prefix(),
71        }
72    }
73}
74
75fn default_auth_scheme() -> String {
76    "bearer".to_string()
77}
78
79fn default_inference_path_prefix() -> String {
80    "/v1".to_string()
81}
82
83impl GatewayConfig {
84    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
85        self.routes.iter().find(|route| route.matches(model))
86    }
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct GatewayCatalog {
91    #[serde(default)]
92    pub providers: Vec<GatewayProvider>,
93    #[serde(default)]
94    pub models: Vec<GatewayModel>,
95}
96
97impl GatewayCatalog {
98    pub fn validate(&self) -> GatewayResult<()> {
99        for model in &self.models {
100            if model.id.is_empty() {
101                return Err(GatewayProfileError::ModelEmptyId);
102            }
103            if !self.providers.iter().any(|p| p.name == model.provider) {
104                return Err(GatewayProfileError::UnknownProvider {
105                    model: model.id.clone(),
106                    provider: model.provider.clone(),
107                });
108            }
109        }
110        for provider in &self.providers {
111            if provider.name.is_empty() {
112                return Err(GatewayProfileError::ProviderEmptyName);
113            }
114            if provider.endpoint.is_empty() {
115                return Err(GatewayProfileError::ProviderEmptyEndpoint {
116                    name: provider.name.clone(),
117                });
118            }
119        }
120        Ok(())
121    }
122
123    pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
124        self.providers.iter().find(|p| p.name == name)
125    }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct GatewayProvider {
130    pub name: String,
131    pub endpoint: String,
132    pub api_key_secret: String,
133    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
134    pub extra_headers: HashMap<String, String>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct GatewayModel {
139    pub id: String,
140    pub provider: String,
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub display_name: Option<String>,
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub upstream_model: Option<String>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct GatewayRoute {
149    #[serde(default)]
150    pub id: String,
151    pub model_pattern: String,
152    pub provider: String,
153    pub endpoint: String,
154    pub api_key_secret: String,
155    #[serde(default, skip_serializing_if = "Option::is_none")]
156    pub upstream_model: Option<String>,
157    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
158    pub extra_headers: HashMap<String, String>,
159}
160
161impl GatewayRoute {
162    pub fn matches(&self, model: &str) -> bool {
163        match_pattern(&self.model_pattern, model)
164    }
165
166    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
167        self.upstream_model.as_deref().unwrap_or(requested)
168    }
169
170    pub fn ensure_id(&mut self) {
171        if self.id.trim().is_empty() {
172            self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
173        }
174    }
175}
176
177/// Slugify a model pattern for use in a stable id.
178///
179/// Mirrors the template's historical implementation in
180/// `extensions/web/admin/.../gateway.rs`: `*` becomes `star`,
181/// non-alphanumeric runs collapse to a single `-`, leading/trailing `-`
182/// are trimmed, and an empty result becomes `route`.
183#[must_use]
184pub fn slugify_pattern(pattern: &str) -> String {
185    let mut out = String::with_capacity(pattern.len());
186    let mut last_dash = false;
187    for ch in pattern.chars() {
188        if ch == '*' {
189            out.push_str("star");
190            last_dash = false;
191        } else if ch.is_ascii_alphanumeric() {
192            for lc in ch.to_lowercase() {
193                out.push(lc);
194            }
195            last_dash = false;
196        } else if !last_dash && !out.is_empty() {
197            out.push('-');
198            last_dash = true;
199        }
200    }
201    while out.ends_with('-') {
202        out.pop();
203    }
204    while out.starts_with('-') {
205        out.remove(0);
206    }
207    if out.is_empty() {
208        out.push_str("route");
209    }
210    out
211}
212
213/// Build a stable route id from `(model_pattern, provider, endpoint)`.
214///
215/// The id is `<slug>-<6 hex chars>` where the hex digest is the first 6
216/// chars of `DefaultHasher` over the same triple. Mirrors the template
217/// logic so ids stay identical across the seam.
218#[must_use]
219pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
220    let mut hasher = DefaultHasher::new();
221    model_pattern.hash(&mut hasher);
222    provider.hash(&mut hasher);
223    endpoint.hash(&mut hasher);
224    let h = hasher.finish();
225    let hash6: String = format!("{h:016x}").chars().take(6).collect();
226    format!("{}-{}", slugify_pattern(model_pattern), hash6)
227}
228
229fn match_pattern(pattern: &str, model: &str) -> bool {
230    if pattern == "*" {
231        return true;
232    }
233    if let Some(prefix) = pattern.strip_suffix('*') {
234        return model.starts_with(prefix);
235    }
236    if let Some(suffix) = pattern.strip_prefix('*') {
237        return model.ends_with(suffix);
238    }
239    pattern == model
240}