1use crate::auth::error::{AuthError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::SystemTime;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct OAuth2Token {
10 pub access_token: String,
11 pub refresh_token: String,
12 pub expires_at: SystemTime,
13}
14
15pub type AuthTokens = OAuth2Token;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum Credential {
21 #[serde(alias = "AuthTokens")]
22 OAuth2(OAuth2Token),
23 ApiKey {
24 value: String,
25 },
26}
27
28impl Credential {
29 pub fn credential_type(&self) -> CredentialType {
30 match self {
31 Credential::OAuth2(_) => CredentialType::OAuth2,
32 Credential::ApiKey { .. } => CredentialType::ApiKey,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
38pub enum CredentialType {
39 #[serde(alias = "AuthTokens")]
40 OAuth2,
41 ApiKey,
42}
43
44impl std::fmt::Display for CredentialType {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 CredentialType::OAuth2 => write!(f, "OAuth2"),
48 CredentialType::ApiKey => write!(f, "ApiKey"),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58struct CredentialStore(HashMap<String, HashMap<CredentialType, Credential>>);
59
60#[async_trait]
61pub trait AuthStorage: Send + Sync {
62 async fn get_credential(
63 &self,
64 provider: &str,
65 credential_type: CredentialType,
66 ) -> Result<Option<Credential>>;
67 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()>;
68 async fn remove_credential(
69 &self,
70 provider: &str,
71 credential_type: CredentialType,
72 ) -> Result<()>;
73}
74
75pub struct KeyringStorage {
77 service_name: String,
78}
79
80impl Default for KeyringStorage {
81 fn default() -> Self {
82 Self::new("steer")
83 }
84}
85
86impl KeyringStorage {
87 pub fn new(service_name: &str) -> Self {
88 Self {
89 service_name: service_name.to_string(),
90 }
91 }
92
93 fn get_username() -> String {
94 whoami::username()
95 }
96}
97
98#[async_trait]
99impl AuthStorage for KeyringStorage {
100 async fn get_credential(
101 &self,
102 provider: &str,
103 credential_type: CredentialType,
104 ) -> Result<Option<Credential>> {
105 let provider = provider.to_string();
106 let username = Self::get_username();
107 let service = self.service_name.clone();
108 let cred_type = credential_type;
109
110 let result = tokio::task::spawn_blocking(
112 move || -> std::result::Result<Option<Credential>, keyring::Error> {
113 let entry = keyring::Entry::new(&service, &username)?;
114 let store_json = match entry.get_password() {
115 Ok(pwd) => pwd,
116 Err(keyring::Error::NoEntry) => return Ok(None),
117 Err(e) => return Err(e),
118 };
119
120 let store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
121
122 let cred = store
125 .0
126 .get(&provider)
127 .and_then(|m| m.get(&cred_type))
128 .cloned();
129
130 Ok(cred)
131 },
132 )
133 .await
134 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?;
135
136 result.map_err(AuthError::from)
137 }
138
139 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
140 let service = self.service_name.clone();
141 let username = Self::get_username();
142 let provider = provider.to_string();
143 let cred_type = credential.credential_type();
144
145 tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
146 let entry = keyring::Entry::new(&service, &username)?;
147 let mut store: CredentialStore = match entry.get_password() {
149 Ok(pwd) => serde_json::from_str(&pwd).unwrap_or_default(),
150 Err(keyring::Error::NoEntry) => CredentialStore::default(),
151 Err(e) => return Err(e),
152 };
153
154 store
156 .0
157 .entry(provider)
158 .or_default()
159 .insert(cred_type, credential);
160
161 let data = serde_json::to_string(&store).expect("serialize credential store");
162 entry.set_password(&data)?;
163 Ok(())
164 })
165 .await
166 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
167 .map_err(AuthError::from)
168 }
169
170 async fn remove_credential(
171 &self,
172 provider: &str,
173 credential_type: CredentialType,
174 ) -> Result<()> {
175 let service = self.service_name.clone();
176 let username = Self::get_username();
177 let provider = provider.to_string();
178
179 tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
180 let entry = keyring::Entry::new(&service, &username)?;
181
182 let store_json = match entry.get_password() {
184 Ok(pwd) => pwd,
185 Err(keyring::Error::NoEntry) => return Ok(()),
186 Err(e) => return Err(e),
187 };
188
189 let mut store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
190
191 if let Some(map) = store.0.get_mut(&provider) {
192 map.remove(&credential_type);
193 if map.is_empty() {
194 store.0.remove(&provider);
195 }
196 }
197
198 if store.0.is_empty() {
199 let _ = entry.delete_credential();
201 } else {
202 let data = serde_json::to_string(&store).expect("serialize credential store");
203 entry.set_password(&data)?;
204 }
205 Ok(())
206 })
207 .await
208 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
209 .map_err(AuthError::from)
210 }
211}
212
213pub struct DefaultAuthStorage {
215 keyring: Arc<dyn AuthStorage>,
216}
217
218impl DefaultAuthStorage {
219 pub fn new() -> Result<Self> {
220 if !cfg!(any(
222 target_os = "macos",
223 target_os = "windows",
224 target_os = "linux"
225 )) {
226 return Err(AuthError::Storage(
227 "Keyring not supported on this platform".to_string(),
228 ));
229 }
230
231 let keyring = Arc::new(KeyringStorage::new("steer")) as Arc<dyn AuthStorage>;
232
233 Ok(Self { keyring })
234 }
235
236 pub async fn get_auth_tokens(&self, provider: &str) -> Result<Option<OAuth2Token>> {
238 match self
239 .get_credential(provider, CredentialType::OAuth2)
240 .await?
241 {
242 Some(Credential::OAuth2(tokens)) => Ok(Some(tokens)),
243 _ => Ok(None),
244 }
245 }
246
247 pub async fn set_auth_tokens(&self, provider: &str, tokens: OAuth2Token) -> Result<()> {
248 self.set_credential(provider, Credential::OAuth2(tokens))
249 .await
250 }
251
252 pub async fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
253 match self
254 .get_credential(provider, CredentialType::ApiKey)
255 .await?
256 {
257 Some(Credential::ApiKey { value }) => Ok(Some(value)),
258 _ => Ok(None),
259 }
260 }
261
262 pub async fn set_api_key(&self, provider: &str, api_key: String) -> Result<()> {
263 self.set_credential(provider, Credential::ApiKey { value: api_key })
264 .await
265 }
266
267 pub async fn remove_auth_tokens(&self, provider: &str) -> Result<()> {
268 self.remove_credential(provider, CredentialType::OAuth2)
269 .await
270 }
271
272 pub async fn remove_api_key(&self, provider: &str) -> Result<()> {
273 self.remove_credential(provider, CredentialType::ApiKey)
274 .await
275 }
276}
277
278#[async_trait]
279impl AuthStorage for DefaultAuthStorage {
280 async fn get_credential(
281 &self,
282 provider: &str,
283 credential_type: CredentialType,
284 ) -> Result<Option<Credential>> {
285 self.keyring.get_credential(provider, credential_type).await
286 }
287
288 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
289 self.keyring
290 .set_credential(provider, credential.clone())
291 .await
292 }
293
294 async fn remove_credential(
295 &self,
296 provider: &str,
297 credential_type: CredentialType,
298 ) -> Result<()> {
299 self.keyring
300 .remove_credential(provider, credential_type)
301 .await
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use std::time::{Duration, SystemTime};
309
310 #[test]
311 fn test_credential_deserialization_with_alias() {
312 let old_json_format = r#"{
314 "anthropic": {
315 "AuthTokens": {
316 "type": "AuthTokens",
317 "access_token": "old_access_token",
318 "refresh_token": "old_refresh_token",
319 "expires_at": {
320 "secs_since_epoch": 1678886400,
321 "nanos_since_epoch": 0
322 }
323 }
324 }
325 }"#;
326
327 let store: CredentialStore =
328 serde_json::from_str(old_json_format).expect("Failed to deserialize old format");
329
330 let creds = store.0.get("anthropic").unwrap();
332 let cred = creds.get(&CredentialType::OAuth2).unwrap();
333
334 match cred {
335 Credential::OAuth2(token) => {
336 assert_eq!(token.access_token, "old_access_token");
337 assert_eq!(token.refresh_token, "old_refresh_token");
338 assert_eq!(
339 token.expires_at,
340 SystemTime::UNIX_EPOCH + Duration::from_secs(1678886400)
341 );
342 }
343 _ => panic!("Deserialization failed: expected OAuth2 credential"),
344 }
345 }
346
347 #[test]
348 fn test_credential_type_deserialization_with_alias() {
349 let old_type = r#""AuthTokens""#;
351 let cred_type: CredentialType = serde_json::from_str(old_type).unwrap();
352 assert_eq!(cred_type, CredentialType::OAuth2);
353
354 let new_type = r#""OAuth2""#;
356 let cred_type: CredentialType = serde_json::from_str(new_type).unwrap();
357 assert_eq!(cred_type, CredentialType::OAuth2);
358 }
359}