Struct rust_bert::pipelines::translation::TranslationConfig
source · pub struct TranslationConfig {Show 23 fields
pub model_type: ModelType,
pub model_resource: ModelResource,
pub config_resource: Box<dyn ResourceProvider + Send>,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
pub source_languages: HashSet<Language>,
pub target_languages: HashSet<Language>,
pub min_length: i64,
pub max_length: Option<i64>,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub device: Device,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub kind: Option<Kind>,
}
Expand description
Configuration for text translation
Contains information regarding the model to load, mirrors the GenerationConfig, with a different set of default parameters and sets the device to place the model on.
Fields§
§model_type: ModelType
Model type used for translation
model_resource: ModelResource
Model weights resource
config_resource: Box<dyn ResourceProvider + Send>
Config resource
vocab_resource: Box<dyn ResourceProvider + Send>
Vocab resource
merges_resource: Option<Box<dyn ResourceProvider + Send>>
Merges resource
source_languages: HashSet<Language>
Supported source languages
target_languages: HashSet<Language>
Supported target languages
min_length: i64
Minimum sequence length (default: 0)
max_length: Option<i64>
Maximum sequence length (default: 512)
do_sample: bool
Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
early_stopping: bool
Early stopping flag indicating if the beam search should stop as soon as num_beam
hypotheses have been generated (default: false)
num_beams: i64
Number of beams for beam search (default: 5)
temperature: f64
Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
top_k: i64
Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
top_p: f64
Top_p value for Nucleus sampling, Holtzman et al.. Keep top tokens until cumulative probability reaches top_p (default: 0.9)
repetition_penalty: f64
Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
length_penalty: f64
Exponential penalty based on the length of the hypotheses generated (default: 1.0)
no_repeat_ngram_size: i64
Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
num_return_sequences: i64
Number of sequences to return for each prompt text (default: 1)
device: Device
Device to place the model on (default: CUDA/GPU when available)
num_beam_groups: Option<i64>
Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
diversity_penalty: Option<f64>
Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
kind: Option<Kind>
Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
Implementations§
source§impl TranslationConfig
impl TranslationConfig
sourcepub fn new<RC, RV, S, T>(
model_type: ModelType,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>
) -> TranslationConfigwhere
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
T: AsRef<[Language]>,
pub fn new<RC, RV, S, T>(
model_type: ModelType,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>
) -> TranslationConfigwhere
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
T: AsRef<[Language]>,
Create a new TranslationConfiguration
from an available language.
Arguments
language
-Language
enum value (e.g.Language::EnglishToFrench
)device
-Device
to place the model on (CPU/GPU)
Example
use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::TranslationConfig;
use rust_bert::resources::RemoteResource;
use tch::Device;
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
MarianModelResources::ROMANCE2ENGLISH,
)));
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
config_resource,
vocab_resource,
Some(spm_resource),
source_languages,
target_languages,
Device::cuda_if_available(),
);