Skip to main content

sqlitegraph/inference/
sampling.rs

1//! Sampling strategies for sparse inference.
2//!
3//! Supports temperature scaling and top-p (nucleus) sampling.
4
5/// Sample next token from logits using temperature + top-p (nucleus) sampling.
6///
7/// # Arguments
8/// * `logits` - Raw logits [vocab_size]
9/// * `temperature` - Temperature for scaling (1.0 = normal, <1 = sharper, >1 = more random)
10/// * `top_p` - Nucleus sampling threshold (0.9 = consider tokens covering 90% of probability mass)
11///
12/// # Returns
13/// Token ID (index into logits array)
14pub fn sample_token(logits: &[f32], temperature: f32, top_p: f32) -> usize {
15    let vocab_size = logits.len();
16    if vocab_size == 0 {
17        return 0;
18    }
19
20    // Temperature scaling
21    let scaled: Vec<f32> = if temperature > 0.001 {
22        logits.iter().map(|&l| l / temperature).collect()
23    } else {
24        // Greedy: find argmax
25        return logits
26            .iter()
27            .enumerate()
28            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
29            .map(|(i, _)| i)
30            .unwrap_or(0);
31    };
32
33    // Softmax: exp(x - max) for numerical stability
34    let max_val = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
35    let mut probs: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
36
37    // Normalize
38    let sum: f32 = probs.iter().sum();
39    if sum > 0.0 {
40        for p in probs.iter_mut() {
41            *p /= sum;
42        }
43    } else {
44        // Fallback: uniform
45        let uniform = 1.0 / vocab_size as f32;
46        for p in probs.iter_mut() {
47            *p = uniform;
48        }
49    }
50
51    // Top-p (nucleus) sampling
52    if top_p < 1.0 {
53        // Create index-probability pairs, sort descending
54        let mut indexed: Vec<(usize, f32)> =
55            probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
56        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
57
58        // Find cutoff
59        let mut cumsum = 0.0f32;
60        let mut cutoff_idx = indexed.len();
61        for (i, &(_, p)) in indexed.iter().enumerate() {
62            cumsum += p;
63            if cumsum >= top_p {
64                cutoff_idx = i + 1;
65                break;
66            }
67        }
68
69        // Zero out everything after cutoff
70        for &(_, _p) in indexed.iter().skip(cutoff_idx) {
71            // We'll rebuild probs below
72        }
73
74        // Rebuild probs with only top-p tokens
75        let mut new_probs = vec![0.0f32; vocab_size];
76        let new_sum: f32 = indexed[..cutoff_idx]
77            .iter()
78            .map(|&(i, p)| {
79                new_probs[i] = p;
80                p
81            })
82            .sum();
83
84        if new_sum > 0.0 {
85            for p in new_probs.iter_mut() {
86                *p /= new_sum;
87            }
88        }
89        probs = new_probs;
90    }
91
92    // Weighted random sampling
93    let r: f32 = rand::random::<f32>();
94    let mut cumsum = 0.0f32;
95    for (i, &p) in probs.iter().enumerate() {
96        cumsum += p;
97        if cumsum >= r {
98            return i;
99        }
100    }
101
102    // Fallback: last token
103    vocab_size - 1
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_greedy_sampling() {
112        let logits = vec![0.1, 0.1, 5.0, 0.1, 0.1];
113        let token = sample_token(&logits, 0.0, 1.0);
114        assert_eq!(token, 2); // index of max value
115    }
116
117    #[test]
118    fn test_temperature_sampling() {
119        let logits = vec![1.0, 1.0, 1.0, 1.0];
120        // With high temperature, all tokens roughly equal probability
121        let mut counts = vec![0usize; 4];
122        for _ in 0..1000 {
123            let token = sample_token(&logits, 1.0, 1.0);
124            assert!(token < 4);
125            counts[token] += 1;
126        }
127        // Each token should appear ~250 times (±100)
128        for &c in &counts {
129            assert!(c > 100, "Token count {} too low", c);
130        }
131    }
132
133    #[test]
134    fn test_top_p_sampling() {
135        // One dominant token, rest small
136        let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0];
137        let token = sample_token(&logits, 1.0, 0.5);
138        assert_eq!(token, 0); // top_p=0.5, dominant token covers >50%
139    }
140}