steer_core/auth/
storage.rs

1use crate::auth::error::{AuthError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::SystemTime;
7use strum::Display;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct AuthTokens {
11    pub access_token: String,
12    pub refresh_token: String,
13    pub expires_at: SystemTime,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type")]
18pub enum Credential {
19    AuthTokens(AuthTokens),
20    ApiKey { value: String },
21}
22
23impl Credential {
24    pub fn credential_type(&self) -> CredentialType {
25        match self {
26            Credential::AuthTokens(_) => CredentialType::AuthTokens,
27            Credential::ApiKey { .. } => CredentialType::ApiKey,
28        }
29    }
30}
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize, Display, PartialEq, Eq, Hash)]
33pub enum CredentialType {
34    AuthTokens,
35    ApiKey,
36}
37
38/// Collection of all credentials kept in the keyring. The first key is the
39/// provider id (e.g. `"anthropic"`), the second key is the credential type
40/// (`"AuthTokens"` / `"ApiKey"`). Each leaf holds the raw `Credential` value
41/// for that pair.
42#[derive(Debug, Clone, Serialize, Deserialize, Default)]
43struct CredentialStore(HashMap<String, HashMap<CredentialType, Credential>>);
44
45#[async_trait]
46pub trait AuthStorage: Send + Sync {
47    async fn get_credential(
48        &self,
49        provider: &str,
50        credential_type: CredentialType,
51    ) -> Result<Option<Credential>>;
52    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()>;
53    async fn remove_credential(
54        &self,
55        provider: &str,
56        credential_type: CredentialType,
57    ) -> Result<()>;
58}
59
60/// Primary storage using OS keyring
61pub struct KeyringStorage {
62    service_name: String,
63}
64
65impl Default for KeyringStorage {
66    fn default() -> Self {
67        Self::new("steer")
68    }
69}
70
71impl KeyringStorage {
72    pub fn new(service_name: &str) -> Self {
73        Self {
74            service_name: service_name.to_string(),
75        }
76    }
77
78    fn get_username() -> String {
79        whoami::username()
80    }
81}
82
83#[async_trait]
84impl AuthStorage for KeyringStorage {
85    async fn get_credential(
86        &self,
87        provider: &str,
88        credential_type: CredentialType,
89    ) -> Result<Option<Credential>> {
90        let provider = provider.to_string();
91        let username = Self::get_username();
92        let service = self.service_name.clone();
93
94        // Load, parse and query the credential store
95        tokio::task::spawn_blocking(
96            move || -> std::result::Result<Option<Credential>, keyring::Error> {
97                let entry = keyring::Entry::new(&service, &username)?;
98                let store_json = match entry.get_password() {
99                    Ok(pwd) => pwd,
100                    Err(keyring::Error::NoEntry) => return Ok(None),
101                    Err(e) => return Err(e),
102                };
103
104                let store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
105                let cred = store
106                    .0
107                    .get(&provider)
108                    .and_then(|m| m.get(&credential_type))
109                    .cloned();
110                Ok(cred)
111            },
112        )
113        .await
114        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
115        .map_err(AuthError::from)
116    }
117
118    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
119        let service = self.service_name.clone();
120        let username = Self::get_username();
121        let provider = provider.to_string();
122        let cred_type = credential.credential_type();
123
124        tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
125            let entry = keyring::Entry::new(&service, &username)?;
126            // Load existing store (if any)
127            let mut store: CredentialStore = match entry.get_password() {
128                Ok(pwd) => serde_json::from_str(&pwd).unwrap_or_default(),
129                Err(keyring::Error::NoEntry) => CredentialStore::default(),
130                Err(e) => return Err(e),
131            };
132
133            // Update
134            store
135                .0
136                .entry(provider)
137                .or_default()
138                .insert(cred_type, credential);
139
140            let data = serde_json::to_string(&store).expect("serialize credential store");
141            entry.set_password(&data)?;
142            Ok(())
143        })
144        .await
145        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
146        .map_err(AuthError::from)
147    }
148
149    async fn remove_credential(
150        &self,
151        provider: &str,
152        credential_type: CredentialType,
153    ) -> Result<()> {
154        let service = self.service_name.clone();
155        let username = Self::get_username();
156        let provider = provider.to_string();
157
158        tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
159            let entry = keyring::Entry::new(&service, &username)?;
160
161            // Load existing store, return Ok if none
162            let store_json = match entry.get_password() {
163                Ok(pwd) => pwd,
164                Err(keyring::Error::NoEntry) => return Ok(()),
165                Err(e) => return Err(e),
166            };
167
168            let mut store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
169
170            if let Some(map) = store.0.get_mut(&provider) {
171                map.remove(&credential_type);
172                if map.is_empty() {
173                    store.0.remove(&provider);
174                }
175            }
176
177            if store.0.is_empty() {
178                // No credentials left – remove the keyring entry entirely.
179                let _ = entry.delete_credential();
180            } else {
181                let data = serde_json::to_string(&store).expect("serialize credential store");
182                entry.set_password(&data)?;
183            }
184            Ok(())
185        })
186        .await
187        .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
188        .map_err(AuthError::from)
189    }
190}
191
192/// Default storage implementation that tries keyring first, then falls back to encrypted file
193pub struct DefaultAuthStorage {
194    keyring: Arc<dyn AuthStorage>,
195}
196
197impl DefaultAuthStorage {
198    pub fn new() -> Result<Self> {
199        // Try to create keyring storage
200        if !cfg!(any(
201            target_os = "macos",
202            target_os = "windows",
203            target_os = "linux"
204        )) {
205            return Err(AuthError::Storage(
206                "Keyring not supported on this platform".to_string(),
207            ));
208        }
209
210        let keyring = Arc::new(KeyringStorage::new("steer")) as Arc<dyn AuthStorage>;
211
212        Ok(Self { keyring })
213    }
214
215    // Convenience methods for working with specific credential types
216    pub async fn get_auth_tokens(&self, provider: &str) -> Result<Option<AuthTokens>> {
217        match self
218            .get_credential(provider, CredentialType::AuthTokens)
219            .await?
220        {
221            Some(Credential::AuthTokens(tokens)) => Ok(Some(tokens)),
222            _ => Ok(None),
223        }
224    }
225
226    pub async fn set_auth_tokens(&self, provider: &str, tokens: AuthTokens) -> Result<()> {
227        self.set_credential(provider, Credential::AuthTokens(tokens))
228            .await
229    }
230
231    pub async fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
232        match self
233            .get_credential(provider, CredentialType::ApiKey)
234            .await?
235        {
236            Some(Credential::ApiKey { value }) => Ok(Some(value)),
237            _ => Ok(None),
238        }
239    }
240
241    pub async fn set_api_key(&self, provider: &str, api_key: String) -> Result<()> {
242        self.set_credential(provider, Credential::ApiKey { value: api_key })
243            .await
244    }
245
246    pub async fn remove_auth_tokens(&self, provider: &str) -> Result<()> {
247        self.remove_credential(provider, CredentialType::AuthTokens)
248            .await
249    }
250
251    pub async fn remove_api_key(&self, provider: &str) -> Result<()> {
252        self.remove_credential(provider, CredentialType::ApiKey)
253            .await
254    }
255}
256
257#[async_trait]
258impl AuthStorage for DefaultAuthStorage {
259    async fn get_credential(
260        &self,
261        provider: &str,
262        credential_type: CredentialType,
263    ) -> Result<Option<Credential>> {
264        self.keyring.get_credential(provider, credential_type).await
265    }
266
267    async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
268        self.keyring
269            .set_credential(provider, credential.clone())
270            .await
271    }
272
273    async fn remove_credential(
274        &self,
275        provider: &str,
276        credential_type: CredentialType,
277    ) -> Result<()> {
278        self.keyring
279            .remove_credential(provider, credential_type)
280            .await
281    }
282}