1use tch::kind::Kind::Int64;
70use tch::{no_grad, Device, Kind, Tensor};
71
72use crate::bart::LayerState as BartLayerState;
73use crate::common::resources::ResourceProvider;
74use crate::gpt_j::LayerState as GPTJLayerState;
75use crate::gpt_neo::LayerState as GPTNeoLayerState;
76use crate::pipelines::generation_utils::private_generation_utils::{
77 InternalGenerateOptions, PrivateLanguageGenerator,
78};
79use crate::prophetnet::LayerState as ProphetNetLayerState;
80use crate::reformer::LayerState as ReformerLayerState;
81use crate::t5::LayerState as T5LayerState;
82use crate::xlnet::LayerState as XLNetLayerState;
83
84use self::ordered_float::OrderedFloat;
85use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
86
87extern crate ordered_float;
88#[cfg(feature = "onnx")]
89use crate::pipelines::onnx::ONNXLayerCache;
90use crate::RustBertError;
91#[cfg(feature = "remote")]
92use crate::{
93 gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
94 resources::RemoteResource,
95};
96
97pub struct GenerateConfig {
99 pub model_type: ModelType,
101 pub model_resource: ModelResource,
103 pub config_resource: Box<dyn ResourceProvider + Send>,
105 pub vocab_resource: Box<dyn ResourceProvider + Send>,
107 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
109 pub min_length: i64,
111 pub max_length: Option<i64>,
113 pub do_sample: bool,
115 pub early_stopping: bool,
117 pub num_beams: i64,
119 pub temperature: f64,
121 pub top_k: i64,
123 pub top_p: f64,
125 pub repetition_penalty: f64,
127 pub length_penalty: f64,
129 pub no_repeat_ngram_size: i64,
131 pub num_return_sequences: i64,
133 pub num_beam_groups: Option<i64>,
135 pub diversity_penalty: Option<f64>,
137 pub device: Device,
139 pub kind: Option<Kind>,
141}
142
143#[cfg(feature = "remote")]
144impl Default for GenerateConfig {
145 fn default() -> GenerateConfig {
146 GenerateConfig {
147 model_type: ModelType::GPT2,
148 model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
149 Gpt2ModelResources::GPT2,
150 ))),
151 config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
152 vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
153 merges_resource: Some(Box::new(RemoteResource::from_pretrained(
154 Gpt2MergesResources::GPT2,
155 ))),
156 min_length: 0,
157 max_length: Some(56),
158 do_sample: true,
159 early_stopping: true,
160 num_beams: 5,
161 temperature: 1.0,
162 top_k: 0,
163 top_p: 0.9,
164 repetition_penalty: 1.0,
165 length_penalty: 1.0,
166 no_repeat_ngram_size: 3,
167 num_return_sequences: 1,
168 num_beam_groups: None,
169 diversity_penalty: None,
170 device: Device::cuda_if_available(),
171 kind: None,
172 }
173 }
174}
175
176impl GenerateConfig {
177 pub(crate) fn validate(&self) {
178 assert!(self.temperature > 0f64, "temperature must positive");
179 assert!(
180 (self.top_p >= 0f64) & (self.top_p <= 1f64),
181 "top_p must be 0 and 1"
182 );
183 assert!(
184 self.repetition_penalty >= 1f64,
185 "repetition_penalty must be greater than 1"
186 );
187 assert!(
188 self.length_penalty > 0f64,
189 "length_penalty must be strictly greater than 0"
190 );
191 assert!(
192 self.num_return_sequences > 0i64,
193 "num_return_sequences must be strictly greater than 0"
194 );
195 assert!(
196 self.num_beams > 0i64,
197 "num_beams must be strictly greater than 0"
198 );
199
200 if !self.do_sample {
201 if self.num_beams == 1 {
202 assert_eq!(
203 self.num_return_sequences, 1,
204 "num_return_sequences must be set to 1 for greedy decoding"
205 )
206 } else {
207 assert!(
208 self.num_beams >= self.num_return_sequences,
209 "num_return_sequences must be lower than the number of beams"
210 )
211 }
212 }
213 if let Some(num_beam_groups_value) = self.num_beam_groups {
214 if num_beam_groups_value > 1 {
215 assert_eq!(
216 self.num_beams % num_beam_groups_value,
217 0,
218 "num_beam_groups must be a multiple of num_beam_groups"
219 )
220 }
221 }
222 }
223}
224
225#[derive(Debug)]
226pub enum Cache {
227 GPT2Cache(Option<Vec<Tensor>>),
228 BARTCache(Option<Vec<(Option<BartLayerState>, Option<BartLayerState>)>>),
229 T5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
230 LongT5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
231 XLNetCache(Option<Vec<Option<XLNetLayerState>>>),
232 ReformerCache(Option<Vec<Option<ReformerLayerState>>>),
233 ProphetNetCache(Option<Vec<(Option<ProphetNetLayerState>, Option<ProphetNetLayerState>)>>),
234 GPTNeoCache(Option<Vec<Option<GPTNeoLayerState>>>),
235 GPTJCache(Option<Vec<Option<GPTJLayerState>>>),
236 #[cfg(feature = "onnx")]
237 ONNXCache(ONNXLayerCache),
238 None,
239}
240
241pub(crate) mod private_generation_utils {
242 use rust_tokenizers::TokenIdsWithOffsets;
243 use std::cmp::{max, min};
244 use std::collections::HashMap;
245 use std::convert::TryFrom;
246 use std::mem;
247
248 use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy};
249 use tch::{nn, Device, Kind, Tensor};
250
251 use crate::pipelines::common::TokenizerOption;
252 use crate::pipelines::generation_utils::{
253 BeamHypotheses, Cache, GenerateConfig, LMModelOutput, PrefixAllowedFunction,
254 };
255
256 use super::ordered_float::OrderedFloat;
257 use crate::common::kind::{get_negative_infinity, get_positive_infinity};
258 use crate::RustBertError;
259
260 pub struct InternalGenerateOptions<'a> {
261 pub min_length: i64,
262 pub max_length: Option<i64>,
263 pub do_sample: bool,
264 pub temperature: f64,
265 pub top_k: i64,
266 pub top_p: f64,
267 pub repetition_penalty: f64,
268 pub no_repeat_ngram_size: i64,
269 pub pad_token_id: Option<i64>,
270 pub eos_token_ids: Option<Vec<i64>>,
271 pub num_return_sequences: i64,
272 pub early_stopping: bool,
273 pub num_beams: i64,
274 pub length_penalty: f64,
275 pub num_beam_groups: Option<i64>,
276 pub diversity_penalty: Option<f64>,
277 pub forced_bos_token_id: Option<i64>,
278 pub bad_word_ids: Option<&'a Vec<Vec<i64>>>,
279 }
280
281 pub struct PreparedInput<'a> {
282 pub prepared_input: Option<Tensor>,
283 pub prepared_attention_mask: Option<Tensor>,
284 pub prepared_encoder_output: Option<&'a Tensor>,
285 pub prepared_decoder_input: Option<Tensor>,
286 pub prepared_position_ids: Option<Tensor>,
287 pub prepared_past: Cache,
288 }
289
290 pub struct GeneratedOutputWithScores {
291 pub indices: Tensor,
292 pub scores: Option<Vec<f64>>,
293 pub token_scores: Option<Vec<Vec<f64>>>,
294 }
295
296 pub trait PrivateLanguageGenerator {
297 fn _get_tokenizer(&self) -> &TokenizerOption;
298 fn get_device(&self) -> Device;
299 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError>;
300 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption;
301 fn get_config(&self) -> &GenerateConfig;
302 fn get_bos_id(&self) -> Option<i64>;
303 fn get_eos_ids(&self) -> Option<&Vec<i64>>;
304 fn get_forced_bos_token_id(&self) -> Option<i64> {
305 None
306 }
307 fn get_forced_eos_token_id(&self) -> Option<i64> {
308 None
309 }
310 fn get_pad_id(&self) -> Option<i64>;
311 fn is_encoder_decoder(&self) -> bool;
312 fn get_vocab_size(&self) -> i64;
313 fn get_decoder_start_id(&self) -> Option<i64>;
314 fn get_max_positions_embeddings(&self) -> Option<i64>;
315
316 fn forward_t(
317 &self,
318 input_ids: Option<&Tensor>,
319 layer_past: Cache,
320 attention_mask: Option<&Tensor>,
321 token_type_ids: Option<&Tensor>,
322 position_ids: Option<&Tensor>,
323 input_embeds: Option<&Tensor>,
324 encoder_outputs: Option<&Tensor>,
325 decoder_input_ids: Option<&Tensor>,
326 train: bool,
327 ) -> Result<LMModelOutput, RustBertError>;
328
329 fn prepare_scores_for_generation(
330 &self,
331 scores: &mut Tensor,
332 current_length: i64,
333 max_length: Option<i64>,
334 forced_bos_token_id: Option<i64>,
335 ) {
336 if current_length == 1 {
337 if let Some(forced_bos_token_id) =
338 forced_bos_token_id.or(self.get_forced_bos_token_id())
339 {
340 force_token_id_generation(
341 scores,
342 &[forced_bos_token_id],
343 self.get_vocab_size(),
344 );
345 }
346 } else if let Some(max_length) = max_length {
347 if let Some(forced_eos_token_id) = self.get_forced_eos_token_id() {
348 if current_length == max_length - 1 {
349 force_token_id_generation(
350 scores,
351 &[forced_eos_token_id],
352 self.get_vocab_size(),
353 );
354 }
355 }
356 }
357 }
358
359 fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> {
360 None
361 }
362
363 fn prepare_inputs_for_generation<'a>(
364 &self,
365 input_ids: Tensor,
366 _encoder_outputs: Option<&'a Tensor>,
367 past: Cache,
368 attention_mask: Tensor,
369 ) -> PreparedInput<'a> {
370 PreparedInput {
371 prepared_input: Some(input_ids),
372 prepared_attention_mask: Some(attention_mask),
373 prepared_encoder_output: None,
374 prepared_decoder_input: None,
375 prepared_position_ids: None,
376 prepared_past: past,
377 }
378 }
379
380 fn encode_prompt_text<S>(
381 &self,
382 prompt_text: &[S],
383 max_len: Option<i64>,
384 pad_token_id: Option<i64>,
385 ) -> Tensor
386 where
387 S: AsRef<str> + Send + Sync,
388 {
389 let token_ids = if self.is_encoder_decoder() {
390 let tokens = self._get_tokenizer().encode_list(
391 prompt_text,
392 max_len
393 .map(|max_len| max_len as usize)
394 .unwrap_or(usize::MAX),
395 &TruncationStrategy::LongestFirst,
396 0,
397 );
398 tokens
399 .into_iter()
400 .map(|tokenized_input| tokenized_input.token_ids)
401 .collect::<Vec<Vec<i64>>>()
402 } else {
403 let tokens = self._get_tokenizer().tokenize_list(prompt_text);
405 let token_ids = tokens
406 .into_iter()
407 .map(|prompt_tokens| {
408 self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)
409 })
410 .collect::<Vec<Vec<i64>>>();
411
412 let num_truncated_tokens = token_ids
413 .iter()
414 .map(|token_ids| {
415 max_len
416 .map(|max_len| {
417 if token_ids.len() > max_len as usize {
418 token_ids.len() - max_len as usize
419 } else {
420 0
421 }
422 })
423 .unwrap_or(0)
424 })
425 .collect::<Vec<usize>>();
426
427 token_ids
428 .into_iter()
429 .zip(num_truncated_tokens)
430 .map(|(tokens, num_truncated_tokens)| {
431 truncate_sequences(
432 TokenIdsWithOffsets {
433 ids: tokens,
434 offsets: vec![],
435 reference_offsets: vec![],
436 masks: vec![],
437 },
438 None,
439 num_truncated_tokens,
440 &TruncationStrategy::LongestFirst,
441 0,
442 )
443 .unwrap()
444 .0
445 .ids
446 })
447 .collect::<Vec<Vec<i64>>>()
448 };
449
450 let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
451
452 let pad_token = match pad_token_id {
453 Some(value) => value,
454 None => self._get_tokenizer().get_unk_id(),
455 };
456
457 let token_ids = token_ids
458 .into_iter()
459 .map(|mut input| {
460 let mut temp = vec![pad_token; max_len - input.len()];
461 if self.is_encoder_decoder() {
462 input.extend(temp);
463 input
464 } else {
465 temp.extend(input);
467 temp
468 }
469 })
470 .map(|tokens| Tensor::from_slice(&tokens).to(self.get_device()))
471 .collect::<Vec<Tensor>>();
472
473 Tensor::stack(&token_ids, 0)
474 }
475
476 fn enforce_repetition_penalty(
477 &self,
478 next_token_logits: &mut Tensor,
479 batch_size: i64,
480 num_beams: i64,
481 prev_output_tokens: &Tensor,
482 repetition_penalty: f64,
483 ) {
484 for i in 0..(batch_size * num_beams) {
485 for token_position in 0..prev_output_tokens.get(i).size()[0] {
486 let token = prev_output_tokens.get(i).int64_value(&[token_position]);
487 let updated_value = &next_token_logits.double_value(&[i, token]);
488 if updated_value < &0f64 {
489 let _ = next_token_logits.get(i).index_fill_(
490 0,
491 &Tensor::from_slice(&[token])
492 .to_kind(Kind::Int64)
493 .to_device(next_token_logits.device()),
494 updated_value * repetition_penalty,
495 );
496 } else {
497 let _ = next_token_logits.get(i).index_fill_(
498 0,
499 &Tensor::from_slice(&[token])
500 .to_kind(Kind::Int64)
501 .to_device(next_token_logits.device()),
502 updated_value / repetition_penalty,
503 );
504 }
505 }
506 }
507 }
508
509 fn get_banned_tokens(
510 &self,
511 input_ids: &Tensor,
512 no_repeat_ngram_size: i64,
513 cur_len: i64,
514 ) -> Vec<Vec<i64>> {
515 if cur_len + 1 < no_repeat_ngram_size {
517 vec![vec![]]
518 } else {
519 let input_ids = input_ids.to(Device::Cpu);
520 let num_hypothesis = *input_ids.size().first().unwrap();
521 let mut banned_tokens: Vec<Vec<i64>> = Vec::with_capacity(num_hypothesis as usize);
522 for hypothesis_index in 0..num_hypothesis {
523 let hypothesis_input_ids = input_ids.get(hypothesis_index);
524 let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
525 let input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
526 let hypothesis_input_ids = hypothesis_input_ids
527 .iter::<i64>()
528 .unwrap()
529 .collect::<Vec<i64>>();
530 let query = &hypothesis_input_ids
531 [cur_len as usize + 1 - no_repeat_ngram_size as usize..]
532 .to_vec();
533 for ngram in input
534 .windows(no_repeat_ngram_size as usize)
535 .map(|win| (*win.first().unwrap(), *win.last().unwrap()))
536 {
537 let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
538 let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
539 let value = *ngram.last().unwrap();
540 generated_ngram
541 .entry(key)
542 .or_insert_with(|| vec![value])
543 .push(value);
544 }
545 let hypothesis_banned_tokens = match generated_ngram.get(query) {
546 Some(banned_tokens) => banned_tokens.clone(),
547 None => vec![],
548 };
549 banned_tokens.push(hypothesis_banned_tokens);
550 }
551 banned_tokens
552 }
553 }
554
555 fn top_k_top_p_filtering(
556 &self,
557 logits: &mut Tensor,
558 top_k: i64,
559 top_p: f64,
560 min_tokens_to_keep: i64,
561 ) {
562 let vocab_size = *logits.size().last().unwrap();
565 if top_k > 0 {
566 let top_k = vocab_size - min(max(top_k, min_tokens_to_keep), vocab_size);
567 let (_, indices_to_remove) = logits.topk(top_k, -1, false, false);
568 for index in 0..*logits.size().first().unwrap() {
569 let _ = logits.get(index).index_fill_(
570 0,
571 &indices_to_remove.get(index),
572 f64::NEG_INFINITY,
573 );
574 }
575 }
576 if top_p < 1f64 {
577 let (sorted_logits, sorted_indices) = logits.sort(-1, true);
578 let cumulative_probabilities = sorted_logits
579 .softmax(-1, sorted_logits.kind())
580 .cumsum(-1, sorted_logits.kind());
581 let mut sorted_indices_to_remove =
582 cumulative_probabilities.ge(top_p).to_kind(Kind::Int64);
583 if min_tokens_to_keep > 1 {
584 let _ = sorted_indices_to_remove.index_fill_(
585 1,
586 &Tensor::arange_start(
587 0,
588 min_tokens_to_keep + 1,
589 (Kind::Int64, logits.device()),
590 ),
591 0,
592 );
593 }
594 let _ = sorted_indices_to_remove.index_copy_(
595 1,
596 &Tensor::arange_start(1, vocab_size, (Kind::Int64, logits.device())),
597 &sorted_indices_to_remove
598 .slice(1, 0, vocab_size - 1, 1)
599 .copy(),
600 );
601 let _ = sorted_indices_to_remove.index_fill_(
602 1,
603 &Tensor::from_slice(&[0])
604 .to_kind(Kind::Int64)
605 .to_device(sorted_indices_to_remove.device()),
606 0,
607 );
608 let indices_to_remove = sorted_indices_to_remove
609 .scatter(1, &sorted_indices, &sorted_indices_to_remove)
610 .to_kind(Kind::Bool);
611 let _ = logits.masked_fill_(&indices_to_remove, f64::NEG_INFINITY);
612 }
613 }
614
615 fn run_hamming_diversity_penalty(
616 &self,
617 scores: &mut Tensor,
618 current_tokens: &Tensor,
619 diversity_penalty: f64,
620 num_beams: i64,
621 batch_size: i64,
622 group_size: i64,
623 group_start_index: i64,
624 ) {
625 if group_start_index > 0 {
626 let vocab_size = *scores.size().last().unwrap();
627 for batch_index in 0..batch_size {
628 let previous_group_tokens = current_tokens.slice(
629 0,
630 batch_index * num_beams,
631 batch_index * num_beams + group_start_index,
632 1,
633 );
634 let diversity_penalty = previous_group_tokens
635 .bincount::<Tensor>(None, vocab_size)
636 * diversity_penalty;
637 let _ = scores
638 .slice(
639 0,
640 batch_index * group_size,
641 (batch_index + 1) * group_size,
642 1,
643 )
644 .subtract_(&diversity_penalty);
645 }
646 }
647 }
648
649 fn apply_prefix_allowed_tokens_function(
650 &self,
651 prefix_allowed_tokens_fn: &dyn Fn(i64, &Tensor) -> Vec<i64>,
652 num_beams: i64,
653 input_ids: &Tensor,
654 scores: &mut Tensor,
655 ) {
656 let mask = scores.new_full(
657 scores.size().as_slice(),
658 get_positive_infinity(scores.kind()).unwrap(),
659 (scores.kind(), scores.device()),
660 );
661 for idx in 0..scores.size()[0] {
662 let batch_id = idx / num_beams;
663 let allowed_tokens: Vec<i64> =
664 prefix_allowed_tokens_fn(batch_id, &input_ids.get(idx));
665 let _ = mask.get(idx).index_fill_(
666 0,
667 &Tensor::from_slice(allowed_tokens.as_slice()).to(scores.device()),
668 0,
669 );
670 }
671 let _ = scores.subtract_(&mask);
672 }
673
674 fn split_bad_word_ids<'a>(
675 &self,
676 bad_word_ids: Option<&'a Vec<Vec<i64>>>,
677 ) -> (Option<Vec<i64>>, Option<Vec<&'a Vec<i64>>>) {
678 if let Some(bad_word_ids) = bad_word_ids {
679 let mut bad_word_ids_length_1 = vec![];
680 let mut bad_word_ids_length_greater_than_1 = vec![];
681 for bad_word in bad_word_ids {
682 if bad_word.len() == 1 {
683 bad_word_ids_length_1.push(bad_word[0]);
684 } else {
685 bad_word_ids_length_greater_than_1.push(bad_word);
686 }
687 }
688 let bad_word_ids_length_1 = if !bad_word_ids_length_1.is_empty() {
689 Some(bad_word_ids_length_1)
690 } else {
691 None
692 };
693 let bad_word_ids_length_greater_than_1 =
694 if !bad_word_ids_length_greater_than_1.is_empty() {
695 Some(bad_word_ids_length_greater_than_1)
696 } else {
697 None
698 };
699 (bad_word_ids_length_1, bad_word_ids_length_greater_than_1)
700 } else {
701 (None, None)
702 }
703 }
704
705 fn tokens_match(&self, prev_tokens: &[i64], tokens: &[i64]) -> bool {
706 if tokens.is_empty() {
707 true
708 } else if tokens.len() > prev_tokens.len() {
709 false
710 } else {
711 &prev_tokens[prev_tokens.len() - tokens.len()..] == tokens
712 }
713 }
714
715 fn calc_static_bad_word_mask(
716 &self,
717 scores: &Tensor,
718 bad_words_id_length_1: &[i64],
719 ) -> Tensor {
720 let mut static_bad_words_mask =
721 Tensor::zeros([scores.size()[1]], (Kind::Int8, scores.device()));
722 let _ = static_bad_words_mask.index_fill_(
723 0,
724 &Tensor::from_slice(bad_words_id_length_1).to_device(scores.device()),
725 1,
726 );
727 static_bad_words_mask.unsqueeze(0).totype(Kind::Bool)
728 }
729
730 fn get_dynamic_bad_word_ids(
731 &self,
732 prev_tokens: &[Vec<i64>],
733 bad_word_ids_length_greater_than_1: &[&Vec<i64>],
734 ) -> Vec<Vec<i64>> {
735 let mut banned_tokens = Vec::new();
736 for prev_token_sequence in prev_tokens {
737 let mut sequence_banned_tokens = Vec::new();
738 for bad_word_ids in bad_word_ids_length_greater_than_1 {
739 if self
740 .tokens_match(prev_token_sequence, &bad_word_ids[..bad_word_ids.len() - 1])
741 {
742 sequence_banned_tokens.push(*bad_word_ids.last().unwrap());
743 }
744 }
745 banned_tokens.push(sequence_banned_tokens);
746 }
747
748 banned_tokens
749 }
750
751 fn ban_bad_words(
752 &self,
753 dynamic_bad_words: Option<&Vec<&Vec<i64>>>,
754 static_bad_words_mask: Option<&Tensor>,
755 token_ids: &Tensor,
756 scores: &mut Tensor,
757 ) {
758 let longest_bad_word = dynamic_bad_words
759 .iter()
760 .map(|bad_word| bad_word.len())
761 .max()
762 .unwrap() as i64;
763
764 let last_token_ids = token_ids.slice(1, -longest_bad_word, None, 1);
765 let mut prev_tokens = Vec::new();
766 for sequence_idx in 0..token_ids.size()[0] {
767 prev_tokens.push(
768 last_token_ids
769 .get(sequence_idx)
770 .iter::<i64>()
771 .unwrap()
772 .collect::<Vec<i64>>(),
773 )
774 }
775
776 let dynamic_bad_words_mask = if let Some(dynamic_bad_words) = dynamic_bad_words {
777 let dynamic_banned_tokens =
778 self.get_dynamic_bad_word_ids(&prev_tokens, dynamic_bad_words);
779 let dynamic_banned_mask =
780 Tensor::zeros(scores.size().as_slice(), (Kind::Int, scores.device()));
781 for (sequence_index, sequence_ban_tokens) in
782 dynamic_banned_tokens.iter().enumerate()
783 {
784 if !sequence_ban_tokens.is_empty() {
785 let _ = dynamic_banned_mask.get(sequence_index as i64).index_fill_(
786 0,
787 &Tensor::from_slice(sequence_ban_tokens).to_device(scores.device()),
788 1,
789 );
790 }
791 }
792 Some(dynamic_banned_mask.to_kind(Kind::Bool))
793 } else {
794 None
795 };
796
797 let combined_bad_word_mask = {
798 if let (Some(static_mask), Some(dynamic_mask)) =
799 (static_bad_words_mask, &dynamic_bad_words_mask)
800 {
801 Some(static_mask.bitwise_or_tensor(dynamic_mask))
802 } else {
803 None
804 }
805 };
806
807 let bad_word_mask = if combined_bad_word_mask.is_some() {
808 combined_bad_word_mask.as_ref()
809 } else if static_bad_words_mask.is_some() {
810 static_bad_words_mask
811 } else if dynamic_bad_words_mask.is_some() {
812 dynamic_bad_words_mask.as_ref()
813 } else {
814 None
815 };
816
817 if let Some(bad_word_mask) = bad_word_mask {
818 let _ = scores.masked_fill_(bad_word_mask, f64::NEG_INFINITY);
819 }
820 }
821
822 fn generate_no_beam_search(
823 &self,
824 input_ids: Tensor,
825 encoder_outputs: Option<Tensor>,
826 cur_len: i64,
827 batch_size: i64,
828 attention_mask: Tensor,
829 gen_opt: InternalGenerateOptions,
830 prefix_allowed_tokens_fn: Option<PrefixAllowedFunction>,
831 output_scores: bool,
832 ) -> GeneratedOutputWithScores {
833 let mut unfinished_sentences =
834 Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
835 let mut sentence_lengths: Tensor =
836 Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
837 let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
838 self.split_bad_word_ids(gen_opt.bad_word_ids);
839 let mut static_bad_words_mask: Option<Tensor> = None;
840 let mut attention_mask = attention_mask.copy();
841 let mut input_ids = input_ids.copy();
842 let mut past: Cache = Cache::None;
843 let mut outputs: Tensor;
844 let mut current_length = cur_len;
845 let mut token_scores_output: Option<Vec<Tensor>> =
846 if output_scores { Some(vec![]) } else { None };
847
848 loop {
849 let prepared_input = self.prepare_inputs_for_generation(
850 input_ids.copy(),
851 encoder_outputs.as_ref(),
852 past,
853 attention_mask.copy(),
854 );
855 let temp = self
856 .forward_t(
857 prepared_input.prepared_input.as_ref(),
858 prepared_input.prepared_past,
859 prepared_input.prepared_attention_mask.as_ref(),
860 None,
861 prepared_input.prepared_position_ids.as_ref(),
862 None,
863 prepared_input.prepared_encoder_output,
864 prepared_input.prepared_decoder_input.as_ref(),
865 false,
866 )
867 .unwrap();
868 outputs = temp.lm_logits;
869 past = temp.cache;
870
871 let mut next_token_logits = outputs.select(1, -1);
872 if gen_opt.repetition_penalty > 1f64 {
874 self.enforce_repetition_penalty(
875 &mut next_token_logits,
876 batch_size,
877 1,
878 &input_ids,
879 gen_opt.repetition_penalty,
880 )
881 }
882
883 if gen_opt.bad_word_ids.is_some() {
885 if let Some(bad_word_ids_length_1) = &bad_word_ids_length_1 {
887 if static_bad_words_mask.is_none() {
888 static_bad_words_mask = Some(self.calc_static_bad_word_mask(
889 &next_token_logits,
890 bad_word_ids_length_1,
891 ));
892 }
893 }
894 self.ban_bad_words(
895 bad_word_ids_length_greater_than_1.as_ref(),
896 static_bad_words_mask.as_ref(),
897 &input_ids,
898 &mut next_token_logits,
899 );
900 }
901
902 if gen_opt.no_repeat_ngram_size > 0 {
904 let banned_tokens = self.get_banned_tokens(
905 &input_ids,
906 gen_opt.no_repeat_ngram_size,
907 current_length,
908 );
909 for (batch_index, index_banned_token) in
910 (0..banned_tokens.len() as i64).zip(banned_tokens)
911 {
912 let _ = next_token_logits.get(batch_index).index_fill_(
913 0,
914 &Tensor::from_slice(&index_banned_token)
915 .to_device(next_token_logits.device()),
916 f64::NEG_INFINITY,
917 );
918 }
919 }
920
921 if let Some(prefix_allowed_tokens_function) = prefix_allowed_tokens_fn {
923 self.apply_prefix_allowed_tokens_function(
924 prefix_allowed_tokens_function,
925 1,
926 &input_ids,
927 &mut next_token_logits,
928 )
929 }
930
931 if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
933 let _ = next_token_logits.index_fill_(
934 1,
935 &Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
936 .to(next_token_logits.device()),
937 f64::NEG_INFINITY,
938 );
939 }
940
941 self.prepare_scores_for_generation(
942 &mut next_token_logits,
943 current_length,
944 gen_opt.max_length,
945 gen_opt.forced_bos_token_id,
946 );
947
948 let next_token = if gen_opt.do_sample {
950 if gen_opt.temperature > 1f64 {
951 next_token_logits /= gen_opt.temperature;
952 }
953 self.top_k_top_p_filtering(
954 &mut next_token_logits,
955 gen_opt.top_k,
956 gen_opt.top_p,
957 1,
958 );
959 let probabilities = next_token_logits.softmax(-1, next_token_logits.kind());
960 probabilities.multinomial(1, false).squeeze_dim(1)
961 } else {
962 next_token_logits.argmax(-1, false)
963 };
964
965 if let Some(prev_scores) = token_scores_output.as_mut() {
966 let finished_mask = unfinished_sentences.eq(0);
967 prev_scores.push(
968 next_token_logits
969 .log_softmax(-1, next_token_logits.kind())
970 .gather(1, &next_token.reshape([-1, 1]), false)
971 .squeeze()
972 .masked_fill(&finished_mask, 0),
973 );
974 };
975
976 let tokens_to_add = match &gen_opt.eos_token_ids {
978 Some(_) => {
979 next_token * &unfinished_sentences
980 - gen_opt.pad_token_id.unwrap() * (&unfinished_sentences - 1)
981 }
982 None => next_token,
983 };
984
985 input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
986 if gen_opt.eos_token_ids.is_some() {
987 for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
988 let sentence_with_eos =
989 tokens_to_add.eq(*eos_token_id).to_kind(Kind::Int64);
990 let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
991 let _ = sentence_lengths.masked_fill_(
992 &sentence_with_eos
993 .to_kind(Kind::Bool)
994 .to_device(sentence_lengths.device()),
995 current_length + 1,
996 );
997 unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1);
998 }
999 if i64::try_from(unfinished_sentences.max()).unwrap() == 0 {
1000 break;
1001 }
1002 }
1003 if !self.is_encoder_decoder() {
1004 attention_mask = Tensor::cat(
1005 &[
1006 attention_mask.as_ref(),
1007 Tensor::ones(
1008 [*attention_mask.size().first().unwrap(), 1],
1009 (Kind::Int64, attention_mask.device()),
1010 )
1011 .as_ref(),
1012 ],
1013 -1,
1014 );
1015 }
1016 current_length += 1;
1017 if let Some(max_length) = gen_opt.max_length {
1018 if current_length >= max_length {
1019 let _ = sentence_lengths.masked_fill_(
1020 &unfinished_sentences
1021 .to_kind(Kind::Bool)
1022 .to_device(sentence_lengths.device()),
1023 current_length,
1024 );
1025 break;
1026 }
1027 }
1028 }
1029 let scores_output = token_scores_output.as_ref().map(|scores_tensor| {
1030 (Tensor::stack(scores_tensor, 1).sum_dim_intlist(
1031 [1].as_slice(),
1032 false,
1033 Kind::Float,
1034 ) / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
1035 .iter::<f64>()
1036 .unwrap()
1037 .collect::<Vec<f64>>()
1038 });
1039 let token_scores_output = token_scores_output.map(|score_tensors| {
1040 Tensor::stack(&score_tensors, 1)
1041 .split(1, 0)
1042 .iter()
1043 .map(|sequence_scores| {
1044 sequence_scores
1045 .squeeze_dim(0)
1046 .iter::<f64>()
1047 .unwrap()
1048 .collect::<Vec<f64>>()
1049 })
1050 .collect()
1051 });
1052 GeneratedOutputWithScores {
1053 indices: input_ids,
1054 scores: scores_output,
1055 token_scores: token_scores_output,
1056 }
1057 }
1058
1059 fn generate_beam_search(
1060 &self,
1061 mut input_ids: Tensor,
1062 encoder_outputs: Option<Tensor>,
1063 cur_len: i64,
1064 batch_size: i64,
1065 mut attention_mask: Tensor,
1066 gen_opt: InternalGenerateOptions,
1067 prefix_allowed_tokens_fn: Option<PrefixAllowedFunction>,
1068 output_scores: bool,
1069 ) -> GeneratedOutputWithScores {
1070 let num_beam_groups = gen_opt.num_beam_groups.unwrap_or(1);
1071 let num_sub_beams = gen_opt.num_beams / num_beam_groups;
1072 let diversity_penalty = gen_opt.diversity_penalty.unwrap_or(5.5);
1073 let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
1074 self.split_bad_word_ids(gen_opt.bad_word_ids);
1075 let mut static_bad_words_mask: Option<Tensor> = None;
1076
1077 let mut hypotheses = (0..batch_size)
1078 .map(|_| {
1079 BeamHypotheses::new(
1080 gen_opt.num_beams,
1081 gen_opt.max_length,
1082 gen_opt.length_penalty,
1083 gen_opt.early_stopping,
1084 )
1085 })
1086 .collect::<Vec<BeamHypotheses>>();
1087
1088 let vocab_size = self.get_vocab_size();
1089 let beam_scores = Tensor::ones(
1090 [batch_size, gen_opt.num_beams],
1091 (Kind::Float, self.get_device()),
1092 ) * -1e9;
1093 let _ = beam_scores
1094 .slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams)
1095 .fill_(0);
1096
1097 let mut beam_scores = beam_scores.view_([-1]);
1098 let mut beam_tokens = Tensor::zeros(
1099 [batch_size * gen_opt.num_beams],
1100 (Kind::Int64, self.get_device()),
1101 );
1102 let mut beam_indices = Tensor::zeros(
1103 [batch_size * gen_opt.num_beams],
1104 (Kind::Int64, self.get_device()),
1105 );
1106 let mut saved_beam_scores: Option<Vec<Tensor>> =
1107 if output_scores { Some(vec![]) } else { None };
1108 let mut current_tokens = Tensor::new();
1109
1110 let mut past: Cache = Cache::None;
1111 let mut done = vec![false; batch_size as usize];
1112
1113 let mut outputs: Tensor;
1114 let mut encoder_outputs = encoder_outputs;
1115 let mut current_length = cur_len;
1116
1117 loop {
1118 if num_beam_groups > 1 {
1119 current_tokens = Tensor::zeros(
1120 [batch_size * gen_opt.num_beams],
1121 (input_ids.kind(), input_ids.device()),
1122 );
1123 }
1124 let prepared_input = self.prepare_inputs_for_generation(
1125 input_ids.copy(),
1126 encoder_outputs.as_ref(),
1127 past,
1128 attention_mask.copy(),
1129 );
1130 let temp = self
1131 .forward_t(
1132 prepared_input.prepared_input.as_ref(),
1133 prepared_input.prepared_past,
1134 prepared_input.prepared_attention_mask.as_ref(),
1135 None,
1136 prepared_input.prepared_position_ids.as_ref(),
1137 None,
1138 prepared_input.prepared_encoder_output,
1139 prepared_input.prepared_decoder_input.as_ref(),
1140 false,
1141 )
1142 .unwrap();
1143 outputs = temp.lm_logits;
1144 past = temp.cache;
1145
1146 for beam_group_index in 0..num_beam_groups {
1147 let group_start_index = beam_group_index * num_sub_beams;
1148 let group_end_index = min(group_start_index + num_sub_beams, gen_opt.num_beams);
1149 let group_size = group_end_index - group_start_index;
1150
1151 let (group_input_ids, batch_group_indices) = if num_beam_groups > 1 {
1152 let mut batch_group_indices: Vec<i64> =
1153 Vec::with_capacity((batch_size * group_size) as usize);
1154 for batch_index in 0..batch_size {
1155 batch_group_indices.extend(
1156 (group_start_index..group_end_index)
1157 .map(|value| value + batch_index * gen_opt.num_beams),
1158 )
1159 }
1160 let batch_group_indices =
1161 Tensor::from_slice(batch_group_indices.as_slice())
1162 .to(input_ids.device());
1163 (
1164 Some(input_ids.index_select(0, &batch_group_indices)),
1165 Some(batch_group_indices),
1166 )
1167 } else {
1168 (None, None)
1169 };
1170
1171 let mut next_token_logits = if num_beam_groups <= 1 {
1172 outputs.select(1, -1)
1173 } else {
1174 outputs
1175 .select(1, -1)
1176 .index_select(0, batch_group_indices.as_ref().unwrap())
1177 };
1178 if gen_opt.repetition_penalty > 1f64 {
1180 self.enforce_repetition_penalty(
1181 &mut next_token_logits,
1182 batch_size,
1183 1,
1184 group_input_ids.as_ref().unwrap_or(&input_ids),
1185 gen_opt.repetition_penalty,
1186 )
1187 }
1188
1189 if gen_opt.temperature > 1f64 {
1190 next_token_logits /= gen_opt.temperature;
1191 }
1192 self.prepare_scores_for_generation(
1193 &mut next_token_logits,
1194 current_length,
1195 gen_opt.max_length,
1196 gen_opt.forced_bos_token_id,
1197 );
1198
1199 let mut scores = next_token_logits.log_softmax(-1, next_token_logits.kind());
1200
1201 if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
1203 let _ = scores.index_fill_(
1204 1,
1205 &Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
1206 .to(scores.device()),
1207 f64::NEG_INFINITY,
1208 );
1209 }
1210
1211 if gen_opt.bad_word_ids.is_some() {
1213 if let Some(bad_word_ids_length_1) = &bad_word_ids_length_1 {
1215 if static_bad_words_mask.is_none() {
1216 static_bad_words_mask = Some(
1217 self.calc_static_bad_word_mask(&scores, bad_word_ids_length_1),
1218 );
1219 }
1220 }
1221 self.ban_bad_words(
1222 bad_word_ids_length_greater_than_1.as_ref(),
1223 static_bad_words_mask.as_ref(),
1224 group_input_ids.as_ref().unwrap_or(&input_ids),
1225 &mut scores,
1226 );
1227 }
1228
1229 if gen_opt.no_repeat_ngram_size > 0 {
1231 let banned_tokens = self.get_banned_tokens(
1232 group_input_ids.as_ref().unwrap_or(&input_ids),
1233 gen_opt.no_repeat_ngram_size,
1234 current_length,
1235 );
1236 for (batch_index, index_banned_token) in
1237 (0..banned_tokens.len() as i64).zip(banned_tokens)
1238 {
1239 let _ = scores.get(batch_index).index_fill_(
1240 0,
1241 &Tensor::from_slice(&index_banned_token)
1242 .to_device(next_token_logits.device()),
1243 f64::NEG_INFINITY,
1244 );
1245 }
1246 }
1247
1248 if num_beam_groups > 1 {
1250 self.run_hamming_diversity_penalty(
1251 &mut scores,
1252 ¤t_tokens,
1253 diversity_penalty,
1254 gen_opt.num_beams,
1255 batch_size,
1256 group_size,
1257 group_start_index,
1258 );
1259 }
1260
1261 if let Some(prefix_allowed_tokens_function) = prefix_allowed_tokens_fn {
1263 self.apply_prefix_allowed_tokens_function(
1264 prefix_allowed_tokens_function,
1265 num_sub_beams,
1266 &input_ids,
1267 &mut scores,
1268 )
1269 }
1270
1271 let mut next_scores: Tensor = &scores
1272 + (if num_beam_groups > 1 {
1273 beam_scores
1274 .index_select(0, batch_group_indices.as_ref().unwrap())
1275 .unsqueeze(-1)
1276 .expand_as(&scores)
1277 } else {
1278 beam_scores.unsqueeze(-1).expand_as(&scores)
1279 });
1280
1281 let (next_scores, next_tokens) = if gen_opt.do_sample {
1282 self.top_k_top_p_filtering(
1283 &mut next_scores,
1284 gen_opt.top_k,
1285 gen_opt.top_p,
1286 2,
1287 );
1288 let _scores = next_scores
1289 .contiguous()
1290 .view((batch_size, group_size * vocab_size));
1291
1292 let probabilities = _scores.softmax(-1, _scores.kind());
1293 let next_tokens = probabilities.multinomial(2 * group_size, false);
1294 let _scores = _scores.gather(-1, &next_tokens, false);
1295 let (_scores, next_scores_indices) = _scores.sort(1, true);
1296 let next_tokens = next_tokens.gather(-1, &next_scores_indices, false);
1297 (_scores, next_tokens)
1298 } else {
1299 let _scores = next_scores
1300 .contiguous()
1301 .view((batch_size, group_size * vocab_size));
1302 _scores.topk(2 * group_size, 1, true, true)
1303 };
1304
1305 let eos_token_ids = gen_opt.eos_token_ids.as_ref();
1306 let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor");
1307 let effective_beam_ids_tensor =
1308 (&next_tokens.ones_like().cumsum(0, Kind::Int64) - 1) * group_size
1309 + beam_ids_tensor;
1310 let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
1311 let (max_scores, _) = next_scores.max_dim(1, false);
1312 let mut eos_mask = token_id_tensor.ones_like();
1313 if let Some(eos_token_id) = eos_token_ids {
1314 eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Kind::Int64);
1315 }
1316 let eos_mask2 = eos_mask
1317 .cumsum(1, Kind::Int64)
1318 .le(group_size)
1319 .to_kind(Kind::Bool)
1320 .logical_and(&eos_mask);
1321
1322 let group_beam_scores = next_scores.masked_select(&eos_mask2);
1323 let group_beam_tokens = token_id_tensor.masked_select(&eos_mask2);
1324 let group_beam_indices = effective_beam_ids_tensor.masked_select(&eos_mask2);
1325 let eos_pos = (eos_mask.ones_like() - eos_mask).nonzero();
1326
1327 for eos_idx in 0..eos_pos.size()[0] {
1328 let eos_data = eos_pos.get(eos_idx);
1329 let batch_index = eos_data.int64_value(&[0]);
1330 if !done[batch_index as usize] {
1331 let beam_index_pos = eos_data.int64_value(&[1]);
1332 let is_beam_token_worse_than_top_num_beams =
1333 beam_index_pos >= gen_opt.num_beams;
1334 if is_beam_token_worse_than_top_num_beams {
1335 continue;
1336 }
1337 let effective_beam_id = effective_beam_ids_tensor
1338 .int64_value(&[batch_index, beam_index_pos]);
1339 let beam_token_score =
1340 next_scores.double_value(&[batch_index, beam_index_pos]);
1341 let saved_beam_scores =
1342 saved_beam_scores.as_ref().map(|step_wise_scores| {
1343 Tensor::stack(step_wise_scores, 1)
1344 .get(effective_beam_id)
1345 .copy()
1346 });
1347 hypotheses[batch_index as usize].add(
1348 input_ids.get(effective_beam_id).copy(),
1349 beam_token_score,
1350 saved_beam_scores,
1351 );
1352 }
1353 }
1354
1355 for batch_index in 0..batch_size {
1356 if done[batch_index as usize] {
1357 let _ = group_beam_scores
1358 .narrow(0, batch_index * gen_opt.num_beams, gen_opt.num_beams)
1359 .fill_(0f64);
1360 let _ = group_beam_tokens
1361 .narrow(0, batch_index * gen_opt.num_beams, gen_opt.num_beams)
1362 .fill_(gen_opt.pad_token_id.unwrap());
1363 let _ = group_beam_indices
1364 .narrow(0, batch_index * gen_opt.num_beams, gen_opt.num_beams)
1365 .fill_(0);
1366 continue;
1367 } else {
1368 done[batch_index as usize] |= hypotheses[batch_index as usize]
1369 .is_done(max_scores.double_value(&[batch_index]), current_length);
1370 }
1371 }
1372
1373 if num_beam_groups <= 1 {
1374 beam_scores = group_beam_scores.view(-1);
1375 beam_tokens = group_beam_tokens.view(-1);
1376 beam_indices = group_beam_indices.view(-1);
1377 } else {
1378 let _ = beam_scores.index_copy_(
1379 0,
1380 batch_group_indices.as_ref().unwrap(),
1381 &group_beam_scores,
1382 );
1383 let _ = beam_tokens.index_copy_(
1384 0,
1385 batch_group_indices.as_ref().unwrap(),
1386 &group_beam_tokens,
1387 );
1388 let new_indices = gen_opt.num_beams
1389 * group_beam_indices.divide_scalar_mode(group_size, "floor")
1390 + group_start_index
1391 + group_beam_indices.remainder(group_size);
1392 let _ = beam_indices.index_copy_(
1393 0,
1394 batch_group_indices.as_ref().unwrap(),
1395 &new_indices,
1396 );
1397 let _ = current_tokens.index_copy_(
1398 0,
1399 batch_group_indices.as_ref().unwrap(),
1400 &group_beam_tokens,
1401 );
1402 }
1403 }
1404
1405 if let Some(scores_output) = saved_beam_scores.as_mut() {
1406 scores_output.push(beam_scores.copy());
1407 }
1408 if done.iter().all(|&x| x) {
1409 break;
1410 }
1411
1412 input_ids = Tensor::cat(
1413 &[
1414 input_ids.index_select(0, &beam_indices),
1415 beam_tokens.unsqueeze(1),
1416 ],
1417 -1,
1418 );
1419
1420 current_length += 1;
1421 if let Some(max_length) = gen_opt.max_length {
1422 if current_length >= max_length {
1423 break;
1424 }
1425 }
1426 encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
1427
1428 if !self.is_encoder_decoder() {
1429 attention_mask = Tensor::cat(
1430 &[
1431 attention_mask.as_ref(),
1432 Tensor::ones(
1433 [*attention_mask.size().first().unwrap(), 1],
1434 (Kind::Int64, attention_mask.device()),
1435 )
1436 .as_ref(),
1437 ],
1438 -1,
1439 );
1440 }
1441 }
1442
1443 let mut batch_index = 0i64;
1444
1445 let mut saved_beam_scores = saved_beam_scores
1446 .map(|step_wise_scores| Tensor::stack(&step_wise_scores, 1).split(1, 0));
1447 loop {
1448 if batch_index == batch_size {
1449 break;
1450 }
1451 if done[batch_index as usize] {
1452 batch_index += 1;
1453 continue;
1454 }
1455 for beam_index in 0..gen_opt.num_beams {
1456 let effective_beam_id = batch_index * gen_opt.num_beams + beam_index;
1457 let beam_saved_token_scores = saved_beam_scores.as_mut().map(|saved_tokens| {
1458 mem::replace(&mut saved_tokens[effective_beam_id as usize], Tensor::new())
1459 });
1460 let final_score = f64::try_from(beam_scores.get(effective_beam_id)).unwrap();
1461 let final_tokens = input_ids.get(effective_beam_id);
1462 hypotheses[batch_index as usize].add(
1463 final_tokens,
1464 final_score,
1465 beam_saved_token_scores,
1466 );
1467 }
1468 batch_index += 1;
1469 }
1470 let (output_batch_size, output_num_return_sequences_per_batch) = if gen_opt.do_sample {
1471 (batch_size, 1)
1472 } else {
1473 (
1474 batch_size * gen_opt.num_return_sequences,
1475 gen_opt.num_return_sequences,
1476 )
1477 };
1478
1479 let mut sentence_lengths =
1480 Tensor::zeros([output_batch_size], (Kind::Int64, input_ids.device()));
1481 let mut best_ids = vec![];
1482
1483 let mut scores_output = if output_scores {
1484 Some(Vec::with_capacity(best_ids.len()))
1485 } else {
1486 None
1487 };
1488 let mut token_scores_output = if output_scores {
1489 Some(Vec::with_capacity(best_ids.len()))
1490 } else {
1491 None
1492 };
1493 for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
1494 let mut sorted_hypotheses = hypothesis.clone();
1495 sorted_hypotheses
1496 .beams
1497 .sort_by_key(|(score, _, _)| OrderedFloat(*score));
1498 for j in 0..output_num_return_sequences_per_batch {
1499 let effective_batch_index =
1500 output_num_return_sequences_per_batch * hypothesis_index as i64 + j;
1501
1502 let (best_score, best_hyp, best_token_scores) =
1503 sorted_hypotheses.beams.pop().unwrap();
1504 let _ = sentence_lengths.index_fill_(
1505 0,
1506 &Tensor::from_slice(&[effective_batch_index]).to(sentence_lengths.device()),
1507 *best_hyp.size().first().unwrap(),
1508 );
1509 best_ids.push(best_hyp);
1510 if let Some(current_best_scores) = &mut scores_output {
1511 current_best_scores.push(best_score);
1512 }
1513 if let Some(current_best_token_scores) = &mut token_scores_output {
1514 current_best_token_scores.push(
1515 best_token_scores
1516 .unwrap()
1517 .iter::<f64>()
1518 .unwrap()
1519 .collect::<Vec<f64>>(),
1520 );
1521 }
1522 }
1523 }
1524 let sentence_max_length = gen_opt
1525 .max_length
1526 .map(|max_length| {
1527 min(
1528 i64::try_from(sentence_lengths.max()).unwrap() + 1,
1529 max_length,
1530 )
1531 })
1532 .unwrap_or(i64::try_from(sentence_lengths.max()).unwrap() + 1);
1533
1534 let mut decoded = input_ids.new_empty(
1535 [output_batch_size, sentence_max_length],
1536 (Kind::Int64, input_ids.device()),
1537 );
1538 if i64::try_from(sentence_lengths.max()).unwrap()
1539 != i64::try_from(sentence_lengths.min()).unwrap()
1540 {
1541 let _ = decoded.fill_(
1542 gen_opt
1543 .pad_token_id
1544 .unwrap_or_else(|| gen_opt.eos_token_ids.as_ref().unwrap()[0]),
1545 );
1546 }
1547 for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
1548 let _ = decoded.get(hypothesis_index as i64).index_copy_(
1549 0,
1550 &Tensor::arange_start(
1551 0,
1552 i64::try_from(sentence_lengths.get(hypothesis_index as i64)).unwrap(),
1553 (Kind::Int64, input_ids.device()),
1554 ),
1555 best_id,
1556 );
1557 let sentence_length =
1558 i64::try_from(sentence_lengths.get(hypothesis_index as i64)).unwrap();
1559 let sentence_length_max = gen_opt
1560 .max_length
1561 .unwrap_or_else(|| i64::try_from(sentence_lengths.max()).unwrap());
1562 if sentence_length < sentence_length_max {
1563 let _ = decoded.get(hypothesis_index as i64).index_fill_(
1564 0,
1565 &Tensor::from_slice(&[sentence_length]).to_device(input_ids.device()),
1566 gen_opt.eos_token_ids.as_ref().unwrap()[0],
1567 );
1568 }
1569 }
1570 GeneratedOutputWithScores {
1571 indices: decoded,
1572 scores: scores_output,
1573 token_scores: token_scores_output,
1574 }
1575 }
1576
1577 fn reorder_cache(
1578 &self,
1579 past: &mut Cache,
1580 _encoder_outputs: Option<Tensor>,
1581 _beam_indices: &Tensor,
1582 ) -> Option<Tensor> {
1583 match past {
1584 Cache::None => None,
1585 _ => {
1586 panic!("Not implemented");
1587 }
1588 }
1589 }
1590 }
1591
1592 pub fn force_token_id_generation(scores: &mut Tensor, token_ids: &[i64], vocab_size: i64) {
1593 let impossible_tokens: Vec<i64> = (0..vocab_size)
1594 .filter(|pos| !token_ids.contains(pos))
1595 .collect();
1596 let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
1597 let _ = scores.index_fill_(
1598 1,
1599 &impossible_tokens,
1600 get_negative_infinity(scores.kind()).unwrap(),
1601 );
1602 }
1603}
1604
1605#[derive(Debug, Clone)]
1606pub struct GeneratedTextOutput {
1609 pub text: String,
1610 pub score: Option<f64>,
1611}
1612
1613#[derive(Debug, Clone)]
1614pub struct GeneratedIndicesOutput {
1617 pub indices: Vec<i64>,
1618 pub score: Option<f64>,
1619 pub token_scores: Option<Vec<f64>>,
1620}
1621
1622pub type PrefixAllowedFunction<'a> = &'a dyn Fn(i64, &Tensor) -> Vec<i64>;
1623#[derive(Clone, Copy, Default)]
1629pub struct GenerateOptions<'a> {
1634 pub min_length: Option<i64>,
1636 pub max_length: Option<i64>,
1638 pub max_new_tokens: Option<i64>,
1642 pub early_stopping: Option<bool>,
1644 pub num_return_sequences: Option<i64>,
1646 pub num_beams: Option<i64>,
1648 pub num_beam_groups: Option<i64>,
1649 pub do_sample: Option<bool>,
1651 pub temperature: Option<f64>,
1653 pub top_k: Option<i64>,
1655 pub top_p: Option<f64>,
1657 pub repetition_penalty: Option<f64>,
1659 pub length_penalty: Option<f64>,
1661 pub no_repeat_ngram_size: Option<i64>,
1663 pub diversity_penalty: Option<f64>,
1665 pub decoder_start_token_id: Option<i64>,
1667 pub forced_bos_token_id: Option<i64>,
1669 pub prefix_allowed_tokens_fn: Option<PrefixAllowedFunction<'a>>,
1671 pub bad_word_ids: Option<&'a Vec<Vec<i64>>>,
1673 pub output_scores: bool,
1675}
1676
1677macro_rules! unpack_config {
1678 ($field_name:ident, $generate_options: ident, $generate_config: ident) => {
1679 $generate_options.map_or($generate_config.$field_name, |opts| {
1680 opts.$field_name.unwrap_or($generate_config.$field_name)
1681 })
1682 };
1683}
1684
1685pub trait LanguageGenerator: PrivateLanguageGenerator {
1688 fn generate<S>(
1775 &self,
1776 prompt_texts: Option<&[S]>,
1777 generate_options: Option<GenerateOptions>,
1778 ) -> Result<Vec<GeneratedTextOutput>, RustBertError>
1779 where
1780 S: AsRef<str> + Send + Sync,
1781 {
1782 let indices_outputs = self.generate_indices(prompt_texts, generate_options)?;
1783 let mut output = Vec::with_capacity(indices_outputs.len());
1784 for generated_sequence in indices_outputs {
1785 output.push(GeneratedTextOutput {
1786 text: self
1787 ._get_tokenizer()
1788 .decode(&generated_sequence.indices, true, true),
1789 score: generated_sequence.score,
1790 });
1791 }
1792 Ok(output)
1793 }
1794
1795 fn generate_indices<S>(
1869 &self,
1870 prompt_texts: Option<&[S]>,
1871 generate_options: Option<GenerateOptions>,
1872 ) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
1873 where
1874 S: AsRef<str> + Send + Sync,
1875 {
1876 let eos_token_ids = self.get_eos_ids();
1877
1878 let config = self.get_config();
1879
1880 let max_length = generate_options.map_or(config.max_length, |generate_options| {
1881 generate_options.max_length
1882 });
1883 let encoding_max_len = if self.is_encoder_decoder() {
1884 self.get_max_positions_embeddings()
1885 } else {
1886 max_length
1887 };
1888 let pad_token_id = match self.get_pad_id() {
1889 Some(value) => Some(value),
1890 None => eos_token_ids.as_ref().map(|eos_ids| eos_ids[0]),
1891 };
1892
1893 let input_ids = match prompt_texts {
1894 Some(prompts) if !prompts.is_empty() => {
1895 self.encode_prompt_text(prompts, encoding_max_len, pad_token_id)
1896 }
1897 None => match self.get_bos_id() {
1898 Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
1899 None => return Err(RustBertError::ValueError(
1900 "A model with a BOS token must be used to start generation with an empty input"
1901 .to_string(),
1902 )),
1903 },
1904 _ => return Ok(Vec::new()),
1905 };
1906 self.generate_from_ids_and_past(input_ids, None, generate_options)
1907 }
1908
1909 fn generate_from_ids_and_past(
1960 &self,
1961 mut input_ids: Tensor,
1962 mut attention_mask: Option<Tensor>,
1963 generate_options: Option<GenerateOptions>,
1964 ) -> Result<Vec<GeneratedIndicesOutput>, RustBertError> {
1965 let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).cloned();
1966
1967 let config = PrivateLanguageGenerator::get_config(self);
1968
1969 let do_sample = unpack_config!(do_sample, generate_options, config);
1972 let num_return_sequences = unpack_config!(num_return_sequences, generate_options, config);
1973 let num_beams = unpack_config!(num_beams, generate_options, config);
1974 let min_length = unpack_config!(min_length, generate_options, config);
1975 let early_stopping = unpack_config!(early_stopping, generate_options, config);
1976 let temperature = unpack_config!(temperature, generate_options, config);
1977 let top_k = unpack_config!(top_k, generate_options, config);
1978 let top_p = unpack_config!(top_p, generate_options, config);
1979 let repetition_penalty = unpack_config!(repetition_penalty, generate_options, config);
1980 let length_penalty = unpack_config!(length_penalty, generate_options, config);
1981 let no_repeat_ngram_size = unpack_config!(no_repeat_ngram_size, generate_options, config);
1982 let num_beam_groups = generate_options.map_or(config.num_beam_groups, |opts| {
1983 opts.num_beam_groups.or(config.num_beam_groups)
1984 });
1985 let diversity_penalty = generate_options.map_or(config.diversity_penalty, |opts| {
1986 opts.diversity_penalty.or(config.diversity_penalty)
1987 });
1988 let decoder_start_token_id = generate_options.and_then(|opts| opts.decoder_start_token_id);
1989 let forced_bos_token_id = generate_options.and_then(|opts| opts.forced_bos_token_id);
1990 let bad_word_ids = generate_options.and_then(|opts| opts.bad_word_ids);
1991 let prefix_allowed_tokens_fn =
1992 generate_options.and_then(|opts| opts.prefix_allowed_tokens_fn);
1993 let output_scores = generate_options.map_or(false, |opts| opts.output_scores);
1994
1995 let pad_token_id = match self.get_pad_id() {
1996 Some(value) => Some(value),
1997 None => eos_token_ids.as_ref().map(|eos_ids| eos_ids[0]),
1998 };
1999
2000 let input_id_size = input_ids.size();
2001 let mut input_ids_len = *input_id_size.last().unwrap();
2002 if input_ids_len == 0 {
2003 input_ids = Tensor::ones(
2004 [*input_id_size.first().unwrap(), 1],
2005 (Int64, input_ids.device()),
2006 ) * self
2007 .get_bos_id()
2008 .expect("`bos_token_id` has to be defined when no `input_ids` are provided.");
2009 attention_mask = Some(Tensor::ones(
2010 [*input_id_size.first().unwrap(), 1],
2011 (Int64, input_ids.device()),
2012 ));
2013 input_ids_len += 1;
2014 }
2015
2016 let cur_len = if !self.is_encoder_decoder() {
2017 *input_ids.size().last().unwrap()
2018 } else {
2019 1
2020 };
2021 let batch_size = *input_ids.size().first().unwrap();
2022
2023 let (effective_batch_size, effective_batch_mult) = match do_sample {
2024 true => (batch_size * num_return_sequences, num_return_sequences),
2025 false => (batch_size, 1),
2026 };
2027
2028 let attention_mask = match attention_mask {
2029 Some(value) => value,
2030 None => match pad_token_id {
2031 Some(pad_id) => input_ids.ne(pad_id).to_kind(Int64),
2032 None => input_ids.ones_like().to_kind(Int64),
2033 },
2034 };
2035
2036 let encoder_outputs = if self.is_encoder_decoder() {
2037 let encoder_outputs = self
2038 .encode(&input_ids, Some(&attention_mask))
2039 .ok_or(RustBertError::UnsupportedError)?;
2040 let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
2041 .view((-1, 1))
2042 .repeat([1, num_beams * effective_batch_mult])
2043 .view(-1);
2044 Some(encoder_outputs.index_select(0, &expanded_batch_indices))
2045 } else {
2046 None
2047 };
2048
2049 let (input_ids, attention_mask) = if !self.is_encoder_decoder() {
2050 if (num_return_sequences > 1) | (num_beams > 1) {
2051 (
2052 input_ids
2053 .unsqueeze(1)
2054 .expand(
2055 [batch_size, effective_batch_mult * num_beams, cur_len],
2056 true,
2057 )
2058 .contiguous()
2059 .view((effective_batch_size * num_beams, cur_len)),
2060 attention_mask
2061 .unsqueeze(1)
2062 .expand(
2063 [batch_size, effective_batch_mult * num_beams, cur_len],
2064 true,
2065 )
2066 .contiguous()
2067 .view((effective_batch_size * num_beams, cur_len)),
2068 )
2069 } else {
2070 (input_ids, attention_mask)
2071 }
2072 } else {
2073 let decoder_start_token_id = decoder_start_token_id
2074 .or(self.get_decoder_start_id())
2075 .ok_or(RustBertError::ValueError(
2076 "decoder start id must be specified for encoder decoders".to_string(),
2077 ))?;
2078 let input_ids = Tensor::full(
2079 [effective_batch_size * num_beams, 1],
2080 decoder_start_token_id,
2081 (Int64, input_ids.device()),
2082 );
2083 let attention_mask = if (num_return_sequences > 1) | (num_beams > 1) {
2084 attention_mask
2085 .unsqueeze(1)
2086 .expand(
2087 [batch_size, effective_batch_mult * num_beams, input_ids_len],
2088 true,
2089 )
2090 .contiguous()
2091 .view((effective_batch_size * num_beams, input_ids_len))
2092 } else {
2093 attention_mask
2094 };
2095 (input_ids, attention_mask)
2096 };
2097
2098 let max_length = if let Some(generate_options) = generate_options {
2099 match (generate_options.max_length, generate_options.max_new_tokens) {
2100 (Some(max_length), _) => Some(max_length),
2101 (None, Some(max_new_tokens)) => {
2102 Some(max_new_tokens + input_ids.size().last().unwrap())
2103 }
2104 (None, None) => config.max_length,
2105 }
2106 } else {
2107 config.max_length
2108 };
2109
2110 if let Some(max_length) = max_length {
2111 if input_ids.size2()?.1 > max_length {
2112 return Err(RustBertError::ValueError("The input ids exceeds the maximum length for generation.\
2113 Reduce the size of the provided input ids or increase the allowable maximum generation length.".to_string()));
2114 }
2115 }
2116
2117 if max_length.is_none() & eos_token_ids.is_none() {
2118 return Err(RustBertError::InvalidConfigurationError("No maximum length given for a model without an EOS token. \
2119 This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`".to_string()));
2120 }
2121
2122 let gen_opt = InternalGenerateOptions {
2123 min_length,
2124 max_length,
2125 do_sample,
2126 temperature,
2127 top_k,
2128 top_p,
2129 repetition_penalty,
2130 no_repeat_ngram_size,
2131 pad_token_id,
2132 eos_token_ids,
2133 num_return_sequences,
2134 early_stopping,
2135 num_beams,
2136 length_penalty,
2137 num_beam_groups,
2138 diversity_penalty,
2139 forced_bos_token_id,
2140 bad_word_ids,
2141 };
2142
2143 let generated_output_with_scores = no_grad(|| {
2144 if num_beams > 1 {
2145 self.generate_beam_search(
2146 input_ids,
2147 encoder_outputs,
2148 cur_len,
2149 effective_batch_size,
2150 attention_mask,
2151 gen_opt,
2152 prefix_allowed_tokens_fn,
2153 output_scores,
2154 )
2155 } else {
2156 self.generate_no_beam_search(
2157 input_ids,
2158 encoder_outputs,
2159 cur_len,
2160 effective_batch_size,
2161 attention_mask,
2162 gen_opt,
2163 prefix_allowed_tokens_fn,
2164 output_scores,
2165 )
2166 }
2167 });
2168 let (decoded, scores, mut token_scores) = (
2169 generated_output_with_scores.indices,
2170 generated_output_with_scores.scores,
2171 generated_output_with_scores.token_scores,
2172 );
2173 let num_sequences = *decoded.size().first().unwrap();
2174 let mut output = Vec::with_capacity(num_sequences as usize);
2175 for sequence_index in 0..num_sequences {
2176 let indices = decoded
2177 .as_ref()
2178 .get(sequence_index)
2179 .iter::<i64>()
2180 .unwrap()
2181 .collect::<Vec<i64>>();
2182 let score = scores
2183 .as_ref()
2184 .map(|scores_value| scores_value[sequence_index as usize]);
2185
2186 let token_scores = token_scores
2187 .as_mut()
2188 .map(|token_scores| std::mem::take(&mut token_scores[sequence_index as usize]));
2189
2190 output.push(GeneratedIndicesOutput {
2191 indices,
2192 score,
2193 token_scores,
2194 });
2195 }
2196 Ok(output)
2197 }
2198
2199 fn get_tokenizer(&self) -> &TokenizerOption {
2236 self._get_tokenizer()
2237 }
2238
2239 fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
2240 self._get_tokenizer_mut()
2241 }
2242
2243 fn half(&mut self) -> Result<(), RustBertError> {
2244 self.get_var_store_mut()?.half();
2245 Ok(())
2246 }
2247
2248 fn float(&mut self) -> Result<(), RustBertError> {
2249 self.get_var_store_mut()?.float();
2250 Ok(())
2251 }
2252
2253 fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
2254 self.get_var_store_mut()?.set_device(device);
2255 Ok(())
2256 }
2257}
2258
2259#[derive(Debug)]
2260struct BeamHypotheses {
2261 max_length: Option<i64>,
2262 length_penalty: f64,
2263 early_stopping: bool,
2264 num_beams: i64,
2265 beams: Vec<(f64, Tensor, Option<Tensor>)>,
2266 worst_score: f64,
2267}
2268
2269impl Clone for BeamHypotheses {
2270 fn clone(&self) -> Self {
2271 BeamHypotheses {
2272 max_length: self.max_length,
2273 length_penalty: self.length_penalty,
2274 early_stopping: self.early_stopping,
2275 num_beams: self.num_beams,
2276 beams: self
2277 .beams
2278 .iter()
2279 .map(|(score, tensor, scores_tensor)| {
2280 (
2281 *score,
2282 tensor.copy(),
2283 scores_tensor
2284 .as_ref()
2285 .map(|scores_tensor| scores_tensor.copy()),
2286 )
2287 })
2288 .collect::<Vec<(f64, Tensor, Option<Tensor>)>>(),
2289 worst_score: self.worst_score,
2290 }
2291 }
2292}
2293
2294impl BeamHypotheses {
2295 fn new(
2296 num_beams: i64,
2297 max_length: Option<i64>,
2298 length_penalty: f64,
2299 early_stopping: bool,
2300 ) -> BeamHypotheses {
2301 BeamHypotheses {
2302 max_length: max_length.map(|max_length| max_length - 1),
2303 length_penalty,
2304 early_stopping,
2305 num_beams,
2306 beams: Vec::with_capacity(num_beams as usize + 1),
2307 worst_score: 1e9f64,
2308 }
2309 }
2310
2311 fn len(&self) -> i64 {
2312 self.beams.len() as i64
2313 }
2314
2315 fn add(
2316 &mut self,
2317 hypothesis: Tensor,
2318 sum_log_probabilities: f64,
2319 token_scores: Option<Tensor>,
2320 ) {
2321 let score =
2322 sum_log_probabilities / ((hypothesis.size()[0] as f64).powf(self.length_penalty));
2323 if (self.len() < self.num_beams) | (score > self.worst_score) {
2324 let token_scores = token_scores.map(|scores_tensor| {
2325 scores_tensor.squeeze_dim(0).diff::<Tensor>(
2326 1,
2327 0,
2328 Some(Tensor::zeros(
2329 [1],
2330 (scores_tensor.kind(), scores_tensor.device()),
2331 )),
2332 None,
2333 )
2334 });
2335 self.beams.push((score, hypothesis, token_scores));
2336 if self.len() > self.num_beams {
2337 let (worst_score_position, _) = self
2338 .beams
2339 .iter()
2340 .enumerate()
2341 .min_by_key(|(_, (score, _, _))| OrderedFloat(*score))
2342 .unwrap();
2343 let _ = self.beams.remove(worst_score_position);
2344 }
2345 self.worst_score = self
2346 .beams
2347 .iter()
2348 .min_by_key(|(score, _, _)| OrderedFloat(*score))
2349 .unwrap()
2350 .0;
2351 }
2352 }
2353
2354 fn is_done(&self, best_sum_log_probabilities: f64, current_length: i64) -> bool {
2355 if self.len() < self.num_beams {
2356 false
2357 } else if self.early_stopping {
2358 true
2359 } else {
2360 self.worst_score
2361 >= best_sum_log_probabilities / (current_length as f64).powf(self.length_penalty)
2362 }
2363 }
2364}
2365
2366pub struct LMModelOutput {
2368 pub lm_logits: Tensor,
2370 pub cache: Cache,
2372}