rust_bert/models/gpt_j/
gpt_j_model.rs

1// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
2// Copyright 2022 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::activations::Activation;
14use crate::common::dropout::Dropout;
15use crate::common::embeddings::process_ids_embeddings_pair;
16use crate::common::kind::get_min;
17use crate::gpt_j::attention::LayerState;
18use crate::gpt_j::transformer::GptJBlock;
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21    PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::{Config, RustBertError};
25use serde::{Deserialize, Serialize};
26use std::borrow::{Borrow, BorrowMut};
27use tch::nn::{embedding, Linear};
28use tch::{nn, Device, Tensor};
29
30/// # GPT-J Pretrained model weight files
31pub struct GptJModelResources;
32
33/// # GPT-J Pretrained model config files
34pub struct GptJConfigResources;
35
36/// # GPT-J Pretrained model vocab files
37pub struct GptJVocabResources;
38
39/// # GPT-J Pretrained model merges files
40pub struct GptJMergesResources;
41
42/// Model weights for Rust are not available out of the box for GPT-J but can be created
43/// simply with the following command:
44///
45/// ```ignore
46/// python utils/convert_model.py path/to/gpt_j/pytorch_model.bin
47/// ```
48///
49/// Where `pytorch_model.bin` was downloaded from [EleutherAI GPT-J 6B][gpt-j-6B] or
50/// [EleutherAI GPT-J 6B (float16)][gpt-j-6B-float16]. Note that to convert GPT-J 6B you
51/// will need about 32 Gb of RAM, and converting GPT-J 6B float16 requires about 12 Gb
52/// of RAM.
53///
54/// [gpt-j-6B]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/main
55/// [gpt-j-6B-float16]:https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16
56impl GptJModelResources {
57    pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
58        "gpt-j-tiny-random/model",
59        "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/rust_model.ot",
60    );
61}
62
63impl GptJConfigResources {
64    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
65    pub const GPT_J_6B: (&'static str, &'static str) = (
66        "gpt-j-6B/config",
67        "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json",
68    );
69    pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = (
70        "gpt-j-6B/config",
71        "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/config.json",
72    );
73    pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
74        "gpt-j-tiny-random/config",
75        "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/config.json",
76    );
77}
78
79impl GptJVocabResources {
80    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
81    pub const GPT_J_6B: (&'static str, &'static str) = (
82        "gpt-j-6B/vocab",
83        "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/vocab.json",
84    );
85    pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = (
86        "gpt-j-6B/vocab",
87        "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/vocab.json",
88    );
89    pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
90        "gpt-j-tiny-random/vocab",
91        "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/vocab.json",
92    );
93}
94
95impl GptJMergesResources {
96    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
97    pub const GPT_J_6B: (&'static str, &'static str) = (
98        "gpt-j-6B/merges",
99        "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/merges.txt",
100    );
101    pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = (
102        "gpt-j-6B/merges",
103        "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/merges.txt",
104    );
105    pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
106        "gpt-j-tiny-random/merges",
107        "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/merges.txt",
108    );
109}
110
111#[derive(Debug, Serialize, Deserialize, Clone)]
112/// # GPT-J model configuration
113/// Defines the GPT-J model architecture (e.g. number of layers, hidden layer size, vocab size...).
114pub struct GptJConfig {
115    pub attn_pdrop: Option<f64>,
116    pub embd_pdrop: Option<f64>,
117    pub hidden_dropout_prob: Option<f64>,
118    pub afn: Option<Activation>,
119    pub initializer_range: f64,
120    pub layer_norm_epsilon: f64,
121    pub n_embd: i64,
122    pub n_head: i64,
123    pub n_layer: i64,
124    pub n_positions: i64,
125    pub n_inner: Option<i64>,
126    pub num_labels: Option<i64>,
127    pub use_cache: Option<bool>,
128    pub output_attentions: Option<bool>,
129    pub output_hidden_states: Option<bool>,
130    pub resid_pdrop: Option<f64>,
131    pub rotary_dim: Option<i64>,
132    pub vocab_size: i64,
133    pub scale_attn_weights: Option<bool>,
134    #[serde(default = "default_preload_on_cpu")]
135    pub preload_on_cpu: bool,
136    pub decoder_start_token_id: Option<i64>,
137    pub forced_bos_token_id: Option<i64>,
138    pub forced_eos_token_id: Option<i64>,
139}
140
141impl Config for GptJConfig {}
142
143impl Default for GptJConfig {
144    fn default() -> Self {
145        GptJConfig {
146            attn_pdrop: Some(0.1),
147            embd_pdrop: Some(0.1),
148            hidden_dropout_prob: None,
149            afn: Some(Activation::gelu_new),
150            initializer_range: 0.02,
151            layer_norm_epsilon: 1e-5,
152            n_embd: 4096,
153            n_head: 16,
154            n_layer: 28,
155            n_positions: 2048,
156            n_inner: None,
157            num_labels: None,
158            use_cache: None,
159            output_attentions: None,
160            output_hidden_states: None,
161            resid_pdrop: Some(0.1),
162            rotary_dim: Some(64),
163            vocab_size: 50400,
164            scale_attn_weights: Some(true),
165            preload_on_cpu: default_preload_on_cpu(),
166            decoder_start_token_id: None,
167            forced_bos_token_id: None,
168            forced_eos_token_id: None,
169        }
170    }
171}
172
173fn default_preload_on_cpu() -> bool {
174    true
175}
176
177/// # GPT-J Base model
178/// Base architecture for GPT-J model. Usually complemented with a task-specific head, such as a language model head.
179/// It is made of the following blocks:
180/// - `wte`: `token` embeddings
181/// - `h`: Encoder (transformer) made of a vector of layers. Each layer is made of a multi-head attention layer, a layer-normalization layer, and a MLP made of linear layers.
182/// - `output_past`: flag indicating if the model should return a past state. This can be fed back to the model to improve the quality of text generated.
183/// - `output_hidden_states`: flag indicating if the model should return all hidden states (as opposed to only the last layer)
184/// - `output_attentions`: flag indicating if the model should return activation weights
185pub struct GptJModel {
186    wte: nn::Embedding,
187    drop: Dropout,
188    ln_f: nn::LayerNorm,
189    h: Vec<GptJBlock>,
190    use_cache: bool,
191    output_hidden_states: bool,
192    output_attentions: bool,
193}
194
195impl GptJModel {
196    /// Build a new `GptJModel`
197    ///
198    /// # Arguments
199    ///
200    /// * `p` - Variable store path for the root of the GPT-J model
201    /// * `config` - `GptJConfig` object defining the model architecture
202    ///
203    /// # Example
204    ///
205    /// ```no_run
206    /// use rust_bert::gpt_j::{GptJConfig, GptJModel};
207    /// use rust_bert::Config;
208    /// use std::path::Path;
209    /// use tch::{nn, Device};
210    ///
211    /// let config_path = Path::new("path/to/config.json");
212    /// let device = Device::Cpu;
213    /// let p = nn::VarStore::new(device);
214    /// let config = GptJConfig::from_file(config_path);
215    /// let gpt_j: GptJModel = GptJModel::new(&p.root() / "gpt_j", &config);
216    /// ```
217    pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJModel
218    where
219        P: Borrow<nn::Path<'p>>,
220    {
221        let p = p.borrow() / "transformer";
222
223        let wte = embedding(
224            &p / "wte",
225            config.vocab_size,
226            config.n_embd,
227            Default::default(),
228        );
229
230        let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
231        let drop = Dropout::new(embd_pdrop);
232
233        let layer_norm_config = nn::LayerNormConfig {
234            eps: config.layer_norm_epsilon,
235            ..Default::default()
236        };
237        let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
238
239        let mut h: Vec<GptJBlock> = vec![];
240        let h_path = &p / "h";
241        for layer_index in 0..config.n_layer {
242            h.push(GptJBlock::new(&h_path / layer_index, config));
243        }
244
245        let use_cache = config.use_cache.unwrap_or(true);
246        let output_attentions = config.output_attentions.unwrap_or(false);
247        let output_hidden_states = config.output_hidden_states.unwrap_or(false);
248
249        GptJModel {
250            wte,
251            drop,
252            ln_f,
253            h,
254            use_cache,
255            output_hidden_states,
256            output_attentions,
257        }
258    }
259
260    /// Forward pass through the model
261    ///
262    /// # Arguments
263    ///
264    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
265    /// * `layer_past` - Optional vector of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values.
266    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
267    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
268    /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
269    /// * `_position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
270    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
271    ///
272    /// # Returns
273    ///
274    /// * `GptJModelOutput` containing:
275    ///   - `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the activations of the last hidden state
276    ///   - `cache` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
277    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
278    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
279    ///
280    /// # Example
281    ///
282    /// ```no_run
283    /// # use tch::{nn, Device, Tensor, no_grad};
284    /// # use rust_bert::Config;
285    /// # use std::path::Path;
286    /// # use tch::kind::Kind::{Int64, Double};
287    /// use rust_bert::gpt_j::{GptJConfig, GptJModel, LayerState};
288    /// # let config_path = Path::new("path/to/config.json");
289    /// # let vocab_path = Path::new("path/to/vocab.txt");
290    /// # let device = Device::Cpu;
291    /// # let vs = nn::VarStore::new(device);
292    /// # let config = GptJConfig::from_file(config_path);
293    /// # let gpt_j_model: GptJModel = GptJModel::new(&vs.root(), &config);
294    /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
295    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
296    /// let mut past: Vec<Option<LayerState>> = Vec::with_capacity(config.n_layer as usize);
297    /// for _ in 0..config.n_layer as usize {
298    ///     past.push(Some(LayerState {
299    ///         prev_key: Tensor::rand(
300    ///             &[
301    ///                 batch_size,
302    ///                 config.n_head,
303    ///                 past_sequence_length,
304    ///                 config.n_embd / config.n_head,
305    ///             ],
306    ///             (Double, device),
307    ///         ),
308    ///         prev_value: Tensor::rand(
309    ///             &[
310    ///                 batch_size,
311    ///                 config.n_head,
312    ///                 past_sequence_length,
313    ///                 config.n_embd / config.n_head,
314    ///             ],
315    ///             (Double, device),
316    ///         ),
317    ///     }))
318    /// }
319    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
320    /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
321    ///
322    /// let model_output = no_grad(|| {
323    ///     gpt_j_model
324    ///         .forward_t(
325    ///             Some(&input_tensor),
326    ///             Some(past),
327    ///             Some(&attention_mask),
328    ///             Some(&token_type_ids),
329    ///             None,
330    ///             None,
331    ///             false,
332    ///         )
333    ///         .unwrap()
334    /// });
335    /// ```
336    pub fn forward_t(
337        &self,
338        input_ids: Option<&Tensor>,
339        layer_past: Option<Vec<Option<LayerState>>>,
340        attention_mask: Option<&Tensor>,
341        token_type_ids: Option<&Tensor>,
342        _position_ids: Option<&Tensor>,
343        input_embeds: Option<&Tensor>,
344        train: bool,
345    ) -> Result<GptJModelOutput, RustBertError> {
346        let (calc_input_embeddings, _input_size, _device) =
347            process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?;
348
349        let input_embeddings =
350            input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
351
352        let (layer_past, _layer_past_length) = match layer_past {
353            Some(value) => {
354                if value.len() != self.h.len() {
355                    return Err(RustBertError::ValueError(format!(
356                        "Past activations vector length ({}) must be equal to the number of layers ({})",
357                        value.len(),
358                        self.h.len()
359                    )));
360                } else {
361                    let length = value.len();
362                    (value, length)
363                }
364            }
365            None => {
366                let mut out = Vec::with_capacity(self.h.len());
367                out.resize_with(self.h.len(), || None);
368                (out, 0)
369            }
370        };
371
372        let kind_min = get_min(input_embeddings.kind())?;
373        let attention_mask: Option<Tensor> = attention_mask.map(|value| {
374            let attention_mask = value
375                .view((input_embeddings.size()[0], -1))
376                .unsqueeze(1)
377                .unsqueeze(2)
378                .to_kind(input_embeddings.kind());
379
380            (attention_mask.ones_like() - attention_mask.to_kind(input_embeddings.kind()))
381                * kind_min
382        });
383
384        let mut hidden_state: Tensor = input_embeddings.copy();
385        if let Some(token_type_ids) = token_type_ids {
386            let token_type_embeds = token_type_ids.apply(&self.wte);
387            hidden_state = hidden_state + token_type_embeds;
388        }
389        hidden_state = hidden_state.apply_t(&self.drop, train);
390
391        let mut all_presents: Option<Vec<Option<LayerState>>> = self.use_cache.then(Vec::new);
392        let mut all_hidden_states: Option<Vec<Tensor>> = self.output_hidden_states.then(Vec::new);
393        let mut all_attentions: Option<Vec<Tensor>> = self.output_attentions.then(Vec::new);
394
395        for (layer, past) in self.h.iter().zip(layer_past) {
396            let temp =
397                layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train);
398            hidden_state = temp.0;
399            if let Some(presents) = all_presents.borrow_mut() {
400                presents.push(temp.1);
401            };
402            if let Some(attentions) = all_attentions.borrow_mut() {
403                attentions.push(std::mem::take(&mut temp.2.unwrap()));
404            };
405            if let Some(hidden_states) = all_hidden_states.borrow_mut() {
406                hidden_states.push(std::mem::take(&mut hidden_state));
407            };
408        }
409
410        let output = hidden_state.apply(&self.ln_f);
411
412        Ok(GptJModelOutput {
413            output,
414            cache: all_presents,
415            all_hidden_states,
416            all_attentions,
417        })
418    }
419}
420
421/// # GPT-J Language Modeling head
422/// GPT-J model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings
423/// It is made of the following blocks:
424/// - `transformer`: Base GptJModel
425pub struct GptJLMHeadModel {
426    transformer: GptJModel,
427    lm_head: Linear,
428}
429
430impl GptJLMHeadModel {
431    /// Build a new `GptJLMHeadModel`
432    ///
433    /// # Arguments
434    ///
435    /// * `p` - Variable store path for the root of the GPT-J model
436    /// * `config` - `GptJConfig` object defining the model architecture
437    ///
438    /// # Example
439    ///
440    /// ```no_run
441    /// use rust_bert::gpt_j::{GptJConfig, GptJLMHeadModel};
442    /// use rust_bert::Config;
443    /// use std::path::Path;
444    /// use tch::{nn, Device};
445    ///
446    /// let config_path = Path::new("path/to/config.json");
447    /// let device = Device::Cpu;
448    /// let p = nn::VarStore::new(device);
449    /// let config = GptJConfig::from_file(config_path);
450    /// let gpt_j: GptJLMHeadModel = GptJLMHeadModel::new(&p.root() / "gpt_j", &config);
451    /// ```
452    pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJLMHeadModel
453    where
454        P: Borrow<nn::Path<'p>>,
455    {
456        let p = p.borrow();
457
458        let transformer = GptJModel::new(p, config);
459        let lm_head = nn::linear(
460            p / "lm_head",
461            config.n_embd,
462            config.vocab_size,
463            Default::default(),
464        );
465
466        GptJLMHeadModel {
467            transformer,
468            lm_head,
469        }
470    }
471
472    pub fn forward_t(
473        &self,
474        input_ids: Option<&Tensor>,
475        layer_past: Cache,
476        attention_mask: Option<&Tensor>,
477        token_type_ids: Option<&Tensor>,
478        position_ids: Option<&Tensor>,
479        input_embeds: Option<&Tensor>,
480        _encoder_outputs: Option<&Tensor>,
481        _decoder_input_ids: Option<&Tensor>,
482        train: bool,
483    ) -> Result<LMModelOutput, RustBertError> {
484        let base_model_output = match layer_past {
485            Cache::GPTJCache(layer_past) => self.transformer.forward_t(
486                input_ids,
487                layer_past,
488                attention_mask,
489                token_type_ids,
490                position_ids,
491                input_embeds,
492                train,
493            ),
494            Cache::None => self.transformer.forward_t(
495                input_ids,
496                None,
497                attention_mask,
498                token_type_ids,
499                position_ids,
500                input_embeds,
501                train,
502            ),
503            _ => {
504                return Err(RustBertError::ValueError(
505                    "Cache not compatible with GPT-J Model".into(),
506                ));
507            }
508        }?;
509
510        let lm_logits = base_model_output.output.apply(&self.lm_head);
511
512        Ok(LMModelOutput {
513            lm_logits,
514            cache: Cache::GPTJCache(base_model_output.cache),
515        })
516    }
517}
518
519/// Container for the GPT-J model output.
520pub struct GptJModelOutput {
521    /// Hidden state of the last layer of the decoder, or logits for a custom head
522    /// module after the decoder (e.g. vocabulary logits for language modeling tasks)
523    pub output: Tensor,
524    /// Cached attention layers keys and values if the model is used for generation
525    pub cache: Option<Vec<Option<LayerState>>>,
526    /// Hidden states for all intermediate layers
527    pub all_hidden_states: Option<Vec<Tensor>>,
528    /// Attention weights for all intermediate layers
529    pub all_attentions: Option<Vec<Tensor>>,
530}
531
532/// # Language generation model based on the GPT-J architecture
533pub struct GptJGenerator {
534    model: GptJLMHeadModel,
535    tokenizer: TokenizerOption,
536    var_store: nn::VarStore,
537    generate_config: GenerateConfig,
538    bos_token_id: Option<i64>,
539    eos_token_ids: Option<Vec<i64>>,
540    pad_token_id: Option<i64>,
541    is_encoder_decoder: bool,
542    vocab_size: i64,
543    decoder_start_id: Option<i64>,
544    max_position_embeddings: i64,
545}
546
547impl GptJGenerator {
548    /// Build a new `GptJGenerator`
549    ///
550    /// # Arguments
551    ///
552    /// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
553    ///
554    /// # Example
555    ///
556    /// ```no_run
557    /// # fn main() -> anyhow::Result<()> {
558    /// use rust_bert::gpt_j::GptJGenerator;
559    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
560    ///
561    /// let generate_config = GenerateConfig {
562    ///     max_length: Some(30),
563    ///     do_sample: true,
564    ///     num_beams: 5,
565    ///     temperature: 1.1,
566    ///     num_return_sequences: 3,
567    ///     ..Default::default()
568    /// };
569    /// let gpt_j_generator = GptJGenerator::new(generate_config)?;
570    /// # Ok(())
571    /// # }
572    /// ```
573    pub fn new(generate_config: GenerateConfig) -> Result<GptJGenerator, RustBertError> {
574        let vocab_path = generate_config.vocab_resource.get_local_path()?;
575        let merges_path = generate_config
576            .merges_resource
577            .as_ref()
578            .ok_or_else(|| {
579                RustBertError::InvalidConfigurationError(
580                    "GPT-J expects a merges resources to be provided".to_string(),
581                )
582            })?
583            .get_local_path()?;
584
585        let tokenizer = TokenizerOption::from_file(
586            ModelType::GPTJ,
587            vocab_path.to_str().unwrap(),
588            Some(merges_path.to_str().unwrap()),
589            false,
590            None,
591            None,
592        )?;
593
594        Self::new_with_tokenizer(generate_config, tokenizer)
595    }
596
597    pub fn new_with_tokenizer(
598        generate_config: GenerateConfig,
599        tokenizer: TokenizerOption,
600    ) -> Result<GptJGenerator, RustBertError> {
601        let config_path = generate_config.config_resource.get_local_path()?;
602        let device = generate_config.device;
603
604        generate_config.validate();
605        let mut var_store = nn::VarStore::new(device);
606
607        let config = GptJConfig::from_file(config_path);
608        let model = GptJLMHeadModel::new(var_store.root(), &config);
609        if config.preload_on_cpu && device != Device::Cpu {
610            var_store.set_device(Device::Cpu);
611        }
612        crate::resources::load_weights(
613            &generate_config.model_resource,
614            &mut var_store,
615            generate_config.kind,
616            device,
617        )?;
618        if device != Device::Cpu {
619            var_store.set_device(device);
620        }
621
622        let bos_token_id = tokenizer.get_bos_id();
623        let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
624        let pad_token_id = tokenizer.get_pad_id();
625        let max_position_embeddings = config.n_positions;
626        let is_encoder_decoder = false;
627        let vocab_size = config.vocab_size;
628        let decoder_start_id = config.decoder_start_token_id;
629
630        Ok(GptJGenerator {
631            model,
632            tokenizer,
633            var_store,
634            generate_config,
635            bos_token_id,
636            eos_token_ids,
637            pad_token_id,
638            is_encoder_decoder,
639            vocab_size,
640            decoder_start_id,
641            max_position_embeddings,
642        })
643    }
644}
645
646impl PrivateLanguageGenerator for GptJGenerator {
647    fn _get_tokenizer(&self) -> &TokenizerOption {
648        &self.tokenizer
649    }
650    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
651        &mut self.tokenizer
652    }
653    fn get_device(&self) -> Device {
654        self.var_store.device()
655    }
656    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
657        Ok(&mut self.var_store)
658    }
659    fn get_config(&self) -> &GenerateConfig {
660        &self.generate_config
661    }
662    fn get_bos_id(&self) -> Option<i64> {
663        self.bos_token_id
664    }
665    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
666        self.eos_token_ids.as_ref()
667    }
668    fn get_pad_id(&self) -> Option<i64> {
669        self.pad_token_id
670    }
671    fn is_encoder_decoder(&self) -> bool {
672        self.is_encoder_decoder
673    }
674    fn get_vocab_size(&self) -> i64 {
675        self.vocab_size
676    }
677    fn get_decoder_start_id(&self) -> Option<i64> {
678        self.decoder_start_id
679    }
680    fn get_max_positions_embeddings(&self) -> Option<i64> {
681        Some(self.max_position_embeddings)
682    }
683
684    fn forward_t(
685        &self,
686        input_ids: Option<&Tensor>,
687        layer_past: Cache,
688        attention_mask: Option<&Tensor>,
689        token_type_ids: Option<&Tensor>,
690        position_ids: Option<&Tensor>,
691        input_embeds: Option<&Tensor>,
692        _encoder_outputs: Option<&Tensor>,
693        _decoder_input_ids: Option<&Tensor>,
694        train: bool,
695    ) -> Result<LMModelOutput, RustBertError> {
696        let base_model_output = match layer_past {
697            Cache::GPTJCache(layer_past) => self.model.transformer.forward_t(
698                input_ids,
699                layer_past,
700                attention_mask,
701                token_type_ids,
702                position_ids,
703                input_embeds,
704                train,
705            ),
706            Cache::None => self.model.transformer.forward_t(
707                input_ids,
708                None,
709                attention_mask,
710                token_type_ids,
711                position_ids,
712                input_embeds,
713                train,
714            ),
715            _ => {
716                return Err(RustBertError::ValueError(
717                    "Cache not compatible with GPT-J Model".into(),
718                ));
719            }
720        }?;
721
722        let lm_logits = base_model_output.output.apply(&self.model.lm_head);
723
724        Ok(LMModelOutput {
725            lm_logits,
726            cache: Cache::GPTJCache(base_model_output.cache),
727        })
728    }
729
730    fn prepare_inputs_for_generation<'a>(
731        &self,
732        input_ids: Tensor,
733        _encoder_outputs: Option<&'a Tensor>,
734        past: Cache,
735        attention_mask: Tensor,
736    ) -> PreparedInput<'a> {
737        match past {
738            Cache::GPTJCache(past) => {
739                if past.is_some() {
740                    PreparedInput {
741                        prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
742                        prepared_attention_mask: Some(attention_mask),
743                        prepared_encoder_output: None,
744                        prepared_decoder_input: None,
745                        prepared_position_ids: None,
746                        prepared_past: Cache::GPTJCache(past),
747                    }
748                } else {
749                    PreparedInput {
750                        prepared_input: Some(input_ids),
751                        prepared_attention_mask: Some(attention_mask),
752                        prepared_encoder_output: None,
753                        prepared_decoder_input: None,
754                        prepared_position_ids: None,
755                        prepared_past: Cache::GPTJCache(None),
756                    }
757                }
758            }
759            Cache::None => PreparedInput {
760                prepared_input: Some(input_ids),
761                prepared_attention_mask: Some(attention_mask),
762                prepared_encoder_output: None,
763                prepared_decoder_input: None,
764                prepared_position_ids: None,
765                prepared_past: Cache::GPTJCache(None),
766            },
767            _ => panic!("Cache type incompatible with GPT-J"),
768        }
769    }
770
771    fn reorder_cache(
772        &self,
773        past: &mut Cache,
774        _encoder_outputs: Option<Tensor>,
775        beam_indices: &Tensor,
776    ) -> Option<Tensor> {
777        match past {
778            Cache::GPTJCache(cached_decoder_state) => match cached_decoder_state {
779                Some(old_cache) => {
780                    for layer_state in old_cache.iter_mut() {
781                        if layer_state.is_some() {
782                            layer_state.as_mut().unwrap().reorder_cache(beam_indices)
783                        };
784                    }
785                    None
786                }
787                None => None,
788            },
789            Cache::None => None,
790            _ => {
791                panic!("Invalid cache for GPT-J model");
792            }
793        }
794    }
795}
796
797impl LanguageGenerator for GptJGenerator {}