Skip to main content

systemprompt_cli/commands/admin/config/
provider.rs

1use anyhow::Result;
2use clap::{Args, Subcommand};
3
4use super::types::{
5    ConfigSection, ProviderInfo, ProviderListOutput, ProviderSetOutput, read_yaml_file,
6    write_yaml_file,
7};
8use crate::CliConfig;
9use crate::shared::{CommandResult, render_result};
10
11#[derive(Debug, Subcommand)]
12pub enum ProviderCommands {
13    #[command(about = "List AI providers")]
14    List(ListArgs),
15
16    #[command(about = "Set default provider")]
17    Set(SetArgs),
18
19    #[command(about = "Enable a provider")]
20    Enable(EnableArgs),
21
22    #[command(about = "Disable a provider")]
23    Disable(DisableArgs),
24}
25
26#[derive(Debug, Clone, Copy, Args)]
27pub struct ListArgs;
28
29#[derive(Debug, Clone, Args)]
30pub struct SetArgs {
31    #[arg(value_name = "PROVIDER")]
32    pub provider: String,
33}
34
35#[derive(Debug, Clone, Args)]
36pub struct EnableArgs {
37    #[arg(value_name = "PROVIDER")]
38    pub provider: String,
39}
40
41#[derive(Debug, Clone, Args)]
42pub struct DisableArgs {
43    #[arg(value_name = "PROVIDER")]
44    pub provider: String,
45}
46
47pub fn execute(cmd: ProviderCommands, _config: &CliConfig) -> Result<()> {
48    match cmd {
49        ProviderCommands::List(_args) => {
50            let result = list_providers()?;
51            render_result(
52                &CommandResult::table(serde_json::to_value(result)?).with_title("AI Providers"),
53            );
54        },
55        ProviderCommands::Set(args) => {
56            let result = set_default_provider(&args.provider)?;
57            render_result(
58                &CommandResult::card(serde_json::to_value(result)?).with_title("Provider Updated"),
59            );
60        },
61        ProviderCommands::Enable(args) => {
62            let result = set_provider_enabled(&args.provider, true)?;
63            render_result(
64                &CommandResult::card(serde_json::to_value(result)?).with_title("Provider Enabled"),
65            );
66        },
67        ProviderCommands::Disable(args) => {
68            let result = set_provider_enabled(&args.provider, false)?;
69            render_result(
70                &CommandResult::card(serde_json::to_value(result)?).with_title("Provider Disabled"),
71            );
72        },
73    }
74    Ok(())
75}
76
77fn get_ai_config_path() -> Result<std::path::PathBuf> {
78    ConfigSection::Ai.file_path()
79}
80
81fn list_providers() -> Result<ProviderListOutput> {
82    let file_path = get_ai_config_path()?;
83    let content = read_yaml_file(&file_path)?;
84
85    let ai = content
86        .get("ai")
87        .ok_or_else(|| anyhow::anyhow!("Missing 'ai' section in config"))?;
88
89    let default_provider = ai
90        .get("default_provider")
91        .and_then(|v| v.as_str())
92        .unwrap_or("unknown")
93        .to_string();
94
95    let providers_section = ai.get("providers");
96
97    let mut providers = Vec::new();
98
99    if let Some(serde_yaml::Value::Mapping(providers_map)) = providers_section {
100        for (name, config) in providers_map {
101            let name_str = name.as_str().unwrap_or("unknown").to_string();
102
103            let enabled = config
104                .get("enabled")
105                .and_then(serde_yaml::Value::as_bool)
106                .unwrap_or(true);
107
108            let model = config
109                .get("default_model")
110                .and_then(|v| v.as_str())
111                .unwrap_or("unknown")
112                .to_string();
113
114            let endpoint = config
115                .get("endpoint")
116                .and_then(|v| v.as_str())
117                .map(String::from);
118
119            providers.push(ProviderInfo {
120                name: name_str.clone(),
121                enabled,
122                is_default: name_str == default_provider,
123                model,
124                endpoint,
125            });
126        }
127    }
128
129    Ok(ProviderListOutput {
130        providers,
131        default_provider,
132    })
133}
134
135fn set_default_provider(provider: &str) -> Result<ProviderSetOutput> {
136    let file_path = get_ai_config_path()?;
137    let mut content = read_yaml_file(&file_path)?;
138
139    let providers = content
140        .get("ai")
141        .and_then(|ai| ai.get("providers"))
142        .ok_or_else(|| anyhow::anyhow!("Missing providers section"))?;
143
144    if !providers
145        .as_mapping()
146        .is_some_and(|m| m.contains_key(serde_yaml::Value::String(provider.to_string())))
147    {
148        let available: Vec<String> = providers.as_mapping().map_or_else(Vec::new, |m| {
149            m.keys()
150                .filter_map(|k| k.as_str().map(String::from))
151                .collect()
152        });
153        anyhow::bail!(
154            "Unknown provider: '{}'. Available providers: {:?}",
155            provider,
156            available
157        );
158    }
159
160    if let Some(serde_yaml::Value::Mapping(ai_map)) = content.get_mut("ai") {
161        ai_map.insert(
162            serde_yaml::Value::String("default_provider".to_string()),
163            serde_yaml::Value::String(provider.to_string()),
164        );
165    }
166
167    write_yaml_file(&file_path, &content)?;
168
169    Ok(ProviderSetOutput {
170        provider: provider.to_string(),
171        action: "set_default".to_string(),
172        message: format!("Default provider set to '{}'", provider),
173    })
174}
175
176fn set_provider_enabled(provider: &str, enabled: bool) -> Result<ProviderSetOutput> {
177    let file_path = get_ai_config_path()?;
178    let mut content = read_yaml_file(&file_path)?;
179
180    let ai = content
181        .get_mut("ai")
182        .ok_or_else(|| anyhow::anyhow!("Missing 'ai' section"))?;
183
184    let providers = ai
185        .get_mut("providers")
186        .ok_or_else(|| anyhow::anyhow!("Missing 'providers' section"))?;
187
188    let provider_config = providers
189        .get_mut(provider)
190        .ok_or_else(|| anyhow::anyhow!("Unknown provider: '{}'", provider))?;
191
192    if let serde_yaml::Value::Mapping(config_map) = provider_config {
193        config_map.insert(
194            serde_yaml::Value::String("enabled".to_string()),
195            serde_yaml::Value::Bool(enabled),
196        );
197    }
198
199    write_yaml_file(&file_path, &content)?;
200
201    let action = if enabled { "enabled" } else { "disabled" };
202
203    Ok(ProviderSetOutput {
204        provider: provider.to_string(),
205        action: action.to_string(),
206        message: format!("Provider '{}' {}", provider, action),
207    })
208}