1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::credential::{deserialize_credential_map, Credential};
7use crate::oauth::OAuthCredential;
8use crate::provider::Provider;
9
10pub const APP_NAME: &str = "shunt";
11
12pub fn config_path() -> PathBuf {
13 dirs::config_dir()
14 .unwrap_or_else(|| PathBuf::from("."))
15 .join(APP_NAME)
16 .join("config.toml")
17}
18
19pub fn credentials_path() -> PathBuf {
20 dirs::config_dir()
21 .unwrap_or_else(|| PathBuf::from("."))
22 .join(APP_NAME)
23 .join("credentials.json")
24}
25
26pub fn state_path() -> PathBuf {
27 dirs::data_local_dir()
28 .unwrap_or_else(|| PathBuf::from("."))
29 .join(APP_NAME)
30 .join("state.json")
31}
32
33pub fn log_path() -> PathBuf {
34 dirs::data_local_dir()
35 .unwrap_or_else(|| PathBuf::from("."))
36 .join(APP_NAME)
37 .join("proxy.log")
38}
39
40pub fn pid_path() -> PathBuf {
41 dirs::data_local_dir()
42 .unwrap_or_else(|| PathBuf::from("."))
43 .join(APP_NAME)
44 .join("shunt.pid")
45}
46
47#[derive(Debug, Default, Serialize, Deserialize)]
52pub struct CredentialsStore {
53 #[serde(deserialize_with = "deserialize_credential_map", default)]
54 pub accounts: HashMap<String, Credential>,
55}
56
57impl CredentialsStore {
58 pub fn load() -> Self {
59 let p = credentials_path();
60 if !p.exists() {
61 return Self::default();
62 }
63 match std::fs::read_to_string(&p) {
64 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
65 Err(_) => Self::default(),
66 }
67 }
68
69 pub fn save(&self) -> Result<()> {
70 let p = credentials_path();
71 if let Some(parent) = p.parent() {
72 std::fs::create_dir_all(parent)?;
73 }
74 let tmp = p.with_extension("tmp");
75 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
76 std::fs::rename(&tmp, &p)?;
77 #[cfg(unix)]
78 {
79 use std::os::unix::fs::PermissionsExt;
80 std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
81 }
82 #[cfg(windows)]
84 {
85 if let Some(path_str) = p.to_str() {
86 let username = std::env::var("USERNAME").unwrap_or_default();
87 if !username.is_empty() {
88 let _ = std::process::Command::new("icacls")
89 .arg(path_str)
90 .arg("/inheritance:r")
91 .arg("/grant:r")
92 .arg(format!("{username}:F"))
93 .status();
94 }
95 }
96 }
97 Ok(())
98 }
99}
100
101#[derive(Debug, Deserialize)]
106struct RawConfig {
107 #[serde(default)]
108 server: RawServer,
109 #[serde(default)]
110 accounts: Vec<RawAccount>,
111}
112
113#[derive(Debug, Deserialize)]
114struct RawServer {
115 #[serde(default = "default_host")]
116 host: String,
117 #[serde(default = "default_port")]
118 port: u16,
119 #[serde(default = "default_control_port")]
120 control_port: u16,
121 #[serde(default = "default_log_level")]
122 log_level: String,
123 upstream_url: Option<String>,
124 remote_key: Option<String>,
125 relay_url: Option<String>,
126 pub custom_domain: Option<String>,
127 sticky_ttl_minutes: Option<u64>,
129 expiry_soon_minutes: Option<u64>,
131 request_timeout_secs: Option<u64>,
133}
134
135impl Default for RawServer {
136 fn default() -> Self {
137 Self {
138 host: default_host(),
139 port: default_port(),
140 control_port: default_control_port(),
141 log_level: default_log_level(),
142 upstream_url: None,
143 remote_key: None,
144 relay_url: None,
145 custom_domain: None,
146 sticky_ttl_minutes: None,
147 expiry_soon_minutes: None,
148 request_timeout_secs: None,
149 }
150 }
151}
152
153#[derive(Debug, Deserialize)]
154struct RawAccount {
155 name: String,
156 #[serde(default = "default_plan_type")]
157 plan_type: String,
158 #[serde(default)]
160 provider: Option<String>,
161 #[serde(default)]
163 api_key: Option<String>,
164 #[serde(default)]
166 api_key_env: Option<String>,
167 #[serde(default)]
169 upstream_url: Option<String>,
170}
171
172fn default_host() -> String { "127.0.0.1".into() }
173fn default_port() -> u16 { 8082 }
174fn default_control_port() -> u16 { 19081 }
175fn default_log_level() -> String { "info".into() }
176fn default_plan_type() -> String { "pro".into() }
177
178#[derive(Debug, Clone)]
183pub struct ServerConfig {
184 pub host: String,
185 pub port: u16,
186 pub control_port: u16,
188 pub log_level: String,
189 pub upstream_url: String,
190 pub remote_key: Option<String>,
192 pub relay_url: String,
194 pub custom_domain: Option<String>,
196 pub sticky_ttl_ms: u64,
198 pub expiry_soon_secs: u64,
200 pub request_timeout_secs: u64,
202}
203
204impl Default for ServerConfig {
205 fn default() -> Self {
206 Self {
207 host: "127.0.0.1".into(),
208 port: 8082,
209 control_port: 19081,
210 log_level: "info".into(),
211 upstream_url: "https://api.anthropic.com".into(),
212 remote_key: None,
213 relay_url: "https://relay.ramcharan.shop".into(),
214 custom_domain: None,
215 sticky_ttl_ms: 10 * 60 * 1000,
216 expiry_soon_secs: 30 * 60,
217 request_timeout_secs: 600,
218 }
219 }
220}
221
222#[derive(Debug, Clone)]
223pub struct AccountConfig {
224 pub name: String,
225 pub plan_type: String,
226 pub provider: Provider,
227 pub credential: Option<Credential>,
232 pub upstream_url: Option<String>,
236}
237
238#[derive(Debug, Clone)]
239pub struct Config {
240 pub server: ServerConfig,
241 pub accounts: Vec<AccountConfig>,
242 pub config_file: PathBuf,
243}
244
245pub fn load_config(path: Option<&Path>) -> Result<Config> {
250 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
251
252 if !p.exists() {
253 bail!(
254 "Config not found: {}\nRun `shunt setup` to get started.",
255 p.display()
256 );
257 }
258
259 let raw_text = std::fs::read_to_string(&p)
260 .with_context(|| format!("Failed to read config: {}", p.display()))?;
261
262 let raw: RawConfig = toml::from_str(&raw_text)
263 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
264
265 let default_upstream = raw.accounts.first()
269 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
270 .unwrap_or_default()
271 .default_upstream_url()
272 .to_owned();
273
274 let upstream_url = raw
275 .server
276 .upstream_url
277 .clone()
278 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
279 .unwrap_or(default_upstream);
280
281 let relay_url = raw
282 .server
283 .relay_url
284 .clone()
285 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
286 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
287
288 let server = ServerConfig {
289 host: raw.server.host,
290 port: raw.server.port,
291 control_port: raw.server.control_port,
292 log_level: raw.server.log_level,
293 upstream_url,
294 remote_key: raw.server.remote_key,
295 relay_url,
296 custom_domain: raw.server.custom_domain,
297 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
298 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
299 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
300 };
301
302 if raw.accounts.is_empty() {
303 bail!("Config has no accounts. Run `shunt setup` to add one.");
304 }
305
306 let store = CredentialsStore::load();
307
308 let primary_provider = raw.accounts.first()
311 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
312 .unwrap_or_default();
313
314 let mut accounts = Vec::new();
315 for a in &raw.accounts {
316 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
317
318 let cred: Option<Credential> = store.accounts.get(&a.name).cloned()
326 .or_else(|| {
327 a.api_key.as_deref().map(|k| Credential::Apikey { key: k.to_owned() })
329 })
330 .or_else(|| {
331 a.api_key_env.as_deref()
333 .and_then(|var| std::env::var(var).ok())
334 .map(|k| Credential::Apikey { key: k })
335 })
336 .or_else(|| {
337 provider.read_local_credentials()
340 });
341
342 let acct_upstream = a.upstream_url.clone().or_else(|| {
346 if provider != primary_provider {
347 Some(provider.default_upstream_url().to_owned())
348 } else {
349 None
350 }
351 });
352
353 accounts.push(AccountConfig {
354 name: a.name.clone(),
355 plan_type: a.plan_type.clone(),
356 provider,
357 credential: cred,
358 upstream_url: acct_upstream,
359 });
360 }
361
362 Ok(Config { server, accounts, config_file: p })
363}
364
365pub fn config_template(accounts: &[(&str, &str)]) -> String {
370 let mut out = String::from(
371 "[server]\nhost = \"127.0.0.1\"\nport = 8082\ncontrol_port = 19081\nlog_level = \"info\"\n",
372 );
373 for (name, plan_type) in accounts {
374 out.push_str(&format!(
375 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
376 ));
377 }
378 out
379}