State

Trait State 

Source
pub trait State {
    // Required methods
    fn num_batch(&self) -> usize;
    fn init_shape(&self) -> Shape;
    fn init(&self) -> TensorCpu<f32>;
    fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
    fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
    fn load(
        &self,
        tensor: TensorCpu<f32>,
        batch: usize,
    ) -> Result<(), TensorError>;
    fn back(
        &self,
        batch: usize,
    ) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
    fn write(
        &self,
        tensor: TensorGpu<f32, ReadWrite>,
        batch: usize,
    ) -> Result<(), TensorError>;
    fn read(
        &self,
        batch: usize,
    ) -> Result<TensorGpu<f32, ReadWrite>, TensorError>;
    fn embed(
        &self,
        layer: usize,
        backed: TensorCpu<f32>,
    ) -> Result<TensorCpu<f32>, TensorError>;
}

Required Methods§

Source

fn num_batch(&self) -> usize

Batch number of this state.

Source

fn init_shape(&self) -> Shape

Shape of the initialized one-batch CPU state.

Source

fn init(&self) -> TensorCpu<f32>

Initialize a one-batch state on CPU.

Source

fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>

The part of the state that is used in an att layer.

Source

fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>

The part of the state that is used in an ffn layer.

Source

fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError>

Load a batch of the state from CPU to GPU.

Source

fn back( &self, batch: usize, ) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>

Read back a batch of the state from GPU to CPU.

Source

fn write( &self, tensor: TensorGpu<f32, ReadWrite>, batch: usize, ) -> Result<(), TensorError>

Write into the state from a GPU tensor.

Source

fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError>

Read the state out into a GPU tensor.

Source

fn embed( &self, layer: usize, backed: TensorCpu<f32>, ) -> Result<TensorCpu<f32>, TensorError>

Get an embed vector from a backed state.

Implementors§

Source§

impl State for web_rwkv::runtime::v4::State

Source§

impl State for web_rwkv::runtime::v5::State

Source§

impl State for web_rwkv::runtime::v6::State

Source§

impl State for web_rwkv::runtime::v7::State