tenflowers_core/device/
mod.rs1pub mod async_execution;
2pub mod context;
3pub mod placement;
4
5pub use context::{
6 CpuContext, DeviceAllocator, DeviceContext, DeviceKernel, DeviceManager, DeviceProperties,
7 DeviceStream, KernelArgs, KernelParam, DEVICE_MANAGER,
8};
9
10#[cfg(feature = "gpu")]
11pub use context::{get_gpu_context, GpuContext, GpuContextInfo};
12
13#[cfg(any(feature = "gpu", feature = "cudnn"))]
14pub use context::{get_enhanced_gpu_context, EnhancedGpuContext, GpuBackend};
15
16#[cfg(feature = "serialize")]
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
20#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
21pub enum Device {
22 #[default]
23 Cpu,
24 #[cfg(feature = "gpu")]
25 Gpu(usize),
26 #[cfg(feature = "rocm")]
27 Rocm(usize),
28}
29
30impl Device {
31 pub fn is_cpu(&self) -> bool {
32 matches!(self, Device::Cpu)
33 }
34
35 #[cfg(feature = "gpu")]
36 pub fn is_gpu(&self) -> bool {
37 matches!(self, Device::Gpu(_))
38 }
39
40 #[cfg(feature = "rocm")]
41 pub fn is_rocm(&self) -> bool {
42 matches!(self, Device::Rocm(_))
43 }
44
45 pub fn id(&self) -> usize {
46 match self {
47 Device::Cpu => 0,
48 #[cfg(feature = "gpu")]
49 Device::Gpu(id) => *id,
50 #[cfg(feature = "rocm")]
51 Device::Rocm(id) => *id,
52 }
53 }
54
55 #[allow(clippy::should_implement_trait)]
57 pub fn from_str(s: &str) -> Result<Self, String> {
58 let s = s.trim().to_lowercase();
59
60 if s == "cpu" {
61 return Ok(Device::Cpu);
62 }
63
64 #[cfg(feature = "gpu")]
65 {
66 if s.starts_with("gpu") {
67 if s == "gpu" {
68 return Ok(Device::Gpu(0));
69 }
70 if let Some(id_str) = s.strip_prefix("gpu:") {
71 match id_str.parse::<usize>() {
72 Ok(id) => return Ok(Device::Gpu(id)),
73 Err(_) => return Err(format!("Invalid GPU ID: {}", id_str)),
74 }
75 }
76 }
77 }
78
79 #[cfg(feature = "rocm")]
80 {
81 if s.starts_with("rocm") || s.starts_with("amd") {
82 if s == "rocm" || s == "amd" {
83 return Ok(Device::Rocm(0));
84 }
85 if let Some(id_str) = s.strip_prefix("rocm:") {
86 match id_str.parse::<usize>() {
87 Ok(id) => return Ok(Device::Rocm(id)),
88 Err(_) => return Err(format!("Invalid ROCm device ID: {}", id_str)),
89 }
90 }
91 if let Some(id_str) = s.strip_prefix("amd:") {
92 match id_str.parse::<usize>() {
93 Ok(id) => return Ok(Device::Rocm(id)),
94 Err(_) => return Err(format!("Invalid AMD GPU ID: {}", id_str)),
95 }
96 }
97 }
98 }
99
100 Err(format!("Invalid device string: {s}"))
101 }
102
103 #[cfg(feature = "gpu")]
105 pub fn best_gpu() -> Result<Self, String> {
106 Self::try_gpu(0)
108 }
109
110 #[cfg(feature = "gpu")]
112 pub fn try_gpu(gpu_id: usize) -> Result<Self, String> {
113 Ok(Device::Gpu(gpu_id))
116 }
117
118 #[cfg(not(feature = "gpu"))]
120 pub fn best_gpu() -> Result<Self, String> {
121 Err("GPU support not compiled".to_string())
122 }
123
124 #[cfg(not(feature = "gpu"))]
126 pub fn try_gpu(_gpu_id: usize) -> Result<Self, String> {
127 Err("GPU support not compiled".to_string())
128 }
129}
130
131impl std::fmt::Display for Device {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 Device::Cpu => write!(f, "cpu"),
135 #[cfg(feature = "gpu")]
136 Device::Gpu(id) => write!(f, "gpu:{}", id),
137 #[cfg(feature = "rocm")]
138 Device::Rocm(id) => write!(f, "rocm:{}", id),
139 }
140 }
141}