xet_client/cas_client/
auth.rs1use std::fmt::Debug;
2use std::sync::Arc;
3#[cfg(not(target_family = "wasm"))]
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use derivative::Derivative;
7use reqwest_middleware::ClientWithMiddleware;
8use thiserror::Error;
9
10use crate::common::auth::CredentialHelper;
11
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum AuthError {
15 #[error("Refresh function: {0} is not callable")]
16 RefreshFunctionNotCallable(String),
17
18 #[error("Token refresh failed: {0}")]
19 TokenRefreshFailure(String),
20}
21
22impl AuthError {
23 pub fn token_refresh_failure(err: impl ToString) -> Self {
24 Self::TokenRefreshFailure(err.to_string())
25 }
26}
27
28const REFRESH_BUFFER_SEC: u64 = 30;
30
31pub type TokenInfo = (String, u64);
34
35#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
37#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
38pub trait TokenRefresher: Send + Sync {
39 async fn refresh(&self) -> Result<TokenInfo, AuthError>;
41}
42
43#[derive(Debug)]
44pub struct NoOpTokenRefresher;
45
46#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
47#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
48impl TokenRefresher for NoOpTokenRefresher {
49 async fn refresh(&self) -> Result<TokenInfo, AuthError> {
50 Ok(("token".to_string(), 0))
51 }
52}
53
54#[derive(Debug)]
55pub struct ErrTokenRefresher;
56
57#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
58#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
59impl TokenRefresher for ErrTokenRefresher {
60 async fn refresh(&self) -> Result<TokenInfo, AuthError> {
61 Err(AuthError::RefreshFunctionNotCallable("Token refresh not expected".to_string()))
62 }
63}
64
65pub struct DirectRefreshRouteTokenRefresher {
70 refresh_route: String,
71 client: ClientWithMiddleware,
72 cred_helper: Option<Arc<dyn CredentialHelper>>,
73}
74
75impl std::fmt::Debug for DirectRefreshRouteTokenRefresher {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("DirectRefreshRouteTokenRefresher")
78 .field("refresh_route", &self.refresh_route)
79 .finish_non_exhaustive()
80 }
81}
82
83impl DirectRefreshRouteTokenRefresher {
84 pub fn new(
85 refresh_route: impl Into<String>,
86 client: ClientWithMiddleware,
87 cred_helper: Option<Arc<dyn CredentialHelper>>,
88 ) -> Self {
89 Self {
90 refresh_route: refresh_route.into(),
91 client,
92 cred_helper,
93 }
94 }
95
96 pub async fn get_cas_jwt(&self) -> Result<crate::hub_client::CasJWTInfo, crate::ClientError> {
97 let client = self.client.clone();
98 let refresh_route = self.refresh_route.clone();
99 let cred_helper = self.cred_helper.clone();
100
101 let jwt_info: crate::hub_client::CasJWTInfo = super::retry_wrapper::RetryWrapper::new("xet-token")
102 .run_and_extract_json(move || {
103 let refresh_route = refresh_route.clone();
104 let client = client.clone();
105 let cred_helper = cred_helper.clone();
106 async move {
107 let req = client
108 .get(&refresh_route)
109 .with_extension(crate::common::http_client::Api("xet-token"));
110 let req = if let Some(helper) = cred_helper {
111 helper
112 .fill_credential(req)
113 .await
114 .map_err(reqwest_middleware::Error::middleware)?
115 } else {
116 req
117 };
118 req.send().await
119 }
120 })
121 .await?;
122
123 Ok(jwt_info)
124 }
125}
126
127#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
128#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
129impl TokenRefresher for DirectRefreshRouteTokenRefresher {
130 async fn refresh(&self) -> Result<TokenInfo, AuthError> {
131 let jwt_info = self.get_cas_jwt().await.map_err(AuthError::token_refresh_failure)?;
132
133 Ok((jwt_info.access_token, jwt_info.exp))
134 }
135}
136
137#[derive(Clone, Derivative)]
139#[derivative(Debug)]
140pub struct AuthConfig {
141 pub token: String,
143 pub token_expiration: u64,
145 #[derivative(Debug = "ignore")]
147 pub token_refresher: Arc<dyn TokenRefresher>,
148}
149
150impl AuthConfig {
151 pub fn maybe_new(
153 token: Option<String>,
154 token_expiry: Option<u64>,
155 token_refresher: Option<Arc<dyn TokenRefresher>>,
156 ) -> Option<Self> {
157 match (token, token_expiry, token_refresher) {
158 (token, expiry, Some(refresher)) => Some(Self {
160 token: token.unwrap_or_default(),
161 token_expiration: expiry.unwrap_or_default(),
162 token_refresher: refresher,
163 }),
164 (Some(token), expiry, None) => Some(Self {
167 token,
168 token_expiration: expiry.unwrap_or(u64::MAX),
169 token_refresher: Arc::new(ErrTokenRefresher),
170 }),
171 (_, _, _) => None,
172 }
173 }
174}
175
176pub struct TokenProvider {
177 token: String,
178 expiration: u64,
179 refresher: Arc<dyn TokenRefresher>,
180}
181
182impl TokenProvider {
183 pub fn new(cfg: &AuthConfig) -> Self {
184 Self {
185 token: cfg.token.clone(),
186 expiration: cfg.token_expiration,
187 refresher: cfg.token_refresher.clone(),
188 }
189 }
190
191 pub async fn get_valid_token(&mut self) -> Result<String, AuthError> {
192 if self.is_expired() {
193 let (new_token, new_expiry) = self.refresher.refresh().await?;
194 self.token = new_token;
195 self.expiration = new_expiry;
196 }
197 Ok(self.token.clone())
198 }
199
200 fn is_expired(&self) -> bool {
201 #[cfg(not(target_family = "wasm"))]
202 let cur_time = SystemTime::now()
203 .duration_since(UNIX_EPOCH)
204 .map(|d| d.as_secs())
205 .unwrap_or(u64::MAX);
206 #[cfg(target_family = "wasm")]
207 let cur_time = web_time::SystemTime::now()
208 .duration_since(web_time::UNIX_EPOCH)
209 .map(|d| d.as_secs())
210 .unwrap_or(u64::MAX);
211 self.expiration <= cur_time + REFRESH_BUFFER_SEC
212 }
213}