Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2
3/// Validate that an upstream URL uses http/https and does not point to
4/// loopback or link-local addresses (SSRF guard).
5/// Pass `allow_loopback = true` for Local-provider accounts (e.g. Ollama).
6fn validate_upstream_url(url: &str, allow_loopback: bool) -> Result<()> {
7    let parsed = url::Url::parse(url)
8        .with_context(|| format!("Invalid upstream URL: {url}"))?;
9    match parsed.scheme() {
10        "http" | "https" => {}
11        s => bail!("Upstream URL must use http or https, got scheme '{s}': {url}"),
12    }
13    if !allow_loopback {
14        if let Some(host) = parsed.host_str() {
15            let blocked = matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
16                || host.starts_with("169.254.")
17                || host.starts_with("fd");
18            if blocked {
19                bail!("Upstream URL must not point to loopback or link-local addresses: {url}");
20            }
21        }
22    }
23    Ok(())
24}
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29use crate::credential::{deserialize_credential_map, Credential};
30use crate::provider::Provider;
31
32pub const APP_NAME: &str = "shunt";
33
34pub fn config_path() -> PathBuf {
35    dirs::config_dir()
36        .unwrap_or_else(|| PathBuf::from("."))
37        .join(APP_NAME)
38        .join("config.toml")
39}
40
41pub fn credentials_path() -> PathBuf {
42    dirs::config_dir()
43        .unwrap_or_else(|| PathBuf::from("."))
44        .join(APP_NAME)
45        .join("credentials.json")
46}
47
48pub fn state_path() -> PathBuf {
49    dirs::data_local_dir()
50        .unwrap_or_else(|| PathBuf::from("."))
51        .join(APP_NAME)
52        .join("state.json")
53}
54
55pub fn log_path() -> PathBuf {
56    dirs::data_local_dir()
57        .unwrap_or_else(|| PathBuf::from("."))
58        .join(APP_NAME)
59        .join("proxy.log")
60}
61
62pub fn notify_log_path() -> PathBuf {
63    dirs::data_local_dir()
64        .unwrap_or_else(|| PathBuf::from("."))
65        .join(APP_NAME)
66        .join("notify.log")
67}
68
69pub fn pid_path() -> PathBuf {
70    dirs::data_local_dir()
71        .unwrap_or_else(|| PathBuf::from("."))
72        .join(APP_NAME)
73        .join("shunt.pid")
74}
75
76// ---------------------------------------------------------------------------
77// Credentials store  (separate file from config — never commit this)
78// ---------------------------------------------------------------------------
79
80#[derive(Debug, Default, Serialize, Deserialize)]
81pub struct CredentialsStore {
82    #[serde(deserialize_with = "deserialize_credential_map", default)]
83    pub accounts: HashMap<String, Credential>,
84}
85
86impl CredentialsStore {
87    pub fn load() -> Self {
88        let p = credentials_path();
89        if !p.exists() {
90            return Self::default();
91        }
92        match std::fs::read_to_string(&p) {
93            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
94            Err(_) => Self::default(),
95        }
96    }
97
98    pub fn save(&self) -> Result<()> {
99        let p = credentials_path();
100        if let Some(parent) = p.parent() {
101            std::fs::create_dir_all(parent)?;
102        }
103        let tmp = p.with_extension("tmp");
104        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
105        #[cfg(unix)]
106        {
107            use std::os::unix::fs::PermissionsExt;
108            std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
109        }
110        std::fs::rename(&tmp, &p)?;
111        // On Windows, restrict the file to the current user via icacls (best-effort).
112        #[cfg(windows)]
113        {
114            if let Some(path_str) = p.to_str() {
115                let username = std::env::var("USERNAME").unwrap_or_default();
116                if !username.is_empty() {
117                    let _ = std::process::Command::new("icacls")
118                        .arg(path_str)
119                        .arg("/inheritance:r")
120                        .arg("/grant:r")
121                        .arg(format!("{username}:F"))
122                        .status();
123                }
124            }
125        }
126        Ok(())
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Raw TOML config types
132// ---------------------------------------------------------------------------
133
134#[derive(Debug, Deserialize)]
135struct RawConfig {
136    #[serde(default)]
137    server: RawServer,
138    #[serde(default)]
139    accounts: Vec<RawAccount>,
140    /// Global model-name mapping: `"claude-sonnet-4-6" = "llama-3.3-70b-versatile"`
141    /// Applied when routing Anthropic-format requests to non-Anthropic providers.
142    #[serde(default)]
143    model_mapping: HashMap<String, String>,
144}
145
146#[derive(Debug, Deserialize)]
147struct RawServer {
148    #[serde(default = "default_host")]
149    host: String,
150    #[serde(default = "default_port")]
151    port: u16,
152    #[serde(default = "default_control_port")]
153    control_port: u16,
154    #[serde(default = "default_log_level")]
155    log_level: String,
156    upstream_url: Option<String>,
157    remote_key: Option<String>,
158    relay_url: Option<String>,
159    pub custom_domain: Option<String>,
160    /// Conversation stickiness TTL in minutes (default: 10)
161    sticky_ttl_minutes: Option<u64>,
162    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
163    expiry_soon_minutes: Option<u64>,
164    /// Account selection strategy: "earliest-expiry" (default), "round-robin", "least-utilized"
165    routing_strategy: Option<String>,
166    /// Upstream request timeout in seconds (default: 600)
167    request_timeout_secs: Option<u64>,
168    /// Per-IP rate limit in requests per minute (0 = disabled, default disabled).
169    rate_limit_rpm: Option<u32>,
170    /// Trust X-Real-IP / X-Forwarded-For headers for per-IP rate limiting.
171    /// Set to true only when shunt sits behind a trusted reverse proxy (e.g. cloudflared).
172    /// When false (default), all requests share one rate-limit bucket.
173    trust_proxy_headers: Option<bool>,
174    /// Enable periodic health-check probes for all accounts (default: true).
175    health_check_enabled: Option<bool>,
176    /// Seconds between health-check probe rounds (default: 300 = 5 min).
177    health_check_interval_secs: Option<u64>,
178    /// Per-account probe timeout in seconds (default: 10).
179    health_check_timeout_secs: Option<u64>,
180    /// URL of a shunt relay-server instance for multi-machine history aggregation.
181    /// e.g. "http://relay.internal:3001"
182    telemetry_url: Option<String>,
183    /// Bearer token sent to the relay-server. Must match RELAY_TOKEN on the server.
184    telemetry_token: Option<String>,
185    /// Human-readable name for this shunt instance (shown in the relay dashboard).
186    /// Defaults to the system hostname.
187    instance_name: Option<String>,
188}
189
190impl Default for RawServer {
191    fn default() -> Self {
192        Self {
193            host: default_host(),
194            port: default_port(),
195            control_port: default_control_port(),
196            log_level: default_log_level(),
197            upstream_url: None,
198            remote_key: None,
199            relay_url: None,
200            custom_domain: None,
201            sticky_ttl_minutes: None,
202            expiry_soon_minutes: None,
203            routing_strategy: None,
204            request_timeout_secs: None,
205            rate_limit_rpm: None,
206            trust_proxy_headers: None,
207            health_check_enabled: None,
208            health_check_interval_secs: None,
209            health_check_timeout_secs: None,
210            telemetry_url: None,
211            telemetry_token: None,
212            instance_name: None,
213        }
214    }
215}
216
217#[derive(Debug, Deserialize)]
218struct RawAccount {
219    name: String,
220    #[serde(default = "default_plan_type")]
221    plan_type: String,
222    /// "anthropic" (default) | "openai" / "codex" | "groq" | "mistral" | "local" | …
223    #[serde(default)]
224    provider: Option<String>,
225    /// Inline API key (use api_key_env for better security).
226    #[serde(default)]
227    api_key: Option<String>,
228    /// Name of an environment variable that holds the API key.
229    #[serde(default)]
230    api_key_env: Option<String>,
231    /// Per-account upstream URL override (required for Local provider).
232    #[serde(default)]
233    upstream_url: Option<String>,
234    /// Pin this account to a specific model, overriding global model_mapping
235    /// and the provider's default_model(). Useful for mixing model tiers.
236    #[serde(default)]
237    model: Option<String>,
238}
239
240fn default_host() -> String { "127.0.0.1".into() }
241
242pub fn default_instance_name() -> String {
243    hostname::get()
244        .ok()
245        .and_then(|h| h.into_string().ok())
246        .unwrap_or_else(|| "shunt".into())
247}
248fn default_port() -> u16 { 8082 }
249fn default_control_port() -> u16 { 19081 }
250fn default_log_level() -> String { "info".into() }
251fn default_plan_type() -> String { "pro".into() }
252
253// ---------------------------------------------------------------------------
254// Resolved config types
255// ---------------------------------------------------------------------------
256
257/// Account-selection algorithm used when no sticky or pinned account applies.
258#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
259pub enum RoutingStrategy {
260    /// Harvest every token before the window expires — use-it-or-lose-it.
261    /// Drains accounts whose quota windows expire soonest first, then prefers
262    /// the account with the most remaining quota. Maximises total token usage over time.
263    /// Config: `"reaper"`
264    Reaper,
265    /// Spins through accounts in a fixed round-robin cycle, ignoring quota state.
266    /// Config: `"carousel"`
267    Carousel,
268    /// Always routes to the account with the softest landing — the most remaining
269    /// capacity across both 5h and 7d windows (binding window primary, secondary as tiebreak).
270    /// Config: `"cushion"`
271    Cushion,
272    /// Time-weighted dual-window optimizer. Scores each account as:
273    ///   health_5h = 1 - (time_fraction_5h × util_5h)
274    ///   health_7d = 1 - (time_fraction_7d × util_7d)
275    ///   score     = health_5h × health_7d
276    /// where time_fraction = secs_to_reset / window_duration (0 = resetting now, 1 = just started).
277    /// Accounts for how much quota remains AND how soon each window refreshes.
278    /// Config: `"maximus"`
279    #[default]
280    Maximus,
281}
282
283impl RoutingStrategy {
284    pub fn as_str(&self) -> &'static str {
285        match self {
286            Self::Reaper  => "reaper",
287            Self::Carousel => "carousel",
288            Self::Cushion  => "cushion",
289            Self::Maximus  => "maximus",
290        }
291    }
292
293    pub fn from_str(s: &str) -> Option<Self> {
294        match s {
295            "reaper" | "earliest-expiry" | "earliest_expiry" => Some(Self::Reaper),
296            "carousel" | "round-robin" | "round_robin" => Some(Self::Carousel),
297            "cushion" | "most-available" | "most_available" => Some(Self::Cushion),
298            "maximus" => Some(Self::Maximus),
299            _ => None,
300        }
301    }
302}
303
304#[derive(Debug, Clone)]
305pub struct ServerConfig {
306    pub host: String,
307    pub port: u16,
308    /// Port for the control plane (/status, /use, /health) — sees all accounts.
309    pub control_port: u16,
310    pub log_level: String,
311    pub upstream_url: String,
312    /// When set, remote requests must supply this value as `x-api-key`.
313    pub remote_key: Option<String>,
314    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
315    pub relay_url: String,
316    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
317    pub custom_domain: Option<String>,
318    /// Conversation stickiness TTL in milliseconds.
319    pub sticky_ttl_ms: u64,
320    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
321    pub expiry_soon_secs: u64,
322    /// Which routing algorithm to use for account selection.
323    pub routing_strategy: RoutingStrategy,
324    /// Upstream request timeout in seconds.
325    pub request_timeout_secs: u64,
326    /// Per-IP rate limit in requests per minute (0 = disabled, default disabled).
327    pub rate_limit_rpm: u32,
328    /// Trust X-Real-IP for per-IP rate limiting (only when behind a trusted proxy).
329    pub trust_proxy_headers: bool,
330    /// Enable periodic health-check probes for all accounts.
331    pub health_check_enabled: bool,
332    /// Seconds between health-check probe rounds.
333    pub health_check_interval_secs: u64,
334    /// Per-account probe timeout in seconds.
335    pub health_check_timeout_secs: u64,
336    /// Optional relay-server URL for cross-instance history aggregation.
337    pub telemetry_url: Option<String>,
338    /// Bearer token for the relay-server.
339    pub telemetry_token: Option<String>,
340    /// Identifier for this shunt instance sent in telemetry payloads.
341    pub instance_name: String,
342}
343
344impl Default for ServerConfig {
345    fn default() -> Self {
346        Self {
347            host: "127.0.0.1".into(),
348            port: 8082,
349            control_port: 19081,
350            log_level: "info".into(),
351            upstream_url: "https://api.anthropic.com".into(),
352            remote_key: None,
353            relay_url: "https://relay.ramcharan.shop".into(),
354            custom_domain: None,
355            sticky_ttl_ms: 10 * 60 * 1000,
356            expiry_soon_secs: 30 * 60,
357            routing_strategy: RoutingStrategy::Maximus,
358            request_timeout_secs: 600,
359            rate_limit_rpm: 0,
360            trust_proxy_headers: false,
361            health_check_enabled: true,
362            health_check_interval_secs: 300,
363            health_check_timeout_secs: 10,
364            telemetry_url: None,
365            telemetry_token: None,
366            instance_name: default_instance_name(),
367        }
368    }
369}
370
371#[derive(Debug, Clone)]
372pub struct AccountConfig {
373    pub name: String,
374    pub plan_type: String,
375    pub provider: Provider,
376    /// `None` when the account has no credential.
377    /// OAuth accounts: None means reauth required (shown as auth_failed).
378    /// ApiKey accounts: None means key not yet configured.
379    /// Local accounts: None is normal (no auth required).
380    pub credential: Option<Credential>,
381    /// Override the upstream base URL for this account.
382    /// `None` means use `config.server.upstream_url` (primary provider) or
383    /// `provider.default_upstream_url()` (non-primary provider).
384    pub upstream_url: Option<String>,
385    /// Pin this account to a specific model name.
386    /// Overrides both `model_mapping` and `provider.default_model()`.
387    pub model: Option<String>,
388}
389
390#[derive(Debug, Clone)]
391pub struct Config {
392    pub server: ServerConfig,
393    pub accounts: Vec<AccountConfig>,
394    pub config_file: PathBuf,
395    /// Global model-name overrides: claude model → provider model.
396    /// e.g. `"claude-sonnet-4-6" → "llama-3.3-70b-versatile"`
397    pub model_mapping: HashMap<String, String>,
398}
399
400// ---------------------------------------------------------------------------
401// Loading
402// ---------------------------------------------------------------------------
403
404pub fn load_config(path: Option<&Path>) -> Result<Config> {
405    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
406
407    if !p.exists() {
408        bail!(
409            "Config not found: {}\nRun `shunt setup` to get started.",
410            p.display()
411        );
412    }
413
414    let raw_text = std::fs::read_to_string(&p)
415        .with_context(|| format!("Failed to read config: {}", p.display()))?;
416
417    let raw: RawConfig = toml::from_str(&raw_text)
418        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
419
420    // Derive the default upstream URL from the first account's provider so that
421    // an all-OpenAI config automatically points at api.openai.com without any
422    // explicit `upstream_url` in the config file.
423    let primary_provider_derived = raw.accounts.first()
424        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
425        .unwrap_or_default();
426    let default_upstream = primary_provider_derived.default_upstream_url().to_owned();
427
428    let upstream_url = raw
429        .server
430        .upstream_url
431        .clone()
432        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
433        .unwrap_or(default_upstream);
434
435    let relay_url = raw
436        .server
437        .relay_url
438        .clone()
439        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
440        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
441
442    let telemetry_url = raw.server.telemetry_url.clone()
443        .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
444    let telemetry_token = raw.server.telemetry_token.clone()
445        .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
446    let instance_name = raw.server.instance_name.clone()
447        .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
448        .unwrap_or_else(default_instance_name);
449
450    // #6 SSRF: validate the server-level upstream URL.
451    // Allow loopback only when the URL was derived from a Local provider's default
452    // (e.g. an all-Ollama config); explicit upstream_url entries are never allowed to
453    // use loopback unless explicitly set via SHUNT_UPSTREAM_URL (trust the operator).
454    let server_url_is_local_derived = raw.server.upstream_url.is_none()
455        && std::env::var("SHUNT_UPSTREAM_URL").is_err()
456        && matches!(primary_provider_derived, Provider::Local);
457    validate_upstream_url(&upstream_url, server_url_is_local_derived)
458        .with_context(|| "server.upstream_url failed validation")?;
459
460    let server = ServerConfig {
461        host: raw.server.host,
462        port: raw.server.port,
463        control_port: raw.server.control_port,
464        log_level: raw.server.log_level,
465        upstream_url,
466        remote_key: raw.server.remote_key,
467        relay_url,
468        custom_domain: raw.server.custom_domain,
469        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
470        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
471        routing_strategy: raw.server.routing_strategy.as_deref()
472            .and_then(RoutingStrategy::from_str)
473            .unwrap_or_default(),
474        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
475        rate_limit_rpm: raw.server.rate_limit_rpm.unwrap_or(0),
476        trust_proxy_headers: raw.server.trust_proxy_headers.unwrap_or(false),
477        health_check_enabled: raw.server.health_check_enabled.unwrap_or(true),
478        health_check_interval_secs: raw.server.health_check_interval_secs.unwrap_or(300),
479        health_check_timeout_secs: raw.server.health_check_timeout_secs.unwrap_or(10),
480        telemetry_url,
481        telemetry_token,
482        instance_name,
483    };
484
485    if raw.accounts.is_empty() {
486        bail!("Config has no accounts. Run `shunt setup` to add one.");
487    }
488
489    let store = CredentialsStore::load();
490
491    // primary_provider_derived was already computed above for the server URL derivation.
492    let primary_provider = primary_provider_derived;
493
494    let mut accounts = Vec::new();
495    for a in &raw.accounts {
496        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
497
498        // Resolve credential.
499        //
500        // OAuth providers (Anthropic, OpenAI): credentials.json first, then
501        // auto-import from the provider's local CLI tool.
502        //
503        // API-key providers: credentials.json first, then inline api_key field,
504        // then api_key_env field, then the provider's well-known env var.
505        let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
506            .or_else(|| {
507                // Inline api_key from TOML (less secure, but convenient for testing).
508                a.api_key.as_deref().map(|k| {
509                    tracing::warn!(account = %a.name, "Inline api_key in config.toml is insecure — use api_key_env instead");
510                    Credential::Apikey { key: k.to_owned() }
511                })
512            })
513            .or_else(|| {
514                // api_key_env: name of env var holding the key.
515                a.api_key_env.as_deref()
516                    .and_then(|var| std::env::var(var).ok())
517                    .map(|k| Credential::Apikey { key: k })
518            })
519            .or_else(|| {
520                // Auto-import from provider's CLI tool (OAuth providers) or
521                // well-known env var (API-key providers).
522                provider.read_local_credentials()
523            });
524
525        // Upstream URL: per-account override from TOML takes priority, then
526        // non-primary-provider accounts get the provider's default URL so
527        // the forwarder knows where to send requests.
528        let is_local = matches!(provider, Provider::Local);
529        if let Some(ref url) = a.upstream_url {
530            // #6 SSRF: allow loopback only for Local provider (e.g. Ollama at localhost).
531            validate_upstream_url(url, is_local)
532                .with_context(|| format!("account '{}' upstream_url failed validation", a.name))?;
533        }
534        let acct_upstream = a.upstream_url.clone().or_else(|| {
535            if provider != primary_provider {
536                Some(provider.default_upstream_url().to_owned())
537            } else {
538                None
539            }
540        });
541
542        accounts.push(AccountConfig {
543            name: a.name.clone(),
544            plan_type: a.plan_type.clone(),
545            provider,
546            credential: cred,
547            upstream_url: acct_upstream,
548            model: a.model.clone(),
549        });
550    }
551
552    Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
553}
554
555// ---------------------------------------------------------------------------
556// Config file template
557// ---------------------------------------------------------------------------
558
559pub fn config_template(accounts: &[(&str, &str)]) -> String {
560    let mut out = String::from(
561        "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
562    );
563    for (name, plan_type) in accounts {
564        out.push_str(&format!(
565            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
566        ));
567    }
568    out
569}