Skip to main content

systemprompt_cli/commands/admin/config/
catalog.rs

1//! `admin config catalog` — edit the profile's provider registry
2//! (`profile.providers`).
3//!
4//! Mutates the typed `ProviderRegistry` on the profile — adding or removing
5//! providers and the models each provider serves — then revalidates the whole
6//! profile before writing it back. This is how an instance declares a custom
7//! provider such as `minimax` (its wire protocol, endpoint, credential, and
8//! model catalog) without hand-editing YAML.
9
10use std::collections::HashMap;
11
12use anyhow::{Result, bail};
13use clap::{Args, Subcommand};
14use systemprompt_config::ProfileBootstrap;
15use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
16use systemprompt_models::Profile;
17use systemprompt_models::profile::{ProviderEntry, ProviderModel, WireProtocol};
18
19use super::profile_io::{load_profile, save_profile};
20use super::types::ConfigMutationOutput;
21use crate::CliConfig;
22use crate::shared::{CommandResult, render_result};
23
24#[derive(Debug, Subcommand)]
25pub enum CatalogCommands {
26    #[command(subcommand, about = "Manage registry providers")]
27    Provider(ProviderCommands),
28
29    #[command(subcommand, about = "Manage the models a provider serves")]
30    Model(ModelCommands),
31}
32
33#[derive(Debug, Subcommand)]
34pub enum ProviderCommands {
35    #[command(about = "List declared providers")]
36    List,
37    #[command(about = "Add or replace a provider")]
38    Add(ProviderAddArgs),
39    #[command(about = "Remove a provider by name")]
40    Remove {
41        #[arg(long)]
42        name: String,
43    },
44}
45
46#[derive(Debug, Subcommand)]
47pub enum ModelCommands {
48    #[command(about = "Add or replace a model under a provider")]
49    Add(ModelAddArgs),
50    #[command(about = "Remove a model by id from a provider")]
51    Remove {
52        #[arg(long, help = "Provider that serves the model")]
53        provider: String,
54        #[arg(long)]
55        id: String,
56    },
57}
58
59#[derive(Debug, Clone, Args)]
60pub struct ProviderAddArgs {
61    #[arg(long)]
62    pub name: String,
63    #[arg(
64        long,
65        help = "Wire protocol: anthropic | openai-chat | openai-responses | gemini"
66    )]
67    pub protocol: String,
68    #[arg(long)]
69    pub endpoint: String,
70    #[arg(long)]
71    pub api_key_secret: String,
72    #[arg(long = "header", help = "Extra header as KEY=VALUE (repeatable)")]
73    pub headers: Vec<String>,
74}
75
76#[derive(Debug, Clone, Args)]
77pub struct ModelAddArgs {
78    #[arg(long, help = "Provider that serves this model")]
79    pub provider: String,
80    #[arg(long)]
81    pub id: String,
82    #[arg(long = "alias", help = "Model alias (repeatable)")]
83    pub aliases: Vec<String>,
84    #[arg(
85        long,
86        help = "Vendor-side model name to forward upstream (defaults to id)"
87    )]
88    pub upstream_model: Option<String>,
89}
90
91pub async fn execute(command: &CatalogCommands, _config: &CliConfig) -> Result<()> {
92    if matches!(command, CatalogCommands::Provider(ProviderCommands::List)) {
93        return list_providers();
94    }
95
96    let profile_path = ProfileBootstrap::get_path()?;
97    let mut profile = load_profile(profile_path)?;
98
99    let message = match command {
100        CatalogCommands::Provider(ProviderCommands::List) => unreachable!("handled above"),
101        CatalogCommands::Provider(ProviderCommands::Add(args)) => add_provider(&mut profile, args)?,
102        CatalogCommands::Provider(ProviderCommands::Remove { name }) => {
103            remove_provider(&mut profile, name)?
104        },
105        CatalogCommands::Model(ModelCommands::Add(args)) => add_model(&mut profile, args)?,
106        CatalogCommands::Model(ModelCommands::Remove { provider, id }) => {
107            remove_model(&mut profile, provider, id)?
108        },
109    };
110
111    save_profile(&profile, profile_path)?;
112    let outcome = super::reconcile::reconcile_authz(&profile, profile_path).await;
113
114    render_result(
115        &CommandResult::text(ConfigMutationOutput {
116            field: "providers".to_owned(),
117            message: super::reconcile::append_reconcile_notice(message, &outcome),
118        })
119        .with_title("Provider Registry Updated"),
120    );
121    Ok(())
122}
123
124fn parse_protocol(raw: &str) -> Result<WireProtocol> {
125    serde_yaml::from_str(raw).map_err(|e| {
126        anyhow::anyhow!(
127            "invalid --protocol '{raw}' ({e}); expected one of: anthropic, openai-chat, \
128             openai-responses, gemini"
129        )
130    })
131}
132
133fn parse_headers(raw: &[String]) -> Result<HashMap<String, String>> {
134    raw.iter()
135        .map(|h| {
136            h.split_once('=')
137                .map(|(k, v)| (k.to_owned(), v.to_owned()))
138                .ok_or_else(|| anyhow::anyhow!("invalid --header '{h}'; expected KEY=VALUE"))
139        })
140        .collect()
141}
142
143fn add_provider(profile: &mut Profile, args: &ProviderAddArgs) -> Result<String> {
144    // Preserve the existing model catalog when replacing a provider in place.
145    let models = profile
146        .providers
147        .find_provider(&args.name)
148        .map(|p| p.models.clone())
149        .unwrap_or_default();
150    let entry = ProviderEntry {
151        name: ProviderId::new(&args.name),
152        protocol: parse_protocol(&args.protocol)?,
153        endpoint: args.endpoint.clone(),
154        api_key_secret: SecretName::new(&args.api_key_secret),
155        extra_headers: parse_headers(&args.headers)?,
156        models,
157    };
158    profile
159        .providers
160        .providers
161        .retain(|p| p.name.as_str() != args.name);
162    profile.providers.providers.push(entry);
163    Ok(format!("Provider {} ({}) added", args.name, args.protocol))
164}
165
166fn remove_provider(profile: &mut Profile, name: &str) -> Result<String> {
167    let before = profile.providers.providers.len();
168    profile
169        .providers
170        .providers
171        .retain(|p| p.name.as_str() != name);
172    if profile.providers.providers.len() == before {
173        bail!("No provider named {}", name);
174    }
175    Ok(format!("Provider {} removed", name))
176}
177
178fn add_model(profile: &mut Profile, args: &ModelAddArgs) -> Result<String> {
179    let provider = profile
180        .providers
181        .providers
182        .iter_mut()
183        .find(|p| p.name.as_str() == args.provider)
184        .ok_or_else(|| anyhow::anyhow!("No provider named {}", args.provider))?;
185    let model = ProviderModel {
186        id: ModelId::new(&args.id),
187        aliases: args.aliases.iter().map(ModelId::new).collect(),
188        upstream_model: args.upstream_model.clone(),
189        pricing: systemprompt_models::services::ai::ModelPricing::default(),
190        capabilities: systemprompt_models::services::ai::ModelCapabilities::default(),
191        limits: systemprompt_models::services::ai::ModelLimits::default(),
192    };
193    provider.models.retain(|m| m.id.as_str() != args.id);
194    provider.models.push(model);
195    Ok(format!("Model {} added to {}", args.id, args.provider))
196}
197
198fn remove_model(profile: &mut Profile, provider_name: &str, id: &str) -> Result<String> {
199    let provider = profile
200        .providers
201        .providers
202        .iter_mut()
203        .find(|p| p.name.as_str() == provider_name)
204        .ok_or_else(|| anyhow::anyhow!("No provider named {}", provider_name))?;
205    let before = provider.models.len();
206    provider.models.retain(|m| m.id.as_str() != id);
207    if provider.models.len() == before {
208        bail!("No model with id {} under provider {}", id, provider_name);
209    }
210    Ok(format!("Model {} removed from {}", id, provider_name))
211}
212
213fn list_providers() -> Result<()> {
214    let profile_path = ProfileBootstrap::get_path()?;
215    let profile = load_profile(profile_path)?;
216    let rows: Vec<String> = profile
217        .providers
218        .providers
219        .iter()
220        .map(|p| {
221            let models: Vec<&str> = p.models.iter().map(|m| m.id.as_str()).collect();
222            format!(
223                "{} [{}] {} ({} models: {})",
224                p.name.as_str(),
225                p.protocol,
226                p.endpoint,
227                models.len(),
228                models.join(", ")
229            )
230        })
231        .collect();
232
233    render_result(&CommandResult::list(rows).with_title("Provider Registry"));
234    Ok(())
235}