Skip to main content

stakpak_shared/
auth_manager.rs

1//! Authentication manager for storing and retrieving provider credentials
2//!
3//! # Deprecated
4//!
5//! This module is **deprecated**. Credentials are now stored directly in `config.toml`
6//! under `[profiles.{profile}.providers.{provider}.auth]` instead of in a separate
7//! `auth.toml` file.
8//!
9//! The `AuthManager` is kept temporarily for:
10//! - Reading existing `auth.toml` files during migration
11//! - Backward compatibility during the transition period
12//!
13//! New code should use `ProviderConfig::set_auth()` and `ProviderConfig::get_auth()`
14//! to manage provider credentials directly in `config.toml`.
15//!
16//! ## Migration
17//!
18//! When `config.toml` is loaded, any credentials in `auth.toml` are automatically
19//! migrated to the new format in `config.toml`, and `auth.toml` is backed up to
20//! `auth.toml.bak`.
21//!
22//! # Legacy File Structure (auth.toml - deprecated)
23//!
24//! ```toml
25//! # Shared across all profiles
26//! [all.anthropic]
27//! type = "oauth"
28//! access = "eyJ..."
29//! refresh = "eyJ..."
30//! expires = 1735600000000
31//!
32//! # Profile-specific override
33//! [work.anthropic]
34//! type = "api"
35//! key = "sk-ant-..."
36//! ```
37
38use crate::models::auth::ProviderAuth;
39use crate::oauth::error::{OAuthError, OAuthResult};
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use std::path::{Path, PathBuf};
43
44/// The name of the auth configuration file
45const AUTH_FILE_NAME: &str = "auth.toml";
46
47/// Special profile name that provides defaults for all profiles
48const ALL_PROFILE: &str = "all";
49
50/// Structure of the auth.toml file
51#[derive(Debug, Clone, Serialize, Deserialize, Default)]
52pub struct AuthFile {
53    /// Profile-scoped credentials: profile_name -> provider_name -> auth
54    #[serde(flatten)]
55    pub profiles: HashMap<String, HashMap<String, ProviderAuth>>,
56}
57
58/// Manages provider credentials stored in auth.toml
59#[derive(Debug, Clone)]
60pub struct AuthManager {
61    /// Path to the auth.toml file
62    auth_path: PathBuf,
63    /// Loaded auth file contents
64    auth_file: AuthFile,
65}
66
67impl AuthManager {
68    /// Load auth manager for the given config directory
69    pub fn new(config_dir: &Path) -> OAuthResult<Self> {
70        let auth_path = config_dir.join(AUTH_FILE_NAME);
71        let auth_file = if auth_path.is_file() {
72            let content = std::fs::read_to_string(&auth_path)?;
73            toml::from_str(&content)?
74        } else {
75            AuthFile::default()
76        };
77
78        Ok(Self {
79            auth_path,
80            auth_file,
81        })
82    }
83
84    /// Load auth manager from the default Stakpak config directory (~/.stakpak/)
85    pub fn from_default_dir() -> OAuthResult<Self> {
86        let config_dir = get_default_config_dir()?;
87        Self::new(&config_dir)
88    }
89
90    /// Get credentials for a provider, respecting profile inheritance
91    ///
92    /// Resolution order:
93    /// 1. `[{profile}.{provider}]` - profile-specific
94    /// 2. `[all.{provider}]` - shared fallback
95    pub fn get(&self, profile: &str, provider: &str) -> Option<&ProviderAuth> {
96        // First, check profile-specific credentials
97        if let Some(providers) = self.auth_file.profiles.get(profile)
98            && let Some(auth) = providers.get(provider)
99        {
100            return Some(auth);
101        }
102
103        // Fall back to "all" profile
104        if profile != ALL_PROFILE
105            && let Some(providers) = self.auth_file.profiles.get(ALL_PROFILE)
106            && let Some(auth) = providers.get(provider)
107        {
108            return Some(auth);
109        }
110
111        None
112    }
113
114    /// Set credentials for a provider in a specific profile
115    pub fn set(&mut self, profile: &str, provider: &str, auth: ProviderAuth) -> OAuthResult<()> {
116        self.auth_file
117            .profiles
118            .entry(profile.to_string())
119            .or_default()
120            .insert(provider.to_string(), auth);
121
122        self.save()
123    }
124
125    /// Remove credentials for a provider from a specific profile
126    pub fn remove(&mut self, profile: &str, provider: &str) -> OAuthResult<bool> {
127        let removed = if let Some(providers) = self.auth_file.profiles.get_mut(profile) {
128            let removed = providers.remove(provider).is_some();
129            // Clean up empty profile entries
130            if providers.is_empty() {
131                self.auth_file.profiles.remove(profile);
132            }
133            removed
134        } else {
135            false
136        };
137
138        if removed {
139            self.save()?;
140        }
141
142        Ok(removed)
143    }
144
145    /// List all credentials
146    pub fn list(&self) -> &HashMap<String, HashMap<String, ProviderAuth>> {
147        &self.auth_file.profiles
148    }
149
150    /// Get all credentials for a specific profile (including inherited from "all")
151    pub fn list_for_profile(&self, profile: &str) -> HashMap<String, &ProviderAuth> {
152        let mut result = HashMap::new();
153
154        // Start with "all" profile credentials
155        if let Some(all_providers) = self.auth_file.profiles.get(ALL_PROFILE) {
156            for (provider, auth) in all_providers {
157                result.insert(provider.clone(), auth);
158            }
159        }
160
161        // Override with profile-specific credentials
162        if profile != ALL_PROFILE
163            && let Some(profile_providers) = self.auth_file.profiles.get(profile)
164        {
165            for (provider, auth) in profile_providers {
166                result.insert(provider.clone(), auth);
167            }
168        }
169
170        result
171    }
172
173    /// Check if any credentials are configured
174    pub fn has_credentials(&self) -> bool {
175        self.auth_file
176            .profiles
177            .values()
178            .any(|providers| !providers.is_empty())
179    }
180
181    /// Get the path to the auth file
182    pub fn auth_path(&self) -> &Path {
183        &self.auth_path
184    }
185
186    /// Update OAuth tokens for a provider (used during token refresh)
187    pub fn update_oauth_tokens(
188        &mut self,
189        profile: &str,
190        provider: &str,
191        access: &str,
192        refresh: &str,
193        expires: i64,
194    ) -> OAuthResult<()> {
195        let auth = ProviderAuth::oauth(access, refresh, expires);
196        self.set(profile, provider, auth)
197    }
198
199    /// Save changes to disk
200    fn save(&self) -> OAuthResult<()> {
201        // Ensure parent directory exists
202        if let Some(parent) = self.auth_path.parent() {
203            std::fs::create_dir_all(parent)?;
204        }
205
206        let content = toml::to_string_pretty(&self.auth_file)?;
207
208        // Write to a temp file first, then rename for atomicity
209        let temp_path = self.auth_path.with_extension("toml.tmp");
210        std::fs::write(&temp_path, &content)?;
211
212        // Set file permissions to 0600 (owner read/write only) on Unix
213        #[cfg(unix)]
214        {
215            use std::os::unix::fs::PermissionsExt;
216            let permissions = std::fs::Permissions::from_mode(0o600);
217            std::fs::set_permissions(&temp_path, permissions)?;
218        }
219
220        // Atomic rename
221        std::fs::rename(&temp_path, &self.auth_path)?;
222
223        Ok(())
224    }
225}
226
227/// Get the default Stakpak config directory
228pub fn get_default_config_dir() -> OAuthResult<PathBuf> {
229    let home = dirs::home_dir().ok_or_else(|| {
230        OAuthError::IoError(std::io::Error::new(
231            std::io::ErrorKind::NotFound,
232            "Could not determine home directory",
233        ))
234    })?;
235
236    Ok(home.join(".stakpak"))
237}
238
239/// Get the auth file path for a given config directory
240pub fn get_auth_file_path(config_dir: &Path) -> PathBuf {
241    config_dir.join(AUTH_FILE_NAME)
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use tempfile::TempDir;
248
249    fn create_test_auth_manager() -> (AuthManager, TempDir) {
250        let temp_dir = TempDir::new().unwrap();
251        let manager = AuthManager::new(temp_dir.path()).unwrap();
252        (manager, temp_dir)
253    }
254
255    #[test]
256    fn test_new_empty() {
257        let (manager, _temp) = create_test_auth_manager();
258        assert!(!manager.has_credentials());
259        assert!(manager.list().is_empty());
260    }
261
262    #[test]
263    fn test_set_and_get() {
264        let (mut manager, _temp) = create_test_auth_manager();
265
266        let auth = ProviderAuth::api_key("sk-test-key");
267        manager.set("default", "anthropic", auth.clone()).unwrap();
268
269        let retrieved = manager.get("default", "anthropic");
270        assert!(retrieved.is_some());
271        assert_eq!(retrieved.unwrap(), &auth);
272    }
273
274    #[test]
275    fn test_profile_inheritance() {
276        let (mut manager, _temp) = create_test_auth_manager();
277
278        // Set in "all" profile
279        let all_auth = ProviderAuth::api_key("sk-all-key");
280        manager.set("all", "anthropic", all_auth.clone()).unwrap();
281
282        // Should be accessible from any profile
283        assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
284        assert_eq!(manager.get("work", "anthropic"), Some(&all_auth));
285        assert_eq!(manager.get("all", "anthropic"), Some(&all_auth));
286    }
287
288    #[test]
289    fn test_profile_override() {
290        let (mut manager, _temp) = create_test_auth_manager();
291
292        // Set in "all" profile
293        let all_auth = ProviderAuth::api_key("sk-all-key");
294        manager.set("all", "anthropic", all_auth.clone()).unwrap();
295
296        // Override in "work" profile
297        let work_auth = ProviderAuth::api_key("sk-work-key");
298        manager.set("work", "anthropic", work_auth.clone()).unwrap();
299
300        // "work" should get its own key
301        assert_eq!(manager.get("work", "anthropic"), Some(&work_auth));
302
303        // "default" should still get the "all" key
304        assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
305    }
306
307    #[test]
308    fn test_remove() {
309        let (mut manager, _temp) = create_test_auth_manager();
310
311        let auth = ProviderAuth::api_key("sk-test-key");
312        manager.set("default", "anthropic", auth).unwrap();
313
314        assert!(manager.get("default", "anthropic").is_some());
315
316        let removed = manager.remove("default", "anthropic").unwrap();
317        assert!(removed);
318
319        assert!(manager.get("default", "anthropic").is_none());
320    }
321
322    #[test]
323    fn test_remove_nonexistent() {
324        let (mut manager, _temp) = create_test_auth_manager();
325
326        let removed = manager.remove("default", "anthropic").unwrap();
327        assert!(!removed);
328    }
329
330    #[test]
331    fn test_list_for_profile() {
332        let (mut manager, _temp) = create_test_auth_manager();
333
334        let all_anthropic = ProviderAuth::api_key("sk-all-anthropic");
335        let all_openai = ProviderAuth::api_key("sk-all-openai");
336        let work_anthropic = ProviderAuth::api_key("sk-work-anthropic");
337
338        manager
339            .set("all", "anthropic", all_anthropic.clone())
340            .unwrap();
341        manager.set("all", "openai", all_openai.clone()).unwrap();
342        manager
343            .set("work", "anthropic", work_anthropic.clone())
344            .unwrap();
345
346        let work_creds = manager.list_for_profile("work");
347        assert_eq!(work_creds.len(), 2);
348        assert_eq!(work_creds.get("anthropic"), Some(&&work_anthropic));
349        assert_eq!(work_creds.get("openai"), Some(&&all_openai));
350
351        let default_creds = manager.list_for_profile("default");
352        assert_eq!(default_creds.len(), 2);
353        assert_eq!(default_creds.get("anthropic"), Some(&&all_anthropic));
354        assert_eq!(default_creds.get("openai"), Some(&&all_openai));
355    }
356
357    #[test]
358    fn test_persistence() {
359        let temp_dir = TempDir::new().unwrap();
360
361        // Create and save credentials
362        {
363            let mut manager = AuthManager::new(temp_dir.path()).unwrap();
364            let auth = ProviderAuth::api_key("sk-test-key");
365            manager.set("default", "anthropic", auth).unwrap();
366        }
367
368        // Load and verify
369        {
370            let manager = AuthManager::new(temp_dir.path()).unwrap();
371            let retrieved = manager.get("default", "anthropic");
372            assert!(retrieved.is_some());
373            assert_eq!(retrieved.unwrap().api_key_value(), Some("sk-test-key"));
374        }
375    }
376
377    #[test]
378    fn test_oauth_tokens() {
379        let (mut manager, _temp) = create_test_auth_manager();
380
381        let expires = chrono::Utc::now().timestamp_millis() + 3600000;
382        let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
383        manager.set("default", "anthropic", auth).unwrap();
384
385        let retrieved = manager.get("default", "anthropic").unwrap();
386        assert!(retrieved.is_oauth());
387        assert_eq!(retrieved.access_token(), Some("access-token"));
388        assert_eq!(retrieved.refresh_token(), Some("refresh-token"));
389    }
390
391    #[test]
392    fn test_update_oauth_tokens() {
393        let (mut manager, _temp) = create_test_auth_manager();
394
395        // Initial set
396        manager
397            .set(
398                "default",
399                "anthropic",
400                ProviderAuth::oauth("old-access", "old-refresh", 0),
401            )
402            .unwrap();
403
404        // Update tokens
405        let new_expires = chrono::Utc::now().timestamp_millis() + 3600000;
406        manager
407            .update_oauth_tokens(
408                "default",
409                "anthropic",
410                "new-access",
411                "new-refresh",
412                new_expires,
413            )
414            .unwrap();
415
416        let retrieved = manager.get("default", "anthropic").unwrap();
417        assert_eq!(retrieved.access_token(), Some("new-access"));
418        assert_eq!(retrieved.refresh_token(), Some("new-refresh"));
419    }
420
421    #[cfg(unix)]
422    #[test]
423    fn test_file_permissions() {
424        use std::os::unix::fs::PermissionsExt;
425
426        let temp_dir = TempDir::new().unwrap();
427        let mut manager = AuthManager::new(temp_dir.path()).unwrap();
428
429        let auth = ProviderAuth::api_key("sk-test-key");
430        manager.set("default", "anthropic", auth).unwrap();
431
432        let metadata = std::fs::metadata(manager.auth_path()).unwrap();
433        let mode = metadata.permissions().mode();
434
435        // Check that file is readable/writable only by owner (0600)
436        assert_eq!(mode & 0o777, 0o600);
437    }
438}