Skip to main content

steer_core/auth/
storage.rs

1use crate::auth::error::{AuthError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6pub use steer_auth_plugin::storage::{
7    AuthStorage, AuthTokens, Credential, CredentialType, OAuth2Token,
8};
9
10/// Collection of all credentials kept in the keyring. The first key is the
11/// provider id (e.g. `"anthropic"`), the second key is the credential type
12/// (`"AuthTokens"` / `"ApiKey"`). Each leaf holds the raw `Credential` value
13/// for that pair.
14#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15struct CredentialStore(HashMap<String, HashMap<CredentialType, Credential>>);
16
17/// Primary storage using OS keyring
18pub struct KeyringStorage {
19    service_name: String,
20}
21
22impl Default for KeyringStorage {
23    fn default() -> Self {
24        Self::new("steer")
25    }
26}
27
28impl KeyringStorage {
29    pub fn new(service_name: &str) -> Self {
30        Self {
31            service_name: service_name.to_string(),
32        }
33    }
34
35    fn get_username() -> String {
36        whoami::username()
37    }
38}
39
40#[async_trait]
41impl AuthStorage for KeyringStorage {
42    async fn get_credential(
43        &self,
44        provider: &str,
45        credential_type: CredentialType,
46    ) -> Result<Option<Credential>> {
47        let provider = provider.to_string();
48        let username = Self::get_username();
49        let service = self.service_name.clone();
50        let cred_type = credential_type;
51
52        // Load, parse and query the credential store
53        let result = tokio::task::spawn_blocking(
54            move || -> std::result::Result<Option<Credential>, keyring::Error> {
55                let entry = keyring::Entry::new(&service, &username)?;
56                let store_json = match entry.get_password() {
57                    Ok(pwd) => pwd,
58                    Err(keyring::Error::NoEntry) => return Ok(None),
59                    Err(e) => return Err(e),
60                };
61
62                let store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
63
64                // Get the credential with the requested type
65                // The serde aliases handle migration from old "AuthTokens" to "OAuth2" automatically
66                let cred = store
67                    .0
68                    .get(&provider)
69                    .and_then(|m| m.get(&cred_type))
70                    .cloned();
71
72                Ok(cred)
73            },
74        )
75        .await
76        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?;
77
78        result.map_err(AuthError::from)
79    }
80
81    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
82        let service = self.service_name.clone();
83        let username = Self::get_username();
84        let provider = provider.to_string();
85        let cred_type = credential.credential_type();
86
87        tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
88            let entry = keyring::Entry::new(&service, &username)?;
89            // Load existing store (if any)
90            let mut store: CredentialStore = match entry.get_password() {
91                Ok(pwd) => serde_json::from_str(&pwd).unwrap_or_default(),
92                Err(keyring::Error::NoEntry) => CredentialStore::default(),
93                Err(e) => return Err(e),
94            };
95
96            // Update
97            store
98                .0
99                .entry(provider)
100                .or_default()
101                .insert(cred_type, credential);
102
103            let data = serde_json::to_string(&store).map_err(|e| {
104                keyring::Error::PlatformFailure(format!("serialize credential store: {e}").into())
105            })?;
106            entry.set_password(&data)?;
107            Ok(())
108        })
109        .await
110        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
111        .map_err(AuthError::from)
112    }
113
114    async fn remove_credential(
115        &self,
116        provider: &str,
117        credential_type: CredentialType,
118    ) -> Result<()> {
119        let service = self.service_name.clone();
120        let username = Self::get_username();
121        let provider = provider.to_string();
122
123        tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
124            let entry = keyring::Entry::new(&service, &username)?;
125
126            // Load existing store, return Ok if none
127            let store_json = match entry.get_password() {
128                Ok(pwd) => pwd,
129                Err(keyring::Error::NoEntry) => return Ok(()),
130                Err(e) => return Err(e),
131            };
132
133            let mut store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
134
135            if let Some(map) = store.0.get_mut(&provider) {
136                map.remove(&credential_type);
137                if map.is_empty() {
138                    store.0.remove(&provider);
139                }
140            }
141
142            if store.0.is_empty() {
143                // No credentials left – remove the keyring entry entirely.
144                let _ = entry.delete_credential();
145            } else {
146                let data = serde_json::to_string(&store).map_err(|e| {
147                    keyring::Error::PlatformFailure(
148                        format!("serialize credential store: {e}").into(),
149                    )
150                })?;
151                entry.set_password(&data)?;
152            }
153            Ok(())
154        })
155        .await
156        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
157        .map_err(AuthError::from)
158    }
159}
160
161/// Default storage implementation that tries keyring first, then falls back to encrypted file
162pub struct DefaultAuthStorage {
163    keyring: Arc<dyn AuthStorage>,
164}
165
166impl DefaultAuthStorage {
167    pub fn new() -> Result<Self> {
168        // Try to create keyring storage
169        if !cfg!(any(
170            target_os = "macos",
171            target_os = "windows",
172            target_os = "linux"
173        )) {
174            return Err(AuthError::Storage(
175                "Keyring not supported on this platform".to_string(),
176            ));
177        }
178
179        let keyring = Arc::new(KeyringStorage::new("steer")) as Arc<dyn AuthStorage>;
180
181        Ok(Self { keyring })
182    }
183
184    // Convenience methods for working with specific credential types
185    pub async fn get_auth_tokens(&self, provider: &str) -> Result<Option<OAuth2Token>> {
186        match self
187            .get_credential(provider, CredentialType::OAuth2)
188            .await?
189        {
190            Some(Credential::OAuth2(tokens)) => Ok(Some(tokens)),
191            _ => Ok(None),
192        }
193    }
194
195    pub async fn set_auth_tokens(&self, provider: &str, tokens: OAuth2Token) -> Result<()> {
196        self.set_credential(provider, Credential::OAuth2(tokens))
197            .await
198    }
199
200    pub async fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
201        match self
202            .get_credential(provider, CredentialType::ApiKey)
203            .await?
204        {
205            Some(Credential::ApiKey { value }) => Ok(Some(value)),
206            _ => Ok(None),
207        }
208    }
209
210    pub async fn set_api_key(&self, provider: &str, api_key: String) -> Result<()> {
211        self.set_credential(provider, Credential::ApiKey { value: api_key })
212            .await
213    }
214
215    pub async fn remove_auth_tokens(&self, provider: &str) -> Result<()> {
216        self.remove_credential(provider, CredentialType::OAuth2)
217            .await
218    }
219
220    pub async fn remove_api_key(&self, provider: &str) -> Result<()> {
221        self.remove_credential(provider, CredentialType::ApiKey)
222            .await
223    }
224}
225
226#[async_trait]
227impl AuthStorage for DefaultAuthStorage {
228    async fn get_credential(
229        &self,
230        provider: &str,
231        credential_type: CredentialType,
232    ) -> Result<Option<Credential>> {
233        self.keyring.get_credential(provider, credential_type).await
234    }
235
236    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
237        self.keyring
238            .set_credential(provider, credential.clone())
239            .await
240    }
241
242    async fn remove_credential(
243        &self,
244        provider: &str,
245        credential_type: CredentialType,
246    ) -> Result<()> {
247        self.keyring
248            .remove_credential(provider, credential_type)
249            .await
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use std::time::{Duration, SystemTime};
257
258    #[test]
259    fn test_credential_deserialization_with_alias() {
260        // Test that old "AuthTokens" format deserializes correctly to OAuth2
261        let old_json_format = r#"{
262            "anthropic": {
263                "AuthTokens": {
264                    "type": "AuthTokens",
265                    "access_token": "old_access_token",
266                    "refresh_token": "old_refresh_token",
267                    "expires_at": {
268                        "secs_since_epoch": 1678886400,
269                        "nanos_since_epoch": 0
270                    }
271                }
272            }
273        }"#;
274
275        let store: CredentialStore =
276            serde_json::from_str(old_json_format).expect("Failed to deserialize old format");
277
278        // The serde alias should have converted AuthTokens to OAuth2
279        let creds = store.0.get("anthropic").unwrap();
280        let cred = creds.get(&CredentialType::OAuth2).unwrap();
281
282        match cred {
283            Credential::OAuth2(token) => {
284                assert_eq!(token.access_token, "old_access_token");
285                assert_eq!(token.refresh_token, "old_refresh_token");
286                assert_eq!(
287                    token.expires_at,
288                    SystemTime::UNIX_EPOCH + Duration::from_secs(1_678_886_400)
289                );
290            }
291            Credential::ApiKey { .. } => {
292                panic!("Deserialization failed: expected OAuth2 credential")
293            }
294        }
295    }
296
297    #[test]
298    fn test_credential_type_deserialization_with_alias() {
299        // Test that the old "AuthTokens" string deserializes to OAuth2
300        let old_type = r#""AuthTokens""#;
301        let cred_type: CredentialType = serde_json::from_str(old_type).unwrap();
302        assert_eq!(cred_type, CredentialType::OAuth2);
303
304        // Also test that "OAuth2" works
305        let new_type = r#""OAuth2""#;
306        let cred_type: CredentialType = serde_json::from_str(new_type).unwrap();
307        assert_eq!(cred_type, CredentialType::OAuth2);
308    }
309}