systemprompt_models/profile/
gateway.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum GatewayProfileError {
8 #[error("Failed to read gateway catalog {path}: {source}")]
9 CatalogRead {
10 path: PathBuf,
11 #[source]
12 source: std::io::Error,
13 },
14
15 #[error("Failed to parse gateway catalog {path}: {source}")]
16 CatalogParse {
17 path: PathBuf,
18 #[source]
19 source: serde_yaml::Error,
20 },
21
22 #[error("Invalid gateway catalog {path}: {source}")]
23 CatalogInvalid {
24 path: PathBuf,
25 #[source]
26 source: Box<Self>,
27 },
28
29 #[error("gateway catalog model has empty id")]
30 ModelEmptyId,
31
32 #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
33 UnknownProvider { model: String, provider: String },
34
35 #[error("gateway catalog provider has empty name")]
36 ProviderEmptyName,
37
38 #[error("gateway catalog provider '{name}' has empty endpoint")]
39 ProviderEmptyEndpoint { name: String },
40}
41
42pub type GatewayResult<T> = Result<T, GatewayProfileError>;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct GatewayConfig {
46 #[serde(default)]
47 pub enabled: bool,
48 #[serde(default)]
49 pub routes: Vec<GatewayRoute>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub catalog_path: Option<PathBuf>,
52 #[serde(default, skip)]
53 pub catalog: Option<GatewayCatalog>,
54 #[serde(default = "default_auth_scheme")]
55 pub auth_scheme: String,
56 #[serde(default = "default_inference_path_prefix")]
57 pub inference_path_prefix: String,
58}
59
60impl Default for GatewayConfig {
61 fn default() -> Self {
62 Self {
63 enabled: false,
64 routes: Vec::new(),
65 catalog_path: None,
66 catalog: None,
67 auth_scheme: default_auth_scheme(),
68 inference_path_prefix: default_inference_path_prefix(),
69 }
70 }
71}
72
73fn default_auth_scheme() -> String {
74 "bearer".to_string()
75}
76
77fn default_inference_path_prefix() -> String {
78 "/v1".to_string()
79}
80
81impl GatewayConfig {
82 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
83 self.routes.iter().find(|route| route.matches(model))
84 }
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88pub struct GatewayCatalog {
89 #[serde(default)]
90 pub providers: Vec<GatewayProvider>,
91 #[serde(default)]
92 pub models: Vec<GatewayModel>,
93}
94
95impl GatewayCatalog {
96 pub fn validate(&self) -> GatewayResult<()> {
97 for model in &self.models {
98 if model.id.is_empty() {
99 return Err(GatewayProfileError::ModelEmptyId);
100 }
101 if !self.providers.iter().any(|p| p.name == model.provider) {
102 return Err(GatewayProfileError::UnknownProvider {
103 model: model.id.clone(),
104 provider: model.provider.clone(),
105 });
106 }
107 }
108 for provider in &self.providers {
109 if provider.name.is_empty() {
110 return Err(GatewayProfileError::ProviderEmptyName);
111 }
112 if provider.endpoint.is_empty() {
113 return Err(GatewayProfileError::ProviderEmptyEndpoint {
114 name: provider.name.clone(),
115 });
116 }
117 }
118 Ok(())
119 }
120
121 pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
122 self.providers.iter().find(|p| p.name == name)
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GatewayProvider {
128 pub name: String,
129 pub endpoint: String,
130 pub api_key_secret: String,
131 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
132 pub extra_headers: HashMap<String, String>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct GatewayModel {
137 pub id: String,
138 pub provider: String,
139 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub display_name: Option<String>,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub upstream_model: Option<String>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct GatewayRoute {
147 pub model_pattern: String,
148 pub provider: String,
149 pub endpoint: String,
150 pub api_key_secret: String,
151 #[serde(default, skip_serializing_if = "Option::is_none")]
152 pub upstream_model: Option<String>,
153 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
154 pub extra_headers: HashMap<String, String>,
155}
156
157impl GatewayRoute {
158 pub fn matches(&self, model: &str) -> bool {
159 match_pattern(&self.model_pattern, model)
160 }
161
162 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
163 self.upstream_model.as_deref().unwrap_or(requested)
164 }
165}
166
167fn match_pattern(pattern: &str, model: &str) -> bool {
168 if pattern == "*" {
169 return true;
170 }
171 if let Some(prefix) = pattern.strip_suffix('*') {
172 return model.starts_with(prefix);
173 }
174 if let Some(suffix) = pattern.strip_prefix('*') {
175 return model.ends_with(suffix);
176 }
177 pattern == model
178}