[][src]Module rust_bert::marian

Marian

Implementation of the Marian language model (Marian: Fast Neural Machine Translation in {C++} Junczys-Dowmunt, Grundkiewicz, Dwojak, Hoang, Heafield, Neckermann, Seide, Germann, Fikri Aji, Bogoychev, Martins, Birch, 2018). The base model is implemented in the bart::BartModel struct. This model includes a language model head: marian::MarianForConditionalGeneration implementing the common generation::LMHeadModel trait shared between the models used for generation (see pipelines for more information).

Model set-up and pre-trained weights loading

A full working example is provided in examples/translation.rs, run with cargo run --example translation. All models expect the following resources:

  • Configuration file expected to have a structure following the Transformers library
  • Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the .bin weights to the .ot format.
  • MarianTokenizer using a vocab.json vocabulary and spiece.model sentence piece model

Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.

use tch::{nn, Device};
use rust_bert::bart::{BartConfig, BartModel};
use rust_bert::marian::MarianForConditionalGeneration;
use rust_bert::resources::{download_resource, LocalResource, Resource};
use rust_bert::Config;
use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;

let config_resource = Resource::Local(LocalResource {
    local_path: PathBuf::from("path/to/config.json"),
});
let vocab_resource = Resource::Local(LocalResource {
    local_path: PathBuf::from("path/to/vocab.json"),
});
let sentence_piece_resource = Resource::Local(LocalResource {
    local_path: PathBuf::from("path/to/spiece.model"),
});
let weights_resource = Resource::Local(LocalResource {
    local_path: PathBuf::from("path/to/model.ot"),
});
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let spiece_path = download_resource(&sentence_piece_resource)?;
let weights_path = download_resource(&weights_resource)?;

let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = MarianTokenizer::from_files(
    vocab_path.to_str().unwrap(),
    spiece_path.to_str().unwrap(),
    true,
);
let config = BartConfig::from_file(config_path);
let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
vs.load(weights_path)?;

Structs

MarianConfigResources

Marian Pretrained model config files

MarianForConditionalGeneration

Marian Model for conditional generation

MarianModelResources

Marian Pretrained model weight files

MarianPrefix

Marian optional prefixes

MarianSpmResources

Marian Pretrained sentence piece model files

MarianVocabResources

Marian Pretrained model vocab files