systemprompt_models/profile/gateway/
config.rs1use std::path::{Path, PathBuf};
2
3use serde::{Deserialize, Serialize};
4
5use super::catalog::GatewayCatalog;
6use super::error::{GatewayProfileError, GatewayResult};
7use super::route::GatewayRoute;
8
9#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
16#[serde(deny_unknown_fields)]
17pub struct GatewayConfigSpec {
18 #[serde(default)]
19 pub enabled: bool,
20 #[serde(default)]
21 pub routes: Vec<GatewayRoute>,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub catalog: Option<GatewayCatalogSource>,
24 #[serde(default = "default_auth_scheme")]
25 pub auth_scheme: String,
26 #[serde(default = "default_inference_path_prefix")]
27 pub inference_path_prefix: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
36#[serde(untagged, deny_unknown_fields)]
37pub enum GatewayCatalogSource {
38 Path { path: PathBuf },
39 Inline(GatewayCatalog),
40}
41
42impl Default for GatewayConfigSpec {
43 fn default() -> Self {
44 Self {
45 enabled: false,
46 routes: Vec::new(),
47 catalog: None,
48 auth_scheme: default_auth_scheme(),
49 inference_path_prefix: default_inference_path_prefix(),
50 }
51 }
52}
53
54fn default_auth_scheme() -> String {
55 "bearer".to_owned()
56}
57
58fn default_inference_path_prefix() -> String {
59 "/v1".to_owned()
60}
61
62impl GatewayConfigSpec {
63 pub fn resolve(self, profile_dir: &Path) -> GatewayResult<GatewayConfig> {
67 let Self {
68 enabled,
69 routes,
70 catalog,
71 auth_scheme,
72 inference_path_prefix,
73 } = self;
74
75 let catalog = match catalog {
76 None => None,
77 Some(GatewayCatalogSource::Inline(c)) => {
78 c.validate()?;
79 Some(c)
80 },
81 Some(GatewayCatalogSource::Path { path: rel }) => {
82 let absolute = if rel.is_absolute() {
83 rel
84 } else {
85 profile_dir.join(rel)
86 };
87 let content = std::fs::read_to_string(&absolute).map_err(|source| {
88 GatewayProfileError::CatalogRead {
89 path: absolute.clone(),
90 source,
91 }
92 })?;
93 let parsed: GatewayCatalog = serde_yaml::from_str(&content).map_err(|source| {
94 GatewayProfileError::CatalogParse {
95 path: absolute.clone(),
96 source,
97 }
98 })?;
99 parsed
100 .validate()
101 .map_err(|source| GatewayProfileError::CatalogInvalid {
102 path: absolute.clone(),
103 source: Box::new(source),
104 })?;
105 Some(parsed)
106 },
107 };
108
109 Ok(GatewayConfig {
110 enabled,
111 routes,
112 catalog,
113 auth_scheme,
114 inference_path_prefix,
115 })
116 }
117}
118
119#[derive(Debug, Clone)]
127pub struct GatewayConfig {
128 pub enabled: bool,
129 pub routes: Vec<GatewayRoute>,
130 pub catalog: Option<GatewayCatalog>,
131 pub auth_scheme: String,
132 pub inference_path_prefix: String,
133}
134
135impl Default for GatewayConfig {
136 fn default() -> Self {
137 Self {
138 enabled: false,
139 routes: Vec::new(),
140 catalog: None,
141 auth_scheme: default_auth_scheme(),
142 inference_path_prefix: default_inference_path_prefix(),
143 }
144 }
145}
146
147impl GatewayConfig {
148 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
149 self.routes.iter().find(|route| route.matches(model))
150 }
151
152 #[must_use]
153 pub fn is_model_exposed(&self, model: &str) -> bool {
154 self.catalog
155 .as_ref()
156 .is_none_or(|c| c.contains_model(model))
157 }
158
159 pub fn validate(&self) -> GatewayResult<()> {
160 let mut route_ids: std::collections::HashSet<&str> =
161 std::collections::HashSet::with_capacity(self.routes.len());
162 for route in &self.routes {
163 if !route_ids.insert(route.id.as_str()) {
164 return Err(GatewayProfileError::DuplicateRouteId {
165 id: route.id.as_str().to_owned(),
166 });
167 }
168 }
169 let Some(catalog) = self.catalog.as_ref() else {
170 return Ok(());
171 };
172 catalog.validate()?;
173 for route in &self.routes {
174 if catalog.find_provider(route.provider.as_str()).is_none() {
175 return Err(GatewayProfileError::RouteProviderNotInCatalog {
176 route: route.model_pattern.clone(),
177 provider: route.provider.as_str().to_owned(),
178 });
179 }
180 }
181 let mut seen = std::collections::HashSet::with_capacity(catalog.models.len());
182 for model in &catalog.models {
183 if !seen.insert(model.id.as_str()) {
184 return Err(GatewayProfileError::DuplicateModelId {
185 id: model.id.as_str().to_owned(),
186 });
187 }
188 for alias in &model.aliases {
189 if !seen.insert(alias.as_str()) {
190 return Err(GatewayProfileError::DuplicateModelId {
191 id: alias.as_str().to_owned(),
192 });
193 }
194 }
195 if !self.routes.iter().any(|r| r.matches(model.id.as_str())) {
196 return Err(GatewayProfileError::UnreachableModel {
197 model: model.id.as_str().to_owned(),
198 });
199 }
200 }
201 Ok(())
202 }
203
204 #[must_use]
207 pub fn to_spec(&self) -> GatewayConfigSpec {
208 GatewayConfigSpec {
209 enabled: self.enabled,
210 routes: self.routes.clone(),
211 catalog: self.catalog.clone().map(GatewayCatalogSource::Inline),
212 auth_scheme: self.auth_scheme.clone(),
213 inference_path_prefix: self.inference_path_prefix.clone(),
214 }
215 }
216}