sqlitegraph/inference/
sampling.rs1pub fn sample_token(logits: &[f32], temperature: f32, top_p: f32) -> usize {
15 let vocab_size = logits.len();
16 if vocab_size == 0 {
17 return 0;
18 }
19
20 let scaled: Vec<f32> = if temperature > 0.001 {
22 logits.iter().map(|&l| l / temperature).collect()
23 } else {
24 return logits
26 .iter()
27 .enumerate()
28 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
29 .map(|(i, _)| i)
30 .unwrap_or(0);
31 };
32
33 let max_val = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
35 let mut probs: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
36
37 let sum: f32 = probs.iter().sum();
39 if sum > 0.0 {
40 for p in probs.iter_mut() {
41 *p /= sum;
42 }
43 } else {
44 let uniform = 1.0 / vocab_size as f32;
46 for p in probs.iter_mut() {
47 *p = uniform;
48 }
49 }
50
51 if top_p < 1.0 {
53 let mut indexed: Vec<(usize, f32)> =
55 probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
56 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
57
58 let mut cumsum = 0.0f32;
60 let mut cutoff_idx = indexed.len();
61 for (i, &(_, p)) in indexed.iter().enumerate() {
62 cumsum += p;
63 if cumsum >= top_p {
64 cutoff_idx = i + 1;
65 break;
66 }
67 }
68
69 for &(_, _p) in indexed.iter().skip(cutoff_idx) {
71 }
73
74 let mut new_probs = vec![0.0f32; vocab_size];
76 let new_sum: f32 = indexed[..cutoff_idx]
77 .iter()
78 .map(|&(i, p)| {
79 new_probs[i] = p;
80 p
81 })
82 .sum();
83
84 if new_sum > 0.0 {
85 for p in new_probs.iter_mut() {
86 *p /= new_sum;
87 }
88 }
89 probs = new_probs;
90 }
91
92 let r: f32 = rand::random::<f32>();
94 let mut cumsum = 0.0f32;
95 for (i, &p) in probs.iter().enumerate() {
96 cumsum += p;
97 if cumsum >= r {
98 return i;
99 }
100 }
101
102 vocab_size - 1
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_greedy_sampling() {
112 let logits = vec![0.1, 0.1, 5.0, 0.1, 0.1];
113 let token = sample_token(&logits, 0.0, 1.0);
114 assert_eq!(token, 2); }
116
117 #[test]
118 fn test_temperature_sampling() {
119 let logits = vec![1.0, 1.0, 1.0, 1.0];
120 let mut counts = vec![0usize; 4];
122 for _ in 0..1000 {
123 let token = sample_token(&logits, 1.0, 1.0);
124 assert!(token < 4);
125 counts[token] += 1;
126 }
127 for &c in &counts {
129 assert!(c > 100, "Token count {} too low", c);
130 }
131 }
132
133 #[test]
134 fn test_top_p_sampling() {
135 let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0];
137 let token = sample_token(&logits, 1.0, 0.5);
138 assert_eq!(token, 0); }
140}