uv_distribution/
distribution_database.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{FutureExt, TryStreamExt};
9use tempfile::TempDir;
10use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
11use tokio::sync::Semaphore;
12use tokio_util::compat::FuturesAsyncReadCompatExt;
13use tracing::{Instrument, info_span, instrument, warn};
14use url::Url;
15
16use uv_cache::{ArchiveId, CacheBucket, CacheEntry, WheelCache};
17use uv_cache_info::{CacheInfo, Timestamp};
18use uv_client::{
19    CacheControl, CachedClientError, Connectivity, DataWithCachePolicy, RegistryClient,
20};
21use uv_distribution_filename::WheelFilename;
22use uv_distribution_types::{
23    BuildInfo, BuildableSource, BuiltDist, Dist, File, HashPolicy, Hashed, IndexUrl, InstalledDist,
24    Name, SourceDist, ToUrlError,
25};
26use uv_extract::hash::Hasher;
27use uv_fs::write_atomic;
28use uv_platform_tags::Tags;
29use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
30use uv_redacted::DisplaySafeUrl;
31use uv_types::{BuildContext, BuildStack};
32
33use crate::archive::Archive;
34use crate::metadata::{ArchiveMetadata, Metadata};
35use crate::source::SourceDistributionBuilder;
36use crate::{Error, LocalWheel, Reporter, RequiresDist};
37
38/// A cached high-level interface to convert distributions (a requirement resolved to a location)
39/// to a wheel or wheel metadata.
40///
41/// For wheel metadata, this happens by either fetching the metadata from the remote wheel or by
42/// building the source distribution. For wheel files, either the wheel is downloaded or a source
43/// distribution is downloaded, built and the new wheel gets returned.
44///
45/// All kinds of wheel sources (index, URL, path) and source distribution source (index, URL, path,
46/// Git) are supported.
47///
48/// This struct also has the task of acquiring locks around source dist builds in general and git
49/// operation especially, as well as respecting concurrency limits.
50pub struct DistributionDatabase<'a, Context: BuildContext> {
51    build_context: &'a Context,
52    builder: SourceDistributionBuilder<'a, Context>,
53    client: ManagedClient<'a>,
54    reporter: Option<Arc<dyn Reporter>>,
55}
56
57impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
58    pub fn new(
59        client: &'a RegistryClient,
60        build_context: &'a Context,
61        concurrent_downloads: usize,
62    ) -> Self {
63        Self {
64            build_context,
65            builder: SourceDistributionBuilder::new(build_context),
66            client: ManagedClient::new(client, concurrent_downloads),
67            reporter: None,
68        }
69    }
70
71    /// Set the build stack to use for the [`DistributionDatabase`].
72    #[must_use]
73    pub fn with_build_stack(self, build_stack: &'a BuildStack) -> Self {
74        Self {
75            builder: self.builder.with_build_stack(build_stack),
76            ..self
77        }
78    }
79
80    /// Set the [`Reporter`] to use for the [`DistributionDatabase`].
81    #[must_use]
82    pub fn with_reporter(self, reporter: Arc<dyn Reporter>) -> Self {
83        Self {
84            builder: self.builder.with_reporter(reporter.clone()),
85            reporter: Some(reporter),
86            ..self
87        }
88    }
89
90    /// Handle a specific `reqwest` error, and convert it to [`io::Error`].
91    fn handle_response_errors(&self, err: reqwest::Error) -> io::Error {
92        if err.is_timeout() {
93            io::Error::new(
94                io::ErrorKind::TimedOut,
95                format!(
96                    "Failed to download distribution due to network timeout. Try increasing UV_HTTP_TIMEOUT (current value: {}s).",
97                    self.client.unmanaged.timeout().as_secs()
98                ),
99            )
100        } else {
101            io::Error::other(err)
102        }
103    }
104
105    /// Either fetch the wheel or fetch and build the source distribution
106    ///
107    /// Returns a wheel that's compliant with the given platform tags.
108    ///
109    /// While hashes will be generated in some cases, hash-checking is only enforced for source
110    /// distributions, and should be enforced by the caller for wheels.
111    #[instrument(skip_all, fields(%dist))]
112    pub async fn get_or_build_wheel(
113        &self,
114        dist: &Dist,
115        tags: &Tags,
116        hashes: HashPolicy<'_>,
117    ) -> Result<LocalWheel, Error> {
118        match dist {
119            Dist::Built(built) => self.get_wheel(built, hashes).await,
120            Dist::Source(source) => self.build_wheel(source, tags, hashes).await,
121        }
122    }
123
124    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
125    /// fetch and build the source distribution.
126    ///
127    /// While hashes will be generated in some cases, hash-checking is only enforced for source
128    /// distributions, and should be enforced by the caller for wheels.
129    #[instrument(skip_all, fields(%dist))]
130    pub async fn get_installed_metadata(
131        &self,
132        dist: &InstalledDist,
133    ) -> Result<ArchiveMetadata, Error> {
134        // If the metadata was provided by the user directly, prefer it.
135        if let Some(metadata) = self
136            .build_context
137            .dependency_metadata()
138            .get(dist.name(), Some(dist.version()))
139        {
140            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
141        }
142
143        let metadata = dist
144            .read_metadata()
145            .map_err(|err| Error::ReadInstalled(Box::new(dist.clone()), err))?;
146
147        Ok(ArchiveMetadata::from_metadata23(metadata.clone()))
148    }
149
150    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
151    /// fetch and build the source distribution.
152    ///
153    /// While hashes will be generated in some cases, hash-checking is only enforced for source
154    /// distributions, and should be enforced by the caller for wheels.
155    #[instrument(skip_all, fields(%dist))]
156    pub async fn get_or_build_wheel_metadata(
157        &self,
158        dist: &Dist,
159        hashes: HashPolicy<'_>,
160    ) -> Result<ArchiveMetadata, Error> {
161        match dist {
162            Dist::Built(built) => self.get_wheel_metadata(built, hashes).await,
163            Dist::Source(source) => {
164                self.build_wheel_metadata(&BuildableSource::Dist(source), hashes)
165                    .await
166            }
167        }
168    }
169
170    /// Fetch a wheel from the cache or download it from the index.
171    ///
172    /// While hashes will be generated in all cases, hash-checking is _not_ enforced and should
173    /// instead be enforced by the caller.
174    async fn get_wheel(
175        &self,
176        dist: &BuiltDist,
177        hashes: HashPolicy<'_>,
178    ) -> Result<LocalWheel, Error> {
179        match dist {
180            BuiltDist::Registry(wheels) => {
181                let wheel = wheels.best_wheel();
182                let WheelTarget {
183                    url,
184                    extension,
185                    size,
186                } = WheelTarget::try_from(&*wheel.file)?;
187
188                // Create a cache entry for the wheel.
189                let wheel_entry = self.build_context.cache().entry(
190                    CacheBucket::Wheels,
191                    WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
192                    wheel.filename.cache_key(),
193                );
194
195                // If the URL is a file URL, load the wheel directly.
196                if url.scheme() == "file" {
197                    let path = url
198                        .to_file_path()
199                        .map_err(|()| Error::NonFileUrl(url.clone()))?;
200                    return self
201                        .load_wheel(
202                            &path,
203                            &wheel.filename,
204                            WheelExtension::Whl,
205                            wheel_entry,
206                            dist,
207                            hashes,
208                        )
209                        .await;
210                }
211
212                // Download and unzip.
213                match self
214                    .stream_wheel(
215                        url.clone(),
216                        dist.index(),
217                        &wheel.filename,
218                        extension,
219                        size,
220                        &wheel_entry,
221                        dist,
222                        hashes,
223                    )
224                    .await
225                {
226                    Ok(archive) => Ok(LocalWheel {
227                        dist: Dist::Built(dist.clone()),
228                        archive: self
229                            .build_context
230                            .cache()
231                            .archive(&archive.id)
232                            .into_boxed_path(),
233                        hashes: archive.hashes,
234                        filename: wheel.filename.clone(),
235                        cache: CacheInfo::default(),
236                        build: None,
237                    }),
238                    Err(Error::Extract(name, err)) => {
239                        if err.is_http_streaming_unsupported() {
240                            warn!(
241                                "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
242                            );
243                        } else if err.is_http_streaming_failed() {
244                            warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
245                        } else {
246                            return Err(Error::Extract(name, err));
247                        }
248
249                        // If the request failed because streaming is unsupported, download the
250                        // wheel directly.
251                        let archive = self
252                            .download_wheel(
253                                url,
254                                dist.index(),
255                                &wheel.filename,
256                                extension,
257                                size,
258                                &wheel_entry,
259                                dist,
260                                hashes,
261                            )
262                            .await?;
263
264                        Ok(LocalWheel {
265                            dist: Dist::Built(dist.clone()),
266                            archive: self
267                                .build_context
268                                .cache()
269                                .archive(&archive.id)
270                                .into_boxed_path(),
271                            hashes: archive.hashes,
272                            filename: wheel.filename.clone(),
273                            cache: CacheInfo::default(),
274                            build: None,
275                        })
276                    }
277                    Err(err) => Err(err),
278                }
279            }
280
281            BuiltDist::DirectUrl(wheel) => {
282                // Create a cache entry for the wheel.
283                let wheel_entry = self.build_context.cache().entry(
284                    CacheBucket::Wheels,
285                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
286                    wheel.filename.cache_key(),
287                );
288
289                // Download and unzip.
290                match self
291                    .stream_wheel(
292                        wheel.url.raw().clone(),
293                        None,
294                        &wheel.filename,
295                        WheelExtension::Whl,
296                        None,
297                        &wheel_entry,
298                        dist,
299                        hashes,
300                    )
301                    .await
302                {
303                    Ok(archive) => Ok(LocalWheel {
304                        dist: Dist::Built(dist.clone()),
305                        archive: self
306                            .build_context
307                            .cache()
308                            .archive(&archive.id)
309                            .into_boxed_path(),
310                        hashes: archive.hashes,
311                        filename: wheel.filename.clone(),
312                        cache: CacheInfo::default(),
313                        build: None,
314                    }),
315                    Err(Error::Client(err)) if err.is_http_streaming_unsupported() => {
316                        warn!(
317                            "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
318                        );
319
320                        // If the request failed because streaming is unsupported, download the
321                        // wheel directly.
322                        let archive = self
323                            .download_wheel(
324                                wheel.url.raw().clone(),
325                                None,
326                                &wheel.filename,
327                                WheelExtension::Whl,
328                                None,
329                                &wheel_entry,
330                                dist,
331                                hashes,
332                            )
333                            .await?;
334                        Ok(LocalWheel {
335                            dist: Dist::Built(dist.clone()),
336                            archive: self
337                                .build_context
338                                .cache()
339                                .archive(&archive.id)
340                                .into_boxed_path(),
341                            hashes: archive.hashes,
342                            filename: wheel.filename.clone(),
343                            cache: CacheInfo::default(),
344                            build: None,
345                        })
346                    }
347                    Err(err) => Err(err),
348                }
349            }
350
351            BuiltDist::Path(wheel) => {
352                let cache_entry = self.build_context.cache().entry(
353                    CacheBucket::Wheels,
354                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
355                    wheel.filename.cache_key(),
356                );
357
358                self.load_wheel(
359                    &wheel.install_path,
360                    &wheel.filename,
361                    WheelExtension::Whl,
362                    cache_entry,
363                    dist,
364                    hashes,
365                )
366                .await
367            }
368        }
369    }
370
371    /// Convert a source distribution into a wheel, fetching it from the cache or building it if
372    /// necessary.
373    ///
374    /// The returned wheel is guaranteed to come from a distribution with a matching hash, and
375    /// no build processes will be executed for distributions with mismatched hashes.
376    async fn build_wheel(
377        &self,
378        dist: &SourceDist,
379        tags: &Tags,
380        hashes: HashPolicy<'_>,
381    ) -> Result<LocalWheel, Error> {
382        let built_wheel = self
383            .builder
384            .download_and_build(&BuildableSource::Dist(dist), tags, hashes, &self.client)
385            .boxed_local()
386            .await?;
387
388        // Check that the wheel is compatible with its install target.
389        //
390        // When building a build dependency for a cross-install, the build dependency needs
391        // to install and run on the host instead of the target. In this case the `tags` are already
392        // for the host instead of the target, so this check passes.
393        if !built_wheel.filename.is_compatible(tags) {
394            return if tags.is_cross() {
395                Err(Error::BuiltWheelIncompatibleTargetPlatform {
396                    filename: built_wheel.filename,
397                    python_platform: tags.python_platform().clone(),
398                    python_version: tags.python_version(),
399                })
400            } else {
401                Err(Error::BuiltWheelIncompatibleHostPlatform {
402                    filename: built_wheel.filename,
403                    python_platform: tags.python_platform().clone(),
404                    python_version: tags.python_version(),
405                })
406            };
407        }
408
409        // Acquire the advisory lock.
410        #[cfg(windows)]
411        let _lock = {
412            let lock_entry = CacheEntry::new(
413                built_wheel.target.parent().unwrap(),
414                format!(
415                    "{}.lock",
416                    built_wheel.target.file_name().unwrap().to_str().unwrap()
417                ),
418            );
419            lock_entry.lock().await.map_err(Error::CacheLock)?
420        };
421
422        // If the wheel was unzipped previously, respect it. Source distributions are
423        // cached under a unique revision ID, so unzipped directories are never stale.
424        match self.build_context.cache().resolve_link(&built_wheel.target) {
425            Ok(archive) => {
426                return Ok(LocalWheel {
427                    dist: Dist::Source(dist.clone()),
428                    archive: archive.into_boxed_path(),
429                    filename: built_wheel.filename,
430                    hashes: built_wheel.hashes,
431                    cache: built_wheel.cache_info,
432                    build: Some(built_wheel.build_info),
433                });
434            }
435            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
436            Err(err) => return Err(Error::CacheRead(err)),
437        }
438
439        // Otherwise, unzip the wheel.
440        let id = self
441            .unzip_wheel(&built_wheel.path, &built_wheel.target)
442            .await?;
443
444        Ok(LocalWheel {
445            dist: Dist::Source(dist.clone()),
446            archive: self.build_context.cache().archive(&id).into_boxed_path(),
447            hashes: built_wheel.hashes,
448            filename: built_wheel.filename,
449            cache: built_wheel.cache_info,
450            build: Some(built_wheel.build_info),
451        })
452    }
453
454    /// Fetch the wheel metadata from the index, or from the cache if possible.
455    ///
456    /// While hashes will be generated in some cases, hash-checking is _not_ enforced and should
457    /// instead be enforced by the caller.
458    async fn get_wheel_metadata(
459        &self,
460        dist: &BuiltDist,
461        hashes: HashPolicy<'_>,
462    ) -> Result<ArchiveMetadata, Error> {
463        // If hash generation is enabled, and the distribution isn't hosted on a registry, get the
464        // entire wheel to ensure that the hashes are included in the response. If the distribution
465        // is hosted on an index, the hashes will be included in the simple metadata response.
466        // For hash _validation_, callers are expected to enforce the policy when retrieving the
467        // wheel.
468        //
469        // Historically, for `uv pip compile --universal`, we also generate hashes for
470        // registry-based distributions when the relevant registry doesn't provide them. This was
471        // motivated by `--find-links`. We continue that behavior (under `HashGeneration::All`) for
472        // backwards compatibility, but it's a little dubious, since we're only hashing _one_
473        // distribution here (as opposed to hashing all distributions for the version), and it may
474        // not even be a compatible distribution!
475        //
476        // TODO(charlie): Request the hashes via a separate method, to reduce the coupling in this API.
477        if hashes.is_generate(dist) {
478            let wheel = self.get_wheel(dist, hashes).await?;
479            // If the metadata was provided by the user directly, prefer it.
480            let metadata = if let Some(metadata) = self
481                .build_context
482                .dependency_metadata()
483                .get(dist.name(), Some(dist.version()))
484            {
485                metadata.clone()
486            } else {
487                wheel.metadata()?
488            };
489            let hashes = wheel.hashes;
490            return Ok(ArchiveMetadata {
491                metadata: Metadata::from_metadata23(metadata),
492                hashes,
493            });
494        }
495
496        // If the metadata was provided by the user directly, prefer it.
497        if let Some(metadata) = self
498            .build_context
499            .dependency_metadata()
500            .get(dist.name(), Some(dist.version()))
501        {
502            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
503        }
504
505        let result = self
506            .client
507            .managed(|client| {
508                client
509                    .wheel_metadata(dist, self.build_context.capabilities())
510                    .boxed_local()
511            })
512            .await;
513
514        match result {
515            Ok(metadata) => {
516                // Validate that the metadata is consistent with the distribution.
517                Ok(ArchiveMetadata::from_metadata23(metadata))
518            }
519            Err(err) if err.is_http_streaming_unsupported() => {
520                warn!(
521                    "Streaming unsupported when fetching metadata for {dist}; downloading wheel directly ({err})"
522                );
523
524                // If the request failed due to an error that could be resolved by
525                // downloading the wheel directly, try that.
526                let wheel = self.get_wheel(dist, hashes).await?;
527                let metadata = wheel.metadata()?;
528                let hashes = wheel.hashes;
529                Ok(ArchiveMetadata {
530                    metadata: Metadata::from_metadata23(metadata),
531                    hashes,
532                })
533            }
534            Err(err) => Err(err.into()),
535        }
536    }
537
538    /// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
539    ///
540    /// The returned metadata is guaranteed to come from a distribution with a matching hash, and
541    /// no build processes will be executed for distributions with mismatched hashes.
542    pub async fn build_wheel_metadata(
543        &self,
544        source: &BuildableSource<'_>,
545        hashes: HashPolicy<'_>,
546    ) -> Result<ArchiveMetadata, Error> {
547        // If the metadata was provided by the user directly, prefer it.
548        if let Some(dist) = source.as_dist() {
549            if let Some(metadata) = self
550                .build_context
551                .dependency_metadata()
552                .get(dist.name(), dist.version())
553            {
554                // If we skipped the build, we should still resolve any Git dependencies to precise
555                // commits.
556                self.builder.resolve_revision(source, &self.client).await?;
557
558                return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
559            }
560        }
561
562        let metadata = self
563            .builder
564            .download_and_build_metadata(source, hashes, &self.client)
565            .boxed_local()
566            .await?;
567
568        Ok(metadata)
569    }
570
571    /// Return the [`RequiresDist`] from a `pyproject.toml`, if it can be statically extracted.
572    pub async fn requires_dist(
573        &self,
574        path: &Path,
575        pyproject_toml: &PyProjectToml,
576    ) -> Result<Option<RequiresDist>, Error> {
577        self.builder
578            .source_tree_requires_dist(
579                path,
580                pyproject_toml,
581                self.client.unmanaged.credentials_cache(),
582            )
583            .await
584    }
585
586    /// Stream a wheel from a URL, unzipping it into the cache as it's downloaded.
587    async fn stream_wheel(
588        &self,
589        url: DisplaySafeUrl,
590        index: Option<&IndexUrl>,
591        filename: &WheelFilename,
592        extension: WheelExtension,
593        size: Option<u64>,
594        wheel_entry: &CacheEntry,
595        dist: &BuiltDist,
596        hashes: HashPolicy<'_>,
597    ) -> Result<Archive, Error> {
598        // Acquire an advisory lock, to guard against concurrent writes.
599        #[cfg(windows)]
600        let _lock = {
601            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
602            lock_entry.lock().await.map_err(Error::CacheLock)?
603        };
604
605        // Create an entry for the HTTP cache.
606        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
607
608        let download = |response: reqwest::Response| {
609            async {
610                let size = size.or_else(|| content_length(&response));
611
612                let progress = self
613                    .reporter
614                    .as_ref()
615                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
616
617                let reader = response
618                    .bytes_stream()
619                    .map_err(|err| self.handle_response_errors(err))
620                    .into_async_read();
621
622                // Create a hasher for each hash algorithm.
623                let algorithms = hashes.algorithms();
624                let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
625                let mut hasher = uv_extract::hash::HashReader::new(reader.compat(), &mut hashers);
626
627                // Download and unzip the wheel to a temporary directory.
628                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
629                    .map_err(Error::CacheWrite)?;
630
631                match progress {
632                    Some((reporter, progress)) => {
633                        let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
634                        match extension {
635                            WheelExtension::Whl => {
636                                uv_extract::stream::unzip(&mut reader, temp_dir.path())
637                                    .await
638                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
639                            }
640                            WheelExtension::WhlZst => {
641                                uv_extract::stream::untar_zst(&mut reader, temp_dir.path())
642                                    .await
643                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
644                            }
645                        }
646                    }
647                    None => match extension {
648                        WheelExtension::Whl => {
649                            uv_extract::stream::unzip(&mut hasher, temp_dir.path())
650                                .await
651                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
652                        }
653                        WheelExtension::WhlZst => {
654                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
655                                .await
656                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
657                        }
658                    },
659                }
660
661                // If necessary, exhaust the reader to compute the hash.
662                if !hashes.is_none() {
663                    hasher.finish().await.map_err(Error::HashExhaustion)?;
664                }
665
666                // Persist the temporary directory to the directory store.
667                let id = self
668                    .build_context
669                    .cache()
670                    .persist(temp_dir.keep(), wheel_entry.path())
671                    .await
672                    .map_err(Error::CacheRead)?;
673
674                if let Some((reporter, progress)) = progress {
675                    reporter.on_download_complete(dist.name(), progress);
676                }
677
678                Ok(Archive::new(
679                    id,
680                    hashers.into_iter().map(HashDigest::from).collect(),
681                    filename.clone(),
682                ))
683            }
684            .instrument(info_span!("wheel", wheel = %dist))
685        };
686
687        // Fetch the archive from the cache, or download it if necessary.
688        let req = self.request(url.clone())?;
689
690        // Determine the cache control policy for the URL.
691        let cache_control = match self.client.unmanaged.connectivity() {
692            Connectivity::Online => {
693                if let Some(header) = index.and_then(|index| {
694                    self.build_context
695                        .locations()
696                        .artifact_cache_control_for(index)
697                }) {
698                    CacheControl::Override(header)
699                } else {
700                    CacheControl::from(
701                        self.build_context
702                            .cache()
703                            .freshness(&http_entry, Some(&filename.name), None)
704                            .map_err(Error::CacheRead)?,
705                    )
706                }
707            }
708            Connectivity::Offline => CacheControl::AllowStale,
709        };
710
711        let archive = self
712            .client
713            .managed(|client| {
714                client.cached_client().get_serde_with_retry(
715                    req,
716                    &http_entry,
717                    cache_control,
718                    download,
719                )
720            })
721            .await
722            .map_err(|err| match err {
723                CachedClientError::Callback { err, .. } => err,
724                CachedClientError::Client { err, .. } => Error::Client(err),
725            })?;
726
727        // If the archive is missing the required hashes, or has since been removed, force a refresh.
728        let archive = Some(archive)
729            .filter(|archive| archive.has_digests(hashes))
730            .filter(|archive| archive.exists(self.build_context.cache()));
731
732        let archive = if let Some(archive) = archive {
733            archive
734        } else {
735            self.client
736                .managed(async |client| {
737                    client
738                        .cached_client()
739                        .skip_cache_with_retry(
740                            self.request(url)?,
741                            &http_entry,
742                            cache_control,
743                            download,
744                        )
745                        .await
746                        .map_err(|err| match err {
747                            CachedClientError::Callback { err, .. } => err,
748                            CachedClientError::Client { err, .. } => Error::Client(err),
749                        })
750                })
751                .await?
752        };
753
754        Ok(archive)
755    }
756
757    /// Download a wheel from a URL, then unzip it into the cache.
758    async fn download_wheel(
759        &self,
760        url: DisplaySafeUrl,
761        index: Option<&IndexUrl>,
762        filename: &WheelFilename,
763        extension: WheelExtension,
764        size: Option<u64>,
765        wheel_entry: &CacheEntry,
766        dist: &BuiltDist,
767        hashes: HashPolicy<'_>,
768    ) -> Result<Archive, Error> {
769        // Acquire an advisory lock, to guard against concurrent writes.
770        #[cfg(windows)]
771        let _lock = {
772            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
773            lock_entry.lock().await.map_err(Error::CacheLock)?
774        };
775
776        // Create an entry for the HTTP cache.
777        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
778
779        let download = |response: reqwest::Response| {
780            async {
781                let size = size.or_else(|| content_length(&response));
782
783                let progress = self
784                    .reporter
785                    .as_ref()
786                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
787
788                let reader = response
789                    .bytes_stream()
790                    .map_err(|err| self.handle_response_errors(err))
791                    .into_async_read();
792
793                // Download the wheel to a temporary file.
794                let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
795                    .map_err(Error::CacheWrite)?;
796                let mut writer = tokio::io::BufWriter::new(fs_err::tokio::File::from_std(
797                    // It's an unnamed file on Linux so that's the best approximation.
798                    fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
799                ));
800
801                match progress {
802                    Some((reporter, progress)) => {
803                        // Wrap the reader in a progress reporter. This will report 100% progress
804                        // after the download is complete, even if we still have to unzip and hash
805                        // part of the file.
806                        let mut reader =
807                            ProgressReader::new(reader.compat(), progress, &**reporter);
808
809                        tokio::io::copy(&mut reader, &mut writer)
810                            .await
811                            .map_err(Error::CacheWrite)?;
812                    }
813                    None => {
814                        tokio::io::copy(&mut reader.compat(), &mut writer)
815                            .await
816                            .map_err(Error::CacheWrite)?;
817                    }
818                }
819
820                // Unzip the wheel to a temporary directory.
821                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
822                    .map_err(Error::CacheWrite)?;
823                let mut file = writer.into_inner();
824                file.seek(io::SeekFrom::Start(0))
825                    .await
826                    .map_err(Error::CacheWrite)?;
827
828                // If no hashes are required, parallelize the unzip operation.
829                let hashes = if hashes.is_none() {
830                    let file = file.into_std().await;
831                    tokio::task::spawn_blocking({
832                        let target = temp_dir.path().to_owned();
833                        move || -> Result<(), uv_extract::Error> {
834                            // Unzip the wheel into a temporary directory.
835                            match extension {
836                                WheelExtension::Whl => {
837                                    uv_extract::unzip(file, &target)?;
838                                }
839                                WheelExtension::WhlZst => {
840                                    uv_extract::stream::untar_zst_file(file, &target)?;
841                                }
842                            }
843                            Ok(())
844                        }
845                    })
846                    .await?
847                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
848
849                    HashDigests::empty()
850                } else {
851                    // Create a hasher for each hash algorithm.
852                    let algorithms = hashes.algorithms();
853                    let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
854                    let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
855
856                    match extension {
857                        WheelExtension::Whl => {
858                            uv_extract::stream::unzip(&mut hasher, temp_dir.path())
859                                .await
860                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
861                        }
862                        WheelExtension::WhlZst => {
863                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
864                                .await
865                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
866                        }
867                    }
868
869                    // If necessary, exhaust the reader to compute the hash.
870                    hasher.finish().await.map_err(Error::HashExhaustion)?;
871
872                    hashers.into_iter().map(HashDigest::from).collect()
873                };
874
875                // Persist the temporary directory to the directory store.
876                let id = self
877                    .build_context
878                    .cache()
879                    .persist(temp_dir.keep(), wheel_entry.path())
880                    .await
881                    .map_err(Error::CacheRead)?;
882
883                if let Some((reporter, progress)) = progress {
884                    reporter.on_download_complete(dist.name(), progress);
885                }
886
887                Ok(Archive::new(id, hashes, filename.clone()))
888            }
889            .instrument(info_span!("wheel", wheel = %dist))
890        };
891
892        // Fetch the archive from the cache, or download it if necessary.
893        let req = self.request(url.clone())?;
894
895        // Determine the cache control policy for the URL.
896        let cache_control = match self.client.unmanaged.connectivity() {
897            Connectivity::Online => {
898                if let Some(header) = index.and_then(|index| {
899                    self.build_context
900                        .locations()
901                        .artifact_cache_control_for(index)
902                }) {
903                    CacheControl::Override(header)
904                } else {
905                    CacheControl::from(
906                        self.build_context
907                            .cache()
908                            .freshness(&http_entry, Some(&filename.name), None)
909                            .map_err(Error::CacheRead)?,
910                    )
911                }
912            }
913            Connectivity::Offline => CacheControl::AllowStale,
914        };
915
916        let archive = self
917            .client
918            .managed(|client| {
919                client.cached_client().get_serde_with_retry(
920                    req,
921                    &http_entry,
922                    cache_control,
923                    download,
924                )
925            })
926            .await
927            .map_err(|err| match err {
928                CachedClientError::Callback { err, .. } => err,
929                CachedClientError::Client { err, .. } => Error::Client(err),
930            })?;
931
932        // If the archive is missing the required hashes, or has since been removed, force a refresh.
933        let archive = Some(archive)
934            .filter(|archive| archive.has_digests(hashes))
935            .filter(|archive| archive.exists(self.build_context.cache()));
936
937        let archive = if let Some(archive) = archive {
938            archive
939        } else {
940            self.client
941                .managed(async |client| {
942                    client
943                        .cached_client()
944                        .skip_cache_with_retry(
945                            self.request(url)?,
946                            &http_entry,
947                            cache_control,
948                            download,
949                        )
950                        .await
951                        .map_err(|err| match err {
952                            CachedClientError::Callback { err, .. } => err,
953                            CachedClientError::Client { err, .. } => Error::Client(err),
954                        })
955                })
956                .await?
957        };
958
959        Ok(archive)
960    }
961
962    /// Load a wheel from a local path.
963    async fn load_wheel(
964        &self,
965        path: &Path,
966        filename: &WheelFilename,
967        extension: WheelExtension,
968        wheel_entry: CacheEntry,
969        dist: &BuiltDist,
970        hashes: HashPolicy<'_>,
971    ) -> Result<LocalWheel, Error> {
972        #[cfg(windows)]
973        let _lock = {
974            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
975            lock_entry.lock().await.map_err(Error::CacheLock)?
976        };
977
978        // Determine the last-modified time of the wheel.
979        let modified = Timestamp::from_path(path).map_err(Error::CacheRead)?;
980
981        // Attempt to read the archive pointer from the cache.
982        let pointer_entry = wheel_entry.with_file(format!("{}.rev", filename.cache_key()));
983        let pointer = LocalArchivePointer::read_from(&pointer_entry)?;
984
985        // Extract the archive from the pointer.
986        let archive = pointer
987            .filter(|pointer| pointer.is_up_to_date(modified))
988            .map(LocalArchivePointer::into_archive)
989            .filter(|archive| archive.has_digests(hashes));
990
991        // If the file is already unzipped, and the cache is up-to-date, return it.
992        if let Some(archive) = archive {
993            Ok(LocalWheel {
994                dist: Dist::Built(dist.clone()),
995                archive: self
996                    .build_context
997                    .cache()
998                    .archive(&archive.id)
999                    .into_boxed_path(),
1000                hashes: archive.hashes,
1001                filename: filename.clone(),
1002                cache: CacheInfo::from_timestamp(modified),
1003                build: None,
1004            })
1005        } else if hashes.is_none() {
1006            // Otherwise, unzip the wheel.
1007            let archive = Archive::new(
1008                self.unzip_wheel(path, wheel_entry.path()).await?,
1009                HashDigests::empty(),
1010                filename.clone(),
1011            );
1012
1013            // Write the archive pointer to the cache.
1014            let pointer = LocalArchivePointer {
1015                timestamp: modified,
1016                archive: archive.clone(),
1017            };
1018            pointer.write_to(&pointer_entry).await?;
1019
1020            Ok(LocalWheel {
1021                dist: Dist::Built(dist.clone()),
1022                archive: self
1023                    .build_context
1024                    .cache()
1025                    .archive(&archive.id)
1026                    .into_boxed_path(),
1027                hashes: archive.hashes,
1028                filename: filename.clone(),
1029                cache: CacheInfo::from_timestamp(modified),
1030                build: None,
1031            })
1032        } else {
1033            // If necessary, compute the hashes of the wheel.
1034            let file = fs_err::tokio::File::open(path)
1035                .await
1036                .map_err(Error::CacheRead)?;
1037            let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
1038                .map_err(Error::CacheWrite)?;
1039
1040            // Create a hasher for each hash algorithm.
1041            let algorithms = hashes.algorithms();
1042            let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
1043            let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
1044
1045            // Unzip the wheel to a temporary directory.
1046            match extension {
1047                WheelExtension::Whl => {
1048                    uv_extract::stream::unzip(&mut hasher, temp_dir.path())
1049                        .await
1050                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1051                }
1052                WheelExtension::WhlZst => {
1053                    uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
1054                        .await
1055                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1056                }
1057            }
1058
1059            // Exhaust the reader to compute the hash.
1060            hasher.finish().await.map_err(Error::HashExhaustion)?;
1061
1062            let hashes = hashers.into_iter().map(HashDigest::from).collect();
1063
1064            // Persist the temporary directory to the directory store.
1065            let id = self
1066                .build_context
1067                .cache()
1068                .persist(temp_dir.keep(), wheel_entry.path())
1069                .await
1070                .map_err(Error::CacheWrite)?;
1071
1072            // Create an archive.
1073            let archive = Archive::new(id, hashes, filename.clone());
1074
1075            // Write the archive pointer to the cache.
1076            let pointer = LocalArchivePointer {
1077                timestamp: modified,
1078                archive: archive.clone(),
1079            };
1080            pointer.write_to(&pointer_entry).await?;
1081
1082            Ok(LocalWheel {
1083                dist: Dist::Built(dist.clone()),
1084                archive: self
1085                    .build_context
1086                    .cache()
1087                    .archive(&archive.id)
1088                    .into_boxed_path(),
1089                hashes: archive.hashes,
1090                filename: filename.clone(),
1091                cache: CacheInfo::from_timestamp(modified),
1092                build: None,
1093            })
1094        }
1095    }
1096
1097    /// Unzip a wheel into the cache, returning the path to the unzipped directory.
1098    async fn unzip_wheel(&self, path: &Path, target: &Path) -> Result<ArchiveId, Error> {
1099        let temp_dir = tokio::task::spawn_blocking({
1100            let path = path.to_owned();
1101            let root = self.build_context.cache().root().to_path_buf();
1102            move || -> Result<TempDir, Error> {
1103                // Unzip the wheel into a temporary directory.
1104                let temp_dir = tempfile::tempdir_in(root).map_err(Error::CacheWrite)?;
1105                let reader = fs_err::File::open(&path).map_err(Error::CacheWrite)?;
1106                uv_extract::unzip(reader, temp_dir.path())
1107                    .map_err(|err| Error::Extract(path.to_string_lossy().into_owned(), err))?;
1108                Ok(temp_dir)
1109            }
1110        })
1111        .await??;
1112
1113        // Persist the temporary directory to the directory store.
1114        let id = self
1115            .build_context
1116            .cache()
1117            .persist(temp_dir.keep(), target)
1118            .await
1119            .map_err(Error::CacheWrite)?;
1120
1121        Ok(id)
1122    }
1123
1124    /// Returns a GET [`reqwest::Request`] for the given URL.
1125    fn request(&self, url: DisplaySafeUrl) -> Result<reqwest::Request, reqwest::Error> {
1126        self.client
1127            .unmanaged
1128            .uncached_client(&url)
1129            .get(Url::from(url))
1130            .header(
1131                // `reqwest` defaults to accepting compressed responses.
1132                // Specify identity encoding to get consistent .whl downloading
1133                // behavior from servers. ref: https://github.com/pypa/pip/pull/1688
1134                "accept-encoding",
1135                reqwest::header::HeaderValue::from_static("identity"),
1136            )
1137            .build()
1138    }
1139
1140    /// Return the [`ManagedClient`] used by this resolver.
1141    pub fn client(&self) -> &ManagedClient<'a> {
1142        &self.client
1143    }
1144}
1145
1146/// A wrapper around `RegistryClient` that manages a concurrency limit.
1147pub struct ManagedClient<'a> {
1148    pub unmanaged: &'a RegistryClient,
1149    control: Semaphore,
1150}
1151
1152impl<'a> ManagedClient<'a> {
1153    /// Create a new `ManagedClient` using the given client and concurrency limit.
1154    fn new(client: &'a RegistryClient, concurrency: usize) -> Self {
1155        ManagedClient {
1156            unmanaged: client,
1157            control: Semaphore::new(concurrency),
1158        }
1159    }
1160
1161    /// Perform a request using the client, respecting the concurrency limit.
1162    ///
1163    /// If the concurrency limit has been reached, this method will wait until a pending
1164    /// operation completes before executing the closure.
1165    pub async fn managed<F, T>(&self, f: impl FnOnce(&'a RegistryClient) -> F) -> T
1166    where
1167        F: Future<Output = T>,
1168    {
1169        let _permit = self.control.acquire().await.unwrap();
1170        f(self.unmanaged).await
1171    }
1172
1173    /// Perform a request using a client that internally manages the concurrency limit.
1174    ///
1175    /// The callback is passed the client and a semaphore. It must acquire the semaphore before
1176    /// any request through the client and drop it after.
1177    ///
1178    /// This method serves as an escape hatch for functions that may want to send multiple requests
1179    /// in parallel.
1180    pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
1181    where
1182        F: Future<Output = T>,
1183    {
1184        f(self.unmanaged, &self.control).await
1185    }
1186}
1187
1188/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
1189fn content_length(response: &reqwest::Response) -> Option<u64> {
1190    response
1191        .headers()
1192        .get(reqwest::header::CONTENT_LENGTH)
1193        .and_then(|val| val.to_str().ok())
1194        .and_then(|val| val.parse::<u64>().ok())
1195}
1196
1197/// An asynchronous reader that reports progress as bytes are read.
1198struct ProgressReader<'a, R> {
1199    reader: R,
1200    index: usize,
1201    reporter: &'a dyn Reporter,
1202}
1203
1204impl<'a, R> ProgressReader<'a, R> {
1205    /// Create a new [`ProgressReader`] that wraps another reader.
1206    fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
1207        Self {
1208            reader,
1209            index,
1210            reporter,
1211        }
1212    }
1213}
1214
1215impl<R> AsyncRead for ProgressReader<'_, R>
1216where
1217    R: AsyncRead + Unpin,
1218{
1219    fn poll_read(
1220        mut self: Pin<&mut Self>,
1221        cx: &mut Context<'_>,
1222        buf: &mut ReadBuf<'_>,
1223    ) -> Poll<io::Result<()>> {
1224        Pin::new(&mut self.as_mut().reader)
1225            .poll_read(cx, buf)
1226            .map_ok(|()| {
1227                self.reporter
1228                    .on_download_progress(self.index, buf.filled().len() as u64);
1229            })
1230    }
1231}
1232
1233/// A pointer to an archive in the cache, fetched from an HTTP archive.
1234///
1235/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
1236#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1237pub struct HttpArchivePointer {
1238    archive: Archive,
1239}
1240
1241impl HttpArchivePointer {
1242    /// Read an [`HttpArchivePointer`] from the cache.
1243    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1244        match fs_err::File::open(path.as_ref()) {
1245            Ok(file) => {
1246                let data = DataWithCachePolicy::from_reader(file)?.data;
1247                let archive = rmp_serde::from_slice::<Archive>(&data)?;
1248                Ok(Some(Self { archive }))
1249            }
1250            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1251            Err(err) => Err(Error::CacheRead(err)),
1252        }
1253    }
1254
1255    /// Return the [`Archive`] from the pointer.
1256    pub fn into_archive(self) -> Archive {
1257        self.archive
1258    }
1259
1260    /// Return the [`CacheInfo`] from the pointer.
1261    pub fn to_cache_info(&self) -> CacheInfo {
1262        CacheInfo::default()
1263    }
1264
1265    /// Return the [`BuildInfo`] from the pointer.
1266    pub fn to_build_info(&self) -> Option<BuildInfo> {
1267        None
1268    }
1269}
1270
1271/// A pointer to an archive in the cache, fetched from a local path.
1272///
1273/// Encoded with `MsgPack`, and represented on disk by a `.rev` file.
1274#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1275pub struct LocalArchivePointer {
1276    timestamp: Timestamp,
1277    archive: Archive,
1278}
1279
1280impl LocalArchivePointer {
1281    /// Read an [`LocalArchivePointer`] from the cache.
1282    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1283        match fs_err::read(path) {
1284            Ok(cached) => Ok(Some(rmp_serde::from_slice::<Self>(&cached)?)),
1285            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1286            Err(err) => Err(Error::CacheRead(err)),
1287        }
1288    }
1289
1290    /// Write an [`LocalArchivePointer`] to the cache.
1291    pub async fn write_to(&self, entry: &CacheEntry) -> Result<(), Error> {
1292        write_atomic(entry.path(), rmp_serde::to_vec(&self)?)
1293            .await
1294            .map_err(Error::CacheWrite)
1295    }
1296
1297    /// Returns `true` if the archive is up-to-date with the given modified timestamp.
1298    pub fn is_up_to_date(&self, modified: Timestamp) -> bool {
1299        self.timestamp == modified
1300    }
1301
1302    /// Return the [`Archive`] from the pointer.
1303    pub fn into_archive(self) -> Archive {
1304        self.archive
1305    }
1306
1307    /// Return the [`CacheInfo`] from the pointer.
1308    pub fn to_cache_info(&self) -> CacheInfo {
1309        CacheInfo::from_timestamp(self.timestamp)
1310    }
1311
1312    /// Return the [`BuildInfo`] from the pointer.
1313    pub fn to_build_info(&self) -> Option<BuildInfo> {
1314        None
1315    }
1316}
1317
1318#[derive(Debug, Clone)]
1319struct WheelTarget {
1320    /// The URL from which the wheel can be downloaded.
1321    url: DisplaySafeUrl,
1322    /// The expected extension of the wheel file.
1323    extension: WheelExtension,
1324    /// The expected size of the wheel file, if known.
1325    size: Option<u64>,
1326}
1327
1328impl TryFrom<&File> for WheelTarget {
1329    type Error = ToUrlError;
1330
1331    /// Determine the [`WheelTarget`] from a [`File`].
1332    fn try_from(file: &File) -> Result<Self, Self::Error> {
1333        let url = file.url.to_url()?;
1334        if let Some(zstd) = file.zstd.as_ref() {
1335            Ok(Self {
1336                url: add_tar_zst_extension(url),
1337                extension: WheelExtension::WhlZst,
1338                size: zstd.size,
1339            })
1340        } else {
1341            Ok(Self {
1342                url,
1343                extension: WheelExtension::Whl,
1344                size: file.size,
1345            })
1346        }
1347    }
1348}
1349
1350#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1351enum WheelExtension {
1352    /// A `.whl` file.
1353    Whl,
1354    /// A `.whl.tar.zst` file.
1355    WhlZst,
1356}
1357
1358/// Add `.tar.zst` to the end of the URL path, if it doesn't already exist.
1359#[must_use]
1360fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {
1361    let mut path = url.path().to_string();
1362
1363    if !path.ends_with(".tar.zst") {
1364        path.push_str(".tar.zst");
1365    }
1366
1367    url.set_path(&path);
1368    url
1369}
1370
1371#[cfg(test)]
1372mod tests {
1373    use super::*;
1374
1375    #[test]
1376    fn test_add_tar_zst_extension() {
1377        let url =
1378            DisplaySafeUrl::parse("https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl")
1379                .unwrap();
1380        assert_eq!(
1381            add_tar_zst_extension(url).as_str(),
1382            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1383        );
1384
1385        let url = DisplaySafeUrl::parse(
1386            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst",
1387        )
1388        .unwrap();
1389        assert_eq!(
1390            add_tar_zst_extension(url).as_str(),
1391            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1392        );
1393
1394        let url = DisplaySafeUrl::parse(
1395            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl",
1396        )
1397        .unwrap();
1398        assert_eq!(
1399            add_tar_zst_extension(url).as_str(),
1400            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl.tar.zst"
1401        );
1402    }
1403}