Skip to main content

systemprompt_models/profile/
gateway.rs

1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum GatewayProfileError {
11    #[error("Failed to read gateway catalog {path}: {source}")]
12    CatalogRead {
13        path: PathBuf,
14        #[source]
15        source: std::io::Error,
16    },
17
18    #[error("Failed to parse gateway catalog {path}: {source}")]
19    CatalogParse {
20        path: PathBuf,
21        #[source]
22        source: serde_yaml::Error,
23    },
24
25    #[error("Invalid gateway catalog {path}: {source}")]
26    CatalogInvalid {
27        path: PathBuf,
28        #[source]
29        source: Box<Self>,
30    },
31
32    #[error("gateway catalog model has empty id")]
33    ModelEmptyId,
34
35    #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
36    UnknownProvider { model: String, provider: String },
37
38    #[error("gateway catalog provider has empty name")]
39    ProviderEmptyName,
40
41    #[error("gateway catalog provider '{name}' has empty endpoint")]
42    ProviderEmptyEndpoint { name: String },
43
44    #[error("gateway {label} endpoint '{endpoint}' is not permitted: {reason}")]
45    BlockedEndpoint {
46        label: String,
47        endpoint: String,
48        reason: String,
49    },
50}
51
52/// Reject gateway upstream endpoints that point at the local host or private
53/// network ranges; an operator-configured endpoint pointing at
54/// `169.254.169.254` or an internal service would otherwise turn the inference
55/// proxy into an SSRF primitive. Delegates to the shared outbound-URL guard so
56/// gateway, webhook, and authz destinations enforce one policy.
57fn validate_endpoint(label: &str, endpoint: &str) -> GatewayResult<()> {
58    crate::net::validate_outbound_url(endpoint)
59        .map(|_| ())
60        .map_err(|e| GatewayProfileError::BlockedEndpoint {
61            label: label.to_string(),
62            endpoint: endpoint.to_string(),
63            reason: e.to_string(),
64        })
65}
66
67pub type GatewayResult<T> = Result<T, GatewayProfileError>;
68
69#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
70#[serde(deny_unknown_fields)]
71pub struct GatewayConfig {
72    #[serde(default)]
73    pub enabled: bool,
74    #[serde(default)]
75    pub routes: Vec<GatewayRoute>,
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub catalog_path: Option<PathBuf>,
78    #[serde(default, skip)]
79    pub catalog: Option<GatewayCatalog>,
80    #[serde(default = "default_auth_scheme")]
81    pub auth_scheme: String,
82    #[serde(default = "default_inference_path_prefix")]
83    pub inference_path_prefix: String,
84}
85
86impl Default for GatewayConfig {
87    fn default() -> Self {
88        Self {
89            enabled: false,
90            routes: Vec::new(),
91            catalog_path: None,
92            catalog: None,
93            auth_scheme: default_auth_scheme(),
94            inference_path_prefix: default_inference_path_prefix(),
95        }
96    }
97}
98
99fn default_auth_scheme() -> String {
100    "bearer".to_string()
101}
102
103fn default_inference_path_prefix() -> String {
104    "/v1".to_string()
105}
106
107impl GatewayConfig {
108    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
109        self.routes.iter().find(|route| route.matches(model))
110    }
111
112    pub fn validate_routes(&self) -> GatewayResult<()> {
113        for route in &self.routes {
114            validate_endpoint(&format!("route '{}'", route.model_pattern), &route.endpoint)?;
115        }
116        Ok(())
117    }
118}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
121#[serde(deny_unknown_fields)]
122pub struct GatewayCatalog {
123    #[serde(default)]
124    pub providers: Vec<GatewayProvider>,
125    #[serde(default)]
126    pub models: Vec<GatewayModel>,
127}
128
129impl GatewayCatalog {
130    pub fn validate(&self) -> GatewayResult<()> {
131        for model in &self.models {
132            if model.id.is_empty() {
133                return Err(GatewayProfileError::ModelEmptyId);
134            }
135            if !self.providers.iter().any(|p| p.name == model.provider) {
136                return Err(GatewayProfileError::UnknownProvider {
137                    model: model.id.clone(),
138                    provider: model.provider.clone(),
139                });
140            }
141        }
142        for provider in &self.providers {
143            if provider.name.is_empty() {
144                return Err(GatewayProfileError::ProviderEmptyName);
145            }
146            if provider.endpoint.is_empty() {
147                return Err(GatewayProfileError::ProviderEmptyEndpoint {
148                    name: provider.name.clone(),
149                });
150            }
151            validate_endpoint(&format!("provider '{}'", provider.name), &provider.endpoint)?;
152        }
153        Ok(())
154    }
155
156    pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
157        self.providers.iter().find(|p| p.name == name)
158    }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
162#[serde(deny_unknown_fields)]
163pub struct GatewayProvider {
164    pub name: String,
165    pub endpoint: String,
166    pub api_key_secret: String,
167    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
168    pub extra_headers: HashMap<String, String>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
172#[serde(deny_unknown_fields)]
173pub struct GatewayModel {
174    pub id: String,
175    pub provider: String,
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub display_name: Option<String>,
178    #[serde(default, skip_serializing_if = "Option::is_none")]
179    pub upstream_model: Option<String>,
180    #[serde(default, skip_serializing_if = "Option::is_none")]
181    pub pricing: Option<ModelPricing>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
185#[serde(deny_unknown_fields)]
186pub struct GatewayRoute {
187    #[serde(default)]
188    pub id: String,
189    pub model_pattern: String,
190    pub provider: String,
191    pub endpoint: String,
192    pub api_key_secret: String,
193    #[serde(default, skip_serializing_if = "Option::is_none")]
194    pub upstream_model: Option<String>,
195    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
196    pub extra_headers: HashMap<String, String>,
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub pricing: Option<ModelPricing>,
199}
200
201impl GatewayRoute {
202    pub fn matches(&self, model: &str) -> bool {
203        match_pattern(&self.model_pattern, model)
204    }
205
206    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
207        self.upstream_model.as_deref().unwrap_or(requested)
208    }
209
210    pub fn ensure_id(&mut self) {
211        if self.id.trim().is_empty() {
212            self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
213        }
214    }
215}
216
217/// Slugify a model pattern for use in a stable id.
218///
219/// Mirrors the template's historical implementation in
220/// `extensions/web/admin/.../gateway.rs`: `*` becomes `star`,
221/// non-alphanumeric runs collapse to a single `-`, leading/trailing `-`
222/// are trimmed, and an empty result becomes `route`.
223#[must_use]
224pub fn slugify_pattern(pattern: &str) -> String {
225    let mut out = String::with_capacity(pattern.len());
226    let mut last_dash = false;
227    for ch in pattern.chars() {
228        if ch == '*' {
229            out.push_str("star");
230            last_dash = false;
231        } else if ch.is_ascii_alphanumeric() {
232            for lc in ch.to_lowercase() {
233                out.push(lc);
234            }
235            last_dash = false;
236        } else if !last_dash && !out.is_empty() {
237            out.push('-');
238            last_dash = true;
239        }
240    }
241    while out.ends_with('-') {
242        out.pop();
243    }
244    while out.starts_with('-') {
245        out.remove(0);
246    }
247    if out.is_empty() {
248        out.push_str("route");
249    }
250    out
251}
252
253// Format: <slug>-<6 hex chars> where the hex digest is the first 6 chars of
254// DefaultHasher over (model_pattern, provider, endpoint). Mirrors the template
255// logic so ids stay identical across the core/template seam.
256#[must_use]
257pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
258    let mut hasher = DefaultHasher::new();
259    model_pattern.hash(&mut hasher);
260    provider.hash(&mut hasher);
261    endpoint.hash(&mut hasher);
262    let h = hasher.finish();
263    let hash6: String = format!("{h:016x}").chars().take(6).collect();
264    format!("{}-{}", slugify_pattern(model_pattern), hash6)
265}
266
267fn match_pattern(pattern: &str, model: &str) -> bool {
268    if pattern == "*" {
269        return true;
270    }
271    if let Some(prefix) = pattern.strip_suffix('*') {
272        return model.starts_with(prefix);
273    }
274    if let Some(suffix) = pattern.strip_prefix('*') {
275        return model.ends_with(suffix);
276    }
277    pattern == model
278}