Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::oauth::OAuthCredential;
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12    dirs::config_dir()
13        .unwrap_or_else(|| PathBuf::from("."))
14        .join(APP_NAME)
15        .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19    dirs::config_dir()
20        .unwrap_or_else(|| PathBuf::from("."))
21        .join(APP_NAME)
22        .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26    dirs::data_local_dir()
27        .unwrap_or_else(|| PathBuf::from("."))
28        .join(APP_NAME)
29        .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33    dirs::data_local_dir()
34        .unwrap_or_else(|| PathBuf::from("."))
35        .join(APP_NAME)
36        .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40    dirs::data_local_dir()
41        .unwrap_or_else(|| PathBuf::from("."))
42        .join(APP_NAME)
43        .join("shunt.pid")
44}
45
46// ---------------------------------------------------------------------------
47// Credentials store  (separate file from config — never commit this)
48// ---------------------------------------------------------------------------
49
50#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52    pub accounts: HashMap<String, OAuthCredential>,
53}
54
55impl CredentialsStore {
56    pub fn load() -> Self {
57        let p = credentials_path();
58        if !p.exists() {
59            return Self::default();
60        }
61        match std::fs::read_to_string(&p) {
62            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
63            Err(_) => Self::default(),
64        }
65    }
66
67    pub fn save(&self) -> Result<()> {
68        let p = credentials_path();
69        if let Some(parent) = p.parent() {
70            std::fs::create_dir_all(parent)?;
71        }
72        let tmp = p.with_extension("tmp");
73        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
74        std::fs::rename(&tmp, &p)?;
75        #[cfg(unix)]
76        {
77            use std::os::unix::fs::PermissionsExt;
78            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
79        }
80        // On Windows, restrict the file to the current user via icacls (best-effort).
81        #[cfg(windows)]
82        {
83            if let Some(path_str) = p.to_str() {
84                let username = std::env::var("USERNAME").unwrap_or_default();
85                if !username.is_empty() {
86                    let _ = std::process::Command::new("icacls")
87                        .arg(path_str)
88                        .arg("/inheritance:r")
89                        .arg("/grant:r")
90                        .arg(format!("{username}:F"))
91                        .status();
92                }
93            }
94        }
95        Ok(())
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Raw TOML config types
101// ---------------------------------------------------------------------------
102
103#[derive(Debug, Deserialize)]
104struct RawConfig {
105    #[serde(default)]
106    server: RawServer,
107    #[serde(default)]
108    accounts: Vec<RawAccount>,
109}
110
111#[derive(Debug, Deserialize)]
112struct RawServer {
113    #[serde(default = "default_host")]
114    host: String,
115    #[serde(default = "default_port")]
116    port: u16,
117    #[serde(default = "default_log_level")]
118    log_level: String,
119    upstream_url: Option<String>,
120    remote_key: Option<String>,
121    relay_url: Option<String>,
122    pub custom_domain: Option<String>,
123    /// Conversation stickiness TTL in minutes (default: 10)
124    sticky_ttl_minutes: Option<u64>,
125    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
126    expiry_soon_minutes: Option<u64>,
127    /// Upstream request timeout in seconds (default: 600)
128    request_timeout_secs: Option<u64>,
129}
130
131impl Default for RawServer {
132    fn default() -> Self {
133        Self {
134            host: default_host(),
135            port: default_port(),
136            log_level: default_log_level(),
137            upstream_url: None,
138            remote_key: None,
139            relay_url: None,
140            custom_domain: None,
141            sticky_ttl_minutes: None,
142            expiry_soon_minutes: None,
143            request_timeout_secs: None,
144        }
145    }
146}
147
148#[derive(Debug, Deserialize)]
149struct RawAccount {
150    name: String,
151    #[serde(default = "default_plan_type")]
152    plan_type: String,
153    /// "anthropic" (default) | "openai" / "codex"
154    #[serde(default)]
155    provider: Option<String>,
156}
157
158fn default_host() -> String { "127.0.0.1".into() }
159fn default_port() -> u16 { 8082 }
160fn default_log_level() -> String { "info".into() }
161fn default_plan_type() -> String { "pro".into() }
162
163// ---------------------------------------------------------------------------
164// Resolved config types
165// ---------------------------------------------------------------------------
166
167#[derive(Debug, Clone)]
168pub struct ServerConfig {
169    pub host: String,
170    pub port: u16,
171    pub log_level: String,
172    pub upstream_url: String,
173    /// When set, remote requests must supply this value as `x-api-key`.
174    pub remote_key: Option<String>,
175    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
176    pub relay_url: String,
177    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
178    pub custom_domain: Option<String>,
179    /// Conversation stickiness TTL in milliseconds.
180    pub sticky_ttl_ms: u64,
181    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
182    pub expiry_soon_secs: u64,
183    /// Upstream request timeout in seconds.
184    pub request_timeout_secs: u64,
185}
186
187impl Default for ServerConfig {
188    fn default() -> Self {
189        Self {
190            host: "127.0.0.1".into(),
191            port: 8082,
192            log_level: "info".into(),
193            upstream_url: "https://api.anthropic.com".into(),
194            remote_key: None,
195            relay_url: "https://relay.ramcharan.shop".into(),
196            custom_domain: None,
197            sticky_ttl_ms: 10 * 60 * 1000,
198            expiry_soon_secs: 30 * 60,
199            request_timeout_secs: 600,
200        }
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct AccountConfig {
206    pub name: String,
207    pub plan_type: String,
208    pub provider: Provider,
209    /// `None` when the account is in config but has no credential yet.
210    /// These accounts are shown in status but skipped during proxying.
211    pub credential: Option<OAuthCredential>,
212    /// Override the upstream base URL for this account.
213    /// Used in tests and for custom per-account routing.
214    /// `None` means use `config.server.upstream_url` (same-protocol) or
215    /// `provider.default_upstream_url()` (cross-protocol translation).
216    pub upstream_url: Option<String>,
217}
218
219#[derive(Debug, Clone)]
220pub struct Config {
221    pub server: ServerConfig,
222    pub accounts: Vec<AccountConfig>,
223    pub config_file: PathBuf,
224}
225
226// ---------------------------------------------------------------------------
227// Loading
228// ---------------------------------------------------------------------------
229
230pub fn load_config(path: Option<&Path>) -> Result<Config> {
231    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
232
233    if !p.exists() {
234        bail!(
235            "Config not found: {}\nRun `shunt setup` to get started.",
236            p.display()
237        );
238    }
239
240    let raw_text = std::fs::read_to_string(&p)
241        .with_context(|| format!("Failed to read config: {}", p.display()))?;
242
243    let raw: RawConfig = toml::from_str(&raw_text)
244        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
245
246    // Derive the default upstream URL from the first account's provider so that
247    // an all-OpenAI config automatically points at api.openai.com without any
248    // explicit `upstream_url` in the config file.
249    let default_upstream = raw.accounts.first()
250        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
251        .unwrap_or_default()
252        .default_upstream_url()
253        .to_owned();
254
255    let upstream_url = raw
256        .server
257        .upstream_url
258        .clone()
259        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
260        .unwrap_or(default_upstream);
261
262    let relay_url = raw
263        .server
264        .relay_url
265        .clone()
266        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
267        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
268
269    let server = ServerConfig {
270        host: raw.server.host,
271        port: raw.server.port,
272        log_level: raw.server.log_level,
273        upstream_url,
274        remote_key: raw.server.remote_key,
275        relay_url,
276        custom_domain: raw.server.custom_domain,
277        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
278        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
279        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
280    };
281
282    if raw.accounts.is_empty() {
283        bail!("Config has no accounts. Run `shunt setup` to add one.");
284    }
285
286    let store = CredentialsStore::load();
287
288    // Determine the primary provider (first account) so we know which accounts
289    // use config.server.upstream_url and which need the provider's default URL.
290    let primary_provider = raw.accounts.first()
291        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
292        .unwrap_or_default();
293
294    let mut accounts = Vec::new();
295    for a in &raw.accounts {
296        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
297
298        // Resolve credential: stored credential first, then auto-import from provider's local CLI.
299        let cred = store
300            .accounts
301            .get(&a.name)
302            .cloned()
303            .or_else(|| provider.read_local_credentials());
304
305        // Non-primary-provider accounts get the provider's real upstream URL pre-populated.
306        // Primary-provider accounts leave it None so proxy falls back to config.server.upstream_url.
307        let acct_upstream = if provider != primary_provider {
308            Some(provider.default_upstream_url().to_owned())
309        } else {
310            None
311        };
312
313        accounts.push(AccountConfig {
314            name: a.name.clone(),
315            plan_type: a.plan_type.clone(),
316            provider,
317            credential: cred,
318            upstream_url: acct_upstream,
319        });
320    }
321
322    Ok(Config { server, accounts, config_file: p })
323}
324
325// ---------------------------------------------------------------------------
326// Config file template
327// ---------------------------------------------------------------------------
328
329pub fn config_template(accounts: &[(&str, &str)]) -> String {
330    let mut out = String::from(
331        "[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
332    );
333    for (name, plan_type) in accounts {
334        out.push_str(&format!(
335            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
336        ));
337    }
338    out
339}