1pub mod backend;
2pub mod buffer;
3
4#[cfg(feature = "cpu")]
5pub mod cpu;
6
7#[cfg(feature = "vulkan")]
8pub mod vulkan;
9#[cfg(feature = "vulkan")]
10pub use ash::vk;
11
12#[cfg(feature = "cuda")]
13pub mod cuda;
14
15#[cfg(feature = "metal")]
16pub mod metal;
17
18use yule_core::error::Result;
19
20pub trait ComputeBackend: Send + Sync {
21 fn name(&self) -> &str;
22 fn device_info(&self) -> DeviceInfo;
23 fn allocate(&self, size_bytes: usize) -> Result<BufferHandle>;
24 fn free(&self, handle: BufferHandle) -> Result<()>;
25 fn matmul(&self, a: &BufferHandle, b: &BufferHandle, out: &BufferHandle, m: u32, n: u32, k: u32) -> Result<()>;
26 fn softmax(&self, input: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
27 fn rms_norm(&self, input: &BufferHandle, weight: &BufferHandle, output: &BufferHandle, size: u32, eps: f32) -> Result<()>;
28 fn rope(&self, q: &BufferHandle, k: &BufferHandle, pos: u32, head_dim: u32, freq_base: f32, n_heads_q: u32, n_heads_k: u32) -> Result<()>;
29 fn silu(&self, input: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
30 fn element_mul(&self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
31 fn add(&self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
32 fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()>;
33 fn copy_from_device(&self, handle: &BufferHandle, data: &mut [u8]) -> Result<()>;
34 fn copy_buffer(&self, src: &BufferHandle, dst: &BufferHandle, size: usize) -> Result<()>;
35 fn copy_buffer_offset(
36 &self, src: &BufferHandle, dst: &BufferHandle,
37 src_offset: usize, dst_offset: usize, size: usize,
38 ) -> Result<()>;
39 fn synchronize(&self) -> Result<()>;
40
41 fn attn_score(
43 &self,
44 _q: &BufferHandle,
45 _k_cache: &BufferHandle,
46 _scores: &BufferHandle,
47 _head_dim: u32,
48 _seq_len: u32,
49 _head_offset: u32,
50 _kv_offset: u32,
51 _kv_stride: u32,
52 ) -> Result<()> {
53 Err(yule_core::error::YuleError::Gpu(
54 "attn_score not supported on this backend".into(),
55 ))
56 }
57
58 fn attn_value(
60 &self,
61 _weights: &BufferHandle,
62 _v_cache: &BufferHandle,
63 _output: &BufferHandle,
64 _head_dim: u32,
65 _seq_len: u32,
66 _kv_offset: u32,
67 _kv_stride: u32,
68 _out_offset: u32,
69 ) -> Result<()> {
70 Err(yule_core::error::YuleError::Gpu(
71 "attn_value not supported on this backend".into(),
72 ))
73 }
74
75 fn quantized_matmul(
79 &self,
80 _weights: &BufferHandle,
81 _input: &BufferHandle,
82 _output: &BufferHandle,
83 _n_rows: u32,
84 _n_cols: u32,
85 _dtype: yule_core::dtype::DType,
86 ) -> Result<()> {
87 Err(yule_core::error::YuleError::Gpu(
88 "quantized_matmul not supported on this backend".into(),
89 ))
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
94pub struct BufferHandle(pub u64);
95
96#[derive(Debug, Clone)]
97pub struct DeviceInfo {
98 pub name: String,
99 pub backend: BackendKind,
100 pub memory_bytes: u64,
101 pub compute_units: u32,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum BackendKind {
106 Cpu,
107 Vulkan,
108 Cuda,
109 Metal,
110}
111
112pub fn detect_backends() -> Vec<BackendKind> {
113 let mut backends = vec![BackendKind::Cpu];
114
115 #[cfg(feature = "vulkan")]
116 if vulkan::VulkanBackend::is_available() {
117 backends.push(BackendKind::Vulkan);
118 }
119
120 #[cfg(feature = "cuda")]
121 backends.push(BackendKind::Cuda);
122
123 #[cfg(feature = "metal")]
124 backends.push(BackendKind::Metal);
125
126 backends
127}
128
129pub fn create_backend(kind: BackendKind) -> Result<Box<dyn ComputeBackend>> {
130 match kind {
131 #[cfg(feature = "cpu")]
132 BackendKind::Cpu => Ok(Box::new(cpu::CpuBackend::new())),
133 #[cfg(feature = "vulkan")]
134 BackendKind::Vulkan => Ok(Box::new(vulkan::VulkanBackend::new()?)),
135 _ => Err(yule_core::error::YuleError::Gpu(format!(
136 "backend {kind:?} not compiled in — enable the feature flag"
137 ))),
138 }
139}