Skip to main content

xet_client/cas_client/
auth.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3#[cfg(not(target_family = "wasm"))]
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use derivative::Derivative;
7use reqwest_middleware::ClientWithMiddleware;
8use thiserror::Error;
9
10use crate::common::auth::CredentialHelper;
11
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum AuthError {
15    #[error("Refresh function: {0} is not callable")]
16    RefreshFunctionNotCallable(String),
17
18    #[error("Token refresh failed: {0}")]
19    TokenRefreshFailure(String),
20}
21
22impl AuthError {
23    pub fn token_refresh_failure(err: impl ToString) -> Self {
24        Self::TokenRefreshFailure(err.to_string())
25    }
26}
27
28/// Seconds before the token expires to refresh
29const REFRESH_BUFFER_SEC: u64 = 30;
30
31/// Helper type for information about an auth token.
32/// Namely, the token itself and expiration time
33pub type TokenInfo = (String, u64);
34
35/// Helper to provide auth tokens to CAS.
36#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
37#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
38pub trait TokenRefresher: Send + Sync {
39    /// Get a new auth token for CAS and the unixtime (in seconds) for expiration
40    async fn refresh(&self) -> Result<TokenInfo, AuthError>;
41}
42
43#[derive(Debug)]
44pub struct NoOpTokenRefresher;
45
46#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
47#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
48impl TokenRefresher for NoOpTokenRefresher {
49    async fn refresh(&self) -> Result<TokenInfo, AuthError> {
50        Ok(("token".to_string(), 0))
51    }
52}
53
54#[derive(Debug)]
55pub struct ErrTokenRefresher;
56
57#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
58#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
59impl TokenRefresher for ErrTokenRefresher {
60    async fn refresh(&self) -> Result<TokenInfo, AuthError> {
61        Err(AuthError::RefreshFunctionNotCallable("Token refresh not expected".to_string()))
62    }
63}
64
65/// Token refresher that fetches a new token by making an authenticated GET request to a URL.
66///
67/// An optional [`CredentialHelper`](crate::common::auth::CredentialHelper) is applied to the
68/// request before it is sent; pass `None` when no additional credentials are needed.
69pub struct DirectRefreshRouteTokenRefresher {
70    refresh_route: String,
71    client: ClientWithMiddleware,
72    cred_helper: Option<Arc<dyn CredentialHelper>>,
73}
74
75impl std::fmt::Debug for DirectRefreshRouteTokenRefresher {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("DirectRefreshRouteTokenRefresher")
78            .field("refresh_route", &self.refresh_route)
79            .finish_non_exhaustive()
80    }
81}
82
83impl DirectRefreshRouteTokenRefresher {
84    pub fn new(
85        refresh_route: impl Into<String>,
86        client: ClientWithMiddleware,
87        cred_helper: Option<Arc<dyn CredentialHelper>>,
88    ) -> Self {
89        Self {
90            refresh_route: refresh_route.into(),
91            client,
92            cred_helper,
93        }
94    }
95
96    pub async fn get_cas_jwt(&self) -> Result<crate::hub_client::CasJWTInfo, crate::ClientError> {
97        let client = self.client.clone();
98        let refresh_route = self.refresh_route.clone();
99        let cred_helper = self.cred_helper.clone();
100
101        let jwt_info: crate::hub_client::CasJWTInfo = super::retry_wrapper::RetryWrapper::new("xet-token")
102            .run_and_extract_json(move || {
103                let refresh_route = refresh_route.clone();
104                let client = client.clone();
105                let cred_helper = cred_helper.clone();
106                async move {
107                    let req = client
108                        .get(&refresh_route)
109                        .with_extension(crate::common::http_client::Api("xet-token"));
110                    let req = if let Some(helper) = cred_helper {
111                        helper
112                            .fill_credential(req)
113                            .await
114                            .map_err(reqwest_middleware::Error::middleware)?
115                    } else {
116                        req
117                    };
118                    req.send().await
119                }
120            })
121            .await?;
122
123        Ok(jwt_info)
124    }
125}
126
127#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
128#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
129impl TokenRefresher for DirectRefreshRouteTokenRefresher {
130    async fn refresh(&self) -> Result<TokenInfo, AuthError> {
131        let jwt_info = self.get_cas_jwt().await.map_err(AuthError::token_refresh_failure)?;
132
133        Ok((jwt_info.access_token, jwt_info.exp))
134    }
135}
136
137/// Shared configuration for token-based auth
138#[derive(Clone, Derivative)]
139#[derivative(Debug)]
140pub struct AuthConfig {
141    /// Initial token to use
142    pub token: String,
143    /// Initial token expiration time in epoch seconds
144    pub token_expiration: u64,
145    /// A function to refresh tokens.
146    #[derivative(Debug = "ignore")]
147    pub token_refresher: Arc<dyn TokenRefresher>,
148}
149
150impl AuthConfig {
151    /// Builds a new AuthConfig from the indicated optional parameters.
152    pub fn maybe_new(
153        token: Option<String>,
154        token_expiry: Option<u64>,
155        token_refresher: Option<Arc<dyn TokenRefresher>>,
156    ) -> Option<Self> {
157        match (token, token_expiry, token_refresher) {
158            // we have a refresher, so use that. Doesn't matter if the token/expiry are set since we can refresh them.
159            (token, expiry, Some(refresher)) => Some(Self {
160                token: token.unwrap_or_default(),
161                token_expiration: expiry.unwrap_or_default(),
162                token_refresher: refresher,
163            }),
164            // Since no refreshing, we instead use the token with some expiration (no expiration means we expect this
165            // token to live forever.
166            (Some(token), expiry, None) => Some(Self {
167                token,
168                token_expiration: expiry.unwrap_or(u64::MAX),
169                token_refresher: Arc::new(ErrTokenRefresher),
170            }),
171            (_, _, _) => None,
172        }
173    }
174}
175
176pub struct TokenProvider {
177    token: String,
178    expiration: u64,
179    refresher: Arc<dyn TokenRefresher>,
180}
181
182impl TokenProvider {
183    pub fn new(cfg: &AuthConfig) -> Self {
184        Self {
185            token: cfg.token.clone(),
186            expiration: cfg.token_expiration,
187            refresher: cfg.token_refresher.clone(),
188        }
189    }
190
191    pub async fn get_valid_token(&mut self) -> Result<String, AuthError> {
192        if self.is_expired() {
193            let (new_token, new_expiry) = self.refresher.refresh().await?;
194            self.token = new_token;
195            self.expiration = new_expiry;
196        }
197        Ok(self.token.clone())
198    }
199
200    fn is_expired(&self) -> bool {
201        #[cfg(not(target_family = "wasm"))]
202        let cur_time = SystemTime::now()
203            .duration_since(UNIX_EPOCH)
204            .map(|d| d.as_secs())
205            .unwrap_or(u64::MAX);
206        #[cfg(target_family = "wasm")]
207        let cur_time = web_time::SystemTime::now()
208            .duration_since(web_time::UNIX_EPOCH)
209            .map(|d| d.as_secs())
210            .unwrap_or(u64::MAX);
211        self.expiration <= cur_time + REFRESH_BUFFER_SEC
212    }
213}