Skip to main content

systemprompt_models/profile/
gateway.rs

1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use systemprompt_identifiers::{ModelId, ProviderId, RouteId, SecretName};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum GatewayProfileError {
12    #[error("Failed to read gateway catalog {path}: {source}")]
13    CatalogRead {
14        path: PathBuf,
15        #[source]
16        source: std::io::Error,
17    },
18
19    #[error("Failed to parse gateway catalog {path}: {source}")]
20    CatalogParse {
21        path: PathBuf,
22        #[source]
23        source: serde_yaml::Error,
24    },
25
26    #[error("Invalid gateway catalog {path}: {source}")]
27    CatalogInvalid {
28        path: PathBuf,
29        #[source]
30        source: Box<Self>,
31    },
32
33    #[error("gateway catalog model has empty id")]
34    ModelEmptyId,
35
36    #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
37    UnknownProvider { model: String, provider: String },
38
39    #[error("gateway catalog provider has empty name")]
40    ProviderEmptyName,
41
42    #[error("gateway catalog provider '{name}' has empty endpoint")]
43    ProviderEmptyEndpoint { name: String },
44
45    #[error("gateway {label} endpoint '{endpoint}' is not permitted: {reason}")]
46    BlockedEndpoint {
47        label: String,
48        endpoint: String,
49        reason: String,
50    },
51
52    #[error(
53        "gateway route '{route}' provider '{provider}' is not declared in the catalog providers"
54    )]
55    RouteProviderNotInCatalog { route: String, provider: String },
56
57    #[error(
58        "gateway route '{route}' endpoint '{route_endpoint}' disagrees with catalog provider \
59         '{provider}' endpoint '{catalog_endpoint}'"
60    )]
61    RouteEndpointMismatch {
62        route: String,
63        provider: String,
64        route_endpoint: String,
65        catalog_endpoint: String,
66    },
67
68    #[error("gateway catalog model id or alias '{id}' is declared more than once")]
69    DuplicateModelId { id: String },
70
71    #[error("gateway route id '{id}' is declared more than once")]
72    DuplicateRouteId { id: String },
73
74    #[error("gateway catalog model '{model}' has no route whose pattern matches its id")]
75    UnreachableModel { model: String },
76}
77
78/// Reject gateway upstream endpoints that point at the local host or private
79/// network ranges; an operator-configured endpoint pointing at
80/// `169.254.169.254` or an internal service would otherwise turn the inference
81/// proxy into an SSRF primitive. Delegates to the shared outbound-URL guard so
82/// gateway, webhook, and authz destinations enforce one policy.
83fn validate_endpoint(label: &str, endpoint: &str) -> GatewayResult<()> {
84    crate::net::validate_outbound_url(endpoint)
85        .map(|_| ())
86        .map_err(|e| GatewayProfileError::BlockedEndpoint {
87            label: label.to_owned(),
88            endpoint: endpoint.to_owned(),
89            reason: e.to_string(),
90        })
91}
92
93pub type GatewayResult<T> = Result<T, GatewayProfileError>;
94
95#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
96#[serde(deny_unknown_fields)]
97pub struct GatewayConfig {
98    #[serde(default)]
99    pub enabled: bool,
100    #[serde(default)]
101    pub routes: Vec<GatewayRoute>,
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub catalog_path: Option<PathBuf>,
104    #[serde(default, skip)]
105    pub catalog: Option<GatewayCatalog>,
106    #[serde(default = "default_auth_scheme")]
107    pub auth_scheme: String,
108    #[serde(default = "default_inference_path_prefix")]
109    pub inference_path_prefix: String,
110}
111
112impl Default for GatewayConfig {
113    fn default() -> Self {
114        Self {
115            enabled: false,
116            routes: Vec::new(),
117            catalog_path: None,
118            catalog: None,
119            auth_scheme: default_auth_scheme(),
120            inference_path_prefix: default_inference_path_prefix(),
121        }
122    }
123}
124
125fn default_auth_scheme() -> String {
126    "bearer".to_owned()
127}
128
129fn default_inference_path_prefix() -> String {
130    "/v1".to_owned()
131}
132
133fn default_route_id() -> RouteId {
134    RouteId::new("")
135}
136
137impl GatewayConfig {
138    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
139        self.routes.iter().find(|route| route.matches(model))
140    }
141
142    #[must_use]
143    pub fn is_model_exposed(&self, model: &str) -> bool {
144        self.catalog
145            .as_ref()
146            .is_none_or(|c| c.contains_model(model))
147    }
148
149    pub fn validate(&self) -> GatewayResult<()> {
150        let mut route_ids: std::collections::HashSet<&str> =
151            std::collections::HashSet::with_capacity(self.routes.len());
152        for route in &self.routes {
153            if !route_ids.insert(route.id.as_str()) {
154                return Err(GatewayProfileError::DuplicateRouteId {
155                    id: route.id.as_str().to_owned(),
156                });
157            }
158        }
159        let Some(catalog) = self.catalog.as_ref() else {
160            return Ok(());
161        };
162        catalog.validate()?;
163        for route in &self.routes {
164            if catalog.find_provider(route.provider.as_str()).is_none() {
165                return Err(GatewayProfileError::RouteProviderNotInCatalog {
166                    route: route.model_pattern.clone(),
167                    provider: route.provider.as_str().to_owned(),
168                });
169            }
170        }
171        let mut seen = std::collections::HashSet::with_capacity(catalog.models.len());
172        for model in &catalog.models {
173            if !seen.insert(model.id.as_str()) {
174                return Err(GatewayProfileError::DuplicateModelId {
175                    id: model.id.as_str().to_owned(),
176                });
177            }
178            for alias in &model.aliases {
179                if !seen.insert(alias.as_str()) {
180                    return Err(GatewayProfileError::DuplicateModelId {
181                        id: alias.as_str().to_owned(),
182                    });
183                }
184            }
185            if !self.routes.iter().any(|r| r.matches(model.id.as_str())) {
186                return Err(GatewayProfileError::UnreachableModel {
187                    model: model.id.as_str().to_owned(),
188                });
189            }
190        }
191        Ok(())
192    }
193}
194
195#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
196#[serde(deny_unknown_fields)]
197pub struct GatewayCatalog {
198    #[serde(default)]
199    pub providers: Vec<GatewayProvider>,
200    #[serde(default)]
201    pub models: Vec<GatewayModel>,
202}
203
204impl GatewayCatalog {
205    pub fn validate(&self) -> GatewayResult<()> {
206        for model in &self.models {
207            if model.id.as_str().is_empty() {
208                return Err(GatewayProfileError::ModelEmptyId);
209            }
210            if !self.providers.iter().any(|p| p.name == model.provider) {
211                return Err(GatewayProfileError::UnknownProvider {
212                    model: model.id.as_str().to_owned(),
213                    provider: model.provider.as_str().to_owned(),
214                });
215            }
216        }
217        for provider in &self.providers {
218            if provider.name.as_str().is_empty() {
219                return Err(GatewayProfileError::ProviderEmptyName);
220            }
221            if provider.endpoint.is_empty() {
222                return Err(GatewayProfileError::ProviderEmptyEndpoint {
223                    name: provider.name.as_str().to_owned(),
224                });
225            }
226            validate_endpoint(
227                &format!("provider '{}'", provider.name.as_str()),
228                &provider.endpoint,
229            )?;
230        }
231        Ok(())
232    }
233
234    pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
235        self.providers.iter().find(|p| p.name.as_str() == name)
236    }
237
238    #[must_use]
239    pub fn contains_model(&self, requested: &str) -> bool {
240        self.models.iter().any(|m| {
241            m.id.as_str() == requested || m.aliases.iter().any(|a| a.as_str() == requested)
242        })
243    }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
247#[serde(deny_unknown_fields)]
248pub struct GatewayProvider {
249    pub name: ProviderId,
250    pub endpoint: String,
251    pub api_key_secret: SecretName,
252    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
253    pub extra_headers: HashMap<String, String>,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
257#[serde(deny_unknown_fields)]
258pub struct GatewayModel {
259    pub id: ModelId,
260    pub provider: ProviderId,
261    #[serde(default, skip_serializing_if = "Vec::is_empty")]
262    pub aliases: Vec<ModelId>,
263    #[serde(default, skip_serializing_if = "Option::is_none")]
264    pub display_name: Option<String>,
265    #[serde(default, skip_serializing_if = "Option::is_none")]
266    pub upstream_model: Option<String>,
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    pub pricing: Option<ModelPricing>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
272#[serde(deny_unknown_fields)]
273pub struct GatewayRoute {
274    #[serde(default = "default_route_id")]
275    pub id: RouteId,
276    pub model_pattern: String,
277    pub provider: ProviderId,
278    #[serde(default, skip_serializing_if = "Option::is_none")]
279    pub upstream_model: Option<String>,
280    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
281    pub extra_headers: HashMap<String, String>,
282    #[serde(default, skip_serializing_if = "Option::is_none")]
283    pub pricing: Option<ModelPricing>,
284}
285
286impl GatewayRoute {
287    pub fn matches(&self, model: &str) -> bool {
288        match_pattern(&self.model_pattern, model)
289    }
290
291    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
292        self.upstream_model.as_deref().unwrap_or(requested)
293    }
294
295    pub fn ensure_id(&mut self) {
296        if self.id.as_str().trim().is_empty() {
297            self.id = synthesize_route_id(&self.model_pattern, self.provider.as_str());
298        }
299    }
300
301    pub fn resolve<'a>(&self, providers: &'a [GatewayProvider]) -> Option<&'a GatewayProvider> {
302        providers.iter().find(|p| p.name == self.provider)
303    }
304}
305
306/// Slugify a model pattern for use in a stable id.
307///
308/// Mirrors the template's historical implementation in
309/// `extensions/web/admin/.../gateway.rs`: `*` becomes `star`,
310/// non-alphanumeric runs collapse to a single `-`, leading/trailing `-`
311/// are trimmed, and an empty result becomes `route`.
312#[must_use]
313pub fn slugify_pattern(pattern: &str) -> String {
314    let mut out = String::with_capacity(pattern.len());
315    let mut last_dash = false;
316    for ch in pattern.chars() {
317        if ch == '*' {
318            out.push_str("star");
319            last_dash = false;
320        } else if ch.is_ascii_alphanumeric() {
321            for lc in ch.to_lowercase() {
322                out.push(lc);
323            }
324            last_dash = false;
325        } else if !last_dash && !out.is_empty() {
326            out.push('-');
327            last_dash = true;
328        }
329    }
330    while out.ends_with('-') {
331        out.pop();
332    }
333    while out.starts_with('-') {
334        out.remove(0);
335    }
336    if out.is_empty() {
337        out.push_str("route");
338    }
339    out
340}
341
342// Format: <slug>-<6 hex chars> where the hex digest is the first 6 chars of
343// DefaultHasher over (model_pattern, provider). The collision check in
344// GatewayConfig::validate() guards against the vanishingly unlikely case of
345// two operator-authored patterns colliding on the 6-hex tail.
346#[must_use]
347pub fn synthesize_route_id(model_pattern: &str, provider: &str) -> RouteId {
348    let mut hasher = DefaultHasher::new();
349    model_pattern.hash(&mut hasher);
350    provider.hash(&mut hasher);
351    let h = hasher.finish();
352    let hash6: String = format!("{h:016x}").chars().take(6).collect();
353    RouteId::new(format!("{}-{}", slugify_pattern(model_pattern), hash6))
354}
355
356fn match_pattern(pattern: &str, model: &str) -> bool {
357    if pattern == "*" {
358        return true;
359    }
360    if let Some(prefix) = pattern.strip_suffix('*') {
361        return model.starts_with(prefix);
362    }
363    if let Some(suffix) = pattern.strip_prefix('*') {
364        return model.ends_with(suffix);
365    }
366    pattern == model
367}