Skip to main content

tensorlogic_infer/
beam_search.rs

1//! Beam Search Decoder for sequence generation.
2//!
3//! Implements beam search, a heuristic search algorithm that explores the best-B
4//! candidate sequences (beams) at each decoding step, keeping only the highest-scoring
5//! hypotheses by cumulative log-probability.
6
7use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9
10// ---------------------------------------------------------------------------
11// Configuration
12// ---------------------------------------------------------------------------
13
14/// Configuration for beam search decoding.
15#[derive(Debug, Clone)]
16pub struct BeamSearchConfig {
17    /// Number of beams to keep at each step.
18    pub beam_width: usize,
19    /// Maximum sequence length (inclusive).
20    pub max_length: usize,
21    /// Optional end-of-sequence token ID. When a beam generates this token it is
22    /// moved to the completed list (subject to `min_length`).
23    pub eos_token_id: Option<usize>,
24    /// Length penalty exponent α: `score = log_prob / length^α`.
25    /// `1.0` gives simple length normalisation; `0.0` disables it.
26    pub length_penalty: f64,
27    /// Minimum sequence length before EOS is allowed to terminate a beam.
28    pub min_length: usize,
29    /// Number of tokens in the vocabulary.
30    pub vocab_size: usize,
31    /// Temperature for logit scaling before softmax.  `1.0` = no change.
32    pub temperature: f64,
33    /// Optional top-k filter: only the top-k logits are kept per step.
34    pub top_k_filter: Option<usize>,
35}
36
37impl Default for BeamSearchConfig {
38    fn default() -> Self {
39        Self {
40            beam_width: 4,
41            max_length: 50,
42            eos_token_id: None,
43            length_penalty: 1.0,
44            min_length: 1,
45            vocab_size: 1000,
46            temperature: 1.0,
47            top_k_filter: None,
48        }
49    }
50}
51
52// ---------------------------------------------------------------------------
53// BeamHypothesis
54// ---------------------------------------------------------------------------
55
56/// A single hypothesis (candidate sequence) tracked during beam search.
57#[derive(Debug, Clone)]
58pub struct BeamHypothesis {
59    /// Token IDs generated so far (including the seed/BOS token).
60    pub tokens: Vec<usize>,
61    /// Cumulative log-probability of the sequence.
62    pub log_prob: f64,
63    /// Length-penalised score used for ranking.
64    pub score: f64,
65    /// Whether this hypothesis has terminated (EOS or max length reached).
66    pub is_done: bool,
67}
68
69impl BeamHypothesis {
70    /// Create a new hypothesis seeded with a single token.
71    pub fn new(initial_token: usize, log_prob: f64) -> Self {
72        let tokens = vec![initial_token];
73        let score = log_prob; // length = 1, any alpha => log_prob / 1.0
74        Self {
75            tokens,
76            log_prob,
77            score,
78            is_done: false,
79        }
80    }
81
82    /// Extend this hypothesis by one token, returning a new hypothesis.
83    pub fn extend(&self, token: usize, token_log_prob: f64) -> Self {
84        let mut tokens = self.tokens.clone();
85        tokens.push(token);
86        let log_prob = self.log_prob + token_log_prob;
87        let score = log_prob; // score is updated by caller via length_penalized_score
88        Self {
89            tokens,
90            log_prob,
91            score,
92            is_done: false,
93        }
94    }
95
96    /// Compute the length-penalised score: `log_prob / length^alpha`.
97    pub fn length_penalized_score(&self, alpha: f64) -> f64 {
98        let len = self.tokens.len() as f64;
99        if alpha == 0.0 || len == 0.0 {
100            self.log_prob
101        } else {
102            self.log_prob / len.powf(alpha)
103        }
104    }
105
106    /// Number of tokens in this hypothesis.
107    pub fn len(&self) -> usize {
108        self.tokens.len()
109    }
110
111    /// Returns `true` if the hypothesis contains no tokens.
112    pub fn is_empty(&self) -> bool {
113        self.tokens.is_empty()
114    }
115}
116
117// ---------------------------------------------------------------------------
118// BeamStepInput
119// ---------------------------------------------------------------------------
120
121/// Log-probabilities (or raw logits) for each beam at one decoding step.
122///
123/// Shape: `[beam_width][vocab_size]` — each row is the distribution for one beam.
124pub struct BeamStepInput {
125    /// `log_probs[beam_i][vocab_j]` = log P(token j | history of beam i).
126    pub log_probs: Vec<Vec<f64>>,
127}
128
129impl BeamStepInput {
130    /// Construct directly from pre-computed log-probabilities.
131    pub fn new(log_probs: Vec<Vec<f64>>) -> Self {
132        Self { log_probs }
133    }
134
135    /// Construct from raw logits: apply temperature scaling then log-softmax.
136    pub fn from_logits(logits: Vec<Vec<f64>>, temperature: f64) -> Self {
137        let log_probs = logits
138            .into_iter()
139            .map(|row| {
140                let scaled = BeamSearchDecoder::apply_temperature(&row, temperature);
141                BeamSearchDecoder::log_softmax(&scaled)
142            })
143            .collect();
144        Self { log_probs }
145    }
146
147    /// Number of beams (rows) in this input.
148    pub fn num_beams(&self) -> usize {
149        self.log_probs.len()
150    }
151
152    /// Vocabulary size inferred from the first row (0 if empty).
153    pub fn vocab_size(&self) -> usize {
154        self.log_probs.first().map(|r| r.len()).unwrap_or(0)
155    }
156}
157
158// ---------------------------------------------------------------------------
159// BeamState
160// ---------------------------------------------------------------------------
161
162/// Complete state of a beam search at a given decoding step.
163#[derive(Debug, Clone)]
164pub struct BeamState {
165    /// Currently active (non-terminated) beams.
166    pub beams: Vec<BeamHypothesis>,
167    /// Completed beams (those that emitted EOS or were finalised at max length).
168    pub completed: Vec<BeamHypothesis>,
169    /// Current step index (0-based; incremented after each call to `step`).
170    pub step: usize,
171}
172
173impl BeamState {
174    /// Create an initial state: `beam_width` identical hypotheses seeded with `bos_token_id`.
175    pub fn initial(beam_width: usize, bos_token_id: usize) -> Self {
176        let beams = (0..beam_width)
177            .map(|_| BeamHypothesis::new(bos_token_id, 0.0))
178            .collect();
179        Self {
180            beams,
181            completed: Vec::new(),
182            step: 0,
183        }
184    }
185
186    /// Returns `true` if search is complete: enough completed beams or step limit reached.
187    pub fn is_done(&self, config: &BeamSearchConfig) -> bool {
188        self.completed.len() >= config.beam_width || self.step >= config.max_length
189    }
190
191    /// Return the highest-scored hypothesis across active and completed beams.
192    pub fn best_hypothesis(&self) -> Option<&BeamHypothesis> {
193        let all = self.beams.iter().chain(self.completed.iter());
194        all.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal))
195    }
196}
197
198// ---------------------------------------------------------------------------
199// Candidate (internal helper for BinaryHeap)
200// ---------------------------------------------------------------------------
201
202/// Internal candidate used for ranking during a beam step.
203#[derive(Debug)]
204struct Candidate {
205    beam_idx: usize,
206    token_id: usize,
207    log_prob: f64, // cumulative log-prob if this candidate is chosen
208    score: f64,    // penalised score (used for ranking)
209}
210
211impl PartialEq for Candidate {
212    fn eq(&self, other: &Self) -> bool {
213        self.score == other.score
214    }
215}
216
217impl Eq for Candidate {}
218
219impl PartialOrd for Candidate {
220    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
221        Some(self.cmp(other))
222    }
223}
224
225impl Ord for Candidate {
226    fn cmp(&self, other: &Self) -> Ordering {
227        self.score
228            .partial_cmp(&other.score)
229            .unwrap_or(Ordering::Equal)
230    }
231}
232
233// ---------------------------------------------------------------------------
234// BeamSearchError
235// ---------------------------------------------------------------------------
236
237/// Errors that can arise during beam search.
238#[derive(Debug, Clone)]
239pub enum BeamSearchError {
240    /// No active beams remain.
241    EmptyBeams,
242    /// The number of beams in `BeamStepInput` does not match the state.
243    BeamWidthMismatch { expected: usize, got: usize },
244    /// Vocabulary size in `BeamStepInput` does not match configuration.
245    VocabSizeMismatch { expected: usize, got: usize },
246    /// `beam_width` is zero, which is invalid.
247    ZeroBeamWidth,
248    /// `max_length` is too short to produce any output.
249    MaxLengthTooShort,
250    /// The user-supplied scoring function returned an error.
251    ScoringFunctionError(String),
252    /// Temperature must be positive.
253    InvalidTemperature(f64),
254}
255
256impl std::fmt::Display for BeamSearchError {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        match self {
259            BeamSearchError::EmptyBeams => write!(f, "beam search has no active beams"),
260            BeamSearchError::BeamWidthMismatch { expected, got } => write!(
261                f,
262                "beam width mismatch: expected {expected} beams, got {got}"
263            ),
264            BeamSearchError::VocabSizeMismatch { expected, got } => write!(
265                f,
266                "vocab size mismatch: expected {expected} tokens, got {got}"
267            ),
268            BeamSearchError::ZeroBeamWidth => write!(f, "beam_width must be at least 1"),
269            BeamSearchError::MaxLengthTooShort => {
270                write!(f, "max_length must be at least 1")
271            }
272            BeamSearchError::ScoringFunctionError(msg) => {
273                write!(f, "scoring function error: {msg}")
274            }
275            BeamSearchError::InvalidTemperature(t) => {
276                write!(f, "temperature must be positive, got {t}")
277            }
278        }
279    }
280}
281
282impl std::error::Error for BeamSearchError {}
283
284// ---------------------------------------------------------------------------
285// BeamSearchStats
286// ---------------------------------------------------------------------------
287
288/// Statistics collected over a complete beam search run.
289#[derive(Debug, Clone)]
290pub struct BeamSearchStats {
291    /// Total number of decoding steps taken.
292    pub total_steps: usize,
293    /// Number of beams that ended by generating the EOS token.
294    pub num_completed_at_eos: usize,
295    /// Number of beams that ended by reaching `max_length`.
296    pub num_completed_at_max_length: usize,
297    /// Mean sequence length across all final hypotheses.
298    pub avg_sequence_length: f64,
299    /// `(min_score, max_score)` across all final hypotheses.
300    pub score_range: (f64, f64),
301}
302
303// ---------------------------------------------------------------------------
304// BeamSearchResult
305// ---------------------------------------------------------------------------
306
307/// Final output of a complete beam search run.
308#[derive(Debug, Clone)]
309pub struct BeamSearchResult {
310    /// All final hypotheses, sorted by score descending.
311    pub hypotheses: Vec<BeamHypothesis>,
312    /// Token sequence of the best hypothesis.
313    pub best_sequence: Vec<usize>,
314    /// Score of the best hypothesis.
315    pub best_score: f64,
316    /// Aggregate statistics.
317    pub stats: BeamSearchStats,
318}
319
320impl BeamSearchResult {
321    /// Return a reference to the best hypothesis, or `None` if empty.
322    pub fn best(&self) -> Option<&BeamHypothesis> {
323        self.hypotheses.first()
324    }
325}
326
327// ---------------------------------------------------------------------------
328// BeamSearchDecoder
329// ---------------------------------------------------------------------------
330
331/// Beam search decoder.
332pub struct BeamSearchDecoder {
333    /// Configuration controlling all search behaviour.
334    pub config: BeamSearchConfig,
335}
336
337impl BeamSearchDecoder {
338    /// Create a decoder with the supplied configuration.
339    pub fn new(config: BeamSearchConfig) -> Self {
340        Self { config }
341    }
342
343    /// Create a decoder with default configuration.
344    pub fn with_default() -> Self {
345        Self::new(BeamSearchConfig::default())
346    }
347
348    /// Create an initial `BeamState` seeded with `bos_token_id`.
349    pub fn initial_state(&self, bos_token_id: usize) -> BeamState {
350        BeamState::initial(self.config.beam_width, bos_token_id)
351    }
352
353    /// Perform one step of beam search.
354    ///
355    /// Given the current `state` and per-beam log-probabilities `input`, advance
356    /// each beam by one token, prune to the top-`beam_width` candidates, and
357    /// handle EOS/completion logic.
358    pub fn step(
359        &self,
360        mut state: BeamState,
361        input: &BeamStepInput,
362    ) -> Result<BeamState, BeamSearchError> {
363        if self.config.beam_width == 0 {
364            return Err(BeamSearchError::ZeroBeamWidth);
365        }
366        if state.beams.is_empty() {
367            // All beams may have already completed; nothing to advance.
368            state.step += 1;
369            return Ok(state);
370        }
371
372        // Validate input dimensions.
373        if input.num_beams() != state.beams.len() {
374            return Err(BeamSearchError::BeamWidthMismatch {
375                expected: state.beams.len(),
376                got: input.num_beams(),
377            });
378        }
379        let vocab_size = self.config.vocab_size;
380        for (i, row) in input.log_probs.iter().enumerate() {
381            if row.len() != vocab_size {
382                return Err(BeamSearchError::VocabSizeMismatch {
383                    expected: vocab_size,
384                    got: row.len(),
385                });
386            }
387            let _ = i;
388        }
389
390        // Build a max-heap of all (beam, token) candidates.
391        let mut heap: BinaryHeap<Candidate> = BinaryHeap::new();
392
393        for (beam_idx, beam) in state.beams.iter().enumerate() {
394            let mut lp: Vec<f64> = input.log_probs[beam_idx].clone();
395
396            // Apply top-k filter if configured.
397            if let Some(k) = self.config.top_k_filter {
398                Self::top_k_filter_logits(&mut lp, k);
399            }
400
401            for (token_id, &token_lp) in lp.iter().enumerate() {
402                // Skip -inf entries (filtered out by top-k).
403                if token_lp == f64::NEG_INFINITY {
404                    continue;
405                }
406                let new_log_prob = beam.log_prob + token_lp;
407                // Compute penalised score based on hypothetical new length.
408                let new_len = (beam.tokens.len() + 1) as f64;
409                let score = if self.config.length_penalty == 0.0 {
410                    new_log_prob
411                } else {
412                    new_log_prob / new_len.powf(self.config.length_penalty)
413                };
414
415                heap.push(Candidate {
416                    beam_idx,
417                    token_id,
418                    log_prob: new_log_prob,
419                    score,
420                });
421            }
422        }
423
424        // Select top beam_width candidates.
425        let desired = self.config.beam_width;
426        let mut new_beams: Vec<BeamHypothesis> = Vec::with_capacity(desired);
427        let mut new_completed: Vec<BeamHypothesis> = state.completed.clone();
428        let mut eos_count: usize = 0;
429        let mut taken: usize = 0;
430
431        while taken < desired {
432            let candidate = match heap.pop() {
433                Some(c) => c,
434                None => break,
435            };
436
437            let parent = &state.beams[candidate.beam_idx];
438            let mut hyp = parent.extend(candidate.token_id, 0.0);
439            // Override the log_prob computed in extend (which adds 0.0) with the real value.
440            hyp.log_prob = candidate.log_prob;
441            hyp.score = candidate.score;
442
443            // Check whether this is an EOS token.
444            let is_eos = self
445                .config
446                .eos_token_id
447                .map(|eos| candidate.token_id == eos)
448                .unwrap_or(false);
449
450            if is_eos && hyp.len() > self.config.min_length {
451                // +1 because len includes the BOS token.
452                hyp.is_done = true;
453                eos_count += 1;
454                new_completed.push(hyp);
455            } else {
456                new_beams.push(hyp);
457            }
458            taken += 1;
459        }
460
461        // Finalise any beams that have reached max_length.
462        let (kept_beams, maxlen_beams): (Vec<_>, Vec<_>) = new_beams
463            .into_iter()
464            .partition(|b| b.len() < self.config.max_length);
465        let new_beams = kept_beams;
466        for mut beam in maxlen_beams {
467            beam.is_done = true;
468            new_completed.push(beam);
469        }
470
471        let _ = eos_count; // suppress unused warning
472
473        Ok(BeamState {
474            beams: new_beams,
475            completed: new_completed,
476            step: state.step + 1,
477        })
478    }
479
480    /// Run a full beam search.
481    ///
482    /// `score_fn` is called at each step with the current beam sequences and must
483    /// return log-probabilities of shape `[num_active_beams][vocab_size]`.
484    pub fn decode<F>(
485        &self,
486        bos_token_id: usize,
487        score_fn: F,
488    ) -> Result<BeamSearchResult, BeamSearchError>
489    where
490        F: Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String>,
491    {
492        if self.config.beam_width == 0 {
493            return Err(BeamSearchError::ZeroBeamWidth);
494        }
495        if self.config.max_length == 0 {
496            return Err(BeamSearchError::MaxLengthTooShort);
497        }
498        if self.config.temperature <= 0.0 {
499            return Err(BeamSearchError::InvalidTemperature(self.config.temperature));
500        }
501
502        let mut state = self.initial_state(bos_token_id);
503        while !state.is_done(&self.config) {
504            if state.beams.is_empty() {
505                break;
506            }
507
508            // Build token slices for the score function.
509            let beam_seqs: Vec<&[usize]> =
510                state.beams.iter().map(|b| b.tokens.as_slice()).collect();
511
512            let raw_logits = score_fn(&beam_seqs).map_err(BeamSearchError::ScoringFunctionError)?;
513
514            // Apply temperature and convert to log-probabilities.
515            let log_probs: Vec<Vec<f64>> = raw_logits
516                .into_iter()
517                .map(|row| {
518                    let scaled = Self::apply_temperature(&row, self.config.temperature);
519                    Self::log_softmax(&scaled)
520                })
521                .collect();
522
523            let input = BeamStepInput::new(log_probs);
524            state = self.step(state, &input)?;
525        }
526
527        // Finalise any remaining active beams.
528        let remaining: Vec<BeamHypothesis> = state.beams.drain(..).collect();
529        for mut beam in remaining {
530            beam.is_done = true;
531            state.completed.push(beam);
532        }
533
534        // Count completion reasons.
535        let mut eos_completed: usize = 0;
536        let mut max_len_completed: usize = 0;
537        for hyp in &state.completed {
538            if let Some(eos) = self.config.eos_token_id {
539                if hyp.tokens.last().copied() == Some(eos) {
540                    eos_completed += 1;
541                } else {
542                    max_len_completed += 1;
543                }
544            } else {
545                max_len_completed += 1;
546            }
547        }
548
549        let total_steps = state.step;
550
551        // Sort completed by score descending.
552        let mut hypotheses = state.completed;
553        hypotheses.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
554
555        let best_sequence = hypotheses
556            .first()
557            .map(|h| h.tokens.clone())
558            .unwrap_or_default();
559        let best_score = hypotheses
560            .first()
561            .map(|h| h.score)
562            .unwrap_or(f64::NEG_INFINITY);
563
564        let avg_sequence_length = if hypotheses.is_empty() {
565            0.0
566        } else {
567            hypotheses.iter().map(|h| h.len() as f64).sum::<f64>() / hypotheses.len() as f64
568        };
569
570        let score_range = if hypotheses.is_empty() {
571            (0.0, 0.0)
572        } else {
573            let min_score = hypotheses
574                .iter()
575                .map(|h| h.score)
576                .fold(f64::INFINITY, f64::min);
577            let max_score = hypotheses
578                .iter()
579                .map(|h| h.score)
580                .fold(f64::NEG_INFINITY, f64::max);
581            (min_score, max_score)
582        };
583
584        let stats = BeamSearchStats {
585            total_steps,
586            num_completed_at_eos: eos_completed,
587            num_completed_at_max_length: max_len_completed,
588            avg_sequence_length,
589            score_range,
590        };
591
592        Ok(BeamSearchResult {
593            hypotheses,
594            best_sequence,
595            best_score,
596            stats,
597        })
598    }
599
600    /// Extract the top-`k` hypotheses from a beam state, sorted by score descending.
601    pub fn top_k_results(&self, state: &BeamState, k: usize) -> Vec<BeamHypothesis> {
602        let mut all: Vec<BeamHypothesis> = state
603            .beams
604            .iter()
605            .chain(state.completed.iter())
606            .cloned()
607            .collect();
608        all.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
609        all.truncate(k);
610        all
611    }
612
613    /// Apply temperature scaling to logits: `logits[i] / temperature`.
614    pub fn apply_temperature(logits: &[f64], temperature: f64) -> Vec<f64> {
615        if temperature == 1.0 {
616            return logits.to_vec();
617        }
618        let t = if temperature == 0.0 {
619            1e-8
620        } else {
621            temperature
622        };
623        logits.iter().map(|&x| x / t).collect()
624    }
625
626    /// Compute numerically stable log-softmax.
627    ///
628    /// Uses the log-sum-exp trick:
629    /// `lse = max + log(sum(exp(x - max)))`,  `log_softmax(x_i) = x_i - lse`.
630    pub fn log_softmax(logits: &[f64]) -> Vec<f64> {
631        if logits.is_empty() {
632            return Vec::new();
633        }
634        let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
635        let sum_exp: f64 = logits.iter().map(|&x| (x - max_val).exp()).sum();
636        let log_sum_exp = max_val + sum_exp.ln();
637        logits.iter().map(|&x| x - log_sum_exp).collect()
638    }
639
640    /// Zero out all but the top-`k` logits in place (set others to `NEG_INFINITY`).
641    pub fn top_k_filter_logits(logits: &mut [f64], k: usize) {
642        if k == 0 || logits.is_empty() {
643            for v in logits.iter_mut() {
644                *v = f64::NEG_INFINITY;
645            }
646            return;
647        }
648        if k >= logits.len() {
649            return; // Nothing to filter.
650        }
651
652        // Find the k-th largest value as a threshold.
653        let mut sorted: Vec<f64> = logits.to_owned();
654        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal));
655        let threshold = sorted[k - 1];
656
657        // Keep only those >= threshold (tie-breaking: we allow exactly k through).
658        let mut kept = 0usize;
659        for v in logits.iter_mut() {
660            if *v >= threshold && kept < k {
661                kept += 1;
662            } else {
663                *v = f64::NEG_INFINITY;
664            }
665        }
666    }
667}
668
669// ---------------------------------------------------------------------------
670// Tests
671// ---------------------------------------------------------------------------
672
673#[cfg(test)]
674mod tests {
675    use super::*;
676
677    /// Helper: build a constant score function that returns uniform log-probs.
678    fn uniform_score_fn(
679        vocab_size: usize,
680    ) -> impl Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String> {
681        let lp = -(vocab_size as f64).ln();
682        move |beams: &[&[usize]]| Ok(beams.iter().map(|_| vec![lp; vocab_size]).collect())
683    }
684
685    #[test]
686    fn test_beam_search_config_default() {
687        let cfg = BeamSearchConfig::default();
688        assert_eq!(cfg.beam_width, 4);
689        assert_eq!(cfg.max_length, 50);
690        assert_eq!(cfg.eos_token_id, None);
691        assert_eq!(cfg.length_penalty, 1.0);
692        assert_eq!(cfg.temperature, 1.0);
693    }
694
695    #[test]
696    fn test_beam_hypothesis_new() {
697        let h = BeamHypothesis::new(0, -0.5);
698        assert_eq!(h.len(), 1);
699        assert_eq!(h.tokens, vec![0]);
700        assert!(!h.is_done);
701    }
702
703    #[test]
704    fn test_beam_hypothesis_extend() {
705        let h = BeamHypothesis::new(0, -0.5);
706        let h2 = h.extend(7, -1.0);
707        assert_eq!(h2.len(), 2);
708        assert_eq!(h2.tokens, vec![0, 7]);
709        assert!((h2.log_prob - (-1.5)).abs() < 1e-10);
710        assert!(!h2.is_done);
711    }
712
713    #[test]
714    fn test_beam_hypothesis_length_penalized_score_no_penalty() {
715        // alpha = 1.0: score = log_prob / length
716        let h = BeamHypothesis::new(0, 0.0);
717        let h2 = h.extend(1, -2.0);
718        // log_prob = -2.0, length = 2
719        let score = h2.length_penalized_score(1.0);
720        assert!((score - (-1.0)).abs() < 1e-10);
721    }
722
723    #[test]
724    fn test_beam_step_input_from_logits() {
725        let logits = vec![vec![1.0, 2.0, 3.0], vec![0.5, 0.5, 0.5]];
726        let input = BeamStepInput::from_logits(logits, 1.0);
727        // Each row's exp should sum to ~1.
728        for row in &input.log_probs {
729            let sum: f64 = row.iter().map(|&lp| lp.exp()).sum();
730            assert!((sum - 1.0).abs() < 1e-9, "sum was {sum}");
731        }
732    }
733
734    #[test]
735    fn test_beam_step_input_vocab_size() {
736        let lp = vec![vec![0.1, 0.2, 0.7]; 3];
737        let input = BeamStepInput::new(lp);
738        assert_eq!(input.vocab_size(), 3);
739        assert_eq!(input.num_beams(), 3);
740    }
741
742    #[test]
743    fn test_beam_state_initial() {
744        let state = BeamState::initial(4, 0);
745        assert_eq!(state.beams.len(), 4);
746        assert_eq!(state.completed.len(), 0);
747        assert_eq!(state.step, 0);
748        for b in &state.beams {
749            assert_eq!(b.tokens, vec![0]);
750        }
751    }
752
753    #[test]
754    fn test_beam_state_is_done_max_length() {
755        let cfg = BeamSearchConfig {
756            max_length: 3,
757            ..BeamSearchConfig::default()
758        };
759        let mut state = BeamState::initial(4, 0);
760        assert!(!state.is_done(&cfg));
761        state.step = 3;
762        assert!(state.is_done(&cfg));
763    }
764
765    #[test]
766    fn test_decoder_step_advances_state() {
767        let cfg = BeamSearchConfig {
768            beam_width: 2,
769            vocab_size: 5,
770            ..BeamSearchConfig::default()
771        };
772        let decoder = BeamSearchDecoder::new(cfg);
773        let state = decoder.initial_state(0);
774        let lp = BeamSearchDecoder::log_softmax(&[1.0; 5]);
775        let input = BeamStepInput::new(vec![lp.clone(), lp]);
776        let new_state = decoder.step(state, &input).expect("step failed");
777        assert_eq!(new_state.step, 1);
778    }
779
780    #[test]
781    fn test_decoder_step_beam_count() {
782        let beam_width = 3;
783        let vocab_size = 10;
784        let cfg = BeamSearchConfig {
785            beam_width,
786            vocab_size,
787            ..BeamSearchConfig::default()
788        };
789        let decoder = BeamSearchDecoder::new(cfg);
790        let state = decoder.initial_state(0);
791        let lp = BeamSearchDecoder::log_softmax(&vec![1.0; vocab_size]);
792        let input = BeamStepInput::new(vec![lp; beam_width]);
793        let new_state = decoder.step(state, &input).expect("step failed");
794        // Active + completed should together contain beam_width hypotheses.
795        assert_eq!(
796            new_state.beams.len() + new_state.completed.len(),
797            beam_width
798        );
799    }
800
801    #[test]
802    fn test_decoder_step_eos_moves_to_completed() {
803        let eos = 1_usize;
804        let vocab_size = 5;
805        let beam_width = 2;
806        let cfg = BeamSearchConfig {
807            beam_width,
808            vocab_size,
809            eos_token_id: Some(eos),
810            min_length: 1,
811            ..BeamSearchConfig::default()
812        };
813        let decoder = BeamSearchDecoder::new(cfg);
814        let state = decoder.initial_state(0);
815
816        // Strongly bias logits towards token 1 (EOS) for all beams.
817        let mut logits = vec![f64::NEG_INFINITY; vocab_size];
818        logits[eos] = 100.0; // overwhelming preference for EOS
819        let lp = BeamSearchDecoder::log_softmax(&logits);
820        let input = BeamStepInput::new(vec![lp; beam_width]);
821
822        let new_state = decoder.step(state, &input).expect("step failed");
823        // With EOS strongly preferred, we expect some beams to complete.
824        // (All beam_width candidates emit EOS, so they all go to completed.)
825        assert!(!new_state.completed.is_empty(), "expected completed beams");
826    }
827
828    #[test]
829    fn test_decoder_step_vocab_size_mismatch() {
830        let cfg = BeamSearchConfig {
831            beam_width: 2,
832            vocab_size: 10,
833            ..BeamSearchConfig::default()
834        };
835        let decoder = BeamSearchDecoder::new(cfg);
836        let state = decoder.initial_state(0);
837        // Provide 5 tokens instead of 10.
838        let lp = vec![0.2; 5];
839        let input = BeamStepInput::new(vec![lp; 2]);
840        let result = decoder.step(state, &input);
841        assert!(matches!(
842            result,
843            Err(BeamSearchError::VocabSizeMismatch { .. })
844        ));
845    }
846
847    #[test]
848    fn test_decoder_log_softmax_sums_to_one() {
849        let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
850        let lsp = BeamSearchDecoder::log_softmax(&logits);
851        let sum: f64 = lsp.iter().map(|&x| x.exp()).sum();
852        assert!((sum - 1.0).abs() < 1e-9, "sum = {sum}");
853    }
854
855    #[test]
856    fn test_decoder_top_k_filter() {
857        let mut logits = vec![1.0, 5.0, 3.0, 2.0, 4.0];
858        BeamSearchDecoder::top_k_filter_logits(&mut logits, 2);
859        // Only the two largest (5.0 at index 1, 4.0 at index 4) should survive.
860        let non_neg_inf: Vec<usize> = logits
861            .iter()
862            .enumerate()
863            .filter(|(_, &v)| v != f64::NEG_INFINITY)
864            .map(|(i, _)| i)
865            .collect();
866        assert_eq!(non_neg_inf.len(), 2);
867        // Indices 1 and 4 are kept.
868        assert!(non_neg_inf.contains(&1));
869        assert!(non_neg_inf.contains(&4));
870    }
871
872    #[test]
873    fn test_decoder_decode_simple() {
874        let vocab_size = 8;
875        let cfg = BeamSearchConfig {
876            beam_width: 2,
877            max_length: 5,
878            vocab_size,
879            ..BeamSearchConfig::default()
880        };
881        let decoder = BeamSearchDecoder::new(cfg);
882        let score_fn = uniform_score_fn(vocab_size);
883        let result = decoder.decode(0, score_fn);
884        assert!(result.is_ok(), "decode returned error: {:?}", result.err());
885    }
886
887    #[test]
888    fn test_beam_search_result_best() {
889        let h1 = BeamHypothesis {
890            tokens: vec![0, 1],
891            log_prob: -1.0,
892            score: -1.0,
893            is_done: true,
894        };
895        let h2 = BeamHypothesis {
896            tokens: vec![0, 2],
897            log_prob: -0.5,
898            score: -0.5,
899            is_done: true,
900        };
901        let result = BeamSearchResult {
902            best_sequence: h2.tokens.clone(),
903            best_score: h2.score,
904            hypotheses: vec![h2.clone(), h1.clone()],
905            stats: BeamSearchStats {
906                total_steps: 1,
907                num_completed_at_eos: 0,
908                num_completed_at_max_length: 2,
909                avg_sequence_length: 2.0,
910                score_range: (-1.0, -0.5),
911            },
912        };
913        let best = result.best().expect("should have best");
914        assert_eq!(best.score, -0.5);
915    }
916
917    #[test]
918    fn test_beam_search_stats() {
919        let vocab_size = 4;
920        let cfg = BeamSearchConfig {
921            beam_width: 2,
922            max_length: 4,
923            vocab_size,
924            ..BeamSearchConfig::default()
925        };
926        let decoder = BeamSearchDecoder::new(cfg);
927        let score_fn = uniform_score_fn(vocab_size);
928        let result = decoder.decode(0, score_fn).expect("decode failed");
929        assert!(result.stats.total_steps > 0);
930    }
931
932    #[test]
933    fn test_top_k_results_sorted() {
934        let decoder = BeamSearchDecoder::with_default();
935        let make_hyp = |score: f64| BeamHypothesis {
936            tokens: vec![0],
937            log_prob: score,
938            score,
939            is_done: false,
940        };
941        let state = BeamState {
942            beams: vec![make_hyp(-2.0), make_hyp(-0.5), make_hyp(-3.0)],
943            completed: vec![make_hyp(-1.0)],
944            step: 1,
945        };
946        let top = decoder.top_k_results(&state, 3);
947        assert_eq!(top.len(), 3);
948        // Sorted descending.
949        assert!(top[0].score >= top[1].score);
950        assert!(top[1].score >= top[2].score);
951        assert!((top[0].score - (-0.5)).abs() < 1e-10);
952    }
953
954    #[test]
955    fn test_decoder_temperature_scaling() {
956        let logits = vec![1.0, 2.0, 3.0];
957        let lp1 =
958            BeamSearchDecoder::log_softmax(&BeamSearchDecoder::apply_temperature(&logits, 1.0));
959        let lp2 =
960            BeamSearchDecoder::log_softmax(&BeamSearchDecoder::apply_temperature(&logits, 2.0));
961        // Higher temperature => flatter distribution (less spread in log-probs).
962        let spread1 = lp1.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
963            - lp1.iter().cloned().fold(f64::INFINITY, f64::min);
964        let spread2 = lp2.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
965            - lp2.iter().cloned().fold(f64::INFINITY, f64::min);
966        assert!(
967            spread2 < spread1,
968            "temperature=2.0 should flatten distribution"
969        );
970    }
971
972    #[test]
973    fn test_beam_search_error_display() {
974        let errors = vec![
975            BeamSearchError::EmptyBeams,
976            BeamSearchError::BeamWidthMismatch {
977                expected: 4,
978                got: 2,
979            },
980            BeamSearchError::VocabSizeMismatch {
981                expected: 1000,
982                got: 500,
983            },
984            BeamSearchError::ZeroBeamWidth,
985            BeamSearchError::MaxLengthTooShort,
986            BeamSearchError::ScoringFunctionError("test error".to_string()),
987            BeamSearchError::InvalidTemperature(-1.0),
988        ];
989        for err in &errors {
990            let s = err.to_string();
991            assert!(!s.is_empty(), "display for {err:?} was empty");
992        }
993    }
994}