Struct rust_bert::pipelines::generation_utils::GenerateConfig[][src]

pub struct GenerateConfig {
    pub model_resource: Resource,
    pub config_resource: Resource,
    pub vocab_resource: Resource,
    pub merges_resource: Resource,
    pub min_length: i64,
    pub max_length: i64,
    pub do_sample: bool,
    pub early_stopping: bool,
    pub num_beams: i64,
    pub temperature: f64,
    pub top_k: i64,
    pub top_p: f64,
    pub repetition_penalty: f64,
    pub length_penalty: f64,
    pub no_repeat_ngram_size: i64,
    pub num_return_sequences: i64,
    pub num_beam_groups: Option<i64>,
    pub diversity_penalty: Option<f64>,
    pub device: Device,
}

Fields

model_resource: Resource

Model weights resource (default: pretrained GPT2 model)

config_resource: Resource

Config resource (default: pretrained GPT2 model)

vocab_resource: Resource

Vocab resource (default: pretrained GPT2 model)

merges_resource: Resource

Merges resource (default: pretrained GPT2 model)

min_length: i64

Minimum sequence length (default: 0)

max_length: i64

Maximum sequence length (default: 20)

do_sample: bool

Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)

early_stopping: bool

Early stopping flag indicating if the beam search should stop as soon as num_beam hypotheses have been generated (default: false)

num_beams: i64

Number of beams for beam search (default: 5)

temperature: f64

Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)

top_k: i64

Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)

top_p: f64

Top_p value for Nucleus sampling, Holtzman et al.. Keep top tokens until cumulative probability reaches top_p (default: 0.9)

repetition_penalty: f64

Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)

length_penalty: f64

Exponential penalty based on the length of the hypotheses generated (default: 1.0)

no_repeat_ngram_size: i64

Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)

num_return_sequences: i64

Number of sequences to return for each prompt text (default: 1)

num_beam_groups: Option<i64>

Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.

diversity_penalty: Option<f64>

Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)

device: Device

Device to place the model on (default: CUDA/GPU when available)

Trait Implementations

impl Default for GenerateConfig[src]

impl From<ConversationConfig> for GenerateConfig[src]

impl From<SummarizationConfig> for GenerateConfig[src]

impl From<TextGenerationConfig> for GenerateConfig[src]

impl From<TranslationConfig> for GenerateConfig[src]

Auto Trait Implementations

Blanket Implementations

impl<T> Any for T where
    T: 'static + ?Sized
[src]

impl<T> Borrow<T> for T where
    T: ?Sized
[src]

impl<T> BorrowMut<T> for T where
    T: ?Sized
[src]

impl<T> From<T> for T[src]

impl<T> Instrument for T[src]

impl<T, U> Into<U> for T where
    U: From<T>, 
[src]

impl<T> Pointable for T

type Init = T

The type for initializers.

impl<T> Same<T> for T

type Output = T

Should always be Self

impl<T, U> TryFrom<U> for T where
    U: Into<T>, 
[src]

type Error = Infallible

The type returned in the event of a conversion error.

impl<T, U> TryInto<U> for T where
    U: TryFrom<T>, 
[src]

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.

impl<V, T> VZip<V> for T where
    V: MultiLane<T>,