Struct rust_bert::pipelines::translation::TranslationConfig
source · [−]pub struct TranslationConfig {Show 22 fields
pub model_type: ModelType,
pub model_resource: Box<dyn ResourceProvider + Send>,
pub config_resource: Box<dyn ResourceProvider + Send>,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub source_languages: HashSet<Language>,
pub target_languages: HashSet<Language>,
pub min_length: i64,
pub max_length: i64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub device: Device,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
}Expand description
Configuration for text translation
Contains information regarding the model to load, mirrors the GenerationConfig, with a different set of default parameters and sets the device to place the model on.
Fields
model_type: ModelTypeModel type used for translation
model_resource: Box<dyn ResourceProvider + Send>Model weights resource
config_resource: Box<dyn ResourceProvider + Send>Config resource
vocab_resource: Box<dyn ResourceProvider + Send>Vocab resource
merges_resource: Box<dyn ResourceProvider + Send>Merges resource
source_languages: HashSet<Language>Supported source languages
target_languages: HashSet<Language>Supported target languages
min_length: i64Minimum sequence length (default: 0)
max_length: i64Maximum sequence length (default: 20)
do_sample: boolSampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
early_stopping: boolEarly stopping flag indicating if the beam search should stop as soon as num_beam hypotheses have been generated (default: false)
num_beams: i64Number of beams for beam search (default: 5)
temperature: f64Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
top_k: i64Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
top_p: f64Top_p value for Nucleus sampling, Holtzman et al.. Keep top tokens until cumulative probability reaches top_p (default: 0.9)
repetition_penalty: f64Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
length_penalty: f64Exponential penalty based on the length of the hypotheses generated (default: 1.0)
no_repeat_ngram_size: i64Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
num_return_sequences: i64Number of sequences to return for each prompt text (default: 1)
device: DeviceDevice to place the model on (default: CUDA/GPU when available)
num_beam_groups: Option<i64>Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
diversity_penalty: Option<f64>Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
Implementations
sourceimpl TranslationConfig
impl TranslationConfig
sourcepub fn new<R, S, T>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>
) -> TranslationConfig where
R: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
T: AsRef<[Language]>,
pub fn new<R, S, T>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>
) -> TranslationConfig where
R: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
T: AsRef<[Language]>,
Create a new TranslationConfiguration from an available language.
Arguments
language-Languageenum value (e.g.Language::EnglishToFrench)device-Deviceto place the model on (CPU/GPU)
Example
use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::TranslationConfig;
use rust_bert::resources::RemoteResource;
use tch::Device;
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
config_resource,
vocab_resource.clone(),
vocab_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);Trait Implementations
sourceimpl From<TranslationConfig> for GenerateConfig
impl From<TranslationConfig> for GenerateConfig
sourcefn from(config: TranslationConfig) -> GenerateConfig
fn from(config: TranslationConfig) -> GenerateConfig
Converts to this type from the input type.
Auto Trait Implementations
impl !RefUnwindSafe for TranslationConfig
impl Send for TranslationConfig
impl !Sync for TranslationConfig
impl Unpin for TranslationConfig
impl !UnwindSafe for TranslationConfig
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
sourceimpl<T> Instrument for T
impl<T> Instrument for T
sourcefn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Instruments this type with the provided Span, returning an
Instrumented wrapper. Read more