Skip to main content

systemprompt_models/profile/providers/
mod.rs

1//! Provider registry: the single source of upstream connectivity.
2//!
3//! [`ProviderRegistry`] is the per-environment `profile.providers` section.
4//! Each [`ProviderEntry`] declares one upstream exactly once — its
5//! [`WireProtocol`], endpoint, credential ([`SecretName`]), extra headers, and
6//! the model catalog it serves. The two policy layers reference entries by
7//! [`ProviderId`] and never re-declare connectivity: the gateway policy
8//! (`profile.gateway`) routes external model names to a provider, and the AI
9//! policy (`services/ai/config.yaml`) selects an agent default and per-provider
10//! overrides.
11//!
12//! Validation here is the authority for connectivity: unique provider names,
13//! SSRF-guarded endpoints, and globally-unique model ids/aliases. The gateway
14//! and AI layers validate only their references *into* this registry.
15
16mod error;
17mod protocol;
18mod surface;
19
20use std::collections::{HashMap, HashSet};
21
22use serde::{Deserialize, Serialize};
23use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
24
25use crate::services::ai::{ModelCapabilities, ModelLimits, ModelPricing};
26
27pub use error::{ProviderRegistryError, ProviderRegistryResult};
28pub use protocol::WireProtocol;
29pub use surface::ApiSurface;
30
31const DEFAULT_CATALOG_YAML: &str = include_str!("default_catalog.yaml");
32
33#[derive(Deserialize)]
34struct DefaultCatalogFile {
35    providers: Vec<ProviderEntry>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
39#[serde(deny_unknown_fields)]
40pub struct ProviderModel {
41    pub id: ModelId,
42
43    #[serde(default, skip_serializing_if = "Vec::is_empty")]
44    pub aliases: Vec<ModelId>,
45
46    /// Vendor-side model name to send upstream when it differs from
47    /// [`Self::id`] (the external-facing name). `None` forwards `id`
48    /// unchanged.
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub upstream_model: Option<String>,
51
52    #[serde(default)]
53    pub pricing: ModelPricing,
54
55    #[serde(default)]
56    pub capabilities: ModelCapabilities,
57
58    #[serde(default)]
59    pub limits: ModelLimits,
60}
61
62impl ProviderModel {
63    #[must_use]
64    pub fn matches(&self, requested: &str) -> bool {
65        self.id.as_str() == requested || self.aliases.iter().any(|a| a.as_str() == requested)
66    }
67
68    #[must_use]
69    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
70        self.upstream_model.as_deref().unwrap_or(requested)
71    }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
75#[serde(deny_unknown_fields)]
76pub struct ProviderEntry {
77    pub name: ProviderId,
78
79    /// The wire codec the gateway speaks to reach this provider. Selects the
80    /// outbound adapter only — never which client API advertises these models.
81    pub wire: WireProtocol,
82
83    /// The client API family these models are advertised under. Required and
84    /// without a default: advertising a backend vendor as Anthropic must mean
85    /// literally writing `surface: anthropic`, not falling through a default.
86    pub surface: ApiSurface,
87
88    pub endpoint: String,
89
90    pub api_key_secret: SecretName,
91
92    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
93    pub extra_headers: HashMap<String, String>,
94
95    #[serde(default, skip_serializing_if = "Vec::is_empty")]
96    pub models: Vec<ProviderModel>,
97}
98
99impl ProviderEntry {
100    #[must_use]
101    pub fn find_model(&self, requested: &str) -> Option<&ProviderModel> {
102        self.models.iter().find(|m| m.matches(requested))
103    }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
107#[serde(transparent)]
108pub struct ProviderRegistry {
109    pub providers: Vec<ProviderEntry>,
110}
111
112impl ProviderRegistry {
113    pub fn default_seed() -> ProviderRegistryResult<Self> {
114        let file: DefaultCatalogFile = serde_yaml::from_str(DEFAULT_CATALOG_YAML)
115            .map_err(|e| ProviderRegistryError::InvalidDefaultCatalog(e.to_string()))?;
116        Ok(Self {
117            providers: file.providers,
118        })
119    }
120
121    #[must_use]
122    pub fn find_provider(&self, name: &str) -> Option<&ProviderEntry> {
123        self.providers.iter().find(|p| p.name.as_str() == name)
124    }
125
126    #[must_use]
127    pub fn contains_model(&self, requested: &str) -> bool {
128        self.providers
129            .iter()
130            .any(|p| p.find_model(requested).is_some())
131    }
132
133    /// The one place the advertisement rule is applied to the registry.
134    ///
135    /// A `surface: backend` provider (e.g. `minimax`) can never leak into a
136    /// client catalog through a hand-rolled flatten. Routing/admin paths that
137    /// must still see backend providers iterate `self.providers` directly.
138    pub fn advertised_providers(&self) -> impl Iterator<Item = &ProviderEntry> {
139        self.providers
140            .iter()
141            .filter(|entry| entry.surface.is_advertised())
142    }
143
144    /// An empty `surfaces` slice means the full catalog.
145    ///
146    /// A gateway front door (e.g. Cowork in Anthropic mode) rejects its whole
147    /// config if advertised models include a name from another vendor family,
148    /// so a caller scopes the list to its own surface; routes may still
149    /// proxy those names to a different provider underneath.
150    #[must_use]
151    pub fn advertised_model_ids(&self, surfaces: &[ApiSurface]) -> Vec<String> {
152        self.advertised_providers()
153            .filter(|entry| surfaces.is_empty() || surfaces.contains(&entry.surface))
154            .flat_map(|entry| {
155                entry.models.iter().flat_map(|m| {
156                    std::iter::once(m.id.as_str().to_owned())
157                        .chain(m.aliases.iter().map(|a| a.as_str().to_owned()))
158                })
159            })
160            .collect()
161    }
162
163    pub fn validate(&self) -> ProviderRegistryResult<()> {
164        let trusted = crate::net::trusted_http_hosts_from_env();
165        let mut seen_providers: HashSet<&str> = HashSet::with_capacity(self.providers.len());
166        let mut seen_models: HashSet<&str> = HashSet::new();
167
168        for provider in &self.providers {
169            if !seen_providers.insert(provider.name.as_str()) {
170                return Err(ProviderRegistryError::DuplicateProvider {
171                    name: provider.name.as_str().to_owned(),
172                });
173            }
174            if provider.endpoint.is_empty() {
175                return Err(ProviderRegistryError::EmptyEndpoint {
176                    name: provider.name.as_str().to_owned(),
177                });
178            }
179            crate::net::validate_outbound_url_with_trust(&provider.endpoint, &trusted).map_err(
180                |e| ProviderRegistryError::BlockedEndpoint {
181                    provider: provider.name.as_str().to_owned(),
182                    endpoint: provider.endpoint.clone(),
183                    reason: e.to_string(),
184                },
185            )?;
186
187            for model in &provider.models {
188                if model.id.as_str().is_empty() {
189                    return Err(ProviderRegistryError::EmptyModelId {
190                        id: provider.name.as_str().to_owned(),
191                    });
192                }
193                if !seen_models.insert(model.id.as_str()) {
194                    return Err(ProviderRegistryError::DuplicateModel {
195                        id: model.id.as_str().to_owned(),
196                    });
197                }
198                for alias in &model.aliases {
199                    if !seen_models.insert(alias.as_str()) {
200                        return Err(ProviderRegistryError::DuplicateModel {
201                            id: alias.as_str().to_owned(),
202                        });
203                    }
204                }
205            }
206        }
207        Ok(())
208    }
209}