steer_core/auth/
storage.rs

1use crate::auth::error::{AuthError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::SystemTime;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct OAuth2Token {
10    pub access_token: String,
11    pub refresh_token: String,
12    pub expires_at: SystemTime,
13}
14
15// Alias for backwards compatibility
16pub type AuthTokens = OAuth2Token;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum Credential {
21    #[serde(alias = "AuthTokens")]
22    OAuth2(OAuth2Token),
23    ApiKey {
24        value: String,
25    },
26}
27
28impl Credential {
29    pub fn credential_type(&self) -> CredentialType {
30        match self {
31            Credential::OAuth2(_) => CredentialType::OAuth2,
32            Credential::ApiKey { .. } => CredentialType::ApiKey,
33        }
34    }
35}
36
37#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
38pub enum CredentialType {
39    #[serde(alias = "AuthTokens")]
40    OAuth2,
41    ApiKey,
42}
43
44impl std::fmt::Display for CredentialType {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            CredentialType::OAuth2 => write!(f, "OAuth2"),
48            CredentialType::ApiKey => write!(f, "ApiKey"),
49        }
50    }
51}
52
53/// Collection of all credentials kept in the keyring. The first key is the
54/// provider id (e.g. `"anthropic"`), the second key is the credential type
55/// (`"AuthTokens"` / `"ApiKey"`). Each leaf holds the raw `Credential` value
56/// for that pair.
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58struct CredentialStore(HashMap<String, HashMap<CredentialType, Credential>>);
59
60#[async_trait]
61pub trait AuthStorage: Send + Sync {
62    async fn get_credential(
63        &self,
64        provider: &str,
65        credential_type: CredentialType,
66    ) -> Result<Option<Credential>>;
67    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()>;
68    async fn remove_credential(
69        &self,
70        provider: &str,
71        credential_type: CredentialType,
72    ) -> Result<()>;
73}
74
75/// Primary storage using OS keyring
76pub struct KeyringStorage {
77    service_name: String,
78}
79
80impl Default for KeyringStorage {
81    fn default() -> Self {
82        Self::new("steer")
83    }
84}
85
86impl KeyringStorage {
87    pub fn new(service_name: &str) -> Self {
88        Self {
89            service_name: service_name.to_string(),
90        }
91    }
92
93    fn get_username() -> String {
94        whoami::username()
95    }
96}
97
98#[async_trait]
99impl AuthStorage for KeyringStorage {
100    async fn get_credential(
101        &self,
102        provider: &str,
103        credential_type: CredentialType,
104    ) -> Result<Option<Credential>> {
105        let provider = provider.to_string();
106        let username = Self::get_username();
107        let service = self.service_name.clone();
108        let cred_type = credential_type;
109
110        // Load, parse and query the credential store
111        let result = tokio::task::spawn_blocking(
112            move || -> std::result::Result<Option<Credential>, keyring::Error> {
113                let entry = keyring::Entry::new(&service, &username)?;
114                let store_json = match entry.get_password() {
115                    Ok(pwd) => pwd,
116                    Err(keyring::Error::NoEntry) => return Ok(None),
117                    Err(e) => return Err(e),
118                };
119
120                let store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
121
122                // Get the credential with the requested type
123                // The serde aliases handle migration from old "AuthTokens" to "OAuth2" automatically
124                let cred = store
125                    .0
126                    .get(&provider)
127                    .and_then(|m| m.get(&cred_type))
128                    .cloned();
129
130                Ok(cred)
131            },
132        )
133        .await
134        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?;
135
136        result.map_err(AuthError::from)
137    }
138
139    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
140        let service = self.service_name.clone();
141        let username = Self::get_username();
142        let provider = provider.to_string();
143        let cred_type = credential.credential_type();
144
145        tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
146            let entry = keyring::Entry::new(&service, &username)?;
147            // Load existing store (if any)
148            let mut store: CredentialStore = match entry.get_password() {
149                Ok(pwd) => serde_json::from_str(&pwd).unwrap_or_default(),
150                Err(keyring::Error::NoEntry) => CredentialStore::default(),
151                Err(e) => return Err(e),
152            };
153
154            // Update
155            store
156                .0
157                .entry(provider)
158                .or_default()
159                .insert(cred_type, credential);
160
161            let data = serde_json::to_string(&store).expect("serialize credential store");
162            entry.set_password(&data)?;
163            Ok(())
164        })
165        .await
166        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
167        .map_err(AuthError::from)
168    }
169
170    async fn remove_credential(
171        &self,
172        provider: &str,
173        credential_type: CredentialType,
174    ) -> Result<()> {
175        let service = self.service_name.clone();
176        let username = Self::get_username();
177        let provider = provider.to_string();
178
179        tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
180            let entry = keyring::Entry::new(&service, &username)?;
181
182            // Load existing store, return Ok if none
183            let store_json = match entry.get_password() {
184                Ok(pwd) => pwd,
185                Err(keyring::Error::NoEntry) => return Ok(()),
186                Err(e) => return Err(e),
187            };
188
189            let mut store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
190
191            if let Some(map) = store.0.get_mut(&provider) {
192                map.remove(&credential_type);
193                if map.is_empty() {
194                    store.0.remove(&provider);
195                }
196            }
197
198            if store.0.is_empty() {
199                // No credentials left – remove the keyring entry entirely.
200                let _ = entry.delete_credential();
201            } else {
202                let data = serde_json::to_string(&store).expect("serialize credential store");
203                entry.set_password(&data)?;
204            }
205            Ok(())
206        })
207        .await
208        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
209        .map_err(AuthError::from)
210    }
211}
212
213/// Default storage implementation that tries keyring first, then falls back to encrypted file
214pub struct DefaultAuthStorage {
215    keyring: Arc<dyn AuthStorage>,
216}
217
218impl DefaultAuthStorage {
219    pub fn new() -> Result<Self> {
220        // Try to create keyring storage
221        if !cfg!(any(
222            target_os = "macos",
223            target_os = "windows",
224            target_os = "linux"
225        )) {
226            return Err(AuthError::Storage(
227                "Keyring not supported on this platform".to_string(),
228            ));
229        }
230
231        let keyring = Arc::new(KeyringStorage::new("steer")) as Arc<dyn AuthStorage>;
232
233        Ok(Self { keyring })
234    }
235
236    // Convenience methods for working with specific credential types
237    pub async fn get_auth_tokens(&self, provider: &str) -> Result<Option<OAuth2Token>> {
238        match self
239            .get_credential(provider, CredentialType::OAuth2)
240            .await?
241        {
242            Some(Credential::OAuth2(tokens)) => Ok(Some(tokens)),
243            _ => Ok(None),
244        }
245    }
246
247    pub async fn set_auth_tokens(&self, provider: &str, tokens: OAuth2Token) -> Result<()> {
248        self.set_credential(provider, Credential::OAuth2(tokens))
249            .await
250    }
251
252    pub async fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
253        match self
254            .get_credential(provider, CredentialType::ApiKey)
255            .await?
256        {
257            Some(Credential::ApiKey { value }) => Ok(Some(value)),
258            _ => Ok(None),
259        }
260    }
261
262    pub async fn set_api_key(&self, provider: &str, api_key: String) -> Result<()> {
263        self.set_credential(provider, Credential::ApiKey { value: api_key })
264            .await
265    }
266
267    pub async fn remove_auth_tokens(&self, provider: &str) -> Result<()> {
268        self.remove_credential(provider, CredentialType::OAuth2)
269            .await
270    }
271
272    pub async fn remove_api_key(&self, provider: &str) -> Result<()> {
273        self.remove_credential(provider, CredentialType::ApiKey)
274            .await
275    }
276}
277
278#[async_trait]
279impl AuthStorage for DefaultAuthStorage {
280    async fn get_credential(
281        &self,
282        provider: &str,
283        credential_type: CredentialType,
284    ) -> Result<Option<Credential>> {
285        self.keyring.get_credential(provider, credential_type).await
286    }
287
288    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
289        self.keyring
290            .set_credential(provider, credential.clone())
291            .await
292    }
293
294    async fn remove_credential(
295        &self,
296        provider: &str,
297        credential_type: CredentialType,
298    ) -> Result<()> {
299        self.keyring
300            .remove_credential(provider, credential_type)
301            .await
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use std::time::{Duration, SystemTime};
309
310    #[test]
311    fn test_credential_deserialization_with_alias() {
312        // Test that old "AuthTokens" format deserializes correctly to OAuth2
313        let old_json_format = r#"{
314            "anthropic": {
315                "AuthTokens": {
316                    "type": "AuthTokens",
317                    "access_token": "old_access_token",
318                    "refresh_token": "old_refresh_token",
319                    "expires_at": {
320                        "secs_since_epoch": 1678886400,
321                        "nanos_since_epoch": 0
322                    }
323                }
324            }
325        }"#;
326
327        let store: CredentialStore =
328            serde_json::from_str(old_json_format).expect("Failed to deserialize old format");
329
330        // The serde alias should have converted AuthTokens to OAuth2
331        let creds = store.0.get("anthropic").unwrap();
332        let cred = creds.get(&CredentialType::OAuth2).unwrap();
333
334        match cred {
335            Credential::OAuth2(token) => {
336                assert_eq!(token.access_token, "old_access_token");
337                assert_eq!(token.refresh_token, "old_refresh_token");
338                assert_eq!(
339                    token.expires_at,
340                    SystemTime::UNIX_EPOCH + Duration::from_secs(1678886400)
341                );
342            }
343            _ => panic!("Deserialization failed: expected OAuth2 credential"),
344        }
345    }
346
347    #[test]
348    fn test_credential_type_deserialization_with_alias() {
349        // Test that the old "AuthTokens" string deserializes to OAuth2
350        let old_type = r#""AuthTokens""#;
351        let cred_type: CredentialType = serde_json::from_str(old_type).unwrap();
352        assert_eq!(cred_type, CredentialType::OAuth2);
353
354        // Also test that "OAuth2" works
355        let new_type = r#""OAuth2""#;
356        let cred_type: CredentialType = serde_json::from_str(new_type).unwrap();
357        assert_eq!(cred_type, CredentialType::OAuth2);
358    }
359}