1use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9
10#[derive(Debug, Clone)]
16pub struct BeamSearchConfig {
17 pub beam_width: usize,
19 pub max_length: usize,
21 pub eos_token_id: Option<usize>,
24 pub length_penalty: f64,
27 pub min_length: usize,
29 pub vocab_size: usize,
31 pub temperature: f64,
33 pub top_k_filter: Option<usize>,
35}
36
37impl Default for BeamSearchConfig {
38 fn default() -> Self {
39 Self {
40 beam_width: 4,
41 max_length: 50,
42 eos_token_id: None,
43 length_penalty: 1.0,
44 min_length: 1,
45 vocab_size: 1000,
46 temperature: 1.0,
47 top_k_filter: None,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
58pub struct BeamHypothesis {
59 pub tokens: Vec<usize>,
61 pub log_prob: f64,
63 pub score: f64,
65 pub is_done: bool,
67}
68
69impl BeamHypothesis {
70 pub fn new(initial_token: usize, log_prob: f64) -> Self {
72 let tokens = vec![initial_token];
73 let score = log_prob; Self {
75 tokens,
76 log_prob,
77 score,
78 is_done: false,
79 }
80 }
81
82 pub fn extend(&self, token: usize, token_log_prob: f64) -> Self {
84 let mut tokens = self.tokens.clone();
85 tokens.push(token);
86 let log_prob = self.log_prob + token_log_prob;
87 let score = log_prob; Self {
89 tokens,
90 log_prob,
91 score,
92 is_done: false,
93 }
94 }
95
96 pub fn length_penalized_score(&self, alpha: f64) -> f64 {
98 let len = self.tokens.len() as f64;
99 if alpha == 0.0 || len == 0.0 {
100 self.log_prob
101 } else {
102 self.log_prob / len.powf(alpha)
103 }
104 }
105
106 pub fn len(&self) -> usize {
108 self.tokens.len()
109 }
110
111 pub fn is_empty(&self) -> bool {
113 self.tokens.is_empty()
114 }
115}
116
117pub struct BeamStepInput {
125 pub log_probs: Vec<Vec<f64>>,
127}
128
129impl BeamStepInput {
130 pub fn new(log_probs: Vec<Vec<f64>>) -> Self {
132 Self { log_probs }
133 }
134
135 pub fn from_logits(logits: Vec<Vec<f64>>, temperature: f64) -> Self {
137 let log_probs = logits
138 .into_iter()
139 .map(|row| {
140 let scaled = BeamSearchDecoder::apply_temperature(&row, temperature);
141 BeamSearchDecoder::log_softmax(&scaled)
142 })
143 .collect();
144 Self { log_probs }
145 }
146
147 pub fn num_beams(&self) -> usize {
149 self.log_probs.len()
150 }
151
152 pub fn vocab_size(&self) -> usize {
154 self.log_probs.first().map(|r| r.len()).unwrap_or(0)
155 }
156}
157
158#[derive(Debug, Clone)]
164pub struct BeamState {
165 pub beams: Vec<BeamHypothesis>,
167 pub completed: Vec<BeamHypothesis>,
169 pub step: usize,
171}
172
173impl BeamState {
174 pub fn initial(beam_width: usize, bos_token_id: usize) -> Self {
176 let beams = (0..beam_width)
177 .map(|_| BeamHypothesis::new(bos_token_id, 0.0))
178 .collect();
179 Self {
180 beams,
181 completed: Vec::new(),
182 step: 0,
183 }
184 }
185
186 pub fn is_done(&self, config: &BeamSearchConfig) -> bool {
188 self.completed.len() >= config.beam_width || self.step >= config.max_length
189 }
190
191 pub fn best_hypothesis(&self) -> Option<&BeamHypothesis> {
193 let all = self.beams.iter().chain(self.completed.iter());
194 all.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal))
195 }
196}
197
198#[derive(Debug)]
204struct Candidate {
205 beam_idx: usize,
206 token_id: usize,
207 log_prob: f64, score: f64, }
210
211impl PartialEq for Candidate {
212 fn eq(&self, other: &Self) -> bool {
213 self.score == other.score
214 }
215}
216
217impl Eq for Candidate {}
218
219impl PartialOrd for Candidate {
220 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
221 Some(self.cmp(other))
222 }
223}
224
225impl Ord for Candidate {
226 fn cmp(&self, other: &Self) -> Ordering {
227 self.score
228 .partial_cmp(&other.score)
229 .unwrap_or(Ordering::Equal)
230 }
231}
232
233#[derive(Debug, Clone)]
239pub enum BeamSearchError {
240 EmptyBeams,
242 BeamWidthMismatch { expected: usize, got: usize },
244 VocabSizeMismatch { expected: usize, got: usize },
246 ZeroBeamWidth,
248 MaxLengthTooShort,
250 ScoringFunctionError(String),
252 InvalidTemperature(f64),
254}
255
256impl std::fmt::Display for BeamSearchError {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 match self {
259 BeamSearchError::EmptyBeams => write!(f, "beam search has no active beams"),
260 BeamSearchError::BeamWidthMismatch { expected, got } => write!(
261 f,
262 "beam width mismatch: expected {expected} beams, got {got}"
263 ),
264 BeamSearchError::VocabSizeMismatch { expected, got } => write!(
265 f,
266 "vocab size mismatch: expected {expected} tokens, got {got}"
267 ),
268 BeamSearchError::ZeroBeamWidth => write!(f, "beam_width must be at least 1"),
269 BeamSearchError::MaxLengthTooShort => {
270 write!(f, "max_length must be at least 1")
271 }
272 BeamSearchError::ScoringFunctionError(msg) => {
273 write!(f, "scoring function error: {msg}")
274 }
275 BeamSearchError::InvalidTemperature(t) => {
276 write!(f, "temperature must be positive, got {t}")
277 }
278 }
279 }
280}
281
282impl std::error::Error for BeamSearchError {}
283
284#[derive(Debug, Clone)]
290pub struct BeamSearchStats {
291 pub total_steps: usize,
293 pub num_completed_at_eos: usize,
295 pub num_completed_at_max_length: usize,
297 pub avg_sequence_length: f64,
299 pub score_range: (f64, f64),
301}
302
303#[derive(Debug, Clone)]
309pub struct BeamSearchResult {
310 pub hypotheses: Vec<BeamHypothesis>,
312 pub best_sequence: Vec<usize>,
314 pub best_score: f64,
316 pub stats: BeamSearchStats,
318}
319
320impl BeamSearchResult {
321 pub fn best(&self) -> Option<&BeamHypothesis> {
323 self.hypotheses.first()
324 }
325}
326
327pub struct BeamSearchDecoder {
333 pub config: BeamSearchConfig,
335}
336
337impl BeamSearchDecoder {
338 pub fn new(config: BeamSearchConfig) -> Self {
340 Self { config }
341 }
342
343 pub fn with_default() -> Self {
345 Self::new(BeamSearchConfig::default())
346 }
347
348 pub fn initial_state(&self, bos_token_id: usize) -> BeamState {
350 BeamState::initial(self.config.beam_width, bos_token_id)
351 }
352
353 pub fn step(
359 &self,
360 mut state: BeamState,
361 input: &BeamStepInput,
362 ) -> Result<BeamState, BeamSearchError> {
363 if self.config.beam_width == 0 {
364 return Err(BeamSearchError::ZeroBeamWidth);
365 }
366 if state.beams.is_empty() {
367 state.step += 1;
369 return Ok(state);
370 }
371
372 if input.num_beams() != state.beams.len() {
374 return Err(BeamSearchError::BeamWidthMismatch {
375 expected: state.beams.len(),
376 got: input.num_beams(),
377 });
378 }
379 let vocab_size = self.config.vocab_size;
380 for (i, row) in input.log_probs.iter().enumerate() {
381 if row.len() != vocab_size {
382 return Err(BeamSearchError::VocabSizeMismatch {
383 expected: vocab_size,
384 got: row.len(),
385 });
386 }
387 let _ = i;
388 }
389
390 let mut heap: BinaryHeap<Candidate> = BinaryHeap::new();
392
393 for (beam_idx, beam) in state.beams.iter().enumerate() {
394 let mut lp: Vec<f64> = input.log_probs[beam_idx].clone();
395
396 if let Some(k) = self.config.top_k_filter {
398 Self::top_k_filter_logits(&mut lp, k);
399 }
400
401 for (token_id, &token_lp) in lp.iter().enumerate() {
402 if token_lp == f64::NEG_INFINITY {
404 continue;
405 }
406 let new_log_prob = beam.log_prob + token_lp;
407 let new_len = (beam.tokens.len() + 1) as f64;
409 let score = if self.config.length_penalty == 0.0 {
410 new_log_prob
411 } else {
412 new_log_prob / new_len.powf(self.config.length_penalty)
413 };
414
415 heap.push(Candidate {
416 beam_idx,
417 token_id,
418 log_prob: new_log_prob,
419 score,
420 });
421 }
422 }
423
424 let desired = self.config.beam_width;
426 let mut new_beams: Vec<BeamHypothesis> = Vec::with_capacity(desired);
427 let mut new_completed: Vec<BeamHypothesis> = state.completed.clone();
428 let mut eos_count: usize = 0;
429 let mut taken: usize = 0;
430
431 while taken < desired {
432 let candidate = match heap.pop() {
433 Some(c) => c,
434 None => break,
435 };
436
437 let parent = &state.beams[candidate.beam_idx];
438 let mut hyp = parent.extend(candidate.token_id, 0.0);
439 hyp.log_prob = candidate.log_prob;
441 hyp.score = candidate.score;
442
443 let is_eos = self
445 .config
446 .eos_token_id
447 .map(|eos| candidate.token_id == eos)
448 .unwrap_or(false);
449
450 if is_eos && hyp.len() > self.config.min_length {
451 hyp.is_done = true;
453 eos_count += 1;
454 new_completed.push(hyp);
455 } else {
456 new_beams.push(hyp);
457 }
458 taken += 1;
459 }
460
461 let (kept_beams, maxlen_beams): (Vec<_>, Vec<_>) = new_beams
463 .into_iter()
464 .partition(|b| b.len() < self.config.max_length);
465 let new_beams = kept_beams;
466 for mut beam in maxlen_beams {
467 beam.is_done = true;
468 new_completed.push(beam);
469 }
470
471 let _ = eos_count; Ok(BeamState {
474 beams: new_beams,
475 completed: new_completed,
476 step: state.step + 1,
477 })
478 }
479
480 pub fn decode<F>(
485 &self,
486 bos_token_id: usize,
487 score_fn: F,
488 ) -> Result<BeamSearchResult, BeamSearchError>
489 where
490 F: Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String>,
491 {
492 if self.config.beam_width == 0 {
493 return Err(BeamSearchError::ZeroBeamWidth);
494 }
495 if self.config.max_length == 0 {
496 return Err(BeamSearchError::MaxLengthTooShort);
497 }
498 if self.config.temperature <= 0.0 {
499 return Err(BeamSearchError::InvalidTemperature(self.config.temperature));
500 }
501
502 let mut state = self.initial_state(bos_token_id);
503 while !state.is_done(&self.config) {
504 if state.beams.is_empty() {
505 break;
506 }
507
508 let beam_seqs: Vec<&[usize]> =
510 state.beams.iter().map(|b| b.tokens.as_slice()).collect();
511
512 let raw_logits = score_fn(&beam_seqs).map_err(BeamSearchError::ScoringFunctionError)?;
513
514 let log_probs: Vec<Vec<f64>> = raw_logits
516 .into_iter()
517 .map(|row| {
518 let scaled = Self::apply_temperature(&row, self.config.temperature);
519 Self::log_softmax(&scaled)
520 })
521 .collect();
522
523 let input = BeamStepInput::new(log_probs);
524 state = self.step(state, &input)?;
525 }
526
527 let remaining: Vec<BeamHypothesis> = state.beams.drain(..).collect();
529 for mut beam in remaining {
530 beam.is_done = true;
531 state.completed.push(beam);
532 }
533
534 let mut eos_completed: usize = 0;
536 let mut max_len_completed: usize = 0;
537 for hyp in &state.completed {
538 if let Some(eos) = self.config.eos_token_id {
539 if hyp.tokens.last().copied() == Some(eos) {
540 eos_completed += 1;
541 } else {
542 max_len_completed += 1;
543 }
544 } else {
545 max_len_completed += 1;
546 }
547 }
548
549 let total_steps = state.step;
550
551 let mut hypotheses = state.completed;
553 hypotheses.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
554
555 let best_sequence = hypotheses
556 .first()
557 .map(|h| h.tokens.clone())
558 .unwrap_or_default();
559 let best_score = hypotheses
560 .first()
561 .map(|h| h.score)
562 .unwrap_or(f64::NEG_INFINITY);
563
564 let avg_sequence_length = if hypotheses.is_empty() {
565 0.0
566 } else {
567 hypotheses.iter().map(|h| h.len() as f64).sum::<f64>() / hypotheses.len() as f64
568 };
569
570 let score_range = if hypotheses.is_empty() {
571 (0.0, 0.0)
572 } else {
573 let min_score = hypotheses
574 .iter()
575 .map(|h| h.score)
576 .fold(f64::INFINITY, f64::min);
577 let max_score = hypotheses
578 .iter()
579 .map(|h| h.score)
580 .fold(f64::NEG_INFINITY, f64::max);
581 (min_score, max_score)
582 };
583
584 let stats = BeamSearchStats {
585 total_steps,
586 num_completed_at_eos: eos_completed,
587 num_completed_at_max_length: max_len_completed,
588 avg_sequence_length,
589 score_range,
590 };
591
592 Ok(BeamSearchResult {
593 hypotheses,
594 best_sequence,
595 best_score,
596 stats,
597 })
598 }
599
600 pub fn top_k_results(&self, state: &BeamState, k: usize) -> Vec<BeamHypothesis> {
602 let mut all: Vec<BeamHypothesis> = state
603 .beams
604 .iter()
605 .chain(state.completed.iter())
606 .cloned()
607 .collect();
608 all.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
609 all.truncate(k);
610 all
611 }
612
613 pub fn apply_temperature(logits: &[f64], temperature: f64) -> Vec<f64> {
615 if temperature == 1.0 {
616 return logits.to_vec();
617 }
618 let t = if temperature == 0.0 {
619 1e-8
620 } else {
621 temperature
622 };
623 logits.iter().map(|&x| x / t).collect()
624 }
625
626 pub fn log_softmax(logits: &[f64]) -> Vec<f64> {
631 if logits.is_empty() {
632 return Vec::new();
633 }
634 let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
635 let sum_exp: f64 = logits.iter().map(|&x| (x - max_val).exp()).sum();
636 let log_sum_exp = max_val + sum_exp.ln();
637 logits.iter().map(|&x| x - log_sum_exp).collect()
638 }
639
640 pub fn top_k_filter_logits(logits: &mut [f64], k: usize) {
642 if k == 0 || logits.is_empty() {
643 for v in logits.iter_mut() {
644 *v = f64::NEG_INFINITY;
645 }
646 return;
647 }
648 if k >= logits.len() {
649 return; }
651
652 let mut sorted: Vec<f64> = logits.to_owned();
654 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal));
655 let threshold = sorted[k - 1];
656
657 let mut kept = 0usize;
659 for v in logits.iter_mut() {
660 if *v >= threshold && kept < k {
661 kept += 1;
662 } else {
663 *v = f64::NEG_INFINITY;
664 }
665 }
666 }
667}
668
669#[cfg(test)]
674mod tests {
675 use super::*;
676
677 fn uniform_score_fn(
679 vocab_size: usize,
680 ) -> impl Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String> {
681 let lp = -(vocab_size as f64).ln();
682 move |beams: &[&[usize]]| Ok(beams.iter().map(|_| vec![lp; vocab_size]).collect())
683 }
684
685 #[test]
686 fn test_beam_search_config_default() {
687 let cfg = BeamSearchConfig::default();
688 assert_eq!(cfg.beam_width, 4);
689 assert_eq!(cfg.max_length, 50);
690 assert_eq!(cfg.eos_token_id, None);
691 assert_eq!(cfg.length_penalty, 1.0);
692 assert_eq!(cfg.temperature, 1.0);
693 }
694
695 #[test]
696 fn test_beam_hypothesis_new() {
697 let h = BeamHypothesis::new(0, -0.5);
698 assert_eq!(h.len(), 1);
699 assert_eq!(h.tokens, vec![0]);
700 assert!(!h.is_done);
701 }
702
703 #[test]
704 fn test_beam_hypothesis_extend() {
705 let h = BeamHypothesis::new(0, -0.5);
706 let h2 = h.extend(7, -1.0);
707 assert_eq!(h2.len(), 2);
708 assert_eq!(h2.tokens, vec![0, 7]);
709 assert!((h2.log_prob - (-1.5)).abs() < 1e-10);
710 assert!(!h2.is_done);
711 }
712
713 #[test]
714 fn test_beam_hypothesis_length_penalized_score_no_penalty() {
715 let h = BeamHypothesis::new(0, 0.0);
717 let h2 = h.extend(1, -2.0);
718 let score = h2.length_penalized_score(1.0);
720 assert!((score - (-1.0)).abs() < 1e-10);
721 }
722
723 #[test]
724 fn test_beam_step_input_from_logits() {
725 let logits = vec![vec![1.0, 2.0, 3.0], vec![0.5, 0.5, 0.5]];
726 let input = BeamStepInput::from_logits(logits, 1.0);
727 for row in &input.log_probs {
729 let sum: f64 = row.iter().map(|&lp| lp.exp()).sum();
730 assert!((sum - 1.0).abs() < 1e-9, "sum was {sum}");
731 }
732 }
733
734 #[test]
735 fn test_beam_step_input_vocab_size() {
736 let lp = vec![vec![0.1, 0.2, 0.7]; 3];
737 let input = BeamStepInput::new(lp);
738 assert_eq!(input.vocab_size(), 3);
739 assert_eq!(input.num_beams(), 3);
740 }
741
742 #[test]
743 fn test_beam_state_initial() {
744 let state = BeamState::initial(4, 0);
745 assert_eq!(state.beams.len(), 4);
746 assert_eq!(state.completed.len(), 0);
747 assert_eq!(state.step, 0);
748 for b in &state.beams {
749 assert_eq!(b.tokens, vec![0]);
750 }
751 }
752
753 #[test]
754 fn test_beam_state_is_done_max_length() {
755 let cfg = BeamSearchConfig {
756 max_length: 3,
757 ..BeamSearchConfig::default()
758 };
759 let mut state = BeamState::initial(4, 0);
760 assert!(!state.is_done(&cfg));
761 state.step = 3;
762 assert!(state.is_done(&cfg));
763 }
764
765 #[test]
766 fn test_decoder_step_advances_state() {
767 let cfg = BeamSearchConfig {
768 beam_width: 2,
769 vocab_size: 5,
770 ..BeamSearchConfig::default()
771 };
772 let decoder = BeamSearchDecoder::new(cfg);
773 let state = decoder.initial_state(0);
774 let lp = BeamSearchDecoder::log_softmax(&[1.0; 5]);
775 let input = BeamStepInput::new(vec![lp.clone(), lp]);
776 let new_state = decoder.step(state, &input).expect("step failed");
777 assert_eq!(new_state.step, 1);
778 }
779
780 #[test]
781 fn test_decoder_step_beam_count() {
782 let beam_width = 3;
783 let vocab_size = 10;
784 let cfg = BeamSearchConfig {
785 beam_width,
786 vocab_size,
787 ..BeamSearchConfig::default()
788 };
789 let decoder = BeamSearchDecoder::new(cfg);
790 let state = decoder.initial_state(0);
791 let lp = BeamSearchDecoder::log_softmax(&vec![1.0; vocab_size]);
792 let input = BeamStepInput::new(vec![lp; beam_width]);
793 let new_state = decoder.step(state, &input).expect("step failed");
794 assert_eq!(
796 new_state.beams.len() + new_state.completed.len(),
797 beam_width
798 );
799 }
800
801 #[test]
802 fn test_decoder_step_eos_moves_to_completed() {
803 let eos = 1_usize;
804 let vocab_size = 5;
805 let beam_width = 2;
806 let cfg = BeamSearchConfig {
807 beam_width,
808 vocab_size,
809 eos_token_id: Some(eos),
810 min_length: 1,
811 ..BeamSearchConfig::default()
812 };
813 let decoder = BeamSearchDecoder::new(cfg);
814 let state = decoder.initial_state(0);
815
816 let mut logits = vec![f64::NEG_INFINITY; vocab_size];
818 logits[eos] = 100.0; let lp = BeamSearchDecoder::log_softmax(&logits);
820 let input = BeamStepInput::new(vec![lp; beam_width]);
821
822 let new_state = decoder.step(state, &input).expect("step failed");
823 assert!(!new_state.completed.is_empty(), "expected completed beams");
826 }
827
828 #[test]
829 fn test_decoder_step_vocab_size_mismatch() {
830 let cfg = BeamSearchConfig {
831 beam_width: 2,
832 vocab_size: 10,
833 ..BeamSearchConfig::default()
834 };
835 let decoder = BeamSearchDecoder::new(cfg);
836 let state = decoder.initial_state(0);
837 let lp = vec![0.2; 5];
839 let input = BeamStepInput::new(vec![lp; 2]);
840 let result = decoder.step(state, &input);
841 assert!(matches!(
842 result,
843 Err(BeamSearchError::VocabSizeMismatch { .. })
844 ));
845 }
846
847 #[test]
848 fn test_decoder_log_softmax_sums_to_one() {
849 let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
850 let lsp = BeamSearchDecoder::log_softmax(&logits);
851 let sum: f64 = lsp.iter().map(|&x| x.exp()).sum();
852 assert!((sum - 1.0).abs() < 1e-9, "sum = {sum}");
853 }
854
855 #[test]
856 fn test_decoder_top_k_filter() {
857 let mut logits = vec![1.0, 5.0, 3.0, 2.0, 4.0];
858 BeamSearchDecoder::top_k_filter_logits(&mut logits, 2);
859 let non_neg_inf: Vec<usize> = logits
861 .iter()
862 .enumerate()
863 .filter(|(_, &v)| v != f64::NEG_INFINITY)
864 .map(|(i, _)| i)
865 .collect();
866 assert_eq!(non_neg_inf.len(), 2);
867 assert!(non_neg_inf.contains(&1));
869 assert!(non_neg_inf.contains(&4));
870 }
871
872 #[test]
873 fn test_decoder_decode_simple() {
874 let vocab_size = 8;
875 let cfg = BeamSearchConfig {
876 beam_width: 2,
877 max_length: 5,
878 vocab_size,
879 ..BeamSearchConfig::default()
880 };
881 let decoder = BeamSearchDecoder::new(cfg);
882 let score_fn = uniform_score_fn(vocab_size);
883 let result = decoder.decode(0, score_fn);
884 assert!(result.is_ok(), "decode returned error: {:?}", result.err());
885 }
886
887 #[test]
888 fn test_beam_search_result_best() {
889 let h1 = BeamHypothesis {
890 tokens: vec![0, 1],
891 log_prob: -1.0,
892 score: -1.0,
893 is_done: true,
894 };
895 let h2 = BeamHypothesis {
896 tokens: vec![0, 2],
897 log_prob: -0.5,
898 score: -0.5,
899 is_done: true,
900 };
901 let result = BeamSearchResult {
902 best_sequence: h2.tokens.clone(),
903 best_score: h2.score,
904 hypotheses: vec![h2.clone(), h1.clone()],
905 stats: BeamSearchStats {
906 total_steps: 1,
907 num_completed_at_eos: 0,
908 num_completed_at_max_length: 2,
909 avg_sequence_length: 2.0,
910 score_range: (-1.0, -0.5),
911 },
912 };
913 let best = result.best().expect("should have best");
914 assert_eq!(best.score, -0.5);
915 }
916
917 #[test]
918 fn test_beam_search_stats() {
919 let vocab_size = 4;
920 let cfg = BeamSearchConfig {
921 beam_width: 2,
922 max_length: 4,
923 vocab_size,
924 ..BeamSearchConfig::default()
925 };
926 let decoder = BeamSearchDecoder::new(cfg);
927 let score_fn = uniform_score_fn(vocab_size);
928 let result = decoder.decode(0, score_fn).expect("decode failed");
929 assert!(result.stats.total_steps > 0);
930 }
931
932 #[test]
933 fn test_top_k_results_sorted() {
934 let decoder = BeamSearchDecoder::with_default();
935 let make_hyp = |score: f64| BeamHypothesis {
936 tokens: vec![0],
937 log_prob: score,
938 score,
939 is_done: false,
940 };
941 let state = BeamState {
942 beams: vec![make_hyp(-2.0), make_hyp(-0.5), make_hyp(-3.0)],
943 completed: vec![make_hyp(-1.0)],
944 step: 1,
945 };
946 let top = decoder.top_k_results(&state, 3);
947 assert_eq!(top.len(), 3);
948 assert!(top[0].score >= top[1].score);
950 assert!(top[1].score >= top[2].score);
951 assert!((top[0].score - (-0.5)).abs() < 1e-10);
952 }
953
954 #[test]
955 fn test_decoder_temperature_scaling() {
956 let logits = vec![1.0, 2.0, 3.0];
957 let lp1 =
958 BeamSearchDecoder::log_softmax(&BeamSearchDecoder::apply_temperature(&logits, 1.0));
959 let lp2 =
960 BeamSearchDecoder::log_softmax(&BeamSearchDecoder::apply_temperature(&logits, 2.0));
961 let spread1 = lp1.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
963 - lp1.iter().cloned().fold(f64::INFINITY, f64::min);
964 let spread2 = lp2.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
965 - lp2.iter().cloned().fold(f64::INFINITY, f64::min);
966 assert!(
967 spread2 < spread1,
968 "temperature=2.0 should flatten distribution"
969 );
970 }
971
972 #[test]
973 fn test_beam_search_error_display() {
974 let errors = vec![
975 BeamSearchError::EmptyBeams,
976 BeamSearchError::BeamWidthMismatch {
977 expected: 4,
978 got: 2,
979 },
980 BeamSearchError::VocabSizeMismatch {
981 expected: 1000,
982 got: 500,
983 },
984 BeamSearchError::ZeroBeamWidth,
985 BeamSearchError::MaxLengthTooShort,
986 BeamSearchError::ScoringFunctionError("test error".to_string()),
987 BeamSearchError::InvalidTemperature(-1.0),
988 ];
989 for err in &errors {
990 let s = err.to_string();
991 assert!(!s.is_empty(), "display for {err:?} was empty");
992 }
993 }
994}