Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::credential::{deserialize_credential_map, Credential};
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12    dirs::config_dir()
13        .unwrap_or_else(|| PathBuf::from("."))
14        .join(APP_NAME)
15        .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19    dirs::config_dir()
20        .unwrap_or_else(|| PathBuf::from("."))
21        .join(APP_NAME)
22        .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26    dirs::data_local_dir()
27        .unwrap_or_else(|| PathBuf::from("."))
28        .join(APP_NAME)
29        .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33    dirs::data_local_dir()
34        .unwrap_or_else(|| PathBuf::from("."))
35        .join(APP_NAME)
36        .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40    dirs::data_local_dir()
41        .unwrap_or_else(|| PathBuf::from("."))
42        .join(APP_NAME)
43        .join("shunt.pid")
44}
45
46// ---------------------------------------------------------------------------
47// Credentials store  (separate file from config — never commit this)
48// ---------------------------------------------------------------------------
49
50#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52    #[serde(deserialize_with = "deserialize_credential_map", default)]
53    pub accounts: HashMap<String, Credential>,
54}
55
56impl CredentialsStore {
57    pub fn load() -> Self {
58        let p = credentials_path();
59        if !p.exists() {
60            return Self::default();
61        }
62        match std::fs::read_to_string(&p) {
63            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
64            Err(_) => Self::default(),
65        }
66    }
67
68    pub fn save(&self) -> Result<()> {
69        let p = credentials_path();
70        if let Some(parent) = p.parent() {
71            std::fs::create_dir_all(parent)?;
72        }
73        let tmp = p.with_extension("tmp");
74        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
75        std::fs::rename(&tmp, &p)?;
76        #[cfg(unix)]
77        {
78            use std::os::unix::fs::PermissionsExt;
79            std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
80        }
81        // On Windows, restrict the file to the current user via icacls (best-effort).
82        #[cfg(windows)]
83        {
84            if let Some(path_str) = p.to_str() {
85                let username = std::env::var("USERNAME").unwrap_or_default();
86                if !username.is_empty() {
87                    let _ = std::process::Command::new("icacls")
88                        .arg(path_str)
89                        .arg("/inheritance:r")
90                        .arg("/grant:r")
91                        .arg(format!("{username}:F"))
92                        .status();
93                }
94            }
95        }
96        Ok(())
97    }
98}
99
100// ---------------------------------------------------------------------------
101// Raw TOML config types
102// ---------------------------------------------------------------------------
103
104#[derive(Debug, Deserialize)]
105struct RawConfig {
106    #[serde(default)]
107    server: RawServer,
108    #[serde(default)]
109    accounts: Vec<RawAccount>,
110    /// Global model-name mapping: `"claude-sonnet-4-6" = "llama-3.3-70b-versatile"`
111    /// Applied when routing Anthropic-format requests to non-Anthropic providers.
112    #[serde(default)]
113    model_mapping: HashMap<String, String>,
114}
115
116#[derive(Debug, Deserialize)]
117struct RawServer {
118    #[serde(default = "default_host")]
119    host: String,
120    #[serde(default = "default_port")]
121    port: u16,
122    #[serde(default = "default_control_port")]
123    control_port: u16,
124    #[serde(default = "default_log_level")]
125    log_level: String,
126    upstream_url: Option<String>,
127    remote_key: Option<String>,
128    relay_url: Option<String>,
129    pub custom_domain: Option<String>,
130    /// Conversation stickiness TTL in minutes (default: 10)
131    sticky_ttl_minutes: Option<u64>,
132    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
133    expiry_soon_minutes: Option<u64>,
134    /// Upstream request timeout in seconds (default: 600)
135    request_timeout_secs: Option<u64>,
136    /// URL of a shunt relay-server instance for multi-machine history aggregation.
137    /// e.g. "http://relay.internal:3001"
138    telemetry_url: Option<String>,
139    /// Bearer token sent to the relay-server. Must match RELAY_TOKEN on the server.
140    telemetry_token: Option<String>,
141    /// Human-readable name for this shunt instance (shown in the relay dashboard).
142    /// Defaults to the system hostname.
143    instance_name: Option<String>,
144}
145
146impl Default for RawServer {
147    fn default() -> Self {
148        Self {
149            host: default_host(),
150            port: default_port(),
151            control_port: default_control_port(),
152            log_level: default_log_level(),
153            upstream_url: None,
154            remote_key: None,
155            relay_url: None,
156            custom_domain: None,
157            sticky_ttl_minutes: None,
158            expiry_soon_minutes: None,
159            request_timeout_secs: None,
160            telemetry_url: None,
161            telemetry_token: None,
162            instance_name: None,
163        }
164    }
165}
166
167#[derive(Debug, Deserialize)]
168struct RawAccount {
169    name: String,
170    #[serde(default = "default_plan_type")]
171    plan_type: String,
172    /// "anthropic" (default) | "openai" / "codex" | "groq" | "mistral" | "local" | …
173    #[serde(default)]
174    provider: Option<String>,
175    /// Inline API key (use api_key_env for better security).
176    #[serde(default)]
177    api_key: Option<String>,
178    /// Name of an environment variable that holds the API key.
179    #[serde(default)]
180    api_key_env: Option<String>,
181    /// Per-account upstream URL override (required for Local provider).
182    #[serde(default)]
183    upstream_url: Option<String>,
184    /// Pin this account to a specific model, overriding global model_mapping
185    /// and the provider's default_model(). Useful for mixing model tiers.
186    #[serde(default)]
187    model: Option<String>,
188}
189
190fn default_host() -> String { "127.0.0.1".into() }
191
192pub fn default_instance_name() -> String {
193    hostname::get()
194        .ok()
195        .and_then(|h| h.into_string().ok())
196        .unwrap_or_else(|| "shunt".into())
197}
198fn default_port() -> u16 { 8082 }
199fn default_control_port() -> u16 { 19081 }
200fn default_log_level() -> String { "info".into() }
201fn default_plan_type() -> String { "pro".into() }
202
203// ---------------------------------------------------------------------------
204// Resolved config types
205// ---------------------------------------------------------------------------
206
207#[derive(Debug, Clone)]
208pub struct ServerConfig {
209    pub host: String,
210    pub port: u16,
211    /// Port for the control plane (/status, /use, /health) — sees all accounts.
212    pub control_port: u16,
213    pub log_level: String,
214    pub upstream_url: String,
215    /// When set, remote requests must supply this value as `x-api-key`.
216    pub remote_key: Option<String>,
217    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
218    pub relay_url: String,
219    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
220    pub custom_domain: Option<String>,
221    /// Conversation stickiness TTL in milliseconds.
222    pub sticky_ttl_ms: u64,
223    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
224    pub expiry_soon_secs: u64,
225    /// Upstream request timeout in seconds.
226    pub request_timeout_secs: u64,
227    /// Optional relay-server URL for cross-instance history aggregation.
228    pub telemetry_url: Option<String>,
229    /// Bearer token for the relay-server.
230    pub telemetry_token: Option<String>,
231    /// Identifier for this shunt instance sent in telemetry payloads.
232    pub instance_name: String,
233}
234
235impl Default for ServerConfig {
236    fn default() -> Self {
237        Self {
238            host: "127.0.0.1".into(),
239            port: 8082,
240            control_port: 19081,
241            log_level: "info".into(),
242            upstream_url: "https://api.anthropic.com".into(),
243            remote_key: None,
244            relay_url: "https://relay.ramcharan.shop".into(),
245            custom_domain: None,
246            sticky_ttl_ms: 10 * 60 * 1000,
247            expiry_soon_secs: 30 * 60,
248            request_timeout_secs: 600,
249            telemetry_url: None,
250            telemetry_token: None,
251            instance_name: default_instance_name(),
252        }
253    }
254}
255
256#[derive(Debug, Clone)]
257pub struct AccountConfig {
258    pub name: String,
259    pub plan_type: String,
260    pub provider: Provider,
261    /// `None` when the account has no credential.
262    /// OAuth accounts: None means reauth required (shown as auth_failed).
263    /// ApiKey accounts: None means key not yet configured.
264    /// Local accounts: None is normal (no auth required).
265    pub credential: Option<Credential>,
266    /// Override the upstream base URL for this account.
267    /// `None` means use `config.server.upstream_url` (primary provider) or
268    /// `provider.default_upstream_url()` (non-primary provider).
269    pub upstream_url: Option<String>,
270    /// Pin this account to a specific model name.
271    /// Overrides both `model_mapping` and `provider.default_model()`.
272    pub model: Option<String>,
273}
274
275#[derive(Debug, Clone)]
276pub struct Config {
277    pub server: ServerConfig,
278    pub accounts: Vec<AccountConfig>,
279    pub config_file: PathBuf,
280    /// Global model-name overrides: claude model → provider model.
281    /// e.g. `"claude-sonnet-4-6" → "llama-3.3-70b-versatile"`
282    pub model_mapping: HashMap<String, String>,
283}
284
285// ---------------------------------------------------------------------------
286// Loading
287// ---------------------------------------------------------------------------
288
289pub fn load_config(path: Option<&Path>) -> Result<Config> {
290    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
291
292    if !p.exists() {
293        bail!(
294            "Config not found: {}\nRun `shunt setup` to get started.",
295            p.display()
296        );
297    }
298
299    let raw_text = std::fs::read_to_string(&p)
300        .with_context(|| format!("Failed to read config: {}", p.display()))?;
301
302    let raw: RawConfig = toml::from_str(&raw_text)
303        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
304
305    // Derive the default upstream URL from the first account's provider so that
306    // an all-OpenAI config automatically points at api.openai.com without any
307    // explicit `upstream_url` in the config file.
308    let default_upstream = raw.accounts.first()
309        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
310        .unwrap_or_default()
311        .default_upstream_url()
312        .to_owned();
313
314    let upstream_url = raw
315        .server
316        .upstream_url
317        .clone()
318        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
319        .unwrap_or(default_upstream);
320
321    let relay_url = raw
322        .server
323        .relay_url
324        .clone()
325        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
326        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
327
328    let telemetry_url = raw.server.telemetry_url.clone()
329        .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
330    let telemetry_token = raw.server.telemetry_token.clone()
331        .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
332    let instance_name = raw.server.instance_name.clone()
333        .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
334        .unwrap_or_else(default_instance_name);
335
336    let server = ServerConfig {
337        host: raw.server.host,
338        port: raw.server.port,
339        control_port: raw.server.control_port,
340        log_level: raw.server.log_level,
341        upstream_url,
342        remote_key: raw.server.remote_key,
343        relay_url,
344        custom_domain: raw.server.custom_domain,
345        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
346        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
347        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
348        telemetry_url,
349        telemetry_token,
350        instance_name,
351    };
352
353    if raw.accounts.is_empty() {
354        bail!("Config has no accounts. Run `shunt setup` to add one.");
355    }
356
357    let store = CredentialsStore::load();
358
359    // Determine the primary provider (first account) so we know which accounts
360    // use config.server.upstream_url and which need the provider's default URL.
361    let primary_provider = raw.accounts.first()
362        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
363        .unwrap_or_default();
364
365    let mut accounts = Vec::new();
366    for a in &raw.accounts {
367        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
368
369        // Resolve credential.
370        //
371        // OAuth providers (Anthropic, OpenAI): credentials.json first, then
372        // auto-import from the provider's local CLI tool.
373        //
374        // API-key providers: credentials.json first, then inline api_key field,
375        // then api_key_env field, then the provider's well-known env var.
376        let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
377            .or_else(|| {
378                // Inline api_key from TOML (less secure, but convenient for testing).
379                a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
380            })
381            .or_else(|| {
382                // api_key_env: name of env var holding the key.
383                a.api_key_env.as_deref()
384                    .and_then(|var| std::env::var(var).ok())
385                    .map(|k| Credential::Apikey { key: k })
386            })
387            .or_else(|| {
388                // Auto-import from provider's CLI tool (OAuth providers) or
389                // well-known env var (API-key providers).
390                provider.read_local_credentials()
391            });
392
393        // Upstream URL: per-account override from TOML takes priority, then
394        // non-primary-provider accounts get the provider's default URL so
395        // the forwarder knows where to send requests.
396        let acct_upstream = a.upstream_url.clone().or_else(|| {
397            if provider != primary_provider {
398                Some(provider.default_upstream_url().to_owned())
399            } else {
400                None
401            }
402        });
403
404        accounts.push(AccountConfig {
405            name: a.name.clone(),
406            plan_type: a.plan_type.clone(),
407            provider,
408            credential: cred,
409            upstream_url: acct_upstream,
410            model: a.model.clone(),
411        });
412    }
413
414    Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
415}
416
417// ---------------------------------------------------------------------------
418// Config file template
419// ---------------------------------------------------------------------------
420
421pub fn config_template(accounts: &[(&str, &str)]) -> String {
422    let mut out = String::from(
423        "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
424    );
425    for (name, plan_type) in accounts {
426        out.push_str(&format!(
427            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
428        ));
429    }
430    out
431}