rust_bert/models/gpt_neo/
gpt_neo_model.rs

1// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
2// Copyright 2021 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::dropout::Dropout;
14use crate::common::embeddings::process_ids_embeddings_pair;
15use crate::gpt_neo::decoder::GptNeoBlock;
16use crate::gpt_neo::LayerState;
17use crate::pipelines::common::{ModelType, TokenizerOption};
18use crate::pipelines::generation_utils::private_generation_utils::{
19    PreparedInput, PrivateLanguageGenerator,
20};
21use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
22use crate::{Activation, Config, RustBertError};
23use serde::{Deserialize, Serialize};
24use std::borrow::{Borrow, BorrowMut};
25use tch::{nn, Device, Kind, Tensor};
26
27/// # GPT-Neo Pretrained model weight files
28pub struct GptNeoModelResources;
29
30/// # GPT-Neo Pretrained model config files
31pub struct GptNeoConfigResources;
32
33/// # GPT-Neo Pretrained model vocab files
34pub struct GptNeoVocabResources;
35
36/// # GPT-Neo Pretrained model merges files
37pub struct GptNeoMergesResources;
38
39impl GptNeoModelResources {
40    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
41    pub const GPT_NEO_125M: (&'static str, &'static str) = (
42        "gpt-neo-125M/model",
43        "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/rust_model.ot",
44    );
45    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
46    pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
47        "gpt-neo-1_3B/model",
48        "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/rust_model.ot",
49    );
50    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
51    pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
52        "gpt-neo-2_7B/model",
53        "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/rust_model.ot",
54    );
55}
56
57impl GptNeoConfigResources {
58    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
59    pub const GPT_NEO_125M: (&'static str, &'static str) = (
60        "gpt-neo-125M/config",
61        "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
62    );
63    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
64    pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
65        "gpt-neo-1_3B/config",
66        "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
67    );
68    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
69    pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
70        "gpt-neo-2_7B/config",
71        "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/config.json",
72    );
73}
74
75impl GptNeoVocabResources {
76    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
77    pub const GPT_NEO_125M: (&'static str, &'static str) = (
78        "gpt-neo-125M/vocab",
79        "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json",
80    );
81    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
82    pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
83        "gpt-neo-1_3B/vocab",
84        "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/vocab.json",
85    );
86    /// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
87    pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
88        "gpt-neo-2_7B/vocab",
89        "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/vocab.json",
90    );
91}
92
93impl GptNeoMergesResources {
94    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
95    pub const GPT_NEO_125M: (&'static str, &'static str) = (
96        "gpt-neo-125M/merges",
97        "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt",
98    );
99    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
100    pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
101        "gpt-neo-1_3B/merges",
102        "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/merges.txt",
103    );
104    /// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
105    pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
106        "gpt-neo-2_7B/merges",
107        "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/merges.txt",
108    );
109}
110
111#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
112#[serde(rename_all = "camelCase")]
113/// #GPT-Neo attention layer type
114pub enum AttentionLayerType {
115    Global,
116    Local,
117}
118
119#[derive(Debug, Serialize, Deserialize, Clone)]
120/// # GPT-Neo model configuration
121/// Defines the GPT-Neo model architecture (e.g. number of layers, hidden layer size, vocab size...).
122pub struct GptNeoConfig {
123    pub activation_function: Activation,
124    pub attention_dropout: f64,
125    pub attention_layers: Vec<AttentionLayerType>,
126    pub attention_types: Vec<(Vec<AttentionLayerType>, i64)>,
127    pub intermediate_size: Option<i64>,
128    pub bos_token_id: i64,
129    pub eos_token_id: i64,
130    pub forced_bos_token_id: Option<i64>,
131    pub forced_eos_token_id: Option<i64>,
132    pub vocab_size: i64,
133    pub num_layers: i64,
134    pub num_heads: i64,
135    pub hidden_size: i64,
136    pub window_size: i64,
137    pub embed_dropout: f64,
138    pub initializer_range: f64,
139    pub layer_norm_epsilon: f64,
140    pub max_position_embeddings: i64,
141    pub output_past: Option<bool>,
142    pub output_attentions: Option<bool>,
143    pub output_hidden_states: Option<bool>,
144    pub resid_dropout: f64,
145    pub decoder_start_token_id: Option<i64>,
146}
147
148impl Config for GptNeoConfig {}
149
150impl Default for GptNeoConfig {
151    fn default() -> Self {
152        GptNeoConfig {
153            activation_function: Activation::gelu_new,
154            attention_dropout: 0.0,
155            attention_layers: [AttentionLayerType::Global, AttentionLayerType::Local]
156                .iter()
157                .cycle()
158                .take(24)
159                .map(|layer_type| layer_type.to_owned())
160                .collect::<Vec<AttentionLayerType>>(),
161            attention_types: vec![(
162                vec![AttentionLayerType::Global, AttentionLayerType::Local],
163                12,
164            )],
165            intermediate_size: None,
166            bos_token_id: 50256,
167            eos_token_id: 50256,
168            forced_bos_token_id: None,
169            forced_eos_token_id: None,
170            vocab_size: 50257,
171            num_layers: 24,
172            num_heads: 16,
173            hidden_size: 2048,
174            window_size: 256,
175            embed_dropout: 0.0,
176            initializer_range: 0.02,
177            layer_norm_epsilon: 1e-5,
178            max_position_embeddings: 2048,
179            output_past: None,
180            output_attentions: None,
181            output_hidden_states: None,
182            resid_dropout: 0.0,
183            decoder_start_token_id: None,
184        }
185    }
186}
187
188/// # GPT-Neo Base model
189/// Base architecture for GPT-Neo models. Task-specific models will be built from this common base model
190/// It is made of the following blocks:
191/// - `word_embeddings`: Word embeddings
192/// - `position_embeddings`: Position embeddings
193/// - `layers`: Vector of `GptNeoBlock` (transformer part of the model)
194pub struct GptNeoModel {
195    word_embeddings: nn::Embedding,
196    position_embeddings: nn::Embedding,
197    layers: Vec<GptNeoBlock>,
198    dropout: Dropout,
199    layer_norm: nn::LayerNorm,
200    output_attentions: bool,
201    output_hidden_states: bool,
202}
203
204impl GptNeoModel {
205    /// Build a new `GptNeoModel`
206    ///
207    /// # Arguments
208    ///
209    /// * `p` - Variable store path for the root of the GPT-Neo model
210    /// * `config` - `GptNeoConfig` object defining the model architecture
211    ///
212    /// # Example
213    ///
214    /// ```no_run
215    /// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
216    /// use rust_bert::Config;
217    /// use std::path::Path;
218    /// use tch::{nn, Device};
219    ///
220    /// let config_path = Path::new("path/to/config.json");
221    /// let device = Device::Cpu;
222    /// let p = nn::VarStore::new(device);
223    /// let config = GptNeoConfig::from_file(config_path);
224    /// let gpt_neo_model = GptNeoModel::new(&p.root(), &config).unwrap();
225    /// ```
226    pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoModel, RustBertError>
227    where
228        P: Borrow<nn::Path<'p>>,
229    {
230        let p = p.borrow();
231
232        let word_embeddings = nn::embedding(
233            p / "wte",
234            config.vocab_size,
235            config.hidden_size,
236            Default::default(),
237        );
238
239        let position_embeddings = nn::embedding(
240            p / "wpe",
241            config.max_position_embeddings,
242            config.hidden_size,
243            Default::default(),
244        );
245
246        let dropout = Dropout::new(config.embed_dropout);
247
248        let layer_norm_config = nn::LayerNormConfig {
249            eps: config.layer_norm_epsilon,
250            ..Default::default()
251        };
252
253        let layer_norm = nn::layer_norm(p / "ln_f", vec![config.hidden_size], layer_norm_config);
254
255        let mut layers: Vec<GptNeoBlock> = Vec::with_capacity(config.num_layers as usize);
256        let p_layers = p / "h";
257        for layer_index in 0..config.num_layers {
258            layers.push(GptNeoBlock::new(
259                &p_layers / layer_index,
260                layer_index as usize,
261                config,
262            ));
263        }
264
265        let output_attentions = config.output_attentions.unwrap_or(false);
266        let output_hidden_states = config.output_hidden_states.unwrap_or(false);
267
268        Ok(GptNeoModel {
269            word_embeddings,
270            position_embeddings,
271            layers,
272            dropout,
273            layer_norm,
274            output_attentions,
275            output_hidden_states,
276        })
277    }
278
279    /// Forward pass through the model
280    ///
281    /// # Arguments
282    ///
283    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
284    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
285    /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
286    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
287    /// * `layer_states` - Optional Vector `Option<Vec<Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention of each layer.
288    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
289    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
290    ///
291    /// # Returns
292    ///
293    /// * `Result<GptNeoModelOutput, RustBertError>` containing:
294    ///   - `hidden_states` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
295    ///   - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers
296    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*)
297    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containing the attention weights for each layer
298    ///
299    /// # Example
300    ///
301    /// ```no_run
302    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
303    /// # use rust_bert::Config;
304    /// # use std::path::Path;
305    /// # use tch::kind::Kind::{Int64, Double};
306    /// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
307    /// # let config_path = Path::new("path/to/config.json");
308    /// # let vocab_path = Path::new("path/to/vocab.txt");
309    /// # let device = Device::Cpu;
310    /// # let vs = nn::VarStore::new(device);
311    /// # let config = GptNeoConfig::from_file(config_path);
312    /// # let gpt_neo_model = GptNeoModel::new(&vs.root(), &config).unwrap();
313    /// let (batch_size, sequence_length) = (64, 128);
314    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
315    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
316    ///
317    /// let model_output = no_grad(|| {
318    ///     gpt_neo_model.forward_t(
319    ///         Some(&input_tensor),
320    ///         Some(&attention_mask),
321    ///         None,
322    ///         None,
323    ///         None,
324    ///         None,
325    ///         false,
326    ///     )
327    /// });
328    /// ```
329    pub fn forward_t(
330        &self,
331        input_ids: Option<&Tensor>,
332        input_embeds: Option<&Tensor>,
333        token_type_ids: Option<&Tensor>,
334        position_ids: Option<&Tensor>,
335        layer_states: Option<Vec<Option<LayerState>>>,
336        attention_mask: Option<&Tensor>,
337        train: bool,
338    ) -> Result<GptNeoModelOutput, RustBertError> {
339        let (calc_input_embeddings, input_shape, device) =
340            process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
341
342        let (batch_size, current_sequence_length) = (input_shape[0], input_shape[1]);
343
344        let past_length = if let Some(past_state_value) = &layer_states {
345            if let Some(first_layer_state) = &past_state_value[0] {
346                let mut size_iter = first_layer_state.prev_key.size().into_iter().rev();
347                size_iter.next();
348                size_iter.next().unwrap()
349            } else {
350                0
351            }
352        } else {
353            0
354        };
355
356        let full_sequence_length = current_sequence_length + past_length;
357
358        let calc_position_ids = if position_ids.is_none() {
359            let position_ids =
360                Tensor::arange_start(past_length, full_sequence_length, (Kind::Int64, device));
361            Some(
362                position_ids
363                    .unsqueeze(0)
364                    .view([-1, current_sequence_length]),
365            )
366        } else {
367            None
368        };
369
370        let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
371
372        let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
373        let position_embeds = position_ids.apply(&self.position_embeddings);
374
375        let attention_mask = attention_mask.map(|attention_mask_value| {
376            let attention_mask = attention_mask_value
377                .view([batch_size, -1])
378                .unsqueeze(1)
379                .unsqueeze(1);
380            let attention_mask = attention_mask.to_kind(position_embeds.kind());
381            (1 - attention_mask) * -1e4
382        });
383
384        let mut hidden_state = input_embeds + position_embeds;
385        if let Some(token_type_ids) = token_type_ids {
386            hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings);
387        };
388        hidden_state = hidden_state.apply_t(&self.dropout, train);
389        let mut output_shape = input_shape;
390        output_shape.push(*hidden_state.size().last().unwrap());
391
392        let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
393            Some(vec![])
394        } else {
395            None
396        };
397        let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
398            Some(vec![])
399        } else {
400            None
401        };
402        let old_cache = layer_states.unwrap_or_else(|| vec![None; self.layers.len()]);
403        let mut next_cache = vec![None; self.layers.len()];
404
405        let mut x: Option<Tensor> = None;
406        let mut attention_weights: Option<Tensor>;
407
408        for ((layer_idx, layer), layer_state) in
409            self.layers.iter().enumerate().zip(old_cache.into_iter())
410        {
411            let temp = if let Some(x_value) = &x {
412                layer.forward_t(
413                    x_value,
414                    layer_state.as_ref(),
415                    attention_mask.as_ref(),
416                    train,
417                )?
418            } else {
419                layer.forward_t(
420                    &hidden_state,
421                    layer_state.as_ref(),
422                    attention_mask.as_ref(),
423                    train,
424                )?
425            };
426            x = Some(temp.0);
427            attention_weights = temp.1;
428            next_cache[layer_idx] = temp.2;
429            if let Some(attentions) = all_attentions.borrow_mut() {
430                attentions.push(std::mem::take(&mut attention_weights.unwrap()));
431            };
432            if let Some(hidden_states) = all_hidden_states.borrow_mut() {
433                hidden_states.push(x.as_ref().unwrap().copy());
434            };
435        }
436
437        let hidden_states = x
438            .unwrap()
439            .apply(&self.layer_norm)
440            .view(output_shape.as_slice());
441
442        Ok(GptNeoModelOutput {
443            hidden_states,
444            next_cache: Some(next_cache),
445            all_hidden_states,
446            all_attentions,
447        })
448    }
449}
450
451/// # GPT-Neo Model for causal language modeling
452/// Gpt-Neo model with a vocabulary decoding head. The language model decoding head is tied to the word embedding matrix weights
453/// It is made of the following blocks:
454/// - `transformer`: `GptNeoModel` Base ProphetNet model
455pub struct GptNeoForCausalLM {
456    transformer: GptNeoModel,
457}
458
459impl GptNeoForCausalLM {
460    /// Build a new `GptNeoForCausalLM`
461    ///
462    /// # Arguments
463    ///
464    /// * `p` - Variable store path for the root of the GPT-Neo model
465    /// * `config` - `GptNeoConfig` object defining the model architecture
466    ///
467    /// # Example
468    ///
469    /// ```no_run
470    /// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoForCausalLM};
471    /// use rust_bert::Config;
472    /// use std::path::Path;
473    /// use tch::{nn, Device};
474    ///
475    /// let config_path = Path::new("path/to/config.json");
476    /// let device = Device::Cpu;
477    /// let p = nn::VarStore::new(device);
478    /// let config = GptNeoConfig::from_file(config_path);
479    /// let gpt_neo_model = GptNeoForCausalLM::new(&p.root(), &config).unwrap();
480    /// ```
481    pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoForCausalLM, RustBertError>
482    where
483        P: Borrow<nn::Path<'p>>,
484    {
485        let p = p.borrow();
486
487        let transformer = GptNeoModel::new(p / "transformer", config)?;
488
489        Ok(GptNeoForCausalLM { transformer })
490    }
491
492    /// Forward pass through the model
493    ///
494    /// # Arguments
495    ///
496    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
497    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
498    /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
499    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
500    /// * `layer_states` - Optional Vector `Option<Vec<Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention of each layer.
501    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
502    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
503    ///
504    /// # Returns
505    ///
506    /// * `Result<GptNeoModelLMOutput, RustBertError>` containing:
507    ///   - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
508    ///   - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers
509    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*)
510    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containing the attention weights for each layer
511    ///
512    /// # Example
513    ///
514    /// ```no_run
515    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
516    /// # use rust_bert::Config;
517    /// # use std::path::Path;
518    /// # use tch::kind::Kind::{Int64, Double};
519    /// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoForCausalLM};
520    /// # let config_path = Path::new("path/to/config.json");
521    /// # let vocab_path = Path::new("path/to/vocab.txt");
522    /// # let device = Device::Cpu;
523    /// # let vs = nn::VarStore::new(device);
524    /// # let config = GptNeoConfig::from_file(config_path);
525    /// # let gpt_neo_model = GptNeoForCausalLM::new(&vs.root(), &config).unwrap();
526    /// let (batch_size, sequence_length) = (64, 128);
527    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
528    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
529    ///
530    /// let model_output = no_grad(|| {
531    ///     gpt_neo_model.forward_t(
532    ///         Some(&input_tensor),
533    ///         Some(&attention_mask),
534    ///         None,
535    ///         None,
536    ///         None,
537    ///         None,
538    ///         false,
539    ///     )
540    /// });
541    /// ```
542    pub fn forward_t(
543        &self,
544        input_ids: Option<&Tensor>,
545        input_embeds: Option<&Tensor>,
546        token_type_ids: Option<&Tensor>,
547        position_ids: Option<&Tensor>,
548        layer_states: Option<Vec<Option<LayerState>>>,
549        attention_mask: Option<&Tensor>,
550        train: bool,
551    ) -> Result<GptNeoModelLMOutput, RustBertError> {
552        let base_model_output = self.transformer.forward_t(
553            input_ids,
554            input_embeds,
555            token_type_ids,
556            position_ids,
557            layer_states,
558            attention_mask,
559            train,
560        )?;
561
562        let lm_logits = base_model_output
563            .hidden_states
564            .linear::<Tensor>(&self.transformer.word_embeddings.ws, None);
565
566        Ok(GptNeoModelLMOutput {
567            lm_logits,
568            next_cache: base_model_output.next_cache,
569            all_hidden_states: base_model_output.all_hidden_states,
570            all_attentions: base_model_output.all_attentions,
571        })
572    }
573}
574
575/// Container for the GPT-Neo model output.
576pub struct GptNeoModelOutput {
577    /// Last hidden states from the model
578    pub hidden_states: Tensor,
579    /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
580    pub next_cache: Option<Vec<Option<LayerState>>>,
581    /// Hidden states for all intermediate layers
582    pub all_hidden_states: Option<Vec<Tensor>>,
583    /// Attention weights for all intermediate layers
584    pub all_attentions: Option<Vec<Tensor>>,
585}
586
587///Container holding a GPT-Neo model with LM head output
588pub struct GptNeoModelLMOutput {
589    /// logits
590    pub lm_logits: Tensor,
591    /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
592    pub next_cache: Option<Vec<Option<LayerState>>>,
593    /// Hidden states for all intermediate layers
594    pub all_hidden_states: Option<Vec<Tensor>>,
595    /// Attention weights for all intermediate layers
596    pub all_attentions: Option<Vec<Tensor>>,
597}
598
599/// # Language generation model based on the GPT-Neo architecture
600pub struct GptNeoGenerator {
601    model: GptNeoForCausalLM,
602    tokenizer: TokenizerOption,
603    var_store: nn::VarStore,
604    generate_config: GenerateConfig,
605    bos_token_id: Option<i64>,
606    eos_token_ids: Option<Vec<i64>>,
607    pad_token_id: Option<i64>,
608    is_encoder_decoder: bool,
609    vocab_size: i64,
610    decoder_start_id: Option<i64>,
611    max_position_embeddings: i64,
612}
613
614impl GptNeoGenerator {
615    /// Build a new `GPTNeoGenerator`
616    ///
617    /// # Arguments
618    ///
619    /// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
620    ///
621    /// # Example
622    ///
623    /// ```no_run
624    /// # fn main() -> anyhow::Result<()> {
625    /// use rust_bert::gpt_neo::GptNeoGenerator;
626    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
627    ///
628    /// let generate_config = GenerateConfig {
629    ///     max_length: Some(30),
630    ///     do_sample: true,
631    ///     num_beams: 5,
632    ///     temperature: 1.1,
633    ///     num_return_sequences: 3,
634    ///     ..Default::default()
635    /// };
636    /// let gpt_neo_generator = GptNeoGenerator::new(generate_config)?;
637    /// # Ok(())
638    /// # }
639    /// ```
640    pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
641        let vocab_path = generate_config.vocab_resource.get_local_path()?;
642        let merges_path = generate_config
643            .merges_resource
644            .as_ref()
645            .ok_or_else(|| {
646                RustBertError::InvalidConfigurationError(
647                    "GPT-Neo expects a merges resources to be provided".to_string(),
648                )
649            })?
650            .get_local_path()?;
651
652        let tokenizer = TokenizerOption::from_file(
653            ModelType::GPTNeo,
654            vocab_path.to_str().unwrap(),
655            Some(merges_path.to_str().unwrap()),
656            false,
657            None,
658            None,
659        )?;
660
661        Self::new_with_tokenizer(generate_config, tokenizer)
662    }
663
664    pub fn new_with_tokenizer(
665        generate_config: GenerateConfig,
666        tokenizer: TokenizerOption,
667    ) -> Result<GptNeoGenerator, RustBertError> {
668        let config_path = generate_config.config_resource.get_local_path()?;
669        let device = generate_config.device;
670
671        generate_config.validate();
672        let mut var_store = nn::VarStore::new(device);
673        let config = GptNeoConfig::from_file(config_path);
674        let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
675        crate::resources::load_weights(
676            &generate_config.model_resource,
677            &mut var_store,
678            generate_config.kind,
679            device,
680        )?;
681
682        let bos_token_id = tokenizer.get_bos_id();
683        let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
684        let pad_token_id = tokenizer.get_pad_id();
685        let is_encoder_decoder = false;
686        let vocab_size = config.vocab_size;
687        let decoder_start_id = config.decoder_start_token_id;
688        let max_position_embeddings = config.max_position_embeddings;
689
690        Ok(GptNeoGenerator {
691            model,
692            tokenizer,
693            var_store,
694            generate_config,
695            bos_token_id,
696            eos_token_ids,
697            pad_token_id,
698            is_encoder_decoder,
699            vocab_size,
700            decoder_start_id,
701            max_position_embeddings,
702        })
703    }
704}
705
706impl PrivateLanguageGenerator for GptNeoGenerator {
707    fn _get_tokenizer(&self) -> &TokenizerOption {
708        &self.tokenizer
709    }
710    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
711        &mut self.tokenizer
712    }
713    fn get_device(&self) -> Device {
714        self.var_store.device()
715    }
716    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
717        Ok(&mut self.var_store)
718    }
719    fn get_config(&self) -> &GenerateConfig {
720        &self.generate_config
721    }
722    fn get_bos_id(&self) -> Option<i64> {
723        self.bos_token_id
724    }
725    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
726        self.eos_token_ids.as_ref()
727    }
728    fn get_pad_id(&self) -> Option<i64> {
729        self.pad_token_id
730    }
731    fn is_encoder_decoder(&self) -> bool {
732        self.is_encoder_decoder
733    }
734    fn get_vocab_size(&self) -> i64 {
735        self.vocab_size
736    }
737    fn get_decoder_start_id(&self) -> Option<i64> {
738        self.decoder_start_id
739    }
740
741    fn get_max_positions_embeddings(&self) -> Option<i64> {
742        Some(self.max_position_embeddings)
743    }
744
745    fn forward_t(
746        &self,
747        input_ids: Option<&Tensor>,
748        layer_past: Cache,
749        attention_mask: Option<&Tensor>,
750        token_type_ids: Option<&Tensor>,
751        position_ids: Option<&Tensor>,
752        input_embeds: Option<&Tensor>,
753        _encoder_outputs: Option<&Tensor>,
754        _decoder_input_ids: Option<&Tensor>,
755        train: bool,
756    ) -> Result<LMModelOutput, RustBertError> {
757        let base_model_output = match layer_past {
758            Cache::GPTNeoCache(layer_past) => self.model.forward_t(
759                input_ids,
760                input_embeds,
761                token_type_ids,
762                position_ids,
763                layer_past,
764                attention_mask,
765                train,
766            ),
767            Cache::None => self.model.forward_t(
768                input_ids,
769                input_embeds,
770                token_type_ids,
771                position_ids,
772                None,
773                attention_mask,
774                train,
775            ),
776            _ => {
777                return Err(RustBertError::ValueError(
778                    "Cache not compatible with GPT-Neo Model".into(),
779                ));
780            }
781        }?;
782
783        Ok(LMModelOutput {
784            lm_logits: base_model_output.lm_logits,
785            cache: Cache::GPTNeoCache(base_model_output.next_cache),
786        })
787    }
788    fn prepare_inputs_for_generation<'a>(
789        &self,
790        input_ids: Tensor,
791        _encoder_outputs: Option<&'a Tensor>,
792        past: Cache,
793        attention_mask: Tensor,
794    ) -> PreparedInput<'a> {
795        let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
796            .masked_fill(&attention_mask.eq(0), 1);
797
798        match past {
799            Cache::GPTNeoCache(past) => {
800                if past.is_some() {
801                    PreparedInput {
802                        prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
803                        prepared_attention_mask: Some(attention_mask),
804                        prepared_encoder_output: None,
805                        prepared_decoder_input: None,
806                        prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
807                        prepared_past: Cache::GPTNeoCache(past),
808                    }
809                } else {
810                    PreparedInput {
811                        prepared_input: Some(input_ids),
812                        prepared_attention_mask: Some(attention_mask),
813                        prepared_encoder_output: None,
814                        prepared_decoder_input: None,
815                        prepared_position_ids: Some(position_ids),
816                        prepared_past: Cache::GPTNeoCache(None),
817                    }
818                }
819            }
820            Cache::None => PreparedInput {
821                prepared_input: Some(input_ids),
822                prepared_attention_mask: Some(attention_mask),
823                prepared_encoder_output: None,
824                prepared_decoder_input: None,
825                prepared_position_ids: Some(position_ids),
826                prepared_past: Cache::GPTNeoCache(None),
827            },
828            _ => panic!("Cache type incompatible with GPT-Neo"),
829        }
830    }
831
832    fn reorder_cache(
833        &self,
834        past: &mut Cache,
835        _encoder_outputs: Option<Tensor>,
836        beam_indices: &Tensor,
837    ) -> Option<Tensor> {
838        match past {
839            Cache::GPTNeoCache(cached_decoder_state) => match cached_decoder_state {
840                Some(old_cache) => {
841                    for layer_state in old_cache.iter_mut() {
842                        if layer_state.is_some() {
843                            layer_state.as_mut().unwrap().reorder_cache(beam_indices)
844                        };
845                    }
846                    None
847                }
848                None => None,
849            },
850            Cache::None => None,
851            _ => {
852                panic!("Invalid cache for GPT-Neo model");
853            }
854        }
855    }
856}
857
858impl LanguageGenerator for GptNeoGenerator {}