Skip to main content

uv_auth/
pyx.rs

1use std::io;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use base64::Engine;
6use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use etcetera::BaseStrategy;
8use reqwest_middleware::ClientWithMiddleware;
9use tracing::debug;
10use url::Url;
11use uv_fs::{LockedFile, LockedFileMode};
12
13use uv_cache_key::CanonicalUrl;
14use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
15use uv_small_str::SmallString;
16use uv_state::{StateBucket, StateStore};
17use uv_static::EnvVars;
18
19use crate::credentials::Token;
20use crate::{AccessToken, Credentials, Realm};
21
22/// The default pyx API URL.
23const PYX_DEFAULT_API_URL: &str = "https://api.pyx.dev";
24
25/// The default pyx CDN domain.
26const PYX_DEFAULT_CDN_DOMAIN: &str = "astralhosted.com";
27
28/// Retrieve the pyx API key from the environment variable, or return `None`.
29fn read_pyx_api_key() -> Option<String> {
30    std::env::var(EnvVars::PYX_API_KEY)
31        .ok()
32        .or_else(|| std::env::var(EnvVars::UV_API_KEY).ok())
33}
34
35/// Retrieve the pyx authentication token (JWT) from the environment variable, or return `None`.
36fn read_pyx_auth_token() -> Option<AccessToken> {
37    std::env::var(EnvVars::PYX_AUTH_TOKEN)
38        .ok()
39        .or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok())
40        .map(AccessToken::from)
41}
42
43/// An access token with an accompanying refresh token.
44///
45/// Refresh tokens are single-use tokens that can be exchanged for a renewed access token
46/// and a new refresh token.
47#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
48pub struct PyxOAuthTokens {
49    pub access_token: AccessToken,
50    pub refresh_token: String,
51}
52
53/// An access token with an accompanying API key.
54#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
55pub struct PyxApiKeyTokens {
56    pub access_token: AccessToken,
57    pub api_key: String,
58}
59
60#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
61pub enum PyxTokens {
62    /// An access token with an accompanying refresh token.
63    ///
64    /// Refresh tokens are single-use tokens that can be exchanged for a renewed access token
65    /// and a new refresh token.
66    OAuth(PyxOAuthTokens),
67    /// An access token with an accompanying API key.
68    ///
69    /// API keys are long-lived tokens that can be exchanged for an access token.
70    ApiKey(PyxApiKeyTokens),
71}
72
73impl From<PyxTokens> for AccessToken {
74    fn from(tokens: PyxTokens) -> Self {
75        match tokens {
76            PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
77            PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
78        }
79    }
80}
81
82impl From<PyxTokens> for Credentials {
83    fn from(tokens: PyxTokens) -> Self {
84        let access_token = match tokens {
85            PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
86            PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
87        };
88        Self::from(access_token)
89    }
90}
91
92impl From<AccessToken> for Credentials {
93    fn from(access_token: AccessToken) -> Self {
94        Self::Bearer {
95            token: Token::new(access_token.into_bytes()),
96        }
97    }
98}
99
100/// Reason why a token is considered expired and needs refresh.
101#[derive(Debug, Clone)]
102enum ExpiredTokenReason {
103    /// The token has no expiration claim.
104    MissingExpiration,
105    /// Zero tolerance was requested, forcing a refresh.
106    ForcedRefresh,
107    /// The token's expiration time has passed.
108    Expired(jiff::Timestamp),
109    /// The token will expire within the tolerance window.
110    ExpiringSoon(jiff::Timestamp),
111}
112
113impl std::fmt::Display for ExpiredTokenReason {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            Self::MissingExpiration => write!(f, "missing expiration"),
117            Self::ForcedRefresh => write!(f, "forced refresh"),
118            Self::Expired(exp) => write!(f, "token expired (`{exp}`)"),
119            Self::ExpiringSoon(exp) => write!(f, "token will expire within tolerance (`{exp}`)"),
120        }
121    }
122}
123
124impl PyxTokens {
125    /// Returns the access token.
126    fn access_token(&self) -> &AccessToken {
127        match self {
128            Self::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
129            Self::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
130        }
131    }
132
133    /// Check if the token is fresh (not expired and not expiring within tolerance).
134    ///
135    /// Returns `Ok(expiration)` if fresh, or `Err(reason)` if refresh is needed.
136    fn check_fresh(&self, tolerance_secs: u64) -> Result<jiff::Timestamp, ExpiredTokenReason> {
137        let Ok(jwt) = PyxJwt::decode(self.access_token()) else {
138            return Err(ExpiredTokenReason::MissingExpiration);
139        };
140        match jwt.exp {
141            None => Err(ExpiredTokenReason::MissingExpiration),
142            Some(_) if tolerance_secs == 0 => Err(ExpiredTokenReason::ForcedRefresh),
143            Some(exp) => {
144                let Ok(exp) = jiff::Timestamp::from_second(exp) else {
145                    return Err(ExpiredTokenReason::MissingExpiration);
146                };
147                let now = jiff::Timestamp::now();
148                if exp < now {
149                    Err(ExpiredTokenReason::Expired(exp))
150                } else if exp < now + Duration::from_secs(tolerance_secs) {
151                    Err(ExpiredTokenReason::ExpiringSoon(exp))
152                } else {
153                    Ok(exp)
154                }
155            }
156        }
157    }
158}
159
160/// The default tolerance for the access token expiration.
161pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5;
162
163#[derive(Debug, Clone)]
164struct PyxDirectories {
165    /// The root directory for the token store (e.g., `/Users/ferris/.local/share/pyx/credentials`).
166    root: PathBuf,
167    /// The subdirectory for the token store (e.g., `/Users/ferris/.local/share/uv/credentials/3859a629b26fda96`).
168    subdirectory: PathBuf,
169}
170
171impl PyxDirectories {
172    /// Detect the [`PyxDirectories`] for a given API URL.
173    fn from_api(api: &DisplaySafeUrl) -> Result<Self, io::Error> {
174        // Store credentials in a subdirectory based on the API URL.
175        let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api));
176
177        // If the user explicitly set `PYX_CREDENTIALS_DIR`, use that.
178        if let Some(root) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) {
179            let root = std::path::absolute(root)?;
180            let subdirectory = root.join(&digest);
181            return Ok(Self { root, subdirectory });
182        }
183
184        // If the user has pyx credentials in their uv credentials directory, read them for
185        // backwards compatibility.
186        let root = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) {
187            std::path::absolute(tool_dir)?
188        } else {
189            StateStore::from_settings(None)?.bucket(StateBucket::Credentials)
190        };
191        let subdirectory = root.join(&digest);
192        if subdirectory.exists() {
193            return Ok(Self { root, subdirectory });
194        }
195
196        // Otherwise, use (e.g.) `~/.local/share/pyx`.
197        let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else {
198            return Err(io::Error::new(
199                io::ErrorKind::NotFound,
200                "Could not determine user data directory",
201            ));
202        };
203
204        let root = xdg.data_dir().join("pyx").join("credentials");
205        let subdirectory = root.join(&digest);
206        Ok(Self { root, subdirectory })
207    }
208}
209
210#[derive(Debug, Clone)]
211pub struct PyxTokenStore {
212    /// The root directory for the token store (e.g., `/Users/ferris/.local/share/pyx/credentials`).
213    root: PathBuf,
214    /// The subdirectory for the token store (e.g., `/Users/ferris/.local/share/uv/credentials/3859a629b26fda96`).
215    subdirectory: PathBuf,
216    /// The API URL for the token store (e.g., `https://api.pyx.dev`).
217    api: DisplaySafeUrl,
218    /// The CDN domain for the token store (e.g., `astralhosted.com`).
219    cdn: SmallString,
220}
221
222impl PyxTokenStore {
223    /// Create a new [`PyxTokenStore`] from settings.
224    pub fn from_settings() -> Result<Self, TokenStoreError> {
225        // Read the API URL and CDN domain from the environment variables, or fallback to the
226        // defaults.
227        let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) {
228            DisplaySafeUrl::parse(&api_url)
229        } else {
230            DisplaySafeUrl::parse(PYX_DEFAULT_API_URL)
231        }?;
232        let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN)
233            .ok()
234            .map(SmallString::from)
235            .unwrap_or_else(|| SmallString::from(arcstr::literal!(PYX_DEFAULT_CDN_DOMAIN)));
236
237        // Determine the root directory for the token store.
238        let PyxDirectories { root, subdirectory } = PyxDirectories::from_api(&api)?;
239
240        Ok(Self {
241            root,
242            subdirectory,
243            api,
244            cdn,
245        })
246    }
247
248    /// Return the root directory for the token store.
249    pub fn root(&self) -> &Path {
250        &self.root
251    }
252
253    /// Return the API URL for the token store.
254    pub fn api(&self) -> &DisplaySafeUrl {
255        &self.api
256    }
257
258    /// Get or initialize an [`AccessToken`] from the store.
259    ///
260    /// If an access token is set in the environment, it will be returned as-is.
261    ///
262    /// If an access token is present on-disk, it will be returned (and refreshed, if necessary).
263    ///
264    /// If no access token is found, but an API key is present, the API key will be used to
265    /// bootstrap an access token.
266    pub async fn access_token(
267        &self,
268        client: &ClientWithMiddleware,
269        tolerance_secs: u64,
270    ) -> Result<Option<AccessToken>, TokenStoreError> {
271        // If the access token is already set in the environment, return it.
272        if let Some(access_token) = read_pyx_auth_token() {
273            return Ok(Some(access_token));
274        }
275
276        // Initialize the tokens from the store.
277        let tokens = self.init(client, tolerance_secs).await?;
278
279        // Extract the access token from the OAuth tokens or API key.
280        Ok(tokens.map(AccessToken::from))
281    }
282
283    /// Initialize the [`PyxTokens`] from the store.
284    ///
285    /// If an access token is already present, it will be returned (and refreshed, if necessary).
286    ///
287    /// If no access token is found, but an API key is present, the API key will be used to
288    /// bootstrap an access token.
289    pub async fn init(
290        &self,
291        client: &ClientWithMiddleware,
292        tolerance_secs: u64,
293    ) -> Result<Option<PyxTokens>, TokenStoreError> {
294        match self.read().await? {
295            Some(tokens) => {
296                // Refresh the tokens if they are expired.
297                let tokens = self.refresh(tokens, client, tolerance_secs).await?;
298                Ok(Some(tokens))
299            }
300            None => {
301                // If no tokens are present, bootstrap them from an API key.
302                self.bootstrap(client).await
303            }
304        }
305    }
306
307    /// Write the tokens to the store.
308    pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> {
309        fs_err::tokio::create_dir_all(&self.subdirectory).await?;
310        match tokens {
311            PyxTokens::OAuth(tokens) => {
312                // Write OAuth tokens to a generic `tokens.json` file.
313                fs_err::tokio::write(
314                    self.subdirectory.join("tokens.json"),
315                    serde_json::to_vec(tokens)?,
316                )
317                .await?;
318            }
319            PyxTokens::ApiKey(tokens) => {
320                // Write API key tokens to a file based on the API key.
321                let digest = uv_cache_key::cache_digest(&tokens.api_key);
322                fs_err::tokio::write(
323                    self.subdirectory.join(format!("{digest}.json")),
324                    &tokens.access_token,
325                )
326                .await?;
327            }
328        }
329        Ok(())
330    }
331
332    /// Returns `true` if the user appears to have an authentication token set.
333    pub fn has_auth_token(&self) -> bool {
334        read_pyx_auth_token().is_some()
335    }
336
337    /// Returns `true` if the user appears to have an API key set.
338    pub fn has_api_key(&self) -> bool {
339        read_pyx_api_key().is_some()
340    }
341
342    /// Returns `true` if the user appears to have OAuth tokens stored on disk.
343    pub fn has_oauth_tokens(&self) -> bool {
344        self.subdirectory.join("tokens.json").is_file()
345    }
346
347    /// Returns `true` if the user appears to have credentials (which may be invalid).
348    pub fn has_credentials(&self) -> bool {
349        self.has_auth_token() || self.has_api_key() || self.has_oauth_tokens()
350    }
351
352    /// Read the tokens from the store.
353    pub async fn read(&self) -> Result<Option<PyxTokens>, TokenStoreError> {
354        if let Some(api_key) = read_pyx_api_key() {
355            // Read the API key tokens from a file based on the API key.
356            let digest = uv_cache_key::cache_digest(&api_key);
357            match fs_err::tokio::read(self.subdirectory.join(format!("{digest}.json"))).await {
358                Ok(data) => {
359                    let access_token =
360                        AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8"));
361                    Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens {
362                        access_token,
363                        api_key,
364                    })))
365                }
366                Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
367                Err(err) => Err(err.into()),
368            }
369        } else {
370            match fs_err::tokio::read(self.subdirectory.join("tokens.json")).await {
371                Ok(data) => {
372                    let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?;
373                    Ok(Some(PyxTokens::OAuth(tokens)))
374                }
375                Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
376                Err(err) => Err(err.into()),
377            }
378        }
379    }
380
381    /// Remove the tokens from the store.
382    pub async fn delete(&self) -> Result<(), io::Error> {
383        fs_err::tokio::remove_dir_all(&self.subdirectory).await?;
384        Ok(())
385    }
386
387    /// Return the path to the refresh lock file for a given token type.
388    ///
389    /// For OAuth tokens, uses a fixed "tokens.lock" file.
390    /// For API key tokens, uses a file based on the API key digest.
391    fn lock_path(&self, tokens: &PyxTokens) -> PathBuf {
392        match tokens {
393            PyxTokens::OAuth(_) => self.subdirectory.join("tokens.lock"),
394            PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
395                let digest = uv_cache_key::cache_digest(api_key);
396                self.subdirectory.join(format!("{digest}.lock"))
397            }
398        }
399    }
400
401    /// Bootstrap the tokens from the store.
402    async fn bootstrap(
403        &self,
404        client: &ClientWithMiddleware,
405    ) -> Result<Option<PyxTokens>, TokenStoreError> {
406        #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
407        struct Payload {
408            access_token: AccessToken,
409        }
410
411        // Retrieve the API key from the environment variable, if set.
412        let Some(api_key) = read_pyx_api_key() else {
413            return Ok(None);
414        };
415
416        debug!("Bootstrapping access token from an API key");
417
418        // Parse the API URL.
419        let mut url = self.api.clone();
420        url.set_path("auth/cli/access-token");
421
422        let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
423        request.headers_mut().insert(
424            "Authorization",
425            reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
426        );
427
428        let response = client.execute(request).await?;
429        let Payload { access_token } = response.error_for_status()?.json::<Payload>().await?;
430        let tokens = PyxTokens::ApiKey(PyxApiKeyTokens {
431            access_token,
432            api_key,
433        });
434
435        // Write the tokens to disk.
436        self.write(&tokens).await?;
437
438        Ok(Some(tokens))
439    }
440
441    /// Refresh the tokens in the store, if they are expired.
442    ///
443    /// In theory, we should _also_ refresh if we hit a 401; but for now, we only refresh ahead of
444    /// time.
445    async fn refresh(
446        &self,
447        tokens: PyxTokens,
448        client: &ClientWithMiddleware,
449        tolerance_secs: u64,
450    ) -> Result<PyxTokens, TokenStoreError> {
451        let reason = match tokens.check_fresh(tolerance_secs) {
452            Ok(exp) => {
453                debug!("Access token is up-to-date (`{exp}`)");
454                return Ok(tokens);
455            }
456            Err(reason) => reason,
457        };
458        debug!("Refreshing token due to {reason}");
459
460        // Ensure the subdirectory exists before acquiring the lock
461        fs_err::tokio::create_dir_all(&self.subdirectory).await?;
462
463        // Get the lock path for this specific token
464        let lock_path = self.lock_path(&tokens);
465
466        // Acquire a lock to prevent concurrent refresh attempts for this token
467        let _lock = LockedFile::acquire(&lock_path, LockedFileMode::Exclusive, "pyx refresh")
468            .await
469            .map_err(|err| TokenStoreError::Io(io::Error::other(err.to_string())))?;
470
471        // Check if another process has already refreshed the tokens
472        if let Some(tokens) = self.read().await? {
473            match tokens.check_fresh(tolerance_secs) {
474                Ok(exp) => {
475                    debug!("Using recently refreshed token (`{exp}`)");
476                    return Ok(tokens);
477                }
478                Err(reason) => {
479                    debug!("Token on disk still needs refresh due to {reason}");
480                }
481            }
482        }
483
484        // Refresh the tokens
485        let tokens = match tokens {
486            PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => {
487                // Parse the API URL.
488                let mut url = self.api.clone();
489                url.set_path("auth/cli/refresh");
490
491                let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
492                let body = serde_json::json!({
493                    "refresh_token": refresh_token
494                });
495                *request.body_mut() = Some(body.to_string().into());
496
497                let response = client.execute(request).await?;
498                let tokens = response
499                    .error_for_status()?
500                    .json::<PyxOAuthTokens>()
501                    .await?;
502                PyxTokens::OAuth(tokens)
503            }
504            PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
505                #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
506                struct Payload {
507                    access_token: AccessToken,
508                }
509
510                // Parse the API URL.
511                let mut url = self.api.clone();
512                url.set_path("auth/cli/access-token");
513
514                let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
515                request.headers_mut().insert(
516                    "Authorization",
517                    reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
518                );
519
520                let response = client.execute(request).await?;
521                let Payload { access_token } =
522                    response.error_for_status()?.json::<Payload>().await?;
523                PyxTokens::ApiKey(PyxApiKeyTokens {
524                    access_token,
525                    api_key,
526                })
527            }
528        };
529
530        // Write the new tokens to disk
531        self.write(&tokens).await?;
532
533        Ok(tokens)
534    }
535
536    /// Returns `true` if the given URL is "known" to this token store (i.e., should be
537    /// authenticated using the store's tokens).
538    pub fn is_known_url(&self, url: &Url) -> bool {
539        is_known_url(url, &self.api, &self.cdn)
540    }
541
542    /// Returns `true` if the URL is on a "known" domain (i.e., the same domain as the API or CDN).
543    ///
544    /// Like [`is_known_url`](Self::is_known_url), but also returns `true` if the API is on the
545    /// subdomain of the URL (e.g., if the API is `api.pyx.dev` and the URL is `pyx.dev`).
546    pub fn is_known_domain(&self, url: &Url) -> bool {
547        is_known_domain(url, &self.api, &self.cdn)
548    }
549}
550
551#[derive(thiserror::Error, Debug)]
552pub enum TokenStoreError {
553    #[error(transparent)]
554    Url(#[from] DisplaySafeUrlError),
555    #[error(transparent)]
556    Io(#[from] io::Error),
557    #[error(transparent)]
558    Serialization(#[from] serde_json::Error),
559    #[error(transparent)]
560    Reqwest(#[from] reqwest::Error),
561    #[error(transparent)]
562    ReqwestMiddleware(#[from] reqwest_middleware::Error),
563    #[error(transparent)]
564    InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
565    #[error(transparent)]
566    Jiff(#[from] jiff::Error),
567    #[error(transparent)]
568    Jwt(#[from] JwtError),
569}
570
571impl TokenStoreError {
572    /// Returns `true` if the error is a 401 (Unauthorized) error.
573    pub fn is_unauthorized(&self) -> bool {
574        match self {
575            Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
576            Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
577            _ => false,
578        }
579    }
580}
581
582/// The payload of the JWT.
583#[derive(Debug, serde::Deserialize)]
584pub struct PyxJwt {
585    /// The expiration time of the JWT, as a Unix timestamp.
586    pub exp: Option<i64>,
587    /// The issuer of the JWT.
588    pub iss: Option<String>,
589    /// The name of the organization, if any.
590    #[serde(rename = "urn:pyx:org_name")]
591    pub name: Option<String>,
592}
593
594impl PyxJwt {
595    /// Decode the JWT from the access token.
596    pub fn decode(access_token: &AccessToken) -> Result<Self, JwtError> {
597        let mut token_segments = access_token.as_str().splitn(3, '.');
598
599        let _header = token_segments.next().ok_or(JwtError::MissingHeader)?;
600        let payload = token_segments.next().ok_or(JwtError::MissingPayload)?;
601        let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?;
602        if token_segments.next().is_some() {
603            return Err(JwtError::TooManySegments);
604        }
605
606        let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
607
608        let jwt = serde_json::from_slice::<Self>(&decoded)?;
609        Ok(jwt)
610    }
611}
612
613#[derive(thiserror::Error, Debug)]
614pub enum JwtError {
615    #[error("JWT is missing a header")]
616    MissingHeader,
617    #[error("JWT is missing a payload")]
618    MissingPayload,
619    #[error("JWT is missing a signature")]
620    MissingSignature,
621    #[error("JWT has too many segments")]
622    TooManySegments,
623    #[error(transparent)]
624    Base64(#[from] base64::DecodeError),
625    #[error(transparent)]
626    Serde(#[from] serde_json::Error),
627}
628
629fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
630    // Determine whether the URL matches the API realm.
631    if Realm::from(url) == Realm::from(&**api) {
632        return true;
633    }
634
635    // Determine whether the URL matches the CDN domain (or a subdomain of it).
636    //
637    // For example, if URL is on `files.astralhosted.com` and the CDN domain is
638    // `astralhosted.com`, consider it known.
639    if matches!(url.scheme(), "https") && matches_domain(url, cdn) {
640        return true;
641    }
642
643    false
644}
645
646fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
647    // Determine whether the URL matches the API domain.
648    if let Some(domain) = url.domain() {
649        if matches_domain(api, domain) {
650            return true;
651        }
652    }
653    is_known_url(url, api, cdn)
654}
655
656/// Returns `true` if the URL is on the default pyx domain.
657///
658/// This is used in auth commands to recognize `pyx.dev` as a pyx domain even when
659/// `PYX_API_URL` points elsewhere (e.g., to a local development server).
660pub fn is_default_pyx_domain(url: &Url) -> bool {
661    let api = DisplaySafeUrl::parse(PYX_DEFAULT_API_URL).expect("default API URL should be valid");
662    is_known_domain(url, &api, PYX_DEFAULT_CDN_DOMAIN)
663}
664
665/// Returns `true` if the target URL is on the given domain.
666fn matches_domain(url: &Url, domain: &str) -> bool {
667    url.domain().is_some_and(|subdomain| {
668        subdomain == domain
669            || subdomain
670                .strip_suffix(domain)
671                .is_some_and(|prefix| prefix.ends_with('.'))
672    })
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678
679    #[test]
680    fn test_is_known_url() {
681        let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
682        let cdn_domain = "astralhosted.com";
683
684        // Same realm as API.
685        assert!(is_known_url(
686            &Url::parse("https://api.pyx.dev/simple/").unwrap(),
687            &api_url,
688            cdn_domain
689        ));
690
691        // Different path on same API domain
692        assert!(is_known_url(
693            &Url::parse("https://api.pyx.dev/v1/").unwrap(),
694            &api_url,
695            cdn_domain
696        ));
697
698        // CDN domain.
699        assert!(is_known_url(
700            &Url::parse("https://astralhosted.com/packages/").unwrap(),
701            &api_url,
702            cdn_domain
703        ));
704
705        // CDN subdomain.
706        assert!(is_known_url(
707            &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
708            &api_url,
709            cdn_domain
710        ));
711
712        // CDN on HTTP.
713        assert!(!is_known_url(
714            &Url::parse("http://astralhosted.com/packages/").unwrap(),
715            &api_url,
716            cdn_domain
717        ));
718
719        // Unknown domain.
720        assert!(!is_known_url(
721            &Url::parse("https://pypi.org/simple/").unwrap(),
722            &api_url,
723            cdn_domain
724        ));
725
726        // Similar but not matching domain.
727        assert!(!is_known_url(
728            &Url::parse("https://badastralhosted.com/packages/").unwrap(),
729            &api_url,
730            cdn_domain
731        ));
732    }
733
734    #[test]
735    fn test_is_known_domain() {
736        let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
737        let cdn_domain = "astralhosted.com";
738
739        // Same realm as API.
740        assert!(is_known_domain(
741            &Url::parse("https://api.pyx.dev/simple/").unwrap(),
742            &api_url,
743            cdn_domain
744        ));
745
746        // API super-domain.
747        assert!(is_known_domain(
748            &Url::parse("https://pyx.dev").unwrap(),
749            &api_url,
750            cdn_domain
751        ));
752
753        // API subdomain.
754        assert!(!is_known_domain(
755            &Url::parse("https://foo.api.pyx.dev").unwrap(),
756            &api_url,
757            cdn_domain
758        ));
759
760        // Different subdomain.
761        assert!(!is_known_domain(
762            &Url::parse("https://beta.pyx.dev/").unwrap(),
763            &api_url,
764            cdn_domain
765        ));
766
767        // CDN domain.
768        assert!(is_known_domain(
769            &Url::parse("https://astralhosted.com/packages/").unwrap(),
770            &api_url,
771            cdn_domain
772        ));
773
774        // CDN subdomain.
775        assert!(is_known_domain(
776            &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
777            &api_url,
778            cdn_domain
779        ));
780
781        // Unknown domain.
782        assert!(!is_known_domain(
783            &Url::parse("https://pypi.org/simple/").unwrap(),
784            &api_url,
785            cdn_domain
786        ));
787
788        // Different TLD.
789        assert!(!is_known_domain(
790            &Url::parse("https://pyx.com/").unwrap(),
791            &api_url,
792            cdn_domain
793        ));
794    }
795
796    #[test]
797    fn test_is_default_pyx_domain() {
798        // pyx.dev is the default domain.
799        assert!(is_default_pyx_domain(
800            &Url::parse("https://pyx.dev").unwrap()
801        ));
802
803        // Subdomains of pyx.dev are also recognized.
804        assert!(is_default_pyx_domain(
805            &Url::parse("https://api.pyx.dev").unwrap()
806        ));
807
808        // The default CDN domain is also recognized.
809        assert!(is_default_pyx_domain(
810            &Url::parse("https://astralhosted.com").unwrap()
811        ));
812        assert!(is_default_pyx_domain(
813            &Url::parse("https://files.astralhosted.com").unwrap()
814        ));
815
816        // Other domains are not.
817        assert!(!is_default_pyx_domain(
818            &Url::parse("http://localhost:8000").unwrap()
819        ));
820        assert!(!is_default_pyx_domain(
821            &Url::parse("https://pypi.org").unwrap()
822        ));
823        assert!(!is_default_pyx_domain(
824            &Url::parse("https://pyx.com").unwrap()
825        ));
826    }
827
828    #[test]
829    fn test_matches_domain() {
830        assert!(matches_domain(
831            &Url::parse("https://example.com").unwrap(),
832            "example.com"
833        ));
834        assert!(matches_domain(
835            &Url::parse("https://foo.example.com").unwrap(),
836            "example.com"
837        ));
838        assert!(matches_domain(
839            &Url::parse("https://bar.foo.example.com").unwrap(),
840            "example.com"
841        ));
842
843        assert!(!matches_domain(
844            &Url::parse("https://example.com").unwrap(),
845            "other.com"
846        ));
847        assert!(!matches_domain(
848            &Url::parse("https://example.org").unwrap(),
849            "example.com"
850        ));
851        assert!(!matches_domain(
852            &Url::parse("https://badexample.com").unwrap(),
853            "example.com"
854        ));
855    }
856
857    #[test]
858    fn test_is_default_pyx_domain_staging() {
859        // Staging URLs should NOT be recognized as default pyx domain.
860        // Users must set PYX_API_URL to use staging environments.
861        assert!(!is_default_pyx_domain(
862            &Url::parse("https://astral-sh-staging-api.pyx.dev").unwrap()
863        ));
864
865        // Other non-default pyx subdomains should also not match.
866        assert!(!is_default_pyx_domain(
867            &Url::parse("https://beta.pyx.dev").unwrap()
868        ));
869    }
870}