Skip to main content

xet_client/cas_client/
auth.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3#[cfg(not(target_family = "wasm"))]
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest_middleware::ClientWithMiddleware;
7use thiserror::Error;
8use tracing::info;
9
10use crate::common::auth::CredentialHelper;
11
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum AuthError {
15    #[error("Refresh function: {0} is not callable")]
16    RefreshFunctionNotCallable(String),
17
18    #[error("Token refresh failed: {0}")]
19    TokenRefreshFailure(String),
20}
21
22impl AuthError {
23    pub fn token_refresh_failure(err: impl ToString) -> Self {
24        Self::TokenRefreshFailure(err.to_string())
25    }
26}
27
28/// Seconds before the token expires to refresh
29const REFRESH_BUFFER_SEC: u64 = 30;
30
31/// Helper type for information about an auth token.
32/// Namely, the token itself and expiration time
33pub type TokenInfo = (String, u64);
34
35/// Helper to provide auth tokens to CAS.
36#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
37#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
38pub trait TokenRefresher: Send + Sync {
39    /// Get a new auth token for CAS and the unixtime (in seconds) for expiration
40    async fn refresh(&self) -> Result<TokenInfo, AuthError>;
41}
42
43#[derive(Debug)]
44pub struct NoOpTokenRefresher;
45
46#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
47#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
48impl TokenRefresher for NoOpTokenRefresher {
49    async fn refresh(&self) -> Result<TokenInfo, AuthError> {
50        Ok(("token".to_string(), 0))
51    }
52}
53
54#[derive(Debug)]
55pub struct ErrTokenRefresher;
56
57#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
58#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
59impl TokenRefresher for ErrTokenRefresher {
60    async fn refresh(&self) -> Result<TokenInfo, AuthError> {
61        Err(AuthError::RefreshFunctionNotCallable("Token refresh not expected".to_string()))
62    }
63}
64
65/// Token refresher that fetches a new token by making an authenticated GET request to a URL.
66///
67/// An optional [`CredentialHelper`](crate::common::auth::CredentialHelper) is applied to the
68/// request before it is sent; pass `None` when no additional credentials are needed.
69pub struct DirectRefreshRouteTokenRefresher {
70    refresh_route: String,
71    client: ClientWithMiddleware,
72    cred_helper: Option<Arc<dyn CredentialHelper>>,
73}
74
75impl std::fmt::Debug for DirectRefreshRouteTokenRefresher {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("DirectRefreshRouteTokenRefresher")
78            .field("refresh_route", &self.refresh_route)
79            .finish_non_exhaustive()
80    }
81}
82
83impl DirectRefreshRouteTokenRefresher {
84    pub fn new(
85        refresh_route: impl Into<String>,
86        client: ClientWithMiddleware,
87        cred_helper: Option<Arc<dyn CredentialHelper>>,
88    ) -> Self {
89        Self {
90            refresh_route: refresh_route.into(),
91            client,
92            cred_helper,
93        }
94    }
95
96    pub async fn get_cas_jwt(&self) -> Result<crate::hub_client::CasJWTInfo, crate::ClientError> {
97        let client = self.client.clone();
98        let refresh_route = self.refresh_route.clone();
99        let cred_helper = self.cred_helper.clone();
100
101        let jwt_info: crate::hub_client::CasJWTInfo = super::retry_wrapper::RetryWrapper::new("xet-token")
102            .run_and_extract_json(move || {
103                let refresh_route = refresh_route.clone();
104                let client = client.clone();
105                let cred_helper = cred_helper.clone();
106                async move {
107                    let req = client
108                        .get(&refresh_route)
109                        .with_extension(crate::common::http_client::Api("xet-token"));
110                    let req = if let Some(helper) = cred_helper {
111                        helper
112                            .fill_credential(req)
113                            .await
114                            .map_err(reqwest_middleware::Error::middleware)?
115                    } else {
116                        req
117                    };
118                    req.send().await
119                }
120            })
121            .await?;
122
123        Ok(jwt_info)
124    }
125}
126
127#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
128#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
129impl TokenRefresher for DirectRefreshRouteTokenRefresher {
130    async fn refresh(&self) -> Result<TokenInfo, AuthError> {
131        let jwt_info = self.get_cas_jwt().await.map_err(AuthError::token_refresh_failure)?;
132
133        Ok((jwt_info.access_token, jwt_info.exp))
134    }
135}
136
137/// Shared configuration for token-based auth
138#[derive(Clone)]
139pub struct AuthConfig {
140    /// Initial token to use
141    pub token: String,
142    /// Initial token expiration time in epoch seconds
143    pub token_expiration: u64,
144    /// A function to refresh tokens.
145    pub token_refresher: Arc<dyn TokenRefresher>,
146}
147
148impl Debug for AuthConfig {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("AuthConfig")
151            .field("token", &self.token)
152            .field("token_expiration", &self.token_expiration)
153            .finish_non_exhaustive()
154    }
155}
156
157impl AuthConfig {
158    /// Builds a new AuthConfig from the indicated optional parameters.
159    pub fn maybe_new(
160        token: Option<String>,
161        token_expiry: Option<u64>,
162        token_refresher: Option<Arc<dyn TokenRefresher>>,
163    ) -> Option<Self> {
164        match (token, token_expiry, token_refresher) {
165            // we have a refresher, so use that. Doesn't matter if the token/expiry are set since we can refresh them.
166            (token, expiry, Some(refresher)) => Some(Self {
167                token: token.unwrap_or_default(),
168                token_expiration: expiry.unwrap_or_default(),
169                token_refresher: refresher,
170            }),
171            // Since no refreshing, we instead use the token with some expiration (no expiration means we expect this
172            // token to live forever.
173            (Some(token), expiry, None) => Some(Self {
174                token,
175                token_expiration: expiry.unwrap_or(u64::MAX),
176                token_refresher: Arc::new(ErrTokenRefresher),
177            }),
178            (_, _, _) => None,
179        }
180    }
181}
182
183pub struct TokenProvider {
184    token: String,
185    expiration: u64,
186    refresher: Arc<dyn TokenRefresher>,
187}
188
189impl TokenProvider {
190    pub fn new(cfg: &AuthConfig) -> Self {
191        Self {
192            token: cfg.token.clone(),
193            expiration: cfg.token_expiration,
194            refresher: cfg.token_refresher.clone(),
195        }
196    }
197
198    pub async fn get_valid_token(&mut self) -> Result<String, AuthError> {
199        if self.is_expired() {
200            let (new_token, new_expiry) = self.refresher.refresh().await?;
201            self.token = new_token;
202            self.expiration = new_expiry;
203            info!(new_expiry = new_expiry, "Token refreshed");
204        }
205        Ok(self.token.clone())
206    }
207
208    fn is_expired(&self) -> bool {
209        #[cfg(not(target_family = "wasm"))]
210        let cur_time = SystemTime::now()
211            .duration_since(UNIX_EPOCH)
212            .map(|d| d.as_secs())
213            .unwrap_or(u64::MAX);
214        #[cfg(target_family = "wasm")]
215        let cur_time = web_time::SystemTime::now()
216            .duration_since(web_time::UNIX_EPOCH)
217            .map(|d| d.as_secs())
218            .unwrap_or(u64::MAX);
219        self.expiration <= cur_time + REFRESH_BUFFER_SEC
220    }
221}