Skip to main content

yule_gpu/
lib.rs

1pub mod backend;
2pub mod buffer;
3
4#[cfg(feature = "cpu")]
5pub mod cpu;
6
7#[cfg(feature = "vulkan")]
8pub mod vulkan;
9#[cfg(feature = "vulkan")]
10pub use ash::vk;
11
12#[cfg(feature = "cuda")]
13pub mod cuda;
14
15#[cfg(feature = "metal")]
16pub mod metal;
17
18use yule_core::error::Result;
19
20pub trait ComputeBackend: Send + Sync {
21    fn name(&self) -> &str;
22    fn device_info(&self) -> DeviceInfo;
23    fn allocate(&self, size_bytes: usize) -> Result<BufferHandle>;
24    fn free(&self, handle: BufferHandle) -> Result<()>;
25    fn matmul(&self, a: &BufferHandle, b: &BufferHandle, out: &BufferHandle, m: u32, n: u32, k: u32) -> Result<()>;
26    fn softmax(&self, input: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
27    fn rms_norm(&self, input: &BufferHandle, weight: &BufferHandle, output: &BufferHandle, size: u32, eps: f32) -> Result<()>;
28    fn rope(&self, q: &BufferHandle, k: &BufferHandle, pos: u32, head_dim: u32, freq_base: f32, n_heads_q: u32, n_heads_k: u32) -> Result<()>;
29    fn silu(&self, input: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
30    fn element_mul(&self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
31    fn add(&self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32) -> Result<()>;
32    fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()>;
33    fn copy_from_device(&self, handle: &BufferHandle, data: &mut [u8]) -> Result<()>;
34    fn copy_buffer(&self, src: &BufferHandle, dst: &BufferHandle, size: usize) -> Result<()>;
35    fn copy_buffer_offset(
36        &self, src: &BufferHandle, dst: &BufferHandle,
37        src_offset: usize, dst_offset: usize, size: usize,
38    ) -> Result<()>;
39    fn synchronize(&self) -> Result<()>;
40
41    /// Compute attention scores: scores[pos] = Q[head_offset..] · K_cache[pos*kv_stride+kv_offset..]
42    fn attn_score(
43        &self,
44        _q: &BufferHandle,
45        _k_cache: &BufferHandle,
46        _scores: &BufferHandle,
47        _head_dim: u32,
48        _seq_len: u32,
49        _head_offset: u32,
50        _kv_offset: u32,
51        _kv_stride: u32,
52    ) -> Result<()> {
53        Err(yule_core::error::YuleError::Gpu(
54            "attn_score not supported on this backend".into(),
55        ))
56    }
57
58    /// Compute weighted value aggregation: out[out_offset+d] = sum_pos(weights[pos] * V[pos*kv_stride+kv_offset+d])
59    fn attn_value(
60        &self,
61        _weights: &BufferHandle,
62        _v_cache: &BufferHandle,
63        _output: &BufferHandle,
64        _head_dim: u32,
65        _seq_len: u32,
66        _kv_offset: u32,
67        _kv_stride: u32,
68        _out_offset: u32,
69    ) -> Result<()> {
70        Err(yule_core::error::YuleError::Gpu(
71            "attn_value not supported on this backend".into(),
72        ))
73    }
74
75    /// Fused dequantize + matrix-vector multiply for quantized weights.
76    /// GPU backends override this for fused VRAM kernels.
77    /// Default impl falls back to regular matmul (assumes pre-dequantized data).
78    fn quantized_matmul(
79        &self,
80        _weights: &BufferHandle,
81        _input: &BufferHandle,
82        _output: &BufferHandle,
83        _n_rows: u32,
84        _n_cols: u32,
85        _dtype: yule_core::dtype::DType,
86    ) -> Result<()> {
87        Err(yule_core::error::YuleError::Gpu(
88            "quantized_matmul not supported on this backend".into(),
89        ))
90    }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
94pub struct BufferHandle(pub u64);
95
96#[derive(Debug, Clone)]
97pub struct DeviceInfo {
98    pub name: String,
99    pub backend: BackendKind,
100    pub memory_bytes: u64,
101    pub compute_units: u32,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum BackendKind {
106    Cpu,
107    Vulkan,
108    Cuda,
109    Metal,
110}
111
112pub fn detect_backends() -> Vec<BackendKind> {
113    let mut backends = vec![BackendKind::Cpu];
114
115    #[cfg(feature = "vulkan")]
116    if vulkan::VulkanBackend::is_available() {
117        backends.push(BackendKind::Vulkan);
118    }
119
120    #[cfg(feature = "cuda")]
121    backends.push(BackendKind::Cuda);
122
123    #[cfg(feature = "metal")]
124    backends.push(BackendKind::Metal);
125
126    backends
127}
128
129pub fn create_backend(kind: BackendKind) -> Result<Box<dyn ComputeBackend>> {
130    match kind {
131        #[cfg(feature = "cpu")]
132        BackendKind::Cpu => Ok(Box::new(cpu::CpuBackend::new())),
133        #[cfg(feature = "vulkan")]
134        BackendKind::Vulkan => Ok(Box::new(vulkan::VulkanBackend::new()?)),
135        _ => Err(yule_core::error::YuleError::Gpu(format!(
136            "backend {kind:?} not compiled in — enable the feature flag"
137        ))),
138    }
139}