Skip to main content

shunt/
config.rs

1use anyhow::{bail, Context, Result};
2
3/// Validate that an upstream URL uses http/https and does not point to
4/// loopback or link-local addresses (SSRF guard).
5/// Pass `allow_loopback = true` for Local-provider accounts (e.g. Ollama).
6fn validate_upstream_url(url: &str, allow_loopback: bool) -> Result<()> {
7    let parsed = url::Url::parse(url)
8        .with_context(|| format!("Invalid upstream URL: {url}"))?;
9    match parsed.scheme() {
10        "http" | "https" => {}
11        s => bail!("Upstream URL must use http or https, got scheme '{s}': {url}"),
12    }
13    if !allow_loopback {
14        if let Some(host) = parsed.host_str() {
15            let blocked = matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
16                || host.starts_with("169.254.")
17                || host.starts_with("fd");
18            if blocked {
19                bail!("Upstream URL must not point to loopback or link-local addresses: {url}");
20            }
21        }
22    }
23    Ok(())
24}
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29use crate::credential::{deserialize_credential_map, Credential};
30use crate::provider::Provider;
31
32pub const APP_NAME: &str = "shunt";
33
34pub fn config_path() -> PathBuf {
35    dirs::config_dir()
36        .unwrap_or_else(|| PathBuf::from("."))
37        .join(APP_NAME)
38        .join("config.toml")
39}
40
41pub fn credentials_path() -> PathBuf {
42    dirs::config_dir()
43        .unwrap_or_else(|| PathBuf::from("."))
44        .join(APP_NAME)
45        .join("credentials.json")
46}
47
48pub fn state_path() -> PathBuf {
49    dirs::data_local_dir()
50        .unwrap_or_else(|| PathBuf::from("."))
51        .join(APP_NAME)
52        .join("state.json")
53}
54
55pub fn log_path() -> PathBuf {
56    dirs::data_local_dir()
57        .unwrap_or_else(|| PathBuf::from("."))
58        .join(APP_NAME)
59        .join("proxy.log")
60}
61
62pub fn notify_log_path() -> PathBuf {
63    dirs::data_local_dir()
64        .unwrap_or_else(|| PathBuf::from("."))
65        .join(APP_NAME)
66        .join("notify.log")
67}
68
69pub fn pid_path() -> PathBuf {
70    dirs::data_local_dir()
71        .unwrap_or_else(|| PathBuf::from("."))
72        .join(APP_NAME)
73        .join("shunt.pid")
74}
75
76// ---------------------------------------------------------------------------
77// Credentials store  (separate file from config — never commit this)
78// ---------------------------------------------------------------------------
79
80#[derive(Debug, Default, Serialize, Deserialize)]
81pub struct CredentialsStore {
82    #[serde(deserialize_with = "deserialize_credential_map", default)]
83    pub accounts: HashMap<String, Credential>,
84}
85
86impl CredentialsStore {
87    pub fn load() -> Self {
88        let p = credentials_path();
89        if !p.exists() {
90            return Self::default();
91        }
92        match std::fs::read_to_string(&p) {
93            Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
94            Err(_) => Self::default(),
95        }
96    }
97
98    pub fn save(&self) -> Result<()> {
99        let p = credentials_path();
100        if let Some(parent) = p.parent() {
101            std::fs::create_dir_all(parent)?;
102        }
103        let tmp = p.with_extension("tmp");
104        std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
105        #[cfg(unix)]
106        {
107            use std::os::unix::fs::PermissionsExt;
108            std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
109        }
110        std::fs::rename(&tmp, &p)?;
111        // On Windows, restrict the file to the current user via icacls (best-effort).
112        #[cfg(windows)]
113        {
114            if let Some(path_str) = p.to_str() {
115                let username = std::env::var("USERNAME").unwrap_or_default();
116                if !username.is_empty() {
117                    let _ = std::process::Command::new("icacls")
118                        .arg(path_str)
119                        .arg("/inheritance:r")
120                        .arg("/grant:r")
121                        .arg(format!("{username}:F"))
122                        .status();
123                }
124            }
125        }
126        Ok(())
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Raw TOML config types
132// ---------------------------------------------------------------------------
133
134#[derive(Debug, Deserialize)]
135struct RawConfig {
136    #[serde(default)]
137    server: RawServer,
138    #[serde(default)]
139    accounts: Vec<RawAccount>,
140    /// Global model-name mapping: `"claude-sonnet-4-6" = "llama-3.3-70b-versatile"`
141    /// Applied when routing Anthropic-format requests to non-Anthropic providers.
142    #[serde(default)]
143    model_mapping: HashMap<String, String>,
144}
145
146#[derive(Debug, Deserialize)]
147struct RawServer {
148    #[serde(default = "default_host")]
149    host: String,
150    #[serde(default = "default_port")]
151    port: u16,
152    #[serde(default = "default_control_port")]
153    control_port: u16,
154    #[serde(default = "default_log_level")]
155    log_level: String,
156    upstream_url: Option<String>,
157    remote_key: Option<String>,
158    relay_url: Option<String>,
159    pub custom_domain: Option<String>,
160    /// Conversation stickiness TTL in minutes (default: 10)
161    sticky_ttl_minutes: Option<u64>,
162    /// "use-it-or-lose-it" expiry window in minutes (default: 30)
163    expiry_soon_minutes: Option<u64>,
164    /// Account selection strategy: "earliest-expiry" (default), "round-robin", "least-utilized"
165    routing_strategy: Option<String>,
166    /// Upstream request timeout in seconds (default: 600)
167    request_timeout_secs: Option<u64>,
168    /// Per-IP rate limit in requests per minute (0 = disabled, default disabled).
169    rate_limit_rpm: Option<u32>,
170    /// Trust X-Real-IP / X-Forwarded-For headers for per-IP rate limiting.
171    /// Set to true only when shunt sits behind a trusted reverse proxy (e.g. cloudflared).
172    /// When false (default), all requests share one rate-limit bucket.
173    trust_proxy_headers: Option<bool>,
174    /// URL of a shunt relay-server instance for multi-machine history aggregation.
175    /// e.g. "http://relay.internal:3001"
176    telemetry_url: Option<String>,
177    /// Bearer token sent to the relay-server. Must match RELAY_TOKEN on the server.
178    telemetry_token: Option<String>,
179    /// Human-readable name for this shunt instance (shown in the relay dashboard).
180    /// Defaults to the system hostname.
181    instance_name: Option<String>,
182}
183
184impl Default for RawServer {
185    fn default() -> Self {
186        Self {
187            host: default_host(),
188            port: default_port(),
189            control_port: default_control_port(),
190            log_level: default_log_level(),
191            upstream_url: None,
192            remote_key: None,
193            relay_url: None,
194            custom_domain: None,
195            sticky_ttl_minutes: None,
196            expiry_soon_minutes: None,
197            routing_strategy: None,
198            request_timeout_secs: None,
199            rate_limit_rpm: None,
200            trust_proxy_headers: None,
201            telemetry_url: None,
202            telemetry_token: None,
203            instance_name: None,
204        }
205    }
206}
207
208#[derive(Debug, Deserialize)]
209struct RawAccount {
210    name: String,
211    #[serde(default = "default_plan_type")]
212    plan_type: String,
213    /// "anthropic" (default) | "openai" / "codex" | "groq" | "mistral" | "local" | …
214    #[serde(default)]
215    provider: Option<String>,
216    /// Inline API key (use api_key_env for better security).
217    #[serde(default)]
218    api_key: Option<String>,
219    /// Name of an environment variable that holds the API key.
220    #[serde(default)]
221    api_key_env: Option<String>,
222    /// Per-account upstream URL override (required for Local provider).
223    #[serde(default)]
224    upstream_url: Option<String>,
225    /// Pin this account to a specific model, overriding global model_mapping
226    /// and the provider's default_model(). Useful for mixing model tiers.
227    #[serde(default)]
228    model: Option<String>,
229}
230
231fn default_host() -> String { "127.0.0.1".into() }
232
233pub fn default_instance_name() -> String {
234    hostname::get()
235        .ok()
236        .and_then(|h| h.into_string().ok())
237        .unwrap_or_else(|| "shunt".into())
238}
239fn default_port() -> u16 { 8082 }
240fn default_control_port() -> u16 { 19081 }
241fn default_log_level() -> String { "info".into() }
242fn default_plan_type() -> String { "pro".into() }
243
244// ---------------------------------------------------------------------------
245// Resolved config types
246// ---------------------------------------------------------------------------
247
248/// Account-selection algorithm used when no sticky or pinned account applies.
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
250pub enum RoutingStrategy {
251    /// Harvest every token before the window expires — use-it-or-lose-it.
252    /// Drains accounts whose quota windows expire soonest first, then prefers
253    /// the account with the most remaining quota. Maximises total token usage over time.
254    /// Config: `"reaper"`
255    Reaper,
256    /// Spins through accounts in a fixed round-robin cycle, ignoring quota state.
257    /// Config: `"carousel"`
258    Carousel,
259    /// Always routes to the account with the softest landing — the most remaining
260    /// capacity across both 5h and 7d windows (binding window primary, secondary as tiebreak).
261    /// Config: `"cushion"`
262    Cushion,
263    /// Time-weighted dual-window optimizer. Scores each account as:
264    ///   health_5h = 1 - (time_fraction_5h × util_5h)
265    ///   health_7d = 1 - (time_fraction_7d × util_7d)
266    ///   score     = health_5h × health_7d
267    /// where time_fraction = secs_to_reset / window_duration (0 = resetting now, 1 = just started).
268    /// Accounts for how much quota remains AND how soon each window refreshes.
269    /// Config: `"maximus"`
270    #[default]
271    Maximus,
272}
273
274impl RoutingStrategy {
275    pub fn as_str(&self) -> &'static str {
276        match self {
277            Self::Reaper  => "reaper",
278            Self::Carousel => "carousel",
279            Self::Cushion  => "cushion",
280            Self::Maximus  => "maximus",
281        }
282    }
283
284    pub fn from_str(s: &str) -> Option<Self> {
285        match s {
286            "reaper" | "earliest-expiry" | "earliest_expiry" => Some(Self::Reaper),
287            "carousel" | "round-robin" | "round_robin" => Some(Self::Carousel),
288            "cushion" | "most-available" | "most_available" => Some(Self::Cushion),
289            "maximus" => Some(Self::Maximus),
290            _ => None,
291        }
292    }
293}
294
295#[derive(Debug, Clone)]
296pub struct ServerConfig {
297    pub host: String,
298    pub port: u16,
299    /// Port for the control plane (/status, /use, /health) — sees all accounts.
300    pub control_port: u16,
301    pub log_level: String,
302    pub upstream_url: String,
303    /// When set, remote requests must supply this value as `x-api-key`.
304    pub remote_key: Option<String>,
305    /// Relay URL for `shunt push` / `shunt login`. Overridable via SHUNT_RELAY_URL.
306    pub relay_url: String,
307    /// Custom domain for permanent online sharing (e.g. https://shunt.mysite.com).
308    pub custom_domain: Option<String>,
309    /// Conversation stickiness TTL in milliseconds.
310    pub sticky_ttl_ms: u64,
311    /// Accounts whose 5h window resets within this many seconds are preferred ("use-it-or-lose-it").
312    pub expiry_soon_secs: u64,
313    /// Which routing algorithm to use for account selection.
314    pub routing_strategy: RoutingStrategy,
315    /// Upstream request timeout in seconds.
316    pub request_timeout_secs: u64,
317    /// Per-IP rate limit in requests per minute (0 = disabled, default disabled).
318    pub rate_limit_rpm: u32,
319    /// Trust X-Real-IP for per-IP rate limiting (only when behind a trusted proxy).
320    pub trust_proxy_headers: bool,
321    /// Optional relay-server URL for cross-instance history aggregation.
322    pub telemetry_url: Option<String>,
323    /// Bearer token for the relay-server.
324    pub telemetry_token: Option<String>,
325    /// Identifier for this shunt instance sent in telemetry payloads.
326    pub instance_name: String,
327}
328
329impl Default for ServerConfig {
330    fn default() -> Self {
331        Self {
332            host: "127.0.0.1".into(),
333            port: 8082,
334            control_port: 19081,
335            log_level: "info".into(),
336            upstream_url: "https://api.anthropic.com".into(),
337            remote_key: None,
338            relay_url: "https://relay.ramcharan.shop".into(),
339            custom_domain: None,
340            sticky_ttl_ms: 10 * 60 * 1000,
341            expiry_soon_secs: 30 * 60,
342            routing_strategy: RoutingStrategy::Maximus,
343            request_timeout_secs: 600,
344            rate_limit_rpm: 0,
345            trust_proxy_headers: false,
346            telemetry_url: None,
347            telemetry_token: None,
348            instance_name: default_instance_name(),
349        }
350    }
351}
352
353#[derive(Debug, Clone)]
354pub struct AccountConfig {
355    pub name: String,
356    pub plan_type: String,
357    pub provider: Provider,
358    /// `None` when the account has no credential.
359    /// OAuth accounts: None means reauth required (shown as auth_failed).
360    /// ApiKey accounts: None means key not yet configured.
361    /// Local accounts: None is normal (no auth required).
362    pub credential: Option<Credential>,
363    /// Override the upstream base URL for this account.
364    /// `None` means use `config.server.upstream_url` (primary provider) or
365    /// `provider.default_upstream_url()` (non-primary provider).
366    pub upstream_url: Option<String>,
367    /// Pin this account to a specific model name.
368    /// Overrides both `model_mapping` and `provider.default_model()`.
369    pub model: Option<String>,
370}
371
372#[derive(Debug, Clone)]
373pub struct Config {
374    pub server: ServerConfig,
375    pub accounts: Vec<AccountConfig>,
376    pub config_file: PathBuf,
377    /// Global model-name overrides: claude model → provider model.
378    /// e.g. `"claude-sonnet-4-6" → "llama-3.3-70b-versatile"`
379    pub model_mapping: HashMap<String, String>,
380}
381
382// ---------------------------------------------------------------------------
383// Loading
384// ---------------------------------------------------------------------------
385
386pub fn load_config(path: Option<&Path>) -> Result<Config> {
387    let p = path.map(PathBuf::from).unwrap_or_else(config_path);
388
389    if !p.exists() {
390        bail!(
391            "Config not found: {}\nRun `shunt setup` to get started.",
392            p.display()
393        );
394    }
395
396    let raw_text = std::fs::read_to_string(&p)
397        .with_context(|| format!("Failed to read config: {}", p.display()))?;
398
399    let raw: RawConfig = toml::from_str(&raw_text)
400        .with_context(|| format!("Failed to parse config: {}", p.display()))?;
401
402    // Derive the default upstream URL from the first account's provider so that
403    // an all-OpenAI config automatically points at api.openai.com without any
404    // explicit `upstream_url` in the config file.
405    let primary_provider_derived = raw.accounts.first()
406        .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
407        .unwrap_or_default();
408    let default_upstream = primary_provider_derived.default_upstream_url().to_owned();
409
410    let upstream_url = raw
411        .server
412        .upstream_url
413        .clone()
414        .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
415        .unwrap_or(default_upstream);
416
417    let relay_url = raw
418        .server
419        .relay_url
420        .clone()
421        .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
422        .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
423
424    let telemetry_url = raw.server.telemetry_url.clone()
425        .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
426    let telemetry_token = raw.server.telemetry_token.clone()
427        .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
428    let instance_name = raw.server.instance_name.clone()
429        .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
430        .unwrap_or_else(default_instance_name);
431
432    // #6 SSRF: validate the server-level upstream URL.
433    // Allow loopback only when the URL was derived from a Local provider's default
434    // (e.g. an all-Ollama config); explicit upstream_url entries are never allowed to
435    // use loopback unless explicitly set via SHUNT_UPSTREAM_URL (trust the operator).
436    let server_url_is_local_derived = raw.server.upstream_url.is_none()
437        && std::env::var("SHUNT_UPSTREAM_URL").is_err()
438        && matches!(primary_provider_derived, Provider::Local);
439    validate_upstream_url(&upstream_url, server_url_is_local_derived)
440        .with_context(|| "server.upstream_url failed validation")?;
441
442    let server = ServerConfig {
443        host: raw.server.host,
444        port: raw.server.port,
445        control_port: raw.server.control_port,
446        log_level: raw.server.log_level,
447        upstream_url,
448        remote_key: raw.server.remote_key,
449        relay_url,
450        custom_domain: raw.server.custom_domain,
451        sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
452        expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
453        routing_strategy: raw.server.routing_strategy.as_deref()
454            .and_then(RoutingStrategy::from_str)
455            .unwrap_or_default(),
456        request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
457        rate_limit_rpm: raw.server.rate_limit_rpm.unwrap_or(0),
458        trust_proxy_headers: raw.server.trust_proxy_headers.unwrap_or(false),
459        telemetry_url,
460        telemetry_token,
461        instance_name,
462    };
463
464    if raw.accounts.is_empty() {
465        bail!("Config has no accounts. Run `shunt setup` to add one.");
466    }
467
468    let store = CredentialsStore::load();
469
470    // primary_provider_derived was already computed above for the server URL derivation.
471    let primary_provider = primary_provider_derived;
472
473    let mut accounts = Vec::new();
474    for a in &raw.accounts {
475        let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
476
477        // Resolve credential.
478        //
479        // OAuth providers (Anthropic, OpenAI): credentials.json first, then
480        // auto-import from the provider's local CLI tool.
481        //
482        // API-key providers: credentials.json first, then inline api_key field,
483        // then api_key_env field, then the provider's well-known env var.
484        let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
485            .or_else(|| {
486                // Inline api_key from TOML (less secure, but convenient for testing).
487                a.api_key.as_deref().map(|k| {
488                    tracing::warn!(account = %a.name, "Inline api_key in config.toml is insecure — use api_key_env instead");
489                    Credential::Apikey { key: k.to_owned() }
490                })
491            })
492            .or_else(|| {
493                // api_key_env: name of env var holding the key.
494                a.api_key_env.as_deref()
495                    .and_then(|var| std::env::var(var).ok())
496                    .map(|k| Credential::Apikey { key: k })
497            })
498            .or_else(|| {
499                // Auto-import from provider's CLI tool (OAuth providers) or
500                // well-known env var (API-key providers).
501                provider.read_local_credentials()
502            });
503
504        // Upstream URL: per-account override from TOML takes priority, then
505        // non-primary-provider accounts get the provider's default URL so
506        // the forwarder knows where to send requests.
507        let is_local = matches!(provider, Provider::Local);
508        if let Some(ref url) = a.upstream_url {
509            // #6 SSRF: allow loopback only for Local provider (e.g. Ollama at localhost).
510            validate_upstream_url(url, is_local)
511                .with_context(|| format!("account '{}' upstream_url failed validation", a.name))?;
512        }
513        let acct_upstream = a.upstream_url.clone().or_else(|| {
514            if provider != primary_provider {
515                Some(provider.default_upstream_url().to_owned())
516            } else {
517                None
518            }
519        });
520
521        accounts.push(AccountConfig {
522            name: a.name.clone(),
523            plan_type: a.plan_type.clone(),
524            provider,
525            credential: cred,
526            upstream_url: acct_upstream,
527            model: a.model.clone(),
528        });
529    }
530
531    Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
532}
533
534// ---------------------------------------------------------------------------
535// Config file template
536// ---------------------------------------------------------------------------
537
538pub fn config_template(accounts: &[(&str, &str)]) -> String {
539    let mut out = String::from(
540        "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
541    );
542    for (name, plan_type) in accounts {
543        out.push_str(&format!(
544            "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
545        ));
546    }
547    out
548}