Skip to main content

systemprompt_cli/commands/admin/config/
provider.rs

1use anyhow::Result;
2use clap::{Args, Subcommand};
3
4use super::types::{
5    read_yaml_file, write_yaml_file, ConfigSection, ProviderInfo, ProviderListOutput,
6    ProviderSetOutput,
7};
8use crate::shared::{render_result, CommandResult};
9use crate::CliConfig;
10
11#[derive(Debug, Subcommand)]
12pub enum ProviderCommands {
13    #[command(about = "List AI providers")]
14    List(ListArgs),
15
16    #[command(about = "Set default provider")]
17    Set(SetArgs),
18
19    #[command(about = "Enable a provider")]
20    Enable(EnableArgs),
21
22    #[command(about = "Disable a provider")]
23    Disable(DisableArgs),
24}
25
26#[derive(Debug, Clone, Copy, Args)]
27pub struct ListArgs;
28
29#[derive(Debug, Clone, Args)]
30pub struct SetArgs {
31    #[arg(value_name = "PROVIDER")]
32    pub provider: String,
33}
34
35#[derive(Debug, Clone, Args)]
36pub struct EnableArgs {
37    #[arg(value_name = "PROVIDER")]
38    pub provider: String,
39}
40
41#[derive(Debug, Clone, Args)]
42pub struct DisableArgs {
43    #[arg(value_name = "PROVIDER")]
44    pub provider: String,
45}
46
47pub fn execute(cmd: ProviderCommands, _config: &CliConfig) -> Result<()> {
48    match cmd {
49        ProviderCommands::List(_args) => {
50            let result = list_providers()?;
51            render_result(
52                &CommandResult::table(serde_json::to_value(result)?).with_title("AI Providers"),
53            );
54        },
55        ProviderCommands::Set(args) => {
56            let result = set_default_provider(&args.provider)?;
57            render_result(
58                &CommandResult::card(serde_json::to_value(result)?).with_title("Provider Updated"),
59            );
60        },
61        ProviderCommands::Enable(args) => {
62            let result = set_provider_enabled(&args.provider, true)?;
63            render_result(
64                &CommandResult::card(serde_json::to_value(result)?).with_title("Provider Enabled"),
65            );
66        },
67        ProviderCommands::Disable(args) => {
68            let result = set_provider_enabled(&args.provider, false)?;
69            render_result(
70                &CommandResult::card(serde_json::to_value(result)?).with_title("Provider Disabled"),
71            );
72        },
73    }
74    Ok(())
75}
76
77fn get_ai_config_path() -> Result<std::path::PathBuf> {
78    ConfigSection::Ai.file_path()
79}
80
81fn list_providers() -> Result<ProviderListOutput> {
82    let file_path = get_ai_config_path()?;
83    let content = read_yaml_file(&file_path)?;
84
85    let ai = content
86        .get("ai")
87        .ok_or_else(|| anyhow::anyhow!("Missing 'ai' section in config"))?;
88
89    let default_provider = ai
90        .get("default_provider")
91        .and_then(|v| v.as_str())
92        .unwrap_or("unknown")
93        .to_string();
94
95    let providers_section = ai.get("providers");
96
97    let mut providers = Vec::new();
98
99    if let Some(serde_yaml::Value::Mapping(providers_map)) = providers_section {
100        for (name, config) in providers_map {
101            let name_str = name.as_str().unwrap_or("unknown").to_string();
102
103            let enabled = config
104                .get("enabled")
105                .and_then(serde_yaml::Value::as_bool)
106                .unwrap_or(true);
107
108            let model = config
109                .get("default_model")
110                .and_then(|v| v.as_str())
111                .unwrap_or("unknown")
112                .to_string();
113
114            let endpoint = config
115                .get("endpoint")
116                .and_then(|v| v.as_str())
117                .map(String::from);
118
119            providers.push(ProviderInfo {
120                name: name_str.clone(),
121                enabled,
122                is_default: name_str == default_provider,
123                model,
124                endpoint,
125            });
126        }
127    }
128
129    Ok(ProviderListOutput {
130        providers,
131        default_provider,
132    })
133}
134
135fn set_default_provider(provider: &str) -> Result<ProviderSetOutput> {
136    let file_path = get_ai_config_path()?;
137    let mut content = read_yaml_file(&file_path)?;
138
139    let providers = content
140        .get("ai")
141        .and_then(|ai| ai.get("providers"))
142        .ok_or_else(|| anyhow::anyhow!("Missing providers section"))?;
143
144    if !providers
145        .as_mapping()
146        .is_some_and(|m| m.contains_key(serde_yaml::Value::String(provider.to_string())))
147    {
148        let available: Vec<String> = providers
149            .as_mapping()
150            .map(|m| {
151                m.keys()
152                    .filter_map(|k| k.as_str().map(String::from))
153                    .collect()
154            })
155            .unwrap_or_default();
156        anyhow::bail!(
157            "Unknown provider: '{}'. Available providers: {:?}",
158            provider,
159            available
160        );
161    }
162
163    if let Some(serde_yaml::Value::Mapping(ai_map)) = content.get_mut("ai") {
164        ai_map.insert(
165            serde_yaml::Value::String("default_provider".to_string()),
166            serde_yaml::Value::String(provider.to_string()),
167        );
168    }
169
170    write_yaml_file(&file_path, &content)?;
171
172    Ok(ProviderSetOutput {
173        provider: provider.to_string(),
174        action: "set_default".to_string(),
175        message: format!("Default provider set to '{}'", provider),
176    })
177}
178
179fn set_provider_enabled(provider: &str, enabled: bool) -> Result<ProviderSetOutput> {
180    let file_path = get_ai_config_path()?;
181    let mut content = read_yaml_file(&file_path)?;
182
183    let ai = content
184        .get_mut("ai")
185        .ok_or_else(|| anyhow::anyhow!("Missing 'ai' section"))?;
186
187    let providers = ai
188        .get_mut("providers")
189        .ok_or_else(|| anyhow::anyhow!("Missing 'providers' section"))?;
190
191    let provider_config = providers
192        .get_mut(provider)
193        .ok_or_else(|| anyhow::anyhow!("Unknown provider: '{}'", provider))?;
194
195    if let serde_yaml::Value::Mapping(config_map) = provider_config {
196        config_map.insert(
197            serde_yaml::Value::String("enabled".to_string()),
198            serde_yaml::Value::Bool(enabled),
199        );
200    }
201
202    write_yaml_file(&file_path, &content)?;
203
204    let action = if enabled { "enabled" } else { "disabled" };
205
206    Ok(ProviderSetOutput {
207        provider: provider.to_string(),
208        action: action.to_string(),
209        message: format!("Provider '{}' {}", provider, action),
210    })
211}