pub struct TranslationConfig {
Show 22 fields pub model_type: ModelType, pub model_resource: Box<dyn ResourceProvider + Send>, pub config_resource: Box<dyn ResourceProvider + Send>, pub vocab_resource: Box<dyn ResourceProvider + Send>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>, pub source_languages: HashSet<Language>, pub target_languages: HashSet<Language>, pub min_length: i64, pub max_length: Option<i64>, pub do_sample: bool, pub early_stopping: bool, pub num_beams: i64, pub temperature: f64, pub top_k: i64, pub top_p: f64, pub repetition_penalty: f64, pub length_penalty: f64, pub no_repeat_ngram_size: i64, pub num_return_sequences: i64, pub device: Device, pub num_beam_groups: Option<i64>, pub diversity_penalty: Option<f64>,
}
Expand description

Configuration for text translation

Contains information regarding the model to load, mirrors the GenerationConfig, with a different set of default parameters and sets the device to place the model on.

Fields§

§model_type: ModelType

Model type used for translation

§model_resource: Box<dyn ResourceProvider + Send>

Model weights resource

§config_resource: Box<dyn ResourceProvider + Send>

Config resource

§vocab_resource: Box<dyn ResourceProvider + Send>

Vocab resource

§merges_resource: Option<Box<dyn ResourceProvider + Send>>

Merges resource

§source_languages: HashSet<Language>

Supported source languages

§target_languages: HashSet<Language>

Supported target languages

§min_length: i64

Minimum sequence length (default: 0)

§max_length: Option<i64>

Maximum sequence length (default: 512)

§do_sample: bool

Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)

§early_stopping: bool

Early stopping flag indicating if the beam search should stop as soon as num_beam hypotheses have been generated (default: false)

§num_beams: i64

Number of beams for beam search (default: 5)

§temperature: f64

Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)

§top_k: i64

Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)

§top_p: f64

Top_p value for Nucleus sampling, Holtzman et al.. Keep top tokens until cumulative probability reaches top_p (default: 0.9)

§repetition_penalty: f64

Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)

§length_penalty: f64

Exponential penalty based on the length of the hypotheses generated (default: 1.0)

§no_repeat_ngram_size: i64

Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)

§num_return_sequences: i64

Number of sequences to return for each prompt text (default: 1)

§device: Device

Device to place the model on (default: CUDA/GPU when available)

§num_beam_groups: Option<i64>

Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.

§diversity_penalty: Option<f64>

Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)

Implementations§

Create a new TranslationConfiguration from an available language.

Arguments
  • language - Language enum value (e.g. Language::EnglishToFrench)
  • device - Device to place the model on (CPU/GPU)
Example
use rust_bert::marian::{
    MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
    MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::TranslationConfig;
use rust_bert::resources::RemoteResource;
use tch::Device;

let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);

let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;

let translation_config = TranslationConfig::new(
    ModelType::Marian,
    model_resource,
    config_resource,
    vocab_resource,
    Some(spm_resource),
    source_languages,
    target_languages,
    Device::cuda_if_available(),
);

Trait Implementations§

Converts to this type from the input type.

Auto Trait Implementations§

Blanket Implementations§

Gets the TypeId of self. Read more
Immutably borrows from an owned value. Read more
Mutably borrows from an owned value. Read more

Returns the argument unchanged.

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Instruments this type with the current Span, returning an Instrumented wrapper. Read more

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

The alignment of pointer.
The type for initializers.
Initializes a with the given initializer. Read more
Dereferences the given pointer. Read more
Mutably dereferences the given pointer. Read more
Drops the object pointed to by the given pointer. Read more
Should always be Self
The type returned in the event of a conversion error.
Performs the conversion.
The type returned in the event of a conversion error.
Performs the conversion.
Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more