1use crate::common::activations::Activation;
16use crate::common::dropout::Dropout;
17use crate::common::embeddings::process_ids_embeddings_pair;
18use crate::gpt2::transformer::Block;
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21 PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::{Config, RustBertError};
25use serde::{Deserialize, Serialize};
26use std::borrow::{Borrow, BorrowMut};
27use tch::kind::Kind::Int64;
28use tch::nn::embedding;
29use tch::{nn, Device, Kind, Tensor};
30
31pub struct Gpt2ModelResources;
33
34pub struct Gpt2ConfigResources;
36
37pub struct Gpt2VocabResources;
39
40pub struct Gpt2MergesResources;
42
43impl Gpt2ModelResources {
44 pub const GPT2: (&'static str, &'static str) = (
46 "gpt2/model",
47 "https://huggingface.co/gpt2/resolve/main/rust_model.ot",
48 );
49 pub const GPT2_MEDIUM: (&'static str, &'static str) = (
51 "gpt2-medium/model",
52 "https://huggingface.co/gpt2-medium/resolve/main/rust_model.ot",
53 );
54 pub const GPT2_LARGE: (&'static str, &'static str) = (
56 "gpt2-large/model",
57 "https://huggingface.co/gpt2-large/resolve/main/rust_model.ot",
58 );
59 pub const GPT2_XL: (&'static str, &'static str) = (
61 "gpt2-xl/model",
62 "https://huggingface.co/gpt2-xl/resolve/main/rust_model.ot",
63 );
64 pub const DISTIL_GPT2: (&'static str, &'static str) = (
66 "distilgpt2/model",
67 "https://huggingface.co/distilgpt2/resolve/main/rust_model.ot",
68 );
69 pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
71 "dialogpt-medium/model",
72 "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot",
73 );
74}
75
76impl Gpt2ConfigResources {
77 pub const GPT2: (&'static str, &'static str) = (
79 "gpt2/config",
80 "https://huggingface.co/gpt2/resolve/main/config.json",
81 );
82 pub const GPT2_MEDIUM: (&'static str, &'static str) = (
84 "gpt2-medium/config",
85 "https://huggingface.co/gpt2-medium/resolve/main/config.json",
86 );
87 pub const GPT2_LARGE: (&'static str, &'static str) = (
89 "gpt2-large/config",
90 "https://huggingface.co/gpt2-large/resolve/main/config.json",
91 );
92 pub const GPT2_XL: (&'static str, &'static str) = (
94 "gpt2-xl/config",
95 "https://huggingface.co/gpt2-xl/resolve/main/config.json",
96 );
97 pub const DISTIL_GPT2: (&'static str, &'static str) = (
99 "distilgpt2/config",
100 "https://huggingface.co/distilgpt2/resolve/main/config.json",
101 );
102 pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
104 "dialogpt-medium/config",
105 "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json",
106 );
107}
108
109impl Gpt2VocabResources {
110 pub const GPT2: (&'static str, &'static str) = (
112 "gpt2/vocab",
113 "https://huggingface.co/gpt2/resolve/main/vocab.json",
114 );
115 pub const GPT2_MEDIUM: (&'static str, &'static str) = (
117 "gpt2-medium/vocab",
118 "https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
119 );
120 pub const GPT2_LARGE: (&'static str, &'static str) = (
122 "gpt2-large/vocab",
123 "https://huggingface.co/gpt2-large/resolve/main/vocab.json",
124 );
125 pub const GPT2_XL: (&'static str, &'static str) = (
127 "gpt2-xl/vocab",
128 "https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
129 );
130 pub const DISTIL_GPT2: (&'static str, &'static str) = (
132 "distilgpt2/vocab",
133 "https://huggingface.co/distilgpt2/resolve/main/vocab.json",
134 );
135 pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
137 "dialogpt-medium/vocab",
138 "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json",
139 );
140}
141
142impl Gpt2MergesResources {
143 pub const GPT2: (&'static str, &'static str) = (
145 "gpt2/merges",
146 "https://huggingface.co/gpt2/resolve/main/merges.txt",
147 );
148 pub const GPT2_MEDIUM: (&'static str, &'static str) = (
150 "gpt2-medium/merges",
151 "https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
152 );
153 pub const GPT2_LARGE: (&'static str, &'static str) = (
155 "gpt2-large/merges",
156 "https://huggingface.co/gpt2-large/resolve/main/merges.txt",
157 );
158 pub const GPT2_XL: (&'static str, &'static str) = (
160 "gpt2-xl/merges",
161 "https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
162 );
163 pub const DISTIL_GPT2: (&'static str, &'static str) = (
165 "distilgpt2/merges",
166 "https://huggingface.co/distilgpt2/resolve/main/merges.txt",
167 );
168 pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
170 "dialogpt-medium/merges",
171 "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt",
172 );
173}
174
175#[derive(Debug, Serialize, Deserialize, Clone)]
176pub struct Gpt2Config {
180 pub attn_pdrop: Option<f64>,
181 pub embd_pdrop: Option<f64>,
182 pub hidden_dropout_prob: Option<f64>,
183 pub afn: Option<Activation>,
184 pub initializer_range: f64,
185 pub layer_norm_epsilon: f64,
186 pub n_ctx: i64,
187 pub n_embd: i64,
188 pub n_head: i64,
189 pub n_layer: i64,
190 pub n_positions: i64,
191 pub num_labels: Option<i64>,
192 pub output_past: Option<bool>,
193 pub output_attentions: Option<bool>,
194 pub output_hidden_states: Option<bool>,
195 pub resid_pdrop: Option<f64>,
196 pub vocab_size: i64,
197 pub decoder_start_token_id: Option<i64>,
198 pub forced_bos_token_id: Option<i64>,
199 pub forced_eos_token_id: Option<i64>,
200}
201
202impl Config for Gpt2Config {}
203
204impl Default for Gpt2Config {
205 fn default() -> Self {
206 Gpt2Config {
207 attn_pdrop: Some(0.1),
208 embd_pdrop: Some(0.1),
209 hidden_dropout_prob: None,
210 afn: Some(Activation::gelu_new),
211 initializer_range: 0.02,
212 layer_norm_epsilon: 1e-5,
213 n_ctx: 1024,
214 n_embd: 768,
215 n_head: 12,
216 n_layer: 12,
217 n_positions: 0,
218 num_labels: None,
219 output_past: None,
220 output_attentions: None,
221 output_hidden_states: None,
222 resid_pdrop: Some(0.1),
223 vocab_size: 50257,
224 decoder_start_token_id: None,
225 forced_bos_token_id: None,
226 forced_eos_token_id: None,
227 }
228 }
229}
230
231pub struct Gpt2Model {
241 wte: nn::Embedding,
242 wpe: nn::Embedding,
243 drop: Dropout,
244 ln_f: nn::LayerNorm,
245 h: Vec<Block>,
246 output_past: bool,
247 output_hidden_states: bool,
248 output_attentions: bool,
249}
250
251impl Gpt2Model {
252 pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model
274 where
275 P: Borrow<nn::Path<'p>>,
276 {
277 let p = p.borrow() / "transformer";
278
279 let wte = embedding(
280 &p / "wte",
281 config.vocab_size,
282 config.n_embd,
283 Default::default(),
284 );
285 let wpe = embedding(
286 &p / "wpe",
287 config.n_positions,
288 config.n_embd,
289 Default::default(),
290 );
291
292 let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
293 let drop = Dropout::new(embd_pdrop);
294 let layer_norm_config = nn::LayerNormConfig {
295 eps: config.layer_norm_epsilon,
296 ..Default::default()
297 };
298 let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
299 let mut h: Vec<Block> = vec![];
300 let h_path = &p / "h";
301 for layer_index in 0..config.n_layer {
302 h.push(Block::new(&h_path / layer_index, config, true));
303 }
304 let output_attentions = config.output_attentions.unwrap_or(false);
305 let output_past = config.output_past.unwrap_or(true);
306 let output_hidden_states = config.output_hidden_states.unwrap_or(false);
307
308 Gpt2Model {
309 wte,
310 wpe,
311 drop,
312 ln_f,
313 h,
314 output_past,
315 output_hidden_states,
316 output_attentions,
317 }
318 }
319
320 pub fn forward_t(
389 &self,
390 input_ids: Option<&Tensor>,
391 layer_past: Option<&Vec<Tensor>>,
392 attention_mask: Option<&Tensor>,
393 token_type_ids: Option<&Tensor>,
394 position_ids: Option<&Tensor>,
395 input_embeds: Option<&Tensor>,
396 train: bool,
397 ) -> Result<Gpt2ModelOutput, RustBertError> {
398 let (calc_input_embeddings, input_size, _) =
399 process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?;
400 let input_embeddings =
401 input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
402
403 let seq_length = input_size[1];
404
405 let (layer_past, layer_past_length) = match layer_past {
406 Some(value) => {
407 assert_eq!(
408 value.len(),
409 self.h.len(),
410 "Past activations vector must be of length equal to the number of layers"
411 );
412 (
413 value
414 .iter()
415 .map(|v| Some(v.copy()))
416 .collect::<Vec<Option<Tensor>>>(),
417 value[0].size()[3],
418 )
419 }
420 None => {
421 let mut out = Vec::with_capacity(self.h.len());
422 out.resize_with(self.h.len(), || None::<Tensor>);
423 (out, 0)
424 }
425 };
426
427 let position_ids = match position_ids {
428 Some(value) => value.copy(),
429 None => Tensor::arange_start(
430 layer_past_length,
431 seq_length + layer_past_length,
432 (Int64, input_embeddings.device()),
433 )
434 .unsqueeze(0),
435 };
436
437 let attention_mask: Option<Tensor> = attention_mask.map(|value| {
438 let attention_mask = value
439 .view((input_embeddings.size()[0], -1))
440 .unsqueeze(1)
441 .unsqueeze(2)
442 .to_kind(input_embeddings.kind());
443
444 let attention_mask: Tensor = (1.0 - attention_mask) * (-10000.0);
445 attention_mask.to_kind(input_embeddings.kind())
446 });
447
448 let position_embeds = position_ids.apply(&self.wpe);
449 let token_type_embeds = match token_type_ids {
450 Some(value) => value.apply(&self.wte),
451 None => Tensor::zeros_like(&position_embeds),
452 };
453 let mut hidden_state: Tensor =
454 (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
455 let mut all_presents: Option<Vec<Tensor>> =
456 if self.output_past { Some(vec![]) } else { None };
457 let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
458 Some(vec![])
459 } else {
460 None
461 };
462 let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
463 Some(vec![])
464 } else {
465 None
466 };
467
468 let layer_iter = self.h.iter().zip(layer_past);
469 for layer_values in layer_iter {
470 let (layer, past) = layer_values;
471 let temp =
472 layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train);
473 hidden_state = temp.0;
474 if let Some(presents) = all_presents.borrow_mut() {
475 presents.push(temp.1);
476 };
477 if let Some(attentions) = all_attentions.borrow_mut() {
478 attentions.push(temp.2.unwrap());
479 };
480 if let Some(hidden_states) = all_hidden_states.borrow_mut() {
481 hidden_states.push(hidden_state.as_ref().copy());
482 };
483 }
484
485 Ok(Gpt2ModelOutput {
486 output: hidden_state.apply(&self.ln_f),
487 cache: all_presents,
488 all_hidden_states,
489 all_attentions,
490 })
491 }
492}
493
494pub struct GPT2LMHeadModel {
499 transformer: Gpt2Model,
500}
501
502impl GPT2LMHeadModel {
503 pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel
525 where
526 P: Borrow<nn::Path<'p>>,
527 {
528 let p = p.borrow();
529
530 let transformer = Gpt2Model::new(p, config);
531
532 GPT2LMHeadModel { transformer }
533 }
534
535 pub fn forward_t(
536 &self,
537 input_ids: Option<&Tensor>,
538 layer_past: Option<&Vec<Tensor>>,
539 attention_mask: Option<&Tensor>,
540 token_type_ids: Option<&Tensor>,
541 position_ids: Option<&Tensor>,
542 input_embeds: Option<&Tensor>,
543 train: bool,
544 ) -> Result<LMModelOutput, RustBertError> {
545 let base_model_output = self.transformer.forward_t(
546 input_ids,
547 layer_past,
548 attention_mask,
549 token_type_ids,
550 position_ids,
551 input_embeds,
552 train,
553 )?;
554
555 let lm_logits = base_model_output
556 .output
557 .linear::<Tensor>(&self.transformer.wte.ws, None);
558 Ok(LMModelOutput {
559 lm_logits,
560 cache: Cache::GPT2Cache(base_model_output.cache),
561 })
562 }
563}
564
565pub struct Gpt2ModelOutput {
567 pub output: Tensor,
570 pub cache: Option<Vec<Tensor>>,
572 pub all_hidden_states: Option<Vec<Tensor>>,
574 pub all_attentions: Option<Vec<Tensor>>,
576}
577
578pub struct GPT2Generator {
580 model: GPT2LMHeadModel,
581 tokenizer: TokenizerOption,
582 var_store: nn::VarStore,
583 generate_config: GenerateConfig,
584 bos_token_id: Option<i64>,
585 eos_token_ids: Option<Vec<i64>>,
586 pad_token_id: Option<i64>,
587 is_encoder_decoder: bool,
588 vocab_size: i64,
589 decoder_start_id: Option<i64>,
590 max_position_embeddings: i64,
591}
592
593impl GPT2Generator {
594 pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
620 let vocab_path = generate_config.vocab_resource.get_local_path()?;
621 let merges_path = generate_config
622 .merges_resource
623 .as_ref()
624 .ok_or_else(|| {
625 RustBertError::InvalidConfigurationError(
626 "GPT2 expects a merges resources to be provided".to_string(),
627 )
628 })?
629 .get_local_path()?;
630
631 let tokenizer = TokenizerOption::from_file(
632 ModelType::GPT2,
633 vocab_path.to_str().unwrap(),
634 Some(merges_path.to_str().unwrap()),
635 false,
636 None,
637 None,
638 )?;
639
640 Self::new_with_tokenizer(generate_config, tokenizer)
641 }
642
643 pub fn new_with_tokenizer(
644 generate_config: GenerateConfig,
645 tokenizer: TokenizerOption,
646 ) -> Result<GPT2Generator, RustBertError> {
647 let config_path = generate_config.config_resource.get_local_path()?;
648 let device = generate_config.device;
649
650 generate_config.validate();
651 let mut var_store = nn::VarStore::new(device);
652
653 let config = Gpt2Config::from_file(config_path);
654 let model = GPT2LMHeadModel::new(var_store.root(), &config);
655 crate::resources::load_weights(
656 &generate_config.model_resource,
657 &mut var_store,
658 generate_config.kind,
659 device,
660 )?;
661
662 let bos_token_id = tokenizer.get_bos_id();
663 let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
664 let pad_token_id = tokenizer.get_pad_id();
665 let max_position_embeddings = config.n_positions;
666 let is_encoder_decoder = false;
667 let vocab_size = config.vocab_size;
668 let decoder_start_id = config.decoder_start_token_id;
669
670 Ok(GPT2Generator {
671 model,
672 tokenizer,
673 var_store,
674 generate_config,
675 bos_token_id,
676 eos_token_ids,
677 pad_token_id,
678 is_encoder_decoder,
679 vocab_size,
680 decoder_start_id,
681 max_position_embeddings,
682 })
683 }
684}
685
686impl PrivateLanguageGenerator for GPT2Generator {
687 fn _get_tokenizer(&self) -> &TokenizerOption {
688 &self.tokenizer
689 }
690 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
691 &mut self.tokenizer
692 }
693 fn get_device(&self) -> Device {
694 self.var_store.device()
695 }
696 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
697 Ok(&mut self.var_store)
698 }
699 fn get_config(&self) -> &GenerateConfig {
700 &self.generate_config
701 }
702 fn get_bos_id(&self) -> Option<i64> {
703 self.bos_token_id
704 }
705 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
706 self.eos_token_ids.as_ref()
707 }
708 fn get_pad_id(&self) -> Option<i64> {
709 self.pad_token_id
710 }
711 fn is_encoder_decoder(&self) -> bool {
712 self.is_encoder_decoder
713 }
714 fn get_vocab_size(&self) -> i64 {
715 self.vocab_size
716 }
717 fn get_decoder_start_id(&self) -> Option<i64> {
718 self.decoder_start_id
719 }
720 fn get_max_positions_embeddings(&self) -> Option<i64> {
721 Some(self.max_position_embeddings)
722 }
723
724 fn forward_t(
725 &self,
726 input_ids: Option<&Tensor>,
727 layer_past: Cache,
728 attention_mask: Option<&Tensor>,
729 token_type_ids: Option<&Tensor>,
730 position_ids: Option<&Tensor>,
731 input_embeds: Option<&Tensor>,
732 _encoder_outputs: Option<&Tensor>,
733 _decoder_input_ids: Option<&Tensor>,
734 train: bool,
735 ) -> Result<LMModelOutput, RustBertError> {
736 match layer_past {
737 Cache::GPT2Cache(layer_past) => self.model.forward_t(
738 input_ids,
739 layer_past.as_ref(),
740 attention_mask,
741 token_type_ids,
742 position_ids,
743 input_embeds,
744 train,
745 ),
746 Cache::None => self.model.forward_t(
747 input_ids,
748 None,
749 attention_mask,
750 token_type_ids,
751 position_ids,
752 input_embeds,
753 train,
754 ),
755 _ => Err(RustBertError::ValueError(
756 "Cache not compatible with GPT2 Model".into(),
757 )),
758 }
759 }
760
761 fn prepare_inputs_for_generation<'a>(
762 &self,
763 input_ids: Tensor,
764 _encoder_outputs: Option<&'a Tensor>,
765 past: Cache,
766 attention_mask: Tensor,
767 ) -> PreparedInput<'a> {
768 let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
769 .masked_fill(&attention_mask.eq(0), 1);
770
771 match past {
772 Cache::GPT2Cache(past) => {
773 if past.is_some() {
774 PreparedInput {
775 prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
776 prepared_attention_mask: Some(attention_mask),
777 prepared_encoder_output: None,
778 prepared_decoder_input: None,
779 prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
780 prepared_past: Cache::GPT2Cache(past),
781 }
782 } else {
783 PreparedInput {
784 prepared_input: Some(input_ids),
785 prepared_attention_mask: Some(attention_mask),
786 prepared_encoder_output: None,
787 prepared_decoder_input: None,
788 prepared_position_ids: Some(position_ids),
789 prepared_past: Cache::GPT2Cache(None),
790 }
791 }
792 }
793 Cache::None => PreparedInput {
794 prepared_input: Some(input_ids),
795 prepared_attention_mask: Some(attention_mask),
796 prepared_encoder_output: None,
797 prepared_decoder_input: None,
798 prepared_position_ids: Some(position_ids),
799 prepared_past: Cache::GPT2Cache(None),
800 },
801 _ => panic!("Cache type incompatible with GPT2"),
802 }
803 }
804
805 fn reorder_cache(
806 &self,
807 past: &mut Cache,
808 _encoder_outputs: Option<Tensor>,
809 beam_indices: &Tensor,
810 ) -> Option<Tensor> {
811 match past {
812 Cache::GPT2Cache(cached_decoder_state) => match cached_decoder_state {
813 Some(value) => {
814 for layer_past in value.iter_mut() {
815 *layer_past = layer_past.index_select(1, beam_indices);
816 }
817 None
818 }
819 None => None,
820 },
821 Cache::None => None,
822 _ => {
823 panic!("Invalid cache for GPT2 model");
824 }
825 }
826 }
827}
828
829impl LanguageGenerator for GPT2Generator {}