systemprompt_cli/commands/admin/config/
catalog.rs1use std::collections::HashMap;
11
12use anyhow::{Result, bail};
13use clap::{Args, Subcommand};
14use systemprompt_config::ProfileBootstrap;
15use systemprompt_identifiers::{ModelId, ProviderId, SecretName};
16use systemprompt_models::Profile;
17use systemprompt_models::profile::{ProviderEntry, ProviderModel, WireProtocol};
18
19use super::profile_io::{load_profile, save_profile};
20use super::types::ConfigMutationOutput;
21use crate::CliConfig;
22use crate::shared::{CommandResult, render_result};
23
24#[derive(Debug, Subcommand)]
25pub enum CatalogCommands {
26 #[command(subcommand, about = "Manage registry providers")]
27 Provider(ProviderCommands),
28
29 #[command(subcommand, about = "Manage the models a provider serves")]
30 Model(ModelCommands),
31}
32
33#[derive(Debug, Subcommand)]
34pub enum ProviderCommands {
35 #[command(about = "List declared providers")]
36 List,
37 #[command(about = "Add or replace a provider")]
38 Add(ProviderAddArgs),
39 #[command(about = "Remove a provider by name")]
40 Remove {
41 #[arg(long)]
42 name: String,
43 },
44}
45
46#[derive(Debug, Subcommand)]
47pub enum ModelCommands {
48 #[command(about = "Add or replace a model under a provider")]
49 Add(ModelAddArgs),
50 #[command(about = "Remove a model by id from a provider")]
51 Remove {
52 #[arg(long, help = "Provider that serves the model")]
53 provider: String,
54 #[arg(long)]
55 id: String,
56 },
57}
58
59#[derive(Debug, Clone, Args)]
60pub struct ProviderAddArgs {
61 #[arg(long)]
62 pub name: String,
63 #[arg(
64 long,
65 help = "Wire protocol: anthropic | openai-chat | openai-responses | gemini"
66 )]
67 pub protocol: String,
68 #[arg(long)]
69 pub endpoint: String,
70 #[arg(long)]
71 pub api_key_secret: String,
72 #[arg(long = "header", help = "Extra header as KEY=VALUE (repeatable)")]
73 pub headers: Vec<String>,
74}
75
76#[derive(Debug, Clone, Args)]
77pub struct ModelAddArgs {
78 #[arg(long, help = "Provider that serves this model")]
79 pub provider: String,
80 #[arg(long)]
81 pub id: String,
82 #[arg(long = "alias", help = "Model alias (repeatable)")]
83 pub aliases: Vec<String>,
84 #[arg(
85 long,
86 help = "Vendor-side model name to forward upstream (defaults to id)"
87 )]
88 pub upstream_model: Option<String>,
89}
90
91pub fn execute(command: &CatalogCommands, _config: &CliConfig) -> Result<()> {
92 if matches!(command, CatalogCommands::Provider(ProviderCommands::List)) {
93 return list_providers();
94 }
95
96 let profile_path = ProfileBootstrap::get_path()?;
97 let mut profile = load_profile(profile_path)?;
98
99 let message = match command {
100 CatalogCommands::Provider(ProviderCommands::List) => unreachable!("handled above"),
101 CatalogCommands::Provider(ProviderCommands::Add(args)) => add_provider(&mut profile, args)?,
102 CatalogCommands::Provider(ProviderCommands::Remove { name }) => {
103 remove_provider(&mut profile, name)?
104 },
105 CatalogCommands::Model(ModelCommands::Add(args)) => add_model(&mut profile, args)?,
106 CatalogCommands::Model(ModelCommands::Remove { provider, id }) => {
107 remove_model(&mut profile, provider, id)?
108 },
109 };
110
111 save_profile(&profile, profile_path)?;
112
113 render_result(
114 &CommandResult::text(ConfigMutationOutput {
115 field: "providers".to_owned(),
116 message,
117 })
118 .with_title("Provider Registry Updated"),
119 );
120 Ok(())
121}
122
123fn parse_protocol(raw: &str) -> Result<WireProtocol> {
124 serde_yaml::from_str(raw).map_err(|e| {
125 anyhow::anyhow!(
126 "invalid --protocol '{raw}' ({e}); expected one of: anthropic, openai-chat, \
127 openai-responses, gemini"
128 )
129 })
130}
131
132fn parse_headers(raw: &[String]) -> Result<HashMap<String, String>> {
133 raw.iter()
134 .map(|h| {
135 h.split_once('=')
136 .map(|(k, v)| (k.to_owned(), v.to_owned()))
137 .ok_or_else(|| anyhow::anyhow!("invalid --header '{h}'; expected KEY=VALUE"))
138 })
139 .collect()
140}
141
142fn add_provider(profile: &mut Profile, args: &ProviderAddArgs) -> Result<String> {
143 let models = profile
145 .providers
146 .find_provider(&args.name)
147 .map(|p| p.models.clone())
148 .unwrap_or_default();
149 let entry = ProviderEntry {
150 name: ProviderId::new(&args.name),
151 protocol: parse_protocol(&args.protocol)?,
152 endpoint: args.endpoint.clone(),
153 api_key_secret: SecretName::new(&args.api_key_secret),
154 extra_headers: parse_headers(&args.headers)?,
155 models,
156 };
157 profile
158 .providers
159 .providers
160 .retain(|p| p.name.as_str() != args.name);
161 profile.providers.providers.push(entry);
162 Ok(format!("Provider {} ({}) added", args.name, args.protocol))
163}
164
165fn remove_provider(profile: &mut Profile, name: &str) -> Result<String> {
166 let before = profile.providers.providers.len();
167 profile
168 .providers
169 .providers
170 .retain(|p| p.name.as_str() != name);
171 if profile.providers.providers.len() == before {
172 bail!("No provider named {}", name);
173 }
174 Ok(format!("Provider {} removed", name))
175}
176
177fn add_model(profile: &mut Profile, args: &ModelAddArgs) -> Result<String> {
178 let provider = profile
179 .providers
180 .providers
181 .iter_mut()
182 .find(|p| p.name.as_str() == args.provider)
183 .ok_or_else(|| anyhow::anyhow!("No provider named {}", args.provider))?;
184 let model = ProviderModel {
185 id: ModelId::new(&args.id),
186 aliases: args.aliases.iter().map(ModelId::new).collect(),
187 upstream_model: args.upstream_model.clone(),
188 pricing: systemprompt_models::services::ai::ModelPricing::default(),
189 capabilities: systemprompt_models::services::ai::ModelCapabilities::default(),
190 limits: systemprompt_models::services::ai::ModelLimits::default(),
191 };
192 provider.models.retain(|m| m.id.as_str() != args.id);
193 provider.models.push(model);
194 Ok(format!("Model {} added to {}", args.id, args.provider))
195}
196
197fn remove_model(profile: &mut Profile, provider_name: &str, id: &str) -> Result<String> {
198 let provider = profile
199 .providers
200 .providers
201 .iter_mut()
202 .find(|p| p.name.as_str() == provider_name)
203 .ok_or_else(|| anyhow::anyhow!("No provider named {}", provider_name))?;
204 let before = provider.models.len();
205 provider.models.retain(|m| m.id.as_str() != id);
206 if provider.models.len() == before {
207 bail!("No model with id {} under provider {}", id, provider_name);
208 }
209 Ok(format!("Model {} removed from {}", id, provider_name))
210}
211
212fn list_providers() -> Result<()> {
213 let profile_path = ProfileBootstrap::get_path()?;
214 let profile = load_profile(profile_path)?;
215 let rows: Vec<String> = profile
216 .providers
217 .providers
218 .iter()
219 .map(|p| {
220 let models: Vec<&str> = p.models.iter().map(|m| m.id.as_str()).collect();
221 format!(
222 "{} [{}] {} ({} models: {})",
223 p.name.as_str(),
224 p.protocol,
225 p.endpoint,
226 models.len(),
227 models.join(", ")
228 )
229 })
230 .collect();
231
232 render_result(&CommandResult::list(rows).with_title("Provider Registry"));
233 Ok(())
234}