uv_torch/
backend.rs

1//! `uv-torch` is a library for determining the appropriate PyTorch index based on the operating
2//! system and CUDA driver version.
3//!
4//! This library is derived from `light-the-torch` by Philipp Meier, which is available under the
5//! following BSD-3 Clause license:
6//!
7//! ```text
8//! BSD 3-Clause License
9//!
10//! Copyright (c) 2020, Philip Meier
11//! All rights reserved.
12//!
13//! Redistribution and use in source and binary forms, with or without
14//! modification, are permitted provided that the following conditions are met:
15//!
16//! 1. Redistributions of source code must retain the above copyright notice, this
17//!    list of conditions and the following disclaimer.
18//!
19//! 2. Redistributions in binary form must reproduce the above copyright notice,
20//!    this list of conditions and the following disclaimer in the documentation
21//!    and/or other materials provided with the distribution.
22//!
23//! 3. Neither the name of the copyright holder nor the names of its
24//!    contributors may be used to endorse or promote products derived from
25//!    this software without specific prior written permission.
26//!
27//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37//! ```
38
39use std::borrow::Cow;
40use std::str::FromStr;
41use std::sync::LazyLock;
42
43use either::Either;
44use url::Url;
45
46use uv_distribution_types::IndexUrl;
47use uv_normalize::PackageName;
48use uv_pep440::Version;
49use uv_platform_tags::Os;
50use uv_static::EnvVars;
51
52use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
53
54/// The strategy to use when determining the appropriate PyTorch index.
55#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
56#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
57#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
58#[serde(rename_all = "kebab-case")]
59pub enum TorchMode {
60    /// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
61    Auto,
62    /// Use the CPU-only PyTorch index.
63    Cpu,
64    /// Use the PyTorch index for CUDA 13.0.
65    Cu130,
66    /// Use the PyTorch index for CUDA 12.9.
67    Cu129,
68    /// Use the PyTorch index for CUDA 12.8.
69    Cu128,
70    /// Use the PyTorch index for CUDA 12.6.
71    Cu126,
72    /// Use the PyTorch index for CUDA 12.5.
73    Cu125,
74    /// Use the PyTorch index for CUDA 12.4.
75    Cu124,
76    /// Use the PyTorch index for CUDA 12.3.
77    Cu123,
78    /// Use the PyTorch index for CUDA 12.2.
79    Cu122,
80    /// Use the PyTorch index for CUDA 12.1.
81    Cu121,
82    /// Use the PyTorch index for CUDA 12.0.
83    Cu120,
84    /// Use the PyTorch index for CUDA 11.8.
85    Cu118,
86    /// Use the PyTorch index for CUDA 11.7.
87    Cu117,
88    /// Use the PyTorch index for CUDA 11.6.
89    Cu116,
90    /// Use the PyTorch index for CUDA 11.5.
91    Cu115,
92    /// Use the PyTorch index for CUDA 11.4.
93    Cu114,
94    /// Use the PyTorch index for CUDA 11.3.
95    Cu113,
96    /// Use the PyTorch index for CUDA 11.2.
97    Cu112,
98    /// Use the PyTorch index for CUDA 11.1.
99    Cu111,
100    /// Use the PyTorch index for CUDA 11.0.
101    Cu110,
102    /// Use the PyTorch index for CUDA 10.2.
103    Cu102,
104    /// Use the PyTorch index for CUDA 10.1.
105    Cu101,
106    /// Use the PyTorch index for CUDA 10.0.
107    Cu100,
108    /// Use the PyTorch index for CUDA 9.2.
109    Cu92,
110    /// Use the PyTorch index for CUDA 9.1.
111    Cu91,
112    /// Use the PyTorch index for CUDA 9.0.
113    Cu90,
114    /// Use the PyTorch index for CUDA 8.0.
115    Cu80,
116    /// Use the PyTorch index for ROCm 6.4.
117    #[serde(rename = "rocm6.4")]
118    #[cfg_attr(feature = "clap", clap(name = "rocm6.4"))]
119    Rocm64,
120    /// Use the PyTorch index for ROCm 6.3.
121    #[serde(rename = "rocm6.3")]
122    #[cfg_attr(feature = "clap", clap(name = "rocm6.3"))]
123    Rocm63,
124    /// Use the PyTorch index for ROCm 6.2.4.
125    #[serde(rename = "rocm6.2.4")]
126    #[cfg_attr(feature = "clap", clap(name = "rocm6.2.4"))]
127    Rocm624,
128    /// Use the PyTorch index for ROCm 6.2.
129    #[serde(rename = "rocm6.2")]
130    #[cfg_attr(feature = "clap", clap(name = "rocm6.2"))]
131    Rocm62,
132    /// Use the PyTorch index for ROCm 6.1.
133    #[serde(rename = "rocm6.1")]
134    #[cfg_attr(feature = "clap", clap(name = "rocm6.1"))]
135    Rocm61,
136    /// Use the PyTorch index for ROCm 6.0.
137    #[serde(rename = "rocm6.0")]
138    #[cfg_attr(feature = "clap", clap(name = "rocm6.0"))]
139    Rocm60,
140    /// Use the PyTorch index for ROCm 5.7.
141    #[serde(rename = "rocm5.7")]
142    #[cfg_attr(feature = "clap", clap(name = "rocm5.7"))]
143    Rocm57,
144    /// Use the PyTorch index for ROCm 5.6.
145    #[serde(rename = "rocm5.6")]
146    #[cfg_attr(feature = "clap", clap(name = "rocm5.6"))]
147    Rocm56,
148    /// Use the PyTorch index for ROCm 5.5.
149    #[serde(rename = "rocm5.5")]
150    #[cfg_attr(feature = "clap", clap(name = "rocm5.5"))]
151    Rocm55,
152    /// Use the PyTorch index for ROCm 5.4.2.
153    #[serde(rename = "rocm5.4.2")]
154    #[cfg_attr(feature = "clap", clap(name = "rocm5.4.2"))]
155    Rocm542,
156    /// Use the PyTorch index for ROCm 5.4.
157    #[serde(rename = "rocm5.4")]
158    #[cfg_attr(feature = "clap", clap(name = "rocm5.4"))]
159    Rocm54,
160    /// Use the PyTorch index for ROCm 5.3.
161    #[serde(rename = "rocm5.3")]
162    #[cfg_attr(feature = "clap", clap(name = "rocm5.3"))]
163    Rocm53,
164    /// Use the PyTorch index for ROCm 5.2.
165    #[serde(rename = "rocm5.2")]
166    #[cfg_attr(feature = "clap", clap(name = "rocm5.2"))]
167    Rocm52,
168    /// Use the PyTorch index for ROCm 5.1.1.
169    #[serde(rename = "rocm5.1.1")]
170    #[cfg_attr(feature = "clap", clap(name = "rocm5.1.1"))]
171    Rocm511,
172    /// Use the PyTorch index for ROCm 4.2.
173    #[serde(rename = "rocm4.2")]
174    #[cfg_attr(feature = "clap", clap(name = "rocm4.2"))]
175    Rocm42,
176    /// Use the PyTorch index for ROCm 4.1.
177    #[serde(rename = "rocm4.1")]
178    #[cfg_attr(feature = "clap", clap(name = "rocm4.1"))]
179    Rocm41,
180    /// Use the PyTorch index for ROCm 4.0.1.
181    #[serde(rename = "rocm4.0.1")]
182    #[cfg_attr(feature = "clap", clap(name = "rocm4.0.1"))]
183    Rocm401,
184    /// Use the PyTorch index for Intel XPU.
185    Xpu,
186}
187
188#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
189pub enum TorchSource {
190    /// Download PyTorch builds from the official PyTorch index.
191    #[default]
192    PyTorch,
193    /// Download PyTorch builds from the pyx index.
194    Pyx,
195}
196
197/// The strategy to use when determining the appropriate PyTorch index.
198#[derive(Debug, Clone, Eq, PartialEq)]
199pub enum TorchStrategy {
200    /// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`).
201    Cuda {
202        os: Os,
203        driver_version: Version,
204        source: TorchSource,
205    },
206    /// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`).
207    Amd {
208        os: Os,
209        gpu_architecture: AmdGpuArchitecture,
210        source: TorchSource,
211    },
212    /// Select the appropriate PyTorch index based on the operating system and Intel GPU presence.
213    Xpu { os: Os, source: TorchSource },
214    /// Use the specified PyTorch index.
215    Backend {
216        backend: TorchBackend,
217        source: TorchSource,
218    },
219}
220
221impl TorchStrategy {
222    /// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
223    pub fn from_mode(
224        mode: TorchMode,
225        source: TorchSource,
226        os: &Os,
227    ) -> Result<Self, AcceleratorError> {
228        let backend = match mode {
229            TorchMode::Auto => match Accelerator::detect()? {
230                Some(Accelerator::Cuda { driver_version }) => {
231                    return Ok(Self::Cuda {
232                        os: os.clone(),
233                        driver_version: driver_version.clone(),
234                        source,
235                    });
236                }
237                Some(Accelerator::Amd { gpu_architecture }) => {
238                    return Ok(Self::Amd {
239                        os: os.clone(),
240                        gpu_architecture,
241                        source,
242                    });
243                }
244                Some(Accelerator::Xpu) => {
245                    return Ok(Self::Xpu {
246                        os: os.clone(),
247                        source,
248                    });
249                }
250                None => TorchBackend::Cpu,
251            },
252            TorchMode::Cpu => TorchBackend::Cpu,
253            TorchMode::Cu130 => TorchBackend::Cu130,
254            TorchMode::Cu129 => TorchBackend::Cu129,
255            TorchMode::Cu128 => TorchBackend::Cu128,
256            TorchMode::Cu126 => TorchBackend::Cu126,
257            TorchMode::Cu125 => TorchBackend::Cu125,
258            TorchMode::Cu124 => TorchBackend::Cu124,
259            TorchMode::Cu123 => TorchBackend::Cu123,
260            TorchMode::Cu122 => TorchBackend::Cu122,
261            TorchMode::Cu121 => TorchBackend::Cu121,
262            TorchMode::Cu120 => TorchBackend::Cu120,
263            TorchMode::Cu118 => TorchBackend::Cu118,
264            TorchMode::Cu117 => TorchBackend::Cu117,
265            TorchMode::Cu116 => TorchBackend::Cu116,
266            TorchMode::Cu115 => TorchBackend::Cu115,
267            TorchMode::Cu114 => TorchBackend::Cu114,
268            TorchMode::Cu113 => TorchBackend::Cu113,
269            TorchMode::Cu112 => TorchBackend::Cu112,
270            TorchMode::Cu111 => TorchBackend::Cu111,
271            TorchMode::Cu110 => TorchBackend::Cu110,
272            TorchMode::Cu102 => TorchBackend::Cu102,
273            TorchMode::Cu101 => TorchBackend::Cu101,
274            TorchMode::Cu100 => TorchBackend::Cu100,
275            TorchMode::Cu92 => TorchBackend::Cu92,
276            TorchMode::Cu91 => TorchBackend::Cu91,
277            TorchMode::Cu90 => TorchBackend::Cu90,
278            TorchMode::Cu80 => TorchBackend::Cu80,
279            TorchMode::Rocm64 => TorchBackend::Rocm64,
280            TorchMode::Rocm63 => TorchBackend::Rocm63,
281            TorchMode::Rocm624 => TorchBackend::Rocm624,
282            TorchMode::Rocm62 => TorchBackend::Rocm62,
283            TorchMode::Rocm61 => TorchBackend::Rocm61,
284            TorchMode::Rocm60 => TorchBackend::Rocm60,
285            TorchMode::Rocm57 => TorchBackend::Rocm57,
286            TorchMode::Rocm56 => TorchBackend::Rocm56,
287            TorchMode::Rocm55 => TorchBackend::Rocm55,
288            TorchMode::Rocm542 => TorchBackend::Rocm542,
289            TorchMode::Rocm54 => TorchBackend::Rocm54,
290            TorchMode::Rocm53 => TorchBackend::Rocm53,
291            TorchMode::Rocm52 => TorchBackend::Rocm52,
292            TorchMode::Rocm511 => TorchBackend::Rocm511,
293            TorchMode::Rocm42 => TorchBackend::Rocm42,
294            TorchMode::Rocm41 => TorchBackend::Rocm41,
295            TorchMode::Rocm401 => TorchBackend::Rocm401,
296            TorchMode::Xpu => TorchBackend::Xpu,
297        };
298        Ok(Self::Backend { backend, source })
299    }
300
301    /// Returns `true` if the [`TorchStrategy`] applies to the given [`PackageName`].
302    pub fn applies_to(&self, package_name: &PackageName) -> bool {
303        let source = match self {
304            Self::Cuda { source, .. } => *source,
305            Self::Amd { source, .. } => *source,
306            Self::Xpu { source, .. } => *source,
307            Self::Backend { source, .. } => *source,
308        };
309        match source {
310            TorchSource::PyTorch => {
311                matches!(
312                    package_name.as_str(),
313                    "torch"
314                        | "torcharrow"
315                        | "torchaudio"
316                        | "torchcsprng"
317                        | "torchdata"
318                        | "torchdistx"
319                        | "torchserve"
320                        | "torchtext"
321                        | "torchvision"
322                        | "triton"
323                        | "pytorch-triton"
324                        | "pytorch-triton-rocm"
325                        | "pytorch-triton-xpu"
326                )
327            }
328            TorchSource::Pyx => {
329                matches!(
330                    package_name.as_str(),
331                    "deepspeed"
332                        | "flash-attn"
333                        | "flash-attn-3"
334                        | "megablocks"
335                        | "natten"
336                        | "pyg-lib"
337                        | "torch-cluster"
338                        | "torch-scatter"
339                        | "torch-sparse"
340                        | "torch-spline-conv"
341                        | "vllm"
342                        | "torch"
343                        | "torcharrow"
344                        | "torchaudio"
345                        | "torchcsprng"
346                        | "torchdata"
347                        | "torchdistx"
348                        | "torchserve"
349                        | "torchtext"
350                        | "torchvision"
351                        | "triton"
352                        | "pytorch-triton"
353                        | "pytorch-triton-rocm"
354                        | "pytorch-triton-xpu"
355                )
356            }
357        }
358    }
359
360    /// Returns `true` if the given [`PackageName`] has a system dependency (e.g., CUDA or ROCm).
361    ///
362    /// For example, `triton` is hosted on the PyTorch indexes, but does not have a system
363    /// dependency on the associated CUDA version (i.e., the `triton` on the `cu128` index doesn't
364    /// depend on CUDA 12.8).
365    pub fn has_system_dependency(&self, package_name: &PackageName) -> bool {
366        matches!(
367            package_name.as_str(),
368            "flash-attn"
369                | "flash-attn-3"
370                | "megablocks"
371                | "natten"
372                | "deepspeed"
373                | "vllm"
374                | "torch"
375                | "torcharrow"
376                | "torchaudio"
377                | "torchcsprng"
378                | "torchdata"
379                | "torchdistx"
380                | "torchtext"
381                | "torchvision"
382        )
383    }
384
385    /// Return the appropriate index URLs for the given [`TorchStrategy`].
386    pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
387        match self {
388            Self::Cuda {
389                os,
390                driver_version,
391                source,
392            } => {
393                // If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
394                // indexes.
395                //
396                // See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
397                match os {
398                    Os::Manylinux { .. } | Os::Musllinux { .. } => {
399                        Either::Left(Either::Left(Either::Left(
400                            LINUX_CUDA_DRIVERS
401                                .iter()
402                                .filter_map(move |(backend, version)| {
403                                    if driver_version >= version {
404                                        Some(backend.index_url(*source))
405                                    } else {
406                                        None
407                                    }
408                                })
409                                .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
410                        )))
411                    }
412                    Os::Windows => Either::Left(Either::Left(Either::Right(
413                        WINDOWS_CUDA_VERSIONS
414                            .iter()
415                            .filter_map(move |(backend, version)| {
416                                if driver_version >= version {
417                                    Some(backend.index_url(*source))
418                                } else {
419                                    None
420                                }
421                            })
422                            .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
423                    ))),
424                    Os::Macos { .. }
425                    | Os::FreeBsd { .. }
426                    | Os::NetBsd { .. }
427                    | Os::OpenBsd { .. }
428                    | Os::Dragonfly { .. }
429                    | Os::Illumos { .. }
430                    | Os::Haiku { .. }
431                    | Os::Android { .. }
432                    | Os::Pyodide { .. }
433                    | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
434                        TorchBackend::Cpu.index_url(*source),
435                    ))),
436                }
437            }
438            Self::Amd {
439                os,
440                gpu_architecture,
441                source,
442            } => match os {
443                Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
444                    LINUX_AMD_GPU_DRIVERS
445                        .iter()
446                        .filter_map(move |(backend, architecture)| {
447                            if gpu_architecture == architecture {
448                                Some(backend.index_url(*source))
449                            } else {
450                                None
451                            }
452                        })
453                        .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
454                )),
455                Os::Windows
456                | Os::Macos { .. }
457                | Os::FreeBsd { .. }
458                | Os::NetBsd { .. }
459                | Os::OpenBsd { .. }
460                | Os::Dragonfly { .. }
461                | Os::Illumos { .. }
462                | Os::Haiku { .. }
463                | Os::Android { .. }
464                | Os::Pyodide { .. }
465                | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
466                    TorchBackend::Cpu.index_url(*source),
467                ))),
468            },
469            Self::Xpu { os, source } => match os {
470                Os::Manylinux { .. } | Os::Windows => Either::Right(Either::Right(Either::Left(
471                    std::iter::once(TorchBackend::Xpu.index_url(*source)),
472                ))),
473                Os::Musllinux { .. }
474                | Os::Macos { .. }
475                | Os::FreeBsd { .. }
476                | Os::NetBsd { .. }
477                | Os::OpenBsd { .. }
478                | Os::Dragonfly { .. }
479                | Os::Illumos { .. }
480                | Os::Haiku { .. }
481                | Os::Android { .. }
482                | Os::Pyodide { .. }
483                | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
484                    TorchBackend::Cpu.index_url(*source),
485                ))),
486            },
487            Self::Backend { backend, source } => Either::Right(Either::Right(Either::Right(
488                std::iter::once(backend.index_url(*source)),
489            ))),
490        }
491    }
492}
493
494/// The available backends for PyTorch.
495#[derive(Debug, Copy, Clone, Eq, PartialEq)]
496pub enum TorchBackend {
497    Cpu,
498    Cu130,
499    Cu129,
500    Cu128,
501    Cu126,
502    Cu125,
503    Cu124,
504    Cu123,
505    Cu122,
506    Cu121,
507    Cu120,
508    Cu118,
509    Cu117,
510    Cu116,
511    Cu115,
512    Cu114,
513    Cu113,
514    Cu112,
515    Cu111,
516    Cu110,
517    Cu102,
518    Cu101,
519    Cu100,
520    Cu92,
521    Cu91,
522    Cu90,
523    Cu80,
524    Rocm64,
525    Rocm63,
526    Rocm624,
527    Rocm62,
528    Rocm61,
529    Rocm60,
530    Rocm57,
531    Rocm56,
532    Rocm55,
533    Rocm542,
534    Rocm54,
535    Rocm53,
536    Rocm52,
537    Rocm511,
538    Rocm42,
539    Rocm41,
540    Rocm401,
541    Xpu,
542}
543
544impl TorchBackend {
545    /// Return the appropriate index URL for the given [`TorchBackend`].
546    fn index_url(self, source: TorchSource) -> &'static IndexUrl {
547        match self {
548            Self::Cpu => match source {
549                TorchSource::PyTorch => &PYTORCH_CPU_INDEX_URL,
550                TorchSource::Pyx => &PYX_CPU_INDEX_URL,
551            },
552            Self::Cu130 => match source {
553                TorchSource::PyTorch => &PYTORCH_CU130_INDEX_URL,
554                TorchSource::Pyx => &PYX_CU130_INDEX_URL,
555            },
556            Self::Cu129 => match source {
557                TorchSource::PyTorch => &PYTORCH_CU129_INDEX_URL,
558                TorchSource::Pyx => &PYX_CU129_INDEX_URL,
559            },
560            Self::Cu128 => match source {
561                TorchSource::PyTorch => &PYTORCH_CU128_INDEX_URL,
562                TorchSource::Pyx => &PYX_CU128_INDEX_URL,
563            },
564            Self::Cu126 => match source {
565                TorchSource::PyTorch => &PYTORCH_CU126_INDEX_URL,
566                TorchSource::Pyx => &PYX_CU126_INDEX_URL,
567            },
568            Self::Cu125 => match source {
569                TorchSource::PyTorch => &PYTORCH_CU125_INDEX_URL,
570                TorchSource::Pyx => &PYX_CU125_INDEX_URL,
571            },
572            Self::Cu124 => match source {
573                TorchSource::PyTorch => &PYTORCH_CU124_INDEX_URL,
574                TorchSource::Pyx => &PYX_CU124_INDEX_URL,
575            },
576            Self::Cu123 => match source {
577                TorchSource::PyTorch => &PYTORCH_CU123_INDEX_URL,
578                TorchSource::Pyx => &PYX_CU123_INDEX_URL,
579            },
580            Self::Cu122 => match source {
581                TorchSource::PyTorch => &PYTORCH_CU122_INDEX_URL,
582                TorchSource::Pyx => &PYX_CU122_INDEX_URL,
583            },
584            Self::Cu121 => match source {
585                TorchSource::PyTorch => &PYTORCH_CU121_INDEX_URL,
586                TorchSource::Pyx => &PYX_CU121_INDEX_URL,
587            },
588            Self::Cu120 => match source {
589                TorchSource::PyTorch => &PYTORCH_CU120_INDEX_URL,
590                TorchSource::Pyx => &PYX_CU120_INDEX_URL,
591            },
592            Self::Cu118 => match source {
593                TorchSource::PyTorch => &PYTORCH_CU118_INDEX_URL,
594                TorchSource::Pyx => &PYX_CU118_INDEX_URL,
595            },
596            Self::Cu117 => match source {
597                TorchSource::PyTorch => &PYTORCH_CU117_INDEX_URL,
598                TorchSource::Pyx => &PYX_CU117_INDEX_URL,
599            },
600            Self::Cu116 => match source {
601                TorchSource::PyTorch => &PYTORCH_CU116_INDEX_URL,
602                TorchSource::Pyx => &PYX_CU116_INDEX_URL,
603            },
604            Self::Cu115 => match source {
605                TorchSource::PyTorch => &PYTORCH_CU115_INDEX_URL,
606                TorchSource::Pyx => &PYX_CU115_INDEX_URL,
607            },
608            Self::Cu114 => match source {
609                TorchSource::PyTorch => &PYTORCH_CU114_INDEX_URL,
610                TorchSource::Pyx => &PYX_CU114_INDEX_URL,
611            },
612            Self::Cu113 => match source {
613                TorchSource::PyTorch => &PYTORCH_CU113_INDEX_URL,
614                TorchSource::Pyx => &PYX_CU113_INDEX_URL,
615            },
616            Self::Cu112 => match source {
617                TorchSource::PyTorch => &PYTORCH_CU112_INDEX_URL,
618                TorchSource::Pyx => &PYX_CU112_INDEX_URL,
619            },
620            Self::Cu111 => match source {
621                TorchSource::PyTorch => &PYTORCH_CU111_INDEX_URL,
622                TorchSource::Pyx => &PYX_CU111_INDEX_URL,
623            },
624            Self::Cu110 => match source {
625                TorchSource::PyTorch => &PYTORCH_CU110_INDEX_URL,
626                TorchSource::Pyx => &PYX_CU110_INDEX_URL,
627            },
628            Self::Cu102 => match source {
629                TorchSource::PyTorch => &PYTORCH_CU102_INDEX_URL,
630                TorchSource::Pyx => &PYX_CU102_INDEX_URL,
631            },
632            Self::Cu101 => match source {
633                TorchSource::PyTorch => &PYTORCH_CU101_INDEX_URL,
634                TorchSource::Pyx => &PYX_CU101_INDEX_URL,
635            },
636            Self::Cu100 => match source {
637                TorchSource::PyTorch => &PYTORCH_CU100_INDEX_URL,
638                TorchSource::Pyx => &PYX_CU100_INDEX_URL,
639            },
640            Self::Cu92 => match source {
641                TorchSource::PyTorch => &PYTORCH_CU92_INDEX_URL,
642                TorchSource::Pyx => &PYX_CU92_INDEX_URL,
643            },
644            Self::Cu91 => match source {
645                TorchSource::PyTorch => &PYTORCH_CU91_INDEX_URL,
646                TorchSource::Pyx => &PYX_CU91_INDEX_URL,
647            },
648            Self::Cu90 => match source {
649                TorchSource::PyTorch => &PYTORCH_CU90_INDEX_URL,
650                TorchSource::Pyx => &PYX_CU90_INDEX_URL,
651            },
652            Self::Cu80 => match source {
653                TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
654                TorchSource::Pyx => &PYX_CU80_INDEX_URL,
655            },
656            Self::Rocm64 => match source {
657                TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL,
658                TorchSource::Pyx => &PYX_ROCM64_INDEX_URL,
659            },
660            Self::Rocm63 => match source {
661                TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
662                TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
663            },
664            Self::Rocm624 => match source {
665                TorchSource::PyTorch => &PYTORCH_ROCM624_INDEX_URL,
666                TorchSource::Pyx => &PYX_ROCM624_INDEX_URL,
667            },
668            Self::Rocm62 => match source {
669                TorchSource::PyTorch => &PYTORCH_ROCM62_INDEX_URL,
670                TorchSource::Pyx => &PYX_ROCM62_INDEX_URL,
671            },
672            Self::Rocm61 => match source {
673                TorchSource::PyTorch => &PYTORCH_ROCM61_INDEX_URL,
674                TorchSource::Pyx => &PYX_ROCM61_INDEX_URL,
675            },
676            Self::Rocm60 => match source {
677                TorchSource::PyTorch => &PYTORCH_ROCM60_INDEX_URL,
678                TorchSource::Pyx => &PYX_ROCM60_INDEX_URL,
679            },
680            Self::Rocm57 => match source {
681                TorchSource::PyTorch => &PYTORCH_ROCM57_INDEX_URL,
682                TorchSource::Pyx => &PYX_ROCM57_INDEX_URL,
683            },
684            Self::Rocm56 => match source {
685                TorchSource::PyTorch => &PYTORCH_ROCM56_INDEX_URL,
686                TorchSource::Pyx => &PYX_ROCM56_INDEX_URL,
687            },
688            Self::Rocm55 => match source {
689                TorchSource::PyTorch => &PYTORCH_ROCM55_INDEX_URL,
690                TorchSource::Pyx => &PYX_ROCM55_INDEX_URL,
691            },
692            Self::Rocm542 => match source {
693                TorchSource::PyTorch => &PYTORCH_ROCM542_INDEX_URL,
694                TorchSource::Pyx => &PYX_ROCM542_INDEX_URL,
695            },
696            Self::Rocm54 => match source {
697                TorchSource::PyTorch => &PYTORCH_ROCM54_INDEX_URL,
698                TorchSource::Pyx => &PYX_ROCM54_INDEX_URL,
699            },
700            Self::Rocm53 => match source {
701                TorchSource::PyTorch => &PYTORCH_ROCM53_INDEX_URL,
702                TorchSource::Pyx => &PYX_ROCM53_INDEX_URL,
703            },
704            Self::Rocm52 => match source {
705                TorchSource::PyTorch => &PYTORCH_ROCM52_INDEX_URL,
706                TorchSource::Pyx => &PYX_ROCM52_INDEX_URL,
707            },
708            Self::Rocm511 => match source {
709                TorchSource::PyTorch => &PYTORCH_ROCM511_INDEX_URL,
710                TorchSource::Pyx => &PYX_ROCM511_INDEX_URL,
711            },
712            Self::Rocm42 => match source {
713                TorchSource::PyTorch => &PYTORCH_ROCM42_INDEX_URL,
714                TorchSource::Pyx => &PYX_ROCM42_INDEX_URL,
715            },
716            Self::Rocm41 => match source {
717                TorchSource::PyTorch => &PYTORCH_ROCM41_INDEX_URL,
718                TorchSource::Pyx => &PYX_ROCM41_INDEX_URL,
719            },
720            Self::Rocm401 => match source {
721                TorchSource::PyTorch => &PYTORCH_ROCM401_INDEX_URL,
722                TorchSource::Pyx => &PYX_ROCM401_INDEX_URL,
723            },
724            Self::Xpu => match source {
725                TorchSource::PyTorch => &PYTORCH_XPU_INDEX_URL,
726                TorchSource::Pyx => &PYX_XPU_INDEX_URL,
727            },
728        }
729    }
730
731    /// Extract a [`TorchBackend`] from an index URL.
732    pub fn from_index(index: &Url) -> Option<Self> {
733        let backend_identifier = if index.host_str() == Some("download.pytorch.org") {
734            // E.g., `https://download.pytorch.org/whl/cu124`
735            let mut path_segments = index.path_segments()?;
736            if path_segments.next() != Some("whl") {
737                return None;
738            }
739            path_segments.next()?
740        // TODO(zanieb): We should consolidate this with `is_known_url` somehow
741        } else if index.host_str() == PYX_API_BASE_URL.strip_prefix("https://") {
742            // E.g., `https://api.pyx.dev/simple/astral-sh/cu124`
743            let mut path_segments = index.path_segments()?;
744            if path_segments.next() != Some("simple") {
745                return None;
746            }
747            if path_segments.next() != Some("astral-sh") {
748                return None;
749            }
750            path_segments.next()?
751        } else {
752            return None;
753        };
754        Self::from_str(backend_identifier).ok()
755    }
756
757    /// Returns the CUDA [`Version`] for the given [`TorchBackend`].
758    pub fn cuda_version(&self) -> Option<Version> {
759        match self {
760            Self::Cpu => None,
761            Self::Cu130 => Some(Version::new([13, 0])),
762            Self::Cu129 => Some(Version::new([12, 9])),
763            Self::Cu128 => Some(Version::new([12, 8])),
764            Self::Cu126 => Some(Version::new([12, 6])),
765            Self::Cu125 => Some(Version::new([12, 5])),
766            Self::Cu124 => Some(Version::new([12, 4])),
767            Self::Cu123 => Some(Version::new([12, 3])),
768            Self::Cu122 => Some(Version::new([12, 2])),
769            Self::Cu121 => Some(Version::new([12, 1])),
770            Self::Cu120 => Some(Version::new([12, 0])),
771            Self::Cu118 => Some(Version::new([11, 8])),
772            Self::Cu117 => Some(Version::new([11, 7])),
773            Self::Cu116 => Some(Version::new([11, 6])),
774            Self::Cu115 => Some(Version::new([11, 5])),
775            Self::Cu114 => Some(Version::new([11, 4])),
776            Self::Cu113 => Some(Version::new([11, 3])),
777            Self::Cu112 => Some(Version::new([11, 2])),
778            Self::Cu111 => Some(Version::new([11, 1])),
779            Self::Cu110 => Some(Version::new([11, 0])),
780            Self::Cu102 => Some(Version::new([10, 2])),
781            Self::Cu101 => Some(Version::new([10, 1])),
782            Self::Cu100 => Some(Version::new([10, 0])),
783            Self::Cu92 => Some(Version::new([9, 2])),
784            Self::Cu91 => Some(Version::new([9, 1])),
785            Self::Cu90 => Some(Version::new([9, 0])),
786            Self::Cu80 => Some(Version::new([8, 0])),
787            Self::Rocm64 => None,
788            Self::Rocm63 => None,
789            Self::Rocm624 => None,
790            Self::Rocm62 => None,
791            Self::Rocm61 => None,
792            Self::Rocm60 => None,
793            Self::Rocm57 => None,
794            Self::Rocm56 => None,
795            Self::Rocm55 => None,
796            Self::Rocm542 => None,
797            Self::Rocm54 => None,
798            Self::Rocm53 => None,
799            Self::Rocm52 => None,
800            Self::Rocm511 => None,
801            Self::Rocm42 => None,
802            Self::Rocm41 => None,
803            Self::Rocm401 => None,
804            Self::Xpu => None,
805        }
806    }
807
808    /// Returns the ROCM [`Version`] for the given [`TorchBackend`].
809    pub fn rocm_version(&self) -> Option<Version> {
810        match self {
811            Self::Cpu => None,
812            Self::Cu130 => None,
813            Self::Cu129 => None,
814            Self::Cu128 => None,
815            Self::Cu126 => None,
816            Self::Cu125 => None,
817            Self::Cu124 => None,
818            Self::Cu123 => None,
819            Self::Cu122 => None,
820            Self::Cu121 => None,
821            Self::Cu120 => None,
822            Self::Cu118 => None,
823            Self::Cu117 => None,
824            Self::Cu116 => None,
825            Self::Cu115 => None,
826            Self::Cu114 => None,
827            Self::Cu113 => None,
828            Self::Cu112 => None,
829            Self::Cu111 => None,
830            Self::Cu110 => None,
831            Self::Cu102 => None,
832            Self::Cu101 => None,
833            Self::Cu100 => None,
834            Self::Cu92 => None,
835            Self::Cu91 => None,
836            Self::Cu90 => None,
837            Self::Cu80 => None,
838            Self::Rocm64 => Some(Version::new([6, 4])),
839            Self::Rocm63 => Some(Version::new([6, 3])),
840            Self::Rocm624 => Some(Version::new([6, 2, 4])),
841            Self::Rocm62 => Some(Version::new([6, 2])),
842            Self::Rocm61 => Some(Version::new([6, 1])),
843            Self::Rocm60 => Some(Version::new([6, 0])),
844            Self::Rocm57 => Some(Version::new([5, 7])),
845            Self::Rocm56 => Some(Version::new([5, 6])),
846            Self::Rocm55 => Some(Version::new([5, 5])),
847            Self::Rocm542 => Some(Version::new([5, 4, 2])),
848            Self::Rocm54 => Some(Version::new([5, 4])),
849            Self::Rocm53 => Some(Version::new([5, 3])),
850            Self::Rocm52 => Some(Version::new([5, 2])),
851            Self::Rocm511 => Some(Version::new([5, 1, 1])),
852            Self::Rocm42 => Some(Version::new([4, 2])),
853            Self::Rocm41 => Some(Version::new([4, 1])),
854            Self::Rocm401 => Some(Version::new([4, 0, 1])),
855            Self::Xpu => None,
856        }
857    }
858}
859
860impl FromStr for TorchBackend {
861    type Err = String;
862
863    fn from_str(s: &str) -> Result<Self, Self::Err> {
864        match s {
865            "cpu" => Ok(Self::Cpu),
866            "cu130" => Ok(Self::Cu130),
867            "cu129" => Ok(Self::Cu129),
868            "cu128" => Ok(Self::Cu128),
869            "cu126" => Ok(Self::Cu126),
870            "cu125" => Ok(Self::Cu125),
871            "cu124" => Ok(Self::Cu124),
872            "cu123" => Ok(Self::Cu123),
873            "cu122" => Ok(Self::Cu122),
874            "cu121" => Ok(Self::Cu121),
875            "cu120" => Ok(Self::Cu120),
876            "cu118" => Ok(Self::Cu118),
877            "cu117" => Ok(Self::Cu117),
878            "cu116" => Ok(Self::Cu116),
879            "cu115" => Ok(Self::Cu115),
880            "cu114" => Ok(Self::Cu114),
881            "cu113" => Ok(Self::Cu113),
882            "cu112" => Ok(Self::Cu112),
883            "cu111" => Ok(Self::Cu111),
884            "cu110" => Ok(Self::Cu110),
885            "cu102" => Ok(Self::Cu102),
886            "cu101" => Ok(Self::Cu101),
887            "cu100" => Ok(Self::Cu100),
888            "cu92" => Ok(Self::Cu92),
889            "cu91" => Ok(Self::Cu91),
890            "cu90" => Ok(Self::Cu90),
891            "cu80" => Ok(Self::Cu80),
892            "rocm6.4" => Ok(Self::Rocm64),
893            "rocm6.3" => Ok(Self::Rocm63),
894            "rocm6.2.4" => Ok(Self::Rocm624),
895            "rocm6.2" => Ok(Self::Rocm62),
896            "rocm6.1" => Ok(Self::Rocm61),
897            "rocm6.0" => Ok(Self::Rocm60),
898            "rocm5.7" => Ok(Self::Rocm57),
899            "rocm5.6" => Ok(Self::Rocm56),
900            "rocm5.5" => Ok(Self::Rocm55),
901            "rocm5.4.2" => Ok(Self::Rocm542),
902            "rocm5.4" => Ok(Self::Rocm54),
903            "rocm5.3" => Ok(Self::Rocm53),
904            "rocm5.2" => Ok(Self::Rocm52),
905            "rocm5.1.1" => Ok(Self::Rocm511),
906            "rocm4.2" => Ok(Self::Rocm42),
907            "rocm4.1" => Ok(Self::Rocm41),
908            "rocm4.0.1" => Ok(Self::Rocm401),
909            "xpu" => Ok(Self::Xpu),
910            _ => Err(format!("Unknown PyTorch backend: {s}")),
911        }
912    }
913}
914
915/// Linux CUDA driver versions and the corresponding CUDA versions.
916///
917/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
918static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock::new(|| {
919    [
920        // Table 2 from
921        // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
922        (TorchBackend::Cu130, Version::new([580])),
923        (TorchBackend::Cu129, Version::new([525, 60, 13])),
924        (TorchBackend::Cu128, Version::new([525, 60, 13])),
925        (TorchBackend::Cu126, Version::new([525, 60, 13])),
926        (TorchBackend::Cu125, Version::new([525, 60, 13])),
927        (TorchBackend::Cu124, Version::new([525, 60, 13])),
928        (TorchBackend::Cu123, Version::new([525, 60, 13])),
929        (TorchBackend::Cu122, Version::new([525, 60, 13])),
930        (TorchBackend::Cu121, Version::new([525, 60, 13])),
931        (TorchBackend::Cu120, Version::new([525, 60, 13])),
932        // Table 2 from
933        // https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
934        (TorchBackend::Cu118, Version::new([450, 80, 2])),
935        (TorchBackend::Cu117, Version::new([450, 80, 2])),
936        (TorchBackend::Cu116, Version::new([450, 80, 2])),
937        (TorchBackend::Cu115, Version::new([450, 80, 2])),
938        (TorchBackend::Cu114, Version::new([450, 80, 2])),
939        (TorchBackend::Cu113, Version::new([450, 80, 2])),
940        (TorchBackend::Cu112, Version::new([450, 80, 2])),
941        (TorchBackend::Cu111, Version::new([450, 80, 2])),
942        (TorchBackend::Cu110, Version::new([450, 36, 6])),
943        // Table 1 from
944        // https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
945        (TorchBackend::Cu102, Version::new([440, 33])),
946        (TorchBackend::Cu101, Version::new([418, 39])),
947        (TorchBackend::Cu100, Version::new([410, 48])),
948        (TorchBackend::Cu92, Version::new([396, 26])),
949        (TorchBackend::Cu91, Version::new([390, 46])),
950        (TorchBackend::Cu90, Version::new([384, 81])),
951        (TorchBackend::Cu80, Version::new([375, 26])),
952    ]
953});
954
955/// Windows CUDA driver versions and the corresponding CUDA versions.
956///
957/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
958static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock::new(|| {
959    [
960        // Table 2 from
961        // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
962        (TorchBackend::Cu130, Version::new([580])),
963        (TorchBackend::Cu129, Version::new([528, 33])),
964        (TorchBackend::Cu128, Version::new([528, 33])),
965        (TorchBackend::Cu126, Version::new([528, 33])),
966        (TorchBackend::Cu125, Version::new([528, 33])),
967        (TorchBackend::Cu124, Version::new([528, 33])),
968        (TorchBackend::Cu123, Version::new([528, 33])),
969        (TorchBackend::Cu122, Version::new([528, 33])),
970        (TorchBackend::Cu121, Version::new([528, 33])),
971        (TorchBackend::Cu120, Version::new([528, 33])),
972        // Table 2 from
973        // https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
974        (TorchBackend::Cu118, Version::new([452, 39])),
975        (TorchBackend::Cu117, Version::new([452, 39])),
976        (TorchBackend::Cu116, Version::new([452, 39])),
977        (TorchBackend::Cu115, Version::new([452, 39])),
978        (TorchBackend::Cu114, Version::new([452, 39])),
979        (TorchBackend::Cu113, Version::new([452, 39])),
980        (TorchBackend::Cu112, Version::new([452, 39])),
981        (TorchBackend::Cu111, Version::new([452, 39])),
982        (TorchBackend::Cu110, Version::new([451, 22])),
983        // Table 1 from
984        // https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
985        (TorchBackend::Cu102, Version::new([441, 22])),
986        (TorchBackend::Cu101, Version::new([418, 96])),
987        (TorchBackend::Cu100, Version::new([411, 31])),
988        (TorchBackend::Cu92, Version::new([398, 26])),
989        (TorchBackend::Cu91, Version::new([391, 29])),
990        (TorchBackend::Cu90, Version::new([385, 54])),
991        (TorchBackend::Cu80, Version::new([376, 51])),
992    ]
993});
994
995/// Linux AMD GPU architectures and the corresponding PyTorch backends.
996///
997/// These were inferred by running the following snippet for each ROCm version:
998///
999/// ```python
1000/// import torch
1001///
1002/// print(torch.cuda.get_arch_list())
1003/// ```
1004///
1005/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
1006/// however, this list includes a broader array of GPUs than those in the matrix.
1007static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 55]> =
1008    LazyLock::new(|| {
1009        [
1010            // ROCm 6.4
1011            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900),
1012            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906),
1013            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx908),
1014            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx90a),
1015            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx942),
1016            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1030),
1017            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1100),
1018            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1101),
1019            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1102),
1020            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1200),
1021            (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1201),
1022            // ROCm 6.3
1023            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
1024            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
1025            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908),
1026            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a),
1027            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942),
1028            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030),
1029            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100),
1030            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101),
1031            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102),
1032            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200),
1033            (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201),
1034            // ROCm 6.2.4
1035            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900),
1036            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906),
1037            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908),
1038            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a),
1039            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942),
1040            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030),
1041            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100),
1042            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101),
1043            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102),
1044            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200),
1045            (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201),
1046            // ROCm 6.2
1047            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900),
1048            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906),
1049            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908),
1050            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a),
1051            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030),
1052            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100),
1053            (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942),
1054            // ROCm 6.1
1055            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900),
1056            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906),
1057            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908),
1058            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a),
1059            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942),
1060            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030),
1061            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100),
1062            (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101),
1063            // ROCm 6.0
1064            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900),
1065            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906),
1066            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908),
1067            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a),
1068            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030),
1069            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100),
1070            (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942),
1071        ]
1072    });
1073
1074static PYTORCH_CPU_INDEX_URL: LazyLock<IndexUrl> =
1075    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
1076static PYTORCH_CU130_INDEX_URL: LazyLock<IndexUrl> =
1077    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu130").unwrap());
1078static PYTORCH_CU129_INDEX_URL: LazyLock<IndexUrl> =
1079    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu129").unwrap());
1080static PYTORCH_CU128_INDEX_URL: LazyLock<IndexUrl> =
1081    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap());
1082static PYTORCH_CU126_INDEX_URL: LazyLock<IndexUrl> =
1083    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap());
1084static PYTORCH_CU125_INDEX_URL: LazyLock<IndexUrl> =
1085    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap());
1086static PYTORCH_CU124_INDEX_URL: LazyLock<IndexUrl> =
1087    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap());
1088static PYTORCH_CU123_INDEX_URL: LazyLock<IndexUrl> =
1089    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap());
1090static PYTORCH_CU122_INDEX_URL: LazyLock<IndexUrl> =
1091    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap());
1092static PYTORCH_CU121_INDEX_URL: LazyLock<IndexUrl> =
1093    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap());
1094static PYTORCH_CU120_INDEX_URL: LazyLock<IndexUrl> =
1095    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap());
1096static PYTORCH_CU118_INDEX_URL: LazyLock<IndexUrl> =
1097    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap());
1098static PYTORCH_CU117_INDEX_URL: LazyLock<IndexUrl> =
1099    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap());
1100static PYTORCH_CU116_INDEX_URL: LazyLock<IndexUrl> =
1101    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap());
1102static PYTORCH_CU115_INDEX_URL: LazyLock<IndexUrl> =
1103    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap());
1104static PYTORCH_CU114_INDEX_URL: LazyLock<IndexUrl> =
1105    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap());
1106static PYTORCH_CU113_INDEX_URL: LazyLock<IndexUrl> =
1107    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap());
1108static PYTORCH_CU112_INDEX_URL: LazyLock<IndexUrl> =
1109    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap());
1110static PYTORCH_CU111_INDEX_URL: LazyLock<IndexUrl> =
1111    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap());
1112static PYTORCH_CU110_INDEX_URL: LazyLock<IndexUrl> =
1113    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap());
1114static PYTORCH_CU102_INDEX_URL: LazyLock<IndexUrl> =
1115    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap());
1116static PYTORCH_CU101_INDEX_URL: LazyLock<IndexUrl> =
1117    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap());
1118static PYTORCH_CU100_INDEX_URL: LazyLock<IndexUrl> =
1119    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap());
1120static PYTORCH_CU92_INDEX_URL: LazyLock<IndexUrl> =
1121    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap());
1122static PYTORCH_CU91_INDEX_URL: LazyLock<IndexUrl> =
1123    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap());
1124static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
1125    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
1126static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
1127    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
1128static PYTORCH_ROCM64_INDEX_URL: LazyLock<IndexUrl> =
1129    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap());
1130static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
1131    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
1132static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> =
1133    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2.4").unwrap());
1134static PYTORCH_ROCM62_INDEX_URL: LazyLock<IndexUrl> =
1135    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2").unwrap());
1136static PYTORCH_ROCM61_INDEX_URL: LazyLock<IndexUrl> =
1137    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.1").unwrap());
1138static PYTORCH_ROCM60_INDEX_URL: LazyLock<IndexUrl> =
1139    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.0").unwrap());
1140static PYTORCH_ROCM57_INDEX_URL: LazyLock<IndexUrl> =
1141    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.7").unwrap());
1142static PYTORCH_ROCM56_INDEX_URL: LazyLock<IndexUrl> =
1143    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.6").unwrap());
1144static PYTORCH_ROCM55_INDEX_URL: LazyLock<IndexUrl> =
1145    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.5").unwrap());
1146static PYTORCH_ROCM542_INDEX_URL: LazyLock<IndexUrl> =
1147    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4.2").unwrap());
1148static PYTORCH_ROCM54_INDEX_URL: LazyLock<IndexUrl> =
1149    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4").unwrap());
1150static PYTORCH_ROCM53_INDEX_URL: LazyLock<IndexUrl> =
1151    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.3").unwrap());
1152static PYTORCH_ROCM52_INDEX_URL: LazyLock<IndexUrl> =
1153    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.2").unwrap());
1154static PYTORCH_ROCM511_INDEX_URL: LazyLock<IndexUrl> =
1155    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.1.1").unwrap());
1156static PYTORCH_ROCM42_INDEX_URL: LazyLock<IndexUrl> =
1157    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.2").unwrap());
1158static PYTORCH_ROCM41_INDEX_URL: LazyLock<IndexUrl> =
1159    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.1").unwrap());
1160static PYTORCH_ROCM401_INDEX_URL: LazyLock<IndexUrl> =
1161    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.0.1").unwrap());
1162static PYTORCH_XPU_INDEX_URL: LazyLock<IndexUrl> =
1163    LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/xpu").unwrap());
1164
1165static PYX_API_BASE_URL: LazyLock<Cow<'static, str>> = LazyLock::new(|| {
1166    std::env::var(EnvVars::PYX_API_URL)
1167        .map(Cow::Owned)
1168        .unwrap_or(Cow::Borrowed("https://api.pyx.dev"))
1169});
1170static PYX_CPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1171    let api_base_url = &*PYX_API_BASE_URL;
1172    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cpu")).unwrap()
1173});
1174static PYX_CU130_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1175    let api_base_url = &*PYX_API_BASE_URL;
1176    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu130")).unwrap()
1177});
1178static PYX_CU129_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1179    let api_base_url = &*PYX_API_BASE_URL;
1180    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu129")).unwrap()
1181});
1182static PYX_CU128_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1183    let api_base_url = &*PYX_API_BASE_URL;
1184    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu128")).unwrap()
1185});
1186static PYX_CU126_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1187    let api_base_url = &*PYX_API_BASE_URL;
1188    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu126")).unwrap()
1189});
1190static PYX_CU125_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1191    let api_base_url = &*PYX_API_BASE_URL;
1192    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu125")).unwrap()
1193});
1194static PYX_CU124_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1195    let api_base_url = &*PYX_API_BASE_URL;
1196    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu124")).unwrap()
1197});
1198static PYX_CU123_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1199    let api_base_url = &*PYX_API_BASE_URL;
1200    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu123")).unwrap()
1201});
1202static PYX_CU122_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1203    let api_base_url = &*PYX_API_BASE_URL;
1204    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu122")).unwrap()
1205});
1206static PYX_CU121_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1207    let api_base_url = &*PYX_API_BASE_URL;
1208    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu121")).unwrap()
1209});
1210static PYX_CU120_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1211    let api_base_url = &*PYX_API_BASE_URL;
1212    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu120")).unwrap()
1213});
1214static PYX_CU118_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1215    let api_base_url = &*PYX_API_BASE_URL;
1216    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu118")).unwrap()
1217});
1218static PYX_CU117_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1219    let api_base_url = &*PYX_API_BASE_URL;
1220    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu117")).unwrap()
1221});
1222static PYX_CU116_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1223    let api_base_url = &*PYX_API_BASE_URL;
1224    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu116")).unwrap()
1225});
1226static PYX_CU115_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1227    let api_base_url = &*PYX_API_BASE_URL;
1228    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu115")).unwrap()
1229});
1230static PYX_CU114_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1231    let api_base_url = &*PYX_API_BASE_URL;
1232    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu114")).unwrap()
1233});
1234static PYX_CU113_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1235    let api_base_url = &*PYX_API_BASE_URL;
1236    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu113")).unwrap()
1237});
1238static PYX_CU112_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1239    let api_base_url = &*PYX_API_BASE_URL;
1240    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu112")).unwrap()
1241});
1242static PYX_CU111_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1243    let api_base_url = &*PYX_API_BASE_URL;
1244    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu111")).unwrap()
1245});
1246static PYX_CU110_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1247    let api_base_url = &*PYX_API_BASE_URL;
1248    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu110")).unwrap()
1249});
1250static PYX_CU102_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1251    let api_base_url = &*PYX_API_BASE_URL;
1252    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu102")).unwrap()
1253});
1254static PYX_CU101_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1255    let api_base_url = &*PYX_API_BASE_URL;
1256    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu101")).unwrap()
1257});
1258static PYX_CU100_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1259    let api_base_url = &*PYX_API_BASE_URL;
1260    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu100")).unwrap()
1261});
1262static PYX_CU92_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1263    let api_base_url = &*PYX_API_BASE_URL;
1264    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu92")).unwrap()
1265});
1266static PYX_CU91_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1267    let api_base_url = &*PYX_API_BASE_URL;
1268    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu91")).unwrap()
1269});
1270static PYX_CU90_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1271    let api_base_url = &*PYX_API_BASE_URL;
1272    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu90")).unwrap()
1273});
1274static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1275    let api_base_url = &*PYX_API_BASE_URL;
1276    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
1277});
1278static PYX_ROCM64_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1279    let api_base_url = &*PYX_API_BASE_URL;
1280    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap()
1281});
1282static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1283    let api_base_url = &*PYX_API_BASE_URL;
1284    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()
1285});
1286static PYX_ROCM624_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1287    let api_base_url = &*PYX_API_BASE_URL;
1288    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2.4")).unwrap()
1289});
1290static PYX_ROCM62_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1291    let api_base_url = &*PYX_API_BASE_URL;
1292    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2")).unwrap()
1293});
1294static PYX_ROCM61_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1295    let api_base_url = &*PYX_API_BASE_URL;
1296    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.1")).unwrap()
1297});
1298static PYX_ROCM60_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1299    let api_base_url = &*PYX_API_BASE_URL;
1300    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.0")).unwrap()
1301});
1302static PYX_ROCM57_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1303    let api_base_url = &*PYX_API_BASE_URL;
1304    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.7")).unwrap()
1305});
1306static PYX_ROCM56_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1307    let api_base_url = &*PYX_API_BASE_URL;
1308    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.6")).unwrap()
1309});
1310static PYX_ROCM55_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1311    let api_base_url = &*PYX_API_BASE_URL;
1312    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.5")).unwrap()
1313});
1314static PYX_ROCM542_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1315    let api_base_url = &*PYX_API_BASE_URL;
1316    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4.2")).unwrap()
1317});
1318static PYX_ROCM54_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1319    let api_base_url = &*PYX_API_BASE_URL;
1320    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4")).unwrap()
1321});
1322static PYX_ROCM53_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1323    let api_base_url = &*PYX_API_BASE_URL;
1324    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.3")).unwrap()
1325});
1326static PYX_ROCM52_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1327    let api_base_url = &*PYX_API_BASE_URL;
1328    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.2")).unwrap()
1329});
1330static PYX_ROCM511_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1331    let api_base_url = &*PYX_API_BASE_URL;
1332    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.1.1")).unwrap()
1333});
1334static PYX_ROCM42_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1335    let api_base_url = &*PYX_API_BASE_URL;
1336    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.2")).unwrap()
1337});
1338static PYX_ROCM41_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1339    let api_base_url = &*PYX_API_BASE_URL;
1340    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.1")).unwrap()
1341});
1342static PYX_ROCM401_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1343    let api_base_url = &*PYX_API_BASE_URL;
1344    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.0.1")).unwrap()
1345});
1346static PYX_XPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1347    let api_base_url = &*PYX_API_BASE_URL;
1348    IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/xpu")).unwrap()
1349});