1use crate::models::auth::ProviderAuth;
39use crate::oauth::error::{OAuthError, OAuthResult};
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use std::path::{Path, PathBuf};
43
44const AUTH_FILE_NAME: &str = "auth.toml";
46
47const ALL_PROFILE: &str = "all";
49
50#[derive(Debug, Clone, Serialize, Deserialize, Default)]
52pub struct AuthFile {
53 #[serde(flatten)]
55 pub profiles: HashMap<String, HashMap<String, ProviderAuth>>,
56}
57
58#[derive(Debug, Clone)]
60pub struct AuthManager {
61 auth_path: PathBuf,
63 auth_file: AuthFile,
65}
66
67impl AuthManager {
68 pub fn new(config_dir: &Path) -> OAuthResult<Self> {
70 let auth_path = config_dir.join(AUTH_FILE_NAME);
71 let auth_file = if auth_path.is_file() {
72 let content = std::fs::read_to_string(&auth_path)?;
73 toml::from_str(&content)?
74 } else {
75 AuthFile::default()
76 };
77
78 Ok(Self {
79 auth_path,
80 auth_file,
81 })
82 }
83
84 pub fn from_default_dir() -> OAuthResult<Self> {
86 let config_dir = get_default_config_dir()?;
87 Self::new(&config_dir)
88 }
89
90 pub fn get(&self, profile: &str, provider: &str) -> Option<&ProviderAuth> {
96 if let Some(providers) = self.auth_file.profiles.get(profile)
98 && let Some(auth) = providers.get(provider)
99 {
100 return Some(auth);
101 }
102
103 if profile != ALL_PROFILE
105 && let Some(providers) = self.auth_file.profiles.get(ALL_PROFILE)
106 && let Some(auth) = providers.get(provider)
107 {
108 return Some(auth);
109 }
110
111 None
112 }
113
114 pub fn set(&mut self, profile: &str, provider: &str, auth: ProviderAuth) -> OAuthResult<()> {
116 self.auth_file
117 .profiles
118 .entry(profile.to_string())
119 .or_default()
120 .insert(provider.to_string(), auth);
121
122 self.save()
123 }
124
125 pub fn remove(&mut self, profile: &str, provider: &str) -> OAuthResult<bool> {
127 let removed = if let Some(providers) = self.auth_file.profiles.get_mut(profile) {
128 let removed = providers.remove(provider).is_some();
129 if providers.is_empty() {
131 self.auth_file.profiles.remove(profile);
132 }
133 removed
134 } else {
135 false
136 };
137
138 if removed {
139 self.save()?;
140 }
141
142 Ok(removed)
143 }
144
145 pub fn list(&self) -> &HashMap<String, HashMap<String, ProviderAuth>> {
147 &self.auth_file.profiles
148 }
149
150 pub fn list_for_profile(&self, profile: &str) -> HashMap<String, &ProviderAuth> {
152 let mut result = HashMap::new();
153
154 if let Some(all_providers) = self.auth_file.profiles.get(ALL_PROFILE) {
156 for (provider, auth) in all_providers {
157 result.insert(provider.clone(), auth);
158 }
159 }
160
161 if profile != ALL_PROFILE
163 && let Some(profile_providers) = self.auth_file.profiles.get(profile)
164 {
165 for (provider, auth) in profile_providers {
166 result.insert(provider.clone(), auth);
167 }
168 }
169
170 result
171 }
172
173 pub fn has_credentials(&self) -> bool {
175 self.auth_file
176 .profiles
177 .values()
178 .any(|providers| !providers.is_empty())
179 }
180
181 pub fn auth_path(&self) -> &Path {
183 &self.auth_path
184 }
185
186 pub fn update_oauth_tokens(
188 &mut self,
189 profile: &str,
190 provider: &str,
191 access: &str,
192 refresh: &str,
193 expires: i64,
194 ) -> OAuthResult<()> {
195 let auth = ProviderAuth::oauth(access, refresh, expires);
196 self.set(profile, provider, auth)
197 }
198
199 fn save(&self) -> OAuthResult<()> {
201 if let Some(parent) = self.auth_path.parent() {
203 std::fs::create_dir_all(parent)?;
204 }
205
206 let content = toml::to_string_pretty(&self.auth_file)?;
207
208 let temp_path = self.auth_path.with_extension("toml.tmp");
210 std::fs::write(&temp_path, &content)?;
211
212 #[cfg(unix)]
214 {
215 use std::os::unix::fs::PermissionsExt;
216 let permissions = std::fs::Permissions::from_mode(0o600);
217 std::fs::set_permissions(&temp_path, permissions)?;
218 }
219
220 std::fs::rename(&temp_path, &self.auth_path)?;
222
223 Ok(())
224 }
225}
226
227pub fn get_default_config_dir() -> OAuthResult<PathBuf> {
229 let home = dirs::home_dir().ok_or_else(|| {
230 OAuthError::IoError(std::io::Error::new(
231 std::io::ErrorKind::NotFound,
232 "Could not determine home directory",
233 ))
234 })?;
235
236 Ok(home.join(".stakpak"))
237}
238
239pub fn get_auth_file_path(config_dir: &Path) -> PathBuf {
241 config_dir.join(AUTH_FILE_NAME)
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use tempfile::TempDir;
248
249 fn create_test_auth_manager() -> (AuthManager, TempDir) {
250 let temp_dir = TempDir::new().unwrap();
251 let manager = AuthManager::new(temp_dir.path()).unwrap();
252 (manager, temp_dir)
253 }
254
255 #[test]
256 fn test_new_empty() {
257 let (manager, _temp) = create_test_auth_manager();
258 assert!(!manager.has_credentials());
259 assert!(manager.list().is_empty());
260 }
261
262 #[test]
263 fn test_set_and_get() {
264 let (mut manager, _temp) = create_test_auth_manager();
265
266 let auth = ProviderAuth::api_key("sk-test-key");
267 manager.set("default", "anthropic", auth.clone()).unwrap();
268
269 let retrieved = manager.get("default", "anthropic");
270 assert!(retrieved.is_some());
271 assert_eq!(retrieved.unwrap(), &auth);
272 }
273
274 #[test]
275 fn test_profile_inheritance() {
276 let (mut manager, _temp) = create_test_auth_manager();
277
278 let all_auth = ProviderAuth::api_key("sk-all-key");
280 manager.set("all", "anthropic", all_auth.clone()).unwrap();
281
282 assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
284 assert_eq!(manager.get("work", "anthropic"), Some(&all_auth));
285 assert_eq!(manager.get("all", "anthropic"), Some(&all_auth));
286 }
287
288 #[test]
289 fn test_profile_override() {
290 let (mut manager, _temp) = create_test_auth_manager();
291
292 let all_auth = ProviderAuth::api_key("sk-all-key");
294 manager.set("all", "anthropic", all_auth.clone()).unwrap();
295
296 let work_auth = ProviderAuth::api_key("sk-work-key");
298 manager.set("work", "anthropic", work_auth.clone()).unwrap();
299
300 assert_eq!(manager.get("work", "anthropic"), Some(&work_auth));
302
303 assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
305 }
306
307 #[test]
308 fn test_remove() {
309 let (mut manager, _temp) = create_test_auth_manager();
310
311 let auth = ProviderAuth::api_key("sk-test-key");
312 manager.set("default", "anthropic", auth).unwrap();
313
314 assert!(manager.get("default", "anthropic").is_some());
315
316 let removed = manager.remove("default", "anthropic").unwrap();
317 assert!(removed);
318
319 assert!(manager.get("default", "anthropic").is_none());
320 }
321
322 #[test]
323 fn test_remove_nonexistent() {
324 let (mut manager, _temp) = create_test_auth_manager();
325
326 let removed = manager.remove("default", "anthropic").unwrap();
327 assert!(!removed);
328 }
329
330 #[test]
331 fn test_list_for_profile() {
332 let (mut manager, _temp) = create_test_auth_manager();
333
334 let all_anthropic = ProviderAuth::api_key("sk-all-anthropic");
335 let all_openai = ProviderAuth::api_key("sk-all-openai");
336 let work_anthropic = ProviderAuth::api_key("sk-work-anthropic");
337
338 manager
339 .set("all", "anthropic", all_anthropic.clone())
340 .unwrap();
341 manager.set("all", "openai", all_openai.clone()).unwrap();
342 manager
343 .set("work", "anthropic", work_anthropic.clone())
344 .unwrap();
345
346 let work_creds = manager.list_for_profile("work");
347 assert_eq!(work_creds.len(), 2);
348 assert_eq!(work_creds.get("anthropic"), Some(&&work_anthropic));
349 assert_eq!(work_creds.get("openai"), Some(&&all_openai));
350
351 let default_creds = manager.list_for_profile("default");
352 assert_eq!(default_creds.len(), 2);
353 assert_eq!(default_creds.get("anthropic"), Some(&&all_anthropic));
354 assert_eq!(default_creds.get("openai"), Some(&&all_openai));
355 }
356
357 #[test]
358 fn test_persistence() {
359 let temp_dir = TempDir::new().unwrap();
360
361 {
363 let mut manager = AuthManager::new(temp_dir.path()).unwrap();
364 let auth = ProviderAuth::api_key("sk-test-key");
365 manager.set("default", "anthropic", auth).unwrap();
366 }
367
368 {
370 let manager = AuthManager::new(temp_dir.path()).unwrap();
371 let retrieved = manager.get("default", "anthropic");
372 assert!(retrieved.is_some());
373 assert_eq!(retrieved.unwrap().api_key_value(), Some("sk-test-key"));
374 }
375 }
376
377 #[test]
378 fn test_oauth_tokens() {
379 let (mut manager, _temp) = create_test_auth_manager();
380
381 let expires = chrono::Utc::now().timestamp_millis() + 3600000;
382 let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
383 manager.set("default", "anthropic", auth).unwrap();
384
385 let retrieved = manager.get("default", "anthropic").unwrap();
386 assert!(retrieved.is_oauth());
387 assert_eq!(retrieved.access_token(), Some("access-token"));
388 assert_eq!(retrieved.refresh_token(), Some("refresh-token"));
389 }
390
391 #[test]
392 fn test_update_oauth_tokens() {
393 let (mut manager, _temp) = create_test_auth_manager();
394
395 manager
397 .set(
398 "default",
399 "anthropic",
400 ProviderAuth::oauth("old-access", "old-refresh", 0),
401 )
402 .unwrap();
403
404 let new_expires = chrono::Utc::now().timestamp_millis() + 3600000;
406 manager
407 .update_oauth_tokens(
408 "default",
409 "anthropic",
410 "new-access",
411 "new-refresh",
412 new_expires,
413 )
414 .unwrap();
415
416 let retrieved = manager.get("default", "anthropic").unwrap();
417 assert_eq!(retrieved.access_token(), Some("new-access"));
418 assert_eq!(retrieved.refresh_token(), Some("new-refresh"));
419 }
420
421 #[cfg(unix)]
422 #[test]
423 fn test_file_permissions() {
424 use std::os::unix::fs::PermissionsExt;
425
426 let temp_dir = TempDir::new().unwrap();
427 let mut manager = AuthManager::new(temp_dir.path()).unwrap();
428
429 let auth = ProviderAuth::api_key("sk-test-key");
430 manager.set("default", "anthropic", auth).unwrap();
431
432 let metadata = std::fs::metadata(manager.auth_path()).unwrap();
433 let mode = metadata.permissions().mode();
434
435 assert_eq!(mode & 0o777, 0o600);
437 }
438}