Skip to main content

uv_distribution/
distribution_database.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{FutureExt, TryStreamExt};
9use tempfile::TempDir;
10use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
11use tokio::sync::Semaphore;
12use tokio_util::compat::FuturesAsyncReadCompatExt;
13use tracing::{Instrument, info_span, instrument, warn};
14use url::Url;
15
16use uv_cache::{ArchiveId, CacheBucket, CacheEntry, WheelCache};
17use uv_cache_info::{CacheInfo, Timestamp};
18use uv_client::{
19    CacheControl, CachedClientError, Connectivity, DataWithCachePolicy, RegistryClient,
20};
21use uv_distribution_filename::WheelFilename;
22use uv_distribution_types::{
23    BuildInfo, BuildableSource, BuiltDist, Dist, File, HashPolicy, Hashed, IndexUrl, InstalledDist,
24    Name, SourceDist, ToUrlError,
25};
26use uv_extract::hash::Hasher;
27use uv_fs::write_atomic;
28use uv_platform_tags::Tags;
29use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
30use uv_redacted::DisplaySafeUrl;
31use uv_types::{BuildContext, BuildStack};
32
33use crate::archive::Archive;
34use crate::metadata::{ArchiveMetadata, Metadata};
35use crate::source::SourceDistributionBuilder;
36use crate::{Error, LocalWheel, Reporter, RequiresDist};
37
38/// A cached high-level interface to convert distributions (a requirement resolved to a location)
39/// to a wheel or wheel metadata.
40///
41/// For wheel metadata, this happens by either fetching the metadata from the remote wheel or by
42/// building the source distribution. For wheel files, either the wheel is downloaded or a source
43/// distribution is downloaded, built and the new wheel gets returned.
44///
45/// All kinds of wheel sources (index, URL, path) and source distribution source (index, URL, path,
46/// Git) are supported.
47///
48/// This struct also has the task of acquiring locks around source dist builds in general and git
49/// operation especially, as well as respecting concurrency limits.
50pub struct DistributionDatabase<'a, Context: BuildContext> {
51    build_context: &'a Context,
52    builder: SourceDistributionBuilder<'a, Context>,
53    client: ManagedClient<'a>,
54    reporter: Option<Arc<dyn Reporter>>,
55}
56
57impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
58    pub fn new(
59        client: &'a RegistryClient,
60        build_context: &'a Context,
61        concurrent_downloads: usize,
62    ) -> Self {
63        Self {
64            build_context,
65            builder: SourceDistributionBuilder::new(build_context),
66            client: ManagedClient::new(client, concurrent_downloads),
67            reporter: None,
68        }
69    }
70
71    /// Set the build stack to use for the [`DistributionDatabase`].
72    #[must_use]
73    pub fn with_build_stack(self, build_stack: &'a BuildStack) -> Self {
74        Self {
75            builder: self.builder.with_build_stack(build_stack),
76            ..self
77        }
78    }
79
80    /// Set the [`Reporter`] to use for the [`DistributionDatabase`].
81    #[must_use]
82    pub fn with_reporter(self, reporter: Arc<dyn Reporter>) -> Self {
83        Self {
84            builder: self.builder.with_reporter(reporter.clone()),
85            reporter: Some(reporter),
86            ..self
87        }
88    }
89
90    /// Handle a specific `reqwest` error, and convert it to [`io::Error`].
91    fn handle_response_errors(&self, err: reqwest::Error) -> io::Error {
92        if err.is_timeout() {
93            // Assumption: The connect timeout with the 10s default is not the culprit.
94            io::Error::new(
95                io::ErrorKind::TimedOut,
96                format!(
97                    "Failed to download distribution due to network timeout. Try increasing UV_HTTP_TIMEOUT (current value: {}s).",
98                    self.client.unmanaged.read_timeout().as_secs()
99                ),
100            )
101        } else {
102            io::Error::other(err)
103        }
104    }
105
106    /// Either fetch the wheel or fetch and build the source distribution
107    ///
108    /// Returns a wheel that's compliant with the given platform tags.
109    ///
110    /// While hashes will be generated in some cases, hash-checking is only enforced for source
111    /// distributions, and should be enforced by the caller for wheels.
112    #[instrument(skip_all, fields(%dist))]
113    pub async fn get_or_build_wheel(
114        &self,
115        dist: &Dist,
116        tags: &Tags,
117        hashes: HashPolicy<'_>,
118    ) -> Result<LocalWheel, Error> {
119        match dist {
120            Dist::Built(built) => self.get_wheel(built, hashes).await,
121            Dist::Source(source) => self.build_wheel(source, tags, hashes).await,
122        }
123    }
124
125    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
126    /// fetch and build the source distribution.
127    ///
128    /// While hashes will be generated in some cases, hash-checking is only enforced for source
129    /// distributions, and should be enforced by the caller for wheels.
130    #[instrument(skip_all, fields(%dist))]
131    pub async fn get_installed_metadata(
132        &self,
133        dist: &InstalledDist,
134    ) -> Result<ArchiveMetadata, Error> {
135        // If the metadata was provided by the user directly, prefer it.
136        if let Some(metadata) = self
137            .build_context
138            .dependency_metadata()
139            .get(dist.name(), Some(dist.version()))
140        {
141            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
142        }
143
144        let metadata = dist
145            .read_metadata()
146            .map_err(|err| Error::ReadInstalled(Box::new(dist.clone()), err))?;
147
148        Ok(ArchiveMetadata::from_metadata23(metadata.clone()))
149    }
150
151    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
152    /// fetch and build the source distribution.
153    ///
154    /// While hashes will be generated in some cases, hash-checking is only enforced for source
155    /// distributions, and should be enforced by the caller for wheels.
156    #[instrument(skip_all, fields(%dist))]
157    pub async fn get_or_build_wheel_metadata(
158        &self,
159        dist: &Dist,
160        hashes: HashPolicy<'_>,
161    ) -> Result<ArchiveMetadata, Error> {
162        match dist {
163            Dist::Built(built) => self.get_wheel_metadata(built, hashes).await,
164            Dist::Source(source) => {
165                self.build_wheel_metadata(&BuildableSource::Dist(source), hashes)
166                    .await
167            }
168        }
169    }
170
171    /// Fetch a wheel from the cache or download it from the index.
172    ///
173    /// While hashes will be generated in all cases, hash-checking is _not_ enforced and should
174    /// instead be enforced by the caller.
175    async fn get_wheel(
176        &self,
177        dist: &BuiltDist,
178        hashes: HashPolicy<'_>,
179    ) -> Result<LocalWheel, Error> {
180        match dist {
181            BuiltDist::Registry(wheels) => {
182                let wheel = wheels.best_wheel();
183                let WheelTarget {
184                    url,
185                    extension,
186                    size,
187                } = WheelTarget::try_from(&*wheel.file)?;
188
189                // Create a cache entry for the wheel.
190                let wheel_entry = self.build_context.cache().entry(
191                    CacheBucket::Wheels,
192                    WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
193                    wheel.filename.cache_key(),
194                );
195
196                // If the URL is a file URL, load the wheel directly.
197                if url.scheme() == "file" {
198                    let path = url
199                        .to_file_path()
200                        .map_err(|()| Error::NonFileUrl(url.clone()))?;
201                    return self
202                        .load_wheel(
203                            &path,
204                            &wheel.filename,
205                            WheelExtension::Whl,
206                            wheel_entry,
207                            dist,
208                            hashes,
209                        )
210                        .await;
211                }
212
213                // Download and unzip.
214                match self
215                    .stream_wheel(
216                        url.clone(),
217                        dist.index(),
218                        &wheel.filename,
219                        extension,
220                        size,
221                        &wheel_entry,
222                        dist,
223                        hashes,
224                    )
225                    .await
226                {
227                    Ok(archive) => Ok(LocalWheel {
228                        dist: Dist::Built(dist.clone()),
229                        archive: self
230                            .build_context
231                            .cache()
232                            .archive(&archive.id)
233                            .into_boxed_path(),
234                        hashes: archive.hashes,
235                        filename: wheel.filename.clone(),
236                        cache: CacheInfo::default(),
237                        build: None,
238                    }),
239                    Err(Error::Extract(name, err)) => {
240                        if err.is_http_streaming_unsupported() {
241                            warn!(
242                                "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
243                            );
244                        } else if err.is_http_streaming_failed() {
245                            warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
246                        } else {
247                            return Err(Error::Extract(name, err));
248                        }
249
250                        // If the request failed because streaming is unsupported, download the
251                        // wheel directly.
252                        let archive = self
253                            .download_wheel(
254                                url,
255                                dist.index(),
256                                &wheel.filename,
257                                extension,
258                                size,
259                                &wheel_entry,
260                                dist,
261                                hashes,
262                            )
263                            .await?;
264
265                        Ok(LocalWheel {
266                            dist: Dist::Built(dist.clone()),
267                            archive: self
268                                .build_context
269                                .cache()
270                                .archive(&archive.id)
271                                .into_boxed_path(),
272                            hashes: archive.hashes,
273                            filename: wheel.filename.clone(),
274                            cache: CacheInfo::default(),
275                            build: None,
276                        })
277                    }
278                    Err(err) => Err(err),
279                }
280            }
281
282            BuiltDist::DirectUrl(wheel) => {
283                // Create a cache entry for the wheel.
284                let wheel_entry = self.build_context.cache().entry(
285                    CacheBucket::Wheels,
286                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
287                    wheel.filename.cache_key(),
288                );
289
290                // Download and unzip.
291                match self
292                    .stream_wheel(
293                        wheel.url.raw().clone(),
294                        None,
295                        &wheel.filename,
296                        WheelExtension::Whl,
297                        None,
298                        &wheel_entry,
299                        dist,
300                        hashes,
301                    )
302                    .await
303                {
304                    Ok(archive) => Ok(LocalWheel {
305                        dist: Dist::Built(dist.clone()),
306                        archive: self
307                            .build_context
308                            .cache()
309                            .archive(&archive.id)
310                            .into_boxed_path(),
311                        hashes: archive.hashes,
312                        filename: wheel.filename.clone(),
313                        cache: CacheInfo::default(),
314                        build: None,
315                    }),
316                    Err(Error::Client(err)) if err.is_http_streaming_unsupported() => {
317                        warn!(
318                            "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
319                        );
320
321                        // If the request failed because streaming is unsupported, download the
322                        // wheel directly.
323                        let archive = self
324                            .download_wheel(
325                                wheel.url.raw().clone(),
326                                None,
327                                &wheel.filename,
328                                WheelExtension::Whl,
329                                None,
330                                &wheel_entry,
331                                dist,
332                                hashes,
333                            )
334                            .await?;
335                        Ok(LocalWheel {
336                            dist: Dist::Built(dist.clone()),
337                            archive: self
338                                .build_context
339                                .cache()
340                                .archive(&archive.id)
341                                .into_boxed_path(),
342                            hashes: archive.hashes,
343                            filename: wheel.filename.clone(),
344                            cache: CacheInfo::default(),
345                            build: None,
346                        })
347                    }
348                    Err(err) => Err(err),
349                }
350            }
351
352            BuiltDist::Path(wheel) => {
353                let cache_entry = self.build_context.cache().entry(
354                    CacheBucket::Wheels,
355                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
356                    wheel.filename.cache_key(),
357                );
358
359                self.load_wheel(
360                    &wheel.install_path,
361                    &wheel.filename,
362                    WheelExtension::Whl,
363                    cache_entry,
364                    dist,
365                    hashes,
366                )
367                .await
368            }
369        }
370    }
371
372    /// Convert a source distribution into a wheel, fetching it from the cache or building it if
373    /// necessary.
374    ///
375    /// The returned wheel is guaranteed to come from a distribution with a matching hash, and
376    /// no build processes will be executed for distributions with mismatched hashes.
377    async fn build_wheel(
378        &self,
379        dist: &SourceDist,
380        tags: &Tags,
381        hashes: HashPolicy<'_>,
382    ) -> Result<LocalWheel, Error> {
383        let built_wheel = self
384            .builder
385            .download_and_build(&BuildableSource::Dist(dist), tags, hashes, &self.client)
386            .boxed_local()
387            .await?;
388
389        // Check that the wheel is compatible with its install target.
390        //
391        // When building a build dependency for a cross-install, the build dependency needs
392        // to install and run on the host instead of the target. In this case the `tags` are already
393        // for the host instead of the target, so this check passes.
394        if !built_wheel.filename.is_compatible(tags) {
395            return if tags.is_cross() {
396                Err(Error::BuiltWheelIncompatibleTargetPlatform {
397                    filename: built_wheel.filename,
398                    python_platform: tags.python_platform().clone(),
399                    python_version: tags.python_version(),
400                })
401            } else {
402                Err(Error::BuiltWheelIncompatibleHostPlatform {
403                    filename: built_wheel.filename,
404                    python_platform: tags.python_platform().clone(),
405                    python_version: tags.python_version(),
406                })
407            };
408        }
409
410        // Acquire the advisory lock.
411        #[cfg(windows)]
412        let _lock = {
413            let lock_entry = CacheEntry::new(
414                built_wheel.target.parent().unwrap(),
415                format!(
416                    "{}.lock",
417                    built_wheel.target.file_name().unwrap().to_str().unwrap()
418                ),
419            );
420            lock_entry.lock().await.map_err(Error::CacheLock)?
421        };
422
423        // If the wheel was unzipped previously, respect it. Source distributions are
424        // cached under a unique revision ID, so unzipped directories are never stale.
425        match self.build_context.cache().resolve_link(&built_wheel.target) {
426            Ok(archive) => {
427                return Ok(LocalWheel {
428                    dist: Dist::Source(dist.clone()),
429                    archive: archive.into_boxed_path(),
430                    filename: built_wheel.filename,
431                    hashes: built_wheel.hashes,
432                    cache: built_wheel.cache_info,
433                    build: Some(built_wheel.build_info),
434                });
435            }
436            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
437            Err(err) => return Err(Error::CacheRead(err)),
438        }
439
440        // Otherwise, unzip the wheel.
441        let id = self
442            .unzip_wheel(&built_wheel.path, &built_wheel.target)
443            .await?;
444
445        Ok(LocalWheel {
446            dist: Dist::Source(dist.clone()),
447            archive: self.build_context.cache().archive(&id).into_boxed_path(),
448            hashes: built_wheel.hashes,
449            filename: built_wheel.filename,
450            cache: built_wheel.cache_info,
451            build: Some(built_wheel.build_info),
452        })
453    }
454
455    /// Fetch the wheel metadata from the index, or from the cache if possible.
456    ///
457    /// While hashes will be generated in some cases, hash-checking is _not_ enforced and should
458    /// instead be enforced by the caller.
459    async fn get_wheel_metadata(
460        &self,
461        dist: &BuiltDist,
462        hashes: HashPolicy<'_>,
463    ) -> Result<ArchiveMetadata, Error> {
464        // If hash generation is enabled, and the distribution isn't hosted on a registry, get the
465        // entire wheel to ensure that the hashes are included in the response. If the distribution
466        // is hosted on an index, the hashes will be included in the simple metadata response.
467        // For hash _validation_, callers are expected to enforce the policy when retrieving the
468        // wheel.
469        //
470        // Historically, for `uv pip compile --universal`, we also generate hashes for
471        // registry-based distributions when the relevant registry doesn't provide them. This was
472        // motivated by `--find-links`. We continue that behavior (under `HashGeneration::All`) for
473        // backwards compatibility, but it's a little dubious, since we're only hashing _one_
474        // distribution here (as opposed to hashing all distributions for the version), and it may
475        // not even be a compatible distribution!
476        //
477        // TODO(charlie): Request the hashes via a separate method, to reduce the coupling in this API.
478        if hashes.is_generate(dist) {
479            let wheel = self.get_wheel(dist, hashes).await?;
480            // If the metadata was provided by the user directly, prefer it.
481            let metadata = if let Some(metadata) = self
482                .build_context
483                .dependency_metadata()
484                .get(dist.name(), Some(dist.version()))
485            {
486                metadata.clone()
487            } else {
488                wheel.metadata()?
489            };
490            let hashes = wheel.hashes;
491            return Ok(ArchiveMetadata {
492                metadata: Metadata::from_metadata23(metadata),
493                hashes,
494            });
495        }
496
497        // If the metadata was provided by the user directly, prefer it.
498        if let Some(metadata) = self
499            .build_context
500            .dependency_metadata()
501            .get(dist.name(), Some(dist.version()))
502        {
503            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
504        }
505
506        let result = self
507            .client
508            .managed(|client| {
509                client
510                    .wheel_metadata(dist, self.build_context.capabilities())
511                    .boxed_local()
512            })
513            .await;
514
515        match result {
516            Ok(metadata) => {
517                // Validate that the metadata is consistent with the distribution.
518                Ok(ArchiveMetadata::from_metadata23(metadata))
519            }
520            Err(err) if err.is_http_streaming_unsupported() => {
521                warn!(
522                    "Streaming unsupported when fetching metadata for {dist}; downloading wheel directly ({err})"
523                );
524
525                // If the request failed due to an error that could be resolved by
526                // downloading the wheel directly, try that.
527                let wheel = self.get_wheel(dist, hashes).await?;
528                let metadata = wheel.metadata()?;
529                let hashes = wheel.hashes;
530                Ok(ArchiveMetadata {
531                    metadata: Metadata::from_metadata23(metadata),
532                    hashes,
533                })
534            }
535            Err(err) => Err(err.into()),
536        }
537    }
538
539    /// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
540    ///
541    /// The returned metadata is guaranteed to come from a distribution with a matching hash, and
542    /// no build processes will be executed for distributions with mismatched hashes.
543    pub async fn build_wheel_metadata(
544        &self,
545        source: &BuildableSource<'_>,
546        hashes: HashPolicy<'_>,
547    ) -> Result<ArchiveMetadata, Error> {
548        // If the metadata was provided by the user directly, prefer it.
549        if let Some(dist) = source.as_dist() {
550            if let Some(metadata) = self
551                .build_context
552                .dependency_metadata()
553                .get(dist.name(), dist.version())
554            {
555                // If we skipped the build, we should still resolve any Git dependencies to precise
556                // commits.
557                self.builder.resolve_revision(source, &self.client).await?;
558
559                return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
560            }
561        }
562
563        let metadata = self
564            .builder
565            .download_and_build_metadata(source, hashes, &self.client)
566            .boxed_local()
567            .await?;
568
569        Ok(metadata)
570    }
571
572    /// Return the [`RequiresDist`] from a `pyproject.toml`, if it can be statically extracted.
573    pub async fn requires_dist(
574        &self,
575        path: &Path,
576        pyproject_toml: &PyProjectToml,
577    ) -> Result<Option<RequiresDist>, Error> {
578        self.builder
579            .source_tree_requires_dist(
580                path,
581                pyproject_toml,
582                self.client.unmanaged.credentials_cache(),
583            )
584            .await
585    }
586
587    /// Stream a wheel from a URL, unzipping it into the cache as it's downloaded.
588    async fn stream_wheel(
589        &self,
590        url: DisplaySafeUrl,
591        index: Option<&IndexUrl>,
592        filename: &WheelFilename,
593        extension: WheelExtension,
594        size: Option<u64>,
595        wheel_entry: &CacheEntry,
596        dist: &BuiltDist,
597        hashes: HashPolicy<'_>,
598    ) -> Result<Archive, Error> {
599        // Acquire an advisory lock, to guard against concurrent writes.
600        #[cfg(windows)]
601        let _lock = {
602            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
603            lock_entry.lock().await.map_err(Error::CacheLock)?
604        };
605
606        // Create an entry for the HTTP cache.
607        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
608
609        let download = |response: reqwest::Response| {
610            async {
611                let size = size.or_else(|| content_length(&response));
612
613                let progress = self
614                    .reporter
615                    .as_ref()
616                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
617
618                let reader = response
619                    .bytes_stream()
620                    .map_err(|err| self.handle_response_errors(err))
621                    .into_async_read();
622
623                // Create a hasher for each hash algorithm.
624                let algorithms = hashes.algorithms();
625                let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
626                let mut hasher = uv_extract::hash::HashReader::new(reader.compat(), &mut hashers);
627
628                // Download and unzip the wheel to a temporary directory.
629                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
630                    .map_err(Error::CacheWrite)?;
631
632                match progress {
633                    Some((reporter, progress)) => {
634                        let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
635                        match extension {
636                            WheelExtension::Whl => {
637                                uv_extract::stream::unzip(&mut reader, temp_dir.path())
638                                    .await
639                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
640                            }
641                            WheelExtension::WhlZst => {
642                                uv_extract::stream::untar_zst(&mut reader, temp_dir.path())
643                                    .await
644                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
645                            }
646                        }
647                    }
648                    None => match extension {
649                        WheelExtension::Whl => {
650                            uv_extract::stream::unzip(&mut hasher, temp_dir.path())
651                                .await
652                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
653                        }
654                        WheelExtension::WhlZst => {
655                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
656                                .await
657                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
658                        }
659                    },
660                }
661
662                // If necessary, exhaust the reader to compute the hash.
663                if !hashes.is_none() {
664                    hasher.finish().await.map_err(Error::HashExhaustion)?;
665                }
666
667                // Persist the temporary directory to the directory store.
668                let id = self
669                    .build_context
670                    .cache()
671                    .persist(temp_dir.keep(), wheel_entry.path())
672                    .await
673                    .map_err(Error::CacheRead)?;
674
675                if let Some((reporter, progress)) = progress {
676                    reporter.on_download_complete(dist.name(), progress);
677                }
678
679                Ok(Archive::new(
680                    id,
681                    hashers.into_iter().map(HashDigest::from).collect(),
682                    filename.clone(),
683                ))
684            }
685            .instrument(info_span!("wheel", wheel = %dist))
686        };
687
688        // Fetch the archive from the cache, or download it if necessary.
689        let req = self.request(url.clone())?;
690
691        // Determine the cache control policy for the URL.
692        let cache_control = match self.client.unmanaged.connectivity() {
693            Connectivity::Online => {
694                if let Some(header) = index.and_then(|index| {
695                    self.build_context
696                        .locations()
697                        .artifact_cache_control_for(index)
698                }) {
699                    CacheControl::Override(header)
700                } else {
701                    CacheControl::from(
702                        self.build_context
703                            .cache()
704                            .freshness(&http_entry, Some(&filename.name), None)
705                            .map_err(Error::CacheRead)?,
706                    )
707                }
708            }
709            Connectivity::Offline => CacheControl::AllowStale,
710        };
711
712        let archive = self
713            .client
714            .managed(|client| {
715                client.cached_client().get_serde_with_retry(
716                    req,
717                    &http_entry,
718                    cache_control,
719                    download,
720                )
721            })
722            .await
723            .map_err(|err| match err {
724                CachedClientError::Callback { err, .. } => err,
725                CachedClientError::Client(err) => Error::Client(err),
726            })?;
727
728        // If the archive is missing the required hashes, or has since been removed, force a refresh.
729        let archive = Some(archive)
730            .filter(|archive| archive.has_digests(hashes))
731            .filter(|archive| archive.exists(self.build_context.cache()));
732
733        let archive = if let Some(archive) = archive {
734            archive
735        } else {
736            self.client
737                .managed(async |client| {
738                    client
739                        .cached_client()
740                        .skip_cache_with_retry(
741                            self.request(url)?,
742                            &http_entry,
743                            cache_control,
744                            download,
745                        )
746                        .await
747                        .map_err(|err| match err {
748                            CachedClientError::Callback { err, .. } => err,
749                            CachedClientError::Client(err) => Error::Client(err),
750                        })
751                })
752                .await?
753        };
754
755        Ok(archive)
756    }
757
758    /// Download a wheel from a URL, then unzip it into the cache.
759    async fn download_wheel(
760        &self,
761        url: DisplaySafeUrl,
762        index: Option<&IndexUrl>,
763        filename: &WheelFilename,
764        extension: WheelExtension,
765        size: Option<u64>,
766        wheel_entry: &CacheEntry,
767        dist: &BuiltDist,
768        hashes: HashPolicy<'_>,
769    ) -> Result<Archive, Error> {
770        // Acquire an advisory lock, to guard against concurrent writes.
771        #[cfg(windows)]
772        let _lock = {
773            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
774            lock_entry.lock().await.map_err(Error::CacheLock)?
775        };
776
777        // Create an entry for the HTTP cache.
778        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
779
780        let download = |response: reqwest::Response| {
781            async {
782                let size = size.or_else(|| content_length(&response));
783
784                let progress = self
785                    .reporter
786                    .as_ref()
787                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
788
789                let reader = response
790                    .bytes_stream()
791                    .map_err(|err| self.handle_response_errors(err))
792                    .into_async_read();
793
794                // Download the wheel to a temporary file.
795                let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
796                    .map_err(Error::CacheWrite)?;
797                let mut writer = tokio::io::BufWriter::new(fs_err::tokio::File::from_std(
798                    // It's an unnamed file on Linux so that's the best approximation.
799                    fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
800                ));
801
802                match progress {
803                    Some((reporter, progress)) => {
804                        // Wrap the reader in a progress reporter. This will report 100% progress
805                        // after the download is complete, even if we still have to unzip and hash
806                        // part of the file.
807                        let mut reader =
808                            ProgressReader::new(reader.compat(), progress, &**reporter);
809
810                        tokio::io::copy(&mut reader, &mut writer)
811                            .await
812                            .map_err(Error::CacheWrite)?;
813                    }
814                    None => {
815                        tokio::io::copy(&mut reader.compat(), &mut writer)
816                            .await
817                            .map_err(Error::CacheWrite)?;
818                    }
819                }
820
821                // Unzip the wheel to a temporary directory.
822                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
823                    .map_err(Error::CacheWrite)?;
824                let mut file = writer.into_inner();
825                file.seek(io::SeekFrom::Start(0))
826                    .await
827                    .map_err(Error::CacheWrite)?;
828
829                // If no hashes are required, parallelize the unzip operation.
830                let hashes = if hashes.is_none() {
831                    let file = file.into_std().await;
832                    tokio::task::spawn_blocking({
833                        let target = temp_dir.path().to_owned();
834                        move || -> Result<(), uv_extract::Error> {
835                            // Unzip the wheel into a temporary directory.
836                            match extension {
837                                WheelExtension::Whl => {
838                                    uv_extract::unzip(file, &target)?;
839                                }
840                                WheelExtension::WhlZst => {
841                                    uv_extract::stream::untar_zst_file(file, &target)?;
842                                }
843                            }
844                            Ok(())
845                        }
846                    })
847                    .await?
848                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
849
850                    HashDigests::empty()
851                } else {
852                    // Create a hasher for each hash algorithm.
853                    let algorithms = hashes.algorithms();
854                    let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
855                    let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
856
857                    match extension {
858                        WheelExtension::Whl => {
859                            uv_extract::stream::unzip(&mut hasher, temp_dir.path())
860                                .await
861                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
862                        }
863                        WheelExtension::WhlZst => {
864                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
865                                .await
866                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
867                        }
868                    }
869
870                    // If necessary, exhaust the reader to compute the hash.
871                    hasher.finish().await.map_err(Error::HashExhaustion)?;
872
873                    hashers.into_iter().map(HashDigest::from).collect()
874                };
875
876                // Persist the temporary directory to the directory store.
877                let id = self
878                    .build_context
879                    .cache()
880                    .persist(temp_dir.keep(), wheel_entry.path())
881                    .await
882                    .map_err(Error::CacheRead)?;
883
884                if let Some((reporter, progress)) = progress {
885                    reporter.on_download_complete(dist.name(), progress);
886                }
887
888                Ok(Archive::new(id, hashes, filename.clone()))
889            }
890            .instrument(info_span!("wheel", wheel = %dist))
891        };
892
893        // Fetch the archive from the cache, or download it if necessary.
894        let req = self.request(url.clone())?;
895
896        // Determine the cache control policy for the URL.
897        let cache_control = match self.client.unmanaged.connectivity() {
898            Connectivity::Online => {
899                if let Some(header) = index.and_then(|index| {
900                    self.build_context
901                        .locations()
902                        .artifact_cache_control_for(index)
903                }) {
904                    CacheControl::Override(header)
905                } else {
906                    CacheControl::from(
907                        self.build_context
908                            .cache()
909                            .freshness(&http_entry, Some(&filename.name), None)
910                            .map_err(Error::CacheRead)?,
911                    )
912                }
913            }
914            Connectivity::Offline => CacheControl::AllowStale,
915        };
916
917        let archive = self
918            .client
919            .managed(|client| {
920                client.cached_client().get_serde_with_retry(
921                    req,
922                    &http_entry,
923                    cache_control,
924                    download,
925                )
926            })
927            .await
928            .map_err(|err| match err {
929                CachedClientError::Callback { err, .. } => err,
930                CachedClientError::Client(err) => Error::Client(err),
931            })?;
932
933        // If the archive is missing the required hashes, or has since been removed, force a refresh.
934        let archive = Some(archive)
935            .filter(|archive| archive.has_digests(hashes))
936            .filter(|archive| archive.exists(self.build_context.cache()));
937
938        let archive = if let Some(archive) = archive {
939            archive
940        } else {
941            self.client
942                .managed(async |client| {
943                    client
944                        .cached_client()
945                        .skip_cache_with_retry(
946                            self.request(url)?,
947                            &http_entry,
948                            cache_control,
949                            download,
950                        )
951                        .await
952                        .map_err(|err| match err {
953                            CachedClientError::Callback { err, .. } => err,
954                            CachedClientError::Client(err) => Error::Client(err),
955                        })
956                })
957                .await?
958        };
959
960        Ok(archive)
961    }
962
963    /// Load a wheel from a local path.
964    async fn load_wheel(
965        &self,
966        path: &Path,
967        filename: &WheelFilename,
968        extension: WheelExtension,
969        wheel_entry: CacheEntry,
970        dist: &BuiltDist,
971        hashes: HashPolicy<'_>,
972    ) -> Result<LocalWheel, Error> {
973        #[cfg(windows)]
974        let _lock = {
975            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
976            lock_entry.lock().await.map_err(Error::CacheLock)?
977        };
978
979        // Determine the last-modified time of the wheel.
980        let modified = Timestamp::from_path(path).map_err(Error::CacheRead)?;
981
982        // Attempt to read the archive pointer from the cache.
983        let pointer_entry = wheel_entry.with_file(format!("{}.rev", filename.cache_key()));
984        let pointer = LocalArchivePointer::read_from(&pointer_entry)?;
985
986        // Extract the archive from the pointer.
987        let archive = pointer
988            .filter(|pointer| pointer.is_up_to_date(modified))
989            .map(LocalArchivePointer::into_archive)
990            .filter(|archive| archive.has_digests(hashes));
991
992        // If the file is already unzipped, and the cache is up-to-date, return it.
993        if let Some(archive) = archive {
994            Ok(LocalWheel {
995                dist: Dist::Built(dist.clone()),
996                archive: self
997                    .build_context
998                    .cache()
999                    .archive(&archive.id)
1000                    .into_boxed_path(),
1001                hashes: archive.hashes,
1002                filename: filename.clone(),
1003                cache: CacheInfo::from_timestamp(modified),
1004                build: None,
1005            })
1006        } else if hashes.is_none() {
1007            // Otherwise, unzip the wheel.
1008            let archive = Archive::new(
1009                self.unzip_wheel(path, wheel_entry.path()).await?,
1010                HashDigests::empty(),
1011                filename.clone(),
1012            );
1013
1014            // Write the archive pointer to the cache.
1015            let pointer = LocalArchivePointer {
1016                timestamp: modified,
1017                archive: archive.clone(),
1018            };
1019            pointer.write_to(&pointer_entry).await?;
1020
1021            Ok(LocalWheel {
1022                dist: Dist::Built(dist.clone()),
1023                archive: self
1024                    .build_context
1025                    .cache()
1026                    .archive(&archive.id)
1027                    .into_boxed_path(),
1028                hashes: archive.hashes,
1029                filename: filename.clone(),
1030                cache: CacheInfo::from_timestamp(modified),
1031                build: None,
1032            })
1033        } else {
1034            // If necessary, compute the hashes of the wheel.
1035            let file = fs_err::tokio::File::open(path)
1036                .await
1037                .map_err(Error::CacheRead)?;
1038            let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
1039                .map_err(Error::CacheWrite)?;
1040
1041            // Create a hasher for each hash algorithm.
1042            let algorithms = hashes.algorithms();
1043            let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
1044            let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
1045
1046            // Unzip the wheel to a temporary directory.
1047            match extension {
1048                WheelExtension::Whl => {
1049                    uv_extract::stream::unzip(&mut hasher, temp_dir.path())
1050                        .await
1051                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1052                }
1053                WheelExtension::WhlZst => {
1054                    uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
1055                        .await
1056                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1057                }
1058            }
1059
1060            // Exhaust the reader to compute the hash.
1061            hasher.finish().await.map_err(Error::HashExhaustion)?;
1062
1063            let hashes = hashers.into_iter().map(HashDigest::from).collect();
1064
1065            // Persist the temporary directory to the directory store.
1066            let id = self
1067                .build_context
1068                .cache()
1069                .persist(temp_dir.keep(), wheel_entry.path())
1070                .await
1071                .map_err(Error::CacheWrite)?;
1072
1073            // Create an archive.
1074            let archive = Archive::new(id, hashes, filename.clone());
1075
1076            // Write the archive pointer to the cache.
1077            let pointer = LocalArchivePointer {
1078                timestamp: modified,
1079                archive: archive.clone(),
1080            };
1081            pointer.write_to(&pointer_entry).await?;
1082
1083            Ok(LocalWheel {
1084                dist: Dist::Built(dist.clone()),
1085                archive: self
1086                    .build_context
1087                    .cache()
1088                    .archive(&archive.id)
1089                    .into_boxed_path(),
1090                hashes: archive.hashes,
1091                filename: filename.clone(),
1092                cache: CacheInfo::from_timestamp(modified),
1093                build: None,
1094            })
1095        }
1096    }
1097
1098    /// Unzip a wheel into the cache, returning the path to the unzipped directory.
1099    async fn unzip_wheel(&self, path: &Path, target: &Path) -> Result<ArchiveId, Error> {
1100        let temp_dir = tokio::task::spawn_blocking({
1101            let path = path.to_owned();
1102            let root = self.build_context.cache().root().to_path_buf();
1103            move || -> Result<TempDir, Error> {
1104                // Unzip the wheel into a temporary directory.
1105                let temp_dir = tempfile::tempdir_in(root).map_err(Error::CacheWrite)?;
1106                let reader = fs_err::File::open(&path).map_err(Error::CacheWrite)?;
1107                uv_extract::unzip(reader, temp_dir.path())
1108                    .map_err(|err| Error::Extract(path.to_string_lossy().into_owned(), err))?;
1109                Ok(temp_dir)
1110            }
1111        })
1112        .await??;
1113
1114        // Persist the temporary directory to the directory store.
1115        let id = self
1116            .build_context
1117            .cache()
1118            .persist(temp_dir.keep(), target)
1119            .await
1120            .map_err(Error::CacheWrite)?;
1121
1122        Ok(id)
1123    }
1124
1125    /// Returns a GET [`reqwest::Request`] for the given URL.
1126    fn request(&self, url: DisplaySafeUrl) -> Result<reqwest::Request, reqwest::Error> {
1127        self.client
1128            .unmanaged
1129            .uncached_client(&url)
1130            .get(Url::from(url))
1131            .header(
1132                // `reqwest` defaults to accepting compressed responses.
1133                // Specify identity encoding to get consistent .whl downloading
1134                // behavior from servers. ref: https://github.com/pypa/pip/pull/1688
1135                "accept-encoding",
1136                reqwest::header::HeaderValue::from_static("identity"),
1137            )
1138            .build()
1139    }
1140
1141    /// Return the [`ManagedClient`] used by this resolver.
1142    pub fn client(&self) -> &ManagedClient<'a> {
1143        &self.client
1144    }
1145}
1146
1147/// A wrapper around `RegistryClient` that manages a concurrency limit.
1148pub struct ManagedClient<'a> {
1149    pub unmanaged: &'a RegistryClient,
1150    control: Semaphore,
1151}
1152
1153impl<'a> ManagedClient<'a> {
1154    /// Create a new `ManagedClient` using the given client and concurrency limit.
1155    fn new(client: &'a RegistryClient, concurrency: usize) -> Self {
1156        ManagedClient {
1157            unmanaged: client,
1158            control: Semaphore::new(concurrency),
1159        }
1160    }
1161
1162    /// Perform a request using the client, respecting the concurrency limit.
1163    ///
1164    /// If the concurrency limit has been reached, this method will wait until a pending
1165    /// operation completes before executing the closure.
1166    pub async fn managed<F, T>(&self, f: impl FnOnce(&'a RegistryClient) -> F) -> T
1167    where
1168        F: Future<Output = T>,
1169    {
1170        let _permit = self.control.acquire().await.unwrap();
1171        f(self.unmanaged).await
1172    }
1173
1174    /// Perform a request using a client that internally manages the concurrency limit.
1175    ///
1176    /// The callback is passed the client and a semaphore. It must acquire the semaphore before
1177    /// any request through the client and drop it after.
1178    ///
1179    /// This method serves as an escape hatch for functions that may want to send multiple requests
1180    /// in parallel.
1181    pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
1182    where
1183        F: Future<Output = T>,
1184    {
1185        f(self.unmanaged, &self.control).await
1186    }
1187}
1188
1189/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
1190fn content_length(response: &reqwest::Response) -> Option<u64> {
1191    response
1192        .headers()
1193        .get(reqwest::header::CONTENT_LENGTH)
1194        .and_then(|val| val.to_str().ok())
1195        .and_then(|val| val.parse::<u64>().ok())
1196}
1197
1198/// An asynchronous reader that reports progress as bytes are read.
1199struct ProgressReader<'a, R> {
1200    reader: R,
1201    index: usize,
1202    reporter: &'a dyn Reporter,
1203}
1204
1205impl<'a, R> ProgressReader<'a, R> {
1206    /// Create a new [`ProgressReader`] that wraps another reader.
1207    fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
1208        Self {
1209            reader,
1210            index,
1211            reporter,
1212        }
1213    }
1214}
1215
1216impl<R> AsyncRead for ProgressReader<'_, R>
1217where
1218    R: AsyncRead + Unpin,
1219{
1220    fn poll_read(
1221        mut self: Pin<&mut Self>,
1222        cx: &mut Context<'_>,
1223        buf: &mut ReadBuf<'_>,
1224    ) -> Poll<io::Result<()>> {
1225        Pin::new(&mut self.as_mut().reader)
1226            .poll_read(cx, buf)
1227            .map_ok(|()| {
1228                self.reporter
1229                    .on_download_progress(self.index, buf.filled().len() as u64);
1230            })
1231    }
1232}
1233
1234/// A pointer to an archive in the cache, fetched from an HTTP archive.
1235///
1236/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
1237#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1238pub struct HttpArchivePointer {
1239    archive: Archive,
1240}
1241
1242impl HttpArchivePointer {
1243    /// Read an [`HttpArchivePointer`] from the cache.
1244    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1245        match fs_err::File::open(path.as_ref()) {
1246            Ok(file) => {
1247                let data = DataWithCachePolicy::from_reader(file)?.data;
1248                let archive = rmp_serde::from_slice::<Archive>(&data)?;
1249                Ok(Some(Self { archive }))
1250            }
1251            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1252            Err(err) => Err(Error::CacheRead(err)),
1253        }
1254    }
1255
1256    /// Return the [`Archive`] from the pointer.
1257    pub fn into_archive(self) -> Archive {
1258        self.archive
1259    }
1260
1261    /// Return the [`CacheInfo`] from the pointer.
1262    pub fn to_cache_info(&self) -> CacheInfo {
1263        CacheInfo::default()
1264    }
1265
1266    /// Return the [`BuildInfo`] from the pointer.
1267    pub fn to_build_info(&self) -> Option<BuildInfo> {
1268        None
1269    }
1270}
1271
1272/// A pointer to an archive in the cache, fetched from a local path.
1273///
1274/// Encoded with `MsgPack`, and represented on disk by a `.rev` file.
1275#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1276pub struct LocalArchivePointer {
1277    timestamp: Timestamp,
1278    archive: Archive,
1279}
1280
1281impl LocalArchivePointer {
1282    /// Read an [`LocalArchivePointer`] from the cache.
1283    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1284        match fs_err::read(path) {
1285            Ok(cached) => Ok(Some(rmp_serde::from_slice::<Self>(&cached)?)),
1286            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1287            Err(err) => Err(Error::CacheRead(err)),
1288        }
1289    }
1290
1291    /// Write an [`LocalArchivePointer`] to the cache.
1292    pub async fn write_to(&self, entry: &CacheEntry) -> Result<(), Error> {
1293        write_atomic(entry.path(), rmp_serde::to_vec(&self)?)
1294            .await
1295            .map_err(Error::CacheWrite)
1296    }
1297
1298    /// Returns `true` if the archive is up-to-date with the given modified timestamp.
1299    pub fn is_up_to_date(&self, modified: Timestamp) -> bool {
1300        self.timestamp == modified
1301    }
1302
1303    /// Return the [`Archive`] from the pointer.
1304    pub fn into_archive(self) -> Archive {
1305        self.archive
1306    }
1307
1308    /// Return the [`CacheInfo`] from the pointer.
1309    pub fn to_cache_info(&self) -> CacheInfo {
1310        CacheInfo::from_timestamp(self.timestamp)
1311    }
1312
1313    /// Return the [`BuildInfo`] from the pointer.
1314    pub fn to_build_info(&self) -> Option<BuildInfo> {
1315        None
1316    }
1317}
1318
1319#[derive(Debug, Clone)]
1320struct WheelTarget {
1321    /// The URL from which the wheel can be downloaded.
1322    url: DisplaySafeUrl,
1323    /// The expected extension of the wheel file.
1324    extension: WheelExtension,
1325    /// The expected size of the wheel file, if known.
1326    size: Option<u64>,
1327}
1328
1329impl TryFrom<&File> for WheelTarget {
1330    type Error = ToUrlError;
1331
1332    /// Determine the [`WheelTarget`] from a [`File`].
1333    fn try_from(file: &File) -> Result<Self, Self::Error> {
1334        let url = file.url.to_url()?;
1335        if let Some(zstd) = file.zstd.as_ref() {
1336            Ok(Self {
1337                url: add_tar_zst_extension(url),
1338                extension: WheelExtension::WhlZst,
1339                size: zstd.size,
1340            })
1341        } else {
1342            Ok(Self {
1343                url,
1344                extension: WheelExtension::Whl,
1345                size: file.size,
1346            })
1347        }
1348    }
1349}
1350
1351#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1352enum WheelExtension {
1353    /// A `.whl` file.
1354    Whl,
1355    /// A `.whl.tar.zst` file.
1356    WhlZst,
1357}
1358
1359/// Add `.tar.zst` to the end of the URL path, if it doesn't already exist.
1360#[must_use]
1361fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {
1362    let mut path = url.path().to_string();
1363
1364    if !path.ends_with(".tar.zst") {
1365        path.push_str(".tar.zst");
1366    }
1367
1368    url.set_path(&path);
1369    url
1370}
1371
1372#[cfg(test)]
1373mod tests {
1374    use super::*;
1375
1376    #[test]
1377    fn test_add_tar_zst_extension() {
1378        let url =
1379            DisplaySafeUrl::parse("https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl")
1380                .unwrap();
1381        assert_eq!(
1382            add_tar_zst_extension(url).as_str(),
1383            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1384        );
1385
1386        let url = DisplaySafeUrl::parse(
1387            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst",
1388        )
1389        .unwrap();
1390        assert_eq!(
1391            add_tar_zst_extension(url).as_str(),
1392            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1393        );
1394
1395        let url = DisplaySafeUrl::parse(
1396            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl",
1397        )
1398        .unwrap();
1399        assert_eq!(
1400            add_tar_zst_extension(url).as_str(),
1401            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl.tar.zst"
1402        );
1403    }
1404}