Skip to main content

systemprompt_models/profile/
gateway.rs

1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum GatewayProfileError {
11    #[error("Failed to read gateway catalog {path}: {source}")]
12    CatalogRead {
13        path: PathBuf,
14        #[source]
15        source: std::io::Error,
16    },
17
18    #[error("Failed to parse gateway catalog {path}: {source}")]
19    CatalogParse {
20        path: PathBuf,
21        #[source]
22        source: serde_yaml::Error,
23    },
24
25    #[error("Invalid gateway catalog {path}: {source}")]
26    CatalogInvalid {
27        path: PathBuf,
28        #[source]
29        source: Box<Self>,
30    },
31
32    #[error("gateway catalog model has empty id")]
33    ModelEmptyId,
34
35    #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
36    UnknownProvider { model: String, provider: String },
37
38    #[error("gateway catalog provider has empty name")]
39    ProviderEmptyName,
40
41    #[error("gateway catalog provider '{name}' has empty endpoint")]
42    ProviderEmptyEndpoint { name: String },
43}
44
45pub type GatewayResult<T> = Result<T, GatewayProfileError>;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct GatewayConfig {
49    #[serde(default)]
50    pub enabled: bool,
51    #[serde(default)]
52    pub routes: Vec<GatewayRoute>,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub catalog_path: Option<PathBuf>,
55    #[serde(default, skip)]
56    pub catalog: Option<GatewayCatalog>,
57    #[serde(default = "default_auth_scheme")]
58    pub auth_scheme: String,
59    #[serde(default = "default_inference_path_prefix")]
60    pub inference_path_prefix: String,
61}
62
63impl Default for GatewayConfig {
64    fn default() -> Self {
65        Self {
66            enabled: false,
67            routes: Vec::new(),
68            catalog_path: None,
69            catalog: None,
70            auth_scheme: default_auth_scheme(),
71            inference_path_prefix: default_inference_path_prefix(),
72        }
73    }
74}
75
76fn default_auth_scheme() -> String {
77    "bearer".to_string()
78}
79
80fn default_inference_path_prefix() -> String {
81    "/v1".to_string()
82}
83
84impl GatewayConfig {
85    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
86        self.routes.iter().find(|route| route.matches(model))
87    }
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct GatewayCatalog {
92    #[serde(default)]
93    pub providers: Vec<GatewayProvider>,
94    #[serde(default)]
95    pub models: Vec<GatewayModel>,
96}
97
98impl GatewayCatalog {
99    pub fn validate(&self) -> GatewayResult<()> {
100        for model in &self.models {
101            if model.id.is_empty() {
102                return Err(GatewayProfileError::ModelEmptyId);
103            }
104            if !self.providers.iter().any(|p| p.name == model.provider) {
105                return Err(GatewayProfileError::UnknownProvider {
106                    model: model.id.clone(),
107                    provider: model.provider.clone(),
108                });
109            }
110        }
111        for provider in &self.providers {
112            if provider.name.is_empty() {
113                return Err(GatewayProfileError::ProviderEmptyName);
114            }
115            if provider.endpoint.is_empty() {
116                return Err(GatewayProfileError::ProviderEmptyEndpoint {
117                    name: provider.name.clone(),
118                });
119            }
120        }
121        Ok(())
122    }
123
124    pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
125        self.providers.iter().find(|p| p.name == name)
126    }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct GatewayProvider {
131    pub name: String,
132    pub endpoint: String,
133    pub api_key_secret: String,
134    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
135    pub extra_headers: HashMap<String, String>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct GatewayModel {
140    pub id: String,
141    pub provider: String,
142    #[serde(default, skip_serializing_if = "Option::is_none")]
143    pub display_name: Option<String>,
144    #[serde(default, skip_serializing_if = "Option::is_none")]
145    pub upstream_model: Option<String>,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub pricing: Option<ModelPricing>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct GatewayRoute {
152    #[serde(default)]
153    pub id: String,
154    pub model_pattern: String,
155    pub provider: String,
156    pub endpoint: String,
157    pub api_key_secret: String,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub upstream_model: Option<String>,
160    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
161    pub extra_headers: HashMap<String, String>,
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub pricing: Option<ModelPricing>,
164}
165
166impl GatewayRoute {
167    pub fn matches(&self, model: &str) -> bool {
168        match_pattern(&self.model_pattern, model)
169    }
170
171    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
172        self.upstream_model.as_deref().unwrap_or(requested)
173    }
174
175    pub fn ensure_id(&mut self) {
176        if self.id.trim().is_empty() {
177            self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
178        }
179    }
180}
181
182/// Slugify a model pattern for use in a stable id.
183///
184/// Mirrors the template's historical implementation in
185/// `extensions/web/admin/.../gateway.rs`: `*` becomes `star`,
186/// non-alphanumeric runs collapse to a single `-`, leading/trailing `-`
187/// are trimmed, and an empty result becomes `route`.
188#[must_use]
189pub fn slugify_pattern(pattern: &str) -> String {
190    let mut out = String::with_capacity(pattern.len());
191    let mut last_dash = false;
192    for ch in pattern.chars() {
193        if ch == '*' {
194            out.push_str("star");
195            last_dash = false;
196        } else if ch.is_ascii_alphanumeric() {
197            for lc in ch.to_lowercase() {
198                out.push(lc);
199            }
200            last_dash = false;
201        } else if !last_dash && !out.is_empty() {
202            out.push('-');
203            last_dash = true;
204        }
205    }
206    while out.ends_with('-') {
207        out.pop();
208    }
209    while out.starts_with('-') {
210        out.remove(0);
211    }
212    if out.is_empty() {
213        out.push_str("route");
214    }
215    out
216}
217
218/// Build a stable route id from `(model_pattern, provider, endpoint)`.
219///
220/// The id is `<slug>-<6 hex chars>` where the hex digest is the first 6
221/// chars of `DefaultHasher` over the same triple. Mirrors the template
222/// logic so ids stay identical across the seam.
223#[must_use]
224pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
225    let mut hasher = DefaultHasher::new();
226    model_pattern.hash(&mut hasher);
227    provider.hash(&mut hasher);
228    endpoint.hash(&mut hasher);
229    let h = hasher.finish();
230    let hash6: String = format!("{h:016x}").chars().take(6).collect();
231    format!("{}-{}", slugify_pattern(model_pattern), hash6)
232}
233
234fn match_pattern(pattern: &str, model: &str) -> bool {
235    if pattern == "*" {
236        return true;
237    }
238    if let Some(prefix) = pattern.strip_suffix('*') {
239        return model.starts_with(prefix);
240    }
241    if let Some(suffix) = pattern.strip_prefix('*') {
242        return model.ends_with(suffix);
243    }
244    pattern == model
245}