uv_torch/
backend.rs

1//! `uv-torch` is a library for determining the appropriate PyTorch index based on the operating
2//! system and CUDA driver version.
3//!
4//! This library is derived from `light-the-torch` by Philipp Meier, which is available under the
5//! following BSD-3 Clause license:
6//!
7//! ```text
8//! BSD 3-Clause License
9//!
10//! Copyright (c) 2020, Philip Meier
11//! All rights reserved.
12//!
13//! Redistribution and use in source and binary forms, with or without
14//! modification, are permitted provided that the following conditions are met:
15//!
16//! 1. Redistributions of source code must retain the above copyright notice, this
17//!    list of conditions and the following disclaimer.
18//!
19//! 2. Redistributions in binary form must reproduce the above copyright notice,
20//!    this list of conditions and the following disclaimer in the documentation
21//!    and/or other materials provided with the distribution.
22//!
23//! 3. Neither the name of the copyright holder nor the names of its
24//!    contributors may be used to endorse or promote products derived from
25//!    this software without specific prior written permission.
26//!
27//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37//! ```
38
39use std::borrow::Cow;
40use std::str::FromStr;
41use std::sync::LazyLock;
42
43use either::Either;
44use url::Url;
45
46use uv_distribution_types::IndexUrl;
47use uv_normalize::PackageName;
48use uv_pep440::Version;
49use uv_platform_tags::Os;
50use uv_static::EnvVars;
51
52use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
53
54/// The strategy to use when determining the appropriate PyTorch index.
55#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
56#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
57#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
58#[serde(rename_all = "kebab-case")]
59pub enum TorchMode {
60    /// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
61    Auto,
62    /// Use the CPU-only PyTorch index.
63    Cpu,
64    /// Use the PyTorch index for CUDA 13.0.
65    Cu130,
66    /// Use the PyTorch index for CUDA 12.9.
67    Cu129,
68    /// Use the PyTorch index for CUDA 12.8.
69    Cu128,
70    /// Use the PyTorch index for CUDA 12.6.
71    Cu126,
72    /// Use the PyTorch index for CUDA 12.5.
73    Cu125,
74    /// Use the PyTorch index for CUDA 12.4.
75    Cu124,
76    /// Use the PyTorch index for CUDA 12.3.
77    Cu123,
78    /// Use the PyTorch index for CUDA 12.2.
79    Cu122,
80    /// Use the PyTorch index for CUDA 12.1.
81    Cu121,
82    /// Use the PyTorch index for CUDA 12.0.
83    Cu120,
84    /// Use the PyTorch index for CUDA 11.8.
85    Cu118,
86    /// Use the PyTorch index for CUDA 11.7.
87    Cu117,
88    /// Use the PyTorch index for CUDA 11.6.
89    Cu116,
90    /// Use the PyTorch index for CUDA 11.5.
91    Cu115,
92    /// Use the PyTorch index for CUDA 11.4.
93    Cu114,
94    /// Use the PyTorch index for CUDA 11.3.
95    Cu113,
96    /// Use the PyTorch index for CUDA 11.2.
97    Cu112,
98    /// Use the PyTorch index for CUDA 11.1.
99    Cu111,
100    /// Use the PyTorch index for CUDA 11.0.
101    Cu110,
102    /// Use the PyTorch index for CUDA 10.2.
103    Cu102,
104    /// Use the PyTorch index for CUDA 10.1.
105    Cu101,
106    /// Use the PyTorch index for CUDA 10.0.
107    Cu100,
108    /// Use the PyTorch index for CUDA 9.2.
109    Cu92,
110    /// Use the PyTorch index for CUDA 9.1.
111    Cu91,
112    /// Use the PyTorch index for CUDA 9.0.
113    Cu90,
114    /// Use the PyTorch index for CUDA 8.0.
115    Cu80,
116    /// Use the PyTorch index for ROCm 6.3.
117    #[serde(rename = "rocm6.3")]
118    #[cfg_attr(feature = "clap", clap(name = "rocm6.3"))]
119    Rocm63,
120    /// Use the PyTorch index for ROCm 6.2.4.
121    #[serde(rename = "rocm6.2.4")]
122    #[cfg_attr(feature = "clap", clap(name = "rocm6.2.4"))]
123    Rocm624,
124    /// Use the PyTorch index for ROCm 6.2.
125    #[serde(rename = "rocm6.2")]
126    #[cfg_attr(feature = "clap", clap(name = "rocm6.2"))]
127    Rocm62,
128    /// Use the PyTorch index for ROCm 6.1.
129    #[serde(rename = "rocm6.1")]
130    #[cfg_attr(feature = "clap", clap(name = "rocm6.1"))]
131    Rocm61,
132    /// Use the PyTorch index for ROCm 6.0.
133    #[serde(rename = "rocm6.0")]
134    #[cfg_attr(feature = "clap", clap(name = "rocm6.0"))]
135    Rocm60,
136    /// Use the PyTorch index for ROCm 5.7.
137    #[serde(rename = "rocm5.7")]
138    #[cfg_attr(feature = "clap", clap(name = "rocm5.7"))]
139    Rocm57,
140    /// Use the PyTorch index for ROCm 5.6.
141    #[serde(rename = "rocm5.6")]
142    #[cfg_attr(feature = "clap", clap(name = "rocm5.6"))]
143    Rocm56,
144    /// Use the PyTorch index for ROCm 5.5.
145    #[serde(rename = "rocm5.5")]
146    #[cfg_attr(feature = "clap", clap(name = "rocm5.5"))]
147    Rocm55,
148    /// Use the PyTorch index for ROCm 5.4.2.
149    #[serde(rename = "rocm5.4.2")]
150    #[cfg_attr(feature = "clap", clap(name = "rocm5.4.2"))]
151    Rocm542,
152    /// Use the PyTorch index for ROCm 5.4.
153    #[serde(rename = "rocm5.4")]
154    #[cfg_attr(feature = "clap", clap(name = "rocm5.4"))]
155    Rocm54,
156    /// Use the PyTorch index for ROCm 5.3.
157    #[serde(rename = "rocm5.3")]
158    #[cfg_attr(feature = "clap", clap(name = "rocm5.3"))]
159    Rocm53,
160    /// Use the PyTorch index for ROCm 5.2.
161    #[serde(rename = "rocm5.2")]
162    #[cfg_attr(feature = "clap", clap(name = "rocm5.2"))]
163    Rocm52,
164    /// Use the PyTorch index for ROCm 5.1.1.
165    #[serde(rename = "rocm5.1.1")]
166    #[cfg_attr(feature = "clap", clap(name = "rocm5.1.1"))]
167    Rocm511,
168    /// Use the PyTorch index for ROCm 4.2.
169    #[serde(rename = "rocm4.2")]
170    #[cfg_attr(feature = "clap", clap(name = "rocm4.2"))]
171    Rocm42,
172    /// Use the PyTorch index for ROCm 4.1.
173    #[serde(rename = "rocm4.1")]
174    #[cfg_attr(feature = "clap", clap(name = "rocm4.1"))]
175    Rocm41,
176    /// Use the PyTorch index for ROCm 4.0.1.
177    #[serde(rename = "rocm4.0.1")]
178    #[cfg_attr(feature = "clap", clap(name = "rocm4.0.1"))]
179    Rocm401,
180    /// Use the PyTorch index for Intel XPU.
181    Xpu,
182}
183
184#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
185pub enum TorchSource {
186    /// Download PyTorch builds from the official PyTorch index.
187    #[default]
188    PyTorch,
189    /// Download PyTorch builds from the pyx index.
190    Pyx,
191}
192
193/// The strategy to use when determining the appropriate PyTorch index.
194#[derive(Debug, Clone, Eq, PartialEq)]
195pub enum TorchStrategy {
196    /// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`).
197    Cuda {
198        os: Os,
199        driver_version: Version,
200        source: TorchSource,
201    },
202    /// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`).
203    Amd {
204        os: Os,
205        gpu_architecture: AmdGpuArchitecture,
206        source: TorchSource,
207    },
208    /// Select the appropriate PyTorch index based on the operating system and Intel GPU presence.
209    Xpu { os: Os, source: TorchSource },
210    /// Use the specified PyTorch index.
211    Backend {
212        backend: TorchBackend,
213        source: TorchSource,
214    },
215}
216
217impl TorchStrategy {
218    /// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
219    pub fn from_mode(
220        mode: TorchMode,
221        source: TorchSource,
222        os: &Os,
223    ) -> Result<Self, AcceleratorError> {
224        let backend = match mode {
225            TorchMode::Auto => match Accelerator::detect()? {
226                Some(Accelerator::Cuda { driver_version }) => {
227                    return Ok(Self::Cuda {
228                        os: os.clone(),
229                        driver_version: driver_version.clone(),
230                        source,
231                    });
232                }
233                Some(Accelerator::Amd { gpu_architecture }) => {
234                    return Ok(Self::Amd {
235                        os: os.clone(),
236                        gpu_architecture,
237                        source,
238                    });
239                }
240                Some(Accelerator::Xpu) => {
241                    return Ok(Self::Xpu {
242                        os: os.clone(),
243                        source,
244                    });
245                }
246                None => TorchBackend::Cpu,
247            },
248            TorchMode::Cpu => TorchBackend::Cpu,
249            TorchMode::Cu130 => TorchBackend::Cu130,
250            TorchMode::Cu129 => TorchBackend::Cu129,
251            TorchMode::Cu128 => TorchBackend::Cu128,
252            TorchMode::Cu126 => TorchBackend::Cu126,
253            TorchMode::Cu125 => TorchBackend::Cu125,
254            TorchMode::Cu124 => TorchBackend::Cu124,
255            TorchMode::Cu123 => TorchBackend::Cu123,
256            TorchMode::Cu122 => TorchBackend::Cu122,
257            TorchMode::Cu121 => TorchBackend::Cu121,
258            TorchMode::Cu120 => TorchBackend::Cu120,
259            TorchMode::Cu118 => TorchBackend::Cu118,
260            TorchMode::Cu117 => TorchBackend::Cu117,
261            TorchMode::Cu116 => TorchBackend::Cu116,
262            TorchMode::Cu115 => TorchBackend::Cu115,
263            TorchMode::Cu114 => TorchBackend::Cu114,
264            TorchMode::Cu113 => TorchBackend::Cu113,
265            TorchMode::Cu112 => TorchBackend::Cu112,
266            TorchMode::Cu111 => TorchBackend::Cu111,
267            TorchMode::Cu110 => TorchBackend::Cu110,
268            TorchMode::Cu102 => TorchBackend::Cu102,
269            TorchMode::Cu101 => TorchBackend::Cu101,
270            TorchMode::Cu100 => TorchBackend::Cu100,
271            TorchMode::Cu92 => TorchBackend::Cu92,
272            TorchMode::Cu91 => TorchBackend::Cu91,
273            TorchMode::Cu90 => TorchBackend::Cu90,
274            TorchMode::Cu80 => TorchBackend::Cu80,
275            TorchMode::Rocm63 => TorchBackend::Rocm63,
276            TorchMode::Rocm624 => TorchBackend::Rocm624,
277            TorchMode::Rocm62 => TorchBackend::Rocm62,
278            TorchMode::Rocm61 => TorchBackend::Rocm61,
279            TorchMode::Rocm60 => TorchBackend::Rocm60,
280            TorchMode::Rocm57 => TorchBackend::Rocm57,
281            TorchMode::Rocm56 => TorchBackend::Rocm56,
282            TorchMode::Rocm55 => TorchBackend::Rocm55,
283            TorchMode::Rocm542 => TorchBackend::Rocm542,
284            TorchMode::Rocm54 => TorchBackend::Rocm54,
285            TorchMode::Rocm53 => TorchBackend::Rocm53,
286            TorchMode::Rocm52 => TorchBackend::Rocm52,
287            TorchMode::Rocm511 => TorchBackend::Rocm511,
288            TorchMode::Rocm42 => TorchBackend::Rocm42,
289            TorchMode::Rocm41 => TorchBackend::Rocm41,
290            TorchMode::Rocm401 => TorchBackend::Rocm401,
291            TorchMode::Xpu => TorchBackend::Xpu,
292        };
293        Ok(Self::Backend { backend, source })
294    }
295
296    /// Returns `true` if the [`TorchStrategy`] applies to the given [`PackageName`].
297    pub fn applies_to(&self, package_name: &PackageName) -> bool {
298        let source = match self {
299            Self::Cuda { source, .. } => *source,
300            Self::Amd { source, .. } => *source,
301            Self::Xpu { source, .. } => *source,
302            Self::Backend { source, .. } => *source,
303        };
304        match source {
305            TorchSource::PyTorch => {
306                matches!(
307                    package_name.as_str(),
308                    "torch"
309                        | "torcharrow"
310                        | "torchaudio"
311                        | "torchcsprng"
312                        | "torchdata"
313                        | "torchdistx"
314                        | "torchserve"
315                        | "torchtext"
316                        | "torchvision"
317                        | "triton"
318                        | "pytorch-triton"
319                        | "pytorch-triton-rocm"
320                        | "pytorch-triton-xpu"
321                )
322            }
323            TorchSource::Pyx => {
324                matches!(
325                    package_name.as_str(),
326                    "deepspeed"
327                        | "flash-attn"
328                        | "flash-attn-3"
329                        | "megablocks"
330                        | "natten"
331                        | "pyg-lib"
332                        | "torch-cluster"
333                        | "torch-scatter"
334                        | "torch-sparse"
335                        | "torch-spline-conv"
336                        | "vllm"
337                        | "torch"
338                        | "torcharrow"
339                        | "torchaudio"
340                        | "torchcsprng"
341                        | "torchdata"
342                        | "torchdistx"
343                        | "torchserve"
344                        | "torchtext"
345                        | "torchvision"
346                        | "triton"
347                        | "pytorch-triton"
348                        | "pytorch-triton-rocm"
349                        | "pytorch-triton-xpu"
350                )
351            }
352        }
353    }
354
355    /// Returns `true` if the given [`PackageName`] has a system dependency (e.g., CUDA or ROCm).
356    ///
357    /// For example, `triton` is hosted on the PyTorch indexes, but does not have a system
358    /// dependency on the associated CUDA version (i.e., the `triton` on the `cu128` index doesn't
359    /// depend on CUDA 12.8).
360    pub fn has_system_dependency(&self, package_name: &PackageName) -> bool {
361        matches!(
362            package_name.as_str(),
363            "flash-attn"
364                | "flash-attn-3"
365                | "megablocks"
366                | "natten"
367                | "deepspeed"
368                | "vllm"
369                | "torch"
370                | "torcharrow"
371                | "torchaudio"
372                | "torchcsprng"
373                | "torchdata"
374                | "torchdistx"
375                | "torchtext"
376                | "torchvision"
377        )
378    }
379
380    /// Return the appropriate index URLs for the given [`TorchStrategy`].
381    pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
382        match self {
383            Self::Cuda {
384                os,
385                driver_version,
386                source,
387            } => {
388                // If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
389                // indexes.
390                //
391                // See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
392                match os {
393                    Os::Manylinux { .. } | Os::Musllinux { .. } => {
394                        Either::Left(Either::Left(Either::Left(
395                            LINUX_CUDA_DRIVERS
396                                .iter()
397                                .filter_map(move |(backend, version)| {
398                                    if driver_version >= version {
399                                        Some(backend.index_url(*source))
400                                    } else {
401                                        None
402                                    }
403                                })
404                                .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
405                        )))
406                    }
407                    Os::Windows => Either::Left(Either::Left(Either::Right(
408                        WINDOWS_CUDA_VERSIONS
409                            .iter()
410                            .filter_map(move |(backend, version)| {
411                                if driver_version >= version {
412                                    Some(backend.index_url(*source))
413                                } else {
414                                    None
415                                }
416                            })
417                            .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
418                    ))),
419                    Os::Macos { .. }
420                    | Os::FreeBsd { .. }
421                    | Os::NetBsd { .. }
422                    | Os::OpenBsd { .. }
423                    | Os::Dragonfly { .. }
424                    | Os::Illumos { .. }
425                    | Os::Haiku { .. }
426                    | Os::Android { .. }
427                    | Os::Pyodide { .. }
428                    | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
429                        TorchBackend::Cpu.index_url(*source),
430                    ))),
431                }
432            }
433            Self::Amd {
434                os,
435                gpu_architecture,
436                source,
437            } => match os {
438                Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
439                    LINUX_AMD_GPU_DRIVERS
440                        .iter()
441                        .filter_map(move |(backend, architecture)| {
442                            if gpu_architecture == architecture {
443                                Some(backend.index_url(*source))
444                            } else {
445                                None
446                            }
447                        })
448                        .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
449                )),
450                Os::Windows
451                | Os::Macos { .. }
452                | Os::FreeBsd { .. }
453                | Os::NetBsd { .. }
454                | Os::OpenBsd { .. }
455                | Os::Dragonfly { .. }
456                | Os::Illumos { .. }
457                | Os::Haiku { .. }
458                | Os::Android { .. }
459                | Os::Pyodide { .. }
460                | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
461                    TorchBackend::Cpu.index_url(*source),
462                ))),
463            },
464            Self::Xpu { os, source } => match os {
465                Os::Manylinux { .. } | Os::Windows => Either::Right(Either::Right(Either::Left(
466                    std::iter::once(TorchBackend::Xpu.index_url(*source)),
467                ))),
468                Os::Musllinux { .. }
469                | Os::Macos { .. }
470                | Os::FreeBsd { .. }
471                | Os::NetBsd { .. }
472                | Os::OpenBsd { .. }
473                | Os::Dragonfly { .. }
474                | Os::Illumos { .. }
475                | Os::Haiku { .. }
476                | Os::Android { .. }
477                | Os::Pyodide { .. }
478                | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
479                    TorchBackend::Cpu.index_url(*source),
480                ))),
481            },
482            Self::Backend { backend, source } => Either::Right(Either::Right(Either::Right(
483                std::iter::once(backend.index_url(*source)),
484            ))),
485        }
486    }
487}
488
489/// The available backends for PyTorch.
490#[derive(Debug, Copy, Clone, Eq, PartialEq)]
491pub enum TorchBackend {
492    Cpu,
493    Cu130,
494    Cu129,
495    Cu128,
496    Cu126,
497    Cu125,
498    Cu124,
499    Cu123,
500    Cu122,
501    Cu121,
502    Cu120,
503    Cu118,
504    Cu117,
505    Cu116,
506    Cu115,
507    Cu114,
508    Cu113,
509    Cu112,
510    Cu111,
511    Cu110,
512    Cu102,
513    Cu101,
514    Cu100,
515    Cu92,
516    Cu91,
517    Cu90,
518    Cu80,
519    Rocm63,
520    Rocm624,
521    Rocm62,
522    Rocm61,
523    Rocm60,
524    Rocm57,
525    Rocm56,
526    Rocm55,
527    Rocm542,
528    Rocm54,
529    Rocm53,
530    Rocm52,
531    Rocm511,
532    Rocm42,
533    Rocm41,
534    Rocm401,
535    Xpu,
536}
537
538impl TorchBackend {
539    /// Return the appropriate index URL for the given [`TorchBackend`].
540    fn index_url(self, source: TorchSource) -> &'static IndexUrl {
541        match self {
542            Self::Cpu => match source {
543                TorchSource::PyTorch => &PYTORCH_CPU_INDEX_URL,
544                TorchSource::Pyx => &PYX_CPU_INDEX_URL,
545            },
546            Self::Cu130 => match source {
547                TorchSource::PyTorch => &PYTORCH_CU130_INDEX_URL,
548                TorchSource::Pyx => &PYX_CU130_INDEX_URL,
549            },
550            Self::Cu129 => match source {
551                TorchSource::PyTorch => &PYTORCH_CU129_INDEX_URL,
552                TorchSource::Pyx => &PYX_CU129_INDEX_URL,
553            },
554            Self::Cu128 => match source {
555                TorchSource::PyTorch => &PYTORCH_CU128_INDEX_URL,
556                TorchSource::Pyx => &PYX_CU128_INDEX_URL,
557            },
558            Self::Cu126 => match source {
559                TorchSource::PyTorch => &PYTORCH_CU126_INDEX_URL,
560                TorchSource::Pyx => &PYX_CU126_INDEX_URL,
561            },
562            Self::Cu125 => match source {
563                TorchSource::PyTorch => &PYTORCH_CU125_INDEX_URL,
564                TorchSource::Pyx => &PYX_CU125_INDEX_URL,
565            },
566            Self::Cu124 => match source {
567                TorchSource::PyTorch => &PYTORCH_CU124_INDEX_URL,
568                TorchSource::Pyx => &PYX_CU124_INDEX_URL,
569            },
570            Self::Cu123 => match source {
571                TorchSource::PyTorch => &PYTORCH_CU123_INDEX_URL,
572                TorchSource::Pyx => &PYX_CU123_INDEX_URL,
573            },
574            Self::Cu122 => match source {
575                TorchSource::PyTorch => &PYTORCH_CU122_INDEX_URL,
576                TorchSource::Pyx => &PYX_CU122_INDEX_URL,
577            },
578            Self::Cu121 => match source {
579                TorchSource::PyTorch => &PYTORCH_CU121_INDEX_URL,
580                TorchSource::Pyx => &PYX_CU121_INDEX_URL,
581            },
582            Self::Cu120 => match source {
583                TorchSource::PyTorch => &PYTORCH_CU120_INDEX_URL,
584                TorchSource::Pyx => &PYX_CU120_INDEX_URL,
585            },
586            Self::Cu118 => match source {
587                TorchSource::PyTorch => &PYTORCH_CU118_INDEX_URL,
588                TorchSource::Pyx => &PYX_CU118_INDEX_URL,
589            },
590            Self::Cu117 => match source {
591                TorchSource::PyTorch => &PYTORCH_CU117_INDEX_URL,
592                TorchSource::Pyx => &PYX_CU117_INDEX_URL,
593            },
594            Self::Cu116 => match source {
595                TorchSource::PyTorch => &PYTORCH_CU116_INDEX_URL,
596                TorchSource::Pyx => &PYX_CU116_INDEX_URL,
597            },
598            Self::Cu115 => match source {
599                TorchSource::PyTorch => &PYTORCH_CU115_INDEX_URL,
600                TorchSource::Pyx => &PYX_CU115_INDEX_URL,
601            },
602            Self::Cu114 => match source {
603                TorchSource::PyTorch => &PYTORCH_CU114_INDEX_URL,
604                TorchSource::Pyx => &PYX_CU114_INDEX_URL,
605            },
606            Self::Cu113 => match source {
607                TorchSource::PyTorch => &PYTORCH_CU113_INDEX_URL,
608                TorchSource::Pyx => &PYX_CU113_INDEX_URL,
609            },
610            Self::Cu112 => match source {
611                TorchSource::PyTorch => &PYTORCH_CU112_INDEX_URL,
612                TorchSource::Pyx => &PYX_CU112_INDEX_URL,
613            },
614            Self::Cu111 => match source {
615                TorchSource::PyTorch => &PYTORCH_CU111_INDEX_URL,
616                TorchSource::Pyx => &PYX_CU111_INDEX_URL,
617            },
618            Self::Cu110 => match source {
619                TorchSource::PyTorch => &PYTORCH_CU110_INDEX_URL,
620                TorchSource::Pyx => &PYX_CU110_INDEX_URL,
621            },
622            Self::Cu102 => match source {
623                TorchSource::PyTorch => &PYTORCH_CU102_INDEX_URL,
624                TorchSource::Pyx => &PYX_CU102_INDEX_URL,
625            },
626            Self::Cu101 => match source {
627                TorchSource::PyTorch => &PYTORCH_CU101_INDEX_URL,
628                TorchSource::Pyx => &PYX_CU101_INDEX_URL,
629            },
630            Self::Cu100 => match source {
631                TorchSource::PyTorch => &PYTORCH_CU100_INDEX_URL,
632                TorchSource::Pyx => &PYX_CU100_INDEX_URL,
633            },
634            Self::Cu92 => match source {
635                TorchSource::PyTorch => &PYTORCH_CU92_INDEX_URL,
636                TorchSource::Pyx => &PYX_CU92_INDEX_URL,
637            },
638            Self::Cu91 => match source {
639                TorchSource::PyTorch => &PYTORCH_CU91_INDEX_URL,
640                TorchSource::Pyx => &PYX_CU91_INDEX_URL,
641            },
642            Self::Cu90 => match source {
643                TorchSource::PyTorch => &PYTORCH_CU90_INDEX_URL,
644                TorchSource::Pyx => &PYX_CU90_INDEX_URL,
645            },
646            Self::Cu80 => match source {
647                TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
648                TorchSource::Pyx => &PYX_CU80_INDEX_URL,
649            },
650            Self::Rocm63 => match source {
651                TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
652                TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
653            },
654            Self::Rocm624 => match source {
655                TorchSource::PyTorch => &PYTORCH_ROCM624_INDEX_URL,
656                TorchSource::Pyx => &PYX_ROCM624_INDEX_URL,
657            },
658            Self::Rocm62 => match source {
659                TorchSource::PyTorch => &PYTORCH_ROCM62_INDEX_URL,
660                TorchSource::Pyx => &PYX_ROCM62_INDEX_URL,
661            },
662            Self::Rocm61 => match source {
663                TorchSource::PyTorch => &PYTORCH_ROCM61_INDEX_URL,
664                TorchSource::Pyx => &PYX_ROCM61_INDEX_URL,
665            },
666            Self::Rocm60 => match source {
667                TorchSource::PyTorch => &PYTORCH_ROCM60_INDEX_URL,
668                TorchSource::Pyx => &PYX_ROCM60_INDEX_URL,
669            },
670            Self::Rocm57 => match source {
671                TorchSource::PyTorch => &PYTORCH_ROCM57_INDEX_URL,
672                TorchSource::Pyx => &PYX_ROCM57_INDEX_URL,
673            },
674            Self::Rocm56 => match source {
675                TorchSource::PyTorch => &PYTORCH_ROCM56_INDEX_URL,
676                TorchSource::Pyx => &PYX_ROCM56_INDEX_URL,
677            },
678            Self::Rocm55 => match source {
679                TorchSource::PyTorch => &PYTORCH_ROCM55_INDEX_URL,
680                TorchSource::Pyx => &PYX_ROCM55_INDEX_URL,
681            },
682            Self::Rocm542 => match source {
683                TorchSource::PyTorch => &PYTORCH_ROCM542_INDEX_URL,
684                TorchSource::Pyx => &PYX_ROCM542_INDEX_URL,
685            },
686            Self::Rocm54 => match source {
687                TorchSource::PyTorch => &PYTORCH_ROCM54_INDEX_URL,
688                TorchSource::Pyx => &PYX_ROCM54_INDEX_URL,
689            },
690            Self::Rocm53 => match source {
691                TorchSource::PyTorch => &PYTORCH_ROCM53_INDEX_URL,
692                TorchSource::Pyx => &PYX_ROCM53_INDEX_URL,
693            },
694            Self::Rocm52 => match source {
695                TorchSource::PyTorch => &PYTORCH_ROCM52_INDEX_URL,
696                TorchSource::Pyx => &PYX_ROCM52_INDEX_URL,
697            },
698            Self::Rocm511 => match source {
699                TorchSource::PyTorch => &PYTORCH_ROCM511_INDEX_URL,
700                TorchSource::Pyx => &PYX_ROCM511_INDEX_URL,
701            },
702            Self::Rocm42 => match source {
703                TorchSource::PyTorch => &PYTORCH_ROCM42_INDEX_URL,
704                TorchSource::Pyx => &PYX_ROCM42_INDEX_URL,
705            },
706            Self::Rocm41 => match source {
707                TorchSource::PyTorch => &PYTORCH_ROCM41_INDEX_URL,
708                TorchSource::Pyx => &PYX_ROCM41_INDEX_URL,
709            },
710            Self::Rocm401 => match source {
711                TorchSource::PyTorch => &PYTORCH_ROCM401_INDEX_URL,
712                TorchSource::Pyx => &PYX_ROCM401_INDEX_URL,
713            },
714            Self::Xpu => match source {
715                TorchSource::PyTorch => &PYTORCH_XPU_INDEX_URL,
716                TorchSource::Pyx => &PYX_XPU_INDEX_URL,
717            },
718        }
719    }
720
721    /// Extract a [`TorchBackend`] from an index URL.
722    pub fn from_index(index: &Url) -> Option<Self> {
723        let backend_identifier = if index.host_str() == Some("download.pytorch.org") {
724            // E.g., `https://download.pytorch.org/whl/cu124`
725            let mut path_segments = index.path_segments()?;
726            if path_segments.next() != Some("whl") {
727                return None;
728            }
729            path_segments.next()?
730        // TODO(zanieb): We should consolidate this with `is_known_url` somehow
731        } else if index.host_str() == PYX_API_BASE_URL.strip_prefix("https://") {
732            // E.g., `https://api.pyx.dev/simple/astral-sh/cu124`
733            let mut path_segments = index.path_segments()?;
734            if path_segments.next() != Some("simple") {
735                return None;
736            }
737            if path_segments.next() != Some("astral-sh") {
738                return None;
739            }
740            path_segments.next()?
741        } else {
742            return None;
743        };
744        Self::from_str(backend_identifier).ok()
745    }
746
747    /// Returns the CUDA [`Version`] for the given [`TorchBackend`].
748    pub fn cuda_version(&self) -> Option<Version> {
749        match self {
750            Self::Cpu => None,
751            Self::Cu130 => Some(Version::new([13, 0])),
752            Self::Cu129 => Some(Version::new([12, 9])),
753            Self::Cu128 => Some(Version::new([12, 8])),
754            Self::Cu126 => Some(Version::new([12, 6])),
755            Self::Cu125 => Some(Version::new([12, 5])),
756            Self::Cu124 => Some(Version::new([12, 4])),
757            Self::Cu123 => Some(Version::new([12, 3])),
758            Self::Cu122 => Some(Version::new([12, 2])),
759            Self::Cu121 => Some(Version::new([12, 1])),
760            Self::Cu120 => Some(Version::new([12, 0])),
761            Self::Cu118 => Some(Version::new([11, 8])),
762            Self::Cu117 => Some(Version::new([11, 7])),
763            Self::Cu116 => Some(Version::new([11, 6])),
764            Self::Cu115 => Some(Version::new([11, 5])),
765            Self::Cu114 => Some(Version::new([11, 4])),
766            Self::Cu113 => Some(Version::new([11, 3])),
767            Self::Cu112 => Some(Version::new([11, 2])),
768            Self::Cu111 => Some(Version::new([11, 1])),
769            Self::Cu110 => Some(Version::new([11, 0])),
770            Self::Cu102 => Some(Version::new([10, 2])),
771            Self::Cu101 => Some(Version::new([10, 1])),
772            Self::Cu100 => Some(Version::new([10, 0])),
773            Self::Cu92 => Some(Version::new([9, 2])),
774            Self::Cu91 => Some(Version::new([9, 1])),
775            Self::Cu90 => Some(Version::new([9, 0])),
776            Self::Cu80 => Some(Version::new([8, 0])),
777            Self::Rocm63 => None,
778            Self::Rocm624 => None,
779            Self::Rocm62 => None,
780            Self::Rocm61 => None,
781            Self::Rocm60 => None,
782            Self::Rocm57 => None,
783            Self::Rocm56 => None,
784            Self::Rocm55 => None,
785            Self::Rocm542 => None,
786            Self::Rocm54 => None,
787            Self::Rocm53 => None,
788            Self::Rocm52 => None,
789            Self::Rocm511 => None,
790            Self::Rocm42 => None,
791            Self::Rocm41 => None,
792            Self::Rocm401 => None,
793            Self::Xpu => None,
794        }
795    }
796
797    /// Returns the ROCM [`Version`] for the given [`TorchBackend`].
798    pub fn rocm_version(&self) -> Option<Version> {
799        match self {
800            Self::Cpu => None,
801            Self::Cu130 => None,
802            Self::Cu129 => None,
803            Self::Cu128 => None,
804            Self::Cu126 => None,
805            Self::Cu125 => None,
806            Self::Cu124 => None,
807            Self::Cu123 => None,
808            Self::Cu122 => None,
809            Self::Cu121 => None,
810            Self::Cu120 => None,
811            Self::Cu118 => None,
812            Self::Cu117 => None,
813            Self::Cu116 => None,
814            Self::Cu115 => None,
815            Self::Cu114 => None,
816            Self::Cu113 => None,
817            Self::Cu112 => None,
818            Self::Cu111 => None,
819            Self::Cu110 => None,
820            Self::Cu102 => None,
821            Self::Cu101 => None,
822            Self::Cu100 => None,
823            Self::Cu92 => None,
824            Self::Cu91 => None,
825            Self::Cu90 => None,
826            Self::Cu80 => None,
827            Self::Rocm63 => Some(Version::new([6, 3])),
828            Self::Rocm624 => Some(Version::new([6, 2, 4])),
829            Self::Rocm62 => Some(Version::new([6, 2])),
830            Self::Rocm61 => Some(Version::new([6, 1])),
831            Self::Rocm60 => Some(Version::new([6, 0])),
832            Self::Rocm57 => Some(Version::new([5, 7])),
833            Self::Rocm56 => Some(Version::new([5, 6])),
834            Self::Rocm55 => Some(Version::new([5, 5])),
835            Self::Rocm542 => Some(Version::new([5, 4, 2])),
836            Self::Rocm54 => Some(Version::new([5, 4])),
837            Self::Rocm53 => Some(Version::new([5, 3])),
838            Self::Rocm52 => Some(Version::new([5, 2])),
839            Self::Rocm511 => Some(Version::new([5, 1, 1])),
840            Self::Rocm42 => Some(Version::new([4, 2])),
841            Self::Rocm41 => Some(Version::new([4, 1])),
842            Self::Rocm401 => Some(Version::new([4, 0, 1])),
843            Self::Xpu => None,
844        }
845    }
846}
847
848impl FromStr for TorchBackend {
849    type Err = String;
850
851    fn from_str(s: &str) -> Result<Self, Self::Err> {
852        match s {
853            "cpu" => Ok(Self::Cpu),
854            "cu130" => Ok(Self::Cu130),
855            "cu129" => Ok(Self::Cu129),
856            "cu128" => Ok(Self::Cu128),
857            "cu126" => Ok(Self::Cu126),
858            "cu125" => Ok(Self::Cu125),
859            "cu124" => Ok(Self::Cu124),
860            "cu123" => Ok(Self::Cu123),
861            "cu122" => Ok(Self::Cu122),
862            "cu121" => Ok(Self::Cu121),
863            "cu120" => Ok(Self::Cu120),
864            "cu118" => Ok(Self::Cu118),
865            "cu117" => Ok(Self::Cu117),
866            "cu116" => Ok(Self::Cu116),
867            "cu115" => Ok(Self::Cu115),
868            "cu114" => Ok(Self::Cu114),
869            "cu113" => Ok(Self::Cu113),
870            "cu112" => Ok(Self::Cu112),
871            "cu111" => Ok(Self::Cu111),
872            "cu110" => Ok(Self::Cu110),
873            "cu102" => Ok(Self::Cu102),
874            "cu101" => Ok(Self::Cu101),
875            "cu100" => Ok(Self::Cu100),
876            "cu92" => Ok(Self::Cu92),
877            "cu91" => Ok(Self::Cu91),
878            "cu90" => Ok(Self::Cu90),
879            "cu80" => Ok(Self::Cu80),
880            "rocm6.3" => Ok(Self::Rocm63),
881            "rocm6.2.4" => Ok(Self::Rocm624),
882            "rocm6.2" => Ok(Self::Rocm62),
883            "rocm6.1" => Ok(Self::Rocm61),
884            "rocm6.0" => Ok(Self::Rocm60),
885            "rocm5.7" => Ok(Self::Rocm57),
886            "rocm5.6" => Ok(Self::Rocm56),
887            "rocm5.5" => Ok(Self::Rocm55),
888            "rocm5.4.2" => Ok(Self::Rocm542),
889            "rocm5.4" => Ok(Self::Rocm54),
890            "rocm5.3" => Ok(Self::Rocm53),
891            "rocm5.2" => Ok(Self::Rocm52),
892            "rocm5.1.1" => Ok(Self::Rocm511),
893            "rocm4.2" => Ok(Self::Rocm42),
894            "rocm4.1" => Ok(Self::Rocm41),
895            "rocm4.0.1" => Ok(Self::Rocm401),
896            "xpu" => Ok(Self::Xpu),
897            _ => Err(format!("Unknown PyTorch backend: {s}")),
898        }
899    }
900}
901
902/// Linux CUDA driver versions and the corresponding CUDA versions.
903///
904/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
905static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock::new(|| {
906    [
907        // Table 2 from
908        // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
909        (TorchBackend::Cu130, Version::new([580])),
910        (TorchBackend::Cu129, Version::new([525, 60, 13])),
911        (TorchBackend::Cu128, Version::new([525, 60, 13])),
912        (TorchBackend::Cu126, Version::new([525, 60, 13])),
913        (TorchBackend::Cu125, Version::new([525, 60, 13])),
914        (TorchBackend::Cu124, Version::new([525, 60, 13])),
915        (TorchBackend::Cu123, Version::new([525, 60, 13])),
916        (TorchBackend::Cu122, Version::new([525, 60, 13])),
917        (TorchBackend::Cu121, Version::new([525, 60, 13])),
918        (TorchBackend::Cu120, Version::new([525, 60, 13])),
919        // Table 2 from
920        // https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
921        (TorchBackend::Cu118, Version::new([450, 80, 2])),
922        (TorchBackend::Cu117, Version::new([450, 80, 2])),
923        (TorchBackend::Cu116, Version::new([450, 80, 2])),
924        (TorchBackend::Cu115, Version::new([450, 80, 2])),
925        (TorchBackend::Cu114, Version::new([450, 80, 2])),
926        (TorchBackend::Cu113, Version::new([450, 80, 2])),
927        (TorchBackend::Cu112, Version::new([450, 80, 2])),
928        (TorchBackend::Cu111, Version::new([450, 80, 2])),
929        (TorchBackend::Cu110, Version::new([450, 36, 6])),
930        // Table 1 from
931        // https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
932        (TorchBackend::Cu102, Version::new([440, 33])),
933        (TorchBackend::Cu101, Version::new([418, 39])),
934        (TorchBackend::Cu100, Version::new([410, 48])),
935        (TorchBackend::Cu92, Version::new([396, 26])),
936        (TorchBackend::Cu91, Version::new([390, 46])),
937        (TorchBackend::Cu90, Version::new([384, 81])),
938        (TorchBackend::Cu80, Version::new([375, 26])),
939    ]
940});
941
942/// Windows CUDA driver versions and the corresponding CUDA versions.
943///
944/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
945static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock::new(|| {
946    [
947        // Table 2 from
948        // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
949        (TorchBackend::Cu130, Version::new([580])),
950        (TorchBackend::Cu129, Version::new([528, 33])),
951        (TorchBackend::Cu128, Version::new([528, 33])),
952        (TorchBackend::Cu126, Version::new([528, 33])),
953        (TorchBackend::Cu125, Version::new([528, 33])),
954        (TorchBackend::Cu124, Version::new([528, 33])),
955        (TorchBackend::Cu123, Version::new([528, 33])),
956        (TorchBackend::Cu122, Version::new([528, 33])),
957        (TorchBackend::Cu121, Version::new([528, 33])),
958        (TorchBackend::Cu120, Version::new([528, 33])),
959        // Table 2 from
960        // https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
961        (TorchBackend::Cu118, Version::new([452, 39])),
962        (TorchBackend::Cu117, Version::new([452, 39])),
963        (TorchBackend::Cu116, Version::new([452, 39])),
964        (TorchBackend::Cu115, Version::new([452, 39])),
965        (TorchBackend::Cu114, Version::new([452, 39])),
966        (TorchBackend::Cu113, Version::new([452, 39])),
967        (TorchBackend::Cu112, Version::new([452, 39])),
968        (TorchBackend::Cu111, Version::new([452, 39])),
969        (TorchBackend::Cu110, Version::new([451, 22])),
970        // Table 1 from
971        // https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
972        (TorchBackend::Cu102, Version::new([441, 22])),
973        (TorchBackend::Cu101, Version::new([418, 96])),
974        (TorchBackend::Cu100, Version::new([411, 31])),
975        (TorchBackend::Cu92, Version::new([398, 26])),
976        (TorchBackend::Cu91, Version::new([391, 29])),
977        (TorchBackend::Cu90, Version::new([385, 54])),
978        (TorchBackend::Cu80, Version::new([376, 51])),
979    ]
980});
981
982/// Linux AMD GPU architectures and the corresponding PyTorch backends.
983///
984/// These were inferred by running the following snippet for each ROCm version:
985///
986/// ```python
987/// import torch
988///
989/// print(torch.cuda.get_arch_list())
990/// ```
991///
992/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
993/// however, this list includes a broader array of GPUs than those in the matrix.
994static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> =
995    LazyLock::new(|| {
996        [
997            // ROCm 6.3
998            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
999            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
1000            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908),
1001            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a),
1002            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942),
1003            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030),
1004            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100),
1005            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101),
1006            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102),
1007            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200),
1008            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201),
1009            // ROCm 6.2.4
1010            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900),
1011            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906),
1012            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908),
1013            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a),
1014            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942),
1015            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030),
1016            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100),
1017            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101),
1018            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102),
1019            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200),
1020            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201),
1021            // ROCm 6.2
1022            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900),
1023            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906),
1024            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908),
1025            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a),
1026            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030),
1027            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100),
1028            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942),
1029            // ROCm 6.1
1030            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900),
1031            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906),
1032            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908),
1033            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a),
1034            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942),
1035            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030),
1036            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100),
1037            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101),
1038            // ROCm 6.0
1039            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900),
1040            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906),
1041            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908),
1042            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a),
1043            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030),
1044            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100),
1045            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942),
1046        ]
1047    });
1048
1049static PYTORCH_CPU_INDEX_URL: LazyLock<IndexUrl> =
1050    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
1051static PYTORCH_CU130_INDEX_URL: LazyLock<IndexUrl> =
1052    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu130").unwrap());
1053static PYTORCH_CU129_INDEX_URL: LazyLock<IndexUrl> =
1054    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu129").unwrap());
1055static PYTORCH_CU128_INDEX_URL: LazyLock<IndexUrl> =
1056    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap());
1057static PYTORCH_CU126_INDEX_URL: LazyLock<IndexUrl> =
1058    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap());
1059static PYTORCH_CU125_INDEX_URL: LazyLock<IndexUrl> =
1060    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap());
1061static PYTORCH_CU124_INDEX_URL: LazyLock<IndexUrl> =
1062    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap());
1063static PYTORCH_CU123_INDEX_URL: LazyLock<IndexUrl> =
1064    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap());
1065static PYTORCH_CU122_INDEX_URL: LazyLock<IndexUrl> =
1066    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap());
1067static PYTORCH_CU121_INDEX_URL: LazyLock<IndexUrl> =
1068    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap());
1069static PYTORCH_CU120_INDEX_URL: LazyLock<IndexUrl> =
1070    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap());
1071static PYTORCH_CU118_INDEX_URL: LazyLock<IndexUrl> =
1072    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap());
1073static PYTORCH_CU117_INDEX_URL: LazyLock<IndexUrl> =
1074    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap());
1075static PYTORCH_CU116_INDEX_URL: LazyLock<IndexUrl> =
1076    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap());
1077static PYTORCH_CU115_INDEX_URL: LazyLock<IndexUrl> =
1078    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap());
1079static PYTORCH_CU114_INDEX_URL: LazyLock<IndexUrl> =
1080    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap());
1081static PYTORCH_CU113_INDEX_URL: LazyLock<IndexUrl> =
1082    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap());
1083static PYTORCH_CU112_INDEX_URL: LazyLock<IndexUrl> =
1084    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap());
1085static PYTORCH_CU111_INDEX_URL: LazyLock<IndexUrl> =
1086    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap());
1087static PYTORCH_CU110_INDEX_URL: LazyLock<IndexUrl> =
1088    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap());
1089static PYTORCH_CU102_INDEX_URL: LazyLock<IndexUrl> =
1090    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap());
1091static PYTORCH_CU101_INDEX_URL: LazyLock<IndexUrl> =
1092    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap());
1093static PYTORCH_CU100_INDEX_URL: LazyLock<IndexUrl> =
1094    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap());
1095static PYTORCH_CU92_INDEX_URL: LazyLock<IndexUrl> =
1096    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap());
1097static PYTORCH_CU91_INDEX_URL: LazyLock<IndexUrl> =
1098    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap());
1099static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
1100    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
1101static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
1102    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
1103static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
1104    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
1105static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> =
1106    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2.4").unwrap());
1107static PYTORCH_ROCM62_INDEX_URL: LazyLock<IndexUrl> =
1108    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2").unwrap());
1109static PYTORCH_ROCM61_INDEX_URL: LazyLock<IndexUrl> =
1110    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.1").unwrap());
1111static PYTORCH_ROCM60_INDEX_URL: LazyLock<IndexUrl> =
1112    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.0").unwrap());
1113static PYTORCH_ROCM57_INDEX_URL: LazyLock<IndexUrl> =
1114    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.7").unwrap());
1115static PYTORCH_ROCM56_INDEX_URL: LazyLock<IndexUrl> =
1116    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.6").unwrap());
1117static PYTORCH_ROCM55_INDEX_URL: LazyLock<IndexUrl> =
1118    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.5").unwrap());
1119static PYTORCH_ROCM542_INDEX_URL: LazyLock<IndexUrl> =
1120    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4.2").unwrap());
1121static PYTORCH_ROCM54_INDEX_URL: LazyLock<IndexUrl> =
1122    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4").unwrap());
1123static PYTORCH_ROCM53_INDEX_URL: LazyLock<IndexUrl> =
1124    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.3").unwrap());
1125static PYTORCH_ROCM52_INDEX_URL: LazyLock<IndexUrl> =
1126    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.2").unwrap());
1127static PYTORCH_ROCM511_INDEX_URL: LazyLock<IndexUrl> =
1128    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.1.1").unwrap());
1129static PYTORCH_ROCM42_INDEX_URL: LazyLock<IndexUrl> =
1130    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.2").unwrap());
1131static PYTORCH_ROCM41_INDEX_URL: LazyLock<IndexUrl> =
1132    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.1").unwrap());
1133static PYTORCH_ROCM401_INDEX_URL: LazyLock<IndexUrl> =
1134    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.0.1").unwrap());
1135static PYTORCH_XPU_INDEX_URL: LazyLock<IndexUrl> =
1136    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/xpu").unwrap());
1137
1138static PYX_API_BASE_URL: LazyLock<Cow<'static, str>> = LazyLock::new(|| {
1139    std::env::var(EnvVars::PYX_API_URL)
1140        .map(Cow::Owned)
1141        .unwrap_or(Cow::Borrowed("https://api.pyx.dev"))
1142});
1143static PYX_CPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1144    let api_base_url = &*PYX_API_BASE_URL;
1145    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cpu")).unwrap()
1146});
1147static PYX_CU130_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1148    let api_base_url = &*PYX_API_BASE_URL;
1149    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu130")).unwrap()
1150});
1151static PYX_CU129_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1152    let api_base_url = &*PYX_API_BASE_URL;
1153    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu129")).unwrap()
1154});
1155static PYX_CU128_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1156    let api_base_url = &*PYX_API_BASE_URL;
1157    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu128")).unwrap()
1158});
1159static PYX_CU126_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1160    let api_base_url = &*PYX_API_BASE_URL;
1161    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu126")).unwrap()
1162});
1163static PYX_CU125_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1164    let api_base_url = &*PYX_API_BASE_URL;
1165    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu125")).unwrap()
1166});
1167static PYX_CU124_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1168    let api_base_url = &*PYX_API_BASE_URL;
1169    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu124")).unwrap()
1170});
1171static PYX_CU123_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1172    let api_base_url = &*PYX_API_BASE_URL;
1173    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu123")).unwrap()
1174});
1175static PYX_CU122_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1176    let api_base_url = &*PYX_API_BASE_URL;
1177    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu122")).unwrap()
1178});
1179static PYX_CU121_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1180    let api_base_url = &*PYX_API_BASE_URL;
1181    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu121")).unwrap()
1182});
1183static PYX_CU120_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1184    let api_base_url = &*PYX_API_BASE_URL;
1185    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu120")).unwrap()
1186});
1187static PYX_CU118_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1188    let api_base_url = &*PYX_API_BASE_URL;
1189    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu118")).unwrap()
1190});
1191static PYX_CU117_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1192    let api_base_url = &*PYX_API_BASE_URL;
1193    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu117")).unwrap()
1194});
1195static PYX_CU116_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1196    let api_base_url = &*PYX_API_BASE_URL;
1197    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu116")).unwrap()
1198});
1199static PYX_CU115_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1200    let api_base_url = &*PYX_API_BASE_URL;
1201    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu115")).unwrap()
1202});
1203static PYX_CU114_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1204    let api_base_url = &*PYX_API_BASE_URL;
1205    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu114")).unwrap()
1206});
1207static PYX_CU113_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1208    let api_base_url = &*PYX_API_BASE_URL;
1209    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu113")).unwrap()
1210});
1211static PYX_CU112_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1212    let api_base_url = &*PYX_API_BASE_URL;
1213    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu112")).unwrap()
1214});
1215static PYX_CU111_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1216    let api_base_url = &*PYX_API_BASE_URL;
1217    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu111")).unwrap()
1218});
1219static PYX_CU110_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1220    let api_base_url = &*PYX_API_BASE_URL;
1221    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu110")).unwrap()
1222});
1223static PYX_CU102_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1224    let api_base_url = &*PYX_API_BASE_URL;
1225    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu102")).unwrap()
1226});
1227static PYX_CU101_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1228    let api_base_url = &*PYX_API_BASE_URL;
1229    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu101")).unwrap()
1230});
1231static PYX_CU100_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1232    let api_base_url = &*PYX_API_BASE_URL;
1233    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu100")).unwrap()
1234});
1235static PYX_CU92_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1236    let api_base_url = &*PYX_API_BASE_URL;
1237    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu92")).unwrap()
1238});
1239static PYX_CU91_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1240    let api_base_url = &*PYX_API_BASE_URL;
1241    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu91")).unwrap()
1242});
1243static PYX_CU90_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1244    let api_base_url = &*PYX_API_BASE_URL;
1245    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu90")).unwrap()
1246});
1247static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1248    let api_base_url = &*PYX_API_BASE_URL;
1249    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
1250});
1251static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1252    let api_base_url = &*PYX_API_BASE_URL;
1253    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()
1254});
1255static PYX_ROCM624_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1256    let api_base_url = &*PYX_API_BASE_URL;
1257    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2.4")).unwrap()
1258});
1259static PYX_ROCM62_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1260    let api_base_url = &*PYX_API_BASE_URL;
1261    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2")).unwrap()
1262});
1263static PYX_ROCM61_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1264    let api_base_url = &*PYX_API_BASE_URL;
1265    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.1")).unwrap()
1266});
1267static PYX_ROCM60_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1268    let api_base_url = &*PYX_API_BASE_URL;
1269    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.0")).unwrap()
1270});
1271static PYX_ROCM57_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1272    let api_base_url = &*PYX_API_BASE_URL;
1273    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.7")).unwrap()
1274});
1275static PYX_ROCM56_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1276    let api_base_url = &*PYX_API_BASE_URL;
1277    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.6")).unwrap()
1278});
1279static PYX_ROCM55_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1280    let api_base_url = &*PYX_API_BASE_URL;
1281    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.5")).unwrap()
1282});
1283static PYX_ROCM542_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1284    let api_base_url = &*PYX_API_BASE_URL;
1285    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4.2")).unwrap()
1286});
1287static PYX_ROCM54_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1288    let api_base_url = &*PYX_API_BASE_URL;
1289    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4")).unwrap()
1290});
1291static PYX_ROCM53_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1292    let api_base_url = &*PYX_API_BASE_URL;
1293    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.3")).unwrap()
1294});
1295static PYX_ROCM52_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1296    let api_base_url = &*PYX_API_BASE_URL;
1297    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.2")).unwrap()
1298});
1299static PYX_ROCM511_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1300    let api_base_url = &*PYX_API_BASE_URL;
1301    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.1.1")).unwrap()
1302});
1303static PYX_ROCM42_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1304    let api_base_url = &*PYX_API_BASE_URL;
1305    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.2")).unwrap()
1306});
1307static PYX_ROCM41_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1308    let api_base_url = &*PYX_API_BASE_URL;
1309    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.1")).unwrap()
1310});
1311static PYX_ROCM401_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1312    let api_base_url = &*PYX_API_BASE_URL;
1313    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.0.1")).unwrap()
1314});
1315static PYX_XPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1316    let api_base_url = &*PYX_API_BASE_URL;
1317    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/xpu")).unwrap()
1318});