1use crate::auth::error::{AuthError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::SystemTime;
7use strum::Display;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct AuthTokens {
11 pub access_token: String,
12 pub refresh_token: String,
13 pub expires_at: SystemTime,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type")]
18pub enum Credential {
19 AuthTokens(AuthTokens),
20 ApiKey { value: String },
21}
22
23impl Credential {
24 pub fn credential_type(&self) -> CredentialType {
25 match self {
26 Credential::AuthTokens(_) => CredentialType::AuthTokens,
27 Credential::ApiKey { .. } => CredentialType::ApiKey,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize, Display, PartialEq, Eq, Hash)]
33pub enum CredentialType {
34 AuthTokens,
35 ApiKey,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Default)]
43struct CredentialStore(HashMap<String, HashMap<CredentialType, Credential>>);
44
45#[async_trait]
46pub trait AuthStorage: Send + Sync {
47 async fn get_credential(
48 &self,
49 provider: &str,
50 credential_type: CredentialType,
51 ) -> Result<Option<Credential>>;
52 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()>;
53 async fn remove_credential(
54 &self,
55 provider: &str,
56 credential_type: CredentialType,
57 ) -> Result<()>;
58}
59
60pub struct KeyringStorage {
62 service_name: String,
63}
64
65impl Default for KeyringStorage {
66 fn default() -> Self {
67 Self::new("steer")
68 }
69}
70
71impl KeyringStorage {
72 pub fn new(service_name: &str) -> Self {
73 Self {
74 service_name: service_name.to_string(),
75 }
76 }
77
78 fn get_username() -> String {
79 whoami::username()
80 }
81}
82
83#[async_trait]
84impl AuthStorage for KeyringStorage {
85 async fn get_credential(
86 &self,
87 provider: &str,
88 credential_type: CredentialType,
89 ) -> Result<Option<Credential>> {
90 let provider = provider.to_string();
91 let username = Self::get_username();
92 let service = self.service_name.clone();
93
94 tokio::task::spawn_blocking(
96 move || -> std::result::Result<Option<Credential>, keyring::Error> {
97 let entry = keyring::Entry::new(&service, &username)?;
98 let store_json = match entry.get_password() {
99 Ok(pwd) => pwd,
100 Err(keyring::Error::NoEntry) => return Ok(None),
101 Err(e) => return Err(e),
102 };
103
104 let store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
105 let cred = store
106 .0
107 .get(&provider)
108 .and_then(|m| m.get(&credential_type))
109 .cloned();
110 Ok(cred)
111 },
112 )
113 .await
114 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
115 .map_err(AuthError::from)
116 }
117
118 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
119 let service = self.service_name.clone();
120 let username = Self::get_username();
121 let provider = provider.to_string();
122 let cred_type = credential.credential_type();
123
124 tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
125 let entry = keyring::Entry::new(&service, &username)?;
126 let mut store: CredentialStore = match entry.get_password() {
128 Ok(pwd) => serde_json::from_str(&pwd).unwrap_or_default(),
129 Err(keyring::Error::NoEntry) => CredentialStore::default(),
130 Err(e) => return Err(e),
131 };
132
133 store
135 .0
136 .entry(provider)
137 .or_default()
138 .insert(cred_type, credential);
139
140 let data = serde_json::to_string(&store).expect("serialize credential store");
141 entry.set_password(&data)?;
142 Ok(())
143 })
144 .await
145 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
146 .map_err(AuthError::from)
147 }
148
149 async fn remove_credential(
150 &self,
151 provider: &str,
152 credential_type: CredentialType,
153 ) -> Result<()> {
154 let service = self.service_name.clone();
155 let username = Self::get_username();
156 let provider = provider.to_string();
157
158 tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
159 let entry = keyring::Entry::new(&service, &username)?;
160
161 let store_json = match entry.get_password() {
163 Ok(pwd) => pwd,
164 Err(keyring::Error::NoEntry) => return Ok(()),
165 Err(e) => return Err(e),
166 };
167
168 let mut store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
169
170 if let Some(map) = store.0.get_mut(&provider) {
171 map.remove(&credential_type);
172 if map.is_empty() {
173 store.0.remove(&provider);
174 }
175 }
176
177 if store.0.is_empty() {
178 let _ = entry.delete_credential();
180 } else {
181 let data = serde_json::to_string(&store).expect("serialize credential store");
182 entry.set_password(&data)?;
183 }
184 Ok(())
185 })
186 .await
187 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
188 .map_err(AuthError::from)
189 }
190}
191
192pub struct DefaultAuthStorage {
194 keyring: Arc<dyn AuthStorage>,
195}
196
197impl DefaultAuthStorage {
198 pub fn new() -> Result<Self> {
199 if !cfg!(any(
201 target_os = "macos",
202 target_os = "windows",
203 target_os = "linux"
204 )) {
205 return Err(AuthError::Storage(
206 "Keyring not supported on this platform".to_string(),
207 ));
208 }
209
210 let keyring = Arc::new(KeyringStorage::new("steer")) as Arc<dyn AuthStorage>;
211
212 Ok(Self { keyring })
213 }
214
215 pub async fn get_auth_tokens(&self, provider: &str) -> Result<Option<AuthTokens>> {
217 match self
218 .get_credential(provider, CredentialType::AuthTokens)
219 .await?
220 {
221 Some(Credential::AuthTokens(tokens)) => Ok(Some(tokens)),
222 _ => Ok(None),
223 }
224 }
225
226 pub async fn set_auth_tokens(&self, provider: &str, tokens: AuthTokens) -> Result<()> {
227 self.set_credential(provider, Credential::AuthTokens(tokens))
228 .await
229 }
230
231 pub async fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
232 match self
233 .get_credential(provider, CredentialType::ApiKey)
234 .await?
235 {
236 Some(Credential::ApiKey { value }) => Ok(Some(value)),
237 _ => Ok(None),
238 }
239 }
240
241 pub async fn set_api_key(&self, provider: &str, api_key: String) -> Result<()> {
242 self.set_credential(provider, Credential::ApiKey { value: api_key })
243 .await
244 }
245
246 pub async fn remove_auth_tokens(&self, provider: &str) -> Result<()> {
247 self.remove_credential(provider, CredentialType::AuthTokens)
248 .await
249 }
250
251 pub async fn remove_api_key(&self, provider: &str) -> Result<()> {
252 self.remove_credential(provider, CredentialType::ApiKey)
253 .await
254 }
255}
256
257#[async_trait]
258impl AuthStorage for DefaultAuthStorage {
259 async fn get_credential(
260 &self,
261 provider: &str,
262 credential_type: CredentialType,
263 ) -> Result<Option<Credential>> {
264 self.keyring.get_credential(provider, credential_type).await
265 }
266
267 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
268 self.keyring
269 .set_credential(provider, credential.clone())
270 .await
271 }
272
273 async fn remove_credential(
274 &self,
275 provider: &str,
276 credential_type: CredentialType,
277 ) -> Result<()> {
278 self.keyring
279 .remove_credential(provider, credential_type)
280 .await
281 }
282}