systemprompt_models/profile/gateway/
config.rs1use std::borrow::Cow;
9use std::collections::HashMap;
10
11use serde::{Deserialize, Serialize};
12use systemprompt_identifiers::{ProviderId, RouteId};
13
14use super::super::providers::ProviderRegistry;
15use super::error::{GatewayProfileError, GatewayResult};
16use super::route::GatewayRoute;
17
18#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
23#[serde(deny_unknown_fields)]
24pub struct GatewayConfigSpec {
25 #[serde(default)]
26 pub enabled: bool,
27 #[serde(default)]
28 pub routes: Vec<GatewayRoute>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub default_provider: Option<ProviderId>,
34 #[serde(default = "default_auth_scheme")]
35 pub auth_scheme: String,
36 #[serde(default = "default_inference_path_prefix")]
37 pub inference_path_prefix: String,
38}
39
40impl Default for GatewayConfigSpec {
41 fn default() -> Self {
42 Self {
43 enabled: false,
44 routes: Vec::new(),
45 default_provider: None,
46 auth_scheme: default_auth_scheme(),
47 inference_path_prefix: default_inference_path_prefix(),
48 }
49 }
50}
51
52fn default_auth_scheme() -> String {
53 "bearer".to_owned()
54}
55
56fn default_inference_path_prefix() -> String {
57 "/v1".to_owned()
58}
59
60impl GatewayConfigSpec {
61 #[must_use]
64 pub fn resolve(self) -> GatewayConfig {
65 let Self {
66 enabled,
67 routes,
68 default_provider,
69 auth_scheme,
70 inference_path_prefix,
71 } = self;
72
73 GatewayConfig {
74 enabled,
75 routes,
76 default_provider,
77 auth_scheme,
78 inference_path_prefix,
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
90pub struct GatewayConfig {
91 pub enabled: bool,
92 pub routes: Vec<GatewayRoute>,
93 pub default_provider: Option<ProviderId>,
94 pub auth_scheme: String,
95 pub inference_path_prefix: String,
96}
97
98impl Default for GatewayConfig {
99 fn default() -> Self {
100 Self {
101 enabled: false,
102 routes: Vec::new(),
103 default_provider: None,
104 auth_scheme: default_auth_scheme(),
105 inference_path_prefix: default_inference_path_prefix(),
106 }
107 }
108}
109
110impl GatewayConfig {
111 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
112 self.routes.iter().find(|route| route.matches(model))
113 }
114
115 #[must_use]
120 pub fn resolve_route<'a>(
121 &'a self,
122 registry: &ProviderRegistry,
123 model: &str,
124 ) -> Option<Cow<'a, GatewayRoute>> {
125 if let Some(route) = self.find_route(model) {
126 return Some(Cow::Borrowed(route));
127 }
128 self.synthesize_default_route(registry).map(Cow::Owned)
129 }
130
131 fn synthesize_default_route(&self, registry: &ProviderRegistry) -> Option<GatewayRoute> {
136 let provider = self.default_provider.as_ref()?;
137 registry.find_provider(provider.as_str())?;
138 let mut route = GatewayRoute {
139 id: RouteId::new(""),
140 model_pattern: "*".to_owned(),
141 provider: provider.clone(),
142 upstream_model: None,
143 extra_headers: HashMap::new(),
144 pricing: None,
145 };
146 route.ensure_id();
147 Some(route)
148 }
149
150 #[must_use]
151 pub fn is_model_exposed(&self, registry: &ProviderRegistry, model: &str) -> bool {
152 self.default_provider.is_some()
153 || self.find_route(model).is_some()
154 || registry.contains_model(model)
155 }
156
157 pub fn validate(&self, registry: &ProviderRegistry) -> GatewayResult<()> {
161 let mut route_ids: std::collections::HashSet<&str> =
162 std::collections::HashSet::with_capacity(self.routes.len());
163 for route in &self.routes {
164 if !route_ids.insert(route.id.as_str()) {
165 return Err(GatewayProfileError::DuplicateRouteId {
166 id: route.id.as_str().to_owned(),
167 });
168 }
169 }
170 if let Some(provider) = self.default_provider.as_ref() {
171 if registry.find_provider(provider.as_str()).is_none() {
172 return Err(GatewayProfileError::DefaultProviderNotInRegistry {
173 provider: provider.as_str().to_owned(),
174 });
175 }
176 }
177 for route in &self.routes {
178 if registry.find_provider(route.provider.as_str()).is_none() {
179 return Err(GatewayProfileError::RouteProviderNotInRegistry {
180 route: route.model_pattern.clone(),
181 provider: route.provider.as_str().to_owned(),
182 });
183 }
184 }
185 Ok(())
186 }
187
188 #[must_use]
191 pub fn to_spec(&self) -> GatewayConfigSpec {
192 GatewayConfigSpec {
193 enabled: self.enabled,
194 routes: self.routes.clone(),
195 default_provider: self.default_provider.clone(),
196 auth_scheme: self.auth_scheme.clone(),
197 inference_path_prefix: self.inference_path_prefix.clone(),
198 }
199 }
200}