systemprompt_models/profile/providers/
mod.rs1mod error;
17mod protocol;
18mod surface;
19
20use std::collections::{HashMap, HashSet};
21
22use serde::{Deserialize, Serialize};
23use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
24
25use crate::services::ai::{ModelCapabilities, ModelLimits, ModelPricing};
26
27pub use error::{ProviderRegistryError, ProviderRegistryResult};
28pub use protocol::WireProtocol;
29pub use surface::ApiSurface;
30
31const DEFAULT_CATALOG_YAML: &str = include_str!("default_catalog.yaml");
32
33#[derive(Deserialize)]
34struct DefaultCatalogFile {
35 providers: Vec<ProviderEntry>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
39#[serde(deny_unknown_fields)]
40pub struct ProviderModel {
41 pub id: ModelId,
42
43 #[serde(default, skip_serializing_if = "Vec::is_empty")]
44 pub aliases: Vec<ModelId>,
45
46 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub upstream_model: Option<String>,
51
52 #[serde(default)]
53 pub pricing: ModelPricing,
54
55 #[serde(default)]
56 pub capabilities: ModelCapabilities,
57
58 #[serde(default)]
59 pub limits: ModelLimits,
60}
61
62impl ProviderModel {
63 #[must_use]
64 pub fn matches(&self, requested: &str) -> bool {
65 self.id.as_str() == requested || self.aliases.iter().any(|a| a.as_str() == requested)
66 }
67
68 #[must_use]
69 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
70 self.upstream_model.as_deref().unwrap_or(requested)
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
75#[serde(deny_unknown_fields)]
76pub struct ProviderEntry {
77 pub name: ProviderId,
78
79 pub wire: WireProtocol,
82
83 pub surface: ApiSurface,
87
88 pub endpoint: String,
89
90 pub api_key_secret: SecretName,
91
92 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
93 pub extra_headers: HashMap<String, String>,
94
95 #[serde(default, skip_serializing_if = "Vec::is_empty")]
96 pub models: Vec<ProviderModel>,
97}
98
99impl ProviderEntry {
100 #[must_use]
101 pub fn find_model(&self, requested: &str) -> Option<&ProviderModel> {
102 self.models.iter().find(|m| m.matches(requested))
103 }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
107#[serde(transparent)]
108pub struct ProviderRegistry {
109 pub providers: Vec<ProviderEntry>,
110}
111
112impl ProviderRegistry {
113 pub fn default_seed() -> ProviderRegistryResult<Self> {
114 let file: DefaultCatalogFile = serde_yaml::from_str(DEFAULT_CATALOG_YAML)
115 .map_err(|e| ProviderRegistryError::InvalidDefaultCatalog(e.to_string()))?;
116 Ok(Self {
117 providers: file.providers,
118 })
119 }
120
121 #[must_use]
122 pub fn find_provider(&self, name: &str) -> Option<&ProviderEntry> {
123 self.providers.iter().find(|p| p.name.as_str() == name)
124 }
125
126 #[must_use]
127 pub fn contains_model(&self, requested: &str) -> bool {
128 self.providers
129 .iter()
130 .any(|p| p.find_model(requested).is_some())
131 }
132
133 pub fn advertised_providers(&self) -> impl Iterator<Item = &ProviderEntry> {
139 self.providers
140 .iter()
141 .filter(|entry| entry.surface.is_advertised())
142 }
143
144 #[must_use]
151 pub fn advertised_model_ids(&self, surfaces: &[ApiSurface]) -> Vec<String> {
152 self.advertised_providers()
153 .filter(|entry| surfaces.is_empty() || surfaces.contains(&entry.surface))
154 .flat_map(|entry| {
155 entry.models.iter().flat_map(|m| {
156 std::iter::once(m.id.as_str().to_owned())
157 .chain(m.aliases.iter().map(|a| a.as_str().to_owned()))
158 })
159 })
160 .collect()
161 }
162
163 pub fn validate(&self) -> ProviderRegistryResult<()> {
164 let trusted = crate::net::trusted_http_hosts_from_env();
165 let mut seen_providers: HashSet<&str> = HashSet::with_capacity(self.providers.len());
166 let mut seen_models: HashSet<&str> = HashSet::new();
167
168 for provider in &self.providers {
169 if !seen_providers.insert(provider.name.as_str()) {
170 return Err(ProviderRegistryError::DuplicateProvider {
171 name: provider.name.as_str().to_owned(),
172 });
173 }
174 if provider.endpoint.is_empty() {
175 return Err(ProviderRegistryError::EmptyEndpoint {
176 name: provider.name.as_str().to_owned(),
177 });
178 }
179 crate::net::validate_outbound_url_with_trust(&provider.endpoint, &trusted).map_err(
180 |e| ProviderRegistryError::BlockedEndpoint {
181 provider: provider.name.as_str().to_owned(),
182 endpoint: provider.endpoint.clone(),
183 reason: e.to_string(),
184 },
185 )?;
186
187 for model in &provider.models {
188 if model.id.as_str().is_empty() {
189 return Err(ProviderRegistryError::EmptyModelId {
190 id: provider.name.as_str().to_owned(),
191 });
192 }
193 if !seen_models.insert(model.id.as_str()) {
194 return Err(ProviderRegistryError::DuplicateModel {
195 id: model.id.as_str().to_owned(),
196 });
197 }
198 for alias in &model.aliases {
199 if !seen_models.insert(alias.as_str()) {
200 return Err(ProviderRegistryError::DuplicateModel {
201 id: alias.as_str().to_owned(),
202 });
203 }
204 }
205 }
206 }
207 Ok(())
208 }
209}