1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::oauth::OAuthCredential;
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12 dirs::config_dir()
13 .unwrap_or_else(|| PathBuf::from("."))
14 .join(APP_NAME)
15 .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19 dirs::config_dir()
20 .unwrap_or_else(|| PathBuf::from("."))
21 .join(APP_NAME)
22 .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26 dirs::data_local_dir()
27 .unwrap_or_else(|| PathBuf::from("."))
28 .join(APP_NAME)
29 .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33 dirs::data_local_dir()
34 .unwrap_or_else(|| PathBuf::from("."))
35 .join(APP_NAME)
36 .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40 dirs::data_local_dir()
41 .unwrap_or_else(|| PathBuf::from("."))
42 .join(APP_NAME)
43 .join("shunt.pid")
44}
45
46#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52 pub accounts: HashMap<String, OAuthCredential>,
53}
54
55impl CredentialsStore {
56 pub fn load() -> Self {
57 let p = credentials_path();
58 if !p.exists() {
59 return Self::default();
60 }
61 match std::fs::read_to_string(&p) {
62 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
63 Err(_) => Self::default(),
64 }
65 }
66
67 pub fn save(&self) -> Result<()> {
68 let p = credentials_path();
69 if let Some(parent) = p.parent() {
70 std::fs::create_dir_all(parent)?;
71 }
72 let tmp = p.with_extension("tmp");
73 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
74 std::fs::rename(&tmp, &p)?;
75 #[cfg(unix)]
76 {
77 use std::os::unix::fs::PermissionsExt;
78 std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
79 }
80 #[cfg(windows)]
82 {
83 if let Some(path_str) = p.to_str() {
84 let username = std::env::var("USERNAME").unwrap_or_default();
85 if !username.is_empty() {
86 let _ = std::process::Command::new("icacls")
87 .arg(path_str)
88 .arg("/inheritance:r")
89 .arg("/grant:r")
90 .arg(format!("{username}:F"))
91 .status();
92 }
93 }
94 }
95 Ok(())
96 }
97}
98
99#[derive(Debug, Deserialize)]
104struct RawConfig {
105 #[serde(default)]
106 server: RawServer,
107 #[serde(default)]
108 accounts: Vec<RawAccount>,
109}
110
111#[derive(Debug, Deserialize)]
112struct RawServer {
113 #[serde(default = "default_host")]
114 host: String,
115 #[serde(default = "default_port")]
116 port: u16,
117 #[serde(default = "default_log_level")]
118 log_level: String,
119 upstream_url: Option<String>,
120 remote_key: Option<String>,
121 relay_url: Option<String>,
122 pub custom_domain: Option<String>,
123 sticky_ttl_minutes: Option<u64>,
125 expiry_soon_minutes: Option<u64>,
127 request_timeout_secs: Option<u64>,
129}
130
131impl Default for RawServer {
132 fn default() -> Self {
133 Self {
134 host: default_host(),
135 port: default_port(),
136 log_level: default_log_level(),
137 upstream_url: None,
138 remote_key: None,
139 relay_url: None,
140 custom_domain: None,
141 sticky_ttl_minutes: None,
142 expiry_soon_minutes: None,
143 request_timeout_secs: None,
144 }
145 }
146}
147
148#[derive(Debug, Deserialize)]
149struct RawAccount {
150 name: String,
151 #[serde(default = "default_plan_type")]
152 plan_type: String,
153 #[serde(default)]
155 provider: Option<String>,
156}
157
158fn default_host() -> String { "127.0.0.1".into() }
159fn default_port() -> u16 { 8082 }
160fn default_log_level() -> String { "info".into() }
161fn default_plan_type() -> String { "pro".into() }
162
163#[derive(Debug, Clone)]
168pub struct ServerConfig {
169 pub host: String,
170 pub port: u16,
171 pub log_level: String,
172 pub upstream_url: String,
173 pub remote_key: Option<String>,
175 pub relay_url: String,
177 pub custom_domain: Option<String>,
179 pub sticky_ttl_ms: u64,
181 pub expiry_soon_secs: u64,
183 pub request_timeout_secs: u64,
185}
186
187impl Default for ServerConfig {
188 fn default() -> Self {
189 Self {
190 host: "127.0.0.1".into(),
191 port: 8082,
192 log_level: "info".into(),
193 upstream_url: "https://api.anthropic.com".into(),
194 remote_key: None,
195 relay_url: "https://relay.ramcharan.shop".into(),
196 custom_domain: None,
197 sticky_ttl_ms: 10 * 60 * 1000,
198 expiry_soon_secs: 30 * 60,
199 request_timeout_secs: 600,
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct AccountConfig {
206 pub name: String,
207 pub plan_type: String,
208 pub provider: Provider,
209 pub credential: Option<OAuthCredential>,
212 pub upstream_url: Option<String>,
217}
218
219#[derive(Debug, Clone)]
220pub struct Config {
221 pub server: ServerConfig,
222 pub accounts: Vec<AccountConfig>,
223 pub config_file: PathBuf,
224}
225
226pub fn load_config(path: Option<&Path>) -> Result<Config> {
231 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
232
233 if !p.exists() {
234 bail!(
235 "Config not found: {}\nRun `shunt setup` to get started.",
236 p.display()
237 );
238 }
239
240 let raw_text = std::fs::read_to_string(&p)
241 .with_context(|| format!("Failed to read config: {}", p.display()))?;
242
243 let raw: RawConfig = toml::from_str(&raw_text)
244 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
245
246 let default_upstream = raw.accounts.first()
250 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
251 .unwrap_or_default()
252 .default_upstream_url()
253 .to_owned();
254
255 let upstream_url = raw
256 .server
257 .upstream_url
258 .clone()
259 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
260 .unwrap_or(default_upstream);
261
262 let relay_url = raw
263 .server
264 .relay_url
265 .clone()
266 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
267 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
268
269 let server = ServerConfig {
270 host: raw.server.host,
271 port: raw.server.port,
272 log_level: raw.server.log_level,
273 upstream_url,
274 remote_key: raw.server.remote_key,
275 relay_url,
276 custom_domain: raw.server.custom_domain,
277 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
278 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
279 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
280 };
281
282 if raw.accounts.is_empty() {
283 bail!("Config has no accounts. Run `shunt setup` to add one.");
284 }
285
286 let store = CredentialsStore::load();
287
288 let primary_provider = raw.accounts.first()
291 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
292 .unwrap_or_default();
293
294 let mut accounts = Vec::new();
295 for a in &raw.accounts {
296 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
297
298 let cred = store
300 .accounts
301 .get(&a.name)
302 .cloned()
303 .or_else(|| provider.read_local_credentials());
304
305 let acct_upstream = if provider != primary_provider {
308 Some(provider.default_upstream_url().to_owned())
309 } else {
310 None
311 };
312
313 accounts.push(AccountConfig {
314 name: a.name.clone(),
315 plan_type: a.plan_type.clone(),
316 provider,
317 credential: cred,
318 upstream_url: acct_upstream,
319 });
320 }
321
322 Ok(Config { server, accounts, config_file: p })
323}
324
325pub fn config_template(accounts: &[(&str, &str)]) -> String {
330 let mut out = String::from(
331 "[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
332 );
333 for (name, plan_type) in accounts {
334 out.push_str(&format!(
335 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
336 ));
337 }
338 out
339}