Skip to main content

systemprompt_cli/commands/admin/config/
provider.rs

1//! `admin config provider` command: manage AI providers in the
2//! `ai/config.yaml`.
3//!
4//! [`ProviderCommands`] lists providers, sets the default, and toggles a
5//! provider's enabled flag, editing the AI config YAML in place.
6
7use anyhow::Result;
8use clap::{Args, Subcommand};
9use systemprompt_config::ProfileBootstrap;
10
11use super::types::{
12    ConfigSection, ProviderInfo, ProviderListOutput, ProviderSetOutput, read_yaml_file,
13    write_yaml_file,
14};
15use crate::CliConfig;
16use crate::shared::{CommandResult, render_result};
17
18#[derive(Debug, Subcommand)]
19pub enum ProviderCommands {
20    #[command(about = "List AI providers")]
21    List(ListArgs),
22
23    #[command(about = "Set default provider")]
24    Set(SetArgs),
25
26    #[command(about = "Enable a provider")]
27    Enable(EnableArgs),
28
29    #[command(about = "Disable a provider")]
30    Disable(DisableArgs),
31}
32
33#[derive(Debug, Clone, Copy, Args)]
34pub struct ListArgs;
35
36#[derive(Debug, Clone, Args)]
37pub struct SetArgs {
38    #[arg(value_name = "PROVIDER")]
39    pub provider: String,
40}
41
42#[derive(Debug, Clone, Args)]
43pub struct EnableArgs {
44    #[arg(value_name = "PROVIDER")]
45    pub provider: String,
46}
47
48#[derive(Debug, Clone, Args)]
49pub struct DisableArgs {
50    #[arg(value_name = "PROVIDER")]
51    pub provider: String,
52}
53
54pub fn execute(cmd: ProviderCommands, _config: &CliConfig) -> Result<()> {
55    match cmd {
56        ProviderCommands::List(_args) => {
57            let result = list_providers()?;
58            render_result(&CommandResult::table(result).with_title("AI Providers"));
59        },
60        ProviderCommands::Set(args) => {
61            let result = set_default_provider(&args.provider)?;
62            render_result(&CommandResult::card(result).with_title("Provider Updated"));
63        },
64        ProviderCommands::Enable(args) => {
65            let result = set_provider_enabled(&args.provider, true)?;
66            render_result(&CommandResult::card(result).with_title("Provider Enabled"));
67        },
68        ProviderCommands::Disable(args) => {
69            let result = set_provider_enabled(&args.provider, false)?;
70            render_result(&CommandResult::card(result).with_title("Provider Disabled"));
71        },
72    }
73    Ok(())
74}
75
76fn get_ai_config_path() -> Result<std::path::PathBuf> {
77    ConfigSection::Ai.file_path()
78}
79
80fn list_providers() -> Result<ProviderListOutput> {
81    let registry = &ProfileBootstrap::get()?.providers;
82    let file_path = get_ai_config_path()?;
83    let content = read_yaml_file(&file_path)?;
84
85    let ai = content
86        .get("ai")
87        .ok_or_else(|| anyhow::anyhow!("Missing 'ai' section in config"))?;
88
89    let default_provider = ai
90        .get("default_provider")
91        .and_then(|v| v.as_str())
92        .unwrap_or("unknown")
93        .to_owned();
94
95    let providers_section = ai.get("providers");
96
97    let mut providers = Vec::new();
98
99    if let Some(serde_yaml::Value::Mapping(providers_map)) = providers_section {
100        for (name, config) in providers_map {
101            let name_str = name.as_str().unwrap_or("unknown").to_owned();
102
103            let enabled = config
104                .get("enabled")
105                .and_then(serde_yaml::Value::as_bool)
106                .unwrap_or(true);
107
108            let model = config
109                .get("default_model")
110                .and_then(|v| v.as_str())
111                .unwrap_or("unknown")
112                .to_owned();
113
114            let endpoint = registry
115                .find_provider(&name_str)
116                .map(|entry| entry.endpoint.clone());
117
118            providers.push(ProviderInfo {
119                name: name_str.clone(),
120                enabled,
121                is_default: name_str == default_provider,
122                model,
123                endpoint,
124            });
125        }
126    }
127
128    Ok(ProviderListOutput {
129        providers,
130        default_provider,
131    })
132}
133
134fn set_default_provider(provider: &str) -> Result<ProviderSetOutput> {
135    let registry = &ProfileBootstrap::get()?.providers;
136    if registry.find_provider(provider).is_none() {
137        let available: Vec<&str> = registry.providers.iter().map(|p| p.name.as_str()).collect();
138        anyhow::bail!(
139            "Unknown provider: '{}' is not in profile.providers. Available: {:?}",
140            provider,
141            available
142        );
143    }
144
145    let file_path = get_ai_config_path()?;
146    let mut content = read_yaml_file(&file_path)?;
147
148    let policy = content.get("ai").and_then(|ai| ai.get("providers"));
149    let enabled = policy
150        .and_then(|p| p.get(provider))
151        .and_then(|p| p.get("enabled"))
152        .and_then(serde_yaml::Value::as_bool)
153        .unwrap_or(true);
154    if !enabled {
155        anyhow::bail!(
156            "Provider '{}' is disabled in AI policy; enable it first \
157             (admin config provider enable {})",
158            provider,
159            provider
160        );
161    }
162
163    if let Some(serde_yaml::Value::Mapping(ai_map)) = content.get_mut("ai") {
164        ai_map.insert(
165            serde_yaml::Value::String("default_provider".to_owned()),
166            serde_yaml::Value::String(provider.to_owned()),
167        );
168    }
169
170    write_yaml_file(&file_path, &content)?;
171
172    Ok(ProviderSetOutput {
173        provider: provider.to_owned(),
174        action: "set_default".to_owned(),
175        message: format!("Default provider set to '{}'", provider),
176    })
177}
178
179fn set_provider_enabled(provider: &str, enabled: bool) -> Result<ProviderSetOutput> {
180    let file_path = get_ai_config_path()?;
181    let mut content = read_yaml_file(&file_path)?;
182
183    let ai = content
184        .get_mut("ai")
185        .ok_or_else(|| anyhow::anyhow!("Missing 'ai' section"))?;
186
187    let providers = ai
188        .get_mut("providers")
189        .ok_or_else(|| anyhow::anyhow!("Missing 'providers' section"))?;
190
191    let provider_config = providers
192        .get_mut(provider)
193        .ok_or_else(|| anyhow::anyhow!("Unknown provider: '{}'", provider))?;
194
195    if let serde_yaml::Value::Mapping(config_map) = provider_config {
196        config_map.insert(
197            serde_yaml::Value::String("enabled".to_owned()),
198            serde_yaml::Value::Bool(enabled),
199        );
200    }
201
202    write_yaml_file(&file_path, &content)?;
203
204    let action = if enabled { "enabled" } else { "disabled" };
205
206    Ok(ProviderSetOutput {
207        provider: provider.to_owned(),
208        action: action.to_owned(),
209        message: format!("Provider '{}' {}", provider, action),
210    })
211}