Skip to main content

tensorlogic_infer/
sampling.rs

1//! Sampling strategies for generative model token selection.
2//!
3//! Provides greedy decoding, temperature sampling, top-k, top-p (nucleus),
4//! and a configurable sampler combining all of the above with repetition penalty.
5
6use std::fmt;
7
8// ---------------------------------------------------------------------------
9// Error type
10// ---------------------------------------------------------------------------
11
12/// Errors that can occur during sampling operations.
13#[derive(Debug, Clone)]
14pub enum SamplingError {
15    /// The logit/probability vector was empty.
16    EmptyDistribution,
17    /// Temperature value was not strictly positive.
18    InvalidTemperature(f64),
19    /// Top-p value was outside (0, 1].
20    InvalidTopP { p: f64 },
21    /// Top-k value was zero.
22    InvalidTopK { k: usize },
23    /// Normalization of the distribution failed (e.g., all-zero softmax).
24    NormalizationFailure,
25    /// The probability array contained invalid values (NaN, negative, etc.).
26    InvalidProbabilities(String),
27}
28
29impl fmt::Display for SamplingError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Self::EmptyDistribution => write!(f, "logit/probability vector is empty"),
33            Self::InvalidTemperature(t) => {
34                write!(f, "temperature must be > 0.0, got {t}")
35            }
36            Self::InvalidTopP { p } => {
37                write!(f, "top_p must be in (0, 1], got {p}")
38            }
39            Self::InvalidTopK { k } => {
40                write!(f, "top_k must be >= 1, got {k}")
41            }
42            Self::NormalizationFailure => {
43                write!(
44                    f,
45                    "probability distribution could not be normalised (all-zero or NaN)"
46                )
47            }
48            Self::InvalidProbabilities(msg) => {
49                write!(f, "invalid probability array: {msg}")
50            }
51        }
52    }
53}
54
55impl std::error::Error for SamplingError {}
56
57// ---------------------------------------------------------------------------
58// Result of a single sampling step
59// ---------------------------------------------------------------------------
60
61/// The result of sampling a single token.
62#[derive(Debug, Clone)]
63pub struct SampledToken {
64    /// Index of the selected token in the vocabulary.
65    pub token_id: usize,
66    /// Natural-log probability of the selected token: ln(prob).
67    pub log_prob: f64,
68    /// Linear probability of the selected token after softmax.
69    pub prob: f64,
70}
71
72// ---------------------------------------------------------------------------
73// SamplingConfig
74// ---------------------------------------------------------------------------
75
76/// Configuration for the [`ConfigurableSampler`].
77#[derive(Debug, Clone)]
78pub struct SamplingConfig {
79    /// Scale applied to logits before softmax. 1.0 = no scaling.
80    pub temperature: f64,
81    /// If `Some(k)`, only the top-k logits participate in sampling.
82    pub top_k: Option<usize>,
83    /// If `Some(p)`, nucleus sampling keeps the fewest tokens whose
84    /// cumulative probability meets or exceeds `p`.
85    pub top_p: Option<f64>,
86    /// Penalty applied to tokens that already appear in the context.
87    /// Values > 1.0 reduce the probability of repetition; 1.0 = no effect.
88    pub repetition_penalty: f64,
89    /// Optional seed for the internal LCG RNG.
90    pub seed: Option<u64>,
91}
92
93impl Default for SamplingConfig {
94    fn default() -> Self {
95        Self {
96            temperature: 1.0,
97            top_k: None,
98            top_p: None,
99            repetition_penalty: 1.0,
100            seed: None,
101        }
102    }
103}
104
105// ---------------------------------------------------------------------------
106// SimpleRng – minimal LCG, no `rand` dependency
107// ---------------------------------------------------------------------------
108
109/// A minimal Linear Congruential Generator for reproducible sampling.
110///
111/// This avoids any external `rand` crate dependency while still providing
112/// adequate statistical quality for token-sampling purposes.
113#[derive(Debug, Clone)]
114struct SimpleRng {
115    state: u64,
116}
117
118impl SimpleRng {
119    fn new(seed: u64) -> Self {
120        // Mix the seed slightly so seed=0 is not degenerate.
121        let state = seed
122            .wrapping_mul(6364136223846793005)
123            .wrapping_add(1442695040888963407);
124        Self { state }
125    }
126
127    /// Advance the LCG and return the raw 64-bit value.
128    fn next_u64(&mut self) -> u64 {
129        self.state = self
130            .state
131            .wrapping_mul(6364136223846793005)
132            .wrapping_add(1442695040888963407);
133        self.state >> 11
134    }
135
136    /// Return a uniform float in [0.0, 1.0).
137    fn next_f64(&mut self) -> f64 {
138        // 53-bit mantissa of f64.
139        (self.next_u64() & ((1u64 << 53) - 1)) as f64 / (1u64 << 53) as f64
140    }
141
142    /// Draw from a categorical distribution defined by `probs` (must sum ≈ 1).
143    ///
144    /// Uses inverse CDF (linear scan); robust to small floating-point errors.
145    fn sample_categorical(&mut self, probs: &[f64]) -> usize {
146        let u = self.next_f64();
147        let mut cumsum = 0.0_f64;
148        for (idx, &p) in probs.iter().enumerate() {
149            cumsum += p;
150            if u < cumsum {
151                return idx;
152            }
153        }
154        // Fallback: return last non-zero index in case of rounding errors.
155        probs
156            .iter()
157            .enumerate()
158            .rev()
159            .find(|(_, &p)| p > 0.0)
160            .map(|(i, _)| i)
161            .unwrap_or(probs.len().saturating_sub(1))
162    }
163}
164
165// ---------------------------------------------------------------------------
166// Utility functions
167// ---------------------------------------------------------------------------
168
169/// Compute softmax with the log-sum-exp trick for numerical stability.
170///
171/// Returns a probability vector that sums to 1.0 (unless the input is empty).
172pub fn softmax(logits: &[f64]) -> Vec<f64> {
173    if logits.is_empty() {
174        return Vec::new();
175    }
176    let max_val = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
177    let mut exps: Vec<f64> = logits.iter().map(|&x| (x - max_val).exp()).collect();
178    let sum: f64 = exps.iter().sum();
179    if sum > 0.0 {
180        for e in &mut exps {
181            *e /= sum;
182        }
183    }
184    exps
185}
186
187/// Compute log-softmax: log(softmax(x)) with the log-sum-exp trick.
188pub fn log_softmax(logits: &[f64]) -> Vec<f64> {
189    if logits.is_empty() {
190        return Vec::new();
191    }
192    let max_val = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
193    let log_sum_exp: f64 = logits
194        .iter()
195        .map(|&x| (x - max_val).exp())
196        .sum::<f64>()
197        .ln()
198        + max_val;
199    logits.iter().map(|&x| x - log_sum_exp).collect()
200}
201
202/// Shannon entropy of a probability distribution (in nats).
203///
204/// Tokens with zero probability contribute 0 to the sum (0 · ln 0 = 0).
205pub fn entropy(probs: &[f64]) -> f64 {
206    probs
207        .iter()
208        .filter(|&&p| p > 0.0)
209        .map(|&p| -p * p.ln())
210        .sum()
211}
212
213/// Perplexity: exp(mean negative log-prob) over a sequence of log-probabilities.
214pub fn perplexity(log_probs: &[f64]) -> f64 {
215    if log_probs.is_empty() {
216        return 1.0;
217    }
218    let mean_nll = -log_probs.iter().sum::<f64>() / log_probs.len() as f64;
219    mean_nll.exp()
220}
221
222// ---------------------------------------------------------------------------
223// Internal helpers
224// ---------------------------------------------------------------------------
225
226/// Scale `logits` by `1 / temperature` and return a new `Vec<f64>`.
227fn scale_by_temperature(logits: &[f64], temperature: f64) -> Vec<f64> {
228    logits.iter().map(|&x| x / temperature).collect()
229}
230
231/// Given a probability vector, sample one token; return `(token_id, prob, log_prob)`.
232fn sample_from_probs(probs: &[f64], rng: &mut SimpleRng) -> Result<SampledToken, SamplingError> {
233    let sum: f64 = probs.iter().sum();
234    if sum <= 0.0 || sum.is_nan() {
235        return Err(SamplingError::NormalizationFailure);
236    }
237    let token_id = rng.sample_categorical(probs);
238    let prob = probs[token_id];
239    let log_prob = if prob > 0.0 {
240        prob.ln()
241    } else {
242        f64::NEG_INFINITY
243    };
244    Ok(SampledToken {
245        token_id,
246        log_prob,
247        prob,
248    })
249}
250
251// ---------------------------------------------------------------------------
252// GreedyDecoder
253// ---------------------------------------------------------------------------
254
255/// Always selects the token with the highest logit (argmax decoding).
256#[derive(Debug, Clone)]
257pub struct GreedyDecoder;
258
259impl GreedyDecoder {
260    /// Create a new `GreedyDecoder`.
261    pub fn new() -> Self {
262        Self
263    }
264
265    /// Decode a single logit vector, returning the argmax token.
266    pub fn decode(&self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
267        if logits.is_empty() {
268            return Err(SamplingError::EmptyDistribution);
269        }
270        let token_id = logits
271            .iter()
272            .enumerate()
273            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
274            .map(|(i, _)| i)
275            .ok_or(SamplingError::EmptyDistribution)?;
276
277        let probs = softmax(logits);
278        let prob = probs[token_id];
279        let log_prob = if prob > 0.0 {
280            prob.ln()
281        } else {
282            f64::NEG_INFINITY
283        };
284        Ok(SampledToken {
285            token_id,
286            log_prob,
287            prob,
288        })
289    }
290
291    /// Decode a batch of logit vectors, one argmax per row.
292    pub fn decode_batch(&self, logits: &[Vec<f64>]) -> Result<Vec<SampledToken>, SamplingError> {
293        logits.iter().map(|row| self.decode(row)).collect()
294    }
295}
296
297impl Default for GreedyDecoder {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303// ---------------------------------------------------------------------------
304// TemperatureSampler
305// ---------------------------------------------------------------------------
306
307/// Samples from a softmax distribution after dividing logits by `temperature`.
308///
309/// - `temperature > 1.0`: flatter distribution (more randomness).
310/// - `temperature < 1.0`: sharper distribution (more peaked).
311/// - `temperature = 1.0`: unmodified softmax.
312#[derive(Debug)]
313pub struct TemperatureSampler {
314    /// The temperature value used to scale logits.
315    pub temperature: f64,
316    rng: SimpleRng,
317}
318
319impl TemperatureSampler {
320    /// Construct a `TemperatureSampler`.
321    ///
322    /// Returns `Err(SamplingError::InvalidTemperature)` if `temperature <= 0.0`.
323    pub fn new(temperature: f64, seed: u64) -> Result<Self, SamplingError> {
324        if temperature <= 0.0 || temperature.is_nan() {
325            return Err(SamplingError::InvalidTemperature(temperature));
326        }
327        Ok(Self {
328            temperature,
329            rng: SimpleRng::new(seed),
330        })
331    }
332
333    /// Sample one token from `logits`.
334    pub fn sample(&mut self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
335        if logits.is_empty() {
336            return Err(SamplingError::EmptyDistribution);
337        }
338        let scaled = scale_by_temperature(logits, self.temperature);
339        let probs = softmax(&scaled);
340        sample_from_probs(&probs, &mut self.rng)
341    }
342
343    /// Sample one token for each row in a batch.
344    pub fn sample_batch(
345        &mut self,
346        logits: &[Vec<f64>],
347    ) -> Result<Vec<SampledToken>, SamplingError> {
348        logits.iter().map(|row| self.sample(row)).collect()
349    }
350}
351
352// ---------------------------------------------------------------------------
353// TopKSampler
354// ---------------------------------------------------------------------------
355
356/// Zeroes out all logits except the top-k, then applies temperature sampling.
357#[derive(Debug)]
358pub struct TopKSampler {
359    /// Number of top tokens to keep.
360    pub k: usize,
361    /// Temperature applied after the top-k filter.
362    pub temperature: f64,
363    rng: SimpleRng,
364}
365
366impl TopKSampler {
367    /// Construct a `TopKSampler`.
368    ///
369    /// Fails if `k == 0` or `temperature <= 0.0`.
370    pub fn new(k: usize, temperature: f64, seed: u64) -> Result<Self, SamplingError> {
371        if k == 0 {
372            return Err(SamplingError::InvalidTopK { k });
373        }
374        if temperature <= 0.0 || temperature.is_nan() {
375            return Err(SamplingError::InvalidTemperature(temperature));
376        }
377        Ok(Self {
378            k,
379            temperature,
380            rng: SimpleRng::new(seed),
381        })
382    }
383
384    /// Sample one token from `logits` using the top-k filter.
385    pub fn sample(&mut self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
386        if logits.is_empty() {
387            return Err(SamplingError::EmptyDistribution);
388        }
389        let filtered = Self::apply_top_k(logits, self.k);
390        let scaled = scale_by_temperature(&filtered, self.temperature);
391        let probs = softmax(&scaled);
392        sample_from_probs(&probs, &mut self.rng)
393    }
394
395    /// Return a copy of `logits` where all but the top-`k` entries are
396    /// set to `f64::NEG_INFINITY`.
397    pub fn apply_top_k(logits: &[f64], k: usize) -> Vec<f64> {
398        if logits.is_empty() || k == 0 {
399            return logits.to_vec();
400        }
401        let effective_k = k.min(logits.len());
402
403        // Build a list of (value, original_index), sort descending by value.
404        let mut indexed: Vec<(f64, usize)> = logits
405            .iter()
406            .copied()
407            .enumerate()
408            .map(|(i, v)| (v, i))
409            .collect();
410        indexed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
411
412        // Collect the indices of the top-k tokens.
413        let top_k_indices: std::collections::HashSet<usize> =
414            indexed.iter().take(effective_k).map(|&(_, i)| i).collect();
415
416        logits
417            .iter()
418            .enumerate()
419            .map(|(i, &v)| {
420                if top_k_indices.contains(&i) {
421                    v
422                } else {
423                    f64::NEG_INFINITY
424                }
425            })
426            .collect()
427    }
428}
429
430// ---------------------------------------------------------------------------
431// TopPSampler
432// ---------------------------------------------------------------------------
433
434/// Nucleus (top-p) sampler: keeps the smallest set of tokens whose cumulative
435/// probability is at least `p`, then samples from that nucleus.
436#[derive(Debug)]
437pub struct TopPSampler {
438    /// Cumulative probability threshold in (0, 1].
439    pub p: f64,
440    /// Temperature applied before the nucleus filter.
441    pub temperature: f64,
442    rng: SimpleRng,
443}
444
445impl TopPSampler {
446    /// Construct a `TopPSampler`.
447    ///
448    /// Fails if `p <= 0.0 || p > 1.0` or if `temperature <= 0.0`.
449    pub fn new(p: f64, temperature: f64, seed: u64) -> Result<Self, SamplingError> {
450        if p <= 0.0 || p > 1.0 || p.is_nan() {
451            return Err(SamplingError::InvalidTopP { p });
452        }
453        if temperature <= 0.0 || temperature.is_nan() {
454            return Err(SamplingError::InvalidTemperature(temperature));
455        }
456        Ok(Self {
457            p,
458            temperature,
459            rng: SimpleRng::new(seed),
460        })
461    }
462
463    /// Sample one token from `logits` using nucleus sampling.
464    pub fn sample(&mut self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
465        if logits.is_empty() {
466            return Err(SamplingError::EmptyDistribution);
467        }
468        let scaled = scale_by_temperature(logits, self.temperature);
469        let probs = softmax(&scaled);
470        let filtered_logits = Self::apply_top_p(&probs, self.p);
471        let filtered_probs = softmax(&filtered_logits);
472        sample_from_probs(&filtered_probs, &mut self.rng)
473    }
474
475    /// Given a probability vector `probs`, return a logit vector in which
476    /// tokens outside the nucleus are set to `f64::NEG_INFINITY`.
477    ///
478    /// Algorithm:
479    /// 1. Sort indices by descending probability.
480    /// 2. Accumulate until the sum >= `p`.
481    /// 3. All tokens beyond the cutoff become `NEG_INFINITY`.
482    pub fn apply_top_p(probs: &[f64], p: f64) -> Vec<f64> {
483        if probs.is_empty() {
484            return Vec::new();
485        }
486        // Sort indices by descending probability.
487        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
488        sorted_indices.sort_by(|&a, &b| {
489            probs[b]
490                .partial_cmp(&probs[a])
491                .unwrap_or(std::cmp::Ordering::Equal)
492        });
493
494        // Find nucleus: smallest prefix whose cumulative prob >= p.
495        let mut cumsum = 0.0_f64;
496        let mut nucleus: std::collections::HashSet<usize> = std::collections::HashSet::new();
497        for &idx in &sorted_indices {
498            nucleus.insert(idx);
499            cumsum += probs[idx];
500            if cumsum >= p {
501                break;
502            }
503        }
504
505        // Build output: nucleus tokens keep their log-prob; others = NEG_INFINITY.
506        probs
507            .iter()
508            .enumerate()
509            .map(|(i, &prob)| {
510                if nucleus.contains(&i) {
511                    // Convert probability back to logit space (log-prob is suitable).
512                    if prob > 0.0 {
513                        prob.ln()
514                    } else {
515                        f64::NEG_INFINITY
516                    }
517                } else {
518                    f64::NEG_INFINITY
519                }
520            })
521            .collect()
522    }
523}
524
525// ---------------------------------------------------------------------------
526// ConfigurableSampler
527// ---------------------------------------------------------------------------
528
529/// A sampler that combines temperature scaling, top-k filtering, top-p
530/// (nucleus) filtering, and repetition penalty into a single configurable
531/// pipeline.
532///
533/// Pipeline order: repetition_penalty → temperature → top-k → top-p → sample.
534#[derive(Debug)]
535pub struct ConfigurableSampler {
536    /// The complete sampling configuration.
537    pub config: SamplingConfig,
538    rng: SimpleRng,
539}
540
541impl ConfigurableSampler {
542    /// Construct a `ConfigurableSampler` from a [`SamplingConfig`].
543    ///
544    /// Validates temperature, top_k, and top_p at construction time.
545    pub fn new(config: SamplingConfig) -> Result<Self, SamplingError> {
546        if config.temperature <= 0.0 || config.temperature.is_nan() {
547            return Err(SamplingError::InvalidTemperature(config.temperature));
548        }
549        if let Some(k) = config.top_k {
550            if k == 0 {
551                return Err(SamplingError::InvalidTopK { k });
552            }
553        }
554        if let Some(p) = config.top_p {
555            if p <= 0.0 || p > 1.0 || p.is_nan() {
556                return Err(SamplingError::InvalidTopP { p });
557            }
558        }
559        let seed = config.seed.unwrap_or(42);
560        Ok(Self {
561            config,
562            rng: SimpleRng::new(seed),
563        })
564    }
565
566    /// Construct a `ConfigurableSampler` with the default configuration.
567    ///
568    /// This is equivalent to `ConfigurableSampler::new(SamplingConfig::default())` but
569    /// is infallible because the defaults are always valid.
570    pub fn with_default() -> Self {
571        Self {
572            config: SamplingConfig::default(),
573            rng: SimpleRng::new(42),
574        }
575    }
576
577    /// Apply repetition penalty in-place.
578    ///
579    /// For each token that appears in `context`:
580    /// - If the logit is positive → divide by `penalty` (move toward 0).
581    /// - If the logit is negative → multiply by `penalty` (move away from 0).
582    ///
583    /// A `penalty` of 1.0 is a no-op.
584    pub fn apply_repetition_penalty(logits: &mut [f64], context: &[usize], penalty: f64) {
585        if (penalty - 1.0).abs() < f64::EPSILON {
586            return; // Fast path: no penalty.
587        }
588        for &token_id in context {
589            if token_id < logits.len() {
590                let v = logits[token_id];
591                logits[token_id] = if v >= 0.0 { v / penalty } else { v * penalty };
592            }
593        }
594    }
595
596    /// Run the full sampling pipeline:
597    ///
598    /// 1. Apply repetition penalty to `logits` for tokens in `context`.
599    /// 2. Scale by temperature.
600    /// 3. Apply top-k filter (if configured).
601    /// 4. Apply top-p (nucleus) filter (if configured).
602    /// 5. Sample from the resulting distribution.
603    pub fn sample(
604        &mut self,
605        logits: &[f64],
606        context: &[usize],
607    ) -> Result<SampledToken, SamplingError> {
608        if logits.is_empty() {
609            return Err(SamplingError::EmptyDistribution);
610        }
611
612        // Step 1: repetition penalty.
613        let mut working = logits.to_vec();
614        Self::apply_repetition_penalty(&mut working, context, self.config.repetition_penalty);
615
616        // Step 2: temperature scaling.
617        let mut working = scale_by_temperature(&working, self.config.temperature);
618
619        // Step 3: top-k.
620        if let Some(k) = self.config.top_k {
621            working = TopKSampler::apply_top_k(&working, k);
622        }
623
624        // Step 4: top-p (operate on probabilities derived from current logits).
625        if let Some(p) = self.config.top_p {
626            let probs = softmax(&working);
627            working = TopPSampler::apply_top_p(&probs, p);
628        }
629
630        // Step 5: sample.
631        let probs = softmax(&working);
632        sample_from_probs(&probs, &mut self.rng)
633    }
634}
635
636// ---------------------------------------------------------------------------
637// Tests
638// ---------------------------------------------------------------------------
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    // Helper: build a small logit vector.
645    fn logits_5() -> Vec<f64> {
646        vec![0.1, 3.5, 1.2, -1.0, 2.0]
647    }
648
649    // ------------------------------------------------------------------
650    // GreedyDecoder
651    // ------------------------------------------------------------------
652
653    #[test]
654    fn test_greedy_decoder_argmax() {
655        let decoder = GreedyDecoder::new();
656        // Index 1 has the highest logit (3.5).
657        let token = decoder.decode(&logits_5()).expect("decode should succeed");
658        assert_eq!(token.token_id, 1);
659    }
660
661    #[test]
662    fn test_greedy_decoder_empty() {
663        let decoder = GreedyDecoder::new();
664        let result = decoder.decode(&[]);
665        assert!(
666            matches!(result, Err(SamplingError::EmptyDistribution)),
667            "expected EmptyDistribution, got {result:?}"
668        );
669    }
670
671    // ------------------------------------------------------------------
672    // TemperatureSampler
673    // ------------------------------------------------------------------
674
675    #[test]
676    fn test_temperature_sampler_valid() {
677        let sampler = TemperatureSampler::new(1.0, 0);
678        assert!(sampler.is_ok(), "construction with temp=1.0 should succeed");
679    }
680
681    #[test]
682    fn test_temperature_sampler_zero_temp_error() {
683        let result = TemperatureSampler::new(0.0, 0);
684        assert!(
685            matches!(result, Err(SamplingError::InvalidTemperature(t)) if t == 0.0),
686            "expected InvalidTemperature, got {result:?}"
687        );
688    }
689
690    #[test]
691    fn test_temperature_sampler_sample_returns_valid_token() {
692        let mut sampler = TemperatureSampler::new(1.0, 42).expect("valid");
693        let lgs = logits_5();
694        let token = sampler.sample(&lgs).expect("sample should succeed");
695        assert!(token.token_id < lgs.len(), "token_id out of vocab");
696    }
697
698    #[test]
699    fn test_temperature_sampler_prob_in_range() {
700        let mut sampler = TemperatureSampler::new(1.0, 7).expect("valid");
701        let token = sampler.sample(&logits_5()).expect("sample should succeed");
702        assert!(
703            (0.0..=1.0).contains(&token.prob),
704            "prob {} is out of [0, 1]",
705            token.prob
706        );
707    }
708
709    // ------------------------------------------------------------------
710    // TopKSampler / apply_top_k
711    // ------------------------------------------------------------------
712
713    #[test]
714    fn test_top_k_apply_filter_keeps_k() {
715        let logits = logits_5();
716        let k = 2_usize;
717        let filtered = TopKSampler::apply_top_k(&logits, k);
718        let finite_count = filtered.iter().filter(|&&v| v.is_finite()).count();
719        assert_eq!(
720            finite_count, k,
721            "expected exactly {k} finite values, got {finite_count}"
722        );
723    }
724
725    #[test]
726    fn test_top_k_sampler_sample_within_vocab() {
727        let mut sampler = TopKSampler::new(3, 1.0, 99).expect("valid");
728        let lgs = logits_5();
729        let token = sampler.sample(&lgs).expect("sample should succeed");
730        assert!(token.token_id < lgs.len(), "token_id out of vocab");
731    }
732
733    #[test]
734    fn test_top_k_zero_k_error() {
735        let result = TopKSampler::new(0, 1.0, 0);
736        assert!(
737            matches!(result, Err(SamplingError::InvalidTopK { k: 0 })),
738            "expected InvalidTopK, got {result:?}"
739        );
740    }
741
742    // ------------------------------------------------------------------
743    // TopPSampler / apply_top_p
744    // ------------------------------------------------------------------
745
746    #[test]
747    fn test_top_p_apply_filter() {
748        // Use a peaked distribution so that nucleus is well-defined.
749        let probs = vec![0.5, 0.3, 0.15, 0.04, 0.01];
750        let p = 0.8_f64;
751        let filtered_logits = TopPSampler::apply_top_p(&probs, p);
752        // The sum of exp of finite entries should be >= p of the total.
753        let nucleus_prob_sum: f64 = filtered_logits
754            .iter()
755            .filter(|&&v| v.is_finite())
756            .map(|&v| v.exp())
757            .sum();
758        // The nucleus should account for at least p of the total mass.
759        assert!(
760            nucleus_prob_sum >= p - 1e-9,
761            "nucleus prob sum {nucleus_prob_sum} < p={p}"
762        );
763    }
764
765    #[test]
766    fn test_top_p_sampler_sample_valid() {
767        let mut sampler = TopPSampler::new(0.9, 1.0, 1).expect("valid");
768        let lgs = logits_5();
769        let token = sampler.sample(&lgs).expect("sample should succeed");
770        assert!(token.token_id < lgs.len());
771    }
772
773    #[test]
774    fn test_top_p_invalid_p_error() {
775        let result = TopPSampler::new(1.5, 1.0, 0);
776        assert!(
777            matches!(result, Err(SamplingError::InvalidTopP { p }) if p == 1.5),
778            "expected InvalidTopP, got {result:?}"
779        );
780    }
781
782    // ------------------------------------------------------------------
783    // ConfigurableSampler
784    // ------------------------------------------------------------------
785
786    #[test]
787    fn test_configurable_sampler_default() {
788        let sampler = ConfigurableSampler::with_default();
789        assert_eq!(sampler.config.temperature, 1.0);
790    }
791
792    #[test]
793    fn test_configurable_sampler_with_top_k() {
794        let config = SamplingConfig {
795            temperature: 1.0,
796            top_k: Some(5),
797            top_p: None,
798            repetition_penalty: 1.0,
799            seed: Some(0),
800        };
801        let mut sampler = ConfigurableSampler::new(config).expect("valid config");
802        let lgs = logits_5();
803        let token = sampler.sample(&lgs, &[]).expect("sample should succeed");
804        assert!(token.token_id < lgs.len());
805    }
806
807    #[test]
808    fn test_repetition_penalty_reduces_seen_tokens() {
809        let logits = vec![1.0, 2.0, 3.0];
810        let mut working = logits.clone();
811        let context = vec![2_usize]; // token 2 has logit 3.0
812        ConfigurableSampler::apply_repetition_penalty(&mut working, &context, 2.0);
813        // Positive logit → divided by penalty.
814        assert!(
815            working[2] < logits[2],
816            "expected logit[2] to decrease; was {}, now {}",
817            logits[2],
818            working[2]
819        );
820        // Token 0 and 1 should be unchanged.
821        assert_eq!(working[0], logits[0]);
822        assert_eq!(working[1], logits[1]);
823    }
824
825    // ------------------------------------------------------------------
826    // Utility functions
827    // ------------------------------------------------------------------
828
829    #[test]
830    fn test_softmax_sums_to_one() {
831        let logits = vec![1.0, 2.0, 3.0, 0.5, -1.0];
832        let probs = softmax(&logits);
833        let total: f64 = probs.iter().sum();
834        assert!((total - 1.0).abs() < 1e-12, "softmax sum={total}");
835    }
836
837    #[test]
838    fn test_softmax_numerical_stability() {
839        // Very large values must not produce NaN or infinity.
840        let logits = vec![1000.0, 999.0, 998.0];
841        let probs = softmax(&logits);
842        for &p in &probs {
843            assert!(p.is_finite() && p >= 0.0, "non-finite probability: {p}");
844        }
845        let total: f64 = probs.iter().sum();
846        assert!((total - 1.0).abs() < 1e-12, "softmax sum={total}");
847    }
848
849    #[test]
850    fn test_log_softmax_matches_log_of_softmax() {
851        let logits = vec![0.5, -1.0, 2.3, 0.0];
852        let sm = softmax(&logits);
853        let lsm = log_softmax(&logits);
854        for (s, ls) in sm.iter().zip(lsm.iter()) {
855            let expected = s.ln();
856            assert!(
857                (expected - ls).abs() < 1e-10,
858                "log(softmax)={expected} vs log_softmax={ls}"
859            );
860        }
861    }
862
863    #[test]
864    fn test_entropy_uniform() {
865        // entropy([0.5, 0.5]) in nats = ln(2) ≈ 0.693147
866        let probs = vec![0.5, 0.5];
867        let h = entropy(&probs);
868        let expected = (2.0_f64).ln();
869        assert!(
870            (h - expected).abs() < 1e-12,
871            "entropy={h} expected={expected}"
872        );
873    }
874
875    #[test]
876    fn test_perplexity_basic() {
877        // perplexity([-1.0]) = exp(1.0) ≈ 2.71828
878        let log_probs = vec![-1.0_f64];
879        let ppl = perplexity(&log_probs);
880        let expected = 1.0_f64.exp();
881        assert!(
882            (ppl - expected).abs() < 1e-12,
883            "perplexity={ppl} expected={expected}"
884        );
885    }
886}