pub trait BertEmbedding {
    // Required methods
    fn new<'p, P>(p: P, config: &BertConfig) -> Self
       where P: Borrow<Path<'p>>;
    fn forward_t(
        &self,
        input_ids: Option<&Tensor>,
        token_type_ids: Option<&Tensor>,
        position_ids: Option<&Tensor>,
        input_embeds: Option<&Tensor>,
        train: bool
    ) -> Result<Tensor, RustBertError>;
}
Expand description

BertEmbedding trait (for use in BertModel or RoBERTaModel)

Defines an interface for the embedding layers in BERT-based models

Required Methods§

source

fn new<'p, P>(p: P, config: &BertConfig) -> Self
where P: Borrow<Path<'p>>,

source

fn forward_t( &self, input_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>, position_ids: Option<&Tensor>, input_embeds: Option<&Tensor>, train: bool ) -> Result<Tensor, RustBertError>

Object Safety§

This trait is not object safe.

Implementors§