Skip to main content

tenflowers_core/device/
mod.rs

1pub mod async_execution;
2pub mod context;
3pub mod placement;
4
5pub use context::{
6    CpuContext, DeviceAllocator, DeviceContext, DeviceKernel, DeviceManager, DeviceProperties,
7    DeviceStream, KernelArgs, KernelParam, DEVICE_MANAGER,
8};
9
10#[cfg(feature = "gpu")]
11pub use context::{get_gpu_context, GpuContext, GpuContextInfo};
12
13#[cfg(any(feature = "gpu", feature = "cudnn"))]
14pub use context::{get_enhanced_gpu_context, EnhancedGpuContext, GpuBackend};
15
16#[cfg(feature = "serialize")]
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
20#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
21pub enum Device {
22    #[default]
23    Cpu,
24    #[cfg(feature = "gpu")]
25    Gpu(usize),
26    #[cfg(feature = "rocm")]
27    Rocm(usize),
28}
29
30impl Device {
31    pub fn is_cpu(&self) -> bool {
32        matches!(self, Device::Cpu)
33    }
34
35    #[cfg(feature = "gpu")]
36    pub fn is_gpu(&self) -> bool {
37        matches!(self, Device::Gpu(_))
38    }
39
40    #[cfg(feature = "rocm")]
41    pub fn is_rocm(&self) -> bool {
42        matches!(self, Device::Rocm(_))
43    }
44
45    pub fn id(&self) -> usize {
46        match self {
47            Device::Cpu => 0,
48            #[cfg(feature = "gpu")]
49            Device::Gpu(id) => *id,
50            #[cfg(feature = "rocm")]
51            Device::Rocm(id) => *id,
52        }
53    }
54
55    /// Parse a device string (e.g., "cpu", "gpu:0", "gpu:1")
56    #[allow(clippy::should_implement_trait)]
57    pub fn from_str(s: &str) -> Result<Self, String> {
58        let s = s.trim().to_lowercase();
59
60        if s == "cpu" {
61            return Ok(Device::Cpu);
62        }
63
64        #[cfg(feature = "gpu")]
65        {
66            if s.starts_with("gpu") {
67                if s == "gpu" {
68                    return Ok(Device::Gpu(0));
69                }
70                if let Some(id_str) = s.strip_prefix("gpu:") {
71                    match id_str.parse::<usize>() {
72                        Ok(id) => return Ok(Device::Gpu(id)),
73                        Err(_) => return Err(format!("Invalid GPU ID: {}", id_str)),
74                    }
75                }
76            }
77        }
78
79        #[cfg(feature = "rocm")]
80        {
81            if s.starts_with("rocm") || s.starts_with("amd") {
82                if s == "rocm" || s == "amd" {
83                    return Ok(Device::Rocm(0));
84                }
85                if let Some(id_str) = s.strip_prefix("rocm:") {
86                    match id_str.parse::<usize>() {
87                        Ok(id) => return Ok(Device::Rocm(id)),
88                        Err(_) => return Err(format!("Invalid ROCm device ID: {}", id_str)),
89                    }
90                }
91                if let Some(id_str) = s.strip_prefix("amd:") {
92                    match id_str.parse::<usize>() {
93                        Ok(id) => return Ok(Device::Rocm(id)),
94                        Err(_) => return Err(format!("Invalid AMD GPU ID: {}", id_str)),
95                    }
96                }
97            }
98        }
99
100        Err(format!("Invalid device string: {s}"))
101    }
102
103    /// Get the best available GPU device
104    #[cfg(feature = "gpu")]
105    pub fn best_gpu() -> Result<Self, String> {
106        // Try to get GPU 0 as the default best GPU
107        Self::try_gpu(0)
108    }
109
110    /// Try to create a GPU device with the specified ID
111    #[cfg(feature = "gpu")]
112    pub fn try_gpu(gpu_id: usize) -> Result<Self, String> {
113        // For now, assume GPU is available - in a full implementation,
114        // this would check actual GPU availability
115        Ok(Device::Gpu(gpu_id))
116    }
117
118    /// Get the best available GPU device (CPU fallback when GPU not available)
119    #[cfg(not(feature = "gpu"))]
120    pub fn best_gpu() -> Result<Self, String> {
121        Err("GPU support not compiled".to_string())
122    }
123
124    /// Try to create a GPU device (CPU fallback when GPU not available)
125    #[cfg(not(feature = "gpu"))]
126    pub fn try_gpu(_gpu_id: usize) -> Result<Self, String> {
127        Err("GPU support not compiled".to_string())
128    }
129}
130
131impl std::fmt::Display for Device {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match self {
134            Device::Cpu => write!(f, "cpu"),
135            #[cfg(feature = "gpu")]
136            Device::Gpu(id) => write!(f, "gpu:{}", id),
137            #[cfg(feature = "rocm")]
138            Device::Rocm(id) => write!(f, "rocm:{}", id),
139        }
140    }
141}