rust_bert/models/t5/t5_model.rs
1// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::borrow::Borrow;
14
15use serde::{Deserialize, Serialize};
16use tch::nn::{embedding, LinearConfig};
17use tch::{nn, Device, Tensor};
18
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21 PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::pipelines::translation::Language;
25use crate::t5::attention::LayerState;
26use crate::t5::encoder::T5Stack;
27use crate::{Config, RustBertError};
28
29/// # T5 Pretrained model weight files
30pub struct T5ModelResources;
31
32/// # T5 Pretrained model config files
33pub struct T5ConfigResources;
34
35/// # T5 Pretrained model vocab files
36pub struct T5VocabResources;
37
38/// # T5 optional prefixes
39pub struct T5Prefix;
40
41/// # T5 source languages pre-sets
42pub struct T5SourceLanguages;
43
44/// # T5 target languages pre-sets
45pub type T5TargetLanguages = T5SourceLanguages;
46
47impl T5ModelResources {
48 /// Shared under Apache 2.0 license by the T5 Authors at <https://github.com/google-research/text-to-text-transfer-transformer>. Modified with conversion to C-array format.
49 pub const T5_SMALL: (&'static str, &'static str) = (
50 "t5-small/model",
51 "https://huggingface.co/t5-small/resolve/main/rust_model.ot",
52 );
53 /// Shared under Apache 2.0 license by the T5 Authors at <https://github.com/google-research/text-to-text-transfer-transformer>. Modified with conversion to C-array format.
54 pub const T5_BASE: (&'static str, &'static str) = (
55 "t5-base/model",
56 "https://huggingface.co/t5-base/resolve/main/rust_model.ot",
57 );
58 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
59 pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
60 "sentence-t5-base/model",
61 "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
62 );
63}
64
65impl T5ConfigResources {
66 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
67 pub const T5_SMALL: (&'static str, &'static str) = (
68 "t5-small/config",
69 "https://huggingface.co/t5-small/resolve/main/config.json",
70 );
71 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
72 pub const T5_BASE: (&'static str, &'static str) = (
73 "t5-base/config",
74 "https://huggingface.co/t5-base/resolve/main/config.json",
75 );
76 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
77 pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
78 "sentence-t5-base/config",
79 "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
80 );
81}
82
83impl T5VocabResources {
84 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
85 pub const T5_SMALL: (&'static str, &'static str) = (
86 "t5-small/spiece",
87 "https://huggingface.co/t5-small/resolve/main/spiece.model",
88 );
89 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/text-to-text-transfer-transformer>.
90 pub const T5_BASE: (&'static str, &'static str) = (
91 "t5-base/spiece",
92 "https://huggingface.co/t5-base/resolve/main/spiece.model",
93 );
94 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
95 pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
96 "sentence-t5-base/spiece",
97 "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
98 );
99}
100
101const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
102
103impl T5SourceLanguages {
104 pub const T5_SMALL: [Language; 3] = T5LANGUAGES;
105 pub const T5_BASE: [Language; 3] = T5LANGUAGES;
106}
107
108impl T5Prefix {
109 pub const ENGLISH2FRENCH: Option<&'static str> = Some("translate English to French:");
110 pub const ENGLISH2GERMAN: Option<&'static str> = Some("translate English to German:");
111}
112
113#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
114#[serde(rename_all = "kebab-case")]
115/// # Options for T5 Feed-forward projection layer
116pub enum FeedForwardProj {
117 /// ReLU
118 Relu,
119 /// Gated geLU
120 GatedGelu,
121}
122
123#[derive(Debug, Serialize, Deserialize, Clone)]
124/// # T5 model configuration
125/// Defines the T5 model architecture (e.g. number of layers, hidden layer size, label mapping...)
126pub struct T5Config {
127 pub dropout_rate: f64,
128 pub d_model: i64,
129 pub d_ff: i64,
130 pub d_kv: i64,
131 pub decoder_start_token_id: Option<i64>,
132 pub bos_token_id: Option<i64>,
133 pub eos_token_id: Option<i64>,
134 pub forced_bos_token_id: Option<i64>,
135 pub forced_eos_token_id: Option<i64>,
136 pub initializer_factor: f64,
137 pub is_encoder_decoder: Option<bool>,
138 pub layer_norm_epsilon: f64,
139 pub num_heads: i64,
140 pub num_layers: i64,
141 pub output_past: Option<bool>,
142 pub pad_token_id: Option<i64>,
143 pub relative_attention_num_buckets: i64,
144 pub relative_attention_max_distance: Option<i64>,
145 pub vocab_size: i64,
146 pub feed_forward_proj: Option<FeedForwardProj>,
147 pub tie_word_embeddings: Option<bool>,
148 pub task_specific_params: Option<TaskSpecificParams>,
149 pub output_attentions: Option<bool>,
150 pub output_hidden_states: Option<bool>,
151}
152
153/// # T5 task-specific configurations
154/// Defines the T5 configuration for summarization and translation tasks
155#[derive(Debug, Serialize, Deserialize, Clone)]
156pub struct TaskSpecificParams {
157 summarization: Summarization,
158 translation_en_to_de: TranslationEnToDe,
159 translation_en_to_fr: TranslationEnToFr,
160 translation_en_to_ro: TranslationEnToRo,
161}
162
163/// # T5 summarization configuration
164#[derive(Debug, Serialize, Deserialize, Clone)]
165pub struct Summarization {
166 early_stopping: bool,
167 length_penalty: f64,
168 max_length: i64,
169 min_length: i64,
170 no_repeat_ngram_size: i64,
171 num_beams: i64,
172 prefix: String,
173}
174
175/// # T5 English to German configuration
176#[derive(Debug, Serialize, Deserialize, Clone)]
177pub struct TranslationEnToDe {
178 early_stopping: bool,
179 max_length: i64,
180 num_beams: i64,
181 prefix: String,
182}
183
184/// # T5 English to French configuration
185#[derive(Debug, Serialize, Deserialize, Clone)]
186pub struct TranslationEnToFr {
187 early_stopping: bool,
188 max_length: i64,
189 num_beams: i64,
190 prefix: String,
191}
192
193/// # T5 English to Romanian configuration
194#[derive(Debug, Serialize, Deserialize, Clone)]
195pub struct TranslationEnToRo {
196 early_stopping: bool,
197 max_length: i64,
198 num_beams: i64,
199 prefix: String,
200}
201
202impl Config for T5Config {}
203
204impl Default for T5Config {
205 fn default() -> Self {
206 T5Config {
207 dropout_rate: 0.1,
208 d_model: 512,
209 d_ff: 2048,
210 d_kv: 64,
211 decoder_start_token_id: None,
212 bos_token_id: None,
213 eos_token_id: Some(1),
214 forced_bos_token_id: None,
215 forced_eos_token_id: None,
216 initializer_factor: 1.0,
217 is_encoder_decoder: None,
218 layer_norm_epsilon: 1e-6,
219 num_heads: 8,
220 num_layers: 6,
221 output_past: None,
222 pad_token_id: Some(0),
223 relative_attention_num_buckets: 32,
224 relative_attention_max_distance: Some(128),
225 vocab_size: 32128,
226 feed_forward_proj: Some(FeedForwardProj::Relu),
227 tie_word_embeddings: None,
228 task_specific_params: None,
229 output_attentions: None,
230 output_hidden_states: None,
231 }
232 }
233}
234
235/// # T5 Base model
236/// Base architecture for T5 model. Usually complemented with a task-specific head, such as a language model head.
237/// It is made of the following blocks:
238/// - `encoder`: `T5Stack` (transformer) made of a vector of encoding layers
239/// - `decoder`: `T5Stack` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
240/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
241/// - `embeddings`: `nn::Embedding` Shared embeddings for the encoder and decoder.
242pub struct T5Model {
243 pub(crate) encoder: T5Stack,
244 decoder: T5Stack,
245 pub(crate) embeddings: nn::Embedding,
246}
247
248impl T5Model {
249 /// Build a new `T5Model`
250 ///
251 /// # Arguments
252 ///
253 /// * `p` - Variable store path for the root of the T5 model
254 /// * `config` - `T5Config` object defining the model architecture
255 ///
256 /// # Example
257 ///
258 /// ```no_run
259 /// use rust_bert::t5::{T5Config, T5Model};
260 /// use rust_bert::Config;
261 /// use std::path::Path;
262 /// use tch::{nn, Device};
263 ///
264 /// let config_path = Path::new("path/to/config.json");
265 /// let device = Device::Cpu;
266 /// let p = nn::VarStore::new(device);
267 /// let config = T5Config::from_file(config_path);
268 /// let t5: T5Model = T5Model::new(&p.root() / "t5", &config);
269 /// ```
270 pub fn new<'p, P>(p: P, config: &T5Config) -> T5Model
271 where
272 P: Borrow<nn::Path<'p>>,
273 {
274 let p = p.borrow();
275
276 let embeddings: nn::Embedding = embedding(
277 p / "shared",
278 config.vocab_size,
279 config.d_model,
280 Default::default(),
281 );
282
283 let encoder = T5Stack::new(
284 p / "encoder",
285 config,
286 false,
287 false,
288 config.output_attentions.unwrap_or(false),
289 config.output_hidden_states.unwrap_or(false),
290 );
291 let decoder = T5Stack::new(
292 p / "decoder",
293 config,
294 true,
295 true,
296 config.output_attentions.unwrap_or(false),
297 config.output_hidden_states.unwrap_or(false),
298 );
299
300 T5Model {
301 encoder,
302 decoder,
303 embeddings,
304 }
305 }
306
307 /// Forward pass through the model
308 ///
309 /// # Arguments
310 ///
311 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided.
312 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
313 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided.
314 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
315 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
316 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
317 /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
318 /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
319 /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
320 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
321 ///
322 /// # Returns
323 ///
324 /// * `T5ModelOutput` containing:
325 /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
326 /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
327 /// - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
328 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
329 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
330 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
331 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
332 ///
333 /// # Example
334 ///
335 /// ```no_run
336 /// # use tch::{nn, Device, Tensor, no_grad};
337 /// # use rust_bert::Config;
338 /// # use std::path::Path;
339 /// # use tch::kind::Kind::{Int64, Double};
340 /// use rust_bert::t5::{T5Config, T5Model};
341 /// # let config_path = Path::new("path/to/config.json");
342 /// # let vocab_path = Path::new("path/to/vocab.txt");
343 /// # let device = Device::Cpu;
344 /// # let vs = nn::VarStore::new(device);
345 /// # let config = T5Config::from_file(config_path);
346 /// # let t5_model: T5Model = T5Model::new(&vs.root(), &config);
347 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
348 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
349 /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
350 /// let encoder_attention_mask =
351 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
352 /// let decoder_attention_mask =
353 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
354 ///
355 /// let model_output = no_grad(|| {
356 /// t5_model.forward_t(
357 /// Some(&input_tensor),
358 /// Some(&encoder_attention_mask),
359 /// None,
360 /// Some(&target_tensor),
361 /// Some(&decoder_attention_mask),
362 /// None,
363 /// None,
364 /// None,
365 /// false,
366 /// )
367 /// });
368 /// ```
369 pub fn forward_t(
370 &self,
371 input_ids: Option<&Tensor>,
372 attention_mask: Option<&Tensor>,
373 encoder_outputs: Option<&Tensor>,
374 decoder_input_ids: Option<&Tensor>,
375 decoder_attention_mask: Option<&Tensor>,
376 input_embeds: Option<&Tensor>,
377 decoder_input_embeds: Option<&Tensor>,
378 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
379 train: bool,
380 ) -> T5ModelOutput {
381 let calc_encoder_outputs = if encoder_outputs.is_none() {
382 Some(
383 self.encoder
384 .forward_t(
385 input_ids,
386 attention_mask,
387 None,
388 None,
389 input_embeds,
390 &self.embeddings,
391 None,
392 train,
393 )
394 .unwrap(),
395 )
396 } else {
397 None
398 };
399
400 let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
401 if let Some(calc_encoder_outputs) = calc_encoder_outputs {
402 (
403 Some(calc_encoder_outputs.hidden_state),
404 calc_encoder_outputs.all_hidden_states,
405 calc_encoder_outputs.all_attentions,
406 )
407 } else {
408 (None, None, None)
409 };
410
411 let encoder_output =
412 encoder_outputs.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
413
414 let decoder_output = self
415 .decoder
416 .forward_t(
417 decoder_input_ids,
418 decoder_attention_mask,
419 Some(encoder_output),
420 attention_mask,
421 decoder_input_embeds,
422 &self.embeddings,
423 old_layer_states,
424 train,
425 )
426 .unwrap();
427 T5ModelOutput {
428 decoder_output: decoder_output.hidden_state,
429 encoder_hidden_state: calc_hidden_states,
430 next_cache: decoder_output.next_cache,
431 all_decoder_hidden_states: decoder_output.all_hidden_states,
432 all_decoder_attentions: decoder_output.all_attentions,
433 all_encoder_hidden_states,
434 all_encoder_attentions,
435 }
436 }
437}
438
439/// # T5 Model for conditional generation
440/// T5 model with a vocabulary decoding head
441/// It is made of the following blocks:
442/// - `base_model`: `T5Model` Base T5 model
443/// - `model_dim`: `f64` representation of the model dimension for scaling of the generated logits
444pub struct T5ForConditionalGeneration {
445 base_model: T5Model,
446 model_dim: f64,
447 tie_word_embeddings: bool,
448 lm_head: Option<nn::Linear>,
449}
450
451impl T5ForConditionalGeneration {
452 /// Build a new `T5ForConditionalGeneration`
453 ///
454 /// # Arguments
455 ///
456 /// * `p` - Variable store path for the root of the BART model
457 /// * `config` - `T5Config` object defining the model architecture
458 ///
459 /// # Example
460 ///
461 /// ```no_run
462 /// use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
463 /// use rust_bert::Config;
464 /// use std::path::Path;
465 /// use tch::{nn, Device};
466 ///
467 /// let config_path = Path::new("path/to/config.json");
468 /// let device = Device::Cpu;
469 /// let p = nn::VarStore::new(device);
470 /// let config = T5Config::from_file(config_path);
471 /// let t5 = T5ForConditionalGeneration::new(&p.root() / "t5", &config);
472 /// ```
473 pub fn new<'p, P>(p: P, config: &T5Config) -> T5ForConditionalGeneration
474 where
475 P: Borrow<nn::Path<'p>>,
476 {
477 let p = p.borrow();
478
479 let base_model = T5Model::new(p, config);
480 let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true);
481
482 let lm_head = if !tie_word_embeddings {
483 Some(nn::linear(
484 p / "lm_head",
485 config.d_model,
486 config.vocab_size,
487 LinearConfig {
488 bias: false,
489 ..Default::default()
490 },
491 ))
492 } else {
493 None
494 };
495
496 T5ForConditionalGeneration {
497 base_model,
498 model_dim: config.d_model as f64,
499 tie_word_embeddings,
500 lm_head,
501 }
502 }
503
504 /// Forward pass through the model
505 ///
506 /// # Arguments
507 ///
508 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided.
509 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
510 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided.
511 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
512 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
513 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
514 /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
515 /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
516 /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
517 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
518 ///
519 /// # Returns
520 ///
521 /// * `T5ModelOutput` containing:
522 /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each sequence position and vocabulary item
523 /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
524 /// - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
525 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
526 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
527 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
528 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
529 ///
530 /// # Example
531 ///
532 /// ```no_run
533 /// # use tch::{nn, Device, Tensor, no_grad};
534 /// # use rust_bert::Config;
535 /// # use std::path::Path;
536 /// # use tch::kind::Kind::{Int64, Double};
537 /// use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
538 /// # let config_path = Path::new("path/to/config.json");
539 /// # let vocab_path = Path::new("path/to/vocab.txt");
540 /// # let device = Device::Cpu;
541 /// # let vs = nn::VarStore::new(device);
542 /// # let config = T5Config::from_file(config_path);
543 /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config);
544 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
545 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
546 /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
547 /// let encoder_attention_mask =
548 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
549 /// let decoder_attention_mask =
550 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
551 ///
552 /// let model_output = no_grad(|| {
553 /// t5_model.forward_t(
554 /// Some(&input_tensor),
555 /// Some(&encoder_attention_mask),
556 /// None,
557 /// Some(&target_tensor),
558 /// Some(&decoder_attention_mask),
559 /// None,
560 /// None,
561 /// None,
562 /// false,
563 /// )
564 /// });
565 /// ```
566 pub fn forward_t(
567 &self,
568 input_ids: Option<&Tensor>,
569 attention_mask: Option<&Tensor>,
570 encoder_outputs: Option<&Tensor>,
571 decoder_input_ids: Option<&Tensor>,
572 decoder_attention_mask: Option<&Tensor>,
573 input_embeds: Option<&Tensor>,
574 decoder_input_embeds: Option<&Tensor>,
575 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
576 train: bool,
577 ) -> T5ModelOutput {
578 let base_model_output = self.base_model.forward_t(
579 input_ids,
580 attention_mask,
581 encoder_outputs,
582 decoder_input_ids,
583 decoder_attention_mask,
584 input_embeds,
585 decoder_input_embeds,
586 old_layer_states,
587 train,
588 );
589
590 let lm_logits = if self.tie_word_embeddings {
591 base_model_output
592 .decoder_output
593 .linear::<Tensor>(&self.base_model.embeddings.ws, None)
594 * (self.model_dim.powf(-0.5))
595 } else {
596 base_model_output
597 .decoder_output
598 .apply(self.lm_head.as_ref().unwrap())
599 };
600
601 T5ModelOutput {
602 decoder_output: lm_logits,
603 ..base_model_output
604 }
605 }
606
607 pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
608 self.base_model
609 .encoder
610 .forward_t(
611 Some(input_ids),
612 attention_mask,
613 None,
614 None,
615 None,
616 &self.base_model.embeddings,
617 None,
618 false,
619 )
620 .unwrap()
621 .hidden_state
622 }
623}
624
625/// # T5 for sentence embeddings
626/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
627pub struct T5ForSentenceEmbeddings {
628 embeddings: nn::Embedding,
629 encoder: T5Stack,
630}
631
632impl T5ForSentenceEmbeddings {
633 /// Build a new `T5ForSentenceEmbeddings`
634 ///
635 /// # Arguments
636 ///
637 /// * `p` - Variable store path for the root of the BART model
638 /// * `config` - `T5Config` object defining the model architecture
639 ///
640 /// It consists of only an encoder (there is no decoder).
641 pub fn new<'p, P>(p: P, config: &T5Config) -> Self
642 where
643 P: Borrow<nn::Path<'p>>,
644 {
645 let p = p.borrow();
646
647 let embeddings: nn::Embedding = embedding(
648 p / "shared",
649 config.vocab_size,
650 config.d_model,
651 Default::default(),
652 );
653
654 let encoder = T5Stack::new(
655 p / "encoder",
656 config,
657 false,
658 false,
659 config.output_attentions.unwrap_or(false),
660 config.output_hidden_states.unwrap_or(false),
661 );
662
663 Self {
664 embeddings,
665 encoder,
666 }
667 }
668
669 /// Forward pass through the model
670 ///
671 /// # Arguments
672 ///
673 /// * `input_ids` - Input of shape (*batch size*, *source_sequence_length*).
674 /// * `mask` - Attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
675 ///
676 /// # Returns
677 ///
678 /// * Tuple containing:
679 /// - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
680 /// - `Option<Vec<Tensor>>` of length *num_encoder_layers* of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing attention weights for all layers of the encoder
681 pub fn forward(
682 &self,
683 input_ids: &Tensor,
684 mask: &Tensor,
685 ) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
686 let transformer_output = self.encoder.forward_t(
687 Some(input_ids),
688 Some(mask),
689 None,
690 None,
691 None,
692 &self.embeddings,
693 None,
694 false,
695 )?;
696 Ok((
697 transformer_output.hidden_state,
698 transformer_output.all_attentions,
699 ))
700 }
701}
702
703/// Container holding a T5 model output. The decoder output may hold the hidden state of
704/// the last layer of the decoder, or may hold logits for a custom head module after the
705/// decoder (e.g. for language modeling tasks)
706pub struct T5ModelOutput {
707 /// Hidden state of the last layer of the decoder, or logits for a custom head
708 /// module after the decoder (e.g. for language modeling tasks)
709 pub decoder_output: Tensor,
710 /// Hidden state for the last layer of the encoder if they are calculated, otherwise None
711 pub encoder_hidden_state: Option<Tensor>,
712 /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
713 pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
714 /// Hidden states for all layers of the decoder
715 pub all_decoder_hidden_states: Option<Vec<Tensor>>,
716 /// Attention weights for all layers of the decoder
717 pub all_decoder_attentions: Option<Vec<Tensor>>,
718 /// Hidden states for all layers of the encoder
719 pub all_encoder_hidden_states: Option<Vec<Tensor>>,
720 /// Attention weights for all layers of the encoder
721 pub all_encoder_attentions: Option<Vec<Tensor>>,
722}
723
724pub struct T5Generator {
725 model: T5ForConditionalGeneration,
726 tokenizer: TokenizerOption,
727 var_store: nn::VarStore,
728 generate_config: GenerateConfig,
729 bos_token_id: Option<i64>,
730 eos_token_ids: Option<Vec<i64>>,
731 pad_token_id: Option<i64>,
732 is_encoder_decoder: bool,
733 vocab_size: i64,
734 decoder_start_id: Option<i64>,
735 max_position_embeddings: i64,
736}
737
738impl T5Generator {
739 pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
740 let vocab_path = generate_config.vocab_resource.get_local_path()?;
741
742 let tokenizer = TokenizerOption::from_file(
743 ModelType::T5,
744 vocab_path.to_str().unwrap(),
745 None,
746 false,
747 None,
748 None,
749 )?;
750
751 Self::new_with_tokenizer(generate_config, tokenizer)
752 }
753
754 pub fn new_with_tokenizer(
755 generate_config: GenerateConfig,
756 tokenizer: TokenizerOption,
757 ) -> Result<T5Generator, RustBertError> {
758 let config_path = generate_config.config_resource.get_local_path()?;
759 let device = generate_config.device;
760
761 generate_config.validate();
762 let mut var_store = nn::VarStore::new(device);
763
764 let config = T5Config::from_file(config_path);
765 let model = T5ForConditionalGeneration::new(var_store.root(), &config);
766 crate::resources::load_weights(
767 &generate_config.model_resource,
768 &mut var_store,
769 generate_config.kind,
770 device,
771 )?;
772
773 let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
774 let eos_token_ids = Some(match config.eos_token_id {
775 Some(value) => vec![value],
776 None => vec![1],
777 });
778 let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
779 let vocab_size = config.vocab_size;
780 let is_encoder_decoder = true;
781 let decoder_start_id = config.decoder_start_token_id;
782 // T5 do not have an embedding matrix for position IDs and relies on relative positions instead
783 let max_position_embeddings = i64::MAX;
784
785 Ok(T5Generator {
786 model,
787 tokenizer,
788 var_store,
789 generate_config,
790 bos_token_id,
791 eos_token_ids,
792 pad_token_id,
793 is_encoder_decoder,
794 vocab_size,
795 decoder_start_id,
796 max_position_embeddings,
797 })
798 }
799}
800
801impl PrivateLanguageGenerator for T5Generator {
802 fn _get_tokenizer(&self) -> &TokenizerOption {
803 &self.tokenizer
804 }
805 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
806 &mut self.tokenizer
807 }
808 fn get_device(&self) -> Device {
809 self.var_store.device()
810 }
811 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
812 Ok(&mut self.var_store)
813 }
814 fn get_config(&self) -> &GenerateConfig {
815 &self.generate_config
816 }
817 fn get_bos_id(&self) -> Option<i64> {
818 self.bos_token_id
819 }
820 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
821 self.eos_token_ids.as_ref()
822 }
823 fn get_pad_id(&self) -> Option<i64> {
824 self.pad_token_id
825 }
826 fn is_encoder_decoder(&self) -> bool {
827 self.is_encoder_decoder
828 }
829 fn get_vocab_size(&self) -> i64 {
830 self.vocab_size
831 }
832 fn get_decoder_start_id(&self) -> Option<i64> {
833 self.decoder_start_id
834 }
835 fn get_max_positions_embeddings(&self) -> Option<i64> {
836 Some(self.max_position_embeddings)
837 }
838 fn forward_t(
839 &self,
840 input_ids: Option<&Tensor>,
841 cache: Cache,
842 attention_mask: Option<&Tensor>,
843 _token_type_ids: Option<&Tensor>,
844 _position_ids: Option<&Tensor>,
845 _input_embeds: Option<&Tensor>,
846 encoder_outputs: Option<&Tensor>,
847 decoder_input_ids: Option<&Tensor>,
848 train: bool,
849 ) -> Result<LMModelOutput, RustBertError> {
850 let base_model_output = match cache {
851 Cache::T5Cache(cached_layer_states) => self.model.forward_t(
852 input_ids,
853 attention_mask,
854 encoder_outputs,
855 decoder_input_ids,
856 None,
857 None,
858 None,
859 cached_layer_states,
860 train,
861 ),
862 Cache::None => self.model.forward_t(
863 input_ids,
864 attention_mask,
865 encoder_outputs,
866 decoder_input_ids,
867 None,
868 None,
869 None,
870 None,
871 train,
872 ),
873 _ => {
874 return Err(RustBertError::ValueError(
875 "Cache not compatible with T5 Model".into(),
876 ));
877 }
878 };
879
880 Ok(LMModelOutput {
881 lm_logits: base_model_output.decoder_output,
882 cache: Cache::T5Cache(base_model_output.next_cache),
883 })
884 }
885 fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
886 Some(self.model.encode(input_ids, attention_mask))
887 }
888
889 fn prepare_inputs_for_generation<'a>(
890 &self,
891 input_ids: Tensor,
892 encoder_outputs: Option<&'a Tensor>,
893 past: Cache,
894 attention_mask: Tensor,
895 ) -> PreparedInput<'a> {
896 match past {
897 Cache::T5Cache(past) => PreparedInput {
898 prepared_input: None,
899 prepared_attention_mask: Some(attention_mask),
900 prepared_encoder_output: encoder_outputs,
901 prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
902 prepared_position_ids: None,
903 prepared_past: Cache::T5Cache(past),
904 },
905 Cache::None => PreparedInput {
906 prepared_input: None,
907 prepared_attention_mask: Some(attention_mask),
908 prepared_encoder_output: encoder_outputs,
909 prepared_decoder_input: Some(input_ids),
910 prepared_position_ids: None,
911 prepared_past: Cache::T5Cache(None),
912 },
913 _ => panic!("Cache type incompatible with T5"),
914 }
915 }
916
917 fn reorder_cache(
918 &self,
919 past: &mut Cache,
920 encoder_outputs: Option<Tensor>,
921 beam_indices: &Tensor,
922 ) -> Option<Tensor> {
923 match past {
924 Cache::T5Cache(old_cache_option) => match old_cache_option {
925 Some(old_cache) => {
926 for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
927 if self_layer_state.is_some() {
928 self_layer_state
929 .as_mut()
930 .unwrap()
931 .reorder_cache(beam_indices)
932 };
933 if encoder_layer_state.is_some() {
934 encoder_layer_state
935 .as_mut()
936 .unwrap()
937 .reorder_cache(beam_indices)
938 };
939 }
940 }
941 None => {}
942 },
943 Cache::None => {}
944 _ => {
945 panic!("Invalid cache for T5 model");
946 }
947 };
948 encoder_outputs
949 }
950}
951
952impl LanguageGenerator for T5Generator {}