Skip to main content

st/
config.rs

1//! Smart Tree Configuration System
2//!
3//! Unified config for API keys, model preferences, and daemon settings.
4//! Config file: ~/.st/config.toml
5
6use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::PathBuf;
11
12/// Main configuration structure
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct StConfig {
15    /// LLM provider API keys
16    #[serde(default)]
17    pub api_keys: ApiKeys,
18
19    /// Model preferences and aliases
20    #[serde(default)]
21    pub models: ModelConfig,
22
23    /// Daemon settings
24    #[serde(default)]
25    pub daemon: DaemonConfig,
26
27    /// Safety/trust settings
28    #[serde(default)]
29    pub safety: SafetyConfig,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, Default)]
33pub struct ApiKeys {
34    pub anthropic: Option<String>,
35    pub openai: Option<String>,
36    pub google: Option<String>,
37    pub openrouter: Option<String>,
38    pub grok: Option<String>,
39    /// Custom providers: name -> api_key
40    #[serde(default)]
41    pub custom: HashMap<String, String>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ModelConfig {
46    /// Default model for chat
47    pub default_model: String,
48    /// Model aliases: short_name -> full_model_id
49    #[serde(default)]
50    pub aliases: HashMap<String, String>,
51    /// Blocked models (safety)
52    #[serde(default)]
53    pub blocked: Vec<String>,
54}
55
56impl Default for ModelConfig {
57    fn default() -> Self {
58        let mut aliases = HashMap::new();
59        aliases.insert("claude".into(), "claude-sonnet-4-6".into());
60        aliases.insert("opus".into(), "claude-opus-4-6".into());
61        aliases.insert("haiku".into(), "claude-haiku-4-5".into());
62        aliases.insert("gpt4".into(), "gpt-4o".into());
63        aliases.insert("gemini".into(), "gemini-2.0-flash".into());
64
65        Self {
66            default_model: "claude-sonnet-4-6".into(),
67            aliases,
68            blocked: vec!["greatcoderMDK".into()], // Known bad actor
69        }
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct DaemonConfig {
75    pub port: u16,
76    pub auto_start: bool,
77    /// Allow external connections (not just localhost)
78    pub allow_external: bool,
79}
80
81impl Default for DaemonConfig {
82    fn default() -> Self {
83        Self {
84            port: 28428,
85            auto_start: false,
86            allow_external: false,
87        }
88    }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct SafetyConfig {
93    /// Enable The Custodian monitoring
94    pub custodian_enabled: bool,
95    /// Log all LLM requests for transparency
96    pub transparency_logging: bool,
97    /// Model safety scores (model_id -> score 0-10)
98    #[serde(default)]
99    pub model_scores: HashMap<String, u8>,
100}
101
102impl Default for SafetyConfig {
103    fn default() -> Self {
104        let mut scores = HashMap::new();
105        scores.insert("claude-opus-4-6".into(), 10);
106        scores.insert("claude-sonnet-4-6".into(), 10);
107        scores.insert("claude-haiku-4-5".into(), 10);
108        scores.insert("gpt-4o".into(), 9);
109        scores.insert("gpt-4-turbo".into(), 9);
110        scores.insert("gemini-2.0-flash".into(), 9);
111        scores.insert("greatcoderMDK".into(), 2); // Suspicious
112
113        Self {
114            custodian_enabled: true,
115            transparency_logging: true,
116            model_scores: scores,
117        }
118    }
119}
120
121impl StConfig {
122    /// Get config file path
123    pub fn config_path() -> Result<PathBuf> {
124        let st_dir = dirs::home_dir()
125            .context("Could not find home directory")?
126            .join(".st");
127        fs::create_dir_all(&st_dir)?;
128        Ok(st_dir.join("config.toml"))
129    }
130
131    /// Load config from file, or create default
132    pub fn load() -> Result<Self> {
133        let path = Self::config_path()?;
134
135        if path.exists() {
136            let content = fs::read_to_string(&path)
137                .with_context(|| format!("Failed to read {}", path.display()))?;
138            let config: StConfig = toml::from_str(&content)
139                .with_context(|| format!("Failed to parse {}", path.display()))?;
140            Ok(config)
141        } else {
142            // Create default config
143            let config = Self::default();
144            config.save()?;
145            Ok(config)
146        }
147    }
148
149    /// Save config to file
150    pub fn save(&self) -> Result<()> {
151        let path = Self::config_path()?;
152        let content = toml::to_string_pretty(self)?;
153        fs::write(&path, content)?;
154        Ok(())
155    }
156
157    /// Get API key for a provider (checks config then env)
158    pub fn get_api_key(&self, provider: &str) -> Option<String> {
159        // Check config first
160        let from_config = match provider.to_lowercase().as_str() {
161            "anthropic" | "claude" => self.api_keys.anthropic.clone(),
162            "openai" | "gpt" => self.api_keys.openai.clone(),
163            "google" | "gemini" => self.api_keys.google.clone(),
164            "openrouter" => self.api_keys.openrouter.clone(),
165            "grok" | "xai" => self.api_keys.grok.clone(),
166            other => self.api_keys.custom.get(other).cloned(),
167        };
168
169        // Fall back to env var
170        from_config.or_else(|| {
171            let env_var = match provider.to_lowercase().as_str() {
172                "anthropic" | "claude" => "ANTHROPIC_API_KEY",
173                "openai" | "gpt" => "OPENAI_API_KEY",
174                "google" | "gemini" => "GOOGLE_API_KEY",
175                "openrouter" => "OPENROUTER_API_KEY",
176                "grok" | "xai" => "XAI_API_KEY",
177                _ => return None,
178            };
179            std::env::var(env_var).ok()
180        })
181    }
182
183    /// Check if a model is blocked
184    pub fn is_model_blocked(&self, model: &str) -> bool {
185        self.models.blocked.iter().any(|b| model.contains(b))
186    }
187
188    /// Get safety score for a model (0-10)
189    pub fn get_model_score(&self, model: &str) -> u8 {
190        self.safety.model_scores.get(model).copied().unwrap_or(5) // Default: neutral
191    }
192}