1use crate::common::dropout::Dropout;
14use crate::common::embeddings::process_ids_embeddings_pair;
15use crate::gpt_neo::decoder::GptNeoBlock;
16use crate::gpt_neo::LayerState;
17use crate::pipelines::common::{ModelType, TokenizerOption};
18use crate::pipelines::generation_utils::private_generation_utils::{
19 PreparedInput, PrivateLanguageGenerator,
20};
21use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
22use crate::{Activation, Config, RustBertError};
23use serde::{Deserialize, Serialize};
24use std::borrow::{Borrow, BorrowMut};
25use tch::{nn, Device, Kind, Tensor};
26
27pub struct GptNeoModelResources;
29
30pub struct GptNeoConfigResources;
32
33pub struct GptNeoVocabResources;
35
36pub struct GptNeoMergesResources;
38
39impl GptNeoModelResources {
40 pub const GPT_NEO_125M: (&'static str, &'static str) = (
42 "gpt-neo-125M/model",
43 "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/rust_model.ot",
44 );
45 pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
47 "gpt-neo-1_3B/model",
48 "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/rust_model.ot",
49 );
50 pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
52 "gpt-neo-2_7B/model",
53 "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/rust_model.ot",
54 );
55}
56
57impl GptNeoConfigResources {
58 pub const GPT_NEO_125M: (&'static str, &'static str) = (
60 "gpt-neo-125M/config",
61 "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
62 );
63 pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
65 "gpt-neo-1_3B/config",
66 "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
67 );
68 pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
70 "gpt-neo-2_7B/config",
71 "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/config.json",
72 );
73}
74
75impl GptNeoVocabResources {
76 pub const GPT_NEO_125M: (&'static str, &'static str) = (
78 "gpt-neo-125M/vocab",
79 "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json",
80 );
81 pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
83 "gpt-neo-1_3B/vocab",
84 "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/vocab.json",
85 );
86 pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
88 "gpt-neo-2_7B/vocab",
89 "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/vocab.json",
90 );
91}
92
93impl GptNeoMergesResources {
94 pub const GPT_NEO_125M: (&'static str, &'static str) = (
96 "gpt-neo-125M/merges",
97 "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt",
98 );
99 pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
101 "gpt-neo-1_3B/merges",
102 "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/merges.txt",
103 );
104 pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
106 "gpt-neo-2_7B/merges",
107 "https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/merges.txt",
108 );
109}
110
111#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
112#[serde(rename_all = "camelCase")]
113pub enum AttentionLayerType {
115 Global,
116 Local,
117}
118
119#[derive(Debug, Serialize, Deserialize, Clone)]
120pub struct GptNeoConfig {
123 pub activation_function: Activation,
124 pub attention_dropout: f64,
125 pub attention_layers: Vec<AttentionLayerType>,
126 pub attention_types: Vec<(Vec<AttentionLayerType>, i64)>,
127 pub intermediate_size: Option<i64>,
128 pub bos_token_id: i64,
129 pub eos_token_id: i64,
130 pub forced_bos_token_id: Option<i64>,
131 pub forced_eos_token_id: Option<i64>,
132 pub vocab_size: i64,
133 pub num_layers: i64,
134 pub num_heads: i64,
135 pub hidden_size: i64,
136 pub window_size: i64,
137 pub embed_dropout: f64,
138 pub initializer_range: f64,
139 pub layer_norm_epsilon: f64,
140 pub max_position_embeddings: i64,
141 pub output_past: Option<bool>,
142 pub output_attentions: Option<bool>,
143 pub output_hidden_states: Option<bool>,
144 pub resid_dropout: f64,
145 pub decoder_start_token_id: Option<i64>,
146}
147
148impl Config for GptNeoConfig {}
149
150impl Default for GptNeoConfig {
151 fn default() -> Self {
152 GptNeoConfig {
153 activation_function: Activation::gelu_new,
154 attention_dropout: 0.0,
155 attention_layers: [AttentionLayerType::Global, AttentionLayerType::Local]
156 .iter()
157 .cycle()
158 .take(24)
159 .map(|layer_type| layer_type.to_owned())
160 .collect::<Vec<AttentionLayerType>>(),
161 attention_types: vec![(
162 vec![AttentionLayerType::Global, AttentionLayerType::Local],
163 12,
164 )],
165 intermediate_size: None,
166 bos_token_id: 50256,
167 eos_token_id: 50256,
168 forced_bos_token_id: None,
169 forced_eos_token_id: None,
170 vocab_size: 50257,
171 num_layers: 24,
172 num_heads: 16,
173 hidden_size: 2048,
174 window_size: 256,
175 embed_dropout: 0.0,
176 initializer_range: 0.02,
177 layer_norm_epsilon: 1e-5,
178 max_position_embeddings: 2048,
179 output_past: None,
180 output_attentions: None,
181 output_hidden_states: None,
182 resid_dropout: 0.0,
183 decoder_start_token_id: None,
184 }
185 }
186}
187
188pub struct GptNeoModel {
195 word_embeddings: nn::Embedding,
196 position_embeddings: nn::Embedding,
197 layers: Vec<GptNeoBlock>,
198 dropout: Dropout,
199 layer_norm: nn::LayerNorm,
200 output_attentions: bool,
201 output_hidden_states: bool,
202}
203
204impl GptNeoModel {
205 pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoModel, RustBertError>
227 where
228 P: Borrow<nn::Path<'p>>,
229 {
230 let p = p.borrow();
231
232 let word_embeddings = nn::embedding(
233 p / "wte",
234 config.vocab_size,
235 config.hidden_size,
236 Default::default(),
237 );
238
239 let position_embeddings = nn::embedding(
240 p / "wpe",
241 config.max_position_embeddings,
242 config.hidden_size,
243 Default::default(),
244 );
245
246 let dropout = Dropout::new(config.embed_dropout);
247
248 let layer_norm_config = nn::LayerNormConfig {
249 eps: config.layer_norm_epsilon,
250 ..Default::default()
251 };
252
253 let layer_norm = nn::layer_norm(p / "ln_f", vec![config.hidden_size], layer_norm_config);
254
255 let mut layers: Vec<GptNeoBlock> = Vec::with_capacity(config.num_layers as usize);
256 let p_layers = p / "h";
257 for layer_index in 0..config.num_layers {
258 layers.push(GptNeoBlock::new(
259 &p_layers / layer_index,
260 layer_index as usize,
261 config,
262 ));
263 }
264
265 let output_attentions = config.output_attentions.unwrap_or(false);
266 let output_hidden_states = config.output_hidden_states.unwrap_or(false);
267
268 Ok(GptNeoModel {
269 word_embeddings,
270 position_embeddings,
271 layers,
272 dropout,
273 layer_norm,
274 output_attentions,
275 output_hidden_states,
276 })
277 }
278
279 pub fn forward_t(
330 &self,
331 input_ids: Option<&Tensor>,
332 input_embeds: Option<&Tensor>,
333 token_type_ids: Option<&Tensor>,
334 position_ids: Option<&Tensor>,
335 layer_states: Option<Vec<Option<LayerState>>>,
336 attention_mask: Option<&Tensor>,
337 train: bool,
338 ) -> Result<GptNeoModelOutput, RustBertError> {
339 let (calc_input_embeddings, input_shape, device) =
340 process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
341
342 let (batch_size, current_sequence_length) = (input_shape[0], input_shape[1]);
343
344 let past_length = if let Some(past_state_value) = &layer_states {
345 if let Some(first_layer_state) = &past_state_value[0] {
346 let mut size_iter = first_layer_state.prev_key.size().into_iter().rev();
347 size_iter.next();
348 size_iter.next().unwrap()
349 } else {
350 0
351 }
352 } else {
353 0
354 };
355
356 let full_sequence_length = current_sequence_length + past_length;
357
358 let calc_position_ids = if position_ids.is_none() {
359 let position_ids =
360 Tensor::arange_start(past_length, full_sequence_length, (Kind::Int64, device));
361 Some(
362 position_ids
363 .unsqueeze(0)
364 .view([-1, current_sequence_length]),
365 )
366 } else {
367 None
368 };
369
370 let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
371
372 let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
373 let position_embeds = position_ids.apply(&self.position_embeddings);
374
375 let attention_mask = attention_mask.map(|attention_mask_value| {
376 let attention_mask = attention_mask_value
377 .view([batch_size, -1])
378 .unsqueeze(1)
379 .unsqueeze(1);
380 let attention_mask = attention_mask.to_kind(position_embeds.kind());
381 (1 - attention_mask) * -1e4
382 });
383
384 let mut hidden_state = input_embeds + position_embeds;
385 if let Some(token_type_ids) = token_type_ids {
386 hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings);
387 };
388 hidden_state = hidden_state.apply_t(&self.dropout, train);
389 let mut output_shape = input_shape;
390 output_shape.push(*hidden_state.size().last().unwrap());
391
392 let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
393 Some(vec![])
394 } else {
395 None
396 };
397 let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
398 Some(vec![])
399 } else {
400 None
401 };
402 let old_cache = layer_states.unwrap_or_else(|| vec![None; self.layers.len()]);
403 let mut next_cache = vec![None; self.layers.len()];
404
405 let mut x: Option<Tensor> = None;
406 let mut attention_weights: Option<Tensor>;
407
408 for ((layer_idx, layer), layer_state) in
409 self.layers.iter().enumerate().zip(old_cache.into_iter())
410 {
411 let temp = if let Some(x_value) = &x {
412 layer.forward_t(
413 x_value,
414 layer_state.as_ref(),
415 attention_mask.as_ref(),
416 train,
417 )?
418 } else {
419 layer.forward_t(
420 &hidden_state,
421 layer_state.as_ref(),
422 attention_mask.as_ref(),
423 train,
424 )?
425 };
426 x = Some(temp.0);
427 attention_weights = temp.1;
428 next_cache[layer_idx] = temp.2;
429 if let Some(attentions) = all_attentions.borrow_mut() {
430 attentions.push(std::mem::take(&mut attention_weights.unwrap()));
431 };
432 if let Some(hidden_states) = all_hidden_states.borrow_mut() {
433 hidden_states.push(x.as_ref().unwrap().copy());
434 };
435 }
436
437 let hidden_states = x
438 .unwrap()
439 .apply(&self.layer_norm)
440 .view(output_shape.as_slice());
441
442 Ok(GptNeoModelOutput {
443 hidden_states,
444 next_cache: Some(next_cache),
445 all_hidden_states,
446 all_attentions,
447 })
448 }
449}
450
451pub struct GptNeoForCausalLM {
456 transformer: GptNeoModel,
457}
458
459impl GptNeoForCausalLM {
460 pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoForCausalLM, RustBertError>
482 where
483 P: Borrow<nn::Path<'p>>,
484 {
485 let p = p.borrow();
486
487 let transformer = GptNeoModel::new(p / "transformer", config)?;
488
489 Ok(GptNeoForCausalLM { transformer })
490 }
491
492 pub fn forward_t(
543 &self,
544 input_ids: Option<&Tensor>,
545 input_embeds: Option<&Tensor>,
546 token_type_ids: Option<&Tensor>,
547 position_ids: Option<&Tensor>,
548 layer_states: Option<Vec<Option<LayerState>>>,
549 attention_mask: Option<&Tensor>,
550 train: bool,
551 ) -> Result<GptNeoModelLMOutput, RustBertError> {
552 let base_model_output = self.transformer.forward_t(
553 input_ids,
554 input_embeds,
555 token_type_ids,
556 position_ids,
557 layer_states,
558 attention_mask,
559 train,
560 )?;
561
562 let lm_logits = base_model_output
563 .hidden_states
564 .linear::<Tensor>(&self.transformer.word_embeddings.ws, None);
565
566 Ok(GptNeoModelLMOutput {
567 lm_logits,
568 next_cache: base_model_output.next_cache,
569 all_hidden_states: base_model_output.all_hidden_states,
570 all_attentions: base_model_output.all_attentions,
571 })
572 }
573}
574
575pub struct GptNeoModelOutput {
577 pub hidden_states: Tensor,
579 pub next_cache: Option<Vec<Option<LayerState>>>,
581 pub all_hidden_states: Option<Vec<Tensor>>,
583 pub all_attentions: Option<Vec<Tensor>>,
585}
586
587pub struct GptNeoModelLMOutput {
589 pub lm_logits: Tensor,
591 pub next_cache: Option<Vec<Option<LayerState>>>,
593 pub all_hidden_states: Option<Vec<Tensor>>,
595 pub all_attentions: Option<Vec<Tensor>>,
597}
598
599pub struct GptNeoGenerator {
601 model: GptNeoForCausalLM,
602 tokenizer: TokenizerOption,
603 var_store: nn::VarStore,
604 generate_config: GenerateConfig,
605 bos_token_id: Option<i64>,
606 eos_token_ids: Option<Vec<i64>>,
607 pad_token_id: Option<i64>,
608 is_encoder_decoder: bool,
609 vocab_size: i64,
610 decoder_start_id: Option<i64>,
611 max_position_embeddings: i64,
612}
613
614impl GptNeoGenerator {
615 pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
641 let vocab_path = generate_config.vocab_resource.get_local_path()?;
642 let merges_path = generate_config
643 .merges_resource
644 .as_ref()
645 .ok_or_else(|| {
646 RustBertError::InvalidConfigurationError(
647 "GPT-Neo expects a merges resources to be provided".to_string(),
648 )
649 })?
650 .get_local_path()?;
651
652 let tokenizer = TokenizerOption::from_file(
653 ModelType::GPTNeo,
654 vocab_path.to_str().unwrap(),
655 Some(merges_path.to_str().unwrap()),
656 false,
657 None,
658 None,
659 )?;
660
661 Self::new_with_tokenizer(generate_config, tokenizer)
662 }
663
664 pub fn new_with_tokenizer(
665 generate_config: GenerateConfig,
666 tokenizer: TokenizerOption,
667 ) -> Result<GptNeoGenerator, RustBertError> {
668 let config_path = generate_config.config_resource.get_local_path()?;
669 let device = generate_config.device;
670
671 generate_config.validate();
672 let mut var_store = nn::VarStore::new(device);
673 let config = GptNeoConfig::from_file(config_path);
674 let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
675 crate::resources::load_weights(
676 &generate_config.model_resource,
677 &mut var_store,
678 generate_config.kind,
679 device,
680 )?;
681
682 let bos_token_id = tokenizer.get_bos_id();
683 let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
684 let pad_token_id = tokenizer.get_pad_id();
685 let is_encoder_decoder = false;
686 let vocab_size = config.vocab_size;
687 let decoder_start_id = config.decoder_start_token_id;
688 let max_position_embeddings = config.max_position_embeddings;
689
690 Ok(GptNeoGenerator {
691 model,
692 tokenizer,
693 var_store,
694 generate_config,
695 bos_token_id,
696 eos_token_ids,
697 pad_token_id,
698 is_encoder_decoder,
699 vocab_size,
700 decoder_start_id,
701 max_position_embeddings,
702 })
703 }
704}
705
706impl PrivateLanguageGenerator for GptNeoGenerator {
707 fn _get_tokenizer(&self) -> &TokenizerOption {
708 &self.tokenizer
709 }
710 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
711 &mut self.tokenizer
712 }
713 fn get_device(&self) -> Device {
714 self.var_store.device()
715 }
716 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
717 Ok(&mut self.var_store)
718 }
719 fn get_config(&self) -> &GenerateConfig {
720 &self.generate_config
721 }
722 fn get_bos_id(&self) -> Option<i64> {
723 self.bos_token_id
724 }
725 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
726 self.eos_token_ids.as_ref()
727 }
728 fn get_pad_id(&self) -> Option<i64> {
729 self.pad_token_id
730 }
731 fn is_encoder_decoder(&self) -> bool {
732 self.is_encoder_decoder
733 }
734 fn get_vocab_size(&self) -> i64 {
735 self.vocab_size
736 }
737 fn get_decoder_start_id(&self) -> Option<i64> {
738 self.decoder_start_id
739 }
740
741 fn get_max_positions_embeddings(&self) -> Option<i64> {
742 Some(self.max_position_embeddings)
743 }
744
745 fn forward_t(
746 &self,
747 input_ids: Option<&Tensor>,
748 layer_past: Cache,
749 attention_mask: Option<&Tensor>,
750 token_type_ids: Option<&Tensor>,
751 position_ids: Option<&Tensor>,
752 input_embeds: Option<&Tensor>,
753 _encoder_outputs: Option<&Tensor>,
754 _decoder_input_ids: Option<&Tensor>,
755 train: bool,
756 ) -> Result<LMModelOutput, RustBertError> {
757 let base_model_output = match layer_past {
758 Cache::GPTNeoCache(layer_past) => self.model.forward_t(
759 input_ids,
760 input_embeds,
761 token_type_ids,
762 position_ids,
763 layer_past,
764 attention_mask,
765 train,
766 ),
767 Cache::None => self.model.forward_t(
768 input_ids,
769 input_embeds,
770 token_type_ids,
771 position_ids,
772 None,
773 attention_mask,
774 train,
775 ),
776 _ => {
777 return Err(RustBertError::ValueError(
778 "Cache not compatible with GPT-Neo Model".into(),
779 ));
780 }
781 }?;
782
783 Ok(LMModelOutput {
784 lm_logits: base_model_output.lm_logits,
785 cache: Cache::GPTNeoCache(base_model_output.next_cache),
786 })
787 }
788 fn prepare_inputs_for_generation<'a>(
789 &self,
790 input_ids: Tensor,
791 _encoder_outputs: Option<&'a Tensor>,
792 past: Cache,
793 attention_mask: Tensor,
794 ) -> PreparedInput<'a> {
795 let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
796 .masked_fill(&attention_mask.eq(0), 1);
797
798 match past {
799 Cache::GPTNeoCache(past) => {
800 if past.is_some() {
801 PreparedInput {
802 prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
803 prepared_attention_mask: Some(attention_mask),
804 prepared_encoder_output: None,
805 prepared_decoder_input: None,
806 prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
807 prepared_past: Cache::GPTNeoCache(past),
808 }
809 } else {
810 PreparedInput {
811 prepared_input: Some(input_ids),
812 prepared_attention_mask: Some(attention_mask),
813 prepared_encoder_output: None,
814 prepared_decoder_input: None,
815 prepared_position_ids: Some(position_ids),
816 prepared_past: Cache::GPTNeoCache(None),
817 }
818 }
819 }
820 Cache::None => PreparedInput {
821 prepared_input: Some(input_ids),
822 prepared_attention_mask: Some(attention_mask),
823 prepared_encoder_output: None,
824 prepared_decoder_input: None,
825 prepared_position_ids: Some(position_ids),
826 prepared_past: Cache::GPTNeoCache(None),
827 },
828 _ => panic!("Cache type incompatible with GPT-Neo"),
829 }
830 }
831
832 fn reorder_cache(
833 &self,
834 past: &mut Cache,
835 _encoder_outputs: Option<Tensor>,
836 beam_indices: &Tensor,
837 ) -> Option<Tensor> {
838 match past {
839 Cache::GPTNeoCache(cached_decoder_state) => match cached_decoder_state {
840 Some(old_cache) => {
841 for layer_state in old_cache.iter_mut() {
842 if layer_state.is_some() {
843 layer_state.as_mut().unwrap().reorder_cache(beam_indices)
844 };
845 }
846 None
847 }
848 None => None,
849 },
850 Cache::None => None,
851 _ => {
852 panic!("Invalid cache for GPT-Neo model");
853 }
854 }
855 }
856}
857
858impl LanguageGenerator for GptNeoGenerator {}