1use std::io;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use base64::Engine;
6use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use etcetera::BaseStrategy;
8use reqwest_middleware::ClientWithMiddleware;
9use tracing::debug;
10use url::Url;
11
12use uv_cache_key::CanonicalUrl;
13use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
14use uv_small_str::SmallString;
15use uv_state::{StateBucket, StateStore};
16use uv_static::EnvVars;
17
18use crate::credentials::Token;
19use crate::{AccessToken, Credentials, Realm};
20
21fn read_pyx_api_key() -> Option<String> {
23 std::env::var(EnvVars::PYX_API_KEY)
24 .ok()
25 .or_else(|| std::env::var(EnvVars::UV_API_KEY).ok())
26}
27
28fn read_pyx_auth_token() -> Option<AccessToken> {
30 std::env::var(EnvVars::PYX_AUTH_TOKEN)
31 .ok()
32 .or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok())
33 .map(AccessToken::from)
34}
35
36#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
41pub struct PyxOAuthTokens {
42 pub access_token: AccessToken,
43 pub refresh_token: String,
44}
45
46#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
48pub struct PyxApiKeyTokens {
49 pub access_token: AccessToken,
50 pub api_key: String,
51}
52
53#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
54pub enum PyxTokens {
55 OAuth(PyxOAuthTokens),
60 ApiKey(PyxApiKeyTokens),
64}
65
66impl From<PyxTokens> for AccessToken {
67 fn from(tokens: PyxTokens) -> Self {
68 match tokens {
69 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
70 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
71 }
72 }
73}
74
75impl From<PyxTokens> for Credentials {
76 fn from(tokens: PyxTokens) -> Self {
77 let access_token = match tokens {
78 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
79 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
80 };
81 Self::from(access_token)
82 }
83}
84
85impl From<AccessToken> for Credentials {
86 fn from(access_token: AccessToken) -> Self {
87 Self::Bearer {
88 token: Token::new(access_token.into_bytes()),
89 }
90 }
91}
92
93pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5;
95
96#[derive(Debug, Clone)]
97struct PyxDirectories {
98 root: PathBuf,
100 subdirectory: PathBuf,
102}
103
104impl PyxDirectories {
105 fn from_api(api: &DisplaySafeUrl) -> Result<Self, io::Error> {
107 let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api));
109
110 if let Some(root) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) {
112 let root = std::path::absolute(root)?;
113 let subdirectory = root.join(&digest);
114 return Ok(Self { root, subdirectory });
115 }
116
117 let root = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) {
120 std::path::absolute(tool_dir)?
121 } else {
122 StateStore::from_settings(None)?.bucket(StateBucket::Credentials)
123 };
124 let subdirectory = root.join(&digest);
125 if subdirectory.exists() {
126 return Ok(Self { root, subdirectory });
127 }
128
129 let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else {
131 return Err(io::Error::new(
132 io::ErrorKind::NotFound,
133 "Could not determine user data directory",
134 ));
135 };
136
137 let root = xdg.data_dir().join("pyx").join("credentials");
138 let subdirectory = root.join(&digest);
139 Ok(Self { root, subdirectory })
140 }
141}
142
143#[derive(Debug, Clone)]
144pub struct PyxTokenStore {
145 root: PathBuf,
147 subdirectory: PathBuf,
149 api: DisplaySafeUrl,
151 cdn: SmallString,
153}
154
155impl PyxTokenStore {
156 pub fn from_settings() -> Result<Self, TokenStoreError> {
158 let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) {
161 DisplaySafeUrl::parse(&api_url)
162 } else {
163 DisplaySafeUrl::parse("https://api.pyx.dev")
164 }?;
165 let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN)
166 .ok()
167 .map(SmallString::from)
168 .unwrap_or_else(|| SmallString::from(arcstr::literal!("astralhosted.com")));
169
170 let PyxDirectories { root, subdirectory } = PyxDirectories::from_api(&api)?;
172
173 Ok(Self {
174 root,
175 subdirectory,
176 api,
177 cdn,
178 })
179 }
180
181 pub fn root(&self) -> &Path {
183 &self.root
184 }
185
186 pub fn api(&self) -> &DisplaySafeUrl {
188 &self.api
189 }
190
191 pub async fn access_token(
200 &self,
201 client: &ClientWithMiddleware,
202 tolerance_secs: u64,
203 ) -> Result<Option<AccessToken>, TokenStoreError> {
204 if let Some(access_token) = read_pyx_auth_token() {
206 return Ok(Some(access_token));
207 }
208
209 let tokens = self.init(client, tolerance_secs).await?;
211
212 Ok(tokens.map(AccessToken::from))
214 }
215
216 pub async fn init(
223 &self,
224 client: &ClientWithMiddleware,
225 tolerance_secs: u64,
226 ) -> Result<Option<PyxTokens>, TokenStoreError> {
227 match self.read().await? {
228 Some(tokens) => {
229 let tokens = self.refresh(tokens, client, tolerance_secs).await?;
231 Ok(Some(tokens))
232 }
233 None => {
234 self.bootstrap(client).await
236 }
237 }
238 }
239
240 pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> {
242 fs_err::tokio::create_dir_all(&self.subdirectory).await?;
243 match tokens {
244 PyxTokens::OAuth(tokens) => {
245 fs_err::tokio::write(
247 self.subdirectory.join("tokens.json"),
248 serde_json::to_vec(tokens)?,
249 )
250 .await?;
251 }
252 PyxTokens::ApiKey(tokens) => {
253 let digest = uv_cache_key::cache_digest(&tokens.api_key);
255 fs_err::tokio::write(
256 self.subdirectory.join(format!("{digest}.json")),
257 &tokens.access_token,
258 )
259 .await?;
260 }
261 }
262 Ok(())
263 }
264
265 pub fn has_auth_token(&self) -> bool {
267 read_pyx_auth_token().is_some()
268 }
269
270 pub fn has_api_key(&self) -> bool {
272 read_pyx_api_key().is_some()
273 }
274
275 pub fn has_oauth_tokens(&self) -> bool {
277 self.subdirectory.join("tokens.json").is_file()
278 }
279
280 pub fn has_credentials(&self) -> bool {
282 self.has_auth_token() || self.has_api_key() || self.has_oauth_tokens()
283 }
284
285 pub async fn read(&self) -> Result<Option<PyxTokens>, TokenStoreError> {
287 if let Some(api_key) = read_pyx_api_key() {
288 let digest = uv_cache_key::cache_digest(&api_key);
290 match fs_err::tokio::read(self.subdirectory.join(format!("{digest}.json"))).await {
291 Ok(data) => {
292 let access_token =
293 AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8"));
294 Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens {
295 access_token,
296 api_key,
297 })))
298 }
299 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
300 Err(err) => Err(err.into()),
301 }
302 } else {
303 match fs_err::tokio::read(self.subdirectory.join("tokens.json")).await {
304 Ok(data) => {
305 let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?;
306 Ok(Some(PyxTokens::OAuth(tokens)))
307 }
308 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
309 Err(err) => Err(err.into()),
310 }
311 }
312 }
313
314 pub async fn delete(&self) -> Result<(), io::Error> {
316 fs_err::tokio::remove_dir_all(&self.subdirectory).await?;
317 Ok(())
318 }
319
320 async fn bootstrap(
322 &self,
323 client: &ClientWithMiddleware,
324 ) -> Result<Option<PyxTokens>, TokenStoreError> {
325 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
326 struct Payload {
327 access_token: AccessToken,
328 }
329
330 let Some(api_key) = read_pyx_api_key() else {
332 return Ok(None);
333 };
334
335 debug!("Bootstrapping access token from an API key");
336
337 let mut url = self.api.clone();
339 url.set_path("auth/cli/access-token");
340
341 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
342 request.headers_mut().insert(
343 "Authorization",
344 reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
345 );
346
347 let response = client.execute(request).await?;
348 let Payload { access_token } = response.error_for_status()?.json::<Payload>().await?;
349 let tokens = PyxTokens::ApiKey(PyxApiKeyTokens {
350 access_token,
351 api_key,
352 });
353
354 self.write(&tokens).await?;
356
357 Ok(Some(tokens))
358 }
359
360 async fn refresh(
365 &self,
366 tokens: PyxTokens,
367 client: &ClientWithMiddleware,
368 tolerance_secs: u64,
369 ) -> Result<PyxTokens, TokenStoreError> {
370 let jwt = PyxJwt::decode(match &tokens {
372 PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
373 PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
374 })?;
375
376 let is_up_to_date = match jwt.exp {
378 None => {
379 debug!("Access token has no expiration; refreshing...");
380 false
381 }
382 Some(..) if tolerance_secs == 0 => {
383 debug!("Refreshing access token due to zero tolerance...");
384 false
385 }
386 Some(jwt) => {
387 let exp = jiff::Timestamp::from_second(jwt)?;
388 let now = jiff::Timestamp::now();
389 if exp < now {
390 debug!("Access token is expired (`{exp}`); refreshing...");
391 false
392 } else if exp < now + Duration::from_secs(tolerance_secs) {
393 debug!(
394 "Access token will expire within the tolerance (`{exp}`); refreshing..."
395 );
396 false
397 } else {
398 debug!("Access token is up-to-date (`{exp}`)");
399 true
400 }
401 }
402 };
403
404 if is_up_to_date {
405 return Ok(tokens);
406 }
407
408 let tokens = match tokens {
409 PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => {
410 let mut url = self.api.clone();
412 url.set_path("auth/cli/refresh");
413
414 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
415 let body = serde_json::json!({
416 "refresh_token": refresh_token
417 });
418 *request.body_mut() = Some(body.to_string().into());
419
420 let response = client.execute(request).await?;
421 let tokens = response
422 .error_for_status()?
423 .json::<PyxOAuthTokens>()
424 .await?;
425 PyxTokens::OAuth(tokens)
426 }
427 PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
428 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
429 struct Payload {
430 access_token: AccessToken,
431 }
432
433 let mut url = self.api.clone();
435 url.set_path("auth/cli/access-token");
436
437 let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
438 request.headers_mut().insert(
439 "Authorization",
440 reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
441 );
442
443 let response = client.execute(request).await?;
444 let Payload { access_token } =
445 response.error_for_status()?.json::<Payload>().await?;
446 PyxTokens::ApiKey(PyxApiKeyTokens {
447 access_token,
448 api_key,
449 })
450 }
451 };
452
453 self.write(&tokens).await?;
455 Ok(tokens)
456 }
457
458 pub fn is_known_url(&self, url: &Url) -> bool {
461 is_known_url(url, &self.api, &self.cdn)
462 }
463
464 pub fn is_known_domain(&self, url: &Url) -> bool {
469 is_known_domain(url, &self.api, &self.cdn)
470 }
471}
472
473#[derive(thiserror::Error, Debug)]
474pub enum TokenStoreError {
475 #[error(transparent)]
476 Url(#[from] DisplaySafeUrlError),
477 #[error(transparent)]
478 Io(#[from] io::Error),
479 #[error(transparent)]
480 Serialization(#[from] serde_json::Error),
481 #[error(transparent)]
482 Reqwest(#[from] reqwest::Error),
483 #[error(transparent)]
484 ReqwestMiddleware(#[from] reqwest_middleware::Error),
485 #[error(transparent)]
486 InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
487 #[error(transparent)]
488 Jiff(#[from] jiff::Error),
489 #[error(transparent)]
490 Jwt(#[from] JwtError),
491}
492
493impl TokenStoreError {
494 pub fn is_unauthorized(&self) -> bool {
496 match self {
497 Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
498 Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
499 _ => false,
500 }
501 }
502}
503
504#[derive(Debug, serde::Deserialize)]
506pub struct PyxJwt {
507 pub exp: Option<i64>,
509 pub iss: Option<String>,
511 #[serde(rename = "urn:pyx:org_name")]
513 pub name: Option<String>,
514}
515
516impl PyxJwt {
517 pub fn decode(access_token: &AccessToken) -> Result<Self, JwtError> {
519 let mut token_segments = access_token.as_str().splitn(3, '.');
520
521 let _header = token_segments.next().ok_or(JwtError::MissingHeader)?;
522 let payload = token_segments.next().ok_or(JwtError::MissingPayload)?;
523 let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?;
524 if token_segments.next().is_some() {
525 return Err(JwtError::TooManySegments);
526 }
527
528 let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
529
530 let jwt = serde_json::from_slice::<Self>(&decoded)?;
531 Ok(jwt)
532 }
533}
534
535#[derive(thiserror::Error, Debug)]
536pub enum JwtError {
537 #[error("JWT is missing a header")]
538 MissingHeader,
539 #[error("JWT is missing a payload")]
540 MissingPayload,
541 #[error("JWT is missing a signature")]
542 MissingSignature,
543 #[error("JWT has too many segments")]
544 TooManySegments,
545 #[error(transparent)]
546 Base64(#[from] base64::DecodeError),
547 #[error(transparent)]
548 Serde(#[from] serde_json::Error),
549}
550
551fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
552 if Realm::from(url) == Realm::from(&**api) {
554 return true;
555 }
556
557 if matches!(url.scheme(), "https") && matches_domain(url, cdn) {
562 return true;
563 }
564
565 false
566}
567
568fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
569 if let Some(domain) = url.domain() {
571 if matches_domain(api, domain) {
572 return true;
573 }
574 }
575 is_known_url(url, api, cdn)
576}
577
578fn matches_domain(url: &Url, domain: &str) -> bool {
580 url.domain().is_some_and(|subdomain| {
581 subdomain == domain
582 || subdomain
583 .strip_suffix(domain)
584 .is_some_and(|prefix| prefix.ends_with('.'))
585 })
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_is_known_url() {
594 let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
595 let cdn_domain = "astralhosted.com";
596
597 assert!(is_known_url(
599 &Url::parse("https://api.pyx.dev/simple/").unwrap(),
600 &api_url,
601 cdn_domain
602 ));
603
604 assert!(is_known_url(
606 &Url::parse("https://api.pyx.dev/v1/").unwrap(),
607 &api_url,
608 cdn_domain
609 ));
610
611 assert!(is_known_url(
613 &Url::parse("https://astralhosted.com/packages/").unwrap(),
614 &api_url,
615 cdn_domain
616 ));
617
618 assert!(is_known_url(
620 &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
621 &api_url,
622 cdn_domain
623 ));
624
625 assert!(!is_known_url(
627 &Url::parse("http://astralhosted.com/packages/").unwrap(),
628 &api_url,
629 cdn_domain
630 ));
631
632 assert!(!is_known_url(
634 &Url::parse("https://pypi.org/simple/").unwrap(),
635 &api_url,
636 cdn_domain
637 ));
638
639 assert!(!is_known_url(
641 &Url::parse("https://badastralhosted.com/packages/").unwrap(),
642 &api_url,
643 cdn_domain
644 ));
645 }
646
647 #[test]
648 fn test_is_known_domain() {
649 let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
650 let cdn_domain = "astralhosted.com";
651
652 assert!(is_known_domain(
654 &Url::parse("https://api.pyx.dev/simple/").unwrap(),
655 &api_url,
656 cdn_domain
657 ));
658
659 assert!(is_known_domain(
661 &Url::parse("https://pyx.dev").unwrap(),
662 &api_url,
663 cdn_domain
664 ));
665
666 assert!(!is_known_domain(
668 &Url::parse("https://foo.api.pyx.dev").unwrap(),
669 &api_url,
670 cdn_domain
671 ));
672
673 assert!(!is_known_domain(
675 &Url::parse("https://beta.pyx.dev/").unwrap(),
676 &api_url,
677 cdn_domain
678 ));
679
680 assert!(is_known_domain(
682 &Url::parse("https://astralhosted.com/packages/").unwrap(),
683 &api_url,
684 cdn_domain
685 ));
686
687 assert!(is_known_domain(
689 &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
690 &api_url,
691 cdn_domain
692 ));
693
694 assert!(!is_known_domain(
696 &Url::parse("https://pypi.org/simple/").unwrap(),
697 &api_url,
698 cdn_domain
699 ));
700
701 assert!(!is_known_domain(
703 &Url::parse("https://pyx.com/").unwrap(),
704 &api_url,
705 cdn_domain
706 ));
707 }
708
709 #[test]
710 fn test_matches_domain() {
711 assert!(matches_domain(
712 &Url::parse("https://example.com").unwrap(),
713 "example.com"
714 ));
715 assert!(matches_domain(
716 &Url::parse("https://foo.example.com").unwrap(),
717 "example.com"
718 ));
719 assert!(matches_domain(
720 &Url::parse("https://bar.foo.example.com").unwrap(),
721 "example.com"
722 ));
723
724 assert!(!matches_domain(
725 &Url::parse("https://example.com").unwrap(),
726 "other.com"
727 ));
728 assert!(!matches_domain(
729 &Url::parse("https://example.org").unwrap(),
730 "example.com"
731 ));
732 assert!(!matches_domain(
733 &Url::parse("https://badexample.com").unwrap(),
734 "example.com"
735 ));
736 }
737}