rust_bert/models/gpt2/
gpt2_model.rs

1// Copyright 2018-present, the HuggingFace Inc. team
2// Copyright 2018-present, The OpenAI Team Authors
3// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4// Copyright 2019 Guillaume Becquin
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//     http://www.apache.org/licenses/LICENSE-2.0
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::common::activations::Activation;
16use crate::common::dropout::Dropout;
17use crate::common::embeddings::process_ids_embeddings_pair;
18use crate::gpt2::transformer::Block;
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21    PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::{Config, RustBertError};
25use serde::{Deserialize, Serialize};
26use std::borrow::{Borrow, BorrowMut};
27use tch::kind::Kind::Int64;
28use tch::nn::embedding;
29use tch::{nn, Device, Kind, Tensor};
30
31/// # GPT2 Pretrained model weight files
32pub struct Gpt2ModelResources;
33
34/// # GPT2 Pretrained model config files
35pub struct Gpt2ConfigResources;
36
37/// # GPT2 Pretrained model vocab files
38pub struct Gpt2VocabResources;
39
40/// # GPT2 Pretrained model merges files
41pub struct Gpt2MergesResources;
42
43impl Gpt2ModelResources {
44    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
45    pub const GPT2: (&'static str, &'static str) = (
46        "gpt2/model",
47        "https://huggingface.co/gpt2/resolve/main/rust_model.ot",
48    );
49    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
50    pub const GPT2_MEDIUM: (&'static str, &'static str) = (
51        "gpt2-medium/model",
52        "https://huggingface.co/gpt2-medium/resolve/main/rust_model.ot",
53    );
54    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
55    pub const GPT2_LARGE: (&'static str, &'static str) = (
56        "gpt2-large/model",
57        "https://huggingface.co/gpt2-large/resolve/main/rust_model.ot",
58    );
59    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
60    pub const GPT2_XL: (&'static str, &'static str) = (
61        "gpt2-xl/model",
62        "https://huggingface.co/gpt2-xl/resolve/main/rust_model.ot",
63    );
64    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
65    pub const DISTIL_GPT2: (&'static str, &'static str) = (
66        "distilgpt2/model",
67        "https://huggingface.co/distilgpt2/resolve/main/rust_model.ot",
68    );
69    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
70    pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
71        "dialogpt-medium/model",
72        "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot",
73    );
74}
75
76impl Gpt2ConfigResources {
77    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
78    pub const GPT2: (&'static str, &'static str) = (
79        "gpt2/config",
80        "https://huggingface.co/gpt2/resolve/main/config.json",
81    );
82    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
83    pub const GPT2_MEDIUM: (&'static str, &'static str) = (
84        "gpt2-medium/config",
85        "https://huggingface.co/gpt2-medium/resolve/main/config.json",
86    );
87    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
88    pub const GPT2_LARGE: (&'static str, &'static str) = (
89        "gpt2-large/config",
90        "https://huggingface.co/gpt2-large/resolve/main/config.json",
91    );
92    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
93    pub const GPT2_XL: (&'static str, &'static str) = (
94        "gpt2-xl/config",
95        "https://huggingface.co/gpt2-xl/resolve/main/config.json",
96    );
97    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
98    pub const DISTIL_GPT2: (&'static str, &'static str) = (
99        "distilgpt2/config",
100        "https://huggingface.co/distilgpt2/resolve/main/config.json",
101    );
102    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
103    pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
104        "dialogpt-medium/config",
105        "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json",
106    );
107}
108
109impl Gpt2VocabResources {
110    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
111    pub const GPT2: (&'static str, &'static str) = (
112        "gpt2/vocab",
113        "https://huggingface.co/gpt2/resolve/main/vocab.json",
114    );
115    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
116    pub const GPT2_MEDIUM: (&'static str, &'static str) = (
117        "gpt2-medium/vocab",
118        "https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
119    );
120    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
121    pub const GPT2_LARGE: (&'static str, &'static str) = (
122        "gpt2-large/vocab",
123        "https://huggingface.co/gpt2-large/resolve/main/vocab.json",
124    );
125    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
126    pub const GPT2_XL: (&'static str, &'static str) = (
127        "gpt2-xl/vocab",
128        "https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
129    );
130    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
131    pub const DISTIL_GPT2: (&'static str, &'static str) = (
132        "distilgpt2/vocab",
133        "https://huggingface.co/distilgpt2/resolve/main/vocab.json",
134    );
135    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
136    pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
137        "dialogpt-medium/vocab",
138        "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json",
139    );
140}
141
142impl Gpt2MergesResources {
143    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
144    pub const GPT2: (&'static str, &'static str) = (
145        "gpt2/merges",
146        "https://huggingface.co/gpt2/resolve/main/merges.txt",
147    );
148    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
149    pub const GPT2_MEDIUM: (&'static str, &'static str) = (
150        "gpt2-medium/merges",
151        "https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
152    );
153    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
154    pub const GPT2_LARGE: (&'static str, &'static str) = (
155        "gpt2-large/merges",
156        "https://huggingface.co/gpt2-large/resolve/main/merges.txt",
157    );
158    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
159    pub const GPT2_XL: (&'static str, &'static str) = (
160        "gpt2-xl/merges",
161        "https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
162    );
163    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
164    pub const DISTIL_GPT2: (&'static str, &'static str) = (
165        "distilgpt2/merges",
166        "https://huggingface.co/distilgpt2/resolve/main/merges.txt",
167    );
168    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
169    pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
170        "dialogpt-medium/merges",
171        "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt",
172    );
173}
174
175#[derive(Debug, Serialize, Deserialize, Clone)]
176/// # GPT2 model configuration
177/// Defines the GPT2 model architecture (e.g. number of layers, hidden layer size, vocab size...).
178/// Shared between GPT and GPT2 models
179pub struct Gpt2Config {
180    pub attn_pdrop: Option<f64>,
181    pub embd_pdrop: Option<f64>,
182    pub hidden_dropout_prob: Option<f64>,
183    pub afn: Option<Activation>,
184    pub initializer_range: f64,
185    pub layer_norm_epsilon: f64,
186    pub n_ctx: i64,
187    pub n_embd: i64,
188    pub n_head: i64,
189    pub n_layer: i64,
190    pub n_positions: i64,
191    pub num_labels: Option<i64>,
192    pub output_past: Option<bool>,
193    pub output_attentions: Option<bool>,
194    pub output_hidden_states: Option<bool>,
195    pub resid_pdrop: Option<f64>,
196    pub vocab_size: i64,
197    pub decoder_start_token_id: Option<i64>,
198    pub forced_bos_token_id: Option<i64>,
199    pub forced_eos_token_id: Option<i64>,
200}
201
202impl Config for Gpt2Config {}
203
204impl Default for Gpt2Config {
205    fn default() -> Self {
206        Gpt2Config {
207            attn_pdrop: Some(0.1),
208            embd_pdrop: Some(0.1),
209            hidden_dropout_prob: None,
210            afn: Some(Activation::gelu_new),
211            initializer_range: 0.02,
212            layer_norm_epsilon: 1e-5,
213            n_ctx: 1024,
214            n_embd: 768,
215            n_head: 12,
216            n_layer: 12,
217            n_positions: 0,
218            num_labels: None,
219            output_past: None,
220            output_attentions: None,
221            output_hidden_states: None,
222            resid_pdrop: Some(0.1),
223            vocab_size: 50257,
224            decoder_start_token_id: None,
225            forced_bos_token_id: None,
226            forced_eos_token_id: None,
227        }
228    }
229}
230
231/// # GPT2 Base model
232/// Base architecture for GPT2 model. Usually complemented with a task-specific head, such as a language model head.
233/// It is made of the following blocks:
234/// - `wte`: `token` embeddings
235/// - `wpe`: `position` embeddings
236/// - `h`: Encoder (transformer) made of a vector of layers. Each layer is made of a multi-head attention layer, layer-normalization layers and a MLP made of linear layers.
237/// - `output_past`: flag indicating if the model should return a past state. This can be fed back to the model to improve the quality of text generated.
238/// - `output_hidden_states`: flag indicating if the model should return all hidden states (as opposed to only the last layer)
239/// - `output_attentions`: flag indicating if the model should return activation weights
240pub struct Gpt2Model {
241    wte: nn::Embedding,
242    wpe: nn::Embedding,
243    drop: Dropout,
244    ln_f: nn::LayerNorm,
245    h: Vec<Block>,
246    output_past: bool,
247    output_hidden_states: bool,
248    output_attentions: bool,
249}
250
251impl Gpt2Model {
252    /// Build a new `Gpt2Model`
253    ///
254    /// # Arguments
255    ///
256    /// * `p` - Variable store path for the root of the GPT2 model
257    /// * `config` - `Gpt2Config` object defining the model architecture
258    ///
259    /// # Example
260    ///
261    /// ```no_run
262    /// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
263    /// use rust_bert::Config;
264    /// use std::path::Path;
265    /// use tch::{nn, Device};
266    ///
267    /// let config_path = Path::new("path/to/config.json");
268    /// let device = Device::Cpu;
269    /// let p = nn::VarStore::new(device);
270    /// let config = Gpt2Config::from_file(config_path);
271    /// let gpt2: Gpt2Model = Gpt2Model::new(&p.root() / "gpt2", &config);
272    /// ```
273    pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model
274    where
275        P: Borrow<nn::Path<'p>>,
276    {
277        let p = p.borrow() / "transformer";
278
279        let wte = embedding(
280            &p / "wte",
281            config.vocab_size,
282            config.n_embd,
283            Default::default(),
284        );
285        let wpe = embedding(
286            &p / "wpe",
287            config.n_positions,
288            config.n_embd,
289            Default::default(),
290        );
291
292        let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
293        let drop = Dropout::new(embd_pdrop);
294        let layer_norm_config = nn::LayerNormConfig {
295            eps: config.layer_norm_epsilon,
296            ..Default::default()
297        };
298        let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
299        let mut h: Vec<Block> = vec![];
300        let h_path = &p / "h";
301        for layer_index in 0..config.n_layer {
302            h.push(Block::new(&h_path / layer_index, config, true));
303        }
304        let output_attentions = config.output_attentions.unwrap_or(false);
305        let output_past = config.output_past.unwrap_or(true);
306        let output_hidden_states = config.output_hidden_states.unwrap_or(false);
307
308        Gpt2Model {
309            wte,
310            wpe,
311            drop,
312            ln_f,
313            h,
314            output_past,
315            output_hidden_states,
316            output_attentions,
317        }
318    }
319
320    /// Forward pass through the model
321    ///
322    /// # Arguments
323    ///
324    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
325    /// * `layer_past` - Optional vector of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values.
326    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
327    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
328    /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
329    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
330    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
331    ///
332    /// # Returns
333    ///
334    /// * `Gpt2ModelOutput` containing:
335    ///   - `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the activations of the last hidden state
336    ///   - `cache` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
337    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
338    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
339    ///
340    /// # Example
341    ///
342    /// ```no_run
343    /// # use tch::{nn, Device, Tensor, no_grad};
344    /// # use rust_bert::Config;
345    /// # use std::path::Path;
346    /// # use tch::kind::Kind::{Int64, Double};
347    /// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
348    /// # let config_path = Path::new("path/to/config.json");
349    /// # let vocab_path = Path::new("path/to/vocab.txt");
350    /// # let device = Device::Cpu;
351    /// # let vs = nn::VarStore::new(device);
352    /// # let config = Gpt2Config::from_file(config_path);
353    /// # let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
354    /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
355    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
356    /// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
357    /// for _ in 0..config.n_layer as usize {
358    ///     past.push(Tensor::rand(
359    ///         &[
360    ///             2,
361    ///             batch_size,
362    ///             config.n_head,
363    ///             past_sequence_length,
364    ///             config.n_embd / config.n_head,
365    ///         ],
366    ///         (Double, device),
367    ///     ))
368    /// }
369    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
370    /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
371    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
372    ///     .expand(&[batch_size, sequence_length], true);
373    ///
374    /// let model_output = no_grad(|| {
375    ///     gpt2_model
376    ///         .forward_t(
377    ///             Some(&input_tensor),
378    ///             Some(&past),
379    ///             Some(&attention_mask),
380    ///             Some(&token_type_ids),
381    ///             Some(&position_ids),
382    ///             None,
383    ///             false,
384    ///         )
385    ///         .unwrap()
386    /// });
387    /// ```
388    pub fn forward_t(
389        &self,
390        input_ids: Option<&Tensor>,
391        layer_past: Option<&Vec<Tensor>>,
392        attention_mask: Option<&Tensor>,
393        token_type_ids: Option<&Tensor>,
394        position_ids: Option<&Tensor>,
395        input_embeds: Option<&Tensor>,
396        train: bool,
397    ) -> Result<Gpt2ModelOutput, RustBertError> {
398        let (calc_input_embeddings, input_size, _) =
399            process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?;
400        let input_embeddings =
401            input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
402
403        let seq_length = input_size[1];
404
405        let (layer_past, layer_past_length) = match layer_past {
406            Some(value) => {
407                assert_eq!(
408                    value.len(),
409                    self.h.len(),
410                    "Past activations vector must be of length equal to the number of layers"
411                );
412                (
413                    value
414                        .iter()
415                        .map(|v| Some(v.copy()))
416                        .collect::<Vec<Option<Tensor>>>(),
417                    value[0].size()[3],
418                )
419            }
420            None => {
421                let mut out = Vec::with_capacity(self.h.len());
422                out.resize_with(self.h.len(), || None::<Tensor>);
423                (out, 0)
424            }
425        };
426
427        let position_ids = match position_ids {
428            Some(value) => value.copy(),
429            None => Tensor::arange_start(
430                layer_past_length,
431                seq_length + layer_past_length,
432                (Int64, input_embeddings.device()),
433            )
434            .unsqueeze(0),
435        };
436
437        let attention_mask: Option<Tensor> = attention_mask.map(|value| {
438            let attention_mask = value
439                .view((input_embeddings.size()[0], -1))
440                .unsqueeze(1)
441                .unsqueeze(2)
442                .to_kind(input_embeddings.kind());
443
444            let attention_mask: Tensor = (1.0 - attention_mask) * (-10000.0);
445            attention_mask.to_kind(input_embeddings.kind())
446        });
447
448        let position_embeds = position_ids.apply(&self.wpe);
449        let token_type_embeds = match token_type_ids {
450            Some(value) => value.apply(&self.wte),
451            None => Tensor::zeros_like(&position_embeds),
452        };
453        let mut hidden_state: Tensor =
454            (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
455        let mut all_presents: Option<Vec<Tensor>> =
456            if self.output_past { Some(vec![]) } else { None };
457        let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
458            Some(vec![])
459        } else {
460            None
461        };
462        let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
463            Some(vec![])
464        } else {
465            None
466        };
467
468        let layer_iter = self.h.iter().zip(layer_past);
469        for layer_values in layer_iter {
470            let (layer, past) = layer_values;
471            let temp =
472                layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train);
473            hidden_state = temp.0;
474            if let Some(presents) = all_presents.borrow_mut() {
475                presents.push(temp.1);
476            };
477            if let Some(attentions) = all_attentions.borrow_mut() {
478                attentions.push(temp.2.unwrap());
479            };
480            if let Some(hidden_states) = all_hidden_states.borrow_mut() {
481                hidden_states.push(hidden_state.as_ref().copy());
482            };
483        }
484
485        Ok(Gpt2ModelOutput {
486            output: hidden_state.apply(&self.ln_f),
487            cache: all_presents,
488            all_hidden_states,
489            all_attentions,
490        })
491    }
492}
493
494/// # GPT2 Language Modeling head
495/// GPT2 model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings
496/// It is made of the following blocks:
497/// - `transformer`: Base Gpt2Model
498pub struct GPT2LMHeadModel {
499    transformer: Gpt2Model,
500}
501
502impl GPT2LMHeadModel {
503    /// Build a new `GPT2LMHeadModel`
504    ///
505    /// # Arguments
506    ///
507    /// * `p` - Variable store path for the root of the GPT2 model
508    /// * `config` - `Gpt2Config` object defining the model architecture
509    ///
510    /// # Example
511    ///
512    /// ```no_run
513    /// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
514    /// use rust_bert::Config;
515    /// use std::path::Path;
516    /// use tch::{nn, Device};
517    ///
518    /// let config_path = Path::new("path/to/config.json");
519    /// let device = Device::Cpu;
520    /// let p = nn::VarStore::new(device);
521    /// let config = Gpt2Config::from_file(config_path);
522    /// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&p.root() / "gpt2", &config);
523    /// ```
524    pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel
525    where
526        P: Borrow<nn::Path<'p>>,
527    {
528        let p = p.borrow();
529
530        let transformer = Gpt2Model::new(p, config);
531
532        GPT2LMHeadModel { transformer }
533    }
534
535    pub fn forward_t(
536        &self,
537        input_ids: Option<&Tensor>,
538        layer_past: Option<&Vec<Tensor>>,
539        attention_mask: Option<&Tensor>,
540        token_type_ids: Option<&Tensor>,
541        position_ids: Option<&Tensor>,
542        input_embeds: Option<&Tensor>,
543        train: bool,
544    ) -> Result<LMModelOutput, RustBertError> {
545        let base_model_output = self.transformer.forward_t(
546            input_ids,
547            layer_past,
548            attention_mask,
549            token_type_ids,
550            position_ids,
551            input_embeds,
552            train,
553        )?;
554
555        let lm_logits = base_model_output
556            .output
557            .linear::<Tensor>(&self.transformer.wte.ws, None);
558        Ok(LMModelOutput {
559            lm_logits,
560            cache: Cache::GPT2Cache(base_model_output.cache),
561        })
562    }
563}
564
565/// Container for the GPT2 model output.
566pub struct Gpt2ModelOutput {
567    /// Hidden state of the last layer of the decoder, or logits for a custom head
568    /// module after the decoder (e.g. vocabulary logits for language modeling tasks)
569    pub output: Tensor,
570    /// Cached attention layers keys and values if the model is used for generation
571    pub cache: Option<Vec<Tensor>>,
572    /// Hidden states for all intermediate layers
573    pub all_hidden_states: Option<Vec<Tensor>>,
574    /// Attention weights for all intermediate layers
575    pub all_attentions: Option<Vec<Tensor>>,
576}
577
578/// # Language generation model based on the GPT2 architecture
579pub struct GPT2Generator {
580    model: GPT2LMHeadModel,
581    tokenizer: TokenizerOption,
582    var_store: nn::VarStore,
583    generate_config: GenerateConfig,
584    bos_token_id: Option<i64>,
585    eos_token_ids: Option<Vec<i64>>,
586    pad_token_id: Option<i64>,
587    is_encoder_decoder: bool,
588    vocab_size: i64,
589    decoder_start_id: Option<i64>,
590    max_position_embeddings: i64,
591}
592
593impl GPT2Generator {
594    /// Build a new `GPT2Generator`
595    ///
596    /// # Arguments
597    ///
598    /// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
599    ///
600    /// # Example
601    ///
602    /// ```no_run
603    /// # fn main() -> anyhow::Result<()> {
604    /// use rust_bert::gpt2::GPT2Generator;
605    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
606    ///
607    /// let generate_config = GenerateConfig {
608    ///     max_length: Some(30),
609    ///     do_sample: true,
610    ///     num_beams: 5,
611    ///     temperature: 1.1,
612    ///     num_return_sequences: 3,
613    ///     ..Default::default()
614    /// };
615    /// let gpt2_generator = GPT2Generator::new(generate_config)?;
616    /// # Ok(())
617    /// # }
618    /// ```
619    pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
620        let vocab_path = generate_config.vocab_resource.get_local_path()?;
621        let merges_path = generate_config
622            .merges_resource
623            .as_ref()
624            .ok_or_else(|| {
625                RustBertError::InvalidConfigurationError(
626                    "GPT2 expects a merges resources to be provided".to_string(),
627                )
628            })?
629            .get_local_path()?;
630
631        let tokenizer = TokenizerOption::from_file(
632            ModelType::GPT2,
633            vocab_path.to_str().unwrap(),
634            Some(merges_path.to_str().unwrap()),
635            false,
636            None,
637            None,
638        )?;
639
640        Self::new_with_tokenizer(generate_config, tokenizer)
641    }
642
643    pub fn new_with_tokenizer(
644        generate_config: GenerateConfig,
645        tokenizer: TokenizerOption,
646    ) -> Result<GPT2Generator, RustBertError> {
647        let config_path = generate_config.config_resource.get_local_path()?;
648        let device = generate_config.device;
649
650        generate_config.validate();
651        let mut var_store = nn::VarStore::new(device);
652
653        let config = Gpt2Config::from_file(config_path);
654        let model = GPT2LMHeadModel::new(var_store.root(), &config);
655        crate::resources::load_weights(
656            &generate_config.model_resource,
657            &mut var_store,
658            generate_config.kind,
659            device,
660        )?;
661
662        let bos_token_id = tokenizer.get_bos_id();
663        let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
664        let pad_token_id = tokenizer.get_pad_id();
665        let max_position_embeddings = config.n_positions;
666        let is_encoder_decoder = false;
667        let vocab_size = config.vocab_size;
668        let decoder_start_id = config.decoder_start_token_id;
669
670        Ok(GPT2Generator {
671            model,
672            tokenizer,
673            var_store,
674            generate_config,
675            bos_token_id,
676            eos_token_ids,
677            pad_token_id,
678            is_encoder_decoder,
679            vocab_size,
680            decoder_start_id,
681            max_position_embeddings,
682        })
683    }
684}
685
686impl PrivateLanguageGenerator for GPT2Generator {
687    fn _get_tokenizer(&self) -> &TokenizerOption {
688        &self.tokenizer
689    }
690    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
691        &mut self.tokenizer
692    }
693    fn get_device(&self) -> Device {
694        self.var_store.device()
695    }
696    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
697        Ok(&mut self.var_store)
698    }
699    fn get_config(&self) -> &GenerateConfig {
700        &self.generate_config
701    }
702    fn get_bos_id(&self) -> Option<i64> {
703        self.bos_token_id
704    }
705    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
706        self.eos_token_ids.as_ref()
707    }
708    fn get_pad_id(&self) -> Option<i64> {
709        self.pad_token_id
710    }
711    fn is_encoder_decoder(&self) -> bool {
712        self.is_encoder_decoder
713    }
714    fn get_vocab_size(&self) -> i64 {
715        self.vocab_size
716    }
717    fn get_decoder_start_id(&self) -> Option<i64> {
718        self.decoder_start_id
719    }
720    fn get_max_positions_embeddings(&self) -> Option<i64> {
721        Some(self.max_position_embeddings)
722    }
723
724    fn forward_t(
725        &self,
726        input_ids: Option<&Tensor>,
727        layer_past: Cache,
728        attention_mask: Option<&Tensor>,
729        token_type_ids: Option<&Tensor>,
730        position_ids: Option<&Tensor>,
731        input_embeds: Option<&Tensor>,
732        _encoder_outputs: Option<&Tensor>,
733        _decoder_input_ids: Option<&Tensor>,
734        train: bool,
735    ) -> Result<LMModelOutput, RustBertError> {
736        match layer_past {
737            Cache::GPT2Cache(layer_past) => self.model.forward_t(
738                input_ids,
739                layer_past.as_ref(),
740                attention_mask,
741                token_type_ids,
742                position_ids,
743                input_embeds,
744                train,
745            ),
746            Cache::None => self.model.forward_t(
747                input_ids,
748                None,
749                attention_mask,
750                token_type_ids,
751                position_ids,
752                input_embeds,
753                train,
754            ),
755            _ => Err(RustBertError::ValueError(
756                "Cache not compatible with GPT2 Model".into(),
757            )),
758        }
759    }
760
761    fn prepare_inputs_for_generation<'a>(
762        &self,
763        input_ids: Tensor,
764        _encoder_outputs: Option<&'a Tensor>,
765        past: Cache,
766        attention_mask: Tensor,
767    ) -> PreparedInput<'a> {
768        let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
769            .masked_fill(&attention_mask.eq(0), 1);
770
771        match past {
772            Cache::GPT2Cache(past) => {
773                if past.is_some() {
774                    PreparedInput {
775                        prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
776                        prepared_attention_mask: Some(attention_mask),
777                        prepared_encoder_output: None,
778                        prepared_decoder_input: None,
779                        prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
780                        prepared_past: Cache::GPT2Cache(past),
781                    }
782                } else {
783                    PreparedInput {
784                        prepared_input: Some(input_ids),
785                        prepared_attention_mask: Some(attention_mask),
786                        prepared_encoder_output: None,
787                        prepared_decoder_input: None,
788                        prepared_position_ids: Some(position_ids),
789                        prepared_past: Cache::GPT2Cache(None),
790                    }
791                }
792            }
793            Cache::None => PreparedInput {
794                prepared_input: Some(input_ids),
795                prepared_attention_mask: Some(attention_mask),
796                prepared_encoder_output: None,
797                prepared_decoder_input: None,
798                prepared_position_ids: Some(position_ids),
799                prepared_past: Cache::GPT2Cache(None),
800            },
801            _ => panic!("Cache type incompatible with GPT2"),
802        }
803    }
804
805    fn reorder_cache(
806        &self,
807        past: &mut Cache,
808        _encoder_outputs: Option<Tensor>,
809        beam_indices: &Tensor,
810    ) -> Option<Tensor> {
811        match past {
812            Cache::GPT2Cache(cached_decoder_state) => match cached_decoder_state {
813                Some(value) => {
814                    for layer_past in value.iter_mut() {
815                        *layer_past = layer_past.index_select(1, beam_indices);
816                    }
817                    None
818                }
819                None => None,
820            },
821            Cache::None => None,
822            _ => {
823                panic!("Invalid cache for GPT2 model");
824            }
825        }
826    }
827}
828
829impl LanguageGenerator for GPT2Generator {}