rust_bert/models/openai_gpt/
openai_gpt_model.rs

1// Copyright 2018-present, the HuggingFace Inc. team
2// Copyright 2018-present, The OpenAI Team Authors
3// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4// Copyright 2019 Guillaume Becquin
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//     http://www.apache.org/licenses/LICENSE-2.0
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::common::dropout::Dropout;
16use crate::common::embeddings::process_ids_embeddings_pair;
17use crate::common::linear::{linear_no_bias, LinearNoBias};
18use crate::gpt2::Gpt2Config;
19use crate::openai_gpt::transformer::Block;
20use crate::pipelines::common::{ModelType, TokenizerOption};
21use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
22use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
23use crate::{Config, RustBertError};
24use std::borrow::{Borrow, BorrowMut};
25use tch::kind::Kind::Int64;
26use tch::nn::embedding;
27use tch::{nn, Device, Tensor};
28
29/// # GPT Pretrained model weight files
30pub struct OpenAiGptModelResources;
31
32/// # GPT Pretrained model config files
33pub struct OpenAiGptConfigResources;
34
35/// # GPT Pretrained model vocab files
36pub struct OpenAiGptVocabResources;
37
38/// # GPT Pretrained model merges files
39pub struct OpenAiGptMergesResources;
40
41impl OpenAiGptModelResources {
42    /// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
43    pub const GPT: (&'static str, &'static str) = (
44        "openai-gpt/model",
45        "https://huggingface.co/openai-gpt/resolve/main/rust_model.ot",
46    );
47}
48
49impl OpenAiGptConfigResources {
50    /// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
51    pub const GPT: (&'static str, &'static str) = (
52        "openai-gpt/config",
53        "https://huggingface.co/openai-gpt/resolve/main/config.json",
54    );
55}
56
57impl OpenAiGptVocabResources {
58    /// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
59    pub const GPT: (&'static str, &'static str) = (
60        "openai-gpt/vocab",
61        "https://huggingface.co/openai-gpt/resolve/main/vocab.json",
62    );
63}
64
65impl OpenAiGptMergesResources {
66    /// Shared under MIT license by the OpenAI team at <https://github.com/openai/finetune-transformer-lm>. Modified with conversion to C-array format.
67    pub const GPT: (&'static str, &'static str) = (
68        "openai-gpt/merges",
69        "https://huggingface.co/openai-gpt/resolve/main/merges.txt",
70    );
71}
72
73/// # OpenAI GPT model configuration
74/// Defines the OpenAI GPT model architecture (e.g. number of layers, hidden layer size, label mapping...)
75pub type OpenAiGptConfig = Gpt2Config;
76
77/// # GPT Base model
78/// Base architecture for GPT model. Usually complemented with a task-specific head, such as a language model head. As opposed to GPT2, GPT does not give the possibility to re-use past activations as an input.
79/// It is made of the following blocks:
80/// - `tokens_embed`: `token` embeddings
81/// - `positions_embed`: `position` embeddings
82/// - `h`: Encoder (transformer) made of a vector of layers. Each layer is made of a multi-head attention layer, layer-normalization layers and a MLP made of linear layers.
83/// - `output_hidden_states`: flag indicating if the model should return all hidden states (as opposed to only the last layer)
84/// - `output_attentions`: flag indicating if the model should return activation weights
85pub struct OpenAiGptModel {
86    tokens_embed: nn::Embedding,
87    positions_embed: nn::Embedding,
88    drop: Dropout,
89    h: Vec<Block>,
90    output_hidden_states: bool,
91    output_attentions: bool,
92}
93
94impl OpenAiGptModel {
95    /// Build a new `OpenAiGptModel`
96    ///
97    /// # Arguments
98    ///
99    /// * `p` - Variable store path for the root of the GPT model
100    /// * `config` - `OpenAiGptConfig` object defining the model architecture
101    ///
102    /// # Example
103    ///
104    /// ```no_run
105    /// use rust_bert::openai_gpt::{OpenAiGptConfig, OpenAiGptModel};
106    /// use rust_bert::Config;
107    /// use std::path::Path;
108    /// use tch::{nn, Device};
109    ///
110    /// let config_path = Path::new("path/to/config.json");
111    /// let device = Device::Cpu;
112    /// let p = nn::VarStore::new(device);
113    /// let config = OpenAiGptConfig::from_file(config_path);
114    /// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&p.root() / "gpt", &config);
115    /// ```
116    pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel
117    where
118        P: Borrow<nn::Path<'p>>,
119    {
120        let p = p.borrow();
121
122        let tokens_embed = embedding(
123            p / "tokens_embed",
124            config.vocab_size,
125            config.n_embd,
126            Default::default(),
127        );
128        let positions_embed = embedding(
129            p / "positions_embed",
130            config.n_positions,
131            config.n_embd,
132            Default::default(),
133        );
134
135        let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
136        let drop = Dropout::new(embd_pdrop);
137        let mut h: Vec<Block> = vec![];
138        let h_path = p / "h";
139        for layer_index in 0..config.n_layer {
140            h.push(Block::new(&h_path / layer_index, config, true));
141        }
142        let output_attentions = config.output_attentions.unwrap_or(false);
143        let output_hidden_states = config.output_hidden_states.unwrap_or(false);
144        OpenAiGptModel {
145            tokens_embed,
146            positions_embed,
147            drop,
148            h,
149            output_hidden_states,
150            output_attentions,
151        }
152    }
153
154    /// Forward pass through the model
155    ///
156    /// # Arguments
157    ///
158    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
159    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
160    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
161    /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
162    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
163    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
164    ///
165    /// # Returns
166    ///
167    /// * `OpenAiGptModelOutput` containing:
168    ///   - `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
169    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
170    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
171    ///
172    /// # Example
173    ///
174    /// ```no_run
175    /// # use tch::{nn, Device, Tensor, no_grad};
176    /// # use rust_bert::Config;
177    /// # use std::path::Path;
178    /// # use tch::kind::Kind::{Int64, Double};
179    /// use rust_bert::gpt2::Gpt2Config;
180    /// use rust_bert::openai_gpt::OpenAiGptModel;
181    /// # let config_path = Path::new("path/to/config.json");
182    /// # let vocab_path = Path::new("path/to/vocab.txt");
183    /// # let device = Device::Cpu;
184    /// # let vs = nn::VarStore::new(device);
185    /// # let config = Gpt2Config::from_file(config_path);
186    /// # let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config);
187    /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
188    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
189    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
190    /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
191    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
192    ///     .expand(&[batch_size, sequence_length], true);
193    ///
194    /// let model_output = no_grad(|| {
195    ///     gpt_model
196    ///         .forward_t(
197    ///             Some(&input_tensor),
198    ///             Some(&attention_mask),
199    ///             Some(&token_type_ids),
200    ///             Some(&position_ids),
201    ///             None,
202    ///             false,
203    ///         )
204    ///         .unwrap()
205    /// });
206    /// ```
207    pub fn forward_t(
208        &self,
209        input_ids: Option<&Tensor>,
210        attention_mask: Option<&Tensor>,
211        token_type_ids: Option<&Tensor>,
212        position_ids: Option<&Tensor>,
213        input_embeds: Option<&Tensor>,
214        train: bool,
215    ) -> Result<OpenAiGptModelOutput, RustBertError> {
216        let (calc_input_embeddings, input_shape, _) =
217            process_ids_embeddings_pair(input_ids, input_embeds, &self.tokens_embed)?;
218        let input_embeddings =
219            input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
220        let seq_length = input_shape[1];
221
222        let position_ids = match position_ids {
223            Some(value) => value.copy(),
224            None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
225        };
226
227        let attention_mask = attention_mask.as_ref().map(|value| {
228            ((value
229                .view((input_embeddings.size()[0], -1))
230                .unsqueeze(1)
231                .unsqueeze(2)
232                - 1.0)
233                * 10000.0)
234                .to_kind(input_embeddings.kind())
235        });
236
237        let position_embeds = position_ids.apply(&self.positions_embed);
238        let token_type_embeds = match token_type_ids {
239            Some(value) => value.apply(&self.tokens_embed),
240            None => Tensor::zeros_like(&position_embeds),
241        };
242        let mut hidden_state: Tensor =
243            (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
244        let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
245            Some(vec![])
246        } else {
247            None
248        };
249        let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
250            Some(vec![])
251        } else {
252            None
253        };
254
255        for layer in &self.h {
256            let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
257            hidden_state = temp.0;
258            if let Some(attentions) = all_attentions.borrow_mut() {
259                attentions.push(temp.1.unwrap());
260            };
261            if let Some(hidden_states) = all_hidden_states.borrow_mut() {
262                hidden_states.push(hidden_state.as_ref().copy());
263            };
264        }
265
266        Ok(OpenAiGptModelOutput {
267            hidden_state,
268            all_hidden_states,
269            all_attentions,
270        })
271    }
272}
273
274/// # GPT Language Modeling head
275/// GPT model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings
276/// It is made of the following blocks:
277/// - `transformer`: Base Gpt2Model
278/// - `lm_head`: Linear layer without bias tied to the weights of the token id embeddings
279pub struct OpenAIGPTLMHeadModel {
280    transformer: OpenAiGptModel,
281    lm_head: LinearNoBias,
282}
283
284impl OpenAIGPTLMHeadModel {
285    /// Build a new `OpenAIGPTLMHeadModel`
286    ///
287    /// # Arguments
288    ///
289    /// * `p` - Variable store path for the root of the GPT model
290    /// * `config` - `Gpt2Config` object defining the model architecture
291    ///
292    /// # Example
293    ///
294    /// ```no_run
295    /// use rust_bert::gpt2::Gpt2Config;
296    /// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
297    /// use rust_bert::Config;
298    /// use std::path::Path;
299    /// use tch::{nn, Device};
300    ///
301    /// let config_path = Path::new("path/to/config.json");
302    /// let device = Device::Cpu;
303    /// let p = nn::VarStore::new(device);
304    /// let config = Gpt2Config::from_file(config_path);
305    /// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&p.root() / "gpt", &config);
306    /// ```
307    pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel
308    where
309        P: Borrow<nn::Path<'p>>,
310    {
311        let p = p.borrow();
312
313        let transformer = OpenAiGptModel::new(p, config);
314        let lm_head = linear_no_bias(
315            p / "lm_head",
316            config.n_embd,
317            config.vocab_size,
318            Default::default(),
319        );
320        OpenAIGPTLMHeadModel {
321            transformer,
322            lm_head,
323        }
324    }
325
326    /// Forward pass through the model
327    ///
328    /// # Arguments
329    ///
330    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
331    /// * `_layer_past` - Unused for GPT
332    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
333    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
334    /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
335    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
336    /// * `_encoder_outputs` - Unused for GPT
337    /// * `_decoder_input_ids` - Unused for GPT
338    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
339    ///
340    ///
341    /// # Returns
342    ///
343    /// * `LMModelOutput` containing:
344    ///   - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
345    ///   - `cache` - None
346    ///   - `encoder_hidden_states` - None
347    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
348    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
349    ///
350    /// # Example
351    ///
352    /// ```no_run
353    /// # use tch::{nn, Device, Tensor, no_grad};
354    /// # use rust_bert::Config;
355    /// # use std::path::Path;
356    /// # use tch::kind::Kind::{Int64, Double};
357    /// use rust_bert::gpt2::Gpt2Config;
358    /// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
359    /// use rust_bert::pipelines::generation_utils::Cache;
360    /// # let config_path = Path::new("path/to/config.json");
361    /// # let vocab_path = Path::new("path/to/vocab.txt");
362    /// # let device = Device::Cpu;
363    /// # let vs = nn::VarStore::new(device);
364    /// # let config = Gpt2Config::from_file(config_path);
365    /// # let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
366    ///  let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
367    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
368    ///  let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
369    ///  let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
370    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
371    ///
372    ///  let model_output = no_grad(|| {
373    ///    gpt_model
374    ///         .forward_t(Some(&input_tensor),
375    ///                    Cache::None,
376    ///                    Some(&attention_mask),
377    ///                    Some(&token_type_ids),
378    ///                    Some(&position_ids),
379    ///                    None,
380    ///                    None,
381    ///                    None,
382    ///                    false).unwrap()
383    ///    });
384    /// ```
385    pub fn forward_t(
386        &self,
387        input_ids: Option<&Tensor>,
388        _layer_past: Cache,
389        attention_mask: Option<&Tensor>,
390        token_type_ids: Option<&Tensor>,
391        position_ids: Option<&Tensor>,
392        input_embeds: Option<&Tensor>,
393        _encoder_outputs: Option<&Tensor>,
394        _decoder_input_ids: Option<&Tensor>,
395        train: bool,
396    ) -> Result<LMModelOutput, RustBertError> {
397        let base_model_output = self.transformer.forward_t(
398            input_ids,
399            attention_mask,
400            token_type_ids,
401            position_ids,
402            input_embeds,
403            train,
404        )?;
405
406        let lm_logits = base_model_output.hidden_state.apply(&self.lm_head);
407        Ok(LMModelOutput {
408            lm_logits,
409            cache: Cache::None,
410        })
411    }
412}
413
414/// Container for the OpenAI GPT model output.
415pub struct OpenAiGptModelOutput {
416    /// Hidden state of the last layer of the decoder, or logits for a custom head
417    /// module after the decoder (e.g. vocabulary logits for language modeling tasks)
418    pub hidden_state: Tensor,
419    /// Hidden states for all intermediate layers
420    pub all_hidden_states: Option<Vec<Tensor>>,
421    /// Attention weights for all intermediate layers
422    pub all_attentions: Option<Vec<Tensor>>,
423}
424
425/// # Language generation model based on the GPT architecture
426pub struct OpenAIGenerator {
427    model: OpenAIGPTLMHeadModel,
428    tokenizer: TokenizerOption,
429    var_store: nn::VarStore,
430    generate_config: GenerateConfig,
431    bos_token_id: Option<i64>,
432    eos_token_ids: Option<Vec<i64>>,
433    pad_token_id: Option<i64>,
434    is_encoder_decoder: bool,
435    vocab_size: i64,
436    decoder_start_id: Option<i64>,
437    max_position_embeddings: i64,
438}
439
440impl OpenAIGenerator {
441    /// Build a new `OpenAIGenerator`
442    ///
443    /// # Arguments
444    ///
445    /// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
446    ///
447    /// # Example
448    ///
449    /// ```no_run
450    /// # fn main() -> anyhow::Result<()> {
451    /// use rust_bert::openai_gpt::OpenAIGenerator;
452    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
453    /// let generate_config = GenerateConfig {
454    ///     max_length: Some(30),
455    ///     do_sample: true,
456    ///     num_beams: 5,
457    ///     temperature: 1.1,
458    ///     num_return_sequences: 3,
459    ///     ..Default::default()
460    /// };
461    /// let gpt_generator = OpenAIGenerator::new(generate_config)?;
462    /// # Ok(())
463    /// # }
464    /// ```
465    pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
466        let vocab_path = generate_config.vocab_resource.get_local_path()?;
467        let merges_path = generate_config
468            .merges_resource
469            .as_ref()
470            .ok_or_else(|| {
471                RustBertError::InvalidConfigurationError(
472                    "GPT expects a merges resources to be provided".to_string(),
473                )
474            })?
475            .get_local_path()?;
476
477        let tokenizer = TokenizerOption::from_file(
478            ModelType::OpenAiGpt,
479            vocab_path.to_str().unwrap(),
480            Some(merges_path.to_str().unwrap()),
481            true,
482            None,
483            None,
484        )?;
485
486        Self::new_with_tokenizer(generate_config, tokenizer)
487    }
488
489    pub fn new_with_tokenizer(
490        generate_config: GenerateConfig,
491        tokenizer: TokenizerOption,
492    ) -> Result<OpenAIGenerator, RustBertError> {
493        generate_config.validate();
494
495        let config_path = generate_config.config_resource.get_local_path()?;
496        let device = generate_config.device;
497
498        let mut var_store = nn::VarStore::new(device);
499        let config = Gpt2Config::from_file(config_path);
500        let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
501        crate::resources::load_weights(
502            &generate_config.model_resource,
503            &mut var_store,
504            generate_config.kind,
505            device,
506        )?;
507
508        let bos_token_id = tokenizer.get_bos_id();
509        let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
510        let pad_token_id = tokenizer.get_pad_id();
511        let is_encoder_decoder = false;
512        let vocab_size = config.vocab_size;
513        let decoder_start_id = config.decoder_start_token_id;
514        let max_position_embeddings = config.n_positions;
515
516        Ok(OpenAIGenerator {
517            model,
518            tokenizer,
519            var_store,
520            generate_config,
521            bos_token_id,
522            eos_token_ids,
523            pad_token_id,
524            is_encoder_decoder,
525            vocab_size,
526            decoder_start_id,
527            max_position_embeddings,
528        })
529    }
530}
531
532impl PrivateLanguageGenerator for OpenAIGenerator {
533    fn _get_tokenizer(&self) -> &TokenizerOption {
534        &self.tokenizer
535    }
536    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
537        &mut self.tokenizer
538    }
539    fn get_device(&self) -> Device {
540        self.var_store.device()
541    }
542    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
543        Ok(&mut self.var_store)
544    }
545    fn get_config(&self) -> &GenerateConfig {
546        &self.generate_config
547    }
548    fn get_bos_id(&self) -> Option<i64> {
549        self.bos_token_id
550    }
551    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
552        self.eos_token_ids.as_ref()
553    }
554    fn get_pad_id(&self) -> Option<i64> {
555        self.pad_token_id
556    }
557    fn is_encoder_decoder(&self) -> bool {
558        self.is_encoder_decoder
559    }
560    fn get_vocab_size(&self) -> i64 {
561        self.vocab_size
562    }
563    fn get_decoder_start_id(&self) -> Option<i64> {
564        self.decoder_start_id
565    }
566    fn get_max_positions_embeddings(&self) -> Option<i64> {
567        Some(self.max_position_embeddings)
568    }
569
570    fn forward_t(
571        &self,
572        input_ids: Option<&Tensor>,
573        _layer_past: Cache,
574        attention_mask: Option<&Tensor>,
575        token_type_ids: Option<&Tensor>,
576        position_ids: Option<&Tensor>,
577        input_embeds: Option<&Tensor>,
578        _encoder_outputs: Option<&Tensor>,
579        _decoder_input_ids: Option<&Tensor>,
580        train: bool,
581    ) -> Result<LMModelOutput, RustBertError> {
582        self.model.forward_t(
583            input_ids,
584            _layer_past,
585            attention_mask,
586            token_type_ids,
587            position_ids,
588            input_embeds,
589            _encoder_outputs,
590            _decoder_input_ids,
591            train,
592        )
593    }
594}
595
596impl LanguageGenerator for OpenAIGenerator {}