1use std::borrow::Cow;
40use std::str::FromStr;
41use std::sync::LazyLock;
42
43use either::Either;
44use url::Url;
45
46use uv_distribution_types::IndexUrl;
47use uv_normalize::PackageName;
48use uv_pep440::Version;
49use uv_platform_tags::Os;
50use uv_static::EnvVars;
51
52use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
53
54#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
56#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
57#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
58#[serde(rename_all = "kebab-case")]
59pub enum TorchMode {
60 Auto,
62 Cpu,
64 Cu130,
66 Cu129,
68 Cu128,
70 Cu126,
72 Cu125,
74 Cu124,
76 Cu123,
78 Cu122,
80 Cu121,
82 Cu120,
84 Cu118,
86 Cu117,
88 Cu116,
90 Cu115,
92 Cu114,
94 Cu113,
96 Cu112,
98 Cu111,
100 Cu110,
102 Cu102,
104 Cu101,
106 Cu100,
108 Cu92,
110 Cu91,
112 Cu90,
114 Cu80,
116 #[serde(rename = "rocm7.1")]
118 #[cfg_attr(feature = "clap", clap(name = "rocm7.1"))]
119 Rocm71,
120 #[serde(rename = "rocm7.0")]
122 #[cfg_attr(feature = "clap", clap(name = "rocm7.0"))]
123 Rocm70,
124 #[serde(rename = "rocm6.4")]
126 #[cfg_attr(feature = "clap", clap(name = "rocm6.4"))]
127 Rocm64,
128 #[serde(rename = "rocm6.3")]
130 #[cfg_attr(feature = "clap", clap(name = "rocm6.3"))]
131 Rocm63,
132 #[serde(rename = "rocm6.2.4")]
134 #[cfg_attr(feature = "clap", clap(name = "rocm6.2.4"))]
135 Rocm624,
136 #[serde(rename = "rocm6.2")]
138 #[cfg_attr(feature = "clap", clap(name = "rocm6.2"))]
139 Rocm62,
140 #[serde(rename = "rocm6.1")]
142 #[cfg_attr(feature = "clap", clap(name = "rocm6.1"))]
143 Rocm61,
144 #[serde(rename = "rocm6.0")]
146 #[cfg_attr(feature = "clap", clap(name = "rocm6.0"))]
147 Rocm60,
148 #[serde(rename = "rocm5.7")]
150 #[cfg_attr(feature = "clap", clap(name = "rocm5.7"))]
151 Rocm57,
152 #[serde(rename = "rocm5.6")]
154 #[cfg_attr(feature = "clap", clap(name = "rocm5.6"))]
155 Rocm56,
156 #[serde(rename = "rocm5.5")]
158 #[cfg_attr(feature = "clap", clap(name = "rocm5.5"))]
159 Rocm55,
160 #[serde(rename = "rocm5.4.2")]
162 #[cfg_attr(feature = "clap", clap(name = "rocm5.4.2"))]
163 Rocm542,
164 #[serde(rename = "rocm5.4")]
166 #[cfg_attr(feature = "clap", clap(name = "rocm5.4"))]
167 Rocm54,
168 #[serde(rename = "rocm5.3")]
170 #[cfg_attr(feature = "clap", clap(name = "rocm5.3"))]
171 Rocm53,
172 #[serde(rename = "rocm5.2")]
174 #[cfg_attr(feature = "clap", clap(name = "rocm5.2"))]
175 Rocm52,
176 #[serde(rename = "rocm5.1.1")]
178 #[cfg_attr(feature = "clap", clap(name = "rocm5.1.1"))]
179 Rocm511,
180 #[serde(rename = "rocm4.2")]
182 #[cfg_attr(feature = "clap", clap(name = "rocm4.2"))]
183 Rocm42,
184 #[serde(rename = "rocm4.1")]
186 #[cfg_attr(feature = "clap", clap(name = "rocm4.1"))]
187 Rocm41,
188 #[serde(rename = "rocm4.0.1")]
190 #[cfg_attr(feature = "clap", clap(name = "rocm4.0.1"))]
191 Rocm401,
192 Xpu,
194}
195
196#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
197pub enum TorchSource {
198 #[default]
200 PyTorch,
201 Pyx,
203}
204
205#[derive(Debug, Clone, Eq, PartialEq)]
207pub enum TorchStrategy {
208 Cuda {
210 os: Os,
211 driver_version: Version,
212 source: TorchSource,
213 },
214 Amd {
216 os: Os,
217 gpu_architecture: AmdGpuArchitecture,
218 source: TorchSource,
219 },
220 Xpu { os: Os, source: TorchSource },
222 Backend {
224 backend: TorchBackend,
225 source: TorchSource,
226 },
227}
228
229impl TorchStrategy {
230 pub fn from_mode(
232 mode: TorchMode,
233 source: TorchSource,
234 os: &Os,
235 ) -> Result<Self, AcceleratorError> {
236 let backend = match mode {
237 TorchMode::Auto => match Accelerator::detect()? {
238 Some(Accelerator::Cuda { driver_version }) => {
239 return Ok(Self::Cuda {
240 os: os.clone(),
241 driver_version: driver_version.clone(),
242 source,
243 });
244 }
245 Some(Accelerator::Amd { gpu_architecture }) => {
246 return Ok(Self::Amd {
247 os: os.clone(),
248 gpu_architecture,
249 source,
250 });
251 }
252 Some(Accelerator::Xpu) => {
253 return Ok(Self::Xpu {
254 os: os.clone(),
255 source,
256 });
257 }
258 None => TorchBackend::Cpu,
259 },
260 TorchMode::Cpu => TorchBackend::Cpu,
261 TorchMode::Cu130 => TorchBackend::Cu130,
262 TorchMode::Cu129 => TorchBackend::Cu129,
263 TorchMode::Cu128 => TorchBackend::Cu128,
264 TorchMode::Cu126 => TorchBackend::Cu126,
265 TorchMode::Cu125 => TorchBackend::Cu125,
266 TorchMode::Cu124 => TorchBackend::Cu124,
267 TorchMode::Cu123 => TorchBackend::Cu123,
268 TorchMode::Cu122 => TorchBackend::Cu122,
269 TorchMode::Cu121 => TorchBackend::Cu121,
270 TorchMode::Cu120 => TorchBackend::Cu120,
271 TorchMode::Cu118 => TorchBackend::Cu118,
272 TorchMode::Cu117 => TorchBackend::Cu117,
273 TorchMode::Cu116 => TorchBackend::Cu116,
274 TorchMode::Cu115 => TorchBackend::Cu115,
275 TorchMode::Cu114 => TorchBackend::Cu114,
276 TorchMode::Cu113 => TorchBackend::Cu113,
277 TorchMode::Cu112 => TorchBackend::Cu112,
278 TorchMode::Cu111 => TorchBackend::Cu111,
279 TorchMode::Cu110 => TorchBackend::Cu110,
280 TorchMode::Cu102 => TorchBackend::Cu102,
281 TorchMode::Cu101 => TorchBackend::Cu101,
282 TorchMode::Cu100 => TorchBackend::Cu100,
283 TorchMode::Cu92 => TorchBackend::Cu92,
284 TorchMode::Cu91 => TorchBackend::Cu91,
285 TorchMode::Cu90 => TorchBackend::Cu90,
286 TorchMode::Cu80 => TorchBackend::Cu80,
287 TorchMode::Rocm71 => TorchBackend::Rocm71,
288 TorchMode::Rocm70 => TorchBackend::Rocm70,
289 TorchMode::Rocm64 => TorchBackend::Rocm64,
290 TorchMode::Rocm63 => TorchBackend::Rocm63,
291 TorchMode::Rocm624 => TorchBackend::Rocm624,
292 TorchMode::Rocm62 => TorchBackend::Rocm62,
293 TorchMode::Rocm61 => TorchBackend::Rocm61,
294 TorchMode::Rocm60 => TorchBackend::Rocm60,
295 TorchMode::Rocm57 => TorchBackend::Rocm57,
296 TorchMode::Rocm56 => TorchBackend::Rocm56,
297 TorchMode::Rocm55 => TorchBackend::Rocm55,
298 TorchMode::Rocm542 => TorchBackend::Rocm542,
299 TorchMode::Rocm54 => TorchBackend::Rocm54,
300 TorchMode::Rocm53 => TorchBackend::Rocm53,
301 TorchMode::Rocm52 => TorchBackend::Rocm52,
302 TorchMode::Rocm511 => TorchBackend::Rocm511,
303 TorchMode::Rocm42 => TorchBackend::Rocm42,
304 TorchMode::Rocm41 => TorchBackend::Rocm41,
305 TorchMode::Rocm401 => TorchBackend::Rocm401,
306 TorchMode::Xpu => TorchBackend::Xpu,
307 };
308 Ok(Self::Backend { backend, source })
309 }
310
311 pub fn applies_to(&self, package_name: &PackageName) -> bool {
313 let source = match self {
314 Self::Cuda { source, .. } => *source,
315 Self::Amd { source, .. } => *source,
316 Self::Xpu { source, .. } => *source,
317 Self::Backend { source, .. } => *source,
318 };
319 match source {
320 TorchSource::PyTorch => {
321 matches!(
322 package_name.as_str(),
323 "pytorch-triton"
324 | "pytorch-triton-rocm"
325 | "pytorch-triton-xpu"
326 | "torch"
327 | "torch-tensorrt"
328 | "torchao"
329 | "torcharrow"
330 | "torchaudio"
331 | "torchcsprng"
332 | "torchdata"
333 | "torchdistx"
334 | "torchserve"
335 | "torchtext"
336 | "torchvision"
337 | "triton"
338 )
339 }
340 TorchSource::Pyx => {
341 matches!(
342 package_name.as_str(),
343 "deepspeed"
344 | "flash-attn"
345 | "flash-attn-3"
346 | "megablocks"
347 | "natten"
348 | "pyg-lib"
349 | "pytorch-triton"
350 | "pytorch-triton-rocm"
351 | "pytorch-triton-xpu"
352 | "torch"
353 | "torch-cluster"
354 | "torch-scatter"
355 | "torch-sparse"
356 | "torch-spline-conv"
357 | "torch-tensorrt"
358 | "torchao"
359 | "torcharrow"
360 | "torchaudio"
361 | "torchcsprng"
362 | "torchdata"
363 | "torchdistx"
364 | "torchserve"
365 | "torchtext"
366 | "torchvision"
367 | "triton"
368 | "vllm"
369 )
370 }
371 }
372 }
373
374 pub fn has_system_dependency(&self, package_name: &PackageName) -> bool {
380 matches!(
381 package_name.as_str(),
382 "deepspeed"
383 | "flash-attn"
384 | "flash-attn-3"
385 | "megablocks"
386 | "natten"
387 | "torch"
388 | "torch-tensorrt"
389 | "torchao"
390 | "torcharrow"
391 | "torchaudio"
392 | "torchcsprng"
393 | "torchdata"
394 | "torchdistx"
395 | "torchtext"
396 | "torchvision"
397 | "vllm"
398 )
399 }
400
401 pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
403 match self {
404 Self::Cuda {
405 os,
406 driver_version,
407 source,
408 } => {
409 match os {
414 Os::Manylinux { .. } | Os::Musllinux { .. } => {
415 Either::Left(Either::Left(Either::Left(
416 LINUX_CUDA_DRIVERS
417 .iter()
418 .filter_map(move |(backend, version)| {
419 if driver_version >= version {
420 Some(backend.index_url(*source))
421 } else {
422 None
423 }
424 })
425 .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
426 )))
427 }
428 Os::Windows => Either::Left(Either::Left(Either::Right(
429 WINDOWS_CUDA_VERSIONS
430 .iter()
431 .filter_map(move |(backend, version)| {
432 if driver_version >= version {
433 Some(backend.index_url(*source))
434 } else {
435 None
436 }
437 })
438 .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
439 ))),
440 Os::Macos { .. }
441 | Os::FreeBsd { .. }
442 | Os::NetBsd { .. }
443 | Os::OpenBsd { .. }
444 | Os::Dragonfly { .. }
445 | Os::Illumos { .. }
446 | Os::Haiku { .. }
447 | Os::Android { .. }
448 | Os::Pyodide { .. }
449 | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
450 TorchBackend::Cpu.index_url(*source),
451 ))),
452 }
453 }
454 Self::Amd {
455 os,
456 gpu_architecture,
457 source,
458 } => match os {
459 Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
460 LINUX_AMD_GPU_DRIVERS
461 .iter()
462 .filter_map(move |(backend, architecture)| {
463 if gpu_architecture == architecture {
464 Some(backend.index_url(*source))
465 } else {
466 None
467 }
468 })
469 .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
470 )),
471 Os::Windows
472 | Os::Macos { .. }
473 | Os::FreeBsd { .. }
474 | Os::NetBsd { .. }
475 | Os::OpenBsd { .. }
476 | Os::Dragonfly { .. }
477 | Os::Illumos { .. }
478 | Os::Haiku { .. }
479 | Os::Android { .. }
480 | Os::Pyodide { .. }
481 | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
482 TorchBackend::Cpu.index_url(*source),
483 ))),
484 },
485 Self::Xpu { os, source } => match os {
486 Os::Manylinux { .. } | Os::Windows => Either::Right(Either::Right(Either::Left(
487 std::iter::once(TorchBackend::Xpu.index_url(*source)),
488 ))),
489 Os::Musllinux { .. }
490 | Os::Macos { .. }
491 | Os::FreeBsd { .. }
492 | Os::NetBsd { .. }
493 | Os::OpenBsd { .. }
494 | Os::Dragonfly { .. }
495 | Os::Illumos { .. }
496 | Os::Haiku { .. }
497 | Os::Android { .. }
498 | Os::Pyodide { .. }
499 | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
500 TorchBackend::Cpu.index_url(*source),
501 ))),
502 },
503 Self::Backend { backend, source } => Either::Right(Either::Right(Either::Right(
504 std::iter::once(backend.index_url(*source)),
505 ))),
506 }
507 }
508}
509
510#[derive(Debug, Copy, Clone, Eq, PartialEq)]
512pub enum TorchBackend {
513 Cpu,
514 Cu130,
515 Cu129,
516 Cu128,
517 Cu126,
518 Cu125,
519 Cu124,
520 Cu123,
521 Cu122,
522 Cu121,
523 Cu120,
524 Cu118,
525 Cu117,
526 Cu116,
527 Cu115,
528 Cu114,
529 Cu113,
530 Cu112,
531 Cu111,
532 Cu110,
533 Cu102,
534 Cu101,
535 Cu100,
536 Cu92,
537 Cu91,
538 Cu90,
539 Cu80,
540 Rocm71,
541 Rocm70,
542 Rocm64,
543 Rocm63,
544 Rocm624,
545 Rocm62,
546 Rocm61,
547 Rocm60,
548 Rocm57,
549 Rocm56,
550 Rocm55,
551 Rocm542,
552 Rocm54,
553 Rocm53,
554 Rocm52,
555 Rocm511,
556 Rocm42,
557 Rocm41,
558 Rocm401,
559 Xpu,
560}
561
562impl TorchBackend {
563 fn index_url(self, source: TorchSource) -> &'static IndexUrl {
565 match self {
566 Self::Cpu => match source {
567 TorchSource::PyTorch => &PYTORCH_CPU_INDEX_URL,
568 TorchSource::Pyx => &PYX_CPU_INDEX_URL,
569 },
570 Self::Cu130 => match source {
571 TorchSource::PyTorch => &PYTORCH_CU130_INDEX_URL,
572 TorchSource::Pyx => &PYX_CU130_INDEX_URL,
573 },
574 Self::Cu129 => match source {
575 TorchSource::PyTorch => &PYTORCH_CU129_INDEX_URL,
576 TorchSource::Pyx => &PYX_CU129_INDEX_URL,
577 },
578 Self::Cu128 => match source {
579 TorchSource::PyTorch => &PYTORCH_CU128_INDEX_URL,
580 TorchSource::Pyx => &PYX_CU128_INDEX_URL,
581 },
582 Self::Cu126 => match source {
583 TorchSource::PyTorch => &PYTORCH_CU126_INDEX_URL,
584 TorchSource::Pyx => &PYX_CU126_INDEX_URL,
585 },
586 Self::Cu125 => match source {
587 TorchSource::PyTorch => &PYTORCH_CU125_INDEX_URL,
588 TorchSource::Pyx => &PYX_CU125_INDEX_URL,
589 },
590 Self::Cu124 => match source {
591 TorchSource::PyTorch => &PYTORCH_CU124_INDEX_URL,
592 TorchSource::Pyx => &PYX_CU124_INDEX_URL,
593 },
594 Self::Cu123 => match source {
595 TorchSource::PyTorch => &PYTORCH_CU123_INDEX_URL,
596 TorchSource::Pyx => &PYX_CU123_INDEX_URL,
597 },
598 Self::Cu122 => match source {
599 TorchSource::PyTorch => &PYTORCH_CU122_INDEX_URL,
600 TorchSource::Pyx => &PYX_CU122_INDEX_URL,
601 },
602 Self::Cu121 => match source {
603 TorchSource::PyTorch => &PYTORCH_CU121_INDEX_URL,
604 TorchSource::Pyx => &PYX_CU121_INDEX_URL,
605 },
606 Self::Cu120 => match source {
607 TorchSource::PyTorch => &PYTORCH_CU120_INDEX_URL,
608 TorchSource::Pyx => &PYX_CU120_INDEX_URL,
609 },
610 Self::Cu118 => match source {
611 TorchSource::PyTorch => &PYTORCH_CU118_INDEX_URL,
612 TorchSource::Pyx => &PYX_CU118_INDEX_URL,
613 },
614 Self::Cu117 => match source {
615 TorchSource::PyTorch => &PYTORCH_CU117_INDEX_URL,
616 TorchSource::Pyx => &PYX_CU117_INDEX_URL,
617 },
618 Self::Cu116 => match source {
619 TorchSource::PyTorch => &PYTORCH_CU116_INDEX_URL,
620 TorchSource::Pyx => &PYX_CU116_INDEX_URL,
621 },
622 Self::Cu115 => match source {
623 TorchSource::PyTorch => &PYTORCH_CU115_INDEX_URL,
624 TorchSource::Pyx => &PYX_CU115_INDEX_URL,
625 },
626 Self::Cu114 => match source {
627 TorchSource::PyTorch => &PYTORCH_CU114_INDEX_URL,
628 TorchSource::Pyx => &PYX_CU114_INDEX_URL,
629 },
630 Self::Cu113 => match source {
631 TorchSource::PyTorch => &PYTORCH_CU113_INDEX_URL,
632 TorchSource::Pyx => &PYX_CU113_INDEX_URL,
633 },
634 Self::Cu112 => match source {
635 TorchSource::PyTorch => &PYTORCH_CU112_INDEX_URL,
636 TorchSource::Pyx => &PYX_CU112_INDEX_URL,
637 },
638 Self::Cu111 => match source {
639 TorchSource::PyTorch => &PYTORCH_CU111_INDEX_URL,
640 TorchSource::Pyx => &PYX_CU111_INDEX_URL,
641 },
642 Self::Cu110 => match source {
643 TorchSource::PyTorch => &PYTORCH_CU110_INDEX_URL,
644 TorchSource::Pyx => &PYX_CU110_INDEX_URL,
645 },
646 Self::Cu102 => match source {
647 TorchSource::PyTorch => &PYTORCH_CU102_INDEX_URL,
648 TorchSource::Pyx => &PYX_CU102_INDEX_URL,
649 },
650 Self::Cu101 => match source {
651 TorchSource::PyTorch => &PYTORCH_CU101_INDEX_URL,
652 TorchSource::Pyx => &PYX_CU101_INDEX_URL,
653 },
654 Self::Cu100 => match source {
655 TorchSource::PyTorch => &PYTORCH_CU100_INDEX_URL,
656 TorchSource::Pyx => &PYX_CU100_INDEX_URL,
657 },
658 Self::Cu92 => match source {
659 TorchSource::PyTorch => &PYTORCH_CU92_INDEX_URL,
660 TorchSource::Pyx => &PYX_CU92_INDEX_URL,
661 },
662 Self::Cu91 => match source {
663 TorchSource::PyTorch => &PYTORCH_CU91_INDEX_URL,
664 TorchSource::Pyx => &PYX_CU91_INDEX_URL,
665 },
666 Self::Cu90 => match source {
667 TorchSource::PyTorch => &PYTORCH_CU90_INDEX_URL,
668 TorchSource::Pyx => &PYX_CU90_INDEX_URL,
669 },
670 Self::Cu80 => match source {
671 TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
672 TorchSource::Pyx => &PYX_CU80_INDEX_URL,
673 },
674 Self::Rocm71 => match source {
675 TorchSource::PyTorch => &PYTORCH_ROCM71_INDEX_URL,
676 TorchSource::Pyx => &PYX_ROCM71_INDEX_URL,
677 },
678 Self::Rocm70 => match source {
679 TorchSource::PyTorch => &PYTORCH_ROCM70_INDEX_URL,
680 TorchSource::Pyx => &PYX_ROCM70_INDEX_URL,
681 },
682 Self::Rocm64 => match source {
683 TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL,
684 TorchSource::Pyx => &PYX_ROCM64_INDEX_URL,
685 },
686 Self::Rocm63 => match source {
687 TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
688 TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
689 },
690 Self::Rocm624 => match source {
691 TorchSource::PyTorch => &PYTORCH_ROCM624_INDEX_URL,
692 TorchSource::Pyx => &PYX_ROCM624_INDEX_URL,
693 },
694 Self::Rocm62 => match source {
695 TorchSource::PyTorch => &PYTORCH_ROCM62_INDEX_URL,
696 TorchSource::Pyx => &PYX_ROCM62_INDEX_URL,
697 },
698 Self::Rocm61 => match source {
699 TorchSource::PyTorch => &PYTORCH_ROCM61_INDEX_URL,
700 TorchSource::Pyx => &PYX_ROCM61_INDEX_URL,
701 },
702 Self::Rocm60 => match source {
703 TorchSource::PyTorch => &PYTORCH_ROCM60_INDEX_URL,
704 TorchSource::Pyx => &PYX_ROCM60_INDEX_URL,
705 },
706 Self::Rocm57 => match source {
707 TorchSource::PyTorch => &PYTORCH_ROCM57_INDEX_URL,
708 TorchSource::Pyx => &PYX_ROCM57_INDEX_URL,
709 },
710 Self::Rocm56 => match source {
711 TorchSource::PyTorch => &PYTORCH_ROCM56_INDEX_URL,
712 TorchSource::Pyx => &PYX_ROCM56_INDEX_URL,
713 },
714 Self::Rocm55 => match source {
715 TorchSource::PyTorch => &PYTORCH_ROCM55_INDEX_URL,
716 TorchSource::Pyx => &PYX_ROCM55_INDEX_URL,
717 },
718 Self::Rocm542 => match source {
719 TorchSource::PyTorch => &PYTORCH_ROCM542_INDEX_URL,
720 TorchSource::Pyx => &PYX_ROCM542_INDEX_URL,
721 },
722 Self::Rocm54 => match source {
723 TorchSource::PyTorch => &PYTORCH_ROCM54_INDEX_URL,
724 TorchSource::Pyx => &PYX_ROCM54_INDEX_URL,
725 },
726 Self::Rocm53 => match source {
727 TorchSource::PyTorch => &PYTORCH_ROCM53_INDEX_URL,
728 TorchSource::Pyx => &PYX_ROCM53_INDEX_URL,
729 },
730 Self::Rocm52 => match source {
731 TorchSource::PyTorch => &PYTORCH_ROCM52_INDEX_URL,
732 TorchSource::Pyx => &PYX_ROCM52_INDEX_URL,
733 },
734 Self::Rocm511 => match source {
735 TorchSource::PyTorch => &PYTORCH_ROCM511_INDEX_URL,
736 TorchSource::Pyx => &PYX_ROCM511_INDEX_URL,
737 },
738 Self::Rocm42 => match source {
739 TorchSource::PyTorch => &PYTORCH_ROCM42_INDEX_URL,
740 TorchSource::Pyx => &PYX_ROCM42_INDEX_URL,
741 },
742 Self::Rocm41 => match source {
743 TorchSource::PyTorch => &PYTORCH_ROCM41_INDEX_URL,
744 TorchSource::Pyx => &PYX_ROCM41_INDEX_URL,
745 },
746 Self::Rocm401 => match source {
747 TorchSource::PyTorch => &PYTORCH_ROCM401_INDEX_URL,
748 TorchSource::Pyx => &PYX_ROCM401_INDEX_URL,
749 },
750 Self::Xpu => match source {
751 TorchSource::PyTorch => &PYTORCH_XPU_INDEX_URL,
752 TorchSource::Pyx => &PYX_XPU_INDEX_URL,
753 },
754 }
755 }
756
757 pub fn from_index(index: &Url) -> Option<Self> {
759 let backend_identifier = if index.host_str() == Some("download.pytorch.org") {
760 let mut path_segments = index.path_segments()?;
762 if path_segments.next() != Some("whl") {
763 return None;
764 }
765 path_segments.next()?
766 } else if index.host_str() == PYX_API_BASE_URL.strip_prefix("https://") {
768 let mut path_segments = index.path_segments()?;
770 if path_segments.next() != Some("simple") {
771 return None;
772 }
773 if path_segments.next() != Some("astral-sh") {
774 return None;
775 }
776 path_segments.next()?
777 } else {
778 return None;
779 };
780 Self::from_str(backend_identifier).ok()
781 }
782
783 pub fn cuda_version(&self) -> Option<Version> {
785 match self {
786 Self::Cpu => None,
787 Self::Cu130 => Some(Version::new([13, 0])),
788 Self::Cu129 => Some(Version::new([12, 9])),
789 Self::Cu128 => Some(Version::new([12, 8])),
790 Self::Cu126 => Some(Version::new([12, 6])),
791 Self::Cu125 => Some(Version::new([12, 5])),
792 Self::Cu124 => Some(Version::new([12, 4])),
793 Self::Cu123 => Some(Version::new([12, 3])),
794 Self::Cu122 => Some(Version::new([12, 2])),
795 Self::Cu121 => Some(Version::new([12, 1])),
796 Self::Cu120 => Some(Version::new([12, 0])),
797 Self::Cu118 => Some(Version::new([11, 8])),
798 Self::Cu117 => Some(Version::new([11, 7])),
799 Self::Cu116 => Some(Version::new([11, 6])),
800 Self::Cu115 => Some(Version::new([11, 5])),
801 Self::Cu114 => Some(Version::new([11, 4])),
802 Self::Cu113 => Some(Version::new([11, 3])),
803 Self::Cu112 => Some(Version::new([11, 2])),
804 Self::Cu111 => Some(Version::new([11, 1])),
805 Self::Cu110 => Some(Version::new([11, 0])),
806 Self::Cu102 => Some(Version::new([10, 2])),
807 Self::Cu101 => Some(Version::new([10, 1])),
808 Self::Cu100 => Some(Version::new([10, 0])),
809 Self::Cu92 => Some(Version::new([9, 2])),
810 Self::Cu91 => Some(Version::new([9, 1])),
811 Self::Cu90 => Some(Version::new([9, 0])),
812 Self::Cu80 => Some(Version::new([8, 0])),
813 Self::Rocm71 => None,
814 Self::Rocm70 => None,
815 Self::Rocm64 => None,
816 Self::Rocm63 => None,
817 Self::Rocm624 => None,
818 Self::Rocm62 => None,
819 Self::Rocm61 => None,
820 Self::Rocm60 => None,
821 Self::Rocm57 => None,
822 Self::Rocm56 => None,
823 Self::Rocm55 => None,
824 Self::Rocm542 => None,
825 Self::Rocm54 => None,
826 Self::Rocm53 => None,
827 Self::Rocm52 => None,
828 Self::Rocm511 => None,
829 Self::Rocm42 => None,
830 Self::Rocm41 => None,
831 Self::Rocm401 => None,
832 Self::Xpu => None,
833 }
834 }
835
836 pub fn rocm_version(&self) -> Option<Version> {
838 match self {
839 Self::Cpu => None,
840 Self::Cu130 => None,
841 Self::Cu129 => None,
842 Self::Cu128 => None,
843 Self::Cu126 => None,
844 Self::Cu125 => None,
845 Self::Cu124 => None,
846 Self::Cu123 => None,
847 Self::Cu122 => None,
848 Self::Cu121 => None,
849 Self::Cu120 => None,
850 Self::Cu118 => None,
851 Self::Cu117 => None,
852 Self::Cu116 => None,
853 Self::Cu115 => None,
854 Self::Cu114 => None,
855 Self::Cu113 => None,
856 Self::Cu112 => None,
857 Self::Cu111 => None,
858 Self::Cu110 => None,
859 Self::Cu102 => None,
860 Self::Cu101 => None,
861 Self::Cu100 => None,
862 Self::Cu92 => None,
863 Self::Cu91 => None,
864 Self::Cu90 => None,
865 Self::Cu80 => None,
866 Self::Rocm71 => Some(Version::new([7, 1])),
867 Self::Rocm70 => Some(Version::new([7, 0])),
868 Self::Rocm64 => Some(Version::new([6, 4])),
869 Self::Rocm63 => Some(Version::new([6, 3])),
870 Self::Rocm624 => Some(Version::new([6, 2, 4])),
871 Self::Rocm62 => Some(Version::new([6, 2])),
872 Self::Rocm61 => Some(Version::new([6, 1])),
873 Self::Rocm60 => Some(Version::new([6, 0])),
874 Self::Rocm57 => Some(Version::new([5, 7])),
875 Self::Rocm56 => Some(Version::new([5, 6])),
876 Self::Rocm55 => Some(Version::new([5, 5])),
877 Self::Rocm542 => Some(Version::new([5, 4, 2])),
878 Self::Rocm54 => Some(Version::new([5, 4])),
879 Self::Rocm53 => Some(Version::new([5, 3])),
880 Self::Rocm52 => Some(Version::new([5, 2])),
881 Self::Rocm511 => Some(Version::new([5, 1, 1])),
882 Self::Rocm42 => Some(Version::new([4, 2])),
883 Self::Rocm41 => Some(Version::new([4, 1])),
884 Self::Rocm401 => Some(Version::new([4, 0, 1])),
885 Self::Xpu => None,
886 }
887 }
888}
889
890impl FromStr for TorchBackend {
891 type Err = String;
892
893 fn from_str(s: &str) -> Result<Self, Self::Err> {
894 match s {
895 "cpu" => Ok(Self::Cpu),
896 "cu130" => Ok(Self::Cu130),
897 "cu129" => Ok(Self::Cu129),
898 "cu128" => Ok(Self::Cu128),
899 "cu126" => Ok(Self::Cu126),
900 "cu125" => Ok(Self::Cu125),
901 "cu124" => Ok(Self::Cu124),
902 "cu123" => Ok(Self::Cu123),
903 "cu122" => Ok(Self::Cu122),
904 "cu121" => Ok(Self::Cu121),
905 "cu120" => Ok(Self::Cu120),
906 "cu118" => Ok(Self::Cu118),
907 "cu117" => Ok(Self::Cu117),
908 "cu116" => Ok(Self::Cu116),
909 "cu115" => Ok(Self::Cu115),
910 "cu114" => Ok(Self::Cu114),
911 "cu113" => Ok(Self::Cu113),
912 "cu112" => Ok(Self::Cu112),
913 "cu111" => Ok(Self::Cu111),
914 "cu110" => Ok(Self::Cu110),
915 "cu102" => Ok(Self::Cu102),
916 "cu101" => Ok(Self::Cu101),
917 "cu100" => Ok(Self::Cu100),
918 "cu92" => Ok(Self::Cu92),
919 "cu91" => Ok(Self::Cu91),
920 "cu90" => Ok(Self::Cu90),
921 "cu80" => Ok(Self::Cu80),
922 "rocm7.1" => Ok(Self::Rocm71),
923 "rocm7.0" => Ok(Self::Rocm70),
924 "rocm6.4" => Ok(Self::Rocm64),
925 "rocm6.3" => Ok(Self::Rocm63),
926 "rocm6.2.4" => Ok(Self::Rocm624),
927 "rocm6.2" => Ok(Self::Rocm62),
928 "rocm6.1" => Ok(Self::Rocm61),
929 "rocm6.0" => Ok(Self::Rocm60),
930 "rocm5.7" => Ok(Self::Rocm57),
931 "rocm5.6" => Ok(Self::Rocm56),
932 "rocm5.5" => Ok(Self::Rocm55),
933 "rocm5.4.2" => Ok(Self::Rocm542),
934 "rocm5.4" => Ok(Self::Rocm54),
935 "rocm5.3" => Ok(Self::Rocm53),
936 "rocm5.2" => Ok(Self::Rocm52),
937 "rocm5.1.1" => Ok(Self::Rocm511),
938 "rocm4.2" => Ok(Self::Rocm42),
939 "rocm4.1" => Ok(Self::Rocm41),
940 "rocm4.0.1" => Ok(Self::Rocm401),
941 "xpu" => Ok(Self::Xpu),
942 _ => Err(format!("Unknown PyTorch backend: {s}")),
943 }
944 }
945}
946
947static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock::new(|| {
951 [
952 (TorchBackend::Cu130, Version::new([580])),
955 (TorchBackend::Cu129, Version::new([525, 60, 13])),
956 (TorchBackend::Cu128, Version::new([525, 60, 13])),
957 (TorchBackend::Cu126, Version::new([525, 60, 13])),
958 (TorchBackend::Cu125, Version::new([525, 60, 13])),
959 (TorchBackend::Cu124, Version::new([525, 60, 13])),
960 (TorchBackend::Cu123, Version::new([525, 60, 13])),
961 (TorchBackend::Cu122, Version::new([525, 60, 13])),
962 (TorchBackend::Cu121, Version::new([525, 60, 13])),
963 (TorchBackend::Cu120, Version::new([525, 60, 13])),
964 (TorchBackend::Cu118, Version::new([450, 80, 2])),
967 (TorchBackend::Cu117, Version::new([450, 80, 2])),
968 (TorchBackend::Cu116, Version::new([450, 80, 2])),
969 (TorchBackend::Cu115, Version::new([450, 80, 2])),
970 (TorchBackend::Cu114, Version::new([450, 80, 2])),
971 (TorchBackend::Cu113, Version::new([450, 80, 2])),
972 (TorchBackend::Cu112, Version::new([450, 80, 2])),
973 (TorchBackend::Cu111, Version::new([450, 80, 2])),
974 (TorchBackend::Cu110, Version::new([450, 36, 6])),
975 (TorchBackend::Cu102, Version::new([440, 33])),
978 (TorchBackend::Cu101, Version::new([418, 39])),
979 (TorchBackend::Cu100, Version::new([410, 48])),
980 (TorchBackend::Cu92, Version::new([396, 26])),
981 (TorchBackend::Cu91, Version::new([390, 46])),
982 (TorchBackend::Cu90, Version::new([384, 81])),
983 (TorchBackend::Cu80, Version::new([375, 26])),
984 ]
985});
986
987static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock::new(|| {
991 [
992 (TorchBackend::Cu130, Version::new([580])),
995 (TorchBackend::Cu129, Version::new([528, 33])),
996 (TorchBackend::Cu128, Version::new([528, 33])),
997 (TorchBackend::Cu126, Version::new([528, 33])),
998 (TorchBackend::Cu125, Version::new([528, 33])),
999 (TorchBackend::Cu124, Version::new([528, 33])),
1000 (TorchBackend::Cu123, Version::new([528, 33])),
1001 (TorchBackend::Cu122, Version::new([528, 33])),
1002 (TorchBackend::Cu121, Version::new([528, 33])),
1003 (TorchBackend::Cu120, Version::new([528, 33])),
1004 (TorchBackend::Cu118, Version::new([452, 39])),
1007 (TorchBackend::Cu117, Version::new([452, 39])),
1008 (TorchBackend::Cu116, Version::new([452, 39])),
1009 (TorchBackend::Cu115, Version::new([452, 39])),
1010 (TorchBackend::Cu114, Version::new([452, 39])),
1011 (TorchBackend::Cu113, Version::new([452, 39])),
1012 (TorchBackend::Cu112, Version::new([452, 39])),
1013 (TorchBackend::Cu111, Version::new([452, 39])),
1014 (TorchBackend::Cu110, Version::new([451, 22])),
1015 (TorchBackend::Cu102, Version::new([441, 22])),
1018 (TorchBackend::Cu101, Version::new([418, 96])),
1019 (TorchBackend::Cu100, Version::new([411, 31])),
1020 (TorchBackend::Cu92, Version::new([398, 26])),
1021 (TorchBackend::Cu91, Version::new([391, 29])),
1022 (TorchBackend::Cu90, Version::new([385, 54])),
1023 (TorchBackend::Cu80, Version::new([376, 51])),
1024 ]
1025});
1026
1027static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 79]> =
1040 LazyLock::new(|| {
1041 [
1042 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx900),
1044 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx906),
1045 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx908),
1046 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx90a),
1047 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx942),
1048 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx950),
1049 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1030),
1050 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1100),
1051 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1101),
1052 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1102),
1053 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1200),
1054 (TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1201),
1055 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx900),
1057 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx906),
1058 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx908),
1059 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx90a),
1060 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx942),
1061 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx950),
1062 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1030),
1063 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1100),
1064 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1101),
1065 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1102),
1066 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1200),
1067 (TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1201),
1068 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900),
1070 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906),
1071 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx908),
1072 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx90a),
1073 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx942),
1074 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1030),
1075 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1100),
1076 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1101),
1077 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1102),
1078 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1200),
1079 (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1201),
1080 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
1082 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
1083 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908),
1084 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a),
1085 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942),
1086 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030),
1087 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100),
1088 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101),
1089 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102),
1090 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200),
1091 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201),
1092 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900),
1094 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906),
1095 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908),
1096 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a),
1097 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942),
1098 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030),
1099 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100),
1100 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101),
1101 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102),
1102 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200),
1103 (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201),
1104 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900),
1106 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906),
1107 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908),
1108 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a),
1109 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030),
1110 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100),
1111 (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942),
1112 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900),
1114 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906),
1115 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908),
1116 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a),
1117 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942),
1118 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030),
1119 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100),
1120 (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101),
1121 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900),
1123 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906),
1124 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908),
1125 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a),
1126 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030),
1127 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100),
1128 (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942),
1129 ]
1130 });
1131
1132static PYTORCH_CPU_INDEX_URL: LazyLock<IndexUrl> =
1133 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
1134static PYTORCH_CU130_INDEX_URL: LazyLock<IndexUrl> =
1135 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu130").unwrap());
1136static PYTORCH_CU129_INDEX_URL: LazyLock<IndexUrl> =
1137 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu129").unwrap());
1138static PYTORCH_CU128_INDEX_URL: LazyLock<IndexUrl> =
1139 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap());
1140static PYTORCH_CU126_INDEX_URL: LazyLock<IndexUrl> =
1141 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap());
1142static PYTORCH_CU125_INDEX_URL: LazyLock<IndexUrl> =
1143 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap());
1144static PYTORCH_CU124_INDEX_URL: LazyLock<IndexUrl> =
1145 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap());
1146static PYTORCH_CU123_INDEX_URL: LazyLock<IndexUrl> =
1147 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap());
1148static PYTORCH_CU122_INDEX_URL: LazyLock<IndexUrl> =
1149 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap());
1150static PYTORCH_CU121_INDEX_URL: LazyLock<IndexUrl> =
1151 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap());
1152static PYTORCH_CU120_INDEX_URL: LazyLock<IndexUrl> =
1153 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap());
1154static PYTORCH_CU118_INDEX_URL: LazyLock<IndexUrl> =
1155 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap());
1156static PYTORCH_CU117_INDEX_URL: LazyLock<IndexUrl> =
1157 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap());
1158static PYTORCH_CU116_INDEX_URL: LazyLock<IndexUrl> =
1159 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap());
1160static PYTORCH_CU115_INDEX_URL: LazyLock<IndexUrl> =
1161 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap());
1162static PYTORCH_CU114_INDEX_URL: LazyLock<IndexUrl> =
1163 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap());
1164static PYTORCH_CU113_INDEX_URL: LazyLock<IndexUrl> =
1165 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap());
1166static PYTORCH_CU112_INDEX_URL: LazyLock<IndexUrl> =
1167 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap());
1168static PYTORCH_CU111_INDEX_URL: LazyLock<IndexUrl> =
1169 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap());
1170static PYTORCH_CU110_INDEX_URL: LazyLock<IndexUrl> =
1171 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap());
1172static PYTORCH_CU102_INDEX_URL: LazyLock<IndexUrl> =
1173 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap());
1174static PYTORCH_CU101_INDEX_URL: LazyLock<IndexUrl> =
1175 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap());
1176static PYTORCH_CU100_INDEX_URL: LazyLock<IndexUrl> =
1177 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap());
1178static PYTORCH_CU92_INDEX_URL: LazyLock<IndexUrl> =
1179 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap());
1180static PYTORCH_CU91_INDEX_URL: LazyLock<IndexUrl> =
1181 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap());
1182static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
1183 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
1184static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
1185 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
1186static PYTORCH_ROCM71_INDEX_URL: LazyLock<IndexUrl> =
1187 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm7.1").unwrap());
1188static PYTORCH_ROCM70_INDEX_URL: LazyLock<IndexUrl> =
1189 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm7.0").unwrap());
1190static PYTORCH_ROCM64_INDEX_URL: LazyLock<IndexUrl> =
1191 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap());
1192static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
1193 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
1194static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> =
1195 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2.4").unwrap());
1196static PYTORCH_ROCM62_INDEX_URL: LazyLock<IndexUrl> =
1197 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2").unwrap());
1198static PYTORCH_ROCM61_INDEX_URL: LazyLock<IndexUrl> =
1199 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.1").unwrap());
1200static PYTORCH_ROCM60_INDEX_URL: LazyLock<IndexUrl> =
1201 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.0").unwrap());
1202static PYTORCH_ROCM57_INDEX_URL: LazyLock<IndexUrl> =
1203 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.7").unwrap());
1204static PYTORCH_ROCM56_INDEX_URL: LazyLock<IndexUrl> =
1205 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.6").unwrap());
1206static PYTORCH_ROCM55_INDEX_URL: LazyLock<IndexUrl> =
1207 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.5").unwrap());
1208static PYTORCH_ROCM542_INDEX_URL: LazyLock<IndexUrl> =
1209 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4.2").unwrap());
1210static PYTORCH_ROCM54_INDEX_URL: LazyLock<IndexUrl> =
1211 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4").unwrap());
1212static PYTORCH_ROCM53_INDEX_URL: LazyLock<IndexUrl> =
1213 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.3").unwrap());
1214static PYTORCH_ROCM52_INDEX_URL: LazyLock<IndexUrl> =
1215 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.2").unwrap());
1216static PYTORCH_ROCM511_INDEX_URL: LazyLock<IndexUrl> =
1217 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.1.1").unwrap());
1218static PYTORCH_ROCM42_INDEX_URL: LazyLock<IndexUrl> =
1219 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.2").unwrap());
1220static PYTORCH_ROCM41_INDEX_URL: LazyLock<IndexUrl> =
1221 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.1").unwrap());
1222static PYTORCH_ROCM401_INDEX_URL: LazyLock<IndexUrl> =
1223 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.0.1").unwrap());
1224static PYTORCH_XPU_INDEX_URL: LazyLock<IndexUrl> =
1225 LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/xpu").unwrap());
1226
1227static PYX_API_BASE_URL: LazyLock<Cow<'static, str>> = LazyLock::new(|| {
1228 std::env::var(EnvVars::PYX_API_URL)
1229 .map(Cow::Owned)
1230 .unwrap_or(Cow::Borrowed("https://api.pyx.dev"))
1231});
1232static PYX_CPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1233 let api_base_url = &*PYX_API_BASE_URL;
1234 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cpu")).unwrap()
1235});
1236static PYX_CU130_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1237 let api_base_url = &*PYX_API_BASE_URL;
1238 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu130")).unwrap()
1239});
1240static PYX_CU129_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1241 let api_base_url = &*PYX_API_BASE_URL;
1242 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu129")).unwrap()
1243});
1244static PYX_CU128_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1245 let api_base_url = &*PYX_API_BASE_URL;
1246 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu128")).unwrap()
1247});
1248static PYX_CU126_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1249 let api_base_url = &*PYX_API_BASE_URL;
1250 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu126")).unwrap()
1251});
1252static PYX_CU125_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1253 let api_base_url = &*PYX_API_BASE_URL;
1254 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu125")).unwrap()
1255});
1256static PYX_CU124_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1257 let api_base_url = &*PYX_API_BASE_URL;
1258 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu124")).unwrap()
1259});
1260static PYX_CU123_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1261 let api_base_url = &*PYX_API_BASE_URL;
1262 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu123")).unwrap()
1263});
1264static PYX_CU122_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1265 let api_base_url = &*PYX_API_BASE_URL;
1266 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu122")).unwrap()
1267});
1268static PYX_CU121_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1269 let api_base_url = &*PYX_API_BASE_URL;
1270 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu121")).unwrap()
1271});
1272static PYX_CU120_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1273 let api_base_url = &*PYX_API_BASE_URL;
1274 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu120")).unwrap()
1275});
1276static PYX_CU118_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1277 let api_base_url = &*PYX_API_BASE_URL;
1278 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu118")).unwrap()
1279});
1280static PYX_CU117_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1281 let api_base_url = &*PYX_API_BASE_URL;
1282 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu117")).unwrap()
1283});
1284static PYX_CU116_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1285 let api_base_url = &*PYX_API_BASE_URL;
1286 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu116")).unwrap()
1287});
1288static PYX_CU115_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1289 let api_base_url = &*PYX_API_BASE_URL;
1290 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu115")).unwrap()
1291});
1292static PYX_CU114_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1293 let api_base_url = &*PYX_API_BASE_URL;
1294 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu114")).unwrap()
1295});
1296static PYX_CU113_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1297 let api_base_url = &*PYX_API_BASE_URL;
1298 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu113")).unwrap()
1299});
1300static PYX_CU112_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1301 let api_base_url = &*PYX_API_BASE_URL;
1302 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu112")).unwrap()
1303});
1304static PYX_CU111_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1305 let api_base_url = &*PYX_API_BASE_URL;
1306 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu111")).unwrap()
1307});
1308static PYX_CU110_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1309 let api_base_url = &*PYX_API_BASE_URL;
1310 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu110")).unwrap()
1311});
1312static PYX_CU102_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1313 let api_base_url = &*PYX_API_BASE_URL;
1314 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu102")).unwrap()
1315});
1316static PYX_CU101_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1317 let api_base_url = &*PYX_API_BASE_URL;
1318 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu101")).unwrap()
1319});
1320static PYX_CU100_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1321 let api_base_url = &*PYX_API_BASE_URL;
1322 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu100")).unwrap()
1323});
1324static PYX_CU92_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1325 let api_base_url = &*PYX_API_BASE_URL;
1326 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu92")).unwrap()
1327});
1328static PYX_CU91_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1329 let api_base_url = &*PYX_API_BASE_URL;
1330 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu91")).unwrap()
1331});
1332static PYX_CU90_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1333 let api_base_url = &*PYX_API_BASE_URL;
1334 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu90")).unwrap()
1335});
1336static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1337 let api_base_url = &*PYX_API_BASE_URL;
1338 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
1339});
1340static PYX_ROCM71_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1341 let api_base_url = &*PYX_API_BASE_URL;
1342 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm7.1")).unwrap()
1343});
1344static PYX_ROCM70_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1345 let api_base_url = &*PYX_API_BASE_URL;
1346 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm7.0")).unwrap()
1347});
1348static PYX_ROCM64_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1349 let api_base_url = &*PYX_API_BASE_URL;
1350 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap()
1351});
1352static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1353 let api_base_url = &*PYX_API_BASE_URL;
1354 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()
1355});
1356static PYX_ROCM624_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1357 let api_base_url = &*PYX_API_BASE_URL;
1358 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2.4")).unwrap()
1359});
1360static PYX_ROCM62_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1361 let api_base_url = &*PYX_API_BASE_URL;
1362 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2")).unwrap()
1363});
1364static PYX_ROCM61_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1365 let api_base_url = &*PYX_API_BASE_URL;
1366 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.1")).unwrap()
1367});
1368static PYX_ROCM60_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1369 let api_base_url = &*PYX_API_BASE_URL;
1370 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.0")).unwrap()
1371});
1372static PYX_ROCM57_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1373 let api_base_url = &*PYX_API_BASE_URL;
1374 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.7")).unwrap()
1375});
1376static PYX_ROCM56_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1377 let api_base_url = &*PYX_API_BASE_URL;
1378 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.6")).unwrap()
1379});
1380static PYX_ROCM55_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1381 let api_base_url = &*PYX_API_BASE_URL;
1382 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.5")).unwrap()
1383});
1384static PYX_ROCM542_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1385 let api_base_url = &*PYX_API_BASE_URL;
1386 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4.2")).unwrap()
1387});
1388static PYX_ROCM54_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1389 let api_base_url = &*PYX_API_BASE_URL;
1390 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4")).unwrap()
1391});
1392static PYX_ROCM53_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1393 let api_base_url = &*PYX_API_BASE_URL;
1394 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.3")).unwrap()
1395});
1396static PYX_ROCM52_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1397 let api_base_url = &*PYX_API_BASE_URL;
1398 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.2")).unwrap()
1399});
1400static PYX_ROCM511_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1401 let api_base_url = &*PYX_API_BASE_URL;
1402 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.1.1")).unwrap()
1403});
1404static PYX_ROCM42_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1405 let api_base_url = &*PYX_API_BASE_URL;
1406 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.2")).unwrap()
1407});
1408static PYX_ROCM41_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1409 let api_base_url = &*PYX_API_BASE_URL;
1410 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.1")).unwrap()
1411});
1412static PYX_ROCM401_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1413 let api_base_url = &*PYX_API_BASE_URL;
1414 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.0.1")).unwrap()
1415});
1416static PYX_XPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
1417 let api_base_url = &*PYX_API_BASE_URL;
1418 IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/xpu")).unwrap()
1419});