xet_client/cas_client/
auth.rs1use std::fmt::Debug;
2use std::sync::Arc;
3#[cfg(not(target_family = "wasm"))]
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest_middleware::ClientWithMiddleware;
7use thiserror::Error;
8use tracing::info;
9
10use crate::common::auth::CredentialHelper;
11
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum AuthError {
15 #[error("Refresh function: {0} is not callable")]
16 RefreshFunctionNotCallable(String),
17
18 #[error("Token refresh failed: {0}")]
19 TokenRefreshFailure(String),
20}
21
22impl AuthError {
23 pub fn token_refresh_failure(err: impl ToString) -> Self {
24 Self::TokenRefreshFailure(err.to_string())
25 }
26}
27
28const REFRESH_BUFFER_SEC: u64 = 30;
30
31pub type TokenInfo = (String, u64);
34
35#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
37#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
38pub trait TokenRefresher: Send + Sync {
39 async fn refresh(&self) -> Result<TokenInfo, AuthError>;
41}
42
43#[derive(Debug)]
44pub struct NoOpTokenRefresher;
45
46#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
47#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
48impl TokenRefresher for NoOpTokenRefresher {
49 async fn refresh(&self) -> Result<TokenInfo, AuthError> {
50 Ok(("token".to_string(), 0))
51 }
52}
53
54#[derive(Debug)]
55pub struct ErrTokenRefresher;
56
57#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
58#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
59impl TokenRefresher for ErrTokenRefresher {
60 async fn refresh(&self) -> Result<TokenInfo, AuthError> {
61 Err(AuthError::RefreshFunctionNotCallable("Token refresh not expected".to_string()))
62 }
63}
64
65pub struct DirectRefreshRouteTokenRefresher {
70 refresh_route: String,
71 client: ClientWithMiddleware,
72 cred_helper: Option<Arc<dyn CredentialHelper>>,
73}
74
75impl std::fmt::Debug for DirectRefreshRouteTokenRefresher {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("DirectRefreshRouteTokenRefresher")
78 .field("refresh_route", &self.refresh_route)
79 .finish_non_exhaustive()
80 }
81}
82
83impl DirectRefreshRouteTokenRefresher {
84 pub fn new(
85 refresh_route: impl Into<String>,
86 client: ClientWithMiddleware,
87 cred_helper: Option<Arc<dyn CredentialHelper>>,
88 ) -> Self {
89 Self {
90 refresh_route: refresh_route.into(),
91 client,
92 cred_helper,
93 }
94 }
95
96 pub async fn get_cas_jwt(&self) -> Result<crate::hub_client::CasJWTInfo, crate::ClientError> {
97 let client = self.client.clone();
98 let refresh_route = self.refresh_route.clone();
99 let cred_helper = self.cred_helper.clone();
100
101 let jwt_info: crate::hub_client::CasJWTInfo = super::retry_wrapper::RetryWrapper::new("xet-token")
102 .run_and_extract_json(move || {
103 let refresh_route = refresh_route.clone();
104 let client = client.clone();
105 let cred_helper = cred_helper.clone();
106 async move {
107 let req = client
108 .get(&refresh_route)
109 .with_extension(crate::common::http_client::Api("xet-token"));
110 let req = if let Some(helper) = cred_helper {
111 helper
112 .fill_credential(req)
113 .await
114 .map_err(reqwest_middleware::Error::middleware)?
115 } else {
116 req
117 };
118 req.send().await
119 }
120 })
121 .await?;
122
123 Ok(jwt_info)
124 }
125}
126
127#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
128#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
129impl TokenRefresher for DirectRefreshRouteTokenRefresher {
130 async fn refresh(&self) -> Result<TokenInfo, AuthError> {
131 let jwt_info = self.get_cas_jwt().await.map_err(AuthError::token_refresh_failure)?;
132
133 Ok((jwt_info.access_token, jwt_info.exp))
134 }
135}
136
137#[derive(Clone)]
139pub struct AuthConfig {
140 pub token: String,
142 pub token_expiration: u64,
144 pub token_refresher: Arc<dyn TokenRefresher>,
146}
147
148impl Debug for AuthConfig {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("AuthConfig")
151 .field("token", &self.token)
152 .field("token_expiration", &self.token_expiration)
153 .finish_non_exhaustive()
154 }
155}
156
157impl AuthConfig {
158 pub fn maybe_new(
160 token: Option<String>,
161 token_expiry: Option<u64>,
162 token_refresher: Option<Arc<dyn TokenRefresher>>,
163 ) -> Option<Self> {
164 match (token, token_expiry, token_refresher) {
165 (token, expiry, Some(refresher)) => Some(Self {
167 token: token.unwrap_or_default(),
168 token_expiration: expiry.unwrap_or_default(),
169 token_refresher: refresher,
170 }),
171 (Some(token), expiry, None) => Some(Self {
174 token,
175 token_expiration: expiry.unwrap_or(u64::MAX),
176 token_refresher: Arc::new(ErrTokenRefresher),
177 }),
178 (_, _, _) => None,
179 }
180 }
181}
182
183pub struct TokenProvider {
184 token: String,
185 expiration: u64,
186 refresher: Arc<dyn TokenRefresher>,
187}
188
189impl TokenProvider {
190 pub fn new(cfg: &AuthConfig) -> Self {
191 Self {
192 token: cfg.token.clone(),
193 expiration: cfg.token_expiration,
194 refresher: cfg.token_refresher.clone(),
195 }
196 }
197
198 pub async fn get_valid_token(&mut self) -> Result<String, AuthError> {
199 if self.is_expired() {
200 let (new_token, new_expiry) = self.refresher.refresh().await?;
201 self.token = new_token;
202 self.expiration = new_expiry;
203 info!(new_expiry = new_expiry, "Token refreshed");
204 }
205 Ok(self.token.clone())
206 }
207
208 fn is_expired(&self) -> bool {
209 #[cfg(not(target_family = "wasm"))]
210 let cur_time = SystemTime::now()
211 .duration_since(UNIX_EPOCH)
212 .map(|d| d.as_secs())
213 .unwrap_or(u64::MAX);
214 #[cfg(target_family = "wasm")]
215 let cur_time = web_time::SystemTime::now()
216 .duration_since(web_time::UNIX_EPOCH)
217 .map(|d| d.as_secs())
218 .unwrap_or(u64::MAX);
219 self.expiration <= cur_time + REFRESH_BUFFER_SEC
220 }
221}