Skip to main content

systemprompt_models/profile/providers/
mod.rs

1//! Provider registry: the single source of upstream connectivity.
2//!
3//! [`ProviderRegistry`] is the per-environment `profile.providers` section.
4//! Each [`ProviderEntry`] declares one upstream exactly once — its
5//! [`WireProtocol`], endpoint, credential ([`SecretName`]), extra headers, and
6//! the model catalog it serves. The two policy layers reference entries by
7//! [`ProviderId`] and never re-declare connectivity: the gateway policy
8//! (`profile.gateway`) routes external model names to a provider, and the AI
9//! policy (`services/ai/config.yaml`) selects an agent default and per-provider
10//! overrides.
11//!
12//! Validation here is the authority for connectivity: unique provider names,
13//! SSRF-guarded endpoints, and globally-unique model ids/aliases. The gateway
14//! and AI layers validate only their references *into* this registry.
15
16mod error;
17mod protocol;
18
19use std::collections::{HashMap, HashSet};
20
21use serde::{Deserialize, Serialize};
22use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
23
24use crate::services::ai::{ModelCapabilities, ModelLimits, ModelPricing};
25
26pub use error::{ProviderRegistryError, ProviderRegistryResult};
27pub use protocol::WireProtocol;
28
29/// The canonical out-of-the-box provider catalog, embedded at build time. The
30/// single seed source shared by the setup wizard, cloud-init scaffolding, and
31/// the catalog-parity tests — see [`ProviderRegistry::default_seed`].
32const DEFAULT_CATALOG_YAML: &str = include_str!("default_catalog.yaml");
33
34#[derive(Deserialize)]
35struct DefaultCatalogFile {
36    providers: Vec<ProviderEntry>,
37}
38
39/// One model served by a provider: identity, routing, and economics.
40///
41/// A model's full description lives here exactly once: identity and routing
42/// (id, aliases, `upstream_model`, pricing) alongside agent-side capabilities
43/// and limits.
44#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
45#[serde(deny_unknown_fields)]
46pub struct ProviderModel {
47    pub id: ModelId,
48
49    #[serde(default, skip_serializing_if = "Vec::is_empty")]
50    pub aliases: Vec<ModelId>,
51
52    /// Vendor-side model name to send upstream when it differs from
53    /// [`Self::id`] (the external-facing name). `None` forwards `id`
54    /// unchanged.
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    pub upstream_model: Option<String>,
57
58    #[serde(default)]
59    pub pricing: ModelPricing,
60
61    #[serde(default)]
62    pub capabilities: ModelCapabilities,
63
64    #[serde(default)]
65    pub limits: ModelLimits,
66}
67
68impl ProviderModel {
69    #[must_use]
70    pub fn matches(&self, requested: &str) -> bool {
71        self.id.as_str() == requested || self.aliases.iter().any(|a| a.as_str() == requested)
72    }
73
74    #[must_use]
75    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
76        self.upstream_model.as_deref().unwrap_or(requested)
77    }
78}
79
80/// One upstream provider declared once: connectivity + the models it serves.
81#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
82#[serde(deny_unknown_fields)]
83pub struct ProviderEntry {
84    pub name: ProviderId,
85
86    pub protocol: WireProtocol,
87
88    pub endpoint: String,
89
90    pub api_key_secret: SecretName,
91
92    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
93    pub extra_headers: HashMap<String, String>,
94
95    #[serde(default, skip_serializing_if = "Vec::is_empty")]
96    pub models: Vec<ProviderModel>,
97}
98
99impl ProviderEntry {
100    #[must_use]
101    pub fn find_model(&self, requested: &str) -> Option<&ProviderModel> {
102        self.models.iter().find(|m| m.matches(requested))
103    }
104}
105
106/// The `profile.providers` section: the registry of upstream providers.
107#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
108#[serde(transparent)]
109pub struct ProviderRegistry {
110    pub providers: Vec<ProviderEntry>,
111}
112
113impl ProviderRegistry {
114    /// Parse the embedded `DEFAULT_CATALOG_YAML` into the canonical seed
115    /// registry (every known provider + its full model catalog). Errs only if
116    /// the in-tree YAML is malformed — a build-time bug caught by tests.
117    pub fn default_seed() -> ProviderRegistryResult<Self> {
118        let file: DefaultCatalogFile = serde_yaml::from_str(DEFAULT_CATALOG_YAML)
119            .map_err(|e| ProviderRegistryError::InvalidDefaultCatalog(e.to_string()))?;
120        Ok(Self {
121            providers: file.providers,
122        })
123    }
124
125    #[must_use]
126    pub fn find_provider(&self, name: &str) -> Option<&ProviderEntry> {
127        self.providers.iter().find(|p| p.name.as_str() == name)
128    }
129
130    #[must_use]
131    pub fn contains_model(&self, requested: &str) -> bool {
132        self.providers
133            .iter()
134            .any(|p| p.find_model(requested).is_some())
135    }
136
137    pub fn validate(&self) -> ProviderRegistryResult<()> {
138        let trusted = crate::net::trusted_http_hosts_from_env();
139        let mut seen_providers: HashSet<&str> = HashSet::with_capacity(self.providers.len());
140        let mut seen_models: HashSet<&str> = HashSet::new();
141
142        for provider in &self.providers {
143            if !seen_providers.insert(provider.name.as_str()) {
144                return Err(ProviderRegistryError::DuplicateProvider {
145                    name: provider.name.as_str().to_owned(),
146                });
147            }
148            if provider.endpoint.is_empty() {
149                return Err(ProviderRegistryError::EmptyEndpoint {
150                    name: provider.name.as_str().to_owned(),
151                });
152            }
153            crate::net::validate_outbound_url_with_trust(&provider.endpoint, &trusted).map_err(
154                |e| ProviderRegistryError::BlockedEndpoint {
155                    provider: provider.name.as_str().to_owned(),
156                    endpoint: provider.endpoint.clone(),
157                    reason: e.to_string(),
158                },
159            )?;
160
161            for model in &provider.models {
162                if model.id.as_str().is_empty() {
163                    return Err(ProviderRegistryError::EmptyModelId {
164                        id: provider.name.as_str().to_owned(),
165                    });
166                }
167                if !seen_models.insert(model.id.as_str()) {
168                    return Err(ProviderRegistryError::DuplicateModel {
169                        id: model.id.as_str().to_owned(),
170                    });
171                }
172                for alias in &model.aliases {
173                    if !seen_models.insert(alias.as_str()) {
174                        return Err(ProviderRegistryError::DuplicateModel {
175                            id: alias.as_str().to_owned(),
176                        });
177                    }
178                }
179            }
180        }
181        Ok(())
182    }
183}