1use std::io;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use base64::Engine;
6use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use etcetera::BaseStrategy;
8use reqwest_middleware::ClientWithMiddleware;
9use tracing::debug;
10use url::Url;
11use uv_fs::{LockedFile, LockedFileMode};
12
13use uv_cache_key::CanonicalUrl;
14use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
15use uv_small_str::SmallString;
16use uv_state::{StateBucket, StateStore};
17use uv_static::EnvVars;
18
19use crate::credentials::Token;
20use crate::{AccessToken, Credentials, Realm};
21
22const PYX_DEFAULT_API_URL: &str = "https://api.pyx.dev";
24
25const PYX_DEFAULT_CDN_DOMAIN: &str = "astralhosted.com";
27
28fn read_pyx_api_key() -> Option<String> {
30 std::env::var(EnvVars::PYX_API_KEY)
31 .ok()
32 .or_else(|| std::env::var(EnvVars::UV_API_KEY).ok())
33}
34
35fn read_pyx_auth_token() -> Option<AccessToken> {
37 std::env::var(EnvVars::PYX_AUTH_TOKEN)
38 .ok()
39 .or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok())
40 .map(AccessToken::from)
41}
42
43#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
48pub struct PyxOAuthTokens {
49 pub access_token: AccessToken,
50 pub refresh_token: String,
51}
52
53#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
55pub struct PyxApiKeyTokens {
56 pub access_token: AccessToken,
57 pub api_key: String,
58}
59
60#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
61pub enum PyxTokens {
62 OAuth(PyxOAuthTokens),
67 ApiKey(PyxApiKeyTokens),
71}
72
73impl From<PyxTokens> for AccessToken {
74 fn from(tokens: PyxTokens) -> Self {
75 match tokens {
76 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
77 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
78 }
79 }
80}
81
82impl From<PyxTokens> for Credentials {
83 fn from(tokens: PyxTokens) -> Self {
84 let access_token = match tokens {
85 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
86 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
87 };
88 Self::from(access_token)
89 }
90}
91
92impl From<AccessToken> for Credentials {
93 fn from(access_token: AccessToken) -> Self {
94 Self::Bearer {
95 token: Token::new(access_token.into_bytes()),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102enum ExpiredTokenReason {
103 MissingExpiration,
105 ForcedRefresh,
107 Expired(jiff::Timestamp),
109 ExpiringSoon(jiff::Timestamp),
111}
112
113impl std::fmt::Display for ExpiredTokenReason {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 Self::MissingExpiration => write!(f, "missing expiration"),
117 Self::ForcedRefresh => write!(f, "forced refresh"),
118 Self::Expired(exp) => write!(f, "token expired (`{exp}`)"),
119 Self::ExpiringSoon(exp) => write!(f, "token will expire within tolerance (`{exp}`)"),
120 }
121 }
122}
123
124impl PyxTokens {
125 fn access_token(&self) -> &AccessToken {
127 match self {
128 Self::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
129 Self::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
130 }
131 }
132
133 fn check_fresh(&self, tolerance_secs: u64) -> Result<jiff::Timestamp, ExpiredTokenReason> {
137 let Ok(jwt) = PyxJwt::decode(self.access_token()) else {
138 return Err(ExpiredTokenReason::MissingExpiration);
139 };
140 match jwt.exp {
141 None => Err(ExpiredTokenReason::MissingExpiration),
142 Some(_) if tolerance_secs == 0 => Err(ExpiredTokenReason::ForcedRefresh),
143 Some(exp) => {
144 let Ok(exp) = jiff::Timestamp::from_second(exp) else {
145 return Err(ExpiredTokenReason::MissingExpiration);
146 };
147 let now = jiff::Timestamp::now();
148 if exp < now {
149 Err(ExpiredTokenReason::Expired(exp))
150 } else if exp < now + Duration::from_secs(tolerance_secs) {
151 Err(ExpiredTokenReason::ExpiringSoon(exp))
152 } else {
153 Ok(exp)
154 }
155 }
156 }
157 }
158}
159
160pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5;
162
163#[derive(Debug, Clone)]
164struct PyxDirectories {
165 root: PathBuf,
167 subdirectory: PathBuf,
169}
170
171impl PyxDirectories {
172 fn from_api(api: &DisplaySafeUrl) -> Result<Self, io::Error> {
174 let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api));
176
177 if let Some(root) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) {
179 let root = std::path::absolute(root)?;
180 let subdirectory = root.join(&digest);
181 return Ok(Self { root, subdirectory });
182 }
183
184 let root = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) {
187 std::path::absolute(tool_dir)?
188 } else {
189 StateStore::from_settings(None)?.bucket(StateBucket::Credentials)
190 };
191 let subdirectory = root.join(&digest);
192 if subdirectory.exists() {
193 return Ok(Self { root, subdirectory });
194 }
195
196 let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else {
198 return Err(io::Error::new(
199 io::ErrorKind::NotFound,
200 "Could not determine user data directory",
201 ));
202 };
203
204 let root = xdg.data_dir().join("pyx").join("credentials");
205 let subdirectory = root.join(&digest);
206 Ok(Self { root, subdirectory })
207 }
208}
209
210#[derive(Debug, Clone)]
211pub struct PyxTokenStore {
212 root: PathBuf,
214 subdirectory: PathBuf,
216 api: DisplaySafeUrl,
218 cdn: SmallString,
220}
221
222impl PyxTokenStore {
223 pub fn from_settings() -> Result<Self, TokenStoreError> {
225 let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) {
228 DisplaySafeUrl::parse(&api_url)
229 } else {
230 DisplaySafeUrl::parse(PYX_DEFAULT_API_URL)
231 }?;
232 let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN)
233 .ok()
234 .map(SmallString::from)
235 .unwrap_or_else(|| SmallString::from(arcstr::literal!(PYX_DEFAULT_CDN_DOMAIN)));
236
237 let PyxDirectories { root, subdirectory } = PyxDirectories::from_api(&api)?;
239
240 Ok(Self {
241 root,
242 subdirectory,
243 api,
244 cdn,
245 })
246 }
247
248 pub fn root(&self) -> &Path {
250 &self.root
251 }
252
253 pub fn api(&self) -> &DisplaySafeUrl {
255 &self.api
256 }
257
258 pub async fn access_token(
267 &self,
268 client: &ClientWithMiddleware,
269 tolerance_secs: u64,
270 ) -> Result<Option<AccessToken>, TokenStoreError> {
271 if let Some(access_token) = read_pyx_auth_token() {
273 return Ok(Some(access_token));
274 }
275
276 let tokens = self.init(client, tolerance_secs).await?;
278
279 Ok(tokens.map(AccessToken::from))
281 }
282
283 pub async fn init(
290 &self,
291 client: &ClientWithMiddleware,
292 tolerance_secs: u64,
293 ) -> Result<Option<PyxTokens>, TokenStoreError> {
294 match self.read().await? {
295 Some(tokens) => {
296 let tokens = self.refresh(tokens, client, tolerance_secs).await?;
298 Ok(Some(tokens))
299 }
300 None => {
301 self.bootstrap(client).await
303 }
304 }
305 }
306
307 pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> {
309 fs_err::tokio::create_dir_all(&self.subdirectory).await?;
310 match tokens {
311 PyxTokens::OAuth(tokens) => {
312 fs_err::tokio::write(
314 self.subdirectory.join("tokens.json"),
315 serde_json::to_vec(tokens)?,
316 )
317 .await?;
318 }
319 PyxTokens::ApiKey(tokens) => {
320 let digest = uv_cache_key::cache_digest(&tokens.api_key);
322 fs_err::tokio::write(
323 self.subdirectory.join(format!("{digest}.json")),
324 &tokens.access_token,
325 )
326 .await?;
327 }
328 }
329 Ok(())
330 }
331
332 pub fn has_auth_token(&self) -> bool {
334 read_pyx_auth_token().is_some()
335 }
336
337 pub fn has_api_key(&self) -> bool {
339 read_pyx_api_key().is_some()
340 }
341
342 pub fn has_oauth_tokens(&self) -> bool {
344 self.subdirectory.join("tokens.json").is_file()
345 }
346
347 pub fn has_credentials(&self) -> bool {
349 self.has_auth_token() || self.has_api_key() || self.has_oauth_tokens()
350 }
351
352 pub async fn read(&self) -> Result<Option<PyxTokens>, TokenStoreError> {
354 if let Some(api_key) = read_pyx_api_key() {
355 let digest = uv_cache_key::cache_digest(&api_key);
357 match fs_err::tokio::read(self.subdirectory.join(format!("{digest}.json"))).await {
358 Ok(data) => {
359 let access_token =
360 AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8"));
361 Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens {
362 access_token,
363 api_key,
364 })))
365 }
366 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
367 Err(err) => Err(err.into()),
368 }
369 } else {
370 match fs_err::tokio::read(self.subdirectory.join("tokens.json")).await {
371 Ok(data) => {
372 let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?;
373 Ok(Some(PyxTokens::OAuth(tokens)))
374 }
375 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
376 Err(err) => Err(err.into()),
377 }
378 }
379 }
380
381 pub async fn delete(&self) -> Result<(), io::Error> {
383 fs_err::tokio::remove_dir_all(&self.subdirectory).await?;
384 Ok(())
385 }
386
387 fn lock_path(&self, tokens: &PyxTokens) -> PathBuf {
392 match tokens {
393 PyxTokens::OAuth(_) => self.subdirectory.join("tokens.lock"),
394 PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
395 let digest = uv_cache_key::cache_digest(api_key);
396 self.subdirectory.join(format!("{digest}.lock"))
397 }
398 }
399 }
400
401 async fn bootstrap(
403 &self,
404 client: &ClientWithMiddleware,
405 ) -> Result<Option<PyxTokens>, TokenStoreError> {
406 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
407 struct Payload {
408 access_token: AccessToken,
409 }
410
411 let Some(api_key) = read_pyx_api_key() else {
413 return Ok(None);
414 };
415
416 debug!("Bootstrapping access token from an API key");
417
418 let mut url = self.api.clone();
420 url.set_path("auth/cli/access-token");
421
422 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
423 request.headers_mut().insert(
424 "Authorization",
425 reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
426 );
427
428 let response = client.execute(request).await?;
429 let Payload { access_token } = response.error_for_status()?.json::<Payload>().await?;
430 let tokens = PyxTokens::ApiKey(PyxApiKeyTokens {
431 access_token,
432 api_key,
433 });
434
435 self.write(&tokens).await?;
437
438 Ok(Some(tokens))
439 }
440
441 async fn refresh(
446 &self,
447 tokens: PyxTokens,
448 client: &ClientWithMiddleware,
449 tolerance_secs: u64,
450 ) -> Result<PyxTokens, TokenStoreError> {
451 let reason = match tokens.check_fresh(tolerance_secs) {
452 Ok(exp) => {
453 debug!("Access token is up-to-date (`{exp}`)");
454 return Ok(tokens);
455 }
456 Err(reason) => reason,
457 };
458 debug!("Refreshing token due to {reason}");
459
460 fs_err::tokio::create_dir_all(&self.subdirectory).await?;
462
463 let lock_path = self.lock_path(&tokens);
465
466 let _lock = LockedFile::acquire(&lock_path, LockedFileMode::Exclusive, "pyx refresh")
468 .await
469 .map_err(|err| TokenStoreError::Io(io::Error::other(err.to_string())))?;
470
471 if let Some(tokens) = self.read().await? {
473 match tokens.check_fresh(tolerance_secs) {
474 Ok(exp) => {
475 debug!("Using recently refreshed token (`{exp}`)");
476 return Ok(tokens);
477 }
478 Err(reason) => {
479 debug!("Token on disk still needs refresh due to {reason}");
480 }
481 }
482 }
483
484 let tokens = match tokens {
486 PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => {
487 let mut url = self.api.clone();
489 url.set_path("auth/cli/refresh");
490
491 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
492 let body = serde_json::json!({
493 "refresh_token": refresh_token
494 });
495 *request.body_mut() = Some(body.to_string().into());
496
497 let response = client.execute(request).await?;
498 let tokens = response
499 .error_for_status()?
500 .json::<PyxOAuthTokens>()
501 .await?;
502 PyxTokens::OAuth(tokens)
503 }
504 PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
505 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
506 struct Payload {
507 access_token: AccessToken,
508 }
509
510 let mut url = self.api.clone();
512 url.set_path("auth/cli/access-token");
513
514 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
515 request.headers_mut().insert(
516 "Authorization",
517 reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
518 );
519
520 let response = client.execute(request).await?;
521 let Payload { access_token } =
522 response.error_for_status()?.json::<Payload>().await?;
523 PyxTokens::ApiKey(PyxApiKeyTokens {
524 access_token,
525 api_key,
526 })
527 }
528 };
529
530 self.write(&tokens).await?;
532
533 Ok(tokens)
534 }
535
536 pub fn is_known_url(&self, url: &Url) -> bool {
539 is_known_url(url, &self.api, &self.cdn)
540 }
541
542 pub fn is_known_domain(&self, url: &Url) -> bool {
547 is_known_domain(url, &self.api, &self.cdn)
548 }
549}
550
551#[derive(thiserror::Error, Debug)]
552pub enum TokenStoreError {
553 #[error(transparent)]
554 Url(#[from] DisplaySafeUrlError),
555 #[error(transparent)]
556 Io(#[from] io::Error),
557 #[error(transparent)]
558 Serialization(#[from] serde_json::Error),
559 #[error(transparent)]
560 Reqwest(#[from] reqwest::Error),
561 #[error(transparent)]
562 ReqwestMiddleware(#[from] reqwest_middleware::Error),
563 #[error(transparent)]
564 InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
565 #[error(transparent)]
566 Jiff(#[from] jiff::Error),
567 #[error(transparent)]
568 Jwt(#[from] JwtError),
569}
570
571impl TokenStoreError {
572 pub fn is_unauthorized(&self) -> bool {
574 match self {
575 Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
576 Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
577 _ => false,
578 }
579 }
580}
581
582#[derive(Debug, serde::Deserialize)]
584pub struct PyxJwt {
585 pub exp: Option<i64>,
587 pub iss: Option<String>,
589 #[serde(rename = "urn:pyx:org_name")]
591 pub name: Option<String>,
592}
593
594impl PyxJwt {
595 pub fn decode(access_token: &AccessToken) -> Result<Self, JwtError> {
597 let mut token_segments = access_token.as_str().splitn(3, '.');
598
599 let _header = token_segments.next().ok_or(JwtError::MissingHeader)?;
600 let payload = token_segments.next().ok_or(JwtError::MissingPayload)?;
601 let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?;
602 if token_segments.next().is_some() {
603 return Err(JwtError::TooManySegments);
604 }
605
606 let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
607
608 let jwt = serde_json::from_slice::<Self>(&decoded)?;
609 Ok(jwt)
610 }
611}
612
613#[derive(thiserror::Error, Debug)]
614pub enum JwtError {
615 #[error("JWT is missing a header")]
616 MissingHeader,
617 #[error("JWT is missing a payload")]
618 MissingPayload,
619 #[error("JWT is missing a signature")]
620 MissingSignature,
621 #[error("JWT has too many segments")]
622 TooManySegments,
623 #[error(transparent)]
624 Base64(#[from] base64::DecodeError),
625 #[error(transparent)]
626 Serde(#[from] serde_json::Error),
627}
628
629fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
630 if Realm::from(url) == Realm::from(&**api) {
632 return true;
633 }
634
635 if matches!(url.scheme(), "https") && matches_domain(url, cdn) {
640 return true;
641 }
642
643 false
644}
645
646fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
647 if let Some(domain) = url.domain() {
649 if matches_domain(api, domain) {
650 return true;
651 }
652 }
653 is_known_url(url, api, cdn)
654}
655
656pub fn is_default_pyx_domain(url: &Url) -> bool {
661 let api = DisplaySafeUrl::parse(PYX_DEFAULT_API_URL).expect("default API URL should be valid");
662 is_known_domain(url, &api, PYX_DEFAULT_CDN_DOMAIN)
663}
664
665fn matches_domain(url: &Url, domain: &str) -> bool {
667 url.domain().is_some_and(|subdomain| {
668 subdomain == domain
669 || subdomain
670 .strip_suffix(domain)
671 .is_some_and(|prefix| prefix.ends_with('.'))
672 })
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678
679 #[test]
680 fn test_is_known_url() {
681 let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
682 let cdn_domain = "astralhosted.com";
683
684 assert!(is_known_url(
686 &Url::parse("https://api.pyx.dev/simple/").unwrap(),
687 &api_url,
688 cdn_domain
689 ));
690
691 assert!(is_known_url(
693 &Url::parse("https://api.pyx.dev/v1/").unwrap(),
694 &api_url,
695 cdn_domain
696 ));
697
698 assert!(is_known_url(
700 &Url::parse("https://astralhosted.com/packages/").unwrap(),
701 &api_url,
702 cdn_domain
703 ));
704
705 assert!(is_known_url(
707 &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
708 &api_url,
709 cdn_domain
710 ));
711
712 assert!(!is_known_url(
714 &Url::parse("http://astralhosted.com/packages/").unwrap(),
715 &api_url,
716 cdn_domain
717 ));
718
719 assert!(!is_known_url(
721 &Url::parse("https://pypi.org/simple/").unwrap(),
722 &api_url,
723 cdn_domain
724 ));
725
726 assert!(!is_known_url(
728 &Url::parse("https://badastralhosted.com/packages/").unwrap(),
729 &api_url,
730 cdn_domain
731 ));
732 }
733
734 #[test]
735 fn test_is_known_domain() {
736 let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
737 let cdn_domain = "astralhosted.com";
738
739 assert!(is_known_domain(
741 &Url::parse("https://api.pyx.dev/simple/").unwrap(),
742 &api_url,
743 cdn_domain
744 ));
745
746 assert!(is_known_domain(
748 &Url::parse("https://pyx.dev").unwrap(),
749 &api_url,
750 cdn_domain
751 ));
752
753 assert!(!is_known_domain(
755 &Url::parse("https://foo.api.pyx.dev").unwrap(),
756 &api_url,
757 cdn_domain
758 ));
759
760 assert!(!is_known_domain(
762 &Url::parse("https://beta.pyx.dev/").unwrap(),
763 &api_url,
764 cdn_domain
765 ));
766
767 assert!(is_known_domain(
769 &Url::parse("https://astralhosted.com/packages/").unwrap(),
770 &api_url,
771 cdn_domain
772 ));
773
774 assert!(is_known_domain(
776 &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
777 &api_url,
778 cdn_domain
779 ));
780
781 assert!(!is_known_domain(
783 &Url::parse("https://pypi.org/simple/").unwrap(),
784 &api_url,
785 cdn_domain
786 ));
787
788 assert!(!is_known_domain(
790 &Url::parse("https://pyx.com/").unwrap(),
791 &api_url,
792 cdn_domain
793 ));
794 }
795
796 #[test]
797 fn test_is_default_pyx_domain() {
798 assert!(is_default_pyx_domain(
800 &Url::parse("https://pyx.dev").unwrap()
801 ));
802
803 assert!(is_default_pyx_domain(
805 &Url::parse("https://api.pyx.dev").unwrap()
806 ));
807
808 assert!(is_default_pyx_domain(
810 &Url::parse("https://astralhosted.com").unwrap()
811 ));
812 assert!(is_default_pyx_domain(
813 &Url::parse("https://files.astralhosted.com").unwrap()
814 ));
815
816 assert!(!is_default_pyx_domain(
818 &Url::parse("http://localhost:8000").unwrap()
819 ));
820 assert!(!is_default_pyx_domain(
821 &Url::parse("https://pypi.org").unwrap()
822 ));
823 assert!(!is_default_pyx_domain(
824 &Url::parse("https://pyx.com").unwrap()
825 ));
826 }
827
828 #[test]
829 fn test_matches_domain() {
830 assert!(matches_domain(
831 &Url::parse("https://example.com").unwrap(),
832 "example.com"
833 ));
834 assert!(matches_domain(
835 &Url::parse("https://foo.example.com").unwrap(),
836 "example.com"
837 ));
838 assert!(matches_domain(
839 &Url::parse("https://bar.foo.example.com").unwrap(),
840 "example.com"
841 ));
842
843 assert!(!matches_domain(
844 &Url::parse("https://example.com").unwrap(),
845 "other.com"
846 ));
847 assert!(!matches_domain(
848 &Url::parse("https://example.org").unwrap(),
849 "example.com"
850 ));
851 assert!(!matches_domain(
852 &Url::parse("https://badexample.com").unwrap(),
853 "example.com"
854 ));
855 }
856
857 #[test]
858 fn test_is_default_pyx_domain_staging() {
859 assert!(!is_default_pyx_domain(
862 &Url::parse("https://astral-sh-staging-api.pyx.dev").unwrap()
863 ));
864
865 assert!(!is_default_pyx_domain(
867 &Url::parse("https://beta.pyx.dev").unwrap()
868 ));
869 }
870}