pub trait ComputeBackend: Send + Sync {
Show 19 methods
// Required methods
fn name(&self) -> &str;
fn device_info(&self) -> DeviceInfo;
fn allocate(&self, size_bytes: usize) -> Result<BufferHandle>;
fn free(&self, handle: BufferHandle) -> Result<()>;
fn matmul(
&self,
a: &BufferHandle,
b: &BufferHandle,
out: &BufferHandle,
m: u32,
n: u32,
k: u32,
) -> Result<()>;
fn softmax(
&self,
input: &BufferHandle,
output: &BufferHandle,
size: u32,
) -> Result<()>;
fn rms_norm(
&self,
input: &BufferHandle,
weight: &BufferHandle,
output: &BufferHandle,
size: u32,
eps: f32,
) -> Result<()>;
fn rope(
&self,
q: &BufferHandle,
k: &BufferHandle,
pos: u32,
head_dim: u32,
freq_base: f32,
n_heads_q: u32,
n_heads_k: u32,
) -> Result<()>;
fn silu(
&self,
input: &BufferHandle,
output: &BufferHandle,
size: u32,
) -> Result<()>;
fn element_mul(
&self,
a: &BufferHandle,
b: &BufferHandle,
output: &BufferHandle,
size: u32,
) -> Result<()>;
fn add(
&self,
a: &BufferHandle,
b: &BufferHandle,
output: &BufferHandle,
size: u32,
) -> Result<()>;
fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()>;
fn copy_from_device(
&self,
handle: &BufferHandle,
data: &mut [u8],
) -> Result<()>;
fn copy_buffer(
&self,
src: &BufferHandle,
dst: &BufferHandle,
size: usize,
) -> Result<()>;
fn copy_buffer_offset(
&self,
src: &BufferHandle,
dst: &BufferHandle,
src_offset: usize,
dst_offset: usize,
size: usize,
) -> Result<()>;
fn synchronize(&self) -> Result<()>;
// Provided methods
fn attn_score(
&self,
_q: &BufferHandle,
_k_cache: &BufferHandle,
_scores: &BufferHandle,
_head_dim: u32,
_seq_len: u32,
_head_offset: u32,
_kv_offset: u32,
_kv_stride: u32,
) -> Result<()> { ... }
fn attn_value(
&self,
_weights: &BufferHandle,
_v_cache: &BufferHandle,
_output: &BufferHandle,
_head_dim: u32,
_seq_len: u32,
_kv_offset: u32,
_kv_stride: u32,
_out_offset: u32,
) -> Result<()> { ... }
fn quantized_matmul(
&self,
_weights: &BufferHandle,
_input: &BufferHandle,
_output: &BufferHandle,
_n_rows: u32,
_n_cols: u32,
_dtype: DType,
) -> Result<()> { ... }
}Required Methods§
fn name(&self) -> &str
fn device_info(&self) -> DeviceInfo
fn allocate(&self, size_bytes: usize) -> Result<BufferHandle>
fn free(&self, handle: BufferHandle) -> Result<()>
fn matmul( &self, a: &BufferHandle, b: &BufferHandle, out: &BufferHandle, m: u32, n: u32, k: u32, ) -> Result<()>
fn softmax( &self, input: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>
fn rms_norm( &self, input: &BufferHandle, weight: &BufferHandle, output: &BufferHandle, size: u32, eps: f32, ) -> Result<()>
fn rope( &self, q: &BufferHandle, k: &BufferHandle, pos: u32, head_dim: u32, freq_base: f32, n_heads_q: u32, n_heads_k: u32, ) -> Result<()>
fn silu( &self, input: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>
fn element_mul( &self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>
fn add( &self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>
fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()>
fn copy_from_device(&self, handle: &BufferHandle, data: &mut [u8]) -> Result<()>
fn copy_buffer( &self, src: &BufferHandle, dst: &BufferHandle, size: usize, ) -> Result<()>
fn copy_buffer_offset( &self, src: &BufferHandle, dst: &BufferHandle, src_offset: usize, dst_offset: usize, size: usize, ) -> Result<()>
fn synchronize(&self) -> Result<()>
Provided Methods§
Sourcefn attn_score(
&self,
_q: &BufferHandle,
_k_cache: &BufferHandle,
_scores: &BufferHandle,
_head_dim: u32,
_seq_len: u32,
_head_offset: u32,
_kv_offset: u32,
_kv_stride: u32,
) -> Result<()>
fn attn_score( &self, _q: &BufferHandle, _k_cache: &BufferHandle, _scores: &BufferHandle, _head_dim: u32, _seq_len: u32, _head_offset: u32, _kv_offset: u32, _kv_stride: u32, ) -> Result<()>
Compute attention scores: scores[pos] = Q[head_offset..] · K_cache[pos*kv_stride+kv_offset..]
Sourcefn attn_value(
&self,
_weights: &BufferHandle,
_v_cache: &BufferHandle,
_output: &BufferHandle,
_head_dim: u32,
_seq_len: u32,
_kv_offset: u32,
_kv_stride: u32,
_out_offset: u32,
) -> Result<()>
fn attn_value( &self, _weights: &BufferHandle, _v_cache: &BufferHandle, _output: &BufferHandle, _head_dim: u32, _seq_len: u32, _kv_offset: u32, _kv_stride: u32, _out_offset: u32, ) -> Result<()>
Compute weighted value aggregation: out[out_offset+d] = sum_pos(weights[pos] * V[pos*kv_stride+kv_offset+d])
Sourcefn quantized_matmul(
&self,
_weights: &BufferHandle,
_input: &BufferHandle,
_output: &BufferHandle,
_n_rows: u32,
_n_cols: u32,
_dtype: DType,
) -> Result<()>
fn quantized_matmul( &self, _weights: &BufferHandle, _input: &BufferHandle, _output: &BufferHandle, _n_rows: u32, _n_cols: u32, _dtype: DType, ) -> Result<()>
Fused dequantize + matrix-vector multiply for quantized weights. GPU backends override this for fused VRAM kernels. Default impl falls back to regular matmul (assumes pre-dequantized data).