systemprompt_models/profile/providers/
mod.rs1mod error;
17mod protocol;
18
19use std::collections::{HashMap, HashSet};
20
21use serde::{Deserialize, Serialize};
22use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
23
24use crate::services::ai::{ModelCapabilities, ModelLimits, ModelPricing};
25
26pub use error::{ProviderRegistryError, ProviderRegistryResult};
27pub use protocol::WireProtocol;
28
29const DEFAULT_CATALOG_YAML: &str = include_str!("default_catalog.yaml");
33
34#[derive(Deserialize)]
35struct DefaultCatalogFile {
36 providers: Vec<ProviderEntry>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
45#[serde(deny_unknown_fields)]
46pub struct ProviderModel {
47 pub id: ModelId,
48
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
50 pub aliases: Vec<ModelId>,
51
52 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub upstream_model: Option<String>,
57
58 #[serde(default)]
59 pub pricing: ModelPricing,
60
61 #[serde(default)]
62 pub capabilities: ModelCapabilities,
63
64 #[serde(default)]
65 pub limits: ModelLimits,
66}
67
68impl ProviderModel {
69 #[must_use]
70 pub fn matches(&self, requested: &str) -> bool {
71 self.id.as_str() == requested || self.aliases.iter().any(|a| a.as_str() == requested)
72 }
73
74 #[must_use]
75 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
76 self.upstream_model.as_deref().unwrap_or(requested)
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
82#[serde(deny_unknown_fields)]
83pub struct ProviderEntry {
84 pub name: ProviderId,
85
86 pub protocol: WireProtocol,
87
88 pub endpoint: String,
89
90 pub api_key_secret: SecretName,
91
92 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
93 pub extra_headers: HashMap<String, String>,
94
95 #[serde(default, skip_serializing_if = "Vec::is_empty")]
96 pub models: Vec<ProviderModel>,
97}
98
99impl ProviderEntry {
100 #[must_use]
101 pub fn find_model(&self, requested: &str) -> Option<&ProviderModel> {
102 self.models.iter().find(|m| m.matches(requested))
103 }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
108#[serde(transparent)]
109pub struct ProviderRegistry {
110 pub providers: Vec<ProviderEntry>,
111}
112
113impl ProviderRegistry {
114 pub fn default_seed() -> ProviderRegistryResult<Self> {
118 let file: DefaultCatalogFile = serde_yaml::from_str(DEFAULT_CATALOG_YAML)
119 .map_err(|e| ProviderRegistryError::InvalidDefaultCatalog(e.to_string()))?;
120 Ok(Self {
121 providers: file.providers,
122 })
123 }
124
125 #[must_use]
126 pub fn find_provider(&self, name: &str) -> Option<&ProviderEntry> {
127 self.providers.iter().find(|p| p.name.as_str() == name)
128 }
129
130 #[must_use]
131 pub fn contains_model(&self, requested: &str) -> bool {
132 self.providers
133 .iter()
134 .any(|p| p.find_model(requested).is_some())
135 }
136
137 pub fn validate(&self) -> ProviderRegistryResult<()> {
138 let trusted = crate::net::trusted_http_hosts_from_env();
139 let mut seen_providers: HashSet<&str> = HashSet::with_capacity(self.providers.len());
140 let mut seen_models: HashSet<&str> = HashSet::new();
141
142 for provider in &self.providers {
143 if !seen_providers.insert(provider.name.as_str()) {
144 return Err(ProviderRegistryError::DuplicateProvider {
145 name: provider.name.as_str().to_owned(),
146 });
147 }
148 if provider.endpoint.is_empty() {
149 return Err(ProviderRegistryError::EmptyEndpoint {
150 name: provider.name.as_str().to_owned(),
151 });
152 }
153 crate::net::validate_outbound_url_with_trust(&provider.endpoint, &trusted).map_err(
154 |e| ProviderRegistryError::BlockedEndpoint {
155 provider: provider.name.as_str().to_owned(),
156 endpoint: provider.endpoint.clone(),
157 reason: e.to_string(),
158 },
159 )?;
160
161 for model in &provider.models {
162 if model.id.as_str().is_empty() {
163 return Err(ProviderRegistryError::EmptyModelId {
164 id: provider.name.as_str().to_owned(),
165 });
166 }
167 if !seen_models.insert(model.id.as_str()) {
168 return Err(ProviderRegistryError::DuplicateModel {
169 id: model.id.as_str().to_owned(),
170 });
171 }
172 for alias in &model.aliases {
173 if !seen_models.insert(alias.as_str()) {
174 return Err(ProviderRegistryError::DuplicateModel {
175 id: alias.as_str().to_owned(),
176 });
177 }
178 }
179 }
180 }
181 Ok(())
182 }
183}