systemprompt_models/profile/
gateway.rs1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum GatewayProfileError {
11 #[error("Failed to read gateway catalog {path}: {source}")]
12 CatalogRead {
13 path: PathBuf,
14 #[source]
15 source: std::io::Error,
16 },
17
18 #[error("Failed to parse gateway catalog {path}: {source}")]
19 CatalogParse {
20 path: PathBuf,
21 #[source]
22 source: serde_yaml::Error,
23 },
24
25 #[error("Invalid gateway catalog {path}: {source}")]
26 CatalogInvalid {
27 path: PathBuf,
28 #[source]
29 source: Box<Self>,
30 },
31
32 #[error("gateway catalog model has empty id")]
33 ModelEmptyId,
34
35 #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
36 UnknownProvider { model: String, provider: String },
37
38 #[error("gateway catalog provider has empty name")]
39 ProviderEmptyName,
40
41 #[error("gateway catalog provider '{name}' has empty endpoint")]
42 ProviderEmptyEndpoint { name: String },
43}
44
45pub type GatewayResult<T> = Result<T, GatewayProfileError>;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct GatewayConfig {
49 #[serde(default)]
50 pub enabled: bool,
51 #[serde(default)]
52 pub routes: Vec<GatewayRoute>,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub catalog_path: Option<PathBuf>,
55 #[serde(default, skip)]
56 pub catalog: Option<GatewayCatalog>,
57 #[serde(default = "default_auth_scheme")]
58 pub auth_scheme: String,
59 #[serde(default = "default_inference_path_prefix")]
60 pub inference_path_prefix: String,
61}
62
63impl Default for GatewayConfig {
64 fn default() -> Self {
65 Self {
66 enabled: false,
67 routes: Vec::new(),
68 catalog_path: None,
69 catalog: None,
70 auth_scheme: default_auth_scheme(),
71 inference_path_prefix: default_inference_path_prefix(),
72 }
73 }
74}
75
76fn default_auth_scheme() -> String {
77 "bearer".to_string()
78}
79
80fn default_inference_path_prefix() -> String {
81 "/v1".to_string()
82}
83
84impl GatewayConfig {
85 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
86 self.routes.iter().find(|route| route.matches(model))
87 }
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct GatewayCatalog {
92 #[serde(default)]
93 pub providers: Vec<GatewayProvider>,
94 #[serde(default)]
95 pub models: Vec<GatewayModel>,
96}
97
98impl GatewayCatalog {
99 pub fn validate(&self) -> GatewayResult<()> {
100 for model in &self.models {
101 if model.id.is_empty() {
102 return Err(GatewayProfileError::ModelEmptyId);
103 }
104 if !self.providers.iter().any(|p| p.name == model.provider) {
105 return Err(GatewayProfileError::UnknownProvider {
106 model: model.id.clone(),
107 provider: model.provider.clone(),
108 });
109 }
110 }
111 for provider in &self.providers {
112 if provider.name.is_empty() {
113 return Err(GatewayProfileError::ProviderEmptyName);
114 }
115 if provider.endpoint.is_empty() {
116 return Err(GatewayProfileError::ProviderEmptyEndpoint {
117 name: provider.name.clone(),
118 });
119 }
120 }
121 Ok(())
122 }
123
124 pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
125 self.providers.iter().find(|p| p.name == name)
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct GatewayProvider {
131 pub name: String,
132 pub endpoint: String,
133 pub api_key_secret: String,
134 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
135 pub extra_headers: HashMap<String, String>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct GatewayModel {
140 pub id: String,
141 pub provider: String,
142 #[serde(default, skip_serializing_if = "Option::is_none")]
143 pub display_name: Option<String>,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub upstream_model: Option<String>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub pricing: Option<ModelPricing>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct GatewayRoute {
152 #[serde(default)]
153 pub id: String,
154 pub model_pattern: String,
155 pub provider: String,
156 pub endpoint: String,
157 pub api_key_secret: String,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub upstream_model: Option<String>,
160 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
161 pub extra_headers: HashMap<String, String>,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub pricing: Option<ModelPricing>,
164}
165
166impl GatewayRoute {
167 pub fn matches(&self, model: &str) -> bool {
168 match_pattern(&self.model_pattern, model)
169 }
170
171 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
172 self.upstream_model.as_deref().unwrap_or(requested)
173 }
174
175 pub fn ensure_id(&mut self) {
176 if self.id.trim().is_empty() {
177 self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
178 }
179 }
180}
181
182#[must_use]
189pub fn slugify_pattern(pattern: &str) -> String {
190 let mut out = String::with_capacity(pattern.len());
191 let mut last_dash = false;
192 for ch in pattern.chars() {
193 if ch == '*' {
194 out.push_str("star");
195 last_dash = false;
196 } else if ch.is_ascii_alphanumeric() {
197 for lc in ch.to_lowercase() {
198 out.push(lc);
199 }
200 last_dash = false;
201 } else if !last_dash && !out.is_empty() {
202 out.push('-');
203 last_dash = true;
204 }
205 }
206 while out.ends_with('-') {
207 out.pop();
208 }
209 while out.starts_with('-') {
210 out.remove(0);
211 }
212 if out.is_empty() {
213 out.push_str("route");
214 }
215 out
216}
217
218#[must_use]
224pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
225 let mut hasher = DefaultHasher::new();
226 model_pattern.hash(&mut hasher);
227 provider.hash(&mut hasher);
228 endpoint.hash(&mut hasher);
229 let h = hasher.finish();
230 let hash6: String = format!("{h:016x}").chars().take(6).collect();
231 format!("{}-{}", slugify_pattern(model_pattern), hash6)
232}
233
234fn match_pattern(pattern: &str, model: &str) -> bool {
235 if pattern == "*" {
236 return true;
237 }
238 if let Some(prefix) = pattern.strip_suffix('*') {
239 return model.starts_with(prefix);
240 }
241 if let Some(suffix) = pattern.strip_prefix('*') {
242 return model.ends_with(suffix);
243 }
244 pattern == model
245}