Skip to main content

systemprompt_models/profile/providers/
mod.rs

1//! Provider registry: the single source of upstream connectivity.
2//!
3//! [`ProviderRegistry`] is the per-environment `profile.providers` section.
4//! Each [`ProviderEntry`] declares one upstream exactly once — its
5//! [`WireProtocol`], endpoint, credential ([`SecretName`]), extra headers, and
6//! the model catalog it serves. The two policy layers reference entries by
7//! [`ProviderId`] and never re-declare connectivity: the gateway policy
8//! (`profile.gateway`) routes external model names to a provider, and the AI
9//! policy (`services/ai/config.yaml`) selects an agent default and per-provider
10//! overrides.
11//!
12//! Validation here is the authority for connectivity: unique provider names,
13//! SSRF-guarded endpoints, and globally-unique model ids/aliases. The gateway
14//! and AI layers validate only their references *into* this registry.
15
16mod error;
17mod protocol;
18
19use std::collections::{HashMap, HashSet};
20
21use serde::{Deserialize, Serialize};
22use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
23
24use crate::services::ai::{ModelCapabilities, ModelLimits, ModelPricing};
25
26pub use error::{ProviderRegistryError, ProviderRegistryResult};
27pub use protocol::WireProtocol;
28
29/// The canonical out-of-the-box provider catalog, embedded at build time. The
30/// single seed source shared by the setup wizard, cloud-init scaffolding, and
31/// the catalog-parity tests — see [`ProviderRegistry::default_seed`].
32const DEFAULT_CATALOG_YAML: &str = include_str!("default_catalog.yaml");
33
34#[derive(Deserialize)]
35struct DefaultCatalogFile {
36    providers: Vec<ProviderEntry>,
37}
38
39/// One model served by a provider: identity, routing, and economics.
40///
41/// A model's full description lives here exactly once: identity and routing
42/// (id, aliases, `upstream_model`, pricing) alongside agent-side capabilities
43/// and limits.
44#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
45#[serde(deny_unknown_fields)]
46pub struct ProviderModel {
47    pub id: ModelId,
48
49    #[serde(default, skip_serializing_if = "Vec::is_empty")]
50    pub aliases: Vec<ModelId>,
51
52    /// Vendor-side model name to send upstream when it differs from
53    /// [`Self::id`] (the external-facing name). `None` forwards `id`
54    /// unchanged.
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    pub upstream_model: Option<String>,
57
58    #[serde(default)]
59    pub pricing: ModelPricing,
60
61    #[serde(default)]
62    pub capabilities: ModelCapabilities,
63
64    #[serde(default)]
65    pub limits: ModelLimits,
66}
67
68impl ProviderModel {
69    #[must_use]
70    pub fn matches(&self, requested: &str) -> bool {
71        self.id.as_str() == requested || self.aliases.iter().any(|a| a.as_str() == requested)
72    }
73
74    #[must_use]
75    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
76        self.upstream_model.as_deref().unwrap_or(requested)
77    }
78}
79
80/// One upstream provider declared once: connectivity + the models it serves.
81#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
82#[serde(deny_unknown_fields)]
83pub struct ProviderEntry {
84    pub name: ProviderId,
85
86    pub protocol: WireProtocol,
87
88    pub endpoint: String,
89
90    pub api_key_secret: SecretName,
91
92    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
93    pub extra_headers: HashMap<String, String>,
94
95    #[serde(default, skip_serializing_if = "Vec::is_empty")]
96    pub models: Vec<ProviderModel>,
97}
98
99impl ProviderEntry {
100    #[must_use]
101    pub fn find_model(&self, requested: &str) -> Option<&ProviderModel> {
102        self.models.iter().find(|m| m.matches(requested))
103    }
104}
105
106/// The `profile.providers` section: the registry of upstream providers.
107#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
108#[serde(transparent)]
109pub struct ProviderRegistry {
110    pub providers: Vec<ProviderEntry>,
111}
112
113impl ProviderRegistry {
114    /// Parse the embedded `DEFAULT_CATALOG_YAML` into the canonical seed
115    /// registry (every known provider + its full model catalog). Errs only if
116    /// the in-tree YAML is malformed — a build-time bug caught by tests.
117    pub fn default_seed() -> ProviderRegistryResult<Self> {
118        let file: DefaultCatalogFile = serde_yaml::from_str(DEFAULT_CATALOG_YAML)
119            .map_err(|e| ProviderRegistryError::InvalidDefaultCatalog(e.to_string()))?;
120        Ok(Self {
121            providers: file.providers,
122        })
123    }
124
125    #[must_use]
126    pub fn find_provider(&self, name: &str) -> Option<&ProviderEntry> {
127        self.providers.iter().find(|p| p.name.as_str() == name)
128    }
129
130    #[must_use]
131    pub fn contains_model(&self, requested: &str) -> bool {
132        self.providers
133            .iter()
134            .any(|p| p.find_model(requested).is_some())
135    }
136
137    /// Model ids (each model's `id` plus its `aliases`) advertised for the
138    /// given wire protocols; an empty `protocols` slice means the full
139    /// catalog.
140    ///
141    /// This is the single source of truth for "which models may a {protocol}
142    /// client see". A gateway front door (e.g. Cowork in Anthropic mode)
143    /// rejects its whole config if advertised models include a name from
144    /// another wire family, so a caller scopes the list to its own
145    /// protocol; routes may still proxy those names to a different provider
146    /// underneath.
147    #[must_use]
148    pub fn advertised_model_ids(&self, protocols: &[WireProtocol]) -> Vec<String> {
149        self.providers
150            .iter()
151            .filter(|entry| protocols.is_empty() || protocols.contains(&entry.protocol))
152            .flat_map(|entry| {
153                entry.models.iter().flat_map(|m| {
154                    std::iter::once(m.id.as_str().to_owned())
155                        .chain(m.aliases.iter().map(|a| a.as_str().to_owned()))
156                })
157            })
158            .collect()
159    }
160
161    pub fn validate(&self) -> ProviderRegistryResult<()> {
162        let trusted = crate::net::trusted_http_hosts_from_env();
163        let mut seen_providers: HashSet<&str> = HashSet::with_capacity(self.providers.len());
164        let mut seen_models: HashSet<&str> = HashSet::new();
165
166        for provider in &self.providers {
167            if !seen_providers.insert(provider.name.as_str()) {
168                return Err(ProviderRegistryError::DuplicateProvider {
169                    name: provider.name.as_str().to_owned(),
170                });
171            }
172            if provider.endpoint.is_empty() {
173                return Err(ProviderRegistryError::EmptyEndpoint {
174                    name: provider.name.as_str().to_owned(),
175                });
176            }
177            crate::net::validate_outbound_url_with_trust(&provider.endpoint, &trusted).map_err(
178                |e| ProviderRegistryError::BlockedEndpoint {
179                    provider: provider.name.as_str().to_owned(),
180                    endpoint: provider.endpoint.clone(),
181                    reason: e.to_string(),
182                },
183            )?;
184
185            for model in &provider.models {
186                if model.id.as_str().is_empty() {
187                    return Err(ProviderRegistryError::EmptyModelId {
188                        id: provider.name.as_str().to_owned(),
189                    });
190                }
191                if !seen_models.insert(model.id.as_str()) {
192                    return Err(ProviderRegistryError::DuplicateModel {
193                        id: model.id.as_str().to_owned(),
194                    });
195                }
196                for alias in &model.aliases {
197                    if !seen_models.insert(alias.as_str()) {
198                        return Err(ProviderRegistryError::DuplicateModel {
199                            id: alias.as_str().to_owned(),
200                        });
201                    }
202                }
203            }
204        }
205        Ok(())
206    }
207}