steer_core/auth/
storage.rs1use crate::auth::error::{AuthError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6pub use steer_auth_plugin::storage::{
7 AuthStorage, AuthTokens, Credential, CredentialType, OAuth2Token,
8};
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15struct CredentialStore(HashMap<String, HashMap<CredentialType, Credential>>);
16
17pub struct KeyringStorage {
19 service_name: String,
20}
21
22impl Default for KeyringStorage {
23 fn default() -> Self {
24 Self::new("steer")
25 }
26}
27
28impl KeyringStorage {
29 pub fn new(service_name: &str) -> Self {
30 Self {
31 service_name: service_name.to_string(),
32 }
33 }
34
35 fn get_username() -> String {
36 whoami::username()
37 }
38}
39
40#[async_trait]
41impl AuthStorage for KeyringStorage {
42 async fn get_credential(
43 &self,
44 provider: &str,
45 credential_type: CredentialType,
46 ) -> Result<Option<Credential>> {
47 let provider = provider.to_string();
48 let username = Self::get_username();
49 let service = self.service_name.clone();
50 let cred_type = credential_type;
51
52 let result = tokio::task::spawn_blocking(
54 move || -> std::result::Result<Option<Credential>, keyring::Error> {
55 let entry = keyring::Entry::new(&service, &username)?;
56 let store_json = match entry.get_password() {
57 Ok(pwd) => pwd,
58 Err(keyring::Error::NoEntry) => return Ok(None),
59 Err(e) => return Err(e),
60 };
61
62 let store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
63
64 let cred = store
67 .0
68 .get(&provider)
69 .and_then(|m| m.get(&cred_type))
70 .cloned();
71
72 Ok(cred)
73 },
74 )
75 .await
76 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?;
77
78 result.map_err(AuthError::from)
79 }
80
81 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
82 let service = self.service_name.clone();
83 let username = Self::get_username();
84 let provider = provider.to_string();
85 let cred_type = credential.credential_type();
86
87 tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
88 let entry = keyring::Entry::new(&service, &username)?;
89 let mut store: CredentialStore = match entry.get_password() {
91 Ok(pwd) => serde_json::from_str(&pwd).unwrap_or_default(),
92 Err(keyring::Error::NoEntry) => CredentialStore::default(),
93 Err(e) => return Err(e),
94 };
95
96 store
98 .0
99 .entry(provider)
100 .or_default()
101 .insert(cred_type, credential);
102
103 let data = serde_json::to_string(&store).map_err(|e| {
104 keyring::Error::PlatformFailure(format!("serialize credential store: {e}").into())
105 })?;
106 entry.set_password(&data)?;
107 Ok(())
108 })
109 .await
110 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
111 .map_err(AuthError::from)
112 }
113
114 async fn remove_credential(
115 &self,
116 provider: &str,
117 credential_type: CredentialType,
118 ) -> Result<()> {
119 let service = self.service_name.clone();
120 let username = Self::get_username();
121 let provider = provider.to_string();
122
123 tokio::task::spawn_blocking(move || -> std::result::Result<(), keyring::Error> {
124 let entry = keyring::Entry::new(&service, &username)?;
125
126 let store_json = match entry.get_password() {
128 Ok(pwd) => pwd,
129 Err(keyring::Error::NoEntry) => return Ok(()),
130 Err(e) => return Err(e),
131 };
132
133 let mut store: CredentialStore = serde_json::from_str(&store_json).unwrap_or_default();
134
135 if let Some(map) = store.0.get_mut(&provider) {
136 map.remove(&credential_type);
137 if map.is_empty() {
138 store.0.remove(&provider);
139 }
140 }
141
142 if store.0.is_empty() {
143 let _ = entry.delete_credential();
145 } else {
146 let data = serde_json::to_string(&store).map_err(|e| {
147 keyring::Error::PlatformFailure(
148 format!("serialize credential store: {e}").into(),
149 )
150 })?;
151 entry.set_password(&data)?;
152 }
153 Ok(())
154 })
155 .await
156 .map_err(|e| AuthError::Storage(format!("Task join error: {e}")))?
157 .map_err(AuthError::from)
158 }
159}
160
161pub struct DefaultAuthStorage {
163 keyring: Arc<dyn AuthStorage>,
164}
165
166impl DefaultAuthStorage {
167 pub fn new() -> Result<Self> {
168 if !cfg!(any(
170 target_os = "macos",
171 target_os = "windows",
172 target_os = "linux"
173 )) {
174 return Err(AuthError::Storage(
175 "Keyring not supported on this platform".to_string(),
176 ));
177 }
178
179 let keyring = Arc::new(KeyringStorage::new("steer")) as Arc<dyn AuthStorage>;
180
181 Ok(Self { keyring })
182 }
183
184 pub async fn get_auth_tokens(&self, provider: &str) -> Result<Option<OAuth2Token>> {
186 match self
187 .get_credential(provider, CredentialType::OAuth2)
188 .await?
189 {
190 Some(Credential::OAuth2(tokens)) => Ok(Some(tokens)),
191 _ => Ok(None),
192 }
193 }
194
195 pub async fn set_auth_tokens(&self, provider: &str, tokens: OAuth2Token) -> Result<()> {
196 self.set_credential(provider, Credential::OAuth2(tokens))
197 .await
198 }
199
200 pub async fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
201 match self
202 .get_credential(provider, CredentialType::ApiKey)
203 .await?
204 {
205 Some(Credential::ApiKey { value }) => Ok(Some(value)),
206 _ => Ok(None),
207 }
208 }
209
210 pub async fn set_api_key(&self, provider: &str, api_key: String) -> Result<()> {
211 self.set_credential(provider, Credential::ApiKey { value: api_key })
212 .await
213 }
214
215 pub async fn remove_auth_tokens(&self, provider: &str) -> Result<()> {
216 self.remove_credential(provider, CredentialType::OAuth2)
217 .await
218 }
219
220 pub async fn remove_api_key(&self, provider: &str) -> Result<()> {
221 self.remove_credential(provider, CredentialType::ApiKey)
222 .await
223 }
224}
225
226#[async_trait]
227impl AuthStorage for DefaultAuthStorage {
228 async fn get_credential(
229 &self,
230 provider: &str,
231 credential_type: CredentialType,
232 ) -> Result<Option<Credential>> {
233 self.keyring.get_credential(provider, credential_type).await
234 }
235
236 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
237 self.keyring
238 .set_credential(provider, credential.clone())
239 .await
240 }
241
242 async fn remove_credential(
243 &self,
244 provider: &str,
245 credential_type: CredentialType,
246 ) -> Result<()> {
247 self.keyring
248 .remove_credential(provider, credential_type)
249 .await
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use std::time::{Duration, SystemTime};
257
258 #[test]
259 fn test_credential_deserialization_with_alias() {
260 let old_json_format = r#"{
262 "anthropic": {
263 "AuthTokens": {
264 "type": "AuthTokens",
265 "access_token": "old_access_token",
266 "refresh_token": "old_refresh_token",
267 "expires_at": {
268 "secs_since_epoch": 1678886400,
269 "nanos_since_epoch": 0
270 }
271 }
272 }
273 }"#;
274
275 let store: CredentialStore =
276 serde_json::from_str(old_json_format).expect("Failed to deserialize old format");
277
278 let creds = store.0.get("anthropic").unwrap();
280 let cred = creds.get(&CredentialType::OAuth2).unwrap();
281
282 match cred {
283 Credential::OAuth2(token) => {
284 assert_eq!(token.access_token, "old_access_token");
285 assert_eq!(token.refresh_token, "old_refresh_token");
286 assert_eq!(
287 token.expires_at,
288 SystemTime::UNIX_EPOCH + Duration::from_secs(1_678_886_400)
289 );
290 }
291 Credential::ApiKey { .. } => {
292 panic!("Deserialization failed: expected OAuth2 credential")
293 }
294 }
295 }
296
297 #[test]
298 fn test_credential_type_deserialization_with_alias() {
299 let old_type = r#""AuthTokens""#;
301 let cred_type: CredentialType = serde_json::from_str(old_type).unwrap();
302 assert_eq!(cred_type, CredentialType::OAuth2);
303
304 let new_type = r#""OAuth2""#;
306 let cred_type: CredentialType = serde_json::from_str(new_type).unwrap();
307 assert_eq!(cred_type, CredentialType::OAuth2);
308 }
309}