1use candle_core::Tensor;
2use candle_transformers::generation::{LogitsProcessor, Sampling};
3
4use crate::{Result, WaxError};
5
6#[derive(Debug, Clone, Copy)]
7pub struct SamplingConfig {
8 pub temperature: f64,
9 pub top_k: Option<usize>,
10 pub top_p: Option<f64>,
11 pub repetition_penalty: f32,
12 pub repeat_last_n: usize,
13 pub seed: u64,
14}
15
16impl Default for SamplingConfig {
17 fn default() -> Self {
18 Self {
19 temperature: 0.0,
20 top_k: None,
21 top_p: None,
22 repetition_penalty: 1.0,
23 repeat_last_n: 128,
24 seed: 299_792_458,
25 }
26 }
27}
28
29impl SamplingConfig {
30 pub fn validate(&self) -> Result<()> {
31 if !self.temperature.is_finite() || self.temperature < 0.0 {
32 return Err(WaxError::InvalidRequest(
33 "temperature must be finite and >= 0".to_string(),
34 ));
35 }
36 if matches!(self.top_k, Some(0)) {
37 return Err(WaxError::InvalidRequest("top-k must be > 0".to_string()));
38 }
39 if let Some(top_p) = self.top_p {
40 if !top_p.is_finite() || !(0.0..=1.0).contains(&top_p) {
41 return Err(WaxError::InvalidRequest(
42 "top-p must be finite and between 0 and 1".to_string(),
43 ));
44 }
45 }
46 if !self.repetition_penalty.is_finite() || self.repetition_penalty <= 0.0 {
47 return Err(WaxError::InvalidRequest(
48 "repetition penalty must be finite and > 0".to_string(),
49 ));
50 }
51 Ok(())
52 }
53
54 pub fn processor(&self) -> Result<LogitsProcessor> {
55 self.validate()?;
56 Ok(LogitsProcessor::from_sampling(self.seed, self.sampling()))
57 }
58
59 fn sampling(&self) -> Sampling {
60 if self.temperature <= 0.0 {
61 return Sampling::ArgMax;
62 }
63
64 match (self.top_k, self.top_p) {
65 (None, None) => Sampling::All {
66 temperature: self.temperature,
67 },
68 (Some(k), None) => Sampling::TopK {
69 k,
70 temperature: self.temperature,
71 },
72 (None, Some(p)) => Sampling::TopP {
73 p,
74 temperature: self.temperature,
75 },
76 (Some(k), Some(p)) => Sampling::TopKThenTopP {
77 k,
78 p,
79 temperature: self.temperature,
80 },
81 }
82 }
83}
84
85pub struct Sampler {
86 config: SamplingConfig,
87 processor: LogitsProcessor,
88}
89
90impl Sampler {
91 pub fn new(config: SamplingConfig) -> Result<Self> {
92 Ok(Self {
93 config,
94 processor: config.processor()?,
95 })
96 }
97
98 pub fn sample(&mut self, logits: &Tensor, tokens: &[u32]) -> Result<u32> {
99 let logits = if self.config.repetition_penalty == 1.0 {
100 logits.clone()
101 } else {
102 let start_at = tokens.len().saturating_sub(self.config.repeat_last_n);
103 candle_transformers::utils::apply_repeat_penalty(
104 logits,
105 self.config.repetition_penalty,
106 &tokens[start_at..],
107 )?
108 };
109
110 Ok(self.processor.sample(&logits)?)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use candle_core::{Device, Tensor};
117
118 use super::{Sampler, SamplingConfig};
119
120 #[test]
121 fn greedy_selects_argmax() {
122 let logits = Tensor::new(&[0.1f32, 4.0, 0.2], &Device::Cpu).unwrap();
123 let mut sampler = Sampler::new(SamplingConfig {
124 temperature: 0.0,
125 ..SamplingConfig::default()
126 })
127 .unwrap();
128
129 let token = sampler.sample(&logits, &[]).unwrap();
130
131 assert_eq!(token, 1);
132 }
133
134 #[test]
135 fn seeded_sampling_is_deterministic() {
136 let logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &Device::Cpu).unwrap();
137 let config = SamplingConfig {
138 temperature: 0.8,
139 top_k: Some(3),
140 top_p: Some(0.9),
141 seed: 42,
142 ..SamplingConfig::default()
143 };
144 let mut left = Sampler::new(config).unwrap();
145 let mut right = Sampler::new(config).unwrap();
146
147 let left_token = left.sample(&logits, &[]).unwrap();
148 let right_token = right.sample(&logits, &[]).unwrap();
149
150 assert_eq!(left_token, right_token);
151 }
152
153 #[test]
154 fn rejects_invalid_top_k() {
155 let err = SamplingConfig {
156 top_k: Some(0),
157 ..SamplingConfig::default()
158 }
159 .validate()
160 .unwrap_err();
161
162 assert!(err.to_string().contains("top-k"));
163 }
164}