Trait safetensors::tensor::View

source ·
pub trait View {
    // Required methods
    fn dtype(&self) -> Dtype;
    fn shape(&self) -> &[usize];
    fn data(&self) -> Cow<'_, [u8]>;
    fn data_len(&self) -> usize;
}
Expand description

The trait necessary to enable safetensors to serialize a tensor If you have an owned tensor like this:

use safetensors::tensor::{View, Dtype};
use std::borrow::Cow;
struct Tensor{ dtype: MyDtype, shape: Vec<usize>, data: Vec<u8>}

impl<'data> View for &'data Tensor{
   fn dtype(&self) -> Dtype{
       self.dtype.into()
   }
   fn shape(&self) -> &[usize]{
        &self.shape
   }
   fn data(&self) -> Cow<[u8]>{
       (&self.data).into()
   }
   fn data_len(&self) -> usize{
       self.data.len()
   }
}

For a borrowed tensor:

use safetensors::tensor::{View, Dtype};
use std::borrow::Cow;
struct Tensor<'data>{ dtype: MyDtype, shape: Vec<usize>, data: &'data[u8]}

impl<'data> View for Tensor<'data>{
   fn dtype(&self) -> Dtype{
       self.dtype.into()
   }
   fn shape(&self) -> &[usize]{
        &self.shape
   }
   fn data(&self) -> Cow<[u8]>{
       self.data.into()
   }
   fn data_len(&self) -> usize{
       self.data.len()
   }
}

Now if you have some unknown buffer that could be on GPU for instance, you can implement the trait to return an owned local buffer containing the data on CPU (needed to write on disk)

use safetensors::tensor::{View, Dtype};
use std::borrow::Cow;

struct Tensor{ dtype: MyDtype, shape: Vec<usize>, data: OpaqueGpu }

impl View for Tensor{
   fn dtype(&self) -> Dtype{
       self.dtype.into()
   }
   fn shape(&self) -> &[usize]{
        &self.shape
   }
   fn data(&self) -> Cow<[u8]>{
       // This copies data from GPU to CPU.
       let data: Vec<u8> = self.data.to_vec();
       data.into()
   }
   fn data_len(&self) -> usize{
       let n: usize = self.shape.iter().product();
       let bytes_per_element = self.dtype.size();
       n * bytes_per_element
   }
}

Required Methods§

source

fn dtype(&self) -> Dtype

The Dtype of the tensor

source

fn shape(&self) -> &[usize]

The shape of the tensor

source

fn data(&self) -> Cow<'_, [u8]>

The data of the tensor

source

fn data_len(&self) -> usize

The length of the data, in bytes. This is necessary as this might be faster to get than data().len() for instance for tensors residing in GPU.

Implementors§

source§

impl<'data> View for &TensorView<'data>