systemprompt_models/profile/gateway/
config.rs1use std::borrow::Cow;
9use std::collections::HashMap;
10
11use serde::{Deserialize, Serialize};
12use systemprompt_identifiers::{ProviderId, RouteId};
13
14use super::super::providers::ProviderRegistry;
15use super::error::{GatewayProfileError, GatewayResult};
16use super::route::GatewayRoute;
17
18pub(crate) const DEFAULT_ROUTE_PATTERN: &str = "*";
19
20#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
21#[serde(deny_unknown_fields)]
22pub struct GatewayConfigSpec {
23 #[serde(default)]
24 pub enabled: bool,
25 #[serde(default)]
26 pub routes: Vec<GatewayRoute>,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
31 pub default_provider: Option<ProviderId>,
32 #[serde(default)]
37 pub allow_unlisted_models: bool,
38 #[serde(default = "default_auth_scheme")]
39 pub auth_scheme: String,
40 #[serde(default = "default_inference_path_prefix")]
41 pub inference_path_prefix: String,
42}
43
44impl Default for GatewayConfigSpec {
45 fn default() -> Self {
46 Self {
47 enabled: false,
48 routes: Vec::new(),
49 default_provider: None,
50 allow_unlisted_models: false,
51 auth_scheme: default_auth_scheme(),
52 inference_path_prefix: default_inference_path_prefix(),
53 }
54 }
55}
56
57fn default_auth_scheme() -> String {
58 "bearer".to_owned()
59}
60
61fn default_inference_path_prefix() -> String {
62 "/v1".to_owned()
63}
64
65impl GatewayConfigSpec {
66 #[must_use]
67 pub fn resolve(self) -> GatewayConfig {
68 let Self {
69 enabled,
70 routes,
71 default_provider,
72 allow_unlisted_models,
73 auth_scheme,
74 inference_path_prefix,
75 } = self;
76
77 GatewayConfig {
78 enabled,
79 routes,
80 default_provider,
81 allow_unlisted_models,
82 auth_scheme,
83 inference_path_prefix,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
95pub struct GatewayConfig {
96 pub enabled: bool,
97 pub routes: Vec<GatewayRoute>,
98 pub default_provider: Option<ProviderId>,
99 pub allow_unlisted_models: bool,
100 pub auth_scheme: String,
101 pub inference_path_prefix: String,
102}
103
104impl Default for GatewayConfig {
105 fn default() -> Self {
106 Self {
107 enabled: false,
108 routes: Vec::new(),
109 default_provider: None,
110 allow_unlisted_models: false,
111 auth_scheme: default_auth_scheme(),
112 inference_path_prefix: default_inference_path_prefix(),
113 }
114 }
115}
116
117impl GatewayConfig {
118 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
119 self.routes.iter().find(|route| route.matches(model))
120 }
121
122 pub fn candidate_routes<'a>(
123 &'a self,
124 registry: &ProviderRegistry,
125 ) -> impl Iterator<Item = Cow<'a, GatewayRoute>> {
126 self.routes
127 .iter()
128 .map(Cow::Borrowed)
129 .chain(self.synthesize_default_route(registry).map(Cow::Owned))
130 }
131
132 #[must_use]
133 pub fn resolve_route<'a>(
134 &'a self,
135 registry: &ProviderRegistry,
136 model: &str,
137 ) -> Option<Cow<'a, GatewayRoute>> {
138 self.candidate_routes(registry)
139 .find(|route| route.matches(model))
140 }
141
142 #[must_use]
143 pub fn dispatchable_route_ids(&self, registry: &ProviderRegistry) -> Vec<RouteId> {
144 let mut ids: Vec<RouteId> = Vec::new();
145 let mut seen: std::collections::HashSet<RouteId> = std::collections::HashSet::new();
146 for route in self.candidate_routes(registry) {
147 let mut route = route.into_owned();
148 route.ensure_id();
149 if seen.insert(route.id.clone()) {
150 ids.push(route.id);
151 }
152 }
153 ids
154 }
155
156 fn synthesize_default_route(&self, registry: &ProviderRegistry) -> Option<GatewayRoute> {
159 let provider = self.default_provider.as_ref()?;
160 registry.find_provider(provider.as_str())?;
161 let mut route = GatewayRoute {
162 id: RouteId::new(""),
163 model_pattern: DEFAULT_ROUTE_PATTERN.to_owned(),
164 provider: provider.clone(),
165 upstream_model: None,
166 extra_headers: HashMap::new(),
167 pricing: None,
168 };
169 route.ensure_id();
170 Some(route)
171 }
172
173 #[must_use]
178 pub fn is_model_exposed(&self, registry: &ProviderRegistry, model: &str) -> bool {
179 if self.find_route(model).is_some() || registry.contains_model(model) {
180 return true;
181 }
182 if self.default_provider.is_some() && self.allow_unlisted_models {
183 tracing::warn!(
184 model,
185 "gateway forwarding an unlisted model to default_provider \
186 (allow_unlisted_models=true): open allowlist posture"
187 );
188 return true;
189 }
190 false
191 }
192
193 pub fn validate(&self, registry: &ProviderRegistry) -> GatewayResult<()> {
194 let mut route_ids: std::collections::HashSet<&str> =
195 std::collections::HashSet::with_capacity(self.routes.len());
196 for route in &self.routes {
197 if !route_ids.insert(route.id.as_str()) {
198 return Err(GatewayProfileError::DuplicateRouteId {
199 id: route.id.as_str().to_owned(),
200 });
201 }
202 }
203 if let Some(provider) = self.default_provider.as_ref() {
204 if registry.find_provider(provider.as_str()).is_none() {
205 return Err(GatewayProfileError::DefaultProviderNotInRegistry {
206 provider: provider.as_str().to_owned(),
207 });
208 }
209 }
210 for route in &self.routes {
211 if registry.find_provider(route.provider.as_str()).is_none() {
212 return Err(GatewayProfileError::RouteProviderNotInRegistry {
213 route: route.model_pattern.clone(),
214 provider: route.provider.as_str().to_owned(),
215 });
216 }
217 }
218 Ok(())
219 }
220
221 #[must_use]
222 pub fn to_spec(&self) -> GatewayConfigSpec {
223 GatewayConfigSpec {
224 enabled: self.enabled,
225 routes: self.routes.clone(),
226 default_provider: self.default_provider.clone(),
227 allow_unlisted_models: self.allow_unlisted_models,
228 auth_scheme: self.auth_scheme.clone(),
229 inference_path_prefix: self.inference_path_prefix.clone(),
230 }
231 }
232}