Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::oauth::OAuthCredential;
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12    dirs::config_dir()
13        .unwrap_or_else(|| PathBuf::from("."))
14        .join(APP_NAME)
15        .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19    dirs::config_dir()
20        .unwrap_or_else(|| PathBuf::from("."))
21        .join(APP_NAME)
22        .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26    dirs::data_local_dir()
27        .unwrap_or_else(|| PathBuf::from("."))
28        .join(APP_NAME)
29        .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33    dirs::data_local_dir()
34        .unwrap_or_else(|| PathBuf::from("."))
35        .join(APP_NAME)
36        .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40    dirs::data_local_dir()
41        .unwrap_or_else(|| PathBuf::from("."))
42        .join(APP_NAME)
43        .join("shunt.pid")
44}
45
46// ---------------------------------------------------------------------------
47// Credentials store  (separate file from config — never commit this)
48// ---------------------------------------------------------------------------
49
50#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52    pub accounts: HashMap<String, OAuthCredential>,
53}
54
55impl CredentialsStore {
56    pub fn load() -> Self {
57        let p = credentials_path();
58        if !p.exists() {
59            return Self::default();
60        }
61        match std::fs::read_to_string(&p) {
62            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
63            Err(_) => Self::default(),
64        }
65    }
66
67    pub fn save(&self) -> Result<()> {
68        let p = credentials_path();
69        if let Some(parent) = p.parent() {
70            std::fs::create_dir_all(parent)?;
71        }
72        let tmp = p.with_extension("tmp");
73        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
74        std::fs::rename(&tmp, &p)?;
75        #[cfg(unix)]
76        {
77            use std::os::unix::fs::PermissionsExt;
78            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
79        }
80        // On Windows, restrict the file to the current user via icacls (best-effort).
81        #[cfg(windows)]
82        {
83            if let Some(path_str) = p.to_str() {
84                let username = std::env::var("USERNAME").unwrap_or_default();
85                if !username.is_empty() {
86                    let _ = std::process::Command::new("icacls")
87                        .arg(path_str)
88                        .arg("/inheritance:r")
89                        .arg("/grant:r")
90                        .arg(format!("{username}:F"))
91                        .status();
92                }
93            }
94        }
95        Ok(())
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Raw TOML config types
101// ---------------------------------------------------------------------------
102
103#[derive(Debug, Deserialize)]
104struct RawConfig {
105    #[serde(default)]
106    server: RawServer,
107    #[serde(default)]
108    accounts: Vec<RawAccount>,
109}
110
111#[derive(Debug, Deserialize)]
112struct RawServer {
113    #[serde(default = "default_host")]
114    host: String,
115    #[serde(default = "default_port")]
116    port: u16,
117    #[serde(default = "default_log_level")]
118    log_level: String,
119    upstream_url: Option<String>,
120    remote_key: Option<String>,
121    relay_url: Option<String>,
122    /// Conversation stickiness TTL in minutes (default: 10)
123    sticky_ttl_minutes: Option<u64>,
124    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
125    expiry_soon_minutes: Option<u64>,
126    /// Upstream request timeout in seconds (default: 600)
127    request_timeout_secs: Option<u64>,
128}
129
130impl Default for RawServer {
131    fn default() -> Self {
132        Self {
133            host: default_host(),
134            port: default_port(),
135            log_level: default_log_level(),
136            upstream_url: None,
137            remote_key: None,
138            relay_url: None,
139            sticky_ttl_minutes: None,
140            expiry_soon_minutes: None,
141            request_timeout_secs: None,
142        }
143    }
144}
145
146#[derive(Debug, Deserialize)]
147struct RawAccount {
148    name: String,
149    #[serde(default = "default_plan_type")]
150    plan_type: String,
151    /// "anthropic" (default) | "openai" / "codex"
152    #[serde(default)]
153    provider: Option<String>,
154}
155
156fn default_host() -> String { "127.0.0.1".into() }
157fn default_port() -> u16 { 8082 }
158fn default_log_level() -> String { "info".into() }
159fn default_plan_type() -> String { "pro".into() }
160
161// ---------------------------------------------------------------------------
162// Resolved config types
163// ---------------------------------------------------------------------------
164
165#[derive(Debug, Clone)]
166pub struct ServerConfig {
167    pub host: String,
168    pub port: u16,
169    pub log_level: String,
170    pub upstream_url: String,
171    /// When set, remote requests must supply this value as `x-api-key`.
172    pub remote_key: Option<String>,
173    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
174    pub relay_url: String,
175    /// Conversation stickiness TTL in milliseconds.
176    pub sticky_ttl_ms: u64,
177    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
178    pub expiry_soon_secs: u64,
179    /// Upstream request timeout in seconds.
180    pub request_timeout_secs: u64,
181}
182
183impl Default for ServerConfig {
184    fn default() -> Self {
185        Self {
186            host: "127.0.0.1".into(),
187            port: 8082,
188            log_level: "info".into(),
189            upstream_url: "https://api.anthropic.com".into(),
190            remote_key: None,
191            relay_url: "https://relay.ramcharan.shop".into(),
192            sticky_ttl_ms: 10 * 60 * 1000,
193            expiry_soon_secs: 30 * 60,
194            request_timeout_secs: 600,
195        }
196    }
197}
198
199#[derive(Debug, Clone)]
200pub struct AccountConfig {
201    pub name: String,
202    pub plan_type: String,
203    pub provider: Provider,
204    /// `None` when the account is in config but has no credential yet.
205    /// These accounts are shown in status but skipped during proxying.
206    pub credential: Option<OAuthCredential>,
207}
208
209#[derive(Debug, Clone)]
210pub struct Config {
211    pub server: ServerConfig,
212    pub accounts: Vec<AccountConfig>,
213    pub config_file: PathBuf,
214}
215
216// ---------------------------------------------------------------------------
217// Loading
218// ---------------------------------------------------------------------------
219
220pub fn load_config(path: Option<&Path>) -> Result<Config> {
221    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
222
223    if !p.exists() {
224        bail!(
225            "Config not found: {}\nRun `shunt setup` to get started.",
226            p.display()
227        );
228    }
229
230    let raw_text = std::fs::read_to_string(&p)
231        .with_context(|| format!("Failed to read config: {}", p.display()))?;
232
233    let raw: RawConfig = toml::from_str(&raw_text)
234        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
235
236    // Derive the default upstream URL from the first account's provider so that
237    // an all-OpenAI config automatically points at api.openai.com without any
238    // explicit `upstream_url` in the config file.
239    let default_upstream = raw.accounts.first()
240        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
241        .unwrap_or_default()
242        .default_upstream_url()
243        .to_owned();
244
245    let upstream_url = raw
246        .server
247        .upstream_url
248        .clone()
249        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
250        .unwrap_or(default_upstream);
251
252    let relay_url = raw
253        .server
254        .relay_url
255        .clone()
256        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
257        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
258
259    let server = ServerConfig {
260        host: raw.server.host,
261        port: raw.server.port,
262        log_level: raw.server.log_level,
263        upstream_url,
264        remote_key: raw.server.remote_key,
265        relay_url,
266        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
267        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
268        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
269    };
270
271    if raw.accounts.is_empty() {
272        bail!("Config has no accounts. Run `shunt setup` to add one.");
273    }
274
275    let store = CredentialsStore::load();
276
277    let mut accounts = Vec::new();
278    for a in &raw.accounts {
279        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
280
281        // Resolve credential: stored credential first, then auto-import from provider's local CLI.
282        let cred = store
283            .accounts
284            .get(&a.name)
285            .cloned()
286            .or_else(|| provider.read_local_credentials());
287
288        accounts.push(AccountConfig {
289            name: a.name.clone(),
290            plan_type: a.plan_type.clone(),
291            provider,
292            credential: cred,
293        });
294    }
295
296    Ok(Config { server, accounts, config_file: p })
297}
298
299// ---------------------------------------------------------------------------
300// Config file template
301// ---------------------------------------------------------------------------
302
303pub fn config_template(accounts: &[(&str, &str)]) -> String {
304    let mut out = String::from(
305        "[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
306    );
307    for (name, plan_type) in accounts {
308        out.push_str(&format!(
309            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
310        ));
311    }
312    out
313}