Skip to main content

wax_core/
sampler.rs

1use candle_core::Tensor;
2use candle_transformers::generation::{LogitsProcessor, Sampling};
3
4use crate::{Result, WaxError};
5
6#[derive(Debug, Clone, Copy)]
7pub struct SamplingConfig {
8    pub temperature: f64,
9    pub top_k: Option<usize>,
10    pub top_p: Option<f64>,
11    pub repetition_penalty: f32,
12    pub repeat_last_n: usize,
13    pub seed: u64,
14}
15
16impl Default for SamplingConfig {
17    fn default() -> Self {
18        Self {
19            temperature: 0.0,
20            top_k: None,
21            top_p: None,
22            repetition_penalty: 1.0,
23            repeat_last_n: 128,
24            seed: 299_792_458,
25        }
26    }
27}
28
29impl SamplingConfig {
30    pub fn validate(&self) -> Result<()> {
31        if !self.temperature.is_finite() || self.temperature < 0.0 {
32            return Err(WaxError::InvalidRequest(
33                "temperature must be finite and >= 0".to_string(),
34            ));
35        }
36        if matches!(self.top_k, Some(0)) {
37            return Err(WaxError::InvalidRequest("top-k must be > 0".to_string()));
38        }
39        if let Some(top_p) = self.top_p {
40            if !top_p.is_finite() || !(0.0..=1.0).contains(&top_p) {
41                return Err(WaxError::InvalidRequest(
42                    "top-p must be finite and between 0 and 1".to_string(),
43                ));
44            }
45        }
46        if !self.repetition_penalty.is_finite() || self.repetition_penalty <= 0.0 {
47            return Err(WaxError::InvalidRequest(
48                "repetition penalty must be finite and > 0".to_string(),
49            ));
50        }
51        Ok(())
52    }
53
54    pub fn processor(&self) -> Result<LogitsProcessor> {
55        self.validate()?;
56        Ok(LogitsProcessor::from_sampling(self.seed, self.sampling()))
57    }
58
59    fn sampling(&self) -> Sampling {
60        if self.temperature <= 0.0 {
61            return Sampling::ArgMax;
62        }
63
64        match (self.top_k, self.top_p) {
65            (None, None) => Sampling::All {
66                temperature: self.temperature,
67            },
68            (Some(k), None) => Sampling::TopK {
69                k,
70                temperature: self.temperature,
71            },
72            (None, Some(p)) => Sampling::TopP {
73                p,
74                temperature: self.temperature,
75            },
76            (Some(k), Some(p)) => Sampling::TopKThenTopP {
77                k,
78                p,
79                temperature: self.temperature,
80            },
81        }
82    }
83}
84
85pub struct Sampler {
86    config: SamplingConfig,
87    processor: LogitsProcessor,
88}
89
90impl Sampler {
91    pub fn new(config: SamplingConfig) -> Result<Self> {
92        Ok(Self {
93            config,
94            processor: config.processor()?,
95        })
96    }
97
98    pub fn sample(&mut self, logits: &Tensor, tokens: &[u32]) -> Result<u32> {
99        let logits = if self.config.repetition_penalty == 1.0 {
100            logits.clone()
101        } else {
102            let start_at = tokens.len().saturating_sub(self.config.repeat_last_n);
103            candle_transformers::utils::apply_repeat_penalty(
104                logits,
105                self.config.repetition_penalty,
106                &tokens[start_at..],
107            )?
108        };
109
110        Ok(self.processor.sample(&logits)?)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use candle_core::{Device, Tensor};
117
118    use super::{Sampler, SamplingConfig};
119
120    #[test]
121    fn greedy_selects_argmax() {
122        let logits = Tensor::new(&[0.1f32, 4.0, 0.2], &Device::Cpu).unwrap();
123        let mut sampler = Sampler::new(SamplingConfig {
124            temperature: 0.0,
125            ..SamplingConfig::default()
126        })
127        .unwrap();
128
129        let token = sampler.sample(&logits, &[]).unwrap();
130
131        assert_eq!(token, 1);
132    }
133
134    #[test]
135    fn seeded_sampling_is_deterministic() {
136        let logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &Device::Cpu).unwrap();
137        let config = SamplingConfig {
138            temperature: 0.8,
139            top_k: Some(3),
140            top_p: Some(0.9),
141            seed: 42,
142            ..SamplingConfig::default()
143        };
144        let mut left = Sampler::new(config).unwrap();
145        let mut right = Sampler::new(config).unwrap();
146
147        let left_token = left.sample(&logits, &[]).unwrap();
148        let right_token = right.sample(&logits, &[]).unwrap();
149
150        assert_eq!(left_token, right_token);
151    }
152
153    #[test]
154    fn rejects_invalid_top_k() {
155        let err = SamplingConfig {
156            top_k: Some(0),
157            ..SamplingConfig::default()
158        }
159        .validate()
160        .unwrap_err();
161
162        assert!(err.to_string().contains("top-k"));
163    }
164}