Skip to main content

systemprompt_cli/commands/admin/config/
gateway.rs

1//! `admin config gateway` — edit the profile's gateway section: enable state,
2//! routing patterns, and the default provider.
3//!
4//! Every mutation resolves the resulting spec and validates it against the
5//! profile's provider registry (`profile.providers`), so a route or
6//! default-provider that names a provider absent from the registry fails at the
7//! edit rather than at the next boot. The gateway owns no catalog: providers
8//! and models live in `profile.providers` (see `admin config catalog`).
9
10use std::collections::HashMap;
11
12use anyhow::{Result, anyhow, bail};
13use clap::{Args, Subcommand};
14use systemprompt_config::ProfileBootstrap;
15use systemprompt_identifiers::{ProviderId, RouteId};
16use systemprompt_models::Profile;
17use systemprompt_models::profile::{GatewayConfigSpec, GatewayRoute, GatewayState};
18
19use super::profile_io::{load_profile, save_profile};
20use super::types::ConfigMutationOutput;
21use crate::CliConfig;
22use crate::shared::{CommandOutput, render_result};
23use systemprompt_models::artifacts::ListItem;
24
25#[derive(Debug, Subcommand)]
26pub enum GatewayCommands {
27    #[command(about = "Enable the gateway")]
28    Enable,
29
30    #[command(about = "Disable the gateway")]
31    Disable,
32
33    #[command(subcommand, about = "Manage gateway routes")]
34    Route(RouteCommands),
35
36    #[command(
37        subcommand,
38        about = "Manage the default provider (catch-all fallback route)"
39    )]
40    DefaultProvider(DefaultProviderCommands),
41}
42
43#[derive(Debug, Subcommand)]
44pub enum DefaultProviderCommands {
45    #[command(about = "Set the default provider (must exist in profile.providers)")]
46    Set {
47        #[arg(long, help = "Provider name declared in profile.providers")]
48        provider: String,
49    },
50
51    #[command(about = "Clear the default provider")]
52    Clear,
53}
54
55#[derive(Debug, Subcommand)]
56pub enum RouteCommands {
57    #[command(about = "Add or replace a route (upsert by model pattern)")]
58    Add(RouteAddArgs),
59
60    #[command(about = "Remove a route by model pattern")]
61    Remove {
62        #[arg(long, help = "Model pattern to remove (e.g. claude-*)")]
63        model_pattern: String,
64    },
65
66    #[command(about = "List configured routes")]
67    List,
68}
69
70#[derive(Debug, Clone, Args)]
71pub struct RouteAddArgs {
72    #[arg(long, help = "Model pattern (e.g. claude-*)")]
73    pub model_pattern: String,
74
75    #[arg(long, help = "Provider name (must exist in profile.providers)")]
76    pub provider: String,
77
78    #[arg(long, help = "Upstream model name the provider expects (optional)")]
79    pub upstream_model: Option<String>,
80}
81
82pub async fn execute(command: &GatewayCommands, _config: &CliConfig) -> Result<()> {
83    if matches!(command, GatewayCommands::Route(RouteCommands::List)) {
84        return list_routes();
85    }
86
87    let profile_path = ProfileBootstrap::get_path()?;
88    let mut profile = load_profile(profile_path)?;
89
90    let message = match command {
91        GatewayCommands::Enable => set_enabled(&mut profile, true)?,
92        GatewayCommands::Disable => set_enabled(&mut profile, false)?,
93        GatewayCommands::Route(RouteCommands::Add(args)) => add_route(&mut profile, args)?,
94        GatewayCommands::Route(RouteCommands::Remove { model_pattern }) => {
95            remove_route(&mut profile, model_pattern)?
96        },
97        GatewayCommands::Route(RouteCommands::List) => unreachable!("handled above"),
98        GatewayCommands::DefaultProvider(DefaultProviderCommands::Set { provider }) => {
99            set_default_provider(&mut profile, provider)?
100        },
101        GatewayCommands::DefaultProvider(DefaultProviderCommands::Clear) => {
102            clear_default_provider(&mut profile)?
103        },
104    };
105
106    validate_gateway(&profile)?;
107    save_profile(&profile, profile_path)?;
108    let outcome = super::reconcile::reconcile_authz(&profile, profile_path).await;
109
110    render_result(&CommandOutput::card_value(
111        "Gateway Updated",
112        &ConfigMutationOutput {
113            field: "gateway".to_owned(),
114            message: super::reconcile::append_reconcile_notice(message, &outcome),
115        },
116    ));
117    Ok(())
118}
119
120fn spec_mut(profile: &mut Profile) -> Result<&mut GatewayConfigSpec> {
121    profile
122        .gateway
123        .get_or_insert_with(|| GatewayState::Spec(GatewayConfigSpec::default()))
124        .as_spec_mut()
125        .ok_or_else(|| anyhow!("gateway is in a resolved state and cannot be edited"))
126}
127
128fn set_enabled(profile: &mut Profile, enabled: bool) -> Result<String> {
129    spec_mut(profile)?.enabled = enabled;
130    Ok(format!("Gateway enabled = {}", enabled))
131}
132
133fn add_route(profile: &mut Profile, args: &RouteAddArgs) -> Result<String> {
134    let mut route = GatewayRoute {
135        id: RouteId::new(""),
136        model_pattern: args.model_pattern.clone(),
137        provider: ProviderId::new(&args.provider),
138        upstream_model: args.upstream_model.clone(),
139        extra_headers: HashMap::new(),
140        pricing: None,
141    };
142    route.ensure_id();
143    let spec = spec_mut(profile)?;
144    spec.routes
145        .retain(|r| r.model_pattern != args.model_pattern);
146    spec.routes.push(route);
147    Ok(format!(
148        "Route {} -> {} added",
149        args.model_pattern, args.provider
150    ))
151}
152
153fn set_default_provider(profile: &mut Profile, provider: &str) -> Result<String> {
154    spec_mut(profile)?.default_provider = Some(ProviderId::new(provider));
155    Ok(format!("Gateway default provider set to {}", provider))
156}
157
158fn clear_default_provider(profile: &mut Profile) -> Result<String> {
159    spec_mut(profile)?.default_provider = None;
160    Ok("Gateway default provider cleared".to_owned())
161}
162
163fn remove_route(profile: &mut Profile, model_pattern: &str) -> Result<String> {
164    let spec = spec_mut(profile)?;
165    let before = spec.routes.len();
166    spec.routes.retain(|r| r.model_pattern != model_pattern);
167    if spec.routes.len() == before {
168        bail!("No route found for model pattern {}", model_pattern);
169    }
170    Ok(format!("Route {} removed", model_pattern))
171}
172
173fn validate_gateway(profile: &Profile) -> Result<()> {
174    let Some(state) = &profile.gateway else {
175        return Ok(());
176    };
177    let resolved = state.clone().into_spec().resolve();
178    resolved
179        .validate(&profile.providers)
180        .map_err(|e| anyhow!("gateway validation failed: {e}"))
181}
182
183fn list_routes() -> Result<()> {
184    let profile_path = ProfileBootstrap::get_path()?;
185    let profile = load_profile(profile_path)?;
186    let items: Vec<ListItem> = profile
187        .gateway
188        .map(|state| state.into_spec().routes)
189        .unwrap_or_default()
190        .iter()
191        .map(|r| {
192            let route = format!("{} -> {}", r.model_pattern, r.provider.as_str());
193            ListItem::new(route, String::new(), String::new())
194        })
195        .collect();
196    render_result(&CommandOutput::list(items).with_title("Gateway Routes"));
197    Ok(())
198}