1use anyhow::{bail, Context, Result};
2
3fn validate_upstream_url(url: &str, allow_loopback: bool) -> Result<()> {
7 let parsed = url::Url::parse(url)
8 .with_context(|| format!("Invalid upstream URL: {url}"))?;
9 match parsed.scheme() {
10 "http" | "https" => {}
11 s => bail!("Upstream URL must use http or https, got scheme '{s}': {url}"),
12 }
13 if !allow_loopback {
14 if let Some(host) = parsed.host_str() {
15 let blocked = matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
16 || host.starts_with("169.254.")
17 || host.starts_with("fd");
18 if blocked {
19 bail!("Upstream URL must not point to loopback or link-local addresses: {url}");
20 }
21 }
22 }
23 Ok(())
24}
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29use crate::credential::{deserialize_credential_map, Credential};
30use crate::provider::Provider;
31
32pub const APP_NAME: &str = "shunt";
33
34pub fn config_path() -> PathBuf {
35 dirs::config_dir()
36 .unwrap_or_else(|| PathBuf::from("."))
37 .join(APP_NAME)
38 .join("config.toml")
39}
40
41pub fn credentials_path() -> PathBuf {
42 dirs::config_dir()
43 .unwrap_or_else(|| PathBuf::from("."))
44 .join(APP_NAME)
45 .join("credentials.json")
46}
47
48pub fn state_path() -> PathBuf {
49 dirs::data_local_dir()
50 .unwrap_or_else(|| PathBuf::from("."))
51 .join(APP_NAME)
52 .join("state.json")
53}
54
55pub fn log_path() -> PathBuf {
56 dirs::data_local_dir()
57 .unwrap_or_else(|| PathBuf::from("."))
58 .join(APP_NAME)
59 .join("proxy.log")
60}
61
62pub fn notify_log_path() -> PathBuf {
63 dirs::data_local_dir()
64 .unwrap_or_else(|| PathBuf::from("."))
65 .join(APP_NAME)
66 .join("notify.log")
67}
68
69pub fn pid_path() -> PathBuf {
70 dirs::data_local_dir()
71 .unwrap_or_else(|| PathBuf::from("."))
72 .join(APP_NAME)
73 .join("shunt.pid")
74}
75
76#[derive(Debug, Default, Serialize, Deserialize)]
81pub struct CredentialsStore {
82 #[serde(deserialize_with = "deserialize_credential_map", default)]
83 pub accounts: HashMap<String, Credential>,
84}
85
86impl CredentialsStore {
87 pub fn load() -> Self {
88 let p = credentials_path();
89 if !p.exists() {
90 return Self::default();
91 }
92 match std::fs::read_to_string(&p) {
93 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
94 Err(_) => Self::default(),
95 }
96 }
97
98 pub fn save(&self) -> Result<()> {
99 let p = credentials_path();
100 if let Some(parent) = p.parent() {
101 std::fs::create_dir_all(parent)?;
102 }
103 let tmp = p.with_extension("tmp");
104 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
105 #[cfg(unix)]
106 {
107 use std::os::unix::fs::PermissionsExt;
108 std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
109 }
110 std::fs::rename(&tmp, &p)?;
111 #[cfg(windows)]
113 {
114 if let Some(path_str) = p.to_str() {
115 let username = std::env::var("USERNAME").unwrap_or_default();
116 if !username.is_empty() {
117 let _ = std::process::Command::new("icacls")
118 .arg(path_str)
119 .arg("/inheritance:r")
120 .arg("/grant:r")
121 .arg(format!("{username}:F"))
122 .status();
123 }
124 }
125 }
126 Ok(())
127 }
128}
129
130#[derive(Debug, Deserialize)]
135struct RawConfig {
136 #[serde(default)]
137 server: RawServer,
138 #[serde(default)]
139 accounts: Vec<RawAccount>,
140 #[serde(default)]
143 model_mapping: HashMap<String, String>,
144}
145
146#[derive(Debug, Deserialize)]
147struct RawServer {
148 #[serde(default = "default_host")]
149 host: String,
150 #[serde(default = "default_port")]
151 port: u16,
152 #[serde(default = "default_control_port")]
153 control_port: u16,
154 #[serde(default = "default_log_level")]
155 log_level: String,
156 upstream_url: Option<String>,
157 remote_key: Option<String>,
158 relay_url: Option<String>,
159 pub custom_domain: Option<String>,
160 sticky_ttl_minutes: Option<u64>,
162 expiry_soon_minutes: Option<u64>,
164 routing_strategy: Option<String>,
166 request_timeout_secs: Option<u64>,
168 rate_limit_rpm: Option<u32>,
170 trust_proxy_headers: Option<bool>,
174 health_check_enabled: Option<bool>,
176 health_check_interval_secs: Option<u64>,
178 health_check_timeout_secs: Option<u64>,
180 telemetry_url: Option<String>,
183 telemetry_token: Option<String>,
185 instance_name: Option<String>,
188}
189
190impl Default for RawServer {
191 fn default() -> Self {
192 Self {
193 host: default_host(),
194 port: default_port(),
195 control_port: default_control_port(),
196 log_level: default_log_level(),
197 upstream_url: None,
198 remote_key: None,
199 relay_url: None,
200 custom_domain: None,
201 sticky_ttl_minutes: None,
202 expiry_soon_minutes: None,
203 routing_strategy: None,
204 request_timeout_secs: None,
205 rate_limit_rpm: None,
206 trust_proxy_headers: None,
207 health_check_enabled: None,
208 health_check_interval_secs: None,
209 health_check_timeout_secs: None,
210 telemetry_url: None,
211 telemetry_token: None,
212 instance_name: None,
213 }
214 }
215}
216
217#[derive(Debug, Deserialize)]
218struct RawAccount {
219 name: String,
220 #[serde(default = "default_plan_type")]
221 plan_type: String,
222 #[serde(default)]
224 provider: Option<String>,
225 #[serde(default)]
227 api_key: Option<String>,
228 #[serde(default)]
230 api_key_env: Option<String>,
231 #[serde(default)]
233 upstream_url: Option<String>,
234 #[serde(default)]
237 model: Option<String>,
238}
239
240fn default_host() -> String { "127.0.0.1".into() }
241
242pub fn default_instance_name() -> String {
243 hostname::get()
244 .ok()
245 .and_then(|h| h.into_string().ok())
246 .unwrap_or_else(|| "shunt".into())
247}
248fn default_port() -> u16 { 8082 }
249fn default_control_port() -> u16 { 19081 }
250fn default_log_level() -> String { "info".into() }
251fn default_plan_type() -> String { "pro".into() }
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
259pub enum RoutingStrategy {
260 Reaper,
265 Carousel,
268 Cushion,
272 #[default]
280 Maximus,
281}
282
283impl RoutingStrategy {
284 pub fn as_str(&self) -> &'static str {
285 match self {
286 Self::Reaper => "reaper",
287 Self::Carousel => "carousel",
288 Self::Cushion => "cushion",
289 Self::Maximus => "maximus",
290 }
291 }
292
293 pub fn from_str(s: &str) -> Option<Self> {
294 match s {
295 "reaper" | "earliest-expiry" | "earliest_expiry" => Some(Self::Reaper),
296 "carousel" | "round-robin" | "round_robin" => Some(Self::Carousel),
297 "cushion" | "most-available" | "most_available" => Some(Self::Cushion),
298 "maximus" => Some(Self::Maximus),
299 _ => None,
300 }
301 }
302}
303
304#[derive(Debug, Clone)]
305pub struct ServerConfig {
306 pub host: String,
307 pub port: u16,
308 pub control_port: u16,
310 pub log_level: String,
311 pub upstream_url: String,
312 pub remote_key: Option<String>,
314 pub relay_url: String,
316 pub custom_domain: Option<String>,
318 pub sticky_ttl_ms: u64,
320 pub expiry_soon_secs: u64,
322 pub routing_strategy: RoutingStrategy,
324 pub request_timeout_secs: u64,
326 pub rate_limit_rpm: u32,
328 pub trust_proxy_headers: bool,
330 pub health_check_enabled: bool,
332 pub health_check_interval_secs: u64,
334 pub health_check_timeout_secs: u64,
336 pub telemetry_url: Option<String>,
338 pub telemetry_token: Option<String>,
340 pub instance_name: String,
342}
343
344impl Default for ServerConfig {
345 fn default() -> Self {
346 Self {
347 host: "127.0.0.1".into(),
348 port: 8082,
349 control_port: 19081,
350 log_level: "info".into(),
351 upstream_url: "https://api.anthropic.com".into(),
352 remote_key: None,
353 relay_url: "https://relay.ramcharan.shop".into(),
354 custom_domain: None,
355 sticky_ttl_ms: 10 * 60 * 1000,
356 expiry_soon_secs: 30 * 60,
357 routing_strategy: RoutingStrategy::Maximus,
358 request_timeout_secs: 600,
359 rate_limit_rpm: 0,
360 trust_proxy_headers: false,
361 health_check_enabled: true,
362 health_check_interval_secs: 300,
363 health_check_timeout_secs: 10,
364 telemetry_url: None,
365 telemetry_token: None,
366 instance_name: default_instance_name(),
367 }
368 }
369}
370
371#[derive(Debug, Clone)]
372pub struct AccountConfig {
373 pub name: String,
374 pub plan_type: String,
375 pub provider: Provider,
376 pub credential: Option<Credential>,
381 pub upstream_url: Option<String>,
385 pub model: Option<String>,
388}
389
390#[derive(Debug, Clone)]
391pub struct Config {
392 pub server: ServerConfig,
393 pub accounts: Vec<AccountConfig>,
394 pub config_file: PathBuf,
395 pub model_mapping: HashMap<String, String>,
398}
399
400pub fn load_config(path: Option<&Path>) -> Result<Config> {
405 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
406
407 if !p.exists() {
408 bail!(
409 "Config not found: {}\nRun `shunt setup` to get started.",
410 p.display()
411 );
412 }
413
414 let raw_text = std::fs::read_to_string(&p)
415 .with_context(|| format!("Failed to read config: {}", p.display()))?;
416
417 let raw: RawConfig = toml::from_str(&raw_text)
418 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
419
420 let primary_provider_derived = raw.accounts.first()
424 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
425 .unwrap_or_default();
426 let default_upstream = primary_provider_derived.default_upstream_url().to_owned();
427
428 let upstream_url = raw
429 .server
430 .upstream_url
431 .clone()
432 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
433 .unwrap_or(default_upstream);
434
435 let relay_url = raw
436 .server
437 .relay_url
438 .clone()
439 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
440 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
441
442 let telemetry_url = raw.server.telemetry_url.clone()
443 .or_else(|| std::env::var("SHUNT_TELEMETRY_URL").ok());
444 let telemetry_token = raw.server.telemetry_token.clone()
445 .or_else(|| std::env::var("SHUNT_TELEMETRY_TOKEN").ok());
446 let instance_name = raw.server.instance_name.clone()
447 .or_else(|| std::env::var("SHUNT_INSTANCE_NAME").ok())
448 .unwrap_or_else(default_instance_name);
449
450 let server_url_is_local_derived = raw.server.upstream_url.is_none()
455 && std::env::var("SHUNT_UPSTREAM_URL").is_err()
456 && matches!(primary_provider_derived, Provider::Local);
457 validate_upstream_url(&upstream_url, server_url_is_local_derived)
458 .with_context(|| "server.upstream_url failed validation")?;
459
460 let server = ServerConfig {
461 host: raw.server.host,
462 port: raw.server.port,
463 control_port: raw.server.control_port,
464 log_level: raw.server.log_level,
465 upstream_url,
466 remote_key: raw.server.remote_key,
467 relay_url,
468 custom_domain: raw.server.custom_domain,
469 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
470 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
471 routing_strategy: raw.server.routing_strategy.as_deref()
472 .and_then(RoutingStrategy::from_str)
473 .unwrap_or_default(),
474 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
475 rate_limit_rpm: raw.server.rate_limit_rpm.unwrap_or(0),
476 trust_proxy_headers: raw.server.trust_proxy_headers.unwrap_or(false),
477 health_check_enabled: raw.server.health_check_enabled.unwrap_or(true),
478 health_check_interval_secs: raw.server.health_check_interval_secs.unwrap_or(300),
479 health_check_timeout_secs: raw.server.health_check_timeout_secs.unwrap_or(10),
480 telemetry_url,
481 telemetry_token,
482 instance_name,
483 };
484
485 if raw.accounts.is_empty() {
486 bail!("Config has no accounts. Run `shunt setup` to add one.");
487 }
488
489 let store = CredentialsStore::load();
490
491 let primary_provider = primary_provider_derived;
493
494 let mut accounts = Vec::new();
495 for a in &raw.accounts {
496 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
497
498 let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
506 .or_else(|| {
507 a.api_key.as_deref().map(|k| {
509 tracing::warn!(account = %a.name, "Inline api_key in config.toml is insecure — use api_key_env instead");
510 Credential::Apikey { key: k.to_owned() }
511 })
512 })
513 .or_else(|| {
514 a.api_key_env.as_deref()
516 .and_then(|var| std::env::var(var).ok())
517 .map(|k| Credential::Apikey { key: k })
518 })
519 .or_else(|| {
520 provider.read_local_credentials()
523 });
524
525 let is_local = matches!(provider, Provider::Local);
529 if let Some(ref url) = a.upstream_url {
530 validate_upstream_url(url, is_local)
532 .with_context(|| format!("account '{}' upstream_url failed validation", a.name))?;
533 }
534 let acct_upstream = a.upstream_url.clone().or_else(|| {
535 if provider != primary_provider {
536 Some(provider.default_upstream_url().to_owned())
537 } else {
538 None
539 }
540 });
541
542 accounts.push(AccountConfig {
543 name: a.name.clone(),
544 plan_type: a.plan_type.clone(),
545 provider,
546 credential: cred,
547 upstream_url: acct_upstream,
548 model: a.model.clone(),
549 });
550 }
551
552 Ok(Config { server, accounts, config_file: p, model_mapping: raw.model_mapping })
553}
554
555pub fn config_template(accounts: &[(&str, &str)]) -> String {
560 let mut out = String::from(
561 "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
562 );
563 for (name, plan_type) in accounts {
564 out.push_str(&format!(
565 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
566 ));
567 }
568 out
569}