1use crate::{Result, TextError};
2use scirs2_core::random::Random;
4use scirs2_core::rngs::StdRng;
5use scirs2_core::RngExt;
6use torsh_tensor::Tensor;
7
8#[derive(Debug, Clone)]
14pub struct GenerationConfig {
15 pub max_length: usize,
16 pub min_length: usize,
17 pub do_sample: bool,
18 pub early_stopping: bool,
19 pub num_beams: usize,
20 pub temperature: f32,
21 pub top_k: Option<usize>,
22 pub top_p: Option<f32>,
23 pub repetition_penalty: f32,
24 pub length_penalty: f32,
25 pub no_repeat_ngram_size: usize,
26 pub encoder_no_repeat_ngram_size: usize,
27 pub bad_words_ids: Vec<Vec<u32>>,
28 pub force_words_ids: Vec<Vec<u32>>,
29 pub pad_token_id: Option<u32>,
30 pub bos_token_id: Option<u32>,
31 pub eos_token_id: Option<u32>,
32 pub decoder_start_token_id: Option<u32>,
33}
34
35impl Default for GenerationConfig {
36 fn default() -> Self {
37 Self {
38 max_length: 50,
39 min_length: 0,
40 do_sample: false,
41 early_stopping: false,
42 num_beams: 1,
43 temperature: 1.0,
44 top_k: None,
45 top_p: None,
46 repetition_penalty: 1.0,
47 length_penalty: 1.0,
48 no_repeat_ngram_size: 0,
49 encoder_no_repeat_ngram_size: 0,
50 bad_words_ids: Vec::new(),
51 force_words_ids: Vec::new(),
52 pad_token_id: None,
53 bos_token_id: None,
54 eos_token_id: None,
55 decoder_start_token_id: None,
56 }
57 }
58}
59
60pub struct TextSampler {
65 rng: Random<StdRng>,
67}
68
69impl Default for TextSampler {
70 fn default() -> Self {
71 Self {
72 rng: Random::seed(42),
73 }
74 }
75}
76
77impl TextSampler {
78 pub fn greedy_sample(&self, logits: &Tensor) -> Result<u32> {
80 let vocab_size = logits.shape().dims()[logits.shape().ndim() - 1];
81 let mut max_idx = 0;
82 let mut max_val = f32::NEG_INFINITY;
83
84 for i in 0..vocab_size {
86 let val = logits.select(0, i as i64)?.item()?;
87 if val > max_val {
88 max_val = val;
89 max_idx = i;
90 }
91 }
92
93 Ok(max_idx as u32)
94 }
95
96 pub fn temperature_sample(&mut self, logits: &Tensor, temperature: f32) -> Result<u32> {
98 if temperature <= 0.0 {
99 return self.greedy_sample(logits);
100 }
101
102 let scaled_logits = logits.div_scalar(temperature)?;
104
105 let probs = scaled_logits.softmax(-1)?;
107
108 self.multinomial_sample(&probs)
109 }
110
111 pub fn top_k_sample(&mut self, logits: &Tensor, k: usize, temperature: f32) -> Result<u32> {
113 let vocab_size = logits.shape().dims()[logits.shape().ndim() - 1];
114 let k = k.min(vocab_size);
115
116 let (top_values, top_indices) = self.get_top_k(logits, k)?;
118
119 let scaled_values = if temperature > 0.0 {
121 top_values.div_scalar(temperature)?
122 } else {
123 top_values
124 };
125
126 let probs = scaled_values.softmax(-1)?;
128
129 let local_idx = self.multinomial_sample(&probs)?;
131
132 let original_idx = top_indices.select(0, local_idx as i64)?.item()?;
134 Ok(original_idx as u32)
135 }
136
137 pub fn top_p_sample(&mut self, logits: &Tensor, p: f32, temperature: f32) -> Result<u32> {
139 let scaled_logits = if temperature > 0.0 {
141 logits.div_scalar(temperature)?
142 } else {
143 logits.clone()
144 };
145
146 let probs = scaled_logits.softmax(-1)?;
148
149 let (sorted_probs, sorted_indices) = self.sort_descending(&probs)?;
151
152 let cumsum = self.cumulative_sum(&sorted_probs)?;
154
155 let vocab_size = probs.shape().dims()[probs.shape().ndim() - 1];
157 let mut cutoff = vocab_size;
158
159 for i in 0..vocab_size {
160 let cum_prob = cumsum.select(0, i as i64)?.item()?;
161 if cum_prob > p {
162 cutoff = i + 1;
163 break;
164 }
165 }
166
167 let nucleus_probs = sorted_probs.narrow(0, 0, cutoff)?;
169 let nucleus_indices = sorted_indices.narrow(0, 0, cutoff)?;
170
171 let sum_tensor = nucleus_probs.sum()?;
173 let renormalized_probs = nucleus_probs.div(&sum_tensor)?;
174
175 let local_idx = self.multinomial_sample(&renormalized_probs)?;
177
178 let original_idx = nucleus_indices.select(0, local_idx as i64)?.item()?;
180 Ok(original_idx as u32)
181 }
182
183 pub fn top_k_top_p_sample(
185 &mut self,
186 logits: &Tensor,
187 k: Option<usize>,
188 p: Option<f32>,
189 temperature: f32,
190 ) -> Result<u32> {
191 let mut working_logits = logits.clone();
192
193 if let Some(k_val) = k {
195 let vocab_size = working_logits.shape().dims()[working_logits.shape().ndim() - 1];
196 if k_val < vocab_size {
197 let (top_values, top_indices) = self.get_top_k(&working_logits, k_val)?;
198
199 let mut new_logits_data = vec![f32::NEG_INFINITY; vocab_size];
201
202 for i in 0..k_val {
204 let idx = top_indices.select(0, i as i64)?.item()? as usize;
205 let val = top_values.select(0, i as i64)?.item()?;
206 if idx < vocab_size {
207 new_logits_data[idx] = val;
208 }
209 }
210
211 working_logits = Tensor::from_data(
212 new_logits_data,
213 working_logits.shape().dims().to_vec(),
214 torsh_core::device::DeviceType::Cpu,
215 )?;
216 }
217 }
218
219 if let Some(p_val) = p {
221 return self.top_p_sample(&working_logits, p_val, temperature);
222 }
223
224 self.temperature_sample(&working_logits, temperature)
226 }
227
228 fn multinomial_sample(&mut self, probs: &Tensor) -> Result<u32> {
230 let vocab_size = probs.shape().dims()[probs.shape().ndim() - 1];
231 let random_val: f32 = self.rng.random();
232
233 let mut cumulative = 0.0;
234 for i in 0..vocab_size {
235 let prob = probs.select(0, i as i64)?.item()?;
236 cumulative += prob;
237 if random_val <= cumulative {
238 return Ok(i as u32);
239 }
240 }
241
242 Ok((vocab_size - 1) as u32)
244 }
245
246 fn get_top_k(&self, tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
247 let vocab_size = tensor.shape().dims()[tensor.shape().ndim() - 1];
249 let mut values_and_indices: Vec<(f32, usize)> = Vec::new();
250
251 for i in 0..vocab_size {
252 let val = tensor.select(0, i as i64)?.item()?;
253 values_and_indices.push((val, i));
254 }
255
256 values_and_indices
257 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
258 values_and_indices.truncate(k);
259
260 let values: Vec<f32> = values_and_indices.iter().map(|(v, _)| *v).collect();
261 let indices: Vec<f32> = values_and_indices.iter().map(|(_, i)| *i as f32).collect();
262
263 let values_tensor = Tensor::from_vec(values, &[k])?.to_dtype(tensor.dtype())?;
264 let indices_tensor = Tensor::from_vec(indices, &[k])?.to_dtype(tensor.dtype())?;
265
266 Ok((values_tensor, indices_tensor))
267 }
268
269 fn sort_descending(&self, tensor: &Tensor) -> Result<(Tensor, Tensor)> {
270 let vocab_size = tensor.shape().dims()[tensor.shape().ndim() - 1];
271 let mut values_and_indices: Vec<(f32, usize)> = Vec::new();
272
273 for i in 0..vocab_size {
274 let val = tensor.select(0, i as i64)?.item()?;
275 values_and_indices.push((val, i));
276 }
277
278 values_and_indices
279 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
280
281 let values: Vec<f32> = values_and_indices.iter().map(|(v, _)| *v).collect();
282 let indices: Vec<f32> = values_and_indices.iter().map(|(_, i)| *i as f32).collect();
283
284 let values_tensor = Tensor::from_vec(values, &[vocab_size])?.to_dtype(tensor.dtype())?;
285 let indices_tensor = Tensor::from_vec(indices, &[vocab_size])?.to_dtype(tensor.dtype())?;
286
287 Ok((values_tensor, indices_tensor))
288 }
289
290 fn cumulative_sum(&self, tensor: &Tensor) -> Result<Tensor> {
291 let size = tensor.shape().dims()[tensor.shape().ndim() - 1];
292 let mut cumsum = Vec::new();
293 let mut running_sum = 0.0;
294
295 for i in 0..size {
296 let val = tensor.select(0, i as i64)?.item()?;
297 running_sum += val;
298 cumsum.push(running_sum);
299 }
300
301 Ok(Tensor::from_vec(cumsum, &[size])?.to_dtype(tensor.dtype())?)
302 }
303}
304
305#[derive(Debug, Clone)]
310pub struct BeamHypothesis {
311 pub tokens: Vec<u32>,
312 pub score: f32,
313 pub length: usize,
314}
315
316impl BeamHypothesis {
317 pub fn new(tokens: Vec<u32>, score: f32) -> Self {
318 let length = tokens.len();
319 Self {
320 tokens,
321 score,
322 length,
323 }
324 }
325
326 pub fn normalized_score(&self, length_penalty: f32) -> f32 {
327 self.score / (self.length as f32).powf(length_penalty)
328 }
329}
330
331pub struct BeamSearchDecoder {
332 num_beams: usize,
333 max_length: usize,
334 length_penalty: f32,
335 early_stopping: bool,
336 eos_token_id: Option<u32>,
337}
338
339impl BeamSearchDecoder {
340 pub fn new(
341 num_beams: usize,
342 max_length: usize,
343 length_penalty: f32,
344 early_stopping: bool,
345 eos_token_id: Option<u32>,
346 ) -> Self {
347 Self {
348 num_beams,
349 max_length,
350 length_penalty,
351 early_stopping,
352 eos_token_id,
353 }
354 }
355
356 pub fn search(
357 &self,
358 initial_tokens: Vec<u32>,
359 vocab_size: usize,
360 get_logits: impl Fn(&[u32]) -> Result<Tensor>,
361 ) -> Result<Vec<BeamHypothesis>> {
362 let mut beam_hypotheses = BeamHypothesesPool::new(
363 self.num_beams,
364 self.max_length,
365 self.length_penalty,
366 self.early_stopping,
367 );
368
369 let mut beams: Vec<BeamHypothesis> = vec![BeamHypothesis::new(initial_tokens.clone(), 0.0)];
371
372 for _step in 0..self.max_length {
373 let mut all_candidates = Vec::new();
374
375 for beam in &beams {
376 if let Some(eos_id) = self.eos_token_id {
377 if beam.tokens.last() == Some(&eos_id) {
378 beam_hypotheses.add(beam.clone());
380 continue;
381 }
382 }
383
384 let logits = get_logits(&beam.tokens)?;
386 let log_probs = logits.log_softmax(-1)?;
387
388 for token_id in 0..vocab_size.min(self.num_beams * 2) {
390 let token_log_prob = log_probs.select(0, token_id as i64)?.item()?;
391 let new_score = beam.score + token_log_prob;
392
393 let mut new_tokens = beam.tokens.clone();
394 new_tokens.push(token_id as u32);
395
396 all_candidates.push(BeamHypothesis::new(new_tokens, new_score));
397 }
398 }
399
400 all_candidates.sort_by(|a, b| {
402 b.normalized_score(self.length_penalty)
403 .partial_cmp(&a.normalized_score(self.length_penalty))
404 .unwrap_or(std::cmp::Ordering::Equal)
405 });
406 beams = all_candidates.into_iter().take(self.num_beams).collect();
407
408 if self.early_stopping
410 && beam_hypotheses.is_done(
411 beams
412 .iter()
413 .map(|b| b.normalized_score(self.length_penalty))
414 .fold(f32::NEG_INFINITY, f32::max),
415 )
416 {
417 break;
418 }
419
420 beams.retain(|beam| {
422 if let Some(eos_id) = self.eos_token_id {
423 beam.tokens.last() != Some(&eos_id)
424 } else {
425 true
426 }
427 });
428
429 if beams.is_empty() {
430 break;
431 }
432 }
433
434 for beam in beams {
436 beam_hypotheses.add(beam);
437 }
438
439 Ok(beam_hypotheses.finalize())
440 }
441}
442
443struct BeamHypothesesPool {
444 hypotheses: Vec<BeamHypothesis>,
445 max_hypotheses: usize,
446 max_length: usize,
447 length_penalty: f32,
448 early_stopping: bool,
449}
450
451impl BeamHypothesesPool {
452 fn new(
453 max_hypotheses: usize,
454 max_length: usize,
455 length_penalty: f32,
456 early_stopping: bool,
457 ) -> Self {
458 Self {
459 hypotheses: Vec::new(),
460 max_hypotheses,
461 max_length,
462 length_penalty,
463 early_stopping,
464 }
465 }
466
467 fn add(&mut self, hypothesis: BeamHypothesis) {
468 let score = hypothesis.normalized_score(self.length_penalty);
469
470 let insert_pos = self
472 .hypotheses
473 .binary_search_by(|h| {
474 score
475 .partial_cmp(&h.normalized_score(self.length_penalty))
476 .unwrap_or(std::cmp::Ordering::Equal)
477 })
478 .unwrap_or_else(|e| e);
479
480 self.hypotheses.insert(insert_pos, hypothesis);
481
482 if self.hypotheses.len() > self.max_hypotheses {
484 self.hypotheses.truncate(self.max_hypotheses);
485 }
486 }
487
488 fn is_done(&self, best_sum_logprobs: f32) -> bool {
489 if !self.early_stopping {
490 return false;
491 }
492
493 if self.hypotheses.len() < self.max_hypotheses {
494 return false;
495 }
496
497 let worst_score = self
498 .hypotheses
499 .last()
500 .map(|h| h.normalized_score(self.length_penalty))
501 .unwrap_or(f32::NEG_INFINITY);
502 let best_possible_score =
503 best_sum_logprobs / (self.max_length as f32).powf(self.length_penalty);
504
505 worst_score >= best_possible_score
506 }
507
508 fn finalize(mut self) -> Vec<BeamHypothesis> {
509 self.hypotheses.sort_by(|a, b| {
510 b.normalized_score(self.length_penalty)
511 .partial_cmp(&a.normalized_score(self.length_penalty))
512 .unwrap_or(std::cmp::Ordering::Equal)
513 });
514 self.hypotheses
515 }
516}
517
518pub struct RepetitionPenalty;
523
524impl RepetitionPenalty {
525 pub fn apply(logits: &Tensor, generated_tokens: &[u32], penalty: f32) -> Result<Tensor> {
526 if penalty == 1.0 {
527 return Ok(logits.clone());
528 }
529
530 let mut penalized_logits = logits.clone();
531
532 for &token in generated_tokens {
534 let current_logit = penalized_logits.select(0, token as i64)?.item()?;
535 let penalized_value = if current_logit > 0.0 {
536 current_logit / penalty
537 } else {
538 current_logit * penalty
539 };
540
541 let _token_tensor = Tensor::from_vec(vec![token as i64], &[1])?;
543 let _penalty_tensor = Tensor::scalar(penalized_value)?;
544
545 let vocab_size = penalized_logits.shape().dims()[0];
547 let mut logits_vec = penalized_logits.to_vec()?;
548 logits_vec[token as usize] = penalized_value;
549 penalized_logits = Tensor::from_vec(logits_vec, &[vocab_size])?;
550 }
551
552 Ok(penalized_logits)
553 }
554}
555
556pub struct NGramRepetitionFilter {
557 no_repeat_ngram_size: usize,
558}
559
560impl NGramRepetitionFilter {
561 pub fn new(no_repeat_ngram_size: usize) -> Self {
562 Self {
563 no_repeat_ngram_size,
564 }
565 }
566
567 pub fn filter_logits(&self, logits: &Tensor, generated_tokens: &[u32]) -> Result<Tensor> {
568 if self.no_repeat_ngram_size == 0 || generated_tokens.len() < self.no_repeat_ngram_size {
569 return Ok(logits.clone());
570 }
571
572 let mut filtered_logits = logits.clone();
573 let _vocab_size = logits.shape().dims()[logits.shape().ndim() - 1];
574
575 let mut banned_tokens = std::collections::HashSet::new();
577
578 for i in 0..generated_tokens.len() - self.no_repeat_ngram_size + 1 {
579 let ngram = &generated_tokens[i..i + self.no_repeat_ngram_size - 1];
580
581 let current_context =
583 &generated_tokens[generated_tokens.len() - self.no_repeat_ngram_size + 1..];
584
585 if ngram == current_context {
586 let banned_token = generated_tokens[i + self.no_repeat_ngram_size - 1];
588 banned_tokens.insert(banned_token);
589 }
590 }
591
592 let vocab_size = filtered_logits.shape().dims()[0];
594 let mut logits_vec = filtered_logits.to_vec()?;
595 for banned_token in banned_tokens {
596 if (banned_token as usize) < logits_vec.len() {
597 logits_vec[banned_token as usize] = f32::NEG_INFINITY;
598 }
599 }
600 filtered_logits = Tensor::from_vec(logits_vec, &[vocab_size])?;
601
602 Ok(filtered_logits)
603 }
604}
605
606pub struct TextGenerator {
611 sampler: TextSampler,
612 beam_decoder: Option<BeamSearchDecoder>,
613 repetition_penalty: RepetitionPenalty,
614 ngram_filter: Option<NGramRepetitionFilter>,
615}
616
617impl TextGenerator {
618 pub fn new(config: &GenerationConfig) -> Self {
619 let beam_decoder = if config.num_beams > 1 {
620 Some(BeamSearchDecoder::new(
621 config.num_beams,
622 config.max_length,
623 config.length_penalty,
624 config.early_stopping,
625 config.eos_token_id,
626 ))
627 } else {
628 None
629 };
630
631 let ngram_filter = if config.no_repeat_ngram_size > 0 {
632 Some(NGramRepetitionFilter::new(config.no_repeat_ngram_size))
633 } else {
634 None
635 };
636
637 Self {
638 sampler: TextSampler::default(),
639 beam_decoder,
640 repetition_penalty: RepetitionPenalty,
641 ngram_filter,
642 }
643 }
644
645 pub fn generate(
646 &mut self,
647 initial_tokens: Vec<u32>,
648 vocab_size: usize,
649 config: &GenerationConfig,
650 get_logits: impl Fn(&[u32]) -> Result<Tensor> + Clone,
651 ) -> Result<Vec<Vec<u32>>> {
652 if config.num_beams > 1 {
653 if let Some(ref decoder) = self.beam_decoder {
655 let hypotheses = decoder.search(initial_tokens, vocab_size, get_logits)?;
656 Ok(hypotheses.into_iter().map(|h| h.tokens).collect())
657 } else {
658 Err(TextError::ModelError(
659 "Beam decoder not initialized".to_string(),
660 ))
661 }
662 } else {
663 let result =
665 self.generate_with_sampling(initial_tokens, vocab_size, config, get_logits)?;
666 Ok(vec![result])
667 }
668 }
669
670 fn generate_with_sampling(
671 &mut self,
672 mut tokens: Vec<u32>,
673 _vocab_size: usize,
674 config: &GenerationConfig,
675 get_logits: impl Fn(&[u32]) -> Result<Tensor>,
676 ) -> Result<Vec<u32>> {
677 for _ in 0..config.max_length {
678 let mut logits = get_logits(&tokens)?;
680
681 if config.repetition_penalty != 1.0 {
683 logits = RepetitionPenalty::apply(&logits, &tokens, config.repetition_penalty)?;
684 }
685
686 if let Some(ref filter) = self.ngram_filter {
688 logits = filter.filter_logits(&logits, &tokens)?;
689 }
690
691 let next_token = if config.do_sample {
693 self.sampler.top_k_top_p_sample(
694 &logits,
695 config.top_k,
696 config.top_p,
697 config.temperature,
698 )?
699 } else {
700 self.sampler.greedy_sample(&logits)?
701 };
702
703 tokens.push(next_token);
704
705 if let Some(eos_id) = config.eos_token_id {
707 if next_token == eos_id {
708 break;
709 }
710 }
711
712 if tokens.len() >= config.min_length {
714 if let Some(eos_id) = config.eos_token_id {
715 if next_token == eos_id {
716 break;
717 }
718 }
719 }
720 }
721
722 Ok(tokens)
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use torsh_core::{device::DeviceType as Device, dtype::DType};
730
731 #[test]
732 fn test_text_sampler_creation() {
733 let _sampler = TextSampler::default();
734 }
736
737 #[test]
738 fn test_generation_config_default() {
739 let config = GenerationConfig::default();
740 assert_eq!(config.max_length, 50);
741 assert_eq!(config.num_beams, 1);
742 assert!(!config.do_sample);
743 }
744
745 #[test]
746 fn test_beam_hypothesis() {
747 let tokens = vec![1, 2, 3];
748 let score = -1.5;
749 let hypothesis = BeamHypothesis::new(tokens.clone(), score);
750
751 assert_eq!(hypothesis.tokens, tokens);
752 assert_eq!(hypothesis.score, score);
753 assert_eq!(hypothesis.length, 3);
754 }
755
756 #[test]
757 fn test_greedy_sampling() {
758 let _device = Device::Cpu;
759 let dtype = DType::F32;
760
761 let logits = Tensor::from_vec(vec![0.1, 0.2, 0.9, 0.3], &[4])
763 .unwrap()
764 .to_dtype(dtype)
765 .unwrap();
766
767 let sampler = TextSampler::default();
768 let result = sampler.greedy_sample(&logits).unwrap();
769
770 assert_eq!(result, 2); }
772}