[][src]Struct rust_bert::pipelines::translation::TranslationConfig

pub struct TranslationConfig {
    pub model_resource: Resource,
    pub config_resource: Resource,
    pub vocab_resource: Resource,
    pub merges_resource: Resource,
    pub min_length: u64,
    pub max_length: u64,
    pub do_sample: bool,
    pub early_stopping: bool,
    pub num_beams: u64,
    pub temperature: f64,
    pub top_k: u64,
    pub top_p: f64,
    pub repetition_penalty: f64,
    pub length_penalty: f64,
    pub no_repeat_ngram_size: u64,
    pub num_return_sequences: u64,
    pub device: Device,
    pub prefix: Option<String>,
}

Configuration for text translation

Contains information regarding the model to load, mirrors the GenerationConfig, with a different set of default parameters and sets the device to place the model on.

Fields

model_resource: Resource

Model weights resource (default: pretrained BART model on CNN-DM)

config_resource: Resource

Config resource (default: pretrained BART model on CNN-DM)

vocab_resource: Resource

Vocab resource (default: pretrained BART model on CNN-DM)

merges_resource: Resource

Merges resource (default: pretrained BART model on CNN-DM)

min_length: u64

Minimum sequence length (default: 0)

max_length: u64

Maximum sequence length (default: 20)

do_sample: bool

Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)

early_stopping: bool

Early stopping flag indicating if the beam search should stop as soon as num_beam hypotheses have been generated (default: false)

num_beams: u64

Number of beams for beam search (default: 5)

temperature: f64

Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)

top_k: u64

Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)

top_p: f64

Top_p value for Nucleus sampling, Holtzman et al.. Keep top tokens until cumulative probability reaches top_p (default: 0.9)

repetition_penalty: f64

Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)

length_penalty: f64

Exponential penalty based on the length of the hypotheses generated (default: 1.0)

no_repeat_ngram_size: u64

Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)

num_return_sequences: u64

Number of sequences to return for each prompt text (default: 1)

device: Device

Device to place the model on (default: CUDA/GPU when available)

prefix: Option<String>

Prefix to append translation inputs with

Implementations

impl TranslationConfig[src]

pub fn new(language: Language, device: Device) -> TranslationConfig[src]

Create a new TranslationCondiguration from an available language.

Arguments

  • language - Language enum value (e.g. Language::EnglishToFrench)
  • device - Device to place the model on (CPU/GPU)

Example

use rust_bert::pipelines::translation::{TranslationConfig, Language};
use tch::Device;

let translation_config =  TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());

pub fn new_from_resources(
    model_resource: Resource,
    config_resource: Resource,
    vocab_resource: Resource,
    sentence_piece_resource: Resource,
    prefix: Option<String>,
    device: Device
) -> TranslationConfig
[src]

Create a new TranslationCondiguration from custom (e.g. local) resources.

Arguments

  • model_resource - Resource pointing to the model
  • config_resource - Resource pointing to the configuration
  • vocab_resource - Resource pointing to the vocabulary
  • sentence_piece_resource - Resource pointing to the sentence piece model of the source language
  • device - Device to place the model on (CPU/GPU)

Example

use rust_bert::pipelines::translation::TranslationConfig;
use tch::Device;
use rust_bert::resources::{Resource, LocalResource};
use std::path::PathBuf;

let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json") });
let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot") });
let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json") });
let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model") });

let translation_config =  TranslationConfig::new_from_resources(model_resource,
                                           config_resource,
                                           vocab_resource,
                                           sentence_piece_resource,
                                           Some(">>fr<<".to_string()),
                                           Device::cuda_if_available());

Auto Trait Implementations

Blanket Implementations

impl<T> Any for T where
    T: 'static + ?Sized
[src]

impl<T> Borrow<T> for T where
    T: ?Sized
[src]

impl<T> BorrowMut<T> for T where
    T: ?Sized
[src]

impl<T> From<T> for T[src]

impl<T, U> Into<U> for T where
    U: From<T>, 
[src]

impl<T, U> TryFrom<U> for T where
    U: Into<T>, 
[src]

type Error = Infallible

The type returned in the event of a conversion error.

impl<T, U> TryInto<U> for T where
    U: TryFrom<T>, 
[src]

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.