Module longt5

Source
Expand description

§LongT5 (Efficient Text-To-Text Transformer for Long Sequences)

Implementation of the LongT5 language model (LongT5: Efficient Text-To-Text Transformer for Long Sequences Guo, Ainslie, Uthus, Ontanon, Ni, Sung, Yang, 2021). The base model is implemented in the longt5_model::LongT5Model struct. This model includes a language model head: longt5_model::LongT5ForConditionalGeneration implementing the common generation_utils::LanguageGenerator trait shared between the models used for generation (see pipelines for more information).

§Model set-up and pre-trained weights loading

All models expect the following resources:

  • Configuration file expected to have a structure following the Transformers library
  • Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the .bin weights to the .ot format.
  • T5Tokenizer using a spiece.model sentence piece model

Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.

use tch::{nn, Device};
use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration};
use rust_bert::resources::{LocalResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::T5Tokenizer;

let config_resource = LocalResource {
    local_path: PathBuf::from("path/to/config.json"),
};
let sentence_piece_resource = LocalResource {
    local_path: PathBuf::from("path/to/spiece.model"),
};
let weights_resource = LocalResource {
    local_path: PathBuf::from("path/to/model.ot"),
};
let config_path = config_resource.get_local_path()?;
let spiece_path = sentence_piece_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;

let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = T5Tokenizer::from_file(spiece_path.to_str().unwrap(), true);
let config = LongT5Config::from_file(config_path);
let longt5_model = LongT5ForConditionalGeneration::new(&vs.root(), &config);
vs.load(weights_path)?;

Structs§

LongT5Config
LongT5 model configuration
LongT5ConfigResources
LongT5 Pretrained model config files
LongT5ForConditionalGeneration
LongT5 Model for conditional generation
LongT5Generator
LongT5Model
LongT5 Base model
LongT5ModelResources
LongT5 Pretrained model weight files
LongT5VocabResources
LongT5 Pretrained model vocab files

Type Aliases§

LayerState