systemprompt_models/profile/gateway/
config.rs1use std::path::{Path, PathBuf};
9
10use serde::{Deserialize, Serialize};
11
12use super::catalog::GatewayCatalog;
13use super::error::{GatewayProfileError, GatewayResult};
14use super::route::GatewayRoute;
15
16#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
23#[serde(deny_unknown_fields)]
24pub struct GatewayConfigSpec {
25 #[serde(default)]
26 pub enabled: bool,
27 #[serde(default)]
28 pub routes: Vec<GatewayRoute>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub catalog: Option<GatewayCatalogSource>,
31 #[serde(default = "default_auth_scheme")]
32 pub auth_scheme: String,
33 #[serde(default = "default_inference_path_prefix")]
34 pub inference_path_prefix: String,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
43#[serde(untagged, deny_unknown_fields)]
44pub enum GatewayCatalogSource {
45 Path { path: PathBuf },
46 Inline(GatewayCatalog),
47}
48
49impl Default for GatewayConfigSpec {
50 fn default() -> Self {
51 Self {
52 enabled: false,
53 routes: Vec::new(),
54 catalog: None,
55 auth_scheme: default_auth_scheme(),
56 inference_path_prefix: default_inference_path_prefix(),
57 }
58 }
59}
60
61fn default_auth_scheme() -> String {
62 "bearer".to_owned()
63}
64
65fn default_inference_path_prefix() -> String {
66 "/v1".to_owned()
67}
68
69impl GatewayConfigSpec {
70 pub fn resolve(self, profile_dir: &Path) -> GatewayResult<GatewayConfig> {
74 let Self {
75 enabled,
76 routes,
77 catalog,
78 auth_scheme,
79 inference_path_prefix,
80 } = self;
81
82 let catalog = match catalog {
83 None => None,
84 Some(GatewayCatalogSource::Inline(c)) => {
85 c.validate()?;
86 Some(c)
87 },
88 Some(GatewayCatalogSource::Path { path: rel }) => {
89 let absolute = if rel.is_absolute() {
90 rel
91 } else {
92 profile_dir.join(rel)
93 };
94 let content = std::fs::read_to_string(&absolute).map_err(|source| {
95 GatewayProfileError::CatalogRead {
96 path: absolute.clone(),
97 source,
98 }
99 })?;
100 let parsed: GatewayCatalog = serde_yaml::from_str(&content).map_err(|source| {
101 GatewayProfileError::CatalogParse {
102 path: absolute.clone(),
103 source,
104 }
105 })?;
106 parsed
107 .validate()
108 .map_err(|source| GatewayProfileError::CatalogInvalid {
109 path: absolute.clone(),
110 source: Box::new(source),
111 })?;
112 Some(parsed)
113 },
114 };
115
116 Ok(GatewayConfig {
117 enabled,
118 routes,
119 catalog,
120 auth_scheme,
121 inference_path_prefix,
122 })
123 }
124}
125
126#[derive(Debug, Clone)]
134pub struct GatewayConfig {
135 pub enabled: bool,
136 pub routes: Vec<GatewayRoute>,
137 pub catalog: Option<GatewayCatalog>,
138 pub auth_scheme: String,
139 pub inference_path_prefix: String,
140}
141
142impl Default for GatewayConfig {
143 fn default() -> Self {
144 Self {
145 enabled: false,
146 routes: Vec::new(),
147 catalog: None,
148 auth_scheme: default_auth_scheme(),
149 inference_path_prefix: default_inference_path_prefix(),
150 }
151 }
152}
153
154impl GatewayConfig {
155 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
156 self.routes.iter().find(|route| route.matches(model))
157 }
158
159 #[must_use]
160 pub fn is_model_exposed(&self, model: &str) -> bool {
161 self.catalog
162 .as_ref()
163 .is_none_or(|c| c.contains_model(model))
164 }
165
166 pub fn validate(&self) -> GatewayResult<()> {
167 let mut route_ids: std::collections::HashSet<&str> =
168 std::collections::HashSet::with_capacity(self.routes.len());
169 for route in &self.routes {
170 if !route_ids.insert(route.id.as_str()) {
171 return Err(GatewayProfileError::DuplicateRouteId {
172 id: route.id.as_str().to_owned(),
173 });
174 }
175 }
176 let Some(catalog) = self.catalog.as_ref() else {
177 return Ok(());
178 };
179 catalog.validate()?;
180 for route in &self.routes {
181 if catalog.find_provider(route.provider.as_str()).is_none() {
182 return Err(GatewayProfileError::RouteProviderNotInCatalog {
183 route: route.model_pattern.clone(),
184 provider: route.provider.as_str().to_owned(),
185 });
186 }
187 }
188 let mut seen = std::collections::HashSet::with_capacity(catalog.models.len());
189 for model in &catalog.models {
190 if !seen.insert(model.id.as_str()) {
191 return Err(GatewayProfileError::DuplicateModelId {
192 id: model.id.as_str().to_owned(),
193 });
194 }
195 for alias in &model.aliases {
196 if !seen.insert(alias.as_str()) {
197 return Err(GatewayProfileError::DuplicateModelId {
198 id: alias.as_str().to_owned(),
199 });
200 }
201 }
202 if !self.routes.iter().any(|r| r.matches(model.id.as_str())) {
203 return Err(GatewayProfileError::UnreachableModel {
204 model: model.id.as_str().to_owned(),
205 });
206 }
207 }
208 Ok(())
209 }
210
211 #[must_use]
214 pub fn to_spec(&self) -> GatewayConfigSpec {
215 GatewayConfigSpec {
216 enabled: self.enabled,
217 routes: self.routes.clone(),
218 catalog: self.catalog.clone().map(GatewayCatalogSource::Inline),
219 auth_scheme: self.auth_scheme.clone(),
220 inference_path_prefix: self.inference_path_prefix.clone(),
221 }
222 }
223}