1use crate::common::activations::Activation;
14use crate::common::dropout::Dropout;
15use crate::common::embeddings::process_ids_embeddings_pair;
16use crate::common::kind::get_min;
17use crate::gpt_j::attention::LayerState;
18use crate::gpt_j::transformer::GptJBlock;
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21 PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::{Config, RustBertError};
25use serde::{Deserialize, Serialize};
26use std::borrow::{Borrow, BorrowMut};
27use tch::nn::{embedding, Linear};
28use tch::{nn, Device, Tensor};
29
30pub struct GptJModelResources;
32
33pub struct GptJConfigResources;
35
36pub struct GptJVocabResources;
38
39pub struct GptJMergesResources;
41
42impl GptJModelResources {
57 pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
58 "gpt-j-tiny-random/model",
59 "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/rust_model.ot",
60 );
61}
62
63impl GptJConfigResources {
64 pub const GPT_J_6B: (&'static str, &'static str) = (
66 "gpt-j-6B/config",
67 "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json",
68 );
69 pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = (
70 "gpt-j-6B/config",
71 "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/config.json",
72 );
73 pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
74 "gpt-j-tiny-random/config",
75 "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/config.json",
76 );
77}
78
79impl GptJVocabResources {
80 pub const GPT_J_6B: (&'static str, &'static str) = (
82 "gpt-j-6B/vocab",
83 "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/vocab.json",
84 );
85 pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = (
86 "gpt-j-6B/vocab",
87 "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/vocab.json",
88 );
89 pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
90 "gpt-j-tiny-random/vocab",
91 "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/vocab.json",
92 );
93}
94
95impl GptJMergesResources {
96 pub const GPT_J_6B: (&'static str, &'static str) = (
98 "gpt-j-6B/merges",
99 "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/merges.txt",
100 );
101 pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = (
102 "gpt-j-6B/merges",
103 "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/merges.txt",
104 );
105 pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
106 "gpt-j-tiny-random/merges",
107 "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/merges.txt",
108 );
109}
110
111#[derive(Debug, Serialize, Deserialize, Clone)]
112pub struct GptJConfig {
115 pub attn_pdrop: Option<f64>,
116 pub embd_pdrop: Option<f64>,
117 pub hidden_dropout_prob: Option<f64>,
118 pub afn: Option<Activation>,
119 pub initializer_range: f64,
120 pub layer_norm_epsilon: f64,
121 pub n_embd: i64,
122 pub n_head: i64,
123 pub n_layer: i64,
124 pub n_positions: i64,
125 pub n_inner: Option<i64>,
126 pub num_labels: Option<i64>,
127 pub use_cache: Option<bool>,
128 pub output_attentions: Option<bool>,
129 pub output_hidden_states: Option<bool>,
130 pub resid_pdrop: Option<f64>,
131 pub rotary_dim: Option<i64>,
132 pub vocab_size: i64,
133 pub scale_attn_weights: Option<bool>,
134 #[serde(default = "default_preload_on_cpu")]
135 pub preload_on_cpu: bool,
136 pub decoder_start_token_id: Option<i64>,
137 pub forced_bos_token_id: Option<i64>,
138 pub forced_eos_token_id: Option<i64>,
139}
140
141impl Config for GptJConfig {}
142
143impl Default for GptJConfig {
144 fn default() -> Self {
145 GptJConfig {
146 attn_pdrop: Some(0.1),
147 embd_pdrop: Some(0.1),
148 hidden_dropout_prob: None,
149 afn: Some(Activation::gelu_new),
150 initializer_range: 0.02,
151 layer_norm_epsilon: 1e-5,
152 n_embd: 4096,
153 n_head: 16,
154 n_layer: 28,
155 n_positions: 2048,
156 n_inner: None,
157 num_labels: None,
158 use_cache: None,
159 output_attentions: None,
160 output_hidden_states: None,
161 resid_pdrop: Some(0.1),
162 rotary_dim: Some(64),
163 vocab_size: 50400,
164 scale_attn_weights: Some(true),
165 preload_on_cpu: default_preload_on_cpu(),
166 decoder_start_token_id: None,
167 forced_bos_token_id: None,
168 forced_eos_token_id: None,
169 }
170 }
171}
172
173fn default_preload_on_cpu() -> bool {
174 true
175}
176
177pub struct GptJModel {
186 wte: nn::Embedding,
187 drop: Dropout,
188 ln_f: nn::LayerNorm,
189 h: Vec<GptJBlock>,
190 use_cache: bool,
191 output_hidden_states: bool,
192 output_attentions: bool,
193}
194
195impl GptJModel {
196 pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJModel
218 where
219 P: Borrow<nn::Path<'p>>,
220 {
221 let p = p.borrow() / "transformer";
222
223 let wte = embedding(
224 &p / "wte",
225 config.vocab_size,
226 config.n_embd,
227 Default::default(),
228 );
229
230 let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
231 let drop = Dropout::new(embd_pdrop);
232
233 let layer_norm_config = nn::LayerNormConfig {
234 eps: config.layer_norm_epsilon,
235 ..Default::default()
236 };
237 let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
238
239 let mut h: Vec<GptJBlock> = vec![];
240 let h_path = &p / "h";
241 for layer_index in 0..config.n_layer {
242 h.push(GptJBlock::new(&h_path / layer_index, config));
243 }
244
245 let use_cache = config.use_cache.unwrap_or(true);
246 let output_attentions = config.output_attentions.unwrap_or(false);
247 let output_hidden_states = config.output_hidden_states.unwrap_or(false);
248
249 GptJModel {
250 wte,
251 drop,
252 ln_f,
253 h,
254 use_cache,
255 output_hidden_states,
256 output_attentions,
257 }
258 }
259
260 pub fn forward_t(
337 &self,
338 input_ids: Option<&Tensor>,
339 layer_past: Option<Vec<Option<LayerState>>>,
340 attention_mask: Option<&Tensor>,
341 token_type_ids: Option<&Tensor>,
342 _position_ids: Option<&Tensor>,
343 input_embeds: Option<&Tensor>,
344 train: bool,
345 ) -> Result<GptJModelOutput, RustBertError> {
346 let (calc_input_embeddings, _input_size, _device) =
347 process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?;
348
349 let input_embeddings =
350 input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
351
352 let (layer_past, _layer_past_length) = match layer_past {
353 Some(value) => {
354 if value.len() != self.h.len() {
355 return Err(RustBertError::ValueError(format!(
356 "Past activations vector length ({}) must be equal to the number of layers ({})",
357 value.len(),
358 self.h.len()
359 )));
360 } else {
361 let length = value.len();
362 (value, length)
363 }
364 }
365 None => {
366 let mut out = Vec::with_capacity(self.h.len());
367 out.resize_with(self.h.len(), || None);
368 (out, 0)
369 }
370 };
371
372 let kind_min = get_min(input_embeddings.kind())?;
373 let attention_mask: Option<Tensor> = attention_mask.map(|value| {
374 let attention_mask = value
375 .view((input_embeddings.size()[0], -1))
376 .unsqueeze(1)
377 .unsqueeze(2)
378 .to_kind(input_embeddings.kind());
379
380 (attention_mask.ones_like() - attention_mask.to_kind(input_embeddings.kind()))
381 * kind_min
382 });
383
384 let mut hidden_state: Tensor = input_embeddings.copy();
385 if let Some(token_type_ids) = token_type_ids {
386 let token_type_embeds = token_type_ids.apply(&self.wte);
387 hidden_state = hidden_state + token_type_embeds;
388 }
389 hidden_state = hidden_state.apply_t(&self.drop, train);
390
391 let mut all_presents: Option<Vec<Option<LayerState>>> = self.use_cache.then(Vec::new);
392 let mut all_hidden_states: Option<Vec<Tensor>> = self.output_hidden_states.then(Vec::new);
393 let mut all_attentions: Option<Vec<Tensor>> = self.output_attentions.then(Vec::new);
394
395 for (layer, past) in self.h.iter().zip(layer_past) {
396 let temp =
397 layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train);
398 hidden_state = temp.0;
399 if let Some(presents) = all_presents.borrow_mut() {
400 presents.push(temp.1);
401 };
402 if let Some(attentions) = all_attentions.borrow_mut() {
403 attentions.push(std::mem::take(&mut temp.2.unwrap()));
404 };
405 if let Some(hidden_states) = all_hidden_states.borrow_mut() {
406 hidden_states.push(std::mem::take(&mut hidden_state));
407 };
408 }
409
410 let output = hidden_state.apply(&self.ln_f);
411
412 Ok(GptJModelOutput {
413 output,
414 cache: all_presents,
415 all_hidden_states,
416 all_attentions,
417 })
418 }
419}
420
421pub struct GptJLMHeadModel {
426 transformer: GptJModel,
427 lm_head: Linear,
428}
429
430impl GptJLMHeadModel {
431 pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJLMHeadModel
453 where
454 P: Borrow<nn::Path<'p>>,
455 {
456 let p = p.borrow();
457
458 let transformer = GptJModel::new(p, config);
459 let lm_head = nn::linear(
460 p / "lm_head",
461 config.n_embd,
462 config.vocab_size,
463 Default::default(),
464 );
465
466 GptJLMHeadModel {
467 transformer,
468 lm_head,
469 }
470 }
471
472 pub fn forward_t(
473 &self,
474 input_ids: Option<&Tensor>,
475 layer_past: Cache,
476 attention_mask: Option<&Tensor>,
477 token_type_ids: Option<&Tensor>,
478 position_ids: Option<&Tensor>,
479 input_embeds: Option<&Tensor>,
480 _encoder_outputs: Option<&Tensor>,
481 _decoder_input_ids: Option<&Tensor>,
482 train: bool,
483 ) -> Result<LMModelOutput, RustBertError> {
484 let base_model_output = match layer_past {
485 Cache::GPTJCache(layer_past) => self.transformer.forward_t(
486 input_ids,
487 layer_past,
488 attention_mask,
489 token_type_ids,
490 position_ids,
491 input_embeds,
492 train,
493 ),
494 Cache::None => self.transformer.forward_t(
495 input_ids,
496 None,
497 attention_mask,
498 token_type_ids,
499 position_ids,
500 input_embeds,
501 train,
502 ),
503 _ => {
504 return Err(RustBertError::ValueError(
505 "Cache not compatible with GPT-J Model".into(),
506 ));
507 }
508 }?;
509
510 let lm_logits = base_model_output.output.apply(&self.lm_head);
511
512 Ok(LMModelOutput {
513 lm_logits,
514 cache: Cache::GPTJCache(base_model_output.cache),
515 })
516 }
517}
518
519pub struct GptJModelOutput {
521 pub output: Tensor,
524 pub cache: Option<Vec<Option<LayerState>>>,
526 pub all_hidden_states: Option<Vec<Tensor>>,
528 pub all_attentions: Option<Vec<Tensor>>,
530}
531
532pub struct GptJGenerator {
534 model: GptJLMHeadModel,
535 tokenizer: TokenizerOption,
536 var_store: nn::VarStore,
537 generate_config: GenerateConfig,
538 bos_token_id: Option<i64>,
539 eos_token_ids: Option<Vec<i64>>,
540 pad_token_id: Option<i64>,
541 is_encoder_decoder: bool,
542 vocab_size: i64,
543 decoder_start_id: Option<i64>,
544 max_position_embeddings: i64,
545}
546
547impl GptJGenerator {
548 pub fn new(generate_config: GenerateConfig) -> Result<GptJGenerator, RustBertError> {
574 let vocab_path = generate_config.vocab_resource.get_local_path()?;
575 let merges_path = generate_config
576 .merges_resource
577 .as_ref()
578 .ok_or_else(|| {
579 RustBertError::InvalidConfigurationError(
580 "GPT-J expects a merges resources to be provided".to_string(),
581 )
582 })?
583 .get_local_path()?;
584
585 let tokenizer = TokenizerOption::from_file(
586 ModelType::GPTJ,
587 vocab_path.to_str().unwrap(),
588 Some(merges_path.to_str().unwrap()),
589 false,
590 None,
591 None,
592 )?;
593
594 Self::new_with_tokenizer(generate_config, tokenizer)
595 }
596
597 pub fn new_with_tokenizer(
598 generate_config: GenerateConfig,
599 tokenizer: TokenizerOption,
600 ) -> Result<GptJGenerator, RustBertError> {
601 let config_path = generate_config.config_resource.get_local_path()?;
602 let device = generate_config.device;
603
604 generate_config.validate();
605 let mut var_store = nn::VarStore::new(device);
606
607 let config = GptJConfig::from_file(config_path);
608 let model = GptJLMHeadModel::new(var_store.root(), &config);
609 if config.preload_on_cpu && device != Device::Cpu {
610 var_store.set_device(Device::Cpu);
611 }
612 crate::resources::load_weights(
613 &generate_config.model_resource,
614 &mut var_store,
615 generate_config.kind,
616 device,
617 )?;
618 if device != Device::Cpu {
619 var_store.set_device(device);
620 }
621
622 let bos_token_id = tokenizer.get_bos_id();
623 let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
624 let pad_token_id = tokenizer.get_pad_id();
625 let max_position_embeddings = config.n_positions;
626 let is_encoder_decoder = false;
627 let vocab_size = config.vocab_size;
628 let decoder_start_id = config.decoder_start_token_id;
629
630 Ok(GptJGenerator {
631 model,
632 tokenizer,
633 var_store,
634 generate_config,
635 bos_token_id,
636 eos_token_ids,
637 pad_token_id,
638 is_encoder_decoder,
639 vocab_size,
640 decoder_start_id,
641 max_position_embeddings,
642 })
643 }
644}
645
646impl PrivateLanguageGenerator for GptJGenerator {
647 fn _get_tokenizer(&self) -> &TokenizerOption {
648 &self.tokenizer
649 }
650 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
651 &mut self.tokenizer
652 }
653 fn get_device(&self) -> Device {
654 self.var_store.device()
655 }
656 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
657 Ok(&mut self.var_store)
658 }
659 fn get_config(&self) -> &GenerateConfig {
660 &self.generate_config
661 }
662 fn get_bos_id(&self) -> Option<i64> {
663 self.bos_token_id
664 }
665 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
666 self.eos_token_ids.as_ref()
667 }
668 fn get_pad_id(&self) -> Option<i64> {
669 self.pad_token_id
670 }
671 fn is_encoder_decoder(&self) -> bool {
672 self.is_encoder_decoder
673 }
674 fn get_vocab_size(&self) -> i64 {
675 self.vocab_size
676 }
677 fn get_decoder_start_id(&self) -> Option<i64> {
678 self.decoder_start_id
679 }
680 fn get_max_positions_embeddings(&self) -> Option<i64> {
681 Some(self.max_position_embeddings)
682 }
683
684 fn forward_t(
685 &self,
686 input_ids: Option<&Tensor>,
687 layer_past: Cache,
688 attention_mask: Option<&Tensor>,
689 token_type_ids: Option<&Tensor>,
690 position_ids: Option<&Tensor>,
691 input_embeds: Option<&Tensor>,
692 _encoder_outputs: Option<&Tensor>,
693 _decoder_input_ids: Option<&Tensor>,
694 train: bool,
695 ) -> Result<LMModelOutput, RustBertError> {
696 let base_model_output = match layer_past {
697 Cache::GPTJCache(layer_past) => self.model.transformer.forward_t(
698 input_ids,
699 layer_past,
700 attention_mask,
701 token_type_ids,
702 position_ids,
703 input_embeds,
704 train,
705 ),
706 Cache::None => self.model.transformer.forward_t(
707 input_ids,
708 None,
709 attention_mask,
710 token_type_ids,
711 position_ids,
712 input_embeds,
713 train,
714 ),
715 _ => {
716 return Err(RustBertError::ValueError(
717 "Cache not compatible with GPT-J Model".into(),
718 ));
719 }
720 }?;
721
722 let lm_logits = base_model_output.output.apply(&self.model.lm_head);
723
724 Ok(LMModelOutput {
725 lm_logits,
726 cache: Cache::GPTJCache(base_model_output.cache),
727 })
728 }
729
730 fn prepare_inputs_for_generation<'a>(
731 &self,
732 input_ids: Tensor,
733 _encoder_outputs: Option<&'a Tensor>,
734 past: Cache,
735 attention_mask: Tensor,
736 ) -> PreparedInput<'a> {
737 match past {
738 Cache::GPTJCache(past) => {
739 if past.is_some() {
740 PreparedInput {
741 prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
742 prepared_attention_mask: Some(attention_mask),
743 prepared_encoder_output: None,
744 prepared_decoder_input: None,
745 prepared_position_ids: None,
746 prepared_past: Cache::GPTJCache(past),
747 }
748 } else {
749 PreparedInput {
750 prepared_input: Some(input_ids),
751 prepared_attention_mask: Some(attention_mask),
752 prepared_encoder_output: None,
753 prepared_decoder_input: None,
754 prepared_position_ids: None,
755 prepared_past: Cache::GPTJCache(None),
756 }
757 }
758 }
759 Cache::None => PreparedInput {
760 prepared_input: Some(input_ids),
761 prepared_attention_mask: Some(attention_mask),
762 prepared_encoder_output: None,
763 prepared_decoder_input: None,
764 prepared_position_ids: None,
765 prepared_past: Cache::GPTJCache(None),
766 },
767 _ => panic!("Cache type incompatible with GPT-J"),
768 }
769 }
770
771 fn reorder_cache(
772 &self,
773 past: &mut Cache,
774 _encoder_outputs: Option<Tensor>,
775 beam_indices: &Tensor,
776 ) -> Option<Tensor> {
777 match past {
778 Cache::GPTJCache(cached_decoder_state) => match cached_decoder_state {
779 Some(old_cache) => {
780 for layer_state in old_cache.iter_mut() {
781 if layer_state.is_some() {
782 layer_state.as_mut().unwrap().reorder_cache(beam_indices)
783 };
784 }
785 None
786 }
787 None => None,
788 },
789 Cache::None => None,
790 _ => {
791 panic!("Invalid cache for GPT-J model");
792 }
793 }
794 }
795}
796
797impl LanguageGenerator for GptJGenerator {}