systemprompt_models/profile/gateway/
config.rs1use std::borrow::Cow;
9use std::collections::HashMap;
10
11use serde::{Deserialize, Serialize};
12use systemprompt_identifiers::{ProviderId, RouteId};
13
14use super::super::providers::ProviderRegistry;
15use super::error::{GatewayProfileError, GatewayResult};
16use super::override_rule::SystemPromptRule;
17use super::route::GatewayRoute;
18
19pub(crate) const DEFAULT_ROUTE_PATTERN: &str = "*";
20
21#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
22#[serde(deny_unknown_fields)]
23pub struct GatewayConfigSpec {
24 #[serde(default)]
25 pub enabled: bool,
26 #[serde(default)]
27 pub routes: Vec<GatewayRoute>,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub default_provider: Option<ProviderId>,
33 #[serde(default)]
38 pub allow_unlisted_models: bool,
39 #[serde(default = "default_auth_scheme")]
40 pub auth_scheme: String,
41 #[serde(default = "default_inference_path_prefix")]
42 pub inference_path_prefix: String,
43 #[serde(default, skip_serializing_if = "Vec::is_empty")]
44 pub system_prompt_overrides: Vec<SystemPromptRule>,
45}
46
47impl Default for GatewayConfigSpec {
48 fn default() -> Self {
49 Self {
50 enabled: false,
51 routes: Vec::new(),
52 default_provider: None,
53 allow_unlisted_models: false,
54 auth_scheme: default_auth_scheme(),
55 inference_path_prefix: default_inference_path_prefix(),
56 system_prompt_overrides: Vec::new(),
57 }
58 }
59}
60
61fn default_auth_scheme() -> String {
62 "bearer".to_owned()
63}
64
65fn default_inference_path_prefix() -> String {
66 "/v1".to_owned()
67}
68
69impl GatewayConfigSpec {
70 #[must_use]
71 pub fn resolve(self) -> GatewayConfig {
72 let Self {
73 enabled,
74 routes,
75 default_provider,
76 allow_unlisted_models,
77 auth_scheme,
78 inference_path_prefix,
79 system_prompt_overrides,
80 } = self;
81
82 GatewayConfig {
83 enabled,
84 routes,
85 default_provider,
86 allow_unlisted_models,
87 auth_scheme,
88 inference_path_prefix,
89 system_prompt_overrides,
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
101pub struct GatewayConfig {
102 pub enabled: bool,
103 pub routes: Vec<GatewayRoute>,
104 pub default_provider: Option<ProviderId>,
105 pub allow_unlisted_models: bool,
106 pub auth_scheme: String,
107 pub inference_path_prefix: String,
108 pub system_prompt_overrides: Vec<SystemPromptRule>,
109}
110
111impl Default for GatewayConfig {
112 fn default() -> Self {
113 Self {
114 enabled: false,
115 routes: Vec::new(),
116 default_provider: None,
117 allow_unlisted_models: false,
118 auth_scheme: default_auth_scheme(),
119 inference_path_prefix: default_inference_path_prefix(),
120 system_prompt_overrides: Vec::new(),
121 }
122 }
123}
124
125impl GatewayConfig {
126 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
127 self.routes.iter().find(|route| route.matches(model))
128 }
129
130 pub fn candidate_routes<'a>(
131 &'a self,
132 registry: &ProviderRegistry,
133 ) -> impl Iterator<Item = Cow<'a, GatewayRoute>> {
134 self.routes
135 .iter()
136 .map(Cow::Borrowed)
137 .chain(self.synthesize_default_route(registry).map(Cow::Owned))
138 }
139
140 #[must_use]
141 pub fn resolve_route<'a>(
142 &'a self,
143 registry: &ProviderRegistry,
144 model: &str,
145 ) -> Option<Cow<'a, GatewayRoute>> {
146 self.candidate_routes(registry)
147 .find(|route| route.matches(model))
148 }
149
150 #[must_use]
151 pub fn dispatchable_route_ids(&self, registry: &ProviderRegistry) -> Vec<RouteId> {
152 let mut ids: Vec<RouteId> = Vec::new();
153 let mut seen: std::collections::HashSet<RouteId> = std::collections::HashSet::new();
154 for route in self.candidate_routes(registry) {
155 let mut route = route.into_owned();
156 route.ensure_id();
157 if seen.insert(route.id.clone()) {
158 ids.push(route.id);
159 }
160 }
161 ids
162 }
163
164 fn synthesize_default_route(&self, registry: &ProviderRegistry) -> Option<GatewayRoute> {
167 let provider = self.default_provider.as_ref()?;
168 registry.find_provider(provider.as_str())?;
169 let mut route = GatewayRoute {
170 id: RouteId::new(""),
171 model_pattern: DEFAULT_ROUTE_PATTERN.to_owned(),
172 provider: provider.clone(),
173 upstream_model: None,
174 extra_headers: HashMap::new(),
175 pricing: None,
176 };
177 route.ensure_id();
178 Some(route)
179 }
180
181 #[must_use]
186 pub fn is_model_exposed(&self, registry: &ProviderRegistry, model: &str) -> bool {
187 if self.find_route(model).is_some() || registry.contains_model(model) {
188 return true;
189 }
190 if self.default_provider.is_some() && self.allow_unlisted_models {
191 tracing::warn!(
192 model,
193 "gateway forwarding an unlisted model to default_provider \
194 (allow_unlisted_models=true): open allowlist posture"
195 );
196 return true;
197 }
198 false
199 }
200
201 pub fn validate(&self, registry: &ProviderRegistry) -> GatewayResult<()> {
202 let mut route_ids: std::collections::HashSet<&str> =
203 std::collections::HashSet::with_capacity(self.routes.len());
204 for route in &self.routes {
205 if !route_ids.insert(route.id.as_str()) {
206 return Err(GatewayProfileError::DuplicateRouteId {
207 id: route.id.as_str().to_owned(),
208 });
209 }
210 }
211 if let Some(provider) = self.default_provider.as_ref()
212 && registry.find_provider(provider.as_str()).is_none()
213 {
214 return Err(GatewayProfileError::DefaultProviderNotInRegistry {
215 provider: provider.as_str().to_owned(),
216 });
217 }
218 for route in &self.routes {
219 if registry.find_provider(route.provider.as_str()).is_none() {
220 return Err(GatewayProfileError::RouteProviderNotInRegistry {
221 route: route.model_pattern.clone(),
222 provider: route.provider.as_str().to_owned(),
223 });
224 }
225 }
226 for rule in &self.system_prompt_overrides {
227 rule.validate()?;
228 if let Some(provider) = rule.provider.as_ref()
229 && registry.find_provider(provider.as_str()).is_none()
230 {
231 return Err(GatewayProfileError::OverrideProviderNotInRegistry {
232 provider: provider.as_str().to_owned(),
233 });
234 }
235 }
236 Ok(())
237 }
238
239 #[must_use]
240 pub fn to_spec(&self) -> GatewayConfigSpec {
241 GatewayConfigSpec {
242 enabled: self.enabled,
243 routes: self.routes.clone(),
244 default_provider: self.default_provider.clone(),
245 allow_unlisted_models: self.allow_unlisted_models,
246 auth_scheme: self.auth_scheme.clone(),
247 inference_path_prefix: self.inference_path_prefix.clone(),
248 system_prompt_overrides: self.system_prompt_overrides.clone(),
249 }
250 }
251}