1use anyhow::{bail, Context, Result};
2
3fn validate_upstream_url(url: &str, allow_loopback: bool) -> Result<()> {
7 let parsed = url::Url::parse(url)
8 .with_context(|| format!("Invalid upstream URL: {url}"))?;
9 match parsed.scheme() {
10 "http" | "https" => {}
11 s => bail!("Upstream URL must use http or https, got scheme '{s}': {url}"),
12 }
13 if !allow_loopback {
14 if let Some(host) = parsed.host_str() {
15 let blocked = matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
16 || host.starts_with("169.254.")
17 || host.starts_with("fd");
18 if blocked {
19 bail!("Upstream URL must not point to loopback or link-local addresses: {url}");
20 }
21 }
22 }
23 Ok(())
24}
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29use crate::credential::{deserialize_credential_map, Credential};
30use crate::provider::Provider;
31
32pub const APP_NAME: &str = "shunt";
33
34pub fn config_path() -> PathBuf {
35 dirs::config_dir()
36 .unwrap_or_else(|| PathBuf::from("."))
37 .join(APP_NAME)
38 .join("config.toml")
39}
40
41pub fn credentials_path() -> PathBuf {
42 dirs::config_dir()
43 .unwrap_or_else(|| PathBuf::from("."))
44 .join(APP_NAME)
45 .join("credentials.json")
46}
47
48pub fn state_path() -> PathBuf {
49 dirs::data_local_dir()
50 .unwrap_or_else(|| PathBuf::from("."))
51 .join(APP_NAME)
52 .join("state.json")
53}
54
55pub fn log_path() -> PathBuf {
56 dirs::data_local_dir()
57 .unwrap_or_else(|| PathBuf::from("."))
58 .join(APP_NAME)
59 .join("proxy.log")
60}
61
62pub fn notify_log_path() -> PathBuf {
63 dirs::data_local_dir()
64 .unwrap_or_else(|| PathBuf::from("."))
65 .join(APP_NAME)
66 .join("notify.log")
67}
68
69pub fn pid_path() -> PathBuf {
70 dirs::data_local_dir()
71 .unwrap_or_else(|| PathBuf::from("."))
72 .join(APP_NAME)
73 .join("shunt.pid")
74}
75
76#[derive(Debug, Default, Serialize, Deserialize)]
81pub struct CredentialsStore {
82 #[serde(deserialize_with = "deserialize_credential_map", default)]
83 pub accounts: HashMap<String, Credential>,
84}
85
86impl CredentialsStore {
87 pub fn load() -> Self {
88 let p = credentials_path();
89 if !p.exists() {
90 return Self::default();
91 }
92 match std::fs::read_to_string(&p) {
93 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
94 Err(_) => Self::default(),
95 }
96 }
97
98 pub fn save(&self) -> Result<()> {
99 let p = credentials_path();
100 if let Some(parent) = p.parent() {
101 std::fs::create_dir_all(parent)?;
102 }
103 let tmp = p.with_extension("tmp");
104 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
105 #[cfg(unix)]
106 {
107 use std::os::unix::fs::PermissionsExt;
108 std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
109 }
110 std::fs::rename(&tmp, &p)?;
111 #[cfg(windows)]
113 {
114 if let Some(path_str) = p.to_str() {
115 let username = std::env::var("USERNAME").unwrap_or_default();
116 if !username.is_empty() {
117 let _ = std::process::Command::new("icacls")
118 .arg(path_str)
119 .arg("/inheritance:r")
120 .arg("/grant:r")
121 .arg(format!("{username}:F"))
122 .status();
123 }
124 }
125 }
126 Ok(())
127 }
128}
129
130#[derive(Debug, Deserialize)]
135struct RawConfig {
136 #[serde(default)]
137 server: RawServer,
138 #[serde(default)]
139 accounts: Vec<RawAccount>,
140 #[serde(default)]
143 model_mapping: HashMap<String, String>,
144}
145
146#[derive(Debug, Deserialize)]
147struct RawServer {
148 #[serde(default = "default_host")]
149 host: String,
150 #[serde(default = "default_port")]
151 port: u16,
152 #[serde(default = "default_control_port")]
153 control_port: u16,
154 #[serde(default = "default_log_level")]
155 log_level: String,
156 upstream_url: Option<String>,
157 remote_key: Option<String>,
158 relay_url: Option<String>,
159 pub custom_domain: Option<String>,
160 sticky_ttl_minutes: Option<u64>,
162 expiry_soon_minutes: Option<u64>,
164 routing_strategy: Option<String>,
166 request_timeout_secs: Option<u64>,
168 rate_limit_rpm: Option<u32>,
170 trust_proxy_headers: Option<bool>,
174 health_check_enabled: Option<bool>,
176 health_check_interval_secs: Option<u64>,
178 health_check_timeout_secs: Option<u64>,
180 telemetry_url: Option<String>,
183 telemetry_token: Option<String>,
185 instance_name: Option<String>,
188 burst_rpm_limit: Option<u32>,
191 fallback_model: Option<String>,
194}
195
196impl Default for RawServer {
197 fn default() -> Self {
198 Self {
199 host: default_host(),
200 port: default_port(),
201 control_port: default_control_port(),
202 log_level: default_log_level(),
203 upstream_url: None,
204 remote_key: None,
205 relay_url: None,
206 custom_domain: None,
207 sticky_ttl_minutes: None,
208 expiry_soon_minutes: None,
209 routing_strategy: None,
210 request_timeout_secs: None,
211 rate_limit_rpm: None,
212 trust_proxy_headers: None,
213 health_check_enabled: None,
214 health_check_interval_secs: None,
215 health_check_timeout_secs: None,
216 telemetry_url: None,
217 telemetry_token: None,
218 instance_name: None,
219 burst_rpm_limit: None,
220 fallback_model: None,
221 }
222 }
223}
224
225#[derive(Debug, Deserialize)]
226struct RawAccount {
227 name: String,
228 #[serde(default = "default_plan_type")]
229 plan_type: String,
230 #[serde(default)]
232 provider: Option<String>,
233 #[serde(default)]
235 api_key: Option<String>,
236 #[serde(default)]
238 api_key_env: Option<String>,
239 #[serde(default)]
241 upstream_url: Option<String>,
242 #[serde(default)]
245 model: Option<String>,
246}
247
248fn default_host() -> String { "127.0.0.1".into() }
249
250pub fn default_instance_name() -> String {
251 hostname::get()
252 .ok()
253 .and_then(|h| h.into_string().ok())
254 .unwrap_or_else(|| "shunt".into())
255}
256fn default_port() -> u16 { 8082 }
257fn default_control_port() -> u16 { 19081 }
258fn default_log_level() -> String { "info".into() }
259fn default_plan_type() -> String { "pro".into() }
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
267pub enum RoutingStrategy {
268 Reaper,
273 Carousel,
276 Cushion,
280 #[default]
288 Maximus,
289}
290
291impl RoutingStrategy {
292 pub fn as_str(&self) -> &'static str {
293 match self {
294 Self::Reaper => "reaper",
295 Self::Carousel => "carousel",
296 Self::Cushion => "cushion",
297 Self::Maximus => "maximus",
298 }
299 }
300
301 pub fn from_str(s: &str) -> Option<Self> {
302 match s {
303 "reaper" | "earliest-expiry" | "earliest_expiry" => Some(Self::Reaper),
304 "carousel" | "round-robin" | "round_robin" => Some(Self::Carousel),
305 "cushion" | "most-available" | "most_available" => Some(Self::Cushion),
306 "maximus" => Some(Self::Maximus),
307 _ => None,
308 }
309 }
310}
311
312#[derive(Debug, Clone)]
313pub struct ServerConfig {
314 pub host: String,
315 pub port: u16,
316 pub control_port: u16,
318 pub log_level: String,
319 pub upstream_url: String,
320 pub remote_key: Option<String>,
322 pub relay_url: String,
324 pub custom_domain: Option<String>,
326 pub sticky_ttl_ms: u64,
328 pub expiry_soon_secs: u64,
330 pub routing_strategy: RoutingStrategy,
332 pub request_timeout_secs: u64,
334 pub rate_limit_rpm: u32,
336 pub trust_proxy_headers: bool,
338 pub health_check_enabled: bool,
340 pub health_check_interval_secs: u64,
342 pub health_check_timeout_secs: u64,
344 pub telemetry_url: Option<String>,
346 pub telemetry_token: Option<String>,
348 pub instance_name: String,
350 pub burst_rpm_limit: u32,
352 pub fallback_model: Option<String>,
354}
355
356impl Default for ServerConfig {
357 fn default() -> Self {
358 Self {
359 host: "127.0.0.1".into(),
360 port: 8082,
361 control_port: 19081,
362 log_level: "info".into(),
363 upstream_url: "https://api.anthropic.com".into(),
364 remote_key: None,
365 relay_url: "https://relay.ramcharan.shop".into(),
366 custom_domain: None,
367 sticky_ttl_ms: 10 * 60 * 1000,
368 expiry_soon_secs: 30 * 60,
369 routing_strategy: RoutingStrategy::Maximus,
370 request_timeout_secs: 600,
371 rate_limit_rpm: 0,
372 trust_proxy_headers: false,
373 health_check_enabled: true,
374 health_check_interval_secs: 300,
375 health_check_timeout_secs: 10,
376 telemetry_url: None,
377 telemetry_token: None,
378 instance_name: default_instance_name(),
379 burst_rpm_limit: 10,
380 fallback_model: None,
381 }
382 }
383}
384
385#[derive(Debug, Clone)]
386pub struct AccountConfig {
387 pub name: String,
388 pub plan_type: String,
389 pub provider: Provider,
390 pub credential: Option<Credential>,
395 pub upstream_url: Option<String>,
399 pub model: Option<String>,
402}
403
404#[derive(Debug, Clone)]
405pub struct Config {
406 pub server: ServerConfig,
407 pub accounts: Vec<AccountConfig>,
408 pub config_file: PathBuf,
409 pub model_mapping: HashMap<String, String>,
412}
413
414pub fn load_config(path: Option<&Path>) -> Result<Config> {
419 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
420
421 if !p.exists() {
422 bail!(
423 "Config not found: {}\nRun `shunt setup` to get started.",
424 p.display()
425 );
426 }
427
428 let raw_text = std::fs::read_to_string(&p)
429 .with_context(|| format!("Failed to read config: {}", p.display()))?;
430
431 let raw: RawConfig = toml::from_str(&raw_text)
432 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
433
434 let primary_provider_derived = raw.accounts.first()
438 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
439 .unwrap_or_default();
440 let default_upstream = primary_provider_derived.default_upstream_url().to_owned();
441
442 let upstream_url = raw
443 .server
444 .upstream_url
445 .clone()
446 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
447 .unwrap_or(default_upstream);
448
449 let relay_url = raw
450 .server
451 .relay_url
452 .clone()
453 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
454 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
455
456 let telemetry_url = raw.server.telemetry_url.clone()
457 .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
458 let telemetry_token = raw.server.telemetry_token.clone()
459 .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
460 let instance_name = raw.server.instance_name.clone()
461 .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
462 .unwrap_or_else(default_instance_name);
463
464 let server_url_is_local_derived = raw.server.upstream_url.is_none()
469 && std::env::var("SHUNT_UPSTREAM_URL").is_err()
470 && matches!(primary_provider_derived, Provider::Local);
471 validate_upstream_url(&upstream_url, server_url_is_local_derived)
472 .with_context(|| "server.upstream_url failed validation")?;
473
474 let server = ServerConfig {
475 host: raw.server.host,
476 port: raw.server.port,
477 control_port: raw.server.control_port,
478 log_level: raw.server.log_level,
479 upstream_url,
480 remote_key: raw.server.remote_key,
481 relay_url,
482 custom_domain: raw.server.custom_domain,
483 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
484 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
485 routing_strategy: raw.server.routing_strategy.as_deref()
486 .and_then(RoutingStrategy::from_str)
487 .unwrap_or_default(),
488 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
489 rate_limit_rpm: raw.server.rate_limit_rpm.unwrap_or(0),
490 trust_proxy_headers: raw.server.trust_proxy_headers.unwrap_or(false),
491 health_check_enabled: raw.server.health_check_enabled.unwrap_or(true),
492 health_check_interval_secs: raw.server.health_check_interval_secs.unwrap_or(300),
493 health_check_timeout_secs: raw.server.health_check_timeout_secs.unwrap_or(10),
494 telemetry_url,
495 telemetry_token,
496 instance_name,
497 burst_rpm_limit: raw.server.burst_rpm_limit.unwrap_or(10),
498 fallback_model: raw.server.fallback_model,
499 };
500
501 if raw.accounts.is_empty() {
502 bail!("Config has no accounts. Run `shunt setup` to add one.");
503 }
504
505 let store = CredentialsStore::load();
506
507 let primary_provider = primary_provider_derived;
509
510 let mut accounts = Vec::new();
511 for a in &raw.accounts {
512 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
513
514 let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
522 .or_else(|| {
523 a.api_key.as_deref().map(|k| {
525 tracing::warn!(account = %a.name, "Inline api_key in config.toml is insecure — use api_key_env instead");
526 Credential::Apikey { key: k.to_owned() }
527 })
528 })
529 .or_else(|| {
530 a.api_key_env.as_deref()
532 .and_then(|var| std::env::var(var).ok())
533 .map(|k| Credential::Apikey { key: k })
534 })
535 .or_else(|| {
536 provider.read_local_credentials()
539 });
540
541 let is_local = matches!(provider, Provider::Local);
545 if let Some(ref url) = a.upstream_url {
546 validate_upstream_url(url, is_local)
548 .with_context(|| format!("account '{}' upstream_url failed validation", a.name))?;
549 }
550 let acct_upstream = a.upstream_url.clone().or_else(|| {
551 if provider != primary_provider {
552 Some(provider.default_upstream_url().to_owned())
553 } else {
554 None
555 }
556 });
557
558 accounts.push(AccountConfig {
559 name: a.name.clone(),
560 plan_type: a.plan_type.clone(),
561 provider,
562 credential: cred,
563 upstream_url: acct_upstream,
564 model: a.model.clone(),
565 });
566 }
567
568 Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
569}
570
571pub fn config_template(accounts: &[(&str, &str)]) -> String {
576 let mut out = String::from(
577 "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
578 );
579 for (name, plan_type) in accounts {
580 out.push_str(&format!(
581 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
582 ));
583 }
584 out
585}