1use std::io;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use base64::Engine;
6use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use etcetera::BaseStrategy;
8use reqwest_middleware::ClientWithMiddleware;
9use tracing::debug;
10use url::Url;
11use uv_fs::{LockedFile, LockedFileMode};
12
13use uv_cache_key::CanonicalUrl;
14use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
15use uv_small_str::SmallString;
16use uv_state::{StateBucket, StateStore};
17use uv_static::EnvVars;
18
19use crate::credentials::Token;
20use crate::{AccessToken, Credentials, Realm};
21
22fn read_pyx_api_key() -> Option<String> {
24 std::env::var(EnvVars::PYX_API_KEY)
25 .ok()
26 .or_else(|| std::env::var(EnvVars::UV_API_KEY).ok())
27}
28
29fn read_pyx_auth_token() -> Option<AccessToken> {
31 std::env::var(EnvVars::PYX_AUTH_TOKEN)
32 .ok()
33 .or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok())
34 .map(AccessToken::from)
35}
36
37#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
42pub struct PyxOAuthTokens {
43 pub access_token: AccessToken,
44 pub refresh_token: String,
45}
46
47#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
49pub struct PyxApiKeyTokens {
50 pub access_token: AccessToken,
51 pub api_key: String,
52}
53
54#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
55pub enum PyxTokens {
56 OAuth(PyxOAuthTokens),
61 ApiKey(PyxApiKeyTokens),
65}
66
67impl From<PyxTokens> for AccessToken {
68 fn from(tokens: PyxTokens) -> Self {
69 match tokens {
70 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
71 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
72 }
73 }
74}
75
76impl From<PyxTokens> for Credentials {
77 fn from(tokens: PyxTokens) -> Self {
78 let access_token = match tokens {
79 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
80 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
81 };
82 Self::from(access_token)
83 }
84}
85
86impl From<AccessToken> for Credentials {
87 fn from(access_token: AccessToken) -> Self {
88 Self::Bearer {
89 token: Token::new(access_token.into_bytes()),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96enum ExpiredTokenReason {
97 MissingExpiration,
99 ForcedRefresh,
101 Expired(jiff::Timestamp),
103 ExpiringSoon(jiff::Timestamp),
105}
106
107impl std::fmt::Display for ExpiredTokenReason {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 Self::MissingExpiration => write!(f, "missing expiration"),
111 Self::ForcedRefresh => write!(f, "forced refresh"),
112 Self::Expired(exp) => write!(f, "token expired (`{exp}`)"),
113 Self::ExpiringSoon(exp) => write!(f, "token will expire within tolerance (`{exp}`)"),
114 }
115 }
116}
117
118impl PyxTokens {
119 fn access_token(&self) -> &AccessToken {
121 match self {
122 Self::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
123 Self::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
124 }
125 }
126
127 fn check_fresh(&self, tolerance_secs: u64) -> Result<jiff::Timestamp, ExpiredTokenReason> {
131 let Ok(jwt) = PyxJwt::decode(self.access_token()) else {
132 return Err(ExpiredTokenReason::MissingExpiration);
133 };
134 match jwt.exp {
135 None => Err(ExpiredTokenReason::MissingExpiration),
136 Some(_) if tolerance_secs == 0 => Err(ExpiredTokenReason::ForcedRefresh),
137 Some(exp) => {
138 let Ok(exp) = jiff::Timestamp::from_second(exp) else {
139 return Err(ExpiredTokenReason::MissingExpiration);
140 };
141 let now = jiff::Timestamp::now();
142 if exp < now {
143 Err(ExpiredTokenReason::Expired(exp))
144 } else if exp < now + Duration::from_secs(tolerance_secs) {
145 Err(ExpiredTokenReason::ExpiringSoon(exp))
146 } else {
147 Ok(exp)
148 }
149 }
150 }
151 }
152}
153
154pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5;
156
157#[derive(Debug, Clone)]
158struct PyxDirectories {
159 root: PathBuf,
161 subdirectory: PathBuf,
163}
164
165impl PyxDirectories {
166 fn from_api(api: &DisplaySafeUrl) -> Result<Self, io::Error> {
168 let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api));
170
171 if let Some(root) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) {
173 let root = std::path::absolute(root)?;
174 let subdirectory = root.join(&digest);
175 return Ok(Self { root, subdirectory });
176 }
177
178 let root = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) {
181 std::path::absolute(tool_dir)?
182 } else {
183 StateStore::from_settings(None)?.bucket(StateBucket::Credentials)
184 };
185 let subdirectory = root.join(&digest);
186 if subdirectory.exists() {
187 return Ok(Self { root, subdirectory });
188 }
189
190 let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else {
192 return Err(io::Error::new(
193 io::ErrorKind::NotFound,
194 "Could not determine user data directory",
195 ));
196 };
197
198 let root = xdg.data_dir().join("pyx").join("credentials");
199 let subdirectory = root.join(&digest);
200 Ok(Self { root, subdirectory })
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct PyxTokenStore {
206 root: PathBuf,
208 subdirectory: PathBuf,
210 api: DisplaySafeUrl,
212 cdn: SmallString,
214}
215
216impl PyxTokenStore {
217 pub fn from_settings() -> Result<Self, TokenStoreError> {
219 let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) {
222 DisplaySafeUrl::parse(&api_url)
223 } else {
224 DisplaySafeUrl::parse("https://api.pyx.dev")
225 }?;
226 let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN)
227 .ok()
228 .map(SmallString::from)
229 .unwrap_or_else(|| SmallString::from(arcstr::literal!("astralhosted.com")));
230
231 let PyxDirectories { root, subdirectory } = PyxDirectories::from_api(&api)?;
233
234 Ok(Self {
235 root,
236 subdirectory,
237 api,
238 cdn,
239 })
240 }
241
242 pub fn root(&self) -> &Path {
244 &self.root
245 }
246
247 pub fn api(&self) -> &DisplaySafeUrl {
249 &self.api
250 }
251
252 pub async fn access_token(
261 &self,
262 client: &ClientWithMiddleware,
263 tolerance_secs: u64,
264 ) -> Result<Option<AccessToken>, TokenStoreError> {
265 if let Some(access_token) = read_pyx_auth_token() {
267 return Ok(Some(access_token));
268 }
269
270 let tokens = self.init(client, tolerance_secs).await?;
272
273 Ok(tokens.map(AccessToken::from))
275 }
276
277 pub async fn init(
284 &self,
285 client: &ClientWithMiddleware,
286 tolerance_secs: u64,
287 ) -> Result<Option<PyxTokens>, TokenStoreError> {
288 match self.read().await? {
289 Some(tokens) => {
290 let tokens = self.refresh(tokens, client, tolerance_secs).await?;
292 Ok(Some(tokens))
293 }
294 None => {
295 self.bootstrap(client).await
297 }
298 }
299 }
300
301 pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> {
303 fs_err::tokio::create_dir_all(&self.subdirectory).await?;
304 match tokens {
305 PyxTokens::OAuth(tokens) => {
306 fs_err::tokio::write(
308 self.subdirectory.join("tokens.json"),
309 serde_json::to_vec(tokens)?,
310 )
311 .await?;
312 }
313 PyxTokens::ApiKey(tokens) => {
314 let digest = uv_cache_key::cache_digest(&tokens.api_key);
316 fs_err::tokio::write(
317 self.subdirectory.join(format!("{digest}.json")),
318 &tokens.access_token,
319 )
320 .await?;
321 }
322 }
323 Ok(())
324 }
325
326 pub fn has_auth_token(&self) -> bool {
328 read_pyx_auth_token().is_some()
329 }
330
331 pub fn has_api_key(&self) -> bool {
333 read_pyx_api_key().is_some()
334 }
335
336 pub fn has_oauth_tokens(&self) -> bool {
338 self.subdirectory.join("tokens.json").is_file()
339 }
340
341 pub fn has_credentials(&self) -> bool {
343 self.has_auth_token() || self.has_api_key() || self.has_oauth_tokens()
344 }
345
346 pub async fn read(&self) -> Result<Option<PyxTokens>, TokenStoreError> {
348 if let Some(api_key) = read_pyx_api_key() {
349 let digest = uv_cache_key::cache_digest(&api_key);
351 match fs_err::tokio::read(self.subdirectory.join(format!("{digest}.json"))).await {
352 Ok(data) => {
353 let access_token =
354 AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8"));
355 Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens {
356 access_token,
357 api_key,
358 })))
359 }
360 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
361 Err(err) => Err(err.into()),
362 }
363 } else {
364 match fs_err::tokio::read(self.subdirectory.join("tokens.json")).await {
365 Ok(data) => {
366 let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?;
367 Ok(Some(PyxTokens::OAuth(tokens)))
368 }
369 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
370 Err(err) => Err(err.into()),
371 }
372 }
373 }
374
375 pub async fn delete(&self) -> Result<(), io::Error> {
377 fs_err::tokio::remove_dir_all(&self.subdirectory).await?;
378 Ok(())
379 }
380
381 fn lock_path(&self, tokens: &PyxTokens) -> PathBuf {
386 match tokens {
387 PyxTokens::OAuth(_) => self.subdirectory.join("tokens.lock"),
388 PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
389 let digest = uv_cache_key::cache_digest(api_key);
390 self.subdirectory.join(format!("{digest}.lock"))
391 }
392 }
393 }
394
395 async fn bootstrap(
397 &self,
398 client: &ClientWithMiddleware,
399 ) -> Result<Option<PyxTokens>, TokenStoreError> {
400 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
401 struct Payload {
402 access_token: AccessToken,
403 }
404
405 let Some(api_key) = read_pyx_api_key() else {
407 return Ok(None);
408 };
409
410 debug!("Bootstrapping access token from an API key");
411
412 let mut url = self.api.clone();
414 url.set_path("auth/cli/access-token");
415
416 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
417 request.headers_mut().insert(
418 "Authorization",
419 reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
420 );
421
422 let response = client.execute(request).await?;
423 let Payload { access_token } = response.error_for_status()?.json::<Payload>().await?;
424 let tokens = PyxTokens::ApiKey(PyxApiKeyTokens {
425 access_token,
426 api_key,
427 });
428
429 self.write(&tokens).await?;
431
432 Ok(Some(tokens))
433 }
434
435 async fn refresh(
440 &self,
441 tokens: PyxTokens,
442 client: &ClientWithMiddleware,
443 tolerance_secs: u64,
444 ) -> Result<PyxTokens, TokenStoreError> {
445 let reason = match tokens.check_fresh(tolerance_secs) {
446 Ok(exp) => {
447 debug!("Access token is up-to-date (`{exp}`)");
448 return Ok(tokens);
449 }
450 Err(reason) => reason,
451 };
452 debug!("Refreshing token due to {reason}");
453
454 fs_err::tokio::create_dir_all(&self.subdirectory).await?;
456
457 let lock_path = self.lock_path(&tokens);
459
460 let _lock = LockedFile::acquire(&lock_path, LockedFileMode::Exclusive, "pyx refresh")
462 .await
463 .map_err(|err| TokenStoreError::Io(io::Error::other(err.to_string())))?;
464
465 if let Some(tokens) = self.read().await? {
467 match tokens.check_fresh(tolerance_secs) {
468 Ok(exp) => {
469 debug!("Using recently refreshed token (`{exp}`)");
470 return Ok(tokens);
471 }
472 Err(reason) => {
473 debug!("Token on disk still needs refresh due to {reason}");
474 }
475 }
476 }
477
478 let tokens = match tokens {
480 PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => {
481 let mut url = self.api.clone();
483 url.set_path("auth/cli/refresh");
484
485 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
486 let body = serde_json::json!({
487 "refresh_token": refresh_token
488 });
489 *request.body_mut() = Some(body.to_string().into());
490
491 let response = client.execute(request).await?;
492 let tokens = response
493 .error_for_status()?
494 .json::<PyxOAuthTokens>()
495 .await?;
496 PyxTokens::OAuth(tokens)
497 }
498 PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
499 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
500 struct Payload {
501 access_token: AccessToken,
502 }
503
504 let mut url = self.api.clone();
506 url.set_path("auth/cli/access-token");
507
508 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
509 request.headers_mut().insert(
510 "Authorization",
511 reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
512 );
513
514 let response = client.execute(request).await?;
515 let Payload { access_token } =
516 response.error_for_status()?.json::<Payload>().await?;
517 PyxTokens::ApiKey(PyxApiKeyTokens {
518 access_token,
519 api_key,
520 })
521 }
522 };
523
524 self.write(&tokens).await?;
526
527 Ok(tokens)
528 }
529
530 pub fn is_known_url(&self, url: &Url) -> bool {
533 is_known_url(url, &self.api, &self.cdn)
534 }
535
536 pub fn is_known_domain(&self, url: &Url) -> bool {
541 is_known_domain(url, &self.api, &self.cdn)
542 }
543}
544
545#[derive(thiserror::Error, Debug)]
546pub enum TokenStoreError {
547 #[error(transparent)]
548 Url(#[from] DisplaySafeUrlError),
549 #[error(transparent)]
550 Io(#[from] io::Error),
551 #[error(transparent)]
552 Serialization(#[from] serde_json::Error),
553 #[error(transparent)]
554 Reqwest(#[from] reqwest::Error),
555 #[error(transparent)]
556 ReqwestMiddleware(#[from] reqwest_middleware::Error),
557 #[error(transparent)]
558 InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
559 #[error(transparent)]
560 Jiff(#[from] jiff::Error),
561 #[error(transparent)]
562 Jwt(#[from] JwtError),
563}
564
565impl TokenStoreError {
566 pub fn is_unauthorized(&self) -> bool {
568 match self {
569 Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
570 Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
571 _ => false,
572 }
573 }
574}
575
576#[derive(Debug, serde::Deserialize)]
578pub struct PyxJwt {
579 pub exp: Option<i64>,
581 pub iss: Option<String>,
583 #[serde(rename = "urn:pyx:org_name")]
585 pub name: Option<String>,
586}
587
588impl PyxJwt {
589 pub fn decode(access_token: &AccessToken) -> Result<Self, JwtError> {
591 let mut token_segments = access_token.as_str().splitn(3, '.');
592
593 let _header = token_segments.next().ok_or(JwtError::MissingHeader)?;
594 let payload = token_segments.next().ok_or(JwtError::MissingPayload)?;
595 let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?;
596 if token_segments.next().is_some() {
597 return Err(JwtError::TooManySegments);
598 }
599
600 let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
601
602 let jwt = serde_json::from_slice::<Self>(&decoded)?;
603 Ok(jwt)
604 }
605}
606
607#[derive(thiserror::Error, Debug)]
608pub enum JwtError {
609 #[error("JWT is missing a header")]
610 MissingHeader,
611 #[error("JWT is missing a payload")]
612 MissingPayload,
613 #[error("JWT is missing a signature")]
614 MissingSignature,
615 #[error("JWT has too many segments")]
616 TooManySegments,
617 #[error(transparent)]
618 Base64(#[from] base64::DecodeError),
619 #[error(transparent)]
620 Serde(#[from] serde_json::Error),
621}
622
623fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
624 if Realm::from(url) == Realm::from(&**api) {
626 return true;
627 }
628
629 if matches!(url.scheme(), "https") && matches_domain(url, cdn) {
634 return true;
635 }
636
637 false
638}
639
640fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
641 if let Some(domain) = url.domain() {
643 if matches_domain(api, domain) {
644 return true;
645 }
646 }
647 is_known_url(url, api, cdn)
648}
649
650fn matches_domain(url: &Url, domain: &str) -> bool {
652 url.domain().is_some_and(|subdomain| {
653 subdomain == domain
654 || subdomain
655 .strip_suffix(domain)
656 .is_some_and(|prefix| prefix.ends_with('.'))
657 })
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn test_is_known_url() {
666 let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
667 let cdn_domain = "astralhosted.com";
668
669 assert!(is_known_url(
671 &Url::parse("https://api.pyx.dev/simple/").unwrap(),
672 &api_url,
673 cdn_domain
674 ));
675
676 assert!(is_known_url(
678 &Url::parse("https://api.pyx.dev/v1/").unwrap(),
679 &api_url,
680 cdn_domain
681 ));
682
683 assert!(is_known_url(
685 &Url::parse("https://astralhosted.com/packages/").unwrap(),
686 &api_url,
687 cdn_domain
688 ));
689
690 assert!(is_known_url(
692 &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
693 &api_url,
694 cdn_domain
695 ));
696
697 assert!(!is_known_url(
699 &Url::parse("http://astralhosted.com/packages/").unwrap(),
700 &api_url,
701 cdn_domain
702 ));
703
704 assert!(!is_known_url(
706 &Url::parse("https://pypi.org/simple/").unwrap(),
707 &api_url,
708 cdn_domain
709 ));
710
711 assert!(!is_known_url(
713 &Url::parse("https://badastralhosted.com/packages/").unwrap(),
714 &api_url,
715 cdn_domain
716 ));
717 }
718
719 #[test]
720 fn test_is_known_domain() {
721 let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
722 let cdn_domain = "astralhosted.com";
723
724 assert!(is_known_domain(
726 &Url::parse("https://api.pyx.dev/simple/").unwrap(),
727 &api_url,
728 cdn_domain
729 ));
730
731 assert!(is_known_domain(
733 &Url::parse("https://pyx.dev").unwrap(),
734 &api_url,
735 cdn_domain
736 ));
737
738 assert!(!is_known_domain(
740 &Url::parse("https://foo.api.pyx.dev").unwrap(),
741 &api_url,
742 cdn_domain
743 ));
744
745 assert!(!is_known_domain(
747 &Url::parse("https://beta.pyx.dev/").unwrap(),
748 &api_url,
749 cdn_domain
750 ));
751
752 assert!(is_known_domain(
754 &Url::parse("https://astralhosted.com/packages/").unwrap(),
755 &api_url,
756 cdn_domain
757 ));
758
759 assert!(is_known_domain(
761 &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
762 &api_url,
763 cdn_domain
764 ));
765
766 assert!(!is_known_domain(
768 &Url::parse("https://pypi.org/simple/").unwrap(),
769 &api_url,
770 cdn_domain
771 ));
772
773 assert!(!is_known_domain(
775 &Url::parse("https://pyx.com/").unwrap(),
776 &api_url,
777 cdn_domain
778 ));
779 }
780
781 #[test]
782 fn test_matches_domain() {
783 assert!(matches_domain(
784 &Url::parse("https://example.com").unwrap(),
785 "example.com"
786 ));
787 assert!(matches_domain(
788 &Url::parse("https://foo.example.com").unwrap(),
789 "example.com"
790 ));
791 assert!(matches_domain(
792 &Url::parse("https://bar.foo.example.com").unwrap(),
793 "example.com"
794 ));
795
796 assert!(!matches_domain(
797 &Url::parse("https://example.com").unwrap(),
798 "other.com"
799 ));
800 assert!(!matches_domain(
801 &Url::parse("https://example.org").unwrap(),
802 "example.com"
803 ));
804 assert!(!matches_domain(
805 &Url::parse("https://badexample.com").unwrap(),
806 "example.com"
807 ));
808 }
809}