pub trait State {
// Required methods
fn num_batch(&self) -> usize;
fn init_shape(&self) -> Shape;
fn init(&self) -> TensorCpu<f32>;
fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
fn load(
&self,
tensor: TensorCpu<f32>,
batch: usize,
) -> Result<(), TensorError>;
fn back(
&self,
batch: usize,
) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
fn write(
&self,
tensor: TensorGpu<f32, ReadWrite>,
batch: usize,
) -> Result<(), TensorError>;
fn read(
&self,
batch: usize,
) -> Result<TensorGpu<f32, ReadWrite>, TensorError>;
fn embed(
&self,
layer: usize,
backed: TensorCpu<f32>,
) -> Result<TensorCpu<f32>, TensorError>;
}Required Methods§
Sourcefn init_shape(&self) -> Shape
fn init_shape(&self) -> Shape
Shape of the initialized one-batch CPU state.
Sourcefn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>
fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>
The part of the state that is used in an att layer.
Sourcefn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>
fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>
The part of the state that is used in an ffn layer.
Sourcefn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError>
fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError>
Load a batch of the state from CPU to GPU.
Sourcefn back(
&self,
batch: usize,
) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>
fn back( &self, batch: usize, ) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>
Read back a batch of the state from GPU to CPU.
Sourcefn write(
&self,
tensor: TensorGpu<f32, ReadWrite>,
batch: usize,
) -> Result<(), TensorError>
fn write( &self, tensor: TensorGpu<f32, ReadWrite>, batch: usize, ) -> Result<(), TensorError>
Write into the state from a GPU tensor.