systemprompt_models/profile/
gateway.rs1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum GatewayProfileError {
11 #[error("Failed to read gateway catalog {path}: {source}")]
12 CatalogRead {
13 path: PathBuf,
14 #[source]
15 source: std::io::Error,
16 },
17
18 #[error("Failed to parse gateway catalog {path}: {source}")]
19 CatalogParse {
20 path: PathBuf,
21 #[source]
22 source: serde_yaml::Error,
23 },
24
25 #[error("Invalid gateway catalog {path}: {source}")]
26 CatalogInvalid {
27 path: PathBuf,
28 #[source]
29 source: Box<Self>,
30 },
31
32 #[error("gateway catalog model has empty id")]
33 ModelEmptyId,
34
35 #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
36 UnknownProvider { model: String, provider: String },
37
38 #[error("gateway catalog provider has empty name")]
39 ProviderEmptyName,
40
41 #[error("gateway catalog provider '{name}' has empty endpoint")]
42 ProviderEmptyEndpoint { name: String },
43
44 #[error("gateway {label} endpoint '{endpoint}' is not permitted: {reason}")]
45 BlockedEndpoint {
46 label: String,
47 endpoint: String,
48 reason: String,
49 },
50}
51
52fn validate_endpoint(label: &str, endpoint: &str) -> GatewayResult<()> {
58 crate::net::validate_outbound_url(endpoint)
59 .map(|_| ())
60 .map_err(|e| GatewayProfileError::BlockedEndpoint {
61 label: label.to_string(),
62 endpoint: endpoint.to_string(),
63 reason: e.to_string(),
64 })
65}
66
67pub type GatewayResult<T> = Result<T, GatewayProfileError>;
68
69#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
70#[serde(deny_unknown_fields)]
71pub struct GatewayConfig {
72 #[serde(default)]
73 pub enabled: bool,
74 #[serde(default)]
75 pub routes: Vec<GatewayRoute>,
76 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub catalog_path: Option<PathBuf>,
78 #[serde(default, skip)]
79 pub catalog: Option<GatewayCatalog>,
80 #[serde(default = "default_auth_scheme")]
81 pub auth_scheme: String,
82 #[serde(default = "default_inference_path_prefix")]
83 pub inference_path_prefix: String,
84}
85
86impl Default for GatewayConfig {
87 fn default() -> Self {
88 Self {
89 enabled: false,
90 routes: Vec::new(),
91 catalog_path: None,
92 catalog: None,
93 auth_scheme: default_auth_scheme(),
94 inference_path_prefix: default_inference_path_prefix(),
95 }
96 }
97}
98
99fn default_auth_scheme() -> String {
100 "bearer".to_string()
101}
102
103fn default_inference_path_prefix() -> String {
104 "/v1".to_string()
105}
106
107impl GatewayConfig {
108 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
109 self.routes.iter().find(|route| route.matches(model))
110 }
111
112 pub fn validate_routes(&self) -> GatewayResult<()> {
113 for route in &self.routes {
114 validate_endpoint(&format!("route '{}'", route.model_pattern), &route.endpoint)?;
115 }
116 Ok(())
117 }
118}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
121#[serde(deny_unknown_fields)]
122pub struct GatewayCatalog {
123 #[serde(default)]
124 pub providers: Vec<GatewayProvider>,
125 #[serde(default)]
126 pub models: Vec<GatewayModel>,
127}
128
129impl GatewayCatalog {
130 pub fn validate(&self) -> GatewayResult<()> {
131 for model in &self.models {
132 if model.id.is_empty() {
133 return Err(GatewayProfileError::ModelEmptyId);
134 }
135 if !self.providers.iter().any(|p| p.name == model.provider) {
136 return Err(GatewayProfileError::UnknownProvider {
137 model: model.id.clone(),
138 provider: model.provider.clone(),
139 });
140 }
141 }
142 for provider in &self.providers {
143 if provider.name.is_empty() {
144 return Err(GatewayProfileError::ProviderEmptyName);
145 }
146 if provider.endpoint.is_empty() {
147 return Err(GatewayProfileError::ProviderEmptyEndpoint {
148 name: provider.name.clone(),
149 });
150 }
151 validate_endpoint(&format!("provider '{}'", provider.name), &provider.endpoint)?;
152 }
153 Ok(())
154 }
155
156 pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
157 self.providers.iter().find(|p| p.name == name)
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
162#[serde(deny_unknown_fields)]
163pub struct GatewayProvider {
164 pub name: String,
165 pub endpoint: String,
166 pub api_key_secret: String,
167 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
168 pub extra_headers: HashMap<String, String>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
172#[serde(deny_unknown_fields)]
173pub struct GatewayModel {
174 pub id: String,
175 pub provider: String,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub display_name: Option<String>,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
179 pub upstream_model: Option<String>,
180 #[serde(default, skip_serializing_if = "Option::is_none")]
181 pub pricing: Option<ModelPricing>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
185#[serde(deny_unknown_fields)]
186pub struct GatewayRoute {
187 #[serde(default)]
188 pub id: String,
189 pub model_pattern: String,
190 pub provider: String,
191 pub endpoint: String,
192 pub api_key_secret: String,
193 #[serde(default, skip_serializing_if = "Option::is_none")]
194 pub upstream_model: Option<String>,
195 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
196 pub extra_headers: HashMap<String, String>,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub pricing: Option<ModelPricing>,
199}
200
201impl GatewayRoute {
202 pub fn matches(&self, model: &str) -> bool {
203 match_pattern(&self.model_pattern, model)
204 }
205
206 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
207 self.upstream_model.as_deref().unwrap_or(requested)
208 }
209
210 pub fn ensure_id(&mut self) {
211 if self.id.trim().is_empty() {
212 self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
213 }
214 }
215}
216
217#[must_use]
224pub fn slugify_pattern(pattern: &str) -> String {
225 let mut out = String::with_capacity(pattern.len());
226 let mut last_dash = false;
227 for ch in pattern.chars() {
228 if ch == '*' {
229 out.push_str("star");
230 last_dash = false;
231 } else if ch.is_ascii_alphanumeric() {
232 for lc in ch.to_lowercase() {
233 out.push(lc);
234 }
235 last_dash = false;
236 } else if !last_dash && !out.is_empty() {
237 out.push('-');
238 last_dash = true;
239 }
240 }
241 while out.ends_with('-') {
242 out.pop();
243 }
244 while out.starts_with('-') {
245 out.remove(0);
246 }
247 if out.is_empty() {
248 out.push_str("route");
249 }
250 out
251}
252
253#[must_use]
257pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
258 let mut hasher = DefaultHasher::new();
259 model_pattern.hash(&mut hasher);
260 provider.hash(&mut hasher);
261 endpoint.hash(&mut hasher);
262 let h = hasher.finish();
263 let hash6: String = format!("{h:016x}").chars().take(6).collect();
264 format!("{}-{}", slugify_pattern(model_pattern), hash6)
265}
266
267fn match_pattern(pattern: &str, model: &str) -> bool {
268 if pattern == "*" {
269 return true;
270 }
271 if let Some(prefix) = pattern.strip_suffix('*') {
272 return model.starts_with(prefix);
273 }
274 if let Some(suffix) = pattern.strip_prefix('*') {
275 return model.ends_with(suffix);
276 }
277 pattern == model
278}