rust_bert/models/t5/
t5_model.rs

1// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::borrow::Borrow;
14
15use serde::{Deserialize, Serialize};
16use tch::nn::{embedding, LinearConfig};
17use tch::{nn, Device, Tensor};
18
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21    PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::pipelines::translation::Language;
25use crate::t5::attention::LayerState;
26use crate::t5::encoder::T5Stack;
27use crate::{Config, RustBertError};
28
29/// # T5 Pretrained model weight files
30pub struct T5ModelResources;
31
32/// # T5 Pretrained model config files
33pub struct T5ConfigResources;
34
35/// # T5 Pretrained model vocab files
36pub struct T5VocabResources;
37
38/// # T5 optional prefixes
39pub struct T5Prefix;
40
41/// # T5 source languages pre-sets
42pub struct T5SourceLanguages;
43
44/// # T5 target languages pre-sets
45pub type T5TargetLanguages = T5SourceLanguages;
46
47impl T5ModelResources {
48    /// Shared under Apache 2.0 license by the T5 Authors at <https://github.com/google-research/text-to-text-transfer-transformer>. Modified with conversion to C-array format.
49    pub const T5_SMALL: (&'static str, &'static str) = (
50        "t5-small/model",
51        "https://huggingface.co/t5-small/resolve/main/rust_model.ot",
52    );
53    /// Shared under Apache 2.0 license by the T5 Authors at <https://github.com/google-research/text-to-text-transfer-transformer>. Modified with conversion to C-array format.
54    pub const T5_BASE: (&'static str, &'static str) = (
55        "t5-base/model",
56        "https://huggingface.co/t5-base/resolve/main/rust_model.ot",
57    );
58    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
59    pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
60        "sentence-t5-base/model",
61        "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
62    );
63}
64
65impl T5ConfigResources {
66    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
67    pub const T5_SMALL: (&'static str, &'static str) = (
68        "t5-small/config",
69        "https://huggingface.co/t5-small/resolve/main/config.json",
70    );
71    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
72    pub const T5_BASE: (&'static str, &'static str) = (
73        "t5-base/config",
74        "https://huggingface.co/t5-base/resolve/main/config.json",
75    );
76    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
77    pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
78        "sentence-t5-base/config",
79        "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
80    );
81}
82
83impl T5VocabResources {
84    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
85    pub const T5_SMALL: (&'static str, &'static str) = (
86        "t5-small/spiece",
87        "https://huggingface.co/t5-small/resolve/main/spiece.model",
88    );
89    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
90    pub const T5_BASE: (&'static str, &'static str) = (
91        "t5-base/spiece",
92        "https://huggingface.co/t5-base/resolve/main/spiece.model",
93    );
94    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
95    pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
96        "sentence-t5-base/spiece",
97        "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
98    );
99}
100
101const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
102
103impl T5SourceLanguages {
104    pub const T5_SMALL: [Language; 3] = T5LANGUAGES;
105    pub const T5_BASE: [Language; 3] = T5LANGUAGES;
106}
107
108impl T5Prefix {
109    pub const ENGLISH2FRENCH: Option<&'static str> = Some("translate English to French:");
110    pub const ENGLISH2GERMAN: Option<&'static str> = Some("translate English to German:");
111}
112
113#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
114#[serde(rename_all = "kebab-case")]
115/// # Options for T5 Feed-forward projection layer
116pub enum FeedForwardProj {
117    /// ReLU
118    Relu,
119    /// Gated geLU
120    GatedGelu,
121}
122
123#[derive(Debug, Serialize, Deserialize, Clone)]
124/// # T5 model configuration
125/// Defines the T5 model architecture (e.g. number of layers, hidden layer size, label mapping...)
126pub struct T5Config {
127    pub dropout_rate: f64,
128    pub d_model: i64,
129    pub d_ff: i64,
130    pub d_kv: i64,
131    pub decoder_start_token_id: Option<i64>,
132    pub bos_token_id: Option<i64>,
133    pub eos_token_id: Option<i64>,
134    pub forced_bos_token_id: Option<i64>,
135    pub forced_eos_token_id: Option<i64>,
136    pub initializer_factor: f64,
137    pub is_encoder_decoder: Option<bool>,
138    pub layer_norm_epsilon: f64,
139    pub num_heads: i64,
140    pub num_layers: i64,
141    pub output_past: Option<bool>,
142    pub pad_token_id: Option<i64>,
143    pub relative_attention_num_buckets: i64,
144    pub relative_attention_max_distance: Option<i64>,
145    pub vocab_size: i64,
146    pub feed_forward_proj: Option<FeedForwardProj>,
147    pub tie_word_embeddings: Option<bool>,
148    pub task_specific_params: Option<TaskSpecificParams>,
149    pub output_attentions: Option<bool>,
150    pub output_hidden_states: Option<bool>,
151}
152
153/// # T5 task-specific configurations
154/// Defines the T5 configuration for summarization and translation tasks
155#[derive(Debug, Serialize, Deserialize, Clone)]
156pub struct TaskSpecificParams {
157    summarization: Summarization,
158    translation_en_to_de: TranslationEnToDe,
159    translation_en_to_fr: TranslationEnToFr,
160    translation_en_to_ro: TranslationEnToRo,
161}
162
163/// # T5 summarization configuration
164#[derive(Debug, Serialize, Deserialize, Clone)]
165pub struct Summarization {
166    early_stopping: bool,
167    length_penalty: f64,
168    max_length: i64,
169    min_length: i64,
170    no_repeat_ngram_size: i64,
171    num_beams: i64,
172    prefix: String,
173}
174
175/// # T5 English to German configuration
176#[derive(Debug, Serialize, Deserialize, Clone)]
177pub struct TranslationEnToDe {
178    early_stopping: bool,
179    max_length: i64,
180    num_beams: i64,
181    prefix: String,
182}
183
184/// # T5 English to French configuration
185#[derive(Debug, Serialize, Deserialize, Clone)]
186pub struct TranslationEnToFr {
187    early_stopping: bool,
188    max_length: i64,
189    num_beams: i64,
190    prefix: String,
191}
192
193/// # T5 English to Romanian configuration
194#[derive(Debug, Serialize, Deserialize, Clone)]
195pub struct TranslationEnToRo {
196    early_stopping: bool,
197    max_length: i64,
198    num_beams: i64,
199    prefix: String,
200}
201
202impl Config for T5Config {}
203
204impl Default for T5Config {
205    fn default() -> Self {
206        T5Config {
207            dropout_rate: 0.1,
208            d_model: 512,
209            d_ff: 2048,
210            d_kv: 64,
211            decoder_start_token_id: None,
212            bos_token_id: None,
213            eos_token_id: Some(1),
214            forced_bos_token_id: None,
215            forced_eos_token_id: None,
216            initializer_factor: 1.0,
217            is_encoder_decoder: None,
218            layer_norm_epsilon: 1e-6,
219            num_heads: 8,
220            num_layers: 6,
221            output_past: None,
222            pad_token_id: Some(0),
223            relative_attention_num_buckets: 32,
224            relative_attention_max_distance: Some(128),
225            vocab_size: 32128,
226            feed_forward_proj: Some(FeedForwardProj::Relu),
227            tie_word_embeddings: None,
228            task_specific_params: None,
229            output_attentions: None,
230            output_hidden_states: None,
231        }
232    }
233}
234
235/// # T5 Base model
236/// Base architecture for T5 model. Usually complemented with a task-specific head, such as a language model head.
237/// It is made of the following blocks:
238/// - `encoder`: `T5Stack` (transformer) made of a vector of encoding layers
239/// - `decoder`: `T5Stack` (transformer)  made of a vector of decoding layers with self attention and encoder cross-attention.
240///     caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
241/// - `embeddings`: `nn::Embedding` Shared embeddings for the encoder and decoder.
242pub struct T5Model {
243    pub(crate) encoder: T5Stack,
244    decoder: T5Stack,
245    pub(crate) embeddings: nn::Embedding,
246}
247
248impl T5Model {
249    /// Build a new `T5Model`
250    ///
251    /// # Arguments
252    ///
253    /// * `p` - Variable store path for the root of the T5 model
254    /// * `config` - `T5Config` object defining the model architecture
255    ///
256    /// # Example
257    ///
258    /// ```no_run
259    /// use rust_bert::t5::{T5Config, T5Model};
260    /// use rust_bert::Config;
261    /// use std::path::Path;
262    /// use tch::{nn, Device};
263    ///
264    /// let config_path = Path::new("path/to/config.json");
265    /// let device = Device::Cpu;
266    /// let p = nn::VarStore::new(device);
267    /// let config = T5Config::from_file(config_path);
268    /// let t5: T5Model = T5Model::new(&p.root() / "t5", &config);
269    /// ```
270    pub fn new<'p, P>(p: P, config: &T5Config) -> T5Model
271    where
272        P: Borrow<nn::Path<'p>>,
273    {
274        let p = p.borrow();
275
276        let embeddings: nn::Embedding = embedding(
277            p / "shared",
278            config.vocab_size,
279            config.d_model,
280            Default::default(),
281        );
282
283        let encoder = T5Stack::new(
284            p / "encoder",
285            config,
286            false,
287            false,
288            config.output_attentions.unwrap_or(false),
289            config.output_hidden_states.unwrap_or(false),
290        );
291        let decoder = T5Stack::new(
292            p / "decoder",
293            config,
294            true,
295            true,
296            config.output_attentions.unwrap_or(false),
297            config.output_hidden_states.unwrap_or(false),
298        );
299
300        T5Model {
301            encoder,
302            decoder,
303            embeddings,
304        }
305    }
306
307    /// Forward pass through the model
308    ///
309    /// # Arguments
310    ///
311    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided.
312    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
313    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided.
314    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
315    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
316    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
317    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
318    /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
319    /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
320    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
321    ///
322    /// # Returns
323    ///
324    /// * `T5ModelOutput` containing:
325    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
326    ///   - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
327    ///   - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
328    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
329    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
330    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
331    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
332    ///
333    /// # Example
334    ///
335    /// ```no_run
336    /// # use tch::{nn, Device, Tensor, no_grad};
337    /// # use rust_bert::Config;
338    /// # use std::path::Path;
339    /// # use tch::kind::Kind::{Int64, Double};
340    /// use rust_bert::t5::{T5Config, T5Model};
341    /// # let config_path = Path::new("path/to/config.json");
342    /// # let vocab_path = Path::new("path/to/vocab.txt");
343    /// # let device = Device::Cpu;
344    /// # let vs = nn::VarStore::new(device);
345    /// # let config = T5Config::from_file(config_path);
346    /// # let t5_model: T5Model = T5Model::new(&vs.root(), &config);
347    /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
348    /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
349    /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
350    /// let encoder_attention_mask =
351    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
352    /// let decoder_attention_mask =
353    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
354    ///
355    /// let model_output = no_grad(|| {
356    ///     t5_model.forward_t(
357    ///         Some(&input_tensor),
358    ///         Some(&encoder_attention_mask),
359    ///         None,
360    ///         Some(&target_tensor),
361    ///         Some(&decoder_attention_mask),
362    ///         None,
363    ///         None,
364    ///         None,
365    ///         false,
366    ///     )
367    /// });
368    /// ```
369    pub fn forward_t(
370        &self,
371        input_ids: Option<&Tensor>,
372        attention_mask: Option<&Tensor>,
373        encoder_outputs: Option<&Tensor>,
374        decoder_input_ids: Option<&Tensor>,
375        decoder_attention_mask: Option<&Tensor>,
376        input_embeds: Option<&Tensor>,
377        decoder_input_embeds: Option<&Tensor>,
378        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
379        train: bool,
380    ) -> T5ModelOutput {
381        let calc_encoder_outputs = if encoder_outputs.is_none() {
382            Some(
383                self.encoder
384                    .forward_t(
385                        input_ids,
386                        attention_mask,
387                        None,
388                        None,
389                        input_embeds,
390                        &self.embeddings,
391                        None,
392                        train,
393                    )
394                    .unwrap(),
395            )
396        } else {
397            None
398        };
399
400        let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
401            if let Some(calc_encoder_outputs) = calc_encoder_outputs {
402                (
403                    Some(calc_encoder_outputs.hidden_state),
404                    calc_encoder_outputs.all_hidden_states,
405                    calc_encoder_outputs.all_attentions,
406                )
407            } else {
408                (None, None, None)
409            };
410
411        let encoder_output =
412            encoder_outputs.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
413
414        let decoder_output = self
415            .decoder
416            .forward_t(
417                decoder_input_ids,
418                decoder_attention_mask,
419                Some(encoder_output),
420                attention_mask,
421                decoder_input_embeds,
422                &self.embeddings,
423                old_layer_states,
424                train,
425            )
426            .unwrap();
427        T5ModelOutput {
428            decoder_output: decoder_output.hidden_state,
429            encoder_hidden_state: calc_hidden_states,
430            next_cache: decoder_output.next_cache,
431            all_decoder_hidden_states: decoder_output.all_hidden_states,
432            all_decoder_attentions: decoder_output.all_attentions,
433            all_encoder_hidden_states,
434            all_encoder_attentions,
435        }
436    }
437}
438
439/// # T5 Model for conditional generation
440/// T5 model with a vocabulary decoding head
441/// It is made of the following blocks:
442/// - `base_model`: `T5Model` Base T5 model
443/// - `model_dim`: `f64` representation of the model dimension for scaling of the generated logits
444pub struct T5ForConditionalGeneration {
445    base_model: T5Model,
446    model_dim: f64,
447    tie_word_embeddings: bool,
448    lm_head: Option<nn::Linear>,
449}
450
451impl T5ForConditionalGeneration {
452    /// Build a new `T5ForConditionalGeneration`
453    ///
454    /// # Arguments
455    ///
456    /// * `p` - Variable store path for the root of the BART model
457    /// * `config` - `T5Config` object defining the model architecture
458    ///
459    /// # Example
460    ///
461    /// ```no_run
462    /// use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
463    /// use rust_bert::Config;
464    /// use std::path::Path;
465    /// use tch::{nn, Device};
466    ///
467    /// let config_path = Path::new("path/to/config.json");
468    /// let device = Device::Cpu;
469    /// let p = nn::VarStore::new(device);
470    /// let config = T5Config::from_file(config_path);
471    /// let t5 = T5ForConditionalGeneration::new(&p.root() / "t5", &config);
472    /// ```
473    pub fn new<'p, P>(p: P, config: &T5Config) -> T5ForConditionalGeneration
474    where
475        P: Borrow<nn::Path<'p>>,
476    {
477        let p = p.borrow();
478
479        let base_model = T5Model::new(p, config);
480        let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true);
481
482        let lm_head = if !tie_word_embeddings {
483            Some(nn::linear(
484                p / "lm_head",
485                config.d_model,
486                config.vocab_size,
487                LinearConfig {
488                    bias: false,
489                    ..Default::default()
490                },
491            ))
492        } else {
493            None
494        };
495
496        T5ForConditionalGeneration {
497            base_model,
498            model_dim: config.d_model as f64,
499            tie_word_embeddings,
500            lm_head,
501        }
502    }
503
504    /// Forward pass through the model
505    ///
506    /// # Arguments
507    ///
508    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided.
509    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
510    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided.
511    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
512    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
513    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
514    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
515    /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
516    /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
517    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
518    ///
519    /// # Returns
520    ///
521    /// * `T5ModelOutput` containing:
522    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each sequence position and vocabulary item
523    ///   - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
524    ///   - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
525    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
526    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
527    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
528    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
529    ///
530    /// # Example
531    ///
532    /// ```no_run
533    /// # use tch::{nn, Device, Tensor, no_grad};
534    /// # use rust_bert::Config;
535    /// # use std::path::Path;
536    /// # use tch::kind::Kind::{Int64, Double};
537    /// use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
538    /// # let config_path = Path::new("path/to/config.json");
539    /// # let vocab_path = Path::new("path/to/vocab.txt");
540    /// # let device = Device::Cpu;
541    /// # let vs = nn::VarStore::new(device);
542    /// # let config = T5Config::from_file(config_path);
543    /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config);
544    /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
545    /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
546    /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
547    /// let encoder_attention_mask =
548    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
549    /// let decoder_attention_mask =
550    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
551    ///
552    /// let model_output = no_grad(|| {
553    ///     t5_model.forward_t(
554    ///         Some(&input_tensor),
555    ///         Some(&encoder_attention_mask),
556    ///         None,
557    ///         Some(&target_tensor),
558    ///         Some(&decoder_attention_mask),
559    ///         None,
560    ///         None,
561    ///         None,
562    ///         false,
563    ///     )
564    /// });
565    /// ```
566    pub fn forward_t(
567        &self,
568        input_ids: Option<&Tensor>,
569        attention_mask: Option<&Tensor>,
570        encoder_outputs: Option<&Tensor>,
571        decoder_input_ids: Option<&Tensor>,
572        decoder_attention_mask: Option<&Tensor>,
573        input_embeds: Option<&Tensor>,
574        decoder_input_embeds: Option<&Tensor>,
575        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
576        train: bool,
577    ) -> T5ModelOutput {
578        let base_model_output = self.base_model.forward_t(
579            input_ids,
580            attention_mask,
581            encoder_outputs,
582            decoder_input_ids,
583            decoder_attention_mask,
584            input_embeds,
585            decoder_input_embeds,
586            old_layer_states,
587            train,
588        );
589
590        let lm_logits = if self.tie_word_embeddings {
591            base_model_output
592                .decoder_output
593                .linear::<Tensor>(&self.base_model.embeddings.ws, None)
594                * (self.model_dim.powf(-0.5))
595        } else {
596            base_model_output
597                .decoder_output
598                .apply(self.lm_head.as_ref().unwrap())
599        };
600
601        T5ModelOutput {
602            decoder_output: lm_logits,
603            ..base_model_output
604        }
605    }
606
607    pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
608        self.base_model
609            .encoder
610            .forward_t(
611                Some(input_ids),
612                attention_mask,
613                None,
614                None,
615                None,
616                &self.base_model.embeddings,
617                None,
618                false,
619            )
620            .unwrap()
621            .hidden_state
622    }
623}
624
625/// # T5 for sentence embeddings
626/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
627pub struct T5ForSentenceEmbeddings {
628    embeddings: nn::Embedding,
629    encoder: T5Stack,
630}
631
632impl T5ForSentenceEmbeddings {
633    /// Build a new `T5ForSentenceEmbeddings`
634    ///
635    /// # Arguments
636    ///
637    /// * `p` - Variable store path for the root of the BART model
638    /// * `config` - `T5Config` object defining the model architecture
639    ///
640    /// It consists of only an encoder (there is no decoder).
641    pub fn new<'p, P>(p: P, config: &T5Config) -> Self
642    where
643        P: Borrow<nn::Path<'p>>,
644    {
645        let p = p.borrow();
646
647        let embeddings: nn::Embedding = embedding(
648            p / "shared",
649            config.vocab_size,
650            config.d_model,
651            Default::default(),
652        );
653
654        let encoder = T5Stack::new(
655            p / "encoder",
656            config,
657            false,
658            false,
659            config.output_attentions.unwrap_or(false),
660            config.output_hidden_states.unwrap_or(false),
661        );
662
663        Self {
664            embeddings,
665            encoder,
666        }
667    }
668
669    /// Forward pass through the model
670    ///
671    /// # Arguments
672    ///
673    /// * `input_ids` - Input of shape (*batch size*, *source_sequence_length*).
674    /// * `mask` - Attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
675    ///
676    /// # Returns
677    ///
678    /// * Tuple containing:
679    ///   - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
680    ///   - `Option<Vec<Tensor>>` of length *num_encoder_layers* of shape (*batch size*, *target_sequence_length*, *hidden_size*)  representing attention weights for all layers of the encoder
681    pub fn forward(
682        &self,
683        input_ids: &Tensor,
684        mask: &Tensor,
685    ) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
686        let transformer_output = self.encoder.forward_t(
687            Some(input_ids),
688            Some(mask),
689            None,
690            None,
691            None,
692            &self.embeddings,
693            None,
694            false,
695        )?;
696        Ok((
697            transformer_output.hidden_state,
698            transformer_output.all_attentions,
699        ))
700    }
701}
702
703/// Container holding a T5 model output. The decoder output may hold the hidden state of
704/// the last layer of the decoder, or may hold logits for a custom head module after the
705/// decoder (e.g. for language modeling tasks)
706pub struct T5ModelOutput {
707    /// Hidden state of the last layer of the decoder, or logits for a custom head
708    /// module after the decoder (e.g. for language modeling tasks)
709    pub decoder_output: Tensor,
710    /// Hidden state for the last layer of the encoder if they are calculated, otherwise None
711    pub encoder_hidden_state: Option<Tensor>,
712    /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
713    pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
714    /// Hidden states for all layers of the decoder
715    pub all_decoder_hidden_states: Option<Vec<Tensor>>,
716    /// Attention weights for all layers of the decoder
717    pub all_decoder_attentions: Option<Vec<Tensor>>,
718    /// Hidden states for all layers of the encoder
719    pub all_encoder_hidden_states: Option<Vec<Tensor>>,
720    /// Attention weights for all layers of the encoder
721    pub all_encoder_attentions: Option<Vec<Tensor>>,
722}
723
724pub struct T5Generator {
725    model: T5ForConditionalGeneration,
726    tokenizer: TokenizerOption,
727    var_store: nn::VarStore,
728    generate_config: GenerateConfig,
729    bos_token_id: Option<i64>,
730    eos_token_ids: Option<Vec<i64>>,
731    pad_token_id: Option<i64>,
732    is_encoder_decoder: bool,
733    vocab_size: i64,
734    decoder_start_id: Option<i64>,
735    max_position_embeddings: i64,
736}
737
738impl T5Generator {
739    pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
740        let vocab_path = generate_config.vocab_resource.get_local_path()?;
741
742        let tokenizer = TokenizerOption::from_file(
743            ModelType::T5,
744            vocab_path.to_str().unwrap(),
745            None,
746            false,
747            None,
748            None,
749        )?;
750
751        Self::new_with_tokenizer(generate_config, tokenizer)
752    }
753
754    pub fn new_with_tokenizer(
755        generate_config: GenerateConfig,
756        tokenizer: TokenizerOption,
757    ) -> Result<T5Generator, RustBertError> {
758        let config_path = generate_config.config_resource.get_local_path()?;
759        let device = generate_config.device;
760
761        generate_config.validate();
762        let mut var_store = nn::VarStore::new(device);
763
764        let config = T5Config::from_file(config_path);
765        let model = T5ForConditionalGeneration::new(var_store.root(), &config);
766        crate::resources::load_weights(
767            &generate_config.model_resource,
768            &mut var_store,
769            generate_config.kind,
770            device,
771        )?;
772
773        let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
774        let eos_token_ids = Some(match config.eos_token_id {
775            Some(value) => vec![value],
776            None => vec![1],
777        });
778        let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
779        let vocab_size = config.vocab_size;
780        let is_encoder_decoder = true;
781        let decoder_start_id = config.decoder_start_token_id;
782        // T5 do not have an embedding matrix for position IDs and relies on relative positions instead
783        let max_position_embeddings = i64::MAX;
784
785        Ok(T5Generator {
786            model,
787            tokenizer,
788            var_store,
789            generate_config,
790            bos_token_id,
791            eos_token_ids,
792            pad_token_id,
793            is_encoder_decoder,
794            vocab_size,
795            decoder_start_id,
796            max_position_embeddings,
797        })
798    }
799}
800
801impl PrivateLanguageGenerator for T5Generator {
802    fn _get_tokenizer(&self) -> &TokenizerOption {
803        &self.tokenizer
804    }
805    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
806        &mut self.tokenizer
807    }
808    fn get_device(&self) -> Device {
809        self.var_store.device()
810    }
811    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
812        Ok(&mut self.var_store)
813    }
814    fn get_config(&self) -> &GenerateConfig {
815        &self.generate_config
816    }
817    fn get_bos_id(&self) -> Option<i64> {
818        self.bos_token_id
819    }
820    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
821        self.eos_token_ids.as_ref()
822    }
823    fn get_pad_id(&self) -> Option<i64> {
824        self.pad_token_id
825    }
826    fn is_encoder_decoder(&self) -> bool {
827        self.is_encoder_decoder
828    }
829    fn get_vocab_size(&self) -> i64 {
830        self.vocab_size
831    }
832    fn get_decoder_start_id(&self) -> Option<i64> {
833        self.decoder_start_id
834    }
835    fn get_max_positions_embeddings(&self) -> Option<i64> {
836        Some(self.max_position_embeddings)
837    }
838    fn forward_t(
839        &self,
840        input_ids: Option<&Tensor>,
841        cache: Cache,
842        attention_mask: Option<&Tensor>,
843        _token_type_ids: Option<&Tensor>,
844        _position_ids: Option<&Tensor>,
845        _input_embeds: Option<&Tensor>,
846        encoder_outputs: Option<&Tensor>,
847        decoder_input_ids: Option<&Tensor>,
848        train: bool,
849    ) -> Result<LMModelOutput, RustBertError> {
850        let base_model_output = match cache {
851            Cache::T5Cache(cached_layer_states) => self.model.forward_t(
852                input_ids,
853                attention_mask,
854                encoder_outputs,
855                decoder_input_ids,
856                None,
857                None,
858                None,
859                cached_layer_states,
860                train,
861            ),
862            Cache::None => self.model.forward_t(
863                input_ids,
864                attention_mask,
865                encoder_outputs,
866                decoder_input_ids,
867                None,
868                None,
869                None,
870                None,
871                train,
872            ),
873            _ => {
874                return Err(RustBertError::ValueError(
875                    "Cache not compatible with T5 Model".into(),
876                ));
877            }
878        };
879
880        Ok(LMModelOutput {
881            lm_logits: base_model_output.decoder_output,
882            cache: Cache::T5Cache(base_model_output.next_cache),
883        })
884    }
885    fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
886        Some(self.model.encode(input_ids, attention_mask))
887    }
888
889    fn prepare_inputs_for_generation<'a>(
890        &self,
891        input_ids: Tensor,
892        encoder_outputs: Option<&'a Tensor>,
893        past: Cache,
894        attention_mask: Tensor,
895    ) -> PreparedInput<'a> {
896        match past {
897            Cache::T5Cache(past) => PreparedInput {
898                prepared_input: None,
899                prepared_attention_mask: Some(attention_mask),
900                prepared_encoder_output: encoder_outputs,
901                prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
902                prepared_position_ids: None,
903                prepared_past: Cache::T5Cache(past),
904            },
905            Cache::None => PreparedInput {
906                prepared_input: None,
907                prepared_attention_mask: Some(attention_mask),
908                prepared_encoder_output: encoder_outputs,
909                prepared_decoder_input: Some(input_ids),
910                prepared_position_ids: None,
911                prepared_past: Cache::T5Cache(None),
912            },
913            _ => panic!("Cache type incompatible with T5"),
914        }
915    }
916
917    fn reorder_cache(
918        &self,
919        past: &mut Cache,
920        encoder_outputs: Option<Tensor>,
921        beam_indices: &Tensor,
922    ) -> Option<Tensor> {
923        match past {
924            Cache::T5Cache(old_cache_option) => match old_cache_option {
925                Some(old_cache) => {
926                    for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
927                        if self_layer_state.is_some() {
928                            self_layer_state
929                                .as_mut()
930                                .unwrap()
931                                .reorder_cache(beam_indices)
932                        };
933                        if encoder_layer_state.is_some() {
934                            encoder_layer_state
935                                .as_mut()
936                                .unwrap()
937                                .reorder_cache(beam_indices)
938                        };
939                    }
940                }
941                None => {}
942            },
943            Cache::None => {}
944            _ => {
945                panic!("Invalid cache for T5 model");
946            }
947        };
948        encoder_outputs
949    }
950}
951
952impl LanguageGenerator for T5Generator {}