Skip to main content

uv_distribution/
distribution_database.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{FutureExt, TryStreamExt};
9use tempfile::TempDir;
10use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
11use tokio::sync::Semaphore;
12use tokio_util::compat::FuturesAsyncReadCompatExt;
13use tracing::{Instrument, info_span, instrument, warn};
14use url::Url;
15
16use uv_cache::{ArchiveId, CacheBucket, CacheEntry, WheelCache};
17use uv_cache_info::{CacheInfo, Timestamp};
18use uv_client::{
19    CacheControl, CachedClientError, Connectivity, DataWithCachePolicy, RegistryClient,
20};
21use uv_distribution_filename::WheelFilename;
22use uv_distribution_types::{
23    BuildInfo, BuildableSource, BuiltDist, Dist, File, HashPolicy, Hashed, IndexUrl, InstalledDist,
24    Name, SourceDist, ToUrlError,
25};
26use uv_extract::hash::Hasher;
27use uv_fs::write_atomic;
28use uv_platform_tags::Tags;
29use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
30use uv_redacted::DisplaySafeUrl;
31use uv_types::{BuildContext, BuildStack};
32
33use crate::archive::Archive;
34use crate::metadata::{ArchiveMetadata, Metadata};
35use crate::source::SourceDistributionBuilder;
36use crate::{Error, LocalWheel, Reporter, RequiresDist};
37
38/// A cached high-level interface to convert distributions (a requirement resolved to a location)
39/// to a wheel or wheel metadata.
40///
41/// For wheel metadata, this happens by either fetching the metadata from the remote wheel or by
42/// building the source distribution. For wheel files, either the wheel is downloaded or a source
43/// distribution is downloaded, built and the new wheel gets returned.
44///
45/// All kinds of wheel sources (index, URL, path) and source distribution source (index, URL, path,
46/// Git) are supported.
47///
48/// This struct also has the task of acquiring locks around source dist builds in general and git
49/// operation especially, as well as respecting concurrency limits.
50pub struct DistributionDatabase<'a, Context: BuildContext> {
51    build_context: &'a Context,
52    builder: SourceDistributionBuilder<'a, Context>,
53    client: ManagedClient<'a>,
54    reporter: Option<Arc<dyn Reporter>>,
55}
56
57impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
58    pub fn new(
59        client: &'a RegistryClient,
60        build_context: &'a Context,
61        concurrent_downloads: usize,
62    ) -> Self {
63        Self {
64            build_context,
65            builder: SourceDistributionBuilder::new(build_context),
66            client: ManagedClient::new(client, concurrent_downloads),
67            reporter: None,
68        }
69    }
70
71    /// Set the build stack to use for the [`DistributionDatabase`].
72    #[must_use]
73    pub fn with_build_stack(self, build_stack: &'a BuildStack) -> Self {
74        Self {
75            builder: self.builder.with_build_stack(build_stack),
76            ..self
77        }
78    }
79
80    /// Set the [`Reporter`] to use for the [`DistributionDatabase`].
81    #[must_use]
82    pub fn with_reporter(self, reporter: Arc<dyn Reporter>) -> Self {
83        Self {
84            builder: self.builder.with_reporter(reporter.clone()),
85            reporter: Some(reporter),
86            ..self
87        }
88    }
89
90    /// Handle a specific `reqwest` error, and convert it to [`io::Error`].
91    fn handle_response_errors(&self, err: reqwest::Error) -> io::Error {
92        if err.is_timeout() {
93            // Assumption: The connect timeout with the 10s default is not the culprit.
94            io::Error::new(
95                io::ErrorKind::TimedOut,
96                format!(
97                    "Failed to download distribution due to network timeout. Try increasing UV_HTTP_TIMEOUT (current value: {}s).",
98                    self.client.unmanaged.read_timeout().as_secs()
99                ),
100            )
101        } else {
102            io::Error::other(err)
103        }
104    }
105
106    /// Either fetch the wheel or fetch and build the source distribution
107    ///
108    /// Returns a wheel that's compliant with the given platform tags.
109    ///
110    /// While hashes will be generated in some cases, hash-checking is only enforced for source
111    /// distributions, and should be enforced by the caller for wheels.
112    #[instrument(skip_all, fields(%dist))]
113    pub async fn get_or_build_wheel(
114        &self,
115        dist: &Dist,
116        tags: &Tags,
117        hashes: HashPolicy<'_>,
118    ) -> Result<LocalWheel, Error> {
119        match dist {
120            Dist::Built(built) => self.get_wheel(built, hashes).await,
121            Dist::Source(source) => self.build_wheel(source, tags, hashes).await,
122        }
123    }
124
125    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
126    /// fetch and build the source distribution.
127    ///
128    /// While hashes will be generated in some cases, hash-checking is only enforced for source
129    /// distributions, and should be enforced by the caller for wheels.
130    #[instrument(skip_all, fields(%dist))]
131    pub async fn get_installed_metadata(
132        &self,
133        dist: &InstalledDist,
134    ) -> Result<ArchiveMetadata, Error> {
135        // If the metadata was provided by the user directly, prefer it.
136        if let Some(metadata) = self
137            .build_context
138            .dependency_metadata()
139            .get(dist.name(), Some(dist.version()))
140        {
141            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
142        }
143
144        let metadata = dist
145            .read_metadata()
146            .map_err(|err| Error::ReadInstalled(Box::new(dist.clone()), err))?;
147
148        Ok(ArchiveMetadata::from_metadata23(metadata.clone()))
149    }
150
151    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
152    /// fetch and build the source distribution.
153    ///
154    /// While hashes will be generated in some cases, hash-checking is only enforced for source
155    /// distributions, and should be enforced by the caller for wheels.
156    #[instrument(skip_all, fields(%dist))]
157    pub async fn get_or_build_wheel_metadata(
158        &self,
159        dist: &Dist,
160        hashes: HashPolicy<'_>,
161    ) -> Result<ArchiveMetadata, Error> {
162        match dist {
163            Dist::Built(built) => self.get_wheel_metadata(built, hashes).await,
164            Dist::Source(source) => {
165                self.build_wheel_metadata(&BuildableSource::Dist(source), hashes)
166                    .await
167            }
168        }
169    }
170
171    /// Fetch a wheel from the cache or download it from the index.
172    ///
173    /// While hashes will be generated in all cases, hash-checking is _not_ enforced and should
174    /// instead be enforced by the caller.
175    async fn get_wheel(
176        &self,
177        dist: &BuiltDist,
178        hashes: HashPolicy<'_>,
179    ) -> Result<LocalWheel, Error> {
180        match dist {
181            BuiltDist::Registry(wheels) => {
182                let wheel = wheels.best_wheel();
183                let WheelTarget {
184                    url,
185                    extension,
186                    size,
187                } = WheelTarget::try_from(&*wheel.file)?;
188
189                // Create a cache entry for the wheel.
190                let wheel_entry = self.build_context.cache().entry(
191                    CacheBucket::Wheels,
192                    WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
193                    wheel.filename.cache_key(),
194                );
195
196                // If the URL is a file URL, load the wheel directly.
197                if url.scheme() == "file" {
198                    let path = url
199                        .to_file_path()
200                        .map_err(|()| Error::NonFileUrl(url.clone()))?;
201                    return self
202                        .load_wheel(
203                            &path,
204                            &wheel.filename,
205                            WheelExtension::Whl,
206                            wheel_entry,
207                            dist,
208                            hashes,
209                        )
210                        .await;
211                }
212
213                // Download and unzip.
214                match self
215                    .stream_wheel(
216                        url.clone(),
217                        dist.index(),
218                        &wheel.filename,
219                        extension,
220                        size,
221                        &wheel_entry,
222                        dist,
223                        hashes,
224                    )
225                    .await
226                {
227                    Ok(archive) => Ok(LocalWheel {
228                        dist: Dist::Built(dist.clone()),
229                        archive: self
230                            .build_context
231                            .cache()
232                            .archive(&archive.id)
233                            .into_boxed_path(),
234                        hashes: archive.hashes,
235                        filename: wheel.filename.clone(),
236                        cache: CacheInfo::default(),
237                        build: None,
238                    }),
239                    Err(Error::Extract(name, err)) => {
240                        if err.is_http_streaming_unsupported() {
241                            warn!(
242                                "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
243                            );
244                        } else if err.is_http_streaming_failed() {
245                            warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
246                        } else {
247                            return Err(Error::Extract(name, err));
248                        }
249
250                        // If the request failed because streaming is unsupported, download the
251                        // wheel directly.
252                        let archive = self
253                            .download_wheel(
254                                url,
255                                dist.index(),
256                                &wheel.filename,
257                                extension,
258                                size,
259                                &wheel_entry,
260                                dist,
261                                hashes,
262                            )
263                            .await?;
264
265                        Ok(LocalWheel {
266                            dist: Dist::Built(dist.clone()),
267                            archive: self
268                                .build_context
269                                .cache()
270                                .archive(&archive.id)
271                                .into_boxed_path(),
272                            hashes: archive.hashes,
273                            filename: wheel.filename.clone(),
274                            cache: CacheInfo::default(),
275                            build: None,
276                        })
277                    }
278                    Err(err) => Err(err),
279                }
280            }
281
282            BuiltDist::DirectUrl(wheel) => {
283                // Create a cache entry for the wheel.
284                let wheel_entry = self.build_context.cache().entry(
285                    CacheBucket::Wheels,
286                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
287                    wheel.filename.cache_key(),
288                );
289
290                // Download and unzip.
291                match self
292                    .stream_wheel(
293                        wheel.url.raw().clone(),
294                        None,
295                        &wheel.filename,
296                        WheelExtension::Whl,
297                        None,
298                        &wheel_entry,
299                        dist,
300                        hashes,
301                    )
302                    .await
303                {
304                    Ok(archive) => Ok(LocalWheel {
305                        dist: Dist::Built(dist.clone()),
306                        archive: self
307                            .build_context
308                            .cache()
309                            .archive(&archive.id)
310                            .into_boxed_path(),
311                        hashes: archive.hashes,
312                        filename: wheel.filename.clone(),
313                        cache: CacheInfo::default(),
314                        build: None,
315                    }),
316                    Err(Error::Client(err)) if err.is_http_streaming_unsupported() => {
317                        warn!(
318                            "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
319                        );
320
321                        // If the request failed because streaming is unsupported, download the
322                        // wheel directly.
323                        let archive = self
324                            .download_wheel(
325                                wheel.url.raw().clone(),
326                                None,
327                                &wheel.filename,
328                                WheelExtension::Whl,
329                                None,
330                                &wheel_entry,
331                                dist,
332                                hashes,
333                            )
334                            .await?;
335                        Ok(LocalWheel {
336                            dist: Dist::Built(dist.clone()),
337                            archive: self
338                                .build_context
339                                .cache()
340                                .archive(&archive.id)
341                                .into_boxed_path(),
342                            hashes: archive.hashes,
343                            filename: wheel.filename.clone(),
344                            cache: CacheInfo::default(),
345                            build: None,
346                        })
347                    }
348                    Err(err) => Err(err),
349                }
350            }
351
352            BuiltDist::Path(wheel) => {
353                let cache_entry = self.build_context.cache().entry(
354                    CacheBucket::Wheels,
355                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
356                    wheel.filename.cache_key(),
357                );
358
359                self.load_wheel(
360                    &wheel.install_path,
361                    &wheel.filename,
362                    WheelExtension::Whl,
363                    cache_entry,
364                    dist,
365                    hashes,
366                )
367                .await
368            }
369        }
370    }
371
372    /// Convert a source distribution into a wheel, fetching it from the cache or building it if
373    /// necessary.
374    ///
375    /// The returned wheel is guaranteed to come from a distribution with a matching hash, and
376    /// no build processes will be executed for distributions with mismatched hashes.
377    async fn build_wheel(
378        &self,
379        dist: &SourceDist,
380        tags: &Tags,
381        hashes: HashPolicy<'_>,
382    ) -> Result<LocalWheel, Error> {
383        let built_wheel = self
384            .builder
385            .download_and_build(&BuildableSource::Dist(dist), tags, hashes, &self.client)
386            .boxed_local()
387            .await?;
388
389        // Check that the wheel is compatible with its install target.
390        //
391        // When building a build dependency for a cross-install, the build dependency needs
392        // to install and run on the host instead of the target. In this case the `tags` are already
393        // for the host instead of the target, so this check passes.
394        if !built_wheel.filename.is_compatible(tags) {
395            return if tags.is_cross() {
396                Err(Error::BuiltWheelIncompatibleTargetPlatform {
397                    filename: built_wheel.filename,
398                    python_platform: tags.python_platform().clone(),
399                    python_version: tags.python_version(),
400                })
401            } else {
402                Err(Error::BuiltWheelIncompatibleHostPlatform {
403                    filename: built_wheel.filename,
404                    python_platform: tags.python_platform().clone(),
405                    python_version: tags.python_version(),
406                })
407            };
408        }
409
410        // Acquire the advisory lock.
411        #[cfg(windows)]
412        let _lock = {
413            let lock_entry = CacheEntry::new(
414                built_wheel.target.parent().unwrap(),
415                format!(
416                    "{}.lock",
417                    built_wheel.target.file_name().unwrap().to_str().unwrap()
418                ),
419            );
420            lock_entry.lock().await.map_err(Error::CacheLock)?
421        };
422
423        // If the wheel was unzipped previously, respect it. Source distributions are
424        // cached under a unique revision ID, so unzipped directories are never stale.
425        match self.build_context.cache().resolve_link(&built_wheel.target) {
426            Ok(archive) => {
427                return Ok(LocalWheel {
428                    dist: Dist::Source(dist.clone()),
429                    archive: archive.into_boxed_path(),
430                    filename: built_wheel.filename,
431                    hashes: built_wheel.hashes,
432                    cache: built_wheel.cache_info,
433                    build: Some(built_wheel.build_info),
434                });
435            }
436            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
437            Err(err) => return Err(Error::CacheRead(err)),
438        }
439
440        // Otherwise, unzip the wheel.
441        let id = self
442            .unzip_wheel(&built_wheel.path, &built_wheel.target)
443            .await?;
444
445        Ok(LocalWheel {
446            dist: Dist::Source(dist.clone()),
447            archive: self.build_context.cache().archive(&id).into_boxed_path(),
448            hashes: built_wheel.hashes,
449            filename: built_wheel.filename,
450            cache: built_wheel.cache_info,
451            build: Some(built_wheel.build_info),
452        })
453    }
454
455    /// Fetch the wheel metadata from the index, or from the cache if possible.
456    ///
457    /// While hashes will be generated in some cases, hash-checking is _not_ enforced and should
458    /// instead be enforced by the caller.
459    async fn get_wheel_metadata(
460        &self,
461        dist: &BuiltDist,
462        hashes: HashPolicy<'_>,
463    ) -> Result<ArchiveMetadata, Error> {
464        // If hash generation is enabled, and the distribution isn't hosted on a registry, get the
465        // entire wheel to ensure that the hashes are included in the response. If the distribution
466        // is hosted on an index, the hashes will be included in the simple metadata response.
467        // For hash _validation_, callers are expected to enforce the policy when retrieving the
468        // wheel.
469        //
470        // Historically, for `uv pip compile --universal`, we also generate hashes for
471        // registry-based distributions when the relevant registry doesn't provide them. This was
472        // motivated by `--find-links`. We continue that behavior (under `HashGeneration::All`) for
473        // backwards compatibility, but it's a little dubious, since we're only hashing _one_
474        // distribution here (as opposed to hashing all distributions for the version), and it may
475        // not even be a compatible distribution!
476        //
477        // TODO(charlie): Request the hashes via a separate method, to reduce the coupling in this API.
478        if hashes.is_generate(dist) {
479            let wheel = self.get_wheel(dist, hashes).await?;
480            // If the metadata was provided by the user directly, prefer it.
481            let metadata = if let Some(metadata) = self
482                .build_context
483                .dependency_metadata()
484                .get(dist.name(), Some(dist.version()))
485            {
486                metadata.clone()
487            } else {
488                wheel.metadata()?
489            };
490            let hashes = wheel.hashes;
491            return Ok(ArchiveMetadata {
492                metadata: Metadata::from_metadata23(metadata),
493                hashes,
494            });
495        }
496
497        // If the metadata was provided by the user directly, prefer it.
498        if let Some(metadata) = self
499            .build_context
500            .dependency_metadata()
501            .get(dist.name(), Some(dist.version()))
502        {
503            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
504        }
505
506        let result = self
507            .client
508            .managed(|client| {
509                client
510                    .wheel_metadata(dist, self.build_context.capabilities())
511                    .boxed_local()
512            })
513            .await;
514
515        match result {
516            Ok(metadata) => {
517                // Validate that the metadata is consistent with the distribution.
518                Ok(ArchiveMetadata::from_metadata23(metadata))
519            }
520            Err(err) if err.is_http_streaming_unsupported() => {
521                warn!(
522                    "Streaming unsupported when fetching metadata for {dist}; downloading wheel directly ({err})"
523                );
524
525                // If the request failed due to an error that could be resolved by
526                // downloading the wheel directly, try that.
527                let wheel = self.get_wheel(dist, hashes).await?;
528                let metadata = wheel.metadata()?;
529                let hashes = wheel.hashes;
530                Ok(ArchiveMetadata {
531                    metadata: Metadata::from_metadata23(metadata),
532                    hashes,
533                })
534            }
535            Err(err) => Err(err.into()),
536        }
537    }
538
539    /// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
540    ///
541    /// The returned metadata is guaranteed to come from a distribution with a matching hash, and
542    /// no build processes will be executed for distributions with mismatched hashes.
543    pub async fn build_wheel_metadata(
544        &self,
545        source: &BuildableSource<'_>,
546        hashes: HashPolicy<'_>,
547    ) -> Result<ArchiveMetadata, Error> {
548        // If the metadata was provided by the user directly, prefer it.
549        if let Some(dist) = source.as_dist() {
550            if let Some(metadata) = self
551                .build_context
552                .dependency_metadata()
553                .get(dist.name(), dist.version())
554            {
555                // If we skipped the build, we should still resolve any Git dependencies to precise
556                // commits.
557                self.builder.resolve_revision(source, &self.client).await?;
558
559                return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
560            }
561        }
562
563        let metadata = self
564            .builder
565            .download_and_build_metadata(source, hashes, &self.client)
566            .boxed_local()
567            .await?;
568
569        Ok(metadata)
570    }
571
572    /// Return the [`RequiresDist`] from a `pyproject.toml`, if it can be statically extracted.
573    pub async fn requires_dist(
574        &self,
575        path: &Path,
576        pyproject_toml: &PyProjectToml,
577    ) -> Result<Option<RequiresDist>, Error> {
578        self.builder
579            .source_tree_requires_dist(
580                path,
581                pyproject_toml,
582                self.client.unmanaged.credentials_cache(),
583            )
584            .await
585    }
586
587    /// Stream a wheel from a URL, unzipping it into the cache as it's downloaded.
588    async fn stream_wheel(
589        &self,
590        url: DisplaySafeUrl,
591        index: Option<&IndexUrl>,
592        filename: &WheelFilename,
593        extension: WheelExtension,
594        size: Option<u64>,
595        wheel_entry: &CacheEntry,
596        dist: &BuiltDist,
597        hashes: HashPolicy<'_>,
598    ) -> Result<Archive, Error> {
599        // Acquire an advisory lock, to guard against concurrent writes.
600        #[cfg(windows)]
601        let _lock = {
602            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
603            lock_entry.lock().await.map_err(Error::CacheLock)?
604        };
605
606        // Create an entry for the HTTP cache.
607        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
608
609        let query_url = &url.clone();
610
611        let download = |response: reqwest::Response| {
612            async {
613                let size = size.or_else(|| content_length(&response));
614
615                let progress = self
616                    .reporter
617                    .as_ref()
618                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
619
620                let reader = response
621                    .bytes_stream()
622                    .map_err(|err| self.handle_response_errors(err))
623                    .into_async_read();
624
625                // Create a hasher for each hash algorithm.
626                let algorithms = hashes.algorithms();
627                let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
628                let mut hasher = uv_extract::hash::HashReader::new(reader.compat(), &mut hashers);
629
630                // Download and unzip the wheel to a temporary directory.
631                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
632                    .map_err(Error::CacheWrite)?;
633
634                match progress {
635                    Some((reporter, progress)) => {
636                        let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
637                        match extension {
638                            WheelExtension::Whl => {
639                                uv_extract::stream::unzip(query_url, &mut reader, temp_dir.path())
640                                    .await
641                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
642                            }
643                            WheelExtension::WhlZst => {
644                                uv_extract::stream::untar_zst(&mut reader, temp_dir.path())
645                                    .await
646                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
647                            }
648                        }
649                    }
650                    None => match extension {
651                        WheelExtension::Whl => {
652                            uv_extract::stream::unzip(query_url, &mut hasher, temp_dir.path())
653                                .await
654                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
655                        }
656                        WheelExtension::WhlZst => {
657                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
658                                .await
659                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
660                        }
661                    },
662                }
663
664                // If necessary, exhaust the reader to compute the hash.
665                if !hashes.is_none() {
666                    hasher.finish().await.map_err(Error::HashExhaustion)?;
667                }
668
669                // Persist the temporary directory to the directory store.
670                let id = self
671                    .build_context
672                    .cache()
673                    .persist(temp_dir.keep(), wheel_entry.path())
674                    .await
675                    .map_err(Error::CacheRead)?;
676
677                if let Some((reporter, progress)) = progress {
678                    reporter.on_download_complete(dist.name(), progress);
679                }
680
681                Ok(Archive::new(
682                    id,
683                    hashers.into_iter().map(HashDigest::from).collect(),
684                    filename.clone(),
685                ))
686            }
687            .instrument(info_span!("wheel", wheel = %dist))
688        };
689
690        // Fetch the archive from the cache, or download it if necessary.
691        let req = self.request(url.clone())?;
692
693        // Determine the cache control policy for the URL.
694        let cache_control = match self.client.unmanaged.connectivity() {
695            Connectivity::Online => {
696                if let Some(header) = index.and_then(|index| {
697                    self.build_context
698                        .locations()
699                        .artifact_cache_control_for(index)
700                }) {
701                    CacheControl::Override(header)
702                } else {
703                    CacheControl::from(
704                        self.build_context
705                            .cache()
706                            .freshness(&http_entry, Some(&filename.name), None)
707                            .map_err(Error::CacheRead)?,
708                    )
709                }
710            }
711            Connectivity::Offline => CacheControl::AllowStale,
712        };
713
714        let archive = self
715            .client
716            .managed(|client| {
717                client.cached_client().get_serde_with_retry(
718                    req,
719                    &http_entry,
720                    cache_control,
721                    download,
722                )
723            })
724            .await
725            .map_err(|err| match err {
726                CachedClientError::Callback { err, .. } => err,
727                CachedClientError::Client(err) => Error::Client(err),
728            })?;
729
730        // If the archive is missing the required hashes, or has since been removed, force a refresh.
731        let archive = Some(archive)
732            .filter(|archive| archive.has_digests(hashes))
733            .filter(|archive| archive.exists(self.build_context.cache()));
734
735        let archive = if let Some(archive) = archive {
736            archive
737        } else {
738            self.client
739                .managed(async |client| {
740                    client
741                        .cached_client()
742                        .skip_cache_with_retry(
743                            self.request(url)?,
744                            &http_entry,
745                            cache_control,
746                            download,
747                        )
748                        .await
749                        .map_err(|err| match err {
750                            CachedClientError::Callback { err, .. } => err,
751                            CachedClientError::Client(err) => Error::Client(err),
752                        })
753                })
754                .await?
755        };
756
757        Ok(archive)
758    }
759
760    /// Download a wheel from a URL, then unzip it into the cache.
761    async fn download_wheel(
762        &self,
763        url: DisplaySafeUrl,
764        index: Option<&IndexUrl>,
765        filename: &WheelFilename,
766        extension: WheelExtension,
767        size: Option<u64>,
768        wheel_entry: &CacheEntry,
769        dist: &BuiltDist,
770        hashes: HashPolicy<'_>,
771    ) -> Result<Archive, Error> {
772        // Acquire an advisory lock, to guard against concurrent writes.
773        #[cfg(windows)]
774        let _lock = {
775            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
776            lock_entry.lock().await.map_err(Error::CacheLock)?
777        };
778
779        // Create an entry for the HTTP cache.
780        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
781
782        let query_url = &url.clone();
783
784        let download = |response: reqwest::Response| {
785            async {
786                let size = size.or_else(|| content_length(&response));
787
788                let progress = self
789                    .reporter
790                    .as_ref()
791                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
792
793                let reader = response
794                    .bytes_stream()
795                    .map_err(|err| self.handle_response_errors(err))
796                    .into_async_read();
797
798                // Download the wheel to a temporary file.
799                let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
800                    .map_err(Error::CacheWrite)?;
801                let mut writer = tokio::io::BufWriter::new(fs_err::tokio::File::from_std(
802                    // It's an unnamed file on Linux so that's the best approximation.
803                    fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
804                ));
805
806                match progress {
807                    Some((reporter, progress)) => {
808                        // Wrap the reader in a progress reporter. This will report 100% progress
809                        // after the download is complete, even if we still have to unzip and hash
810                        // part of the file.
811                        let mut reader =
812                            ProgressReader::new(reader.compat(), progress, &**reporter);
813
814                        tokio::io::copy(&mut reader, &mut writer)
815                            .await
816                            .map_err(Error::CacheWrite)?;
817                    }
818                    None => {
819                        tokio::io::copy(&mut reader.compat(), &mut writer)
820                            .await
821                            .map_err(Error::CacheWrite)?;
822                    }
823                }
824
825                // Unzip the wheel to a temporary directory.
826                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
827                    .map_err(Error::CacheWrite)?;
828                let mut file = writer.into_inner();
829                file.seek(io::SeekFrom::Start(0))
830                    .await
831                    .map_err(Error::CacheWrite)?;
832
833                // If no hashes are required, parallelize the unzip operation.
834                let hashes = if hashes.is_none() {
835                    let file = file.into_std().await;
836                    tokio::task::spawn_blocking({
837                        let target = temp_dir.path().to_owned();
838                        move || -> Result<(), uv_extract::Error> {
839                            // Unzip the wheel into a temporary directory.
840                            match extension {
841                                WheelExtension::Whl => {
842                                    uv_extract::unzip(file, &target)?;
843                                }
844                                WheelExtension::WhlZst => {
845                                    uv_extract::stream::untar_zst_file(file, &target)?;
846                                }
847                            }
848                            Ok(())
849                        }
850                    })
851                    .await?
852                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
853
854                    HashDigests::empty()
855                } else {
856                    // Create a hasher for each hash algorithm.
857                    let algorithms = hashes.algorithms();
858                    let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
859                    let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
860
861                    match extension {
862                        WheelExtension::Whl => {
863                            uv_extract::stream::unzip(query_url, &mut hasher, temp_dir.path())
864                                .await
865                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
866                        }
867                        WheelExtension::WhlZst => {
868                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
869                                .await
870                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
871                        }
872                    }
873
874                    // If necessary, exhaust the reader to compute the hash.
875                    hasher.finish().await.map_err(Error::HashExhaustion)?;
876
877                    hashers.into_iter().map(HashDigest::from).collect()
878                };
879
880                // Persist the temporary directory to the directory store.
881                let id = self
882                    .build_context
883                    .cache()
884                    .persist(temp_dir.keep(), wheel_entry.path())
885                    .await
886                    .map_err(Error::CacheRead)?;
887
888                if let Some((reporter, progress)) = progress {
889                    reporter.on_download_complete(dist.name(), progress);
890                }
891
892                Ok(Archive::new(id, hashes, filename.clone()))
893            }
894            .instrument(info_span!("wheel", wheel = %dist))
895        };
896
897        // Fetch the archive from the cache, or download it if necessary.
898        let req = self.request(url.clone())?;
899
900        // Determine the cache control policy for the URL.
901        let cache_control = match self.client.unmanaged.connectivity() {
902            Connectivity::Online => {
903                if let Some(header) = index.and_then(|index| {
904                    self.build_context
905                        .locations()
906                        .artifact_cache_control_for(index)
907                }) {
908                    CacheControl::Override(header)
909                } else {
910                    CacheControl::from(
911                        self.build_context
912                            .cache()
913                            .freshness(&http_entry, Some(&filename.name), None)
914                            .map_err(Error::CacheRead)?,
915                    )
916                }
917            }
918            Connectivity::Offline => CacheControl::AllowStale,
919        };
920
921        let archive = self
922            .client
923            .managed(|client| {
924                client.cached_client().get_serde_with_retry(
925                    req,
926                    &http_entry,
927                    cache_control,
928                    download,
929                )
930            })
931            .await
932            .map_err(|err| match err {
933                CachedClientError::Callback { err, .. } => err,
934                CachedClientError::Client(err) => Error::Client(err),
935            })?;
936
937        // If the archive is missing the required hashes, or has since been removed, force a refresh.
938        let archive = Some(archive)
939            .filter(|archive| archive.has_digests(hashes))
940            .filter(|archive| archive.exists(self.build_context.cache()));
941
942        let archive = if let Some(archive) = archive {
943            archive
944        } else {
945            self.client
946                .managed(async |client| {
947                    client
948                        .cached_client()
949                        .skip_cache_with_retry(
950                            self.request(url)?,
951                            &http_entry,
952                            cache_control,
953                            download,
954                        )
955                        .await
956                        .map_err(|err| match err {
957                            CachedClientError::Callback { err, .. } => err,
958                            CachedClientError::Client(err) => Error::Client(err),
959                        })
960                })
961                .await?
962        };
963
964        Ok(archive)
965    }
966
967    /// Load a wheel from a local path.
968    async fn load_wheel(
969        &self,
970        path: &Path,
971        filename: &WheelFilename,
972        extension: WheelExtension,
973        wheel_entry: CacheEntry,
974        dist: &BuiltDist,
975        hashes: HashPolicy<'_>,
976    ) -> Result<LocalWheel, Error> {
977        #[cfg(windows)]
978        let _lock = {
979            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
980            lock_entry.lock().await.map_err(Error::CacheLock)?
981        };
982
983        // Determine the last-modified time of the wheel.
984        let modified = Timestamp::from_path(path).map_err(Error::CacheRead)?;
985
986        // Attempt to read the archive pointer from the cache.
987        let pointer_entry = wheel_entry.with_file(format!("{}.rev", filename.cache_key()));
988        let pointer = LocalArchivePointer::read_from(&pointer_entry)?;
989
990        // Extract the archive from the pointer.
991        let archive = pointer
992            .filter(|pointer| pointer.is_up_to_date(modified))
993            .map(LocalArchivePointer::into_archive)
994            .filter(|archive| archive.has_digests(hashes));
995
996        // If the file is already unzipped, and the cache is up-to-date, return it.
997        if let Some(archive) = archive {
998            Ok(LocalWheel {
999                dist: Dist::Built(dist.clone()),
1000                archive: self
1001                    .build_context
1002                    .cache()
1003                    .archive(&archive.id)
1004                    .into_boxed_path(),
1005                hashes: archive.hashes,
1006                filename: filename.clone(),
1007                cache: CacheInfo::from_timestamp(modified),
1008                build: None,
1009            })
1010        } else if hashes.is_none() {
1011            // Otherwise, unzip the wheel.
1012            let archive = Archive::new(
1013                self.unzip_wheel(path, wheel_entry.path()).await?,
1014                HashDigests::empty(),
1015                filename.clone(),
1016            );
1017
1018            // Write the archive pointer to the cache.
1019            let pointer = LocalArchivePointer {
1020                timestamp: modified,
1021                archive: archive.clone(),
1022            };
1023            pointer.write_to(&pointer_entry).await?;
1024
1025            Ok(LocalWheel {
1026                dist: Dist::Built(dist.clone()),
1027                archive: self
1028                    .build_context
1029                    .cache()
1030                    .archive(&archive.id)
1031                    .into_boxed_path(),
1032                hashes: archive.hashes,
1033                filename: filename.clone(),
1034                cache: CacheInfo::from_timestamp(modified),
1035                build: None,
1036            })
1037        } else {
1038            // If necessary, compute the hashes of the wheel.
1039            let file = fs_err::tokio::File::open(path)
1040                .await
1041                .map_err(Error::CacheRead)?;
1042            let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
1043                .map_err(Error::CacheWrite)?;
1044
1045            // Create a hasher for each hash algorithm.
1046            let algorithms = hashes.algorithms();
1047            let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
1048            let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
1049
1050            // Unzip the wheel to a temporary directory.
1051            match extension {
1052                WheelExtension::Whl => {
1053                    uv_extract::stream::unzip(path.display(), &mut hasher, temp_dir.path())
1054                        .await
1055                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1056                }
1057                WheelExtension::WhlZst => {
1058                    uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
1059                        .await
1060                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1061                }
1062            }
1063
1064            // Exhaust the reader to compute the hash.
1065            hasher.finish().await.map_err(Error::HashExhaustion)?;
1066
1067            let hashes = hashers.into_iter().map(HashDigest::from).collect();
1068
1069            // Persist the temporary directory to the directory store.
1070            let id = self
1071                .build_context
1072                .cache()
1073                .persist(temp_dir.keep(), wheel_entry.path())
1074                .await
1075                .map_err(Error::CacheWrite)?;
1076
1077            // Create an archive.
1078            let archive = Archive::new(id, hashes, filename.clone());
1079
1080            // Write the archive pointer to the cache.
1081            let pointer = LocalArchivePointer {
1082                timestamp: modified,
1083                archive: archive.clone(),
1084            };
1085            pointer.write_to(&pointer_entry).await?;
1086
1087            Ok(LocalWheel {
1088                dist: Dist::Built(dist.clone()),
1089                archive: self
1090                    .build_context
1091                    .cache()
1092                    .archive(&archive.id)
1093                    .into_boxed_path(),
1094                hashes: archive.hashes,
1095                filename: filename.clone(),
1096                cache: CacheInfo::from_timestamp(modified),
1097                build: None,
1098            })
1099        }
1100    }
1101
1102    /// Unzip a wheel into the cache, returning the path to the unzipped directory.
1103    async fn unzip_wheel(&self, path: &Path, target: &Path) -> Result<ArchiveId, Error> {
1104        let temp_dir = tokio::task::spawn_blocking({
1105            let path = path.to_owned();
1106            let root = self.build_context.cache().root().to_path_buf();
1107            move || -> Result<TempDir, Error> {
1108                // Unzip the wheel into a temporary directory.
1109                let temp_dir = tempfile::tempdir_in(root).map_err(Error::CacheWrite)?;
1110                let reader = fs_err::File::open(&path).map_err(Error::CacheWrite)?;
1111                uv_extract::unzip(reader, temp_dir.path())
1112                    .map_err(|err| Error::Extract(path.to_string_lossy().into_owned(), err))?;
1113                Ok(temp_dir)
1114            }
1115        })
1116        .await??;
1117
1118        // Persist the temporary directory to the directory store.
1119        let id = self
1120            .build_context
1121            .cache()
1122            .persist(temp_dir.keep(), target)
1123            .await
1124            .map_err(Error::CacheWrite)?;
1125
1126        Ok(id)
1127    }
1128
1129    /// Returns a GET [`reqwest::Request`] for the given URL.
1130    fn request(&self, url: DisplaySafeUrl) -> Result<reqwest::Request, reqwest::Error> {
1131        self.client
1132            .unmanaged
1133            .uncached_client(&url)
1134            .get(Url::from(url))
1135            .header(
1136                // `reqwest` defaults to accepting compressed responses.
1137                // Specify identity encoding to get consistent .whl downloading
1138                // behavior from servers. ref: https://github.com/pypa/pip/pull/1688
1139                "accept-encoding",
1140                reqwest::header::HeaderValue::from_static("identity"),
1141            )
1142            .build()
1143    }
1144
1145    /// Return the [`ManagedClient`] used by this resolver.
1146    pub fn client(&self) -> &ManagedClient<'a> {
1147        &self.client
1148    }
1149}
1150
1151/// A wrapper around `RegistryClient` that manages a concurrency limit.
1152pub struct ManagedClient<'a> {
1153    pub unmanaged: &'a RegistryClient,
1154    control: Semaphore,
1155}
1156
1157impl<'a> ManagedClient<'a> {
1158    /// Create a new `ManagedClient` using the given client and concurrency limit.
1159    fn new(client: &'a RegistryClient, concurrency: usize) -> Self {
1160        ManagedClient {
1161            unmanaged: client,
1162            control: Semaphore::new(concurrency),
1163        }
1164    }
1165
1166    /// Perform a request using the client, respecting the concurrency limit.
1167    ///
1168    /// If the concurrency limit has been reached, this method will wait until a pending
1169    /// operation completes before executing the closure.
1170    pub async fn managed<F, T>(&self, f: impl FnOnce(&'a RegistryClient) -> F) -> T
1171    where
1172        F: Future<Output = T>,
1173    {
1174        let _permit = self.control.acquire().await.unwrap();
1175        f(self.unmanaged).await
1176    }
1177
1178    /// Perform a request using a client that internally manages the concurrency limit.
1179    ///
1180    /// The callback is passed the client and a semaphore. It must acquire the semaphore before
1181    /// any request through the client and drop it after.
1182    ///
1183    /// This method serves as an escape hatch for functions that may want to send multiple requests
1184    /// in parallel.
1185    pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
1186    where
1187        F: Future<Output = T>,
1188    {
1189        f(self.unmanaged, &self.control).await
1190    }
1191}
1192
1193/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
1194fn content_length(response: &reqwest::Response) -> Option<u64> {
1195    response
1196        .headers()
1197        .get(reqwest::header::CONTENT_LENGTH)
1198        .and_then(|val| val.to_str().ok())
1199        .and_then(|val| val.parse::<u64>().ok())
1200}
1201
1202/// An asynchronous reader that reports progress as bytes are read.
1203struct ProgressReader<'a, R> {
1204    reader: R,
1205    index: usize,
1206    reporter: &'a dyn Reporter,
1207}
1208
1209impl<'a, R> ProgressReader<'a, R> {
1210    /// Create a new [`ProgressReader`] that wraps another reader.
1211    fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
1212        Self {
1213            reader,
1214            index,
1215            reporter,
1216        }
1217    }
1218}
1219
1220impl<R> AsyncRead for ProgressReader<'_, R>
1221where
1222    R: AsyncRead + Unpin,
1223{
1224    fn poll_read(
1225        mut self: Pin<&mut Self>,
1226        cx: &mut Context<'_>,
1227        buf: &mut ReadBuf<'_>,
1228    ) -> Poll<io::Result<()>> {
1229        Pin::new(&mut self.as_mut().reader)
1230            .poll_read(cx, buf)
1231            .map_ok(|()| {
1232                self.reporter
1233                    .on_download_progress(self.index, buf.filled().len() as u64);
1234            })
1235    }
1236}
1237
1238/// A pointer to an archive in the cache, fetched from an HTTP archive.
1239///
1240/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
1241#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1242pub struct HttpArchivePointer {
1243    archive: Archive,
1244}
1245
1246impl HttpArchivePointer {
1247    /// Read an [`HttpArchivePointer`] from the cache.
1248    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1249        match fs_err::File::open(path.as_ref()) {
1250            Ok(file) => {
1251                let data = DataWithCachePolicy::from_reader(file)?.data;
1252                let archive = rmp_serde::from_slice::<Archive>(&data)?;
1253                Ok(Some(Self { archive }))
1254            }
1255            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1256            Err(err) => Err(Error::CacheRead(err)),
1257        }
1258    }
1259
1260    /// Return the [`Archive`] from the pointer.
1261    pub fn into_archive(self) -> Archive {
1262        self.archive
1263    }
1264
1265    /// Return the [`CacheInfo`] from the pointer.
1266    pub fn to_cache_info(&self) -> CacheInfo {
1267        CacheInfo::default()
1268    }
1269
1270    /// Return the [`BuildInfo`] from the pointer.
1271    pub fn to_build_info(&self) -> Option<BuildInfo> {
1272        None
1273    }
1274}
1275
1276/// A pointer to an archive in the cache, fetched from a local path.
1277///
1278/// Encoded with `MsgPack`, and represented on disk by a `.rev` file.
1279#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1280pub struct LocalArchivePointer {
1281    timestamp: Timestamp,
1282    archive: Archive,
1283}
1284
1285impl LocalArchivePointer {
1286    /// Read an [`LocalArchivePointer`] from the cache.
1287    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1288        match fs_err::read(path) {
1289            Ok(cached) => Ok(Some(rmp_serde::from_slice::<Self>(&cached)?)),
1290            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1291            Err(err) => Err(Error::CacheRead(err)),
1292        }
1293    }
1294
1295    /// Write an [`LocalArchivePointer`] to the cache.
1296    pub async fn write_to(&self, entry: &CacheEntry) -> Result<(), Error> {
1297        write_atomic(entry.path(), rmp_serde::to_vec(&self)?)
1298            .await
1299            .map_err(Error::CacheWrite)
1300    }
1301
1302    /// Returns `true` if the archive is up-to-date with the given modified timestamp.
1303    pub fn is_up_to_date(&self, modified: Timestamp) -> bool {
1304        self.timestamp == modified
1305    }
1306
1307    /// Return the [`Archive`] from the pointer.
1308    pub fn into_archive(self) -> Archive {
1309        self.archive
1310    }
1311
1312    /// Return the [`CacheInfo`] from the pointer.
1313    pub fn to_cache_info(&self) -> CacheInfo {
1314        CacheInfo::from_timestamp(self.timestamp)
1315    }
1316
1317    /// Return the [`BuildInfo`] from the pointer.
1318    pub fn to_build_info(&self) -> Option<BuildInfo> {
1319        None
1320    }
1321}
1322
1323#[derive(Debug, Clone)]
1324struct WheelTarget {
1325    /// The URL from which the wheel can be downloaded.
1326    url: DisplaySafeUrl,
1327    /// The expected extension of the wheel file.
1328    extension: WheelExtension,
1329    /// The expected size of the wheel file, if known.
1330    size: Option<u64>,
1331}
1332
1333impl TryFrom<&File> for WheelTarget {
1334    type Error = ToUrlError;
1335
1336    /// Determine the [`WheelTarget`] from a [`File`].
1337    fn try_from(file: &File) -> Result<Self, Self::Error> {
1338        let url = file.url.to_url()?;
1339        if let Some(zstd) = file.zstd.as_ref() {
1340            Ok(Self {
1341                url: add_tar_zst_extension(url),
1342                extension: WheelExtension::WhlZst,
1343                size: zstd.size,
1344            })
1345        } else {
1346            Ok(Self {
1347                url,
1348                extension: WheelExtension::Whl,
1349                size: file.size,
1350            })
1351        }
1352    }
1353}
1354
1355#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1356enum WheelExtension {
1357    /// A `.whl` file.
1358    Whl,
1359    /// A `.whl.tar.zst` file.
1360    WhlZst,
1361}
1362
1363/// Add `.tar.zst` to the end of the URL path, if it doesn't already exist.
1364#[must_use]
1365fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {
1366    let mut path = url.path().to_string();
1367
1368    if !path.ends_with(".tar.zst") {
1369        path.push_str(".tar.zst");
1370    }
1371
1372    url.set_path(&path);
1373    url
1374}
1375
1376#[cfg(test)]
1377mod tests {
1378    use super::*;
1379
1380    #[test]
1381    fn test_add_tar_zst_extension() {
1382        let url =
1383            DisplaySafeUrl::parse("https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl")
1384                .unwrap();
1385        assert_eq!(
1386            add_tar_zst_extension(url).as_str(),
1387            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1388        );
1389
1390        let url = DisplaySafeUrl::parse(
1391            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst",
1392        )
1393        .unwrap();
1394        assert_eq!(
1395            add_tar_zst_extension(url).as_str(),
1396            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1397        );
1398
1399        let url = DisplaySafeUrl::parse(
1400            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl",
1401        )
1402        .unwrap();
1403        assert_eq!(
1404            add_tar_zst_extension(url).as_str(),
1405            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl.tar.zst"
1406        );
1407    }
1408}