1use std::fmt;
7
8#[derive(Debug, Clone)]
14pub enum SamplingError {
15 EmptyDistribution,
17 InvalidTemperature(f64),
19 InvalidTopP { p: f64 },
21 InvalidTopK { k: usize },
23 NormalizationFailure,
25 InvalidProbabilities(String),
27}
28
29impl fmt::Display for SamplingError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::EmptyDistribution => write!(f, "logit/probability vector is empty"),
33 Self::InvalidTemperature(t) => {
34 write!(f, "temperature must be > 0.0, got {t}")
35 }
36 Self::InvalidTopP { p } => {
37 write!(f, "top_p must be in (0, 1], got {p}")
38 }
39 Self::InvalidTopK { k } => {
40 write!(f, "top_k must be >= 1, got {k}")
41 }
42 Self::NormalizationFailure => {
43 write!(
44 f,
45 "probability distribution could not be normalised (all-zero or NaN)"
46 )
47 }
48 Self::InvalidProbabilities(msg) => {
49 write!(f, "invalid probability array: {msg}")
50 }
51 }
52 }
53}
54
55impl std::error::Error for SamplingError {}
56
57#[derive(Debug, Clone)]
63pub struct SampledToken {
64 pub token_id: usize,
66 pub log_prob: f64,
68 pub prob: f64,
70}
71
72#[derive(Debug, Clone)]
78pub struct SamplingConfig {
79 pub temperature: f64,
81 pub top_k: Option<usize>,
83 pub top_p: Option<f64>,
86 pub repetition_penalty: f64,
89 pub seed: Option<u64>,
91}
92
93impl Default for SamplingConfig {
94 fn default() -> Self {
95 Self {
96 temperature: 1.0,
97 top_k: None,
98 top_p: None,
99 repetition_penalty: 1.0,
100 seed: None,
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
114struct SimpleRng {
115 state: u64,
116}
117
118impl SimpleRng {
119 fn new(seed: u64) -> Self {
120 let state = seed
122 .wrapping_mul(6364136223846793005)
123 .wrapping_add(1442695040888963407);
124 Self { state }
125 }
126
127 fn next_u64(&mut self) -> u64 {
129 self.state = self
130 .state
131 .wrapping_mul(6364136223846793005)
132 .wrapping_add(1442695040888963407);
133 self.state >> 11
134 }
135
136 fn next_f64(&mut self) -> f64 {
138 (self.next_u64() & ((1u64 << 53) - 1)) as f64 / (1u64 << 53) as f64
140 }
141
142 fn sample_categorical(&mut self, probs: &[f64]) -> usize {
146 let u = self.next_f64();
147 let mut cumsum = 0.0_f64;
148 for (idx, &p) in probs.iter().enumerate() {
149 cumsum += p;
150 if u < cumsum {
151 return idx;
152 }
153 }
154 probs
156 .iter()
157 .enumerate()
158 .rev()
159 .find(|(_, &p)| p > 0.0)
160 .map(|(i, _)| i)
161 .unwrap_or(probs.len().saturating_sub(1))
162 }
163}
164
165pub fn softmax(logits: &[f64]) -> Vec<f64> {
173 if logits.is_empty() {
174 return Vec::new();
175 }
176 let max_val = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
177 let mut exps: Vec<f64> = logits.iter().map(|&x| (x - max_val).exp()).collect();
178 let sum: f64 = exps.iter().sum();
179 if sum > 0.0 {
180 for e in &mut exps {
181 *e /= sum;
182 }
183 }
184 exps
185}
186
187pub fn log_softmax(logits: &[f64]) -> Vec<f64> {
189 if logits.is_empty() {
190 return Vec::new();
191 }
192 let max_val = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
193 let log_sum_exp: f64 = logits
194 .iter()
195 .map(|&x| (x - max_val).exp())
196 .sum::<f64>()
197 .ln()
198 + max_val;
199 logits.iter().map(|&x| x - log_sum_exp).collect()
200}
201
202pub fn entropy(probs: &[f64]) -> f64 {
206 probs
207 .iter()
208 .filter(|&&p| p > 0.0)
209 .map(|&p| -p * p.ln())
210 .sum()
211}
212
213pub fn perplexity(log_probs: &[f64]) -> f64 {
215 if log_probs.is_empty() {
216 return 1.0;
217 }
218 let mean_nll = -log_probs.iter().sum::<f64>() / log_probs.len() as f64;
219 mean_nll.exp()
220}
221
222fn scale_by_temperature(logits: &[f64], temperature: f64) -> Vec<f64> {
228 logits.iter().map(|&x| x / temperature).collect()
229}
230
231fn sample_from_probs(probs: &[f64], rng: &mut SimpleRng) -> Result<SampledToken, SamplingError> {
233 let sum: f64 = probs.iter().sum();
234 if sum <= 0.0 || sum.is_nan() {
235 return Err(SamplingError::NormalizationFailure);
236 }
237 let token_id = rng.sample_categorical(probs);
238 let prob = probs[token_id];
239 let log_prob = if prob > 0.0 {
240 prob.ln()
241 } else {
242 f64::NEG_INFINITY
243 };
244 Ok(SampledToken {
245 token_id,
246 log_prob,
247 prob,
248 })
249}
250
251#[derive(Debug, Clone)]
257pub struct GreedyDecoder;
258
259impl GreedyDecoder {
260 pub fn new() -> Self {
262 Self
263 }
264
265 pub fn decode(&self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
267 if logits.is_empty() {
268 return Err(SamplingError::EmptyDistribution);
269 }
270 let token_id = logits
271 .iter()
272 .enumerate()
273 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
274 .map(|(i, _)| i)
275 .ok_or(SamplingError::EmptyDistribution)?;
276
277 let probs = softmax(logits);
278 let prob = probs[token_id];
279 let log_prob = if prob > 0.0 {
280 prob.ln()
281 } else {
282 f64::NEG_INFINITY
283 };
284 Ok(SampledToken {
285 token_id,
286 log_prob,
287 prob,
288 })
289 }
290
291 pub fn decode_batch(&self, logits: &[Vec<f64>]) -> Result<Vec<SampledToken>, SamplingError> {
293 logits.iter().map(|row| self.decode(row)).collect()
294 }
295}
296
297impl Default for GreedyDecoder {
298 fn default() -> Self {
299 Self::new()
300 }
301}
302
303#[derive(Debug)]
313pub struct TemperatureSampler {
314 pub temperature: f64,
316 rng: SimpleRng,
317}
318
319impl TemperatureSampler {
320 pub fn new(temperature: f64, seed: u64) -> Result<Self, SamplingError> {
324 if temperature <= 0.0 || temperature.is_nan() {
325 return Err(SamplingError::InvalidTemperature(temperature));
326 }
327 Ok(Self {
328 temperature,
329 rng: SimpleRng::new(seed),
330 })
331 }
332
333 pub fn sample(&mut self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
335 if logits.is_empty() {
336 return Err(SamplingError::EmptyDistribution);
337 }
338 let scaled = scale_by_temperature(logits, self.temperature);
339 let probs = softmax(&scaled);
340 sample_from_probs(&probs, &mut self.rng)
341 }
342
343 pub fn sample_batch(
345 &mut self,
346 logits: &[Vec<f64>],
347 ) -> Result<Vec<SampledToken>, SamplingError> {
348 logits.iter().map(|row| self.sample(row)).collect()
349 }
350}
351
352#[derive(Debug)]
358pub struct TopKSampler {
359 pub k: usize,
361 pub temperature: f64,
363 rng: SimpleRng,
364}
365
366impl TopKSampler {
367 pub fn new(k: usize, temperature: f64, seed: u64) -> Result<Self, SamplingError> {
371 if k == 0 {
372 return Err(SamplingError::InvalidTopK { k });
373 }
374 if temperature <= 0.0 || temperature.is_nan() {
375 return Err(SamplingError::InvalidTemperature(temperature));
376 }
377 Ok(Self {
378 k,
379 temperature,
380 rng: SimpleRng::new(seed),
381 })
382 }
383
384 pub fn sample(&mut self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
386 if logits.is_empty() {
387 return Err(SamplingError::EmptyDistribution);
388 }
389 let filtered = Self::apply_top_k(logits, self.k);
390 let scaled = scale_by_temperature(&filtered, self.temperature);
391 let probs = softmax(&scaled);
392 sample_from_probs(&probs, &mut self.rng)
393 }
394
395 pub fn apply_top_k(logits: &[f64], k: usize) -> Vec<f64> {
398 if logits.is_empty() || k == 0 {
399 return logits.to_vec();
400 }
401 let effective_k = k.min(logits.len());
402
403 let mut indexed: Vec<(f64, usize)> = logits
405 .iter()
406 .copied()
407 .enumerate()
408 .map(|(i, v)| (v, i))
409 .collect();
410 indexed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
411
412 let top_k_indices: std::collections::HashSet<usize> =
414 indexed.iter().take(effective_k).map(|&(_, i)| i).collect();
415
416 logits
417 .iter()
418 .enumerate()
419 .map(|(i, &v)| {
420 if top_k_indices.contains(&i) {
421 v
422 } else {
423 f64::NEG_INFINITY
424 }
425 })
426 .collect()
427 }
428}
429
430#[derive(Debug)]
437pub struct TopPSampler {
438 pub p: f64,
440 pub temperature: f64,
442 rng: SimpleRng,
443}
444
445impl TopPSampler {
446 pub fn new(p: f64, temperature: f64, seed: u64) -> Result<Self, SamplingError> {
450 if p <= 0.0 || p > 1.0 || p.is_nan() {
451 return Err(SamplingError::InvalidTopP { p });
452 }
453 if temperature <= 0.0 || temperature.is_nan() {
454 return Err(SamplingError::InvalidTemperature(temperature));
455 }
456 Ok(Self {
457 p,
458 temperature,
459 rng: SimpleRng::new(seed),
460 })
461 }
462
463 pub fn sample(&mut self, logits: &[f64]) -> Result<SampledToken, SamplingError> {
465 if logits.is_empty() {
466 return Err(SamplingError::EmptyDistribution);
467 }
468 let scaled = scale_by_temperature(logits, self.temperature);
469 let probs = softmax(&scaled);
470 let filtered_logits = Self::apply_top_p(&probs, self.p);
471 let filtered_probs = softmax(&filtered_logits);
472 sample_from_probs(&filtered_probs, &mut self.rng)
473 }
474
475 pub fn apply_top_p(probs: &[f64], p: f64) -> Vec<f64> {
483 if probs.is_empty() {
484 return Vec::new();
485 }
486 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
488 sorted_indices.sort_by(|&a, &b| {
489 probs[b]
490 .partial_cmp(&probs[a])
491 .unwrap_or(std::cmp::Ordering::Equal)
492 });
493
494 let mut cumsum = 0.0_f64;
496 let mut nucleus: std::collections::HashSet<usize> = std::collections::HashSet::new();
497 for &idx in &sorted_indices {
498 nucleus.insert(idx);
499 cumsum += probs[idx];
500 if cumsum >= p {
501 break;
502 }
503 }
504
505 probs
507 .iter()
508 .enumerate()
509 .map(|(i, &prob)| {
510 if nucleus.contains(&i) {
511 if prob > 0.0 {
513 prob.ln()
514 } else {
515 f64::NEG_INFINITY
516 }
517 } else {
518 f64::NEG_INFINITY
519 }
520 })
521 .collect()
522 }
523}
524
525#[derive(Debug)]
535pub struct ConfigurableSampler {
536 pub config: SamplingConfig,
538 rng: SimpleRng,
539}
540
541impl ConfigurableSampler {
542 pub fn new(config: SamplingConfig) -> Result<Self, SamplingError> {
546 if config.temperature <= 0.0 || config.temperature.is_nan() {
547 return Err(SamplingError::InvalidTemperature(config.temperature));
548 }
549 if let Some(k) = config.top_k {
550 if k == 0 {
551 return Err(SamplingError::InvalidTopK { k });
552 }
553 }
554 if let Some(p) = config.top_p {
555 if p <= 0.0 || p > 1.0 || p.is_nan() {
556 return Err(SamplingError::InvalidTopP { p });
557 }
558 }
559 let seed = config.seed.unwrap_or(42);
560 Ok(Self {
561 config,
562 rng: SimpleRng::new(seed),
563 })
564 }
565
566 pub fn with_default() -> Self {
571 Self {
572 config: SamplingConfig::default(),
573 rng: SimpleRng::new(42),
574 }
575 }
576
577 pub fn apply_repetition_penalty(logits: &mut [f64], context: &[usize], penalty: f64) {
585 if (penalty - 1.0).abs() < f64::EPSILON {
586 return; }
588 for &token_id in context {
589 if token_id < logits.len() {
590 let v = logits[token_id];
591 logits[token_id] = if v >= 0.0 { v / penalty } else { v * penalty };
592 }
593 }
594 }
595
596 pub fn sample(
604 &mut self,
605 logits: &[f64],
606 context: &[usize],
607 ) -> Result<SampledToken, SamplingError> {
608 if logits.is_empty() {
609 return Err(SamplingError::EmptyDistribution);
610 }
611
612 let mut working = logits.to_vec();
614 Self::apply_repetition_penalty(&mut working, context, self.config.repetition_penalty);
615
616 let mut working = scale_by_temperature(&working, self.config.temperature);
618
619 if let Some(k) = self.config.top_k {
621 working = TopKSampler::apply_top_k(&working, k);
622 }
623
624 if let Some(p) = self.config.top_p {
626 let probs = softmax(&working);
627 working = TopPSampler::apply_top_p(&probs, p);
628 }
629
630 let probs = softmax(&working);
632 sample_from_probs(&probs, &mut self.rng)
633 }
634}
635
636#[cfg(test)]
641mod tests {
642 use super::*;
643
644 fn logits_5() -> Vec<f64> {
646 vec![0.1, 3.5, 1.2, -1.0, 2.0]
647 }
648
649 #[test]
654 fn test_greedy_decoder_argmax() {
655 let decoder = GreedyDecoder::new();
656 let token = decoder.decode(&logits_5()).expect("decode should succeed");
658 assert_eq!(token.token_id, 1);
659 }
660
661 #[test]
662 fn test_greedy_decoder_empty() {
663 let decoder = GreedyDecoder::new();
664 let result = decoder.decode(&[]);
665 assert!(
666 matches!(result, Err(SamplingError::EmptyDistribution)),
667 "expected EmptyDistribution, got {result:?}"
668 );
669 }
670
671 #[test]
676 fn test_temperature_sampler_valid() {
677 let sampler = TemperatureSampler::new(1.0, 0);
678 assert!(sampler.is_ok(), "construction with temp=1.0 should succeed");
679 }
680
681 #[test]
682 fn test_temperature_sampler_zero_temp_error() {
683 let result = TemperatureSampler::new(0.0, 0);
684 assert!(
685 matches!(result, Err(SamplingError::InvalidTemperature(t)) if t == 0.0),
686 "expected InvalidTemperature, got {result:?}"
687 );
688 }
689
690 #[test]
691 fn test_temperature_sampler_sample_returns_valid_token() {
692 let mut sampler = TemperatureSampler::new(1.0, 42).expect("valid");
693 let lgs = logits_5();
694 let token = sampler.sample(&lgs).expect("sample should succeed");
695 assert!(token.token_id < lgs.len(), "token_id out of vocab");
696 }
697
698 #[test]
699 fn test_temperature_sampler_prob_in_range() {
700 let mut sampler = TemperatureSampler::new(1.0, 7).expect("valid");
701 let token = sampler.sample(&logits_5()).expect("sample should succeed");
702 assert!(
703 (0.0..=1.0).contains(&token.prob),
704 "prob {} is out of [0, 1]",
705 token.prob
706 );
707 }
708
709 #[test]
714 fn test_top_k_apply_filter_keeps_k() {
715 let logits = logits_5();
716 let k = 2_usize;
717 let filtered = TopKSampler::apply_top_k(&logits, k);
718 let finite_count = filtered.iter().filter(|&&v| v.is_finite()).count();
719 assert_eq!(
720 finite_count, k,
721 "expected exactly {k} finite values, got {finite_count}"
722 );
723 }
724
725 #[test]
726 fn test_top_k_sampler_sample_within_vocab() {
727 let mut sampler = TopKSampler::new(3, 1.0, 99).expect("valid");
728 let lgs = logits_5();
729 let token = sampler.sample(&lgs).expect("sample should succeed");
730 assert!(token.token_id < lgs.len(), "token_id out of vocab");
731 }
732
733 #[test]
734 fn test_top_k_zero_k_error() {
735 let result = TopKSampler::new(0, 1.0, 0);
736 assert!(
737 matches!(result, Err(SamplingError::InvalidTopK { k: 0 })),
738 "expected InvalidTopK, got {result:?}"
739 );
740 }
741
742 #[test]
747 fn test_top_p_apply_filter() {
748 let probs = vec![0.5, 0.3, 0.15, 0.04, 0.01];
750 let p = 0.8_f64;
751 let filtered_logits = TopPSampler::apply_top_p(&probs, p);
752 let nucleus_prob_sum: f64 = filtered_logits
754 .iter()
755 .filter(|&&v| v.is_finite())
756 .map(|&v| v.exp())
757 .sum();
758 assert!(
760 nucleus_prob_sum >= p - 1e-9,
761 "nucleus prob sum {nucleus_prob_sum} < p={p}"
762 );
763 }
764
765 #[test]
766 fn test_top_p_sampler_sample_valid() {
767 let mut sampler = TopPSampler::new(0.9, 1.0, 1).expect("valid");
768 let lgs = logits_5();
769 let token = sampler.sample(&lgs).expect("sample should succeed");
770 assert!(token.token_id < lgs.len());
771 }
772
773 #[test]
774 fn test_top_p_invalid_p_error() {
775 let result = TopPSampler::new(1.5, 1.0, 0);
776 assert!(
777 matches!(result, Err(SamplingError::InvalidTopP { p }) if p == 1.5),
778 "expected InvalidTopP, got {result:?}"
779 );
780 }
781
782 #[test]
787 fn test_configurable_sampler_default() {
788 let sampler = ConfigurableSampler::with_default();
789 assert_eq!(sampler.config.temperature, 1.0);
790 }
791
792 #[test]
793 fn test_configurable_sampler_with_top_k() {
794 let config = SamplingConfig {
795 temperature: 1.0,
796 top_k: Some(5),
797 top_p: None,
798 repetition_penalty: 1.0,
799 seed: Some(0),
800 };
801 let mut sampler = ConfigurableSampler::new(config).expect("valid config");
802 let lgs = logits_5();
803 let token = sampler.sample(&lgs, &[]).expect("sample should succeed");
804 assert!(token.token_id < lgs.len());
805 }
806
807 #[test]
808 fn test_repetition_penalty_reduces_seen_tokens() {
809 let logits = vec![1.0, 2.0, 3.0];
810 let mut working = logits.clone();
811 let context = vec![2_usize]; ConfigurableSampler::apply_repetition_penalty(&mut working, &context, 2.0);
813 assert!(
815 working[2] < logits[2],
816 "expected logit[2] to decrease; was {}, now {}",
817 logits[2],
818 working[2]
819 );
820 assert_eq!(working[0], logits[0]);
822 assert_eq!(working[1], logits[1]);
823 }
824
825 #[test]
830 fn test_softmax_sums_to_one() {
831 let logits = vec![1.0, 2.0, 3.0, 0.5, -1.0];
832 let probs = softmax(&logits);
833 let total: f64 = probs.iter().sum();
834 assert!((total - 1.0).abs() < 1e-12, "softmax sum={total}");
835 }
836
837 #[test]
838 fn test_softmax_numerical_stability() {
839 let logits = vec![1000.0, 999.0, 998.0];
841 let probs = softmax(&logits);
842 for &p in &probs {
843 assert!(p.is_finite() && p >= 0.0, "non-finite probability: {p}");
844 }
845 let total: f64 = probs.iter().sum();
846 assert!((total - 1.0).abs() < 1e-12, "softmax sum={total}");
847 }
848
849 #[test]
850 fn test_log_softmax_matches_log_of_softmax() {
851 let logits = vec![0.5, -1.0, 2.3, 0.0];
852 let sm = softmax(&logits);
853 let lsm = log_softmax(&logits);
854 for (s, ls) in sm.iter().zip(lsm.iter()) {
855 let expected = s.ln();
856 assert!(
857 (expected - ls).abs() < 1e-10,
858 "log(softmax)={expected} vs log_softmax={ls}"
859 );
860 }
861 }
862
863 #[test]
864 fn test_entropy_uniform() {
865 let probs = vec![0.5, 0.5];
867 let h = entropy(&probs);
868 let expected = (2.0_f64).ln();
869 assert!(
870 (h - expected).abs() < 1e-12,
871 "entropy={h} expected={expected}"
872 );
873 }
874
875 #[test]
876 fn test_perplexity_basic() {
877 let log_probs = vec![-1.0_f64];
879 let ppl = perplexity(&log_probs);
880 let expected = 1.0_f64.exp();
881 assert!(
882 (ppl - expected).abs() < 1e-12,
883 "perplexity={ppl} expected={expected}"
884 );
885 }
886}