Skip to main content

ComputeBackend

Trait ComputeBackend 

Source
pub trait ComputeBackend: Send + Sync {
Show 19 methods // Required methods fn name(&self) -> &str; fn device_info(&self) -> DeviceInfo; fn allocate(&self, size_bytes: usize) -> Result<BufferHandle>; fn free(&self, handle: BufferHandle) -> Result<()>; fn matmul( &self, a: &BufferHandle, b: &BufferHandle, out: &BufferHandle, m: u32, n: u32, k: u32, ) -> Result<()>; fn softmax( &self, input: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>; fn rms_norm( &self, input: &BufferHandle, weight: &BufferHandle, output: &BufferHandle, size: u32, eps: f32, ) -> Result<()>; fn rope( &self, q: &BufferHandle, k: &BufferHandle, pos: u32, head_dim: u32, freq_base: f32, n_heads_q: u32, n_heads_k: u32, ) -> Result<()>; fn silu( &self, input: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>; fn element_mul( &self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>; fn add( &self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>; fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()>; fn copy_from_device( &self, handle: &BufferHandle, data: &mut [u8], ) -> Result<()>; fn copy_buffer( &self, src: &BufferHandle, dst: &BufferHandle, size: usize, ) -> Result<()>; fn copy_buffer_offset( &self, src: &BufferHandle, dst: &BufferHandle, src_offset: usize, dst_offset: usize, size: usize, ) -> Result<()>; fn synchronize(&self) -> Result<()>; // Provided methods fn attn_score( &self, _q: &BufferHandle, _k_cache: &BufferHandle, _scores: &BufferHandle, _head_dim: u32, _seq_len: u32, _head_offset: u32, _kv_offset: u32, _kv_stride: u32, ) -> Result<()> { ... } fn attn_value( &self, _weights: &BufferHandle, _v_cache: &BufferHandle, _output: &BufferHandle, _head_dim: u32, _seq_len: u32, _kv_offset: u32, _kv_stride: u32, _out_offset: u32, ) -> Result<()> { ... } fn quantized_matmul( &self, _weights: &BufferHandle, _input: &BufferHandle, _output: &BufferHandle, _n_rows: u32, _n_cols: u32, _dtype: DType, ) -> Result<()> { ... }
}

Required Methods§

Source

fn name(&self) -> &str

Source

fn device_info(&self) -> DeviceInfo

Source

fn allocate(&self, size_bytes: usize) -> Result<BufferHandle>

Source

fn free(&self, handle: BufferHandle) -> Result<()>

Source

fn matmul( &self, a: &BufferHandle, b: &BufferHandle, out: &BufferHandle, m: u32, n: u32, k: u32, ) -> Result<()>

Source

fn softmax( &self, input: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>

Source

fn rms_norm( &self, input: &BufferHandle, weight: &BufferHandle, output: &BufferHandle, size: u32, eps: f32, ) -> Result<()>

Source

fn rope( &self, q: &BufferHandle, k: &BufferHandle, pos: u32, head_dim: u32, freq_base: f32, n_heads_q: u32, n_heads_k: u32, ) -> Result<()>

Source

fn silu( &self, input: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>

Source

fn element_mul( &self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>

Source

fn add( &self, a: &BufferHandle, b: &BufferHandle, output: &BufferHandle, size: u32, ) -> Result<()>

Source

fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()>

Source

fn copy_from_device(&self, handle: &BufferHandle, data: &mut [u8]) -> Result<()>

Source

fn copy_buffer( &self, src: &BufferHandle, dst: &BufferHandle, size: usize, ) -> Result<()>

Source

fn copy_buffer_offset( &self, src: &BufferHandle, dst: &BufferHandle, src_offset: usize, dst_offset: usize, size: usize, ) -> Result<()>

Source

fn synchronize(&self) -> Result<()>

Provided Methods§

Source

fn attn_score( &self, _q: &BufferHandle, _k_cache: &BufferHandle, _scores: &BufferHandle, _head_dim: u32, _seq_len: u32, _head_offset: u32, _kv_offset: u32, _kv_stride: u32, ) -> Result<()>

Compute attention scores: scores[pos] = Q[head_offset..] · K_cache[pos*kv_stride+kv_offset..]

Source

fn attn_value( &self, _weights: &BufferHandle, _v_cache: &BufferHandle, _output: &BufferHandle, _head_dim: u32, _seq_len: u32, _kv_offset: u32, _kv_stride: u32, _out_offset: u32, ) -> Result<()>

Compute weighted value aggregation: out[out_offset+d] = sum_pos(weights[pos] * V[pos*kv_stride+kv_offset+d])

Source

fn quantized_matmul( &self, _weights: &BufferHandle, _input: &BufferHandle, _output: &BufferHandle, _n_rows: u32, _n_cols: u32, _dtype: DType, ) -> Result<()>

Fused dequantize + matrix-vector multiply for quantized weights. GPU backends override this for fused VRAM kernels. Default impl falls back to regular matmul (assumes pre-dequantized data).

Implementors§