Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::credential::{deserialize_credential_map, Credential};
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12    dirs::config_dir()
13        .unwrap_or_else(|| PathBuf::from("."))
14        .join(APP_NAME)
15        .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19    dirs::config_dir()
20        .unwrap_or_else(|| PathBuf::from("."))
21        .join(APP_NAME)
22        .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26    dirs::data_local_dir()
27        .unwrap_or_else(|| PathBuf::from("."))
28        .join(APP_NAME)
29        .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33    dirs::data_local_dir()
34        .unwrap_or_else(|| PathBuf::from("."))
35        .join(APP_NAME)
36        .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40    dirs::data_local_dir()
41        .unwrap_or_else(|| PathBuf::from("."))
42        .join(APP_NAME)
43        .join("shunt.pid")
44}
45
46// ---------------------------------------------------------------------------
47// Credentials store  (separate file from config — never commit this)
48// ---------------------------------------------------------------------------
49
50#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52    #[serde(deserialize_with = "deserialize_credential_map", default)]
53    pub accounts: HashMap<String, Credential>,
54}
55
56impl CredentialsStore {
57    pub fn load() -> Self {
58        let p = credentials_path();
59        if !p.exists() {
60            return Self::default();
61        }
62        match std::fs::read_to_string(&p) {
63            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
64            Err(_) => Self::default(),
65        }
66    }
67
68    pub fn save(&self) -> Result<()> {
69        let p = credentials_path();
70        if let Some(parent) = p.parent() {
71            std::fs::create_dir_all(parent)?;
72        }
73        let tmp = p.with_extension("tmp");
74        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
75        std::fs::rename(&tmp, &p)?;
76        #[cfg(unix)]
77        {
78            use std::os::unix::fs::PermissionsExt;
79            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
80        }
81        // On Windows, restrict the file to the current user via icacls (best-effort).
82        #[cfg(windows)]
83        {
84            if let Some(path_str) = p.to_str() {
85                let username = std::env::var("USERNAME").unwrap_or_default();
86                if !username.is_empty() {
87                    let _ = std::process::Command::new("icacls")
88                        .arg(path_str)
89                        .arg("/inheritance:r")
90                        .arg("/grant:r")
91                        .arg(format!("{username}:F"))
92                        .status();
93                }
94            }
95        }
96        Ok(())
97    }
98}
99
100// ---------------------------------------------------------------------------
101// Raw TOML config types
102// ---------------------------------------------------------------------------
103
104#[derive(Debug, Deserialize)]
105struct RawConfig {
106    #[serde(default)]
107    server: RawServer,
108    #[serde(default)]
109    accounts: Vec<RawAccount>,
110    /// Global model-name mapping: `"claude-sonnet-4-6" = "llama-3.3-70b-versatile"`
111    /// Applied when routing Anthropic-format requests to non-Anthropic providers.
112    #[serde(default)]
113    model_mapping: HashMap<String, String>,
114}
115
116#[derive(Debug, Deserialize)]
117struct RawServer {
118    #[serde(default = "default_host")]
119    host: String,
120    #[serde(default = "default_port")]
121    port: u16,
122    #[serde(default = "default_control_port")]
123    control_port: u16,
124    #[serde(default = "default_log_level")]
125    log_level: String,
126    upstream_url: Option<String>,
127    remote_key: Option<String>,
128    relay_url: Option<String>,
129    pub custom_domain: Option<String>,
130    /// Conversation stickiness TTL in minutes (default: 10)
131    sticky_ttl_minutes: Option<u64>,
132    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
133    expiry_soon_minutes: Option<u64>,
134    /// Upstream request timeout in seconds (default: 600)
135    request_timeout_secs: Option<u64>,
136}
137
138impl Default for RawServer {
139    fn default() -> Self {
140        Self {
141            host: default_host(),
142            port: default_port(),
143            control_port: default_control_port(),
144            log_level: default_log_level(),
145            upstream_url: None,
146            remote_key: None,
147            relay_url: None,
148            custom_domain: None,
149            sticky_ttl_minutes: None,
150            expiry_soon_minutes: None,
151            request_timeout_secs: None,
152        }
153    }
154}
155
156#[derive(Debug, Deserialize)]
157struct RawAccount {
158    name: String,
159    #[serde(default = "default_plan_type")]
160    plan_type: String,
161    /// "anthropic" (default) | "openai" / "codex" | "groq" | "mistral" | "local" | …
162    #[serde(default)]
163    provider: Option<String>,
164    /// Inline API key (use api_key_env for better security).
165    #[serde(default)]
166    api_key: Option<String>,
167    /// Name of an environment variable that holds the API key.
168    #[serde(default)]
169    api_key_env: Option<String>,
170    /// Per-account upstream URL override (required for Local provider).
171    #[serde(default)]
172    upstream_url: Option<String>,
173    /// Pin this account to a specific model, overriding global model_mapping
174    /// and the provider's default_model(). Useful for mixing model tiers.
175    #[serde(default)]
176    model: Option<String>,
177}
178
179fn default_host() -> String { "127.0.0.1".into() }
180fn default_port() -> u16 { 8082 }
181fn default_control_port() -> u16 { 19081 }
182fn default_log_level() -> String { "info".into() }
183fn default_plan_type() -> String { "pro".into() }
184
185// ---------------------------------------------------------------------------
186// Resolved config types
187// ---------------------------------------------------------------------------
188
189#[derive(Debug, Clone)]
190pub struct ServerConfig {
191    pub host: String,
192    pub port: u16,
193    /// Port for the control plane (/status, /use, /health) — sees all accounts.
194    pub control_port: u16,
195    pub log_level: String,
196    pub upstream_url: String,
197    /// When set, remote requests must supply this value as `x-api-key`.
198    pub remote_key: Option<String>,
199    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
200    pub relay_url: String,
201    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
202    pub custom_domain: Option<String>,
203    /// Conversation stickiness TTL in milliseconds.
204    pub sticky_ttl_ms: u64,
205    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
206    pub expiry_soon_secs: u64,
207    /// Upstream request timeout in seconds.
208    pub request_timeout_secs: u64,
209}
210
211impl Default for ServerConfig {
212    fn default() -> Self {
213        Self {
214            host: "127.0.0.1".into(),
215            port: 8082,
216            control_port: 19081,
217            log_level: "info".into(),
218            upstream_url: "https://api.anthropic.com".into(),
219            remote_key: None,
220            relay_url: "https://relay.ramcharan.shop".into(),
221            custom_domain: None,
222            sticky_ttl_ms: 10 * 60 * 1000,
223            expiry_soon_secs: 30 * 60,
224            request_timeout_secs: 600,
225        }
226    }
227}
228
229#[derive(Debug, Clone)]
230pub struct AccountConfig {
231    pub name: String,
232    pub plan_type: String,
233    pub provider: Provider,
234    /// `None` when the account has no credential.
235    /// OAuth accounts: None means reauth required (shown as auth_failed).
236    /// ApiKey accounts: None means key not yet configured.
237    /// Local accounts: None is normal (no auth required).
238    pub credential: Option<Credential>,
239    /// Override the upstream base URL for this account.
240    /// `None` means use `config.server.upstream_url` (primary provider) or
241    /// `provider.default_upstream_url()` (non-primary provider).
242    pub upstream_url: Option<String>,
243    /// Pin this account to a specific model name.
244    /// Overrides both `model_mapping` and `provider.default_model()`.
245    pub model: Option<String>,
246}
247
248#[derive(Debug, Clone)]
249pub struct Config {
250    pub server: ServerConfig,
251    pub accounts: Vec<AccountConfig>,
252    pub config_file: PathBuf,
253    /// Global model-name overrides: claude model → provider model.
254    /// e.g. `"claude-sonnet-4-6" → "llama-3.3-70b-versatile"`
255    pub model_mapping: HashMap<String, String>,
256}
257
258// ---------------------------------------------------------------------------
259// Loading
260// ---------------------------------------------------------------------------
261
262pub fn load_config(path: Option<&Path>) -> Result<Config> {
263    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
264
265    if !p.exists() {
266        bail!(
267            "Config not found: {}\nRun `shunt setup` to get started.",
268            p.display()
269        );
270    }
271
272    let raw_text = std::fs::read_to_string(&p)
273        .with_context(|| format!("Failed to read config: {}", p.display()))?;
274
275    let raw: RawConfig = toml::from_str(&raw_text)
276        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
277
278    // Derive the default upstream URL from the first account's provider so that
279    // an all-OpenAI config automatically points at api.openai.com without any
280    // explicit `upstream_url` in the config file.
281    let default_upstream = raw.accounts.first()
282        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
283        .unwrap_or_default()
284        .default_upstream_url()
285        .to_owned();
286
287    let upstream_url = raw
288        .server
289        .upstream_url
290        .clone()
291        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
292        .unwrap_or(default_upstream);
293
294    let relay_url = raw
295        .server
296        .relay_url
297        .clone()
298        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
299        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
300
301    let server = ServerConfig {
302        host: raw.server.host,
303        port: raw.server.port,
304        control_port: raw.server.control_port,
305        log_level: raw.server.log_level,
306        upstream_url,
307        remote_key: raw.server.remote_key,
308        relay_url,
309        custom_domain: raw.server.custom_domain,
310        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
311        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
312        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
313    };
314
315    if raw.accounts.is_empty() {
316        bail!("Config has no accounts. Run `shunt setup` to add one.");
317    }
318
319    let store = CredentialsStore::load();
320
321    // Determine the primary provider (first account) so we know which accounts
322    // use config.server.upstream_url and which need the provider's default URL.
323    let primary_provider = raw.accounts.first()
324        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
325        .unwrap_or_default();
326
327    let mut accounts = Vec::new();
328    for a in &raw.accounts {
329        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
330
331        // Resolve credential.
332        //
333        // OAuth providers (Anthropic, OpenAI): credentials.json first, then
334        // auto-import from the provider's local CLI tool.
335        //
336        // API-key providers: credentials.json first, then inline api_key field,
337        // then api_key_env field, then the provider's well-known env var.
338        let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
339            .or_else(|| {
340                // Inline api_key from TOML (less secure, but convenient for testing).
341                a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
342            })
343            .or_else(|| {
344                // api_key_env: name of env var holding the key.
345                a.api_key_env.as_deref()
346                    .and_then(|var| std::env::var(var).ok())
347                    .map(|k| Credential::Apikey { key: k })
348            })
349            .or_else(|| {
350                // Auto-import from provider's CLI tool (OAuth providers) or
351                // well-known env var (API-key providers).
352                provider.read_local_credentials()
353            });
354
355        // Upstream URL: per-account override from TOML takes priority, then
356        // non-primary-provider accounts get the provider's default URL so
357        // the forwarder knows where to send requests.
358        let acct_upstream = a.upstream_url.clone().or_else(|| {
359            if provider != primary_provider {
360                Some(provider.default_upstream_url().to_owned())
361            } else {
362                None
363            }
364        });
365
366        accounts.push(AccountConfig {
367            name: a.name.clone(),
368            plan_type: a.plan_type.clone(),
369            provider,
370            credential: cred,
371            upstream_url: acct_upstream,
372            model: a.model.clone(),
373        });
374    }
375
376    Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
377}
378
379// ---------------------------------------------------------------------------
380// Config file template
381// ---------------------------------------------------------------------------
382
383pub fn config_template(accounts: &[(&str, &str)]) -> String {
384    let mut out = String::from(
385        "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
386    );
387    for (name, plan_type) in accounts {
388        out.push_str(&format!(
389            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
390        ));
391    }
392    out
393}