1use anyhow::{bail, Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::oauth::OAuthCredential;
7use crate::provider::Provider;
8
9pub const APP_NAME: &str = "shunt";
10
11pub fn config_path() -> PathBuf {
12 dirs::config_dir()
13 .unwrap_or_else(|| PathBuf::from("."))
14 .join(APP_NAME)
15 .join("config.toml")
16}
17
18pub fn credentials_path() -> PathBuf {
19 dirs::config_dir()
20 .unwrap_or_else(|| PathBuf::from("."))
21 .join(APP_NAME)
22 .join("credentials.json")
23}
24
25pub fn state_path() -> PathBuf {
26 dirs::data_local_dir()
27 .unwrap_or_else(|| PathBuf::from("."))
28 .join(APP_NAME)
29 .join("state.json")
30}
31
32pub fn log_path() -> PathBuf {
33 dirs::data_local_dir()
34 .unwrap_or_else(|| PathBuf::from("."))
35 .join(APP_NAME)
36 .join("proxy.log")
37}
38
39pub fn pid_path() -> PathBuf {
40 dirs::data_local_dir()
41 .unwrap_or_else(|| PathBuf::from("."))
42 .join(APP_NAME)
43 .join("shunt.pid")
44}
45
46#[derive(Debug, Default, Serialize, Deserialize)]
51pub struct CredentialsStore {
52 pub accounts: HashMap<String, OAuthCredential>,
53}
54
55impl CredentialsStore {
56 pub fn load() -> Self {
57 let p = credentials_path();
58 if !p.exists() {
59 return Self::default();
60 }
61 match std::fs::read_to_string(&p) {
62 Ok(text) => serde_json::from_str(&text).unwrap_or_default(),
63 Err(_) => Self::default(),
64 }
65 }
66
67 pub fn save(&self) -> Result<()> {
68 let p = credentials_path();
69 if let Some(parent) = p.parent() {
70 std::fs::create_dir_all(parent)?;
71 }
72 let tmp = p.with_extension("tmp");
73 std::fs::write(&tmp, serde_json::to_string_pretty(self)?)?;
74 std::fs::rename(&tmp, &p)?;
75 #[cfg(unix)]
76 {
77 use std::os::unix::fs::PermissionsExt;
78 std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600))?;
79 }
80 #[cfg(windows)]
82 {
83 if let Some(path_str) = p.to_str() {
84 let username = std::env::var("USERNAME").unwrap_or_default();
85 if !username.is_empty() {
86 let _ = std::process::Command::new("icacls")
87 .arg(path_str)
88 .arg("/inheritance:r")
89 .arg("/grant:r")
90 .arg(format!("{username}:F"))
91 .status();
92 }
93 }
94 }
95 Ok(())
96 }
97}
98
99#[derive(Debug, Deserialize)]
104struct RawConfig {
105 #[serde(default)]
106 server: RawServer,
107 #[serde(default)]
108 accounts: Vec<RawAccount>,
109}
110
111#[derive(Debug, Deserialize)]
112struct RawServer {
113 #[serde(default = "default_host")]
114 host: String,
115 #[serde(default = "default_port")]
116 port: u16,
117 #[serde(default = "default_log_level")]
118 log_level: String,
119 upstream_url: Option<String>,
120 remote_key: Option<String>,
121 relay_url: Option<String>,
122 sticky_ttl_minutes: Option<u64>,
124 expiry_soon_minutes: Option<u64>,
126 request_timeout_secs: Option<u64>,
128}
129
130impl Default for RawServer {
131 fn default() -> Self {
132 Self {
133 host: default_host(),
134 port: default_port(),
135 log_level: default_log_level(),
136 upstream_url: None,
137 remote_key: None,
138 relay_url: None,
139 sticky_ttl_minutes: None,
140 expiry_soon_minutes: None,
141 request_timeout_secs: None,
142 }
143 }
144}
145
146#[derive(Debug, Deserialize)]
147struct RawAccount {
148 name: String,
149 #[serde(default = "default_plan_type")]
150 plan_type: String,
151 #[serde(default)]
153 provider: Option<String>,
154}
155
156fn default_host() -> String { "127.0.0.1".into() }
157fn default_port() -> u16 { 8082 }
158fn default_log_level() -> String { "info".into() }
159fn default_plan_type() -> String { "pro".into() }
160
161#[derive(Debug, Clone)]
166pub struct ServerConfig {
167 pub host: String,
168 pub port: u16,
169 pub log_level: String,
170 pub upstream_url: String,
171 pub remote_key: Option<String>,
173 pub relay_url: String,
175 pub sticky_ttl_ms: u64,
177 pub expiry_soon_secs: u64,
179 pub request_timeout_secs: u64,
181}
182
183impl Default for ServerConfig {
184 fn default() -> Self {
185 Self {
186 host: "127.0.0.1".into(),
187 port: 8082,
188 log_level: "info".into(),
189 upstream_url: "https://api.anthropic.com".into(),
190 remote_key: None,
191 relay_url: "https://relay.ramcharan.shop".into(),
192 sticky_ttl_ms: 10 * 60 * 1000,
193 expiry_soon_secs: 30 * 60,
194 request_timeout_secs: 600,
195 }
196 }
197}
198
199#[derive(Debug, Clone)]
200pub struct AccountConfig {
201 pub name: String,
202 pub plan_type: String,
203 pub provider: Provider,
204 pub credential: Option<OAuthCredential>,
207 pub upstream_url: Option<String>,
212}
213
214#[derive(Debug, Clone)]
215pub struct Config {
216 pub server: ServerConfig,
217 pub accounts: Vec<AccountConfig>,
218 pub config_file: PathBuf,
219}
220
221pub fn load_config(path: Option<&Path>) -> Result<Config> {
226 let p = path.map(PathBuf::from).unwrap_or_else(config_path);
227
228 if !p.exists() {
229 bail!(
230 "Config not found: {}\nRun `shunt setup` to get started.",
231 p.display()
232 );
233 }
234
235 let raw_text = std::fs::read_to_string(&p)
236 .with_context(|| format!("Failed to read config: {}", p.display()))?;
237
238 let raw: RawConfig = toml::from_str(&raw_text)
239 .with_context(|| format!("Failed to parse config: {}", p.display()))?;
240
241 let default_upstream = raw.accounts.first()
245 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
246 .unwrap_or_default()
247 .default_upstream_url()
248 .to_owned();
249
250 let upstream_url = raw
251 .server
252 .upstream_url
253 .clone()
254 .or_else(|| std::env::var("SHUNT_UPSTREAM_URL").ok())
255 .unwrap_or(default_upstream);
256
257 let relay_url = raw
258 .server
259 .relay_url
260 .clone()
261 .or_else(|| std::env::var("SHUNT_RELAY_URL").ok())
262 .unwrap_or_else(|| "https://relay.ramcharan.shop".into());
263
264 let server = ServerConfig {
265 host: raw.server.host,
266 port: raw.server.port,
267 log_level: raw.server.log_level,
268 upstream_url,
269 remote_key: raw.server.remote_key,
270 relay_url,
271 sticky_ttl_ms: raw.server.sticky_ttl_minutes.unwrap_or(10) * 60 * 1000,
272 expiry_soon_secs: raw.server.expiry_soon_minutes.unwrap_or(30) * 60,
273 request_timeout_secs: raw.server.request_timeout_secs.unwrap_or(600),
274 };
275
276 if raw.accounts.is_empty() {
277 bail!("Config has no accounts. Run `shunt setup` to add one.");
278 }
279
280 let store = CredentialsStore::load();
281
282 let primary_provider = raw.accounts.first()
285 .map(|a| a.provider.as_deref().map(Provider::from_str).unwrap_or_default())
286 .unwrap_or_default();
287
288 let mut accounts = Vec::new();
289 for a in &raw.accounts {
290 let provider = a.provider.as_deref().map(Provider::from_str).unwrap_or_default();
291
292 let cred = store
294 .accounts
295 .get(&a.name)
296 .cloned()
297 .or_else(|| provider.read_local_credentials());
298
299 let acct_upstream = if provider != primary_provider {
302 Some(provider.default_upstream_url().to_owned())
303 } else {
304 None
305 };
306
307 accounts.push(AccountConfig {
308 name: a.name.clone(),
309 plan_type: a.plan_type.clone(),
310 provider,
311 credential: cred,
312 upstream_url: acct_upstream,
313 });
314 }
315
316 Ok(Config { server, accounts, config_file: p })
317}
318
319pub fn config_template(accounts: &[(&str, &str)]) -> String {
324 let mut out = String::from(
325 "[server]\nhost = \"127.0.0.1\"\nport = 8082\nlog_level = \"info\"\n",
326 );
327 for (name, plan_type) in accounts {
328 out.push_str(&format!(
329 "\n[[accounts]]\nname = \"{name}\"\nplan_type = \"{plan_type}\"\n"
330 ));
331 }
332 out
333}