Skip to main content

systemprompt_models/profile/
gateway.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum GatewayProfileError {
8    #[error("Failed to read gateway catalog {path}: {source}")]
9    CatalogRead {
10        path: PathBuf,
11        #[source]
12        source: std::io::Error,
13    },
14
15    #[error("Failed to parse gateway catalog {path}: {source}")]
16    CatalogParse {
17        path: PathBuf,
18        #[source]
19        source: serde_yaml::Error,
20    },
21
22    #[error("Invalid gateway catalog {path}: {source}")]
23    CatalogInvalid {
24        path: PathBuf,
25        #[source]
26        source: Box<Self>,
27    },
28
29    #[error("gateway catalog model has empty id")]
30    ModelEmptyId,
31
32    #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
33    UnknownProvider { model: String, provider: String },
34
35    #[error("gateway catalog provider has empty name")]
36    ProviderEmptyName,
37
38    #[error("gateway catalog provider '{name}' has empty endpoint")]
39    ProviderEmptyEndpoint { name: String },
40}
41
42pub type GatewayResult<T> = Result<T, GatewayProfileError>;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct GatewayConfig {
46    #[serde(default)]
47    pub enabled: bool,
48    #[serde(default)]
49    pub routes: Vec<GatewayRoute>,
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub catalog_path: Option<PathBuf>,
52    #[serde(default, skip)]
53    pub catalog: Option<GatewayCatalog>,
54    #[serde(default = "default_auth_scheme")]
55    pub auth_scheme: String,
56    #[serde(default = "default_inference_path_prefix")]
57    pub inference_path_prefix: String,
58}
59
60impl Default for GatewayConfig {
61    fn default() -> Self {
62        Self {
63            enabled: false,
64            routes: Vec::new(),
65            catalog_path: None,
66            catalog: None,
67            auth_scheme: default_auth_scheme(),
68            inference_path_prefix: default_inference_path_prefix(),
69        }
70    }
71}
72
73fn default_auth_scheme() -> String {
74    "bearer".to_string()
75}
76
77fn default_inference_path_prefix() -> String {
78    "/v1".to_string()
79}
80
81impl GatewayConfig {
82    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
83        self.routes.iter().find(|route| route.matches(model))
84    }
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88pub struct GatewayCatalog {
89    #[serde(default)]
90    pub providers: Vec<GatewayProvider>,
91    #[serde(default)]
92    pub models: Vec<GatewayModel>,
93}
94
95impl GatewayCatalog {
96    pub fn validate(&self) -> GatewayResult<()> {
97        for model in &self.models {
98            if model.id.is_empty() {
99                return Err(GatewayProfileError::ModelEmptyId);
100            }
101            if !self.providers.iter().any(|p| p.name == model.provider) {
102                return Err(GatewayProfileError::UnknownProvider {
103                    model: model.id.clone(),
104                    provider: model.provider.clone(),
105                });
106            }
107        }
108        for provider in &self.providers {
109            if provider.name.is_empty() {
110                return Err(GatewayProfileError::ProviderEmptyName);
111            }
112            if provider.endpoint.is_empty() {
113                return Err(GatewayProfileError::ProviderEmptyEndpoint {
114                    name: provider.name.clone(),
115                });
116            }
117        }
118        Ok(())
119    }
120
121    pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
122        self.providers.iter().find(|p| p.name == name)
123    }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GatewayProvider {
128    pub name: String,
129    pub endpoint: String,
130    pub api_key_secret: String,
131    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
132    pub extra_headers: HashMap<String, String>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct GatewayModel {
137    pub id: String,
138    pub provider: String,
139    #[serde(default, skip_serializing_if = "Option::is_none")]
140    pub display_name: Option<String>,
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub upstream_model: Option<String>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct GatewayRoute {
147    pub model_pattern: String,
148    pub provider: String,
149    pub endpoint: String,
150    pub api_key_secret: String,
151    #[serde(default, skip_serializing_if = "Option::is_none")]
152    pub upstream_model: Option<String>,
153    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
154    pub extra_headers: HashMap<String, String>,
155}
156
157impl GatewayRoute {
158    pub fn matches(&self, model: &str) -> bool {
159        match_pattern(&self.model_pattern, model)
160    }
161
162    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
163        self.upstream_model.as_deref().unwrap_or(requested)
164    }
165}
166
167fn match_pattern(pattern: &str, model: &str) -> bool {
168    if pattern == "*" {
169        return true;
170    }
171    if let Some(prefix) = pattern.strip_suffix('*') {
172        return model.starts_with(prefix);
173    }
174    if let Some(suffix) = pattern.strip_prefix('*') {
175        return model.ends_with(suffix);
176    }
177    pattern == model
178}