systemprompt_models/profile/
gateway.rs1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use systemprompt_identifiers::{ModelId, ProviderId, RouteId, SecretName};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum GatewayProfileError {
12 #[error("Failed to read gateway catalog {path}: {source}")]
13 CatalogRead {
14 path: PathBuf,
15 #[source]
16 source: std::io::Error,
17 },
18
19 #[error("Failed to parse gateway catalog {path}: {source}")]
20 CatalogParse {
21 path: PathBuf,
22 #[source]
23 source: serde_yaml::Error,
24 },
25
26 #[error("Invalid gateway catalog {path}: {source}")]
27 CatalogInvalid {
28 path: PathBuf,
29 #[source]
30 source: Box<Self>,
31 },
32
33 #[error("gateway catalog model has empty id")]
34 ModelEmptyId,
35
36 #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
37 UnknownProvider { model: String, provider: String },
38
39 #[error("gateway catalog provider has empty name")]
40 ProviderEmptyName,
41
42 #[error("gateway catalog provider '{name}' has empty endpoint")]
43 ProviderEmptyEndpoint { name: String },
44
45 #[error("gateway {label} endpoint '{endpoint}' is not permitted: {reason}")]
46 BlockedEndpoint {
47 label: String,
48 endpoint: String,
49 reason: String,
50 },
51
52 #[error(
53 "gateway route '{route}' provider '{provider}' is not declared in the catalog providers"
54 )]
55 RouteProviderNotInCatalog { route: String, provider: String },
56
57 #[error(
58 "gateway route '{route}' endpoint '{route_endpoint}' disagrees with catalog provider \
59 '{provider}' endpoint '{catalog_endpoint}'"
60 )]
61 RouteEndpointMismatch {
62 route: String,
63 provider: String,
64 route_endpoint: String,
65 catalog_endpoint: String,
66 },
67
68 #[error("gateway catalog model id or alias '{id}' is declared more than once")]
69 DuplicateModelId { id: String },
70
71 #[error("gateway route id '{id}' is declared more than once")]
72 DuplicateRouteId { id: String },
73
74 #[error("gateway catalog model '{model}' has no route whose pattern matches its id")]
75 UnreachableModel { model: String },
76}
77
78fn validate_endpoint(label: &str, endpoint: &str) -> GatewayResult<()> {
84 crate::net::validate_outbound_url(endpoint)
85 .map(|_| ())
86 .map_err(|e| GatewayProfileError::BlockedEndpoint {
87 label: label.to_owned(),
88 endpoint: endpoint.to_owned(),
89 reason: e.to_string(),
90 })
91}
92
93pub type GatewayResult<T> = Result<T, GatewayProfileError>;
94
95#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
96#[serde(deny_unknown_fields)]
97pub struct GatewayConfig {
98 #[serde(default)]
99 pub enabled: bool,
100 #[serde(default)]
101 pub routes: Vec<GatewayRoute>,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub catalog_path: Option<PathBuf>,
104 #[serde(default, skip)]
105 pub catalog: Option<GatewayCatalog>,
106 #[serde(default = "default_auth_scheme")]
107 pub auth_scheme: String,
108 #[serde(default = "default_inference_path_prefix")]
109 pub inference_path_prefix: String,
110}
111
112impl Default for GatewayConfig {
113 fn default() -> Self {
114 Self {
115 enabled: false,
116 routes: Vec::new(),
117 catalog_path: None,
118 catalog: None,
119 auth_scheme: default_auth_scheme(),
120 inference_path_prefix: default_inference_path_prefix(),
121 }
122 }
123}
124
125fn default_auth_scheme() -> String {
126 "bearer".to_owned()
127}
128
129fn default_inference_path_prefix() -> String {
130 "/v1".to_owned()
131}
132
133fn default_route_id() -> RouteId {
134 RouteId::new("")
135}
136
137impl GatewayConfig {
138 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
139 self.routes.iter().find(|route| route.matches(model))
140 }
141
142 #[must_use]
143 pub fn is_model_exposed(&self, model: &str) -> bool {
144 self.catalog
145 .as_ref()
146 .is_none_or(|c| c.contains_model(model))
147 }
148
149 pub fn validate(&self) -> GatewayResult<()> {
150 let mut route_ids: std::collections::HashSet<&str> =
151 std::collections::HashSet::with_capacity(self.routes.len());
152 for route in &self.routes {
153 if !route_ids.insert(route.id.as_str()) {
154 return Err(GatewayProfileError::DuplicateRouteId {
155 id: route.id.as_str().to_owned(),
156 });
157 }
158 }
159 let Some(catalog) = self.catalog.as_ref() else {
160 return Ok(());
161 };
162 catalog.validate()?;
163 for route in &self.routes {
164 if catalog.find_provider(route.provider.as_str()).is_none() {
165 return Err(GatewayProfileError::RouteProviderNotInCatalog {
166 route: route.model_pattern.clone(),
167 provider: route.provider.as_str().to_owned(),
168 });
169 }
170 }
171 let mut seen = std::collections::HashSet::with_capacity(catalog.models.len());
172 for model in &catalog.models {
173 if !seen.insert(model.id.as_str()) {
174 return Err(GatewayProfileError::DuplicateModelId {
175 id: model.id.as_str().to_owned(),
176 });
177 }
178 for alias in &model.aliases {
179 if !seen.insert(alias.as_str()) {
180 return Err(GatewayProfileError::DuplicateModelId {
181 id: alias.as_str().to_owned(),
182 });
183 }
184 }
185 if !self.routes.iter().any(|r| r.matches(model.id.as_str())) {
186 return Err(GatewayProfileError::UnreachableModel {
187 model: model.id.as_str().to_owned(),
188 });
189 }
190 }
191 Ok(())
192 }
193}
194
195#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
196#[serde(deny_unknown_fields)]
197pub struct GatewayCatalog {
198 #[serde(default)]
199 pub providers: Vec<GatewayProvider>,
200 #[serde(default)]
201 pub models: Vec<GatewayModel>,
202}
203
204impl GatewayCatalog {
205 pub fn validate(&self) -> GatewayResult<()> {
206 for model in &self.models {
207 if model.id.as_str().is_empty() {
208 return Err(GatewayProfileError::ModelEmptyId);
209 }
210 if !self.providers.iter().any(|p| p.name == model.provider) {
211 return Err(GatewayProfileError::UnknownProvider {
212 model: model.id.as_str().to_owned(),
213 provider: model.provider.as_str().to_owned(),
214 });
215 }
216 }
217 for provider in &self.providers {
218 if provider.name.as_str().is_empty() {
219 return Err(GatewayProfileError::ProviderEmptyName);
220 }
221 if provider.endpoint.is_empty() {
222 return Err(GatewayProfileError::ProviderEmptyEndpoint {
223 name: provider.name.as_str().to_owned(),
224 });
225 }
226 validate_endpoint(
227 &format!("provider '{}'", provider.name.as_str()),
228 &provider.endpoint,
229 )?;
230 }
231 Ok(())
232 }
233
234 pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
235 self.providers.iter().find(|p| p.name.as_str() == name)
236 }
237
238 #[must_use]
239 pub fn contains_model(&self, requested: &str) -> bool {
240 self.models.iter().any(|m| {
241 m.id.as_str() == requested || m.aliases.iter().any(|a| a.as_str() == requested)
242 })
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
247#[serde(deny_unknown_fields)]
248pub struct GatewayProvider {
249 pub name: ProviderId,
250 pub endpoint: String,
251 pub api_key_secret: SecretName,
252 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
253 pub extra_headers: HashMap<String, String>,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
257#[serde(deny_unknown_fields)]
258pub struct GatewayModel {
259 pub id: ModelId,
260 pub provider: ProviderId,
261 #[serde(default, skip_serializing_if = "Vec::is_empty")]
262 pub aliases: Vec<ModelId>,
263 #[serde(default, skip_serializing_if = "Option::is_none")]
264 pub display_name: Option<String>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub upstream_model: Option<String>,
267 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub pricing: Option<ModelPricing>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
272#[serde(deny_unknown_fields)]
273pub struct GatewayRoute {
274 #[serde(default = "default_route_id")]
275 pub id: RouteId,
276 pub model_pattern: String,
277 pub provider: ProviderId,
278 #[serde(default, skip_serializing_if = "Option::is_none")]
279 pub upstream_model: Option<String>,
280 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
281 pub extra_headers: HashMap<String, String>,
282 #[serde(default, skip_serializing_if = "Option::is_none")]
283 pub pricing: Option<ModelPricing>,
284}
285
286impl GatewayRoute {
287 pub fn matches(&self, model: &str) -> bool {
288 match_pattern(&self.model_pattern, model)
289 }
290
291 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
292 self.upstream_model.as_deref().unwrap_or(requested)
293 }
294
295 pub fn ensure_id(&mut self) {
296 if self.id.as_str().trim().is_empty() {
297 self.id = synthesize_route_id(&self.model_pattern, self.provider.as_str());
298 }
299 }
300
301 pub fn resolve<'a>(&self, providers: &'a [GatewayProvider]) -> Option<&'a GatewayProvider> {
302 providers.iter().find(|p| p.name == self.provider)
303 }
304}
305
306#[must_use]
313pub fn slugify_pattern(pattern: &str) -> String {
314 let mut out = String::with_capacity(pattern.len());
315 let mut last_dash = false;
316 for ch in pattern.chars() {
317 if ch == '*' {
318 out.push_str("star");
319 last_dash = false;
320 } else if ch.is_ascii_alphanumeric() {
321 for lc in ch.to_lowercase() {
322 out.push(lc);
323 }
324 last_dash = false;
325 } else if !last_dash && !out.is_empty() {
326 out.push('-');
327 last_dash = true;
328 }
329 }
330 while out.ends_with('-') {
331 out.pop();
332 }
333 while out.starts_with('-') {
334 out.remove(0);
335 }
336 if out.is_empty() {
337 out.push_str("route");
338 }
339 out
340}
341
342#[must_use]
347pub fn synthesize_route_id(model_pattern: &str, provider: &str) -> RouteId {
348 let mut hasher = DefaultHasher::new();
349 model_pattern.hash(&mut hasher);
350 provider.hash(&mut hasher);
351 let h = hasher.finish();
352 let hash6: String = format!("{h:016x}").chars().take(6).collect();
353 RouteId::new(format!("{}-{}", slugify_pattern(model_pattern), hash6))
354}
355
356fn match_pattern(pattern: &str, model: &str) -> bool {
357 if pattern == "*" {
358 return true;
359 }
360 if let Some(prefix) = pattern.strip_suffix('*') {
361 return model.starts_with(prefix);
362 }
363 if let Some(suffix) = pattern.strip_prefix('*') {
364 return model.ends_with(suffix);
365 }
366 pattern == model
367}