scirs2_neural/speculative/
types.rs1use std::fmt;
8
9#[derive(Debug, Clone)]
15pub struct SpeculativeConfig {
16 pub draft_length: usize,
21
22 pub temperature: f64,
27
28 pub top_k: usize,
31
32 pub max_tokens: usize,
34
35 pub adaptive_draft: bool,
38}
39
40impl Default for SpeculativeConfig {
41 fn default() -> Self {
42 Self {
43 draft_length: 4,
44 temperature: 1.0,
45 top_k: 50,
46 max_tokens: 512,
47 adaptive_draft: false,
48 }
49 }
50}
51
52impl fmt::Display for SpeculativeConfig {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(
55 f,
56 "SpeculativeConfig(draft_length={}, temperature={:.2}, top_k={}, max_tokens={}, adaptive={})",
57 self.draft_length, self.temperature, self.top_k, self.max_tokens, self.adaptive_draft
58 )
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct VerificationResult {
65 pub accepted_tokens: Vec<usize>,
67
68 pub rejected_at: Option<usize>,
71
72 pub acceptance_rate: f64,
74}
75
76impl VerificationResult {
77 pub fn new(
79 accepted_tokens: Vec<usize>,
80 rejected_at: Option<usize>,
81 acceptance_rate: f64,
82 ) -> Self {
83 Self {
84 accepted_tokens,
85 rejected_at,
86 acceptance_rate,
87 }
88 }
89
90 pub fn all_accepted(&self) -> bool {
92 self.rejected_at.is_none()
93 }
94
95 pub fn num_accepted(&self) -> usize {
97 self.accepted_tokens.len()
98 }
99}
100
101impl fmt::Display for VerificationResult {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 write!(
104 f,
105 "VerificationResult(accepted={}, rejected_at={:?}, rate={:.2})",
106 self.accepted_tokens.len(),
107 self.rejected_at,
108 self.acceptance_rate
109 )
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct DecodingStats {
116 pub total_tokens: usize,
118
119 pub draft_tokens: usize,
121
122 pub accepted_tokens: usize,
124
125 pub wall_time_ms: f64,
127
128 pub tokens_per_step: f64,
133}
134
135impl DecodingStats {
136 pub fn new() -> Self {
138 Self {
139 total_tokens: 0,
140 draft_tokens: 0,
141 accepted_tokens: 0,
142 wall_time_ms: 0.0,
143 tokens_per_step: 0.0,
144 }
145 }
146
147 pub fn acceptance_rate(&self) -> f64 {
149 if self.draft_tokens == 0 {
150 0.0
151 } else {
152 self.accepted_tokens as f64 / self.draft_tokens as f64
153 }
154 }
155
156 pub fn throughput(&self) -> f64 {
158 if self.wall_time_ms <= 0.0 {
159 0.0
160 } else {
161 self.total_tokens as f64 / self.wall_time_ms
162 }
163 }
164}
165
166impl Default for DecodingStats {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl fmt::Display for DecodingStats {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 write!(
175 f,
176 "DecodingStats(total={}, drafted={}, accepted={}, rate={:.2}, tok/step={:.2}, time={:.1}ms)",
177 self.total_tokens,
178 self.draft_tokens,
179 self.accepted_tokens,
180 self.acceptance_rate(),
181 self.tokens_per_step,
182 self.wall_time_ms,
183 )
184 }
185}
186
187#[derive(Debug, Clone)]
192pub struct TokenDistribution {
193 probs: Vec<f64>,
195}
196
197impl TokenDistribution {
198 pub fn from_probs(probs: Vec<f64>) -> Option<Self> {
203 if probs.is_empty() {
204 return None;
205 }
206 if probs.iter().any(|&p| p < 0.0) {
208 return None;
209 }
210 let sum: f64 = probs.iter().sum();
211 if sum <= 0.0 {
212 return None;
213 }
214 let normalized: Vec<f64> = probs.iter().map(|&p| p / sum).collect();
215 Some(Self { probs: normalized })
216 }
217
218 pub fn uniform(vocab_size: usize) -> Option<Self> {
220 if vocab_size == 0 {
221 return None;
222 }
223 let p = 1.0 / vocab_size as f64;
224 Some(Self {
225 probs: vec![p; vocab_size],
226 })
227 }
228
229 pub fn from_log_probs(log_probs: &[f64]) -> Option<Self> {
234 if log_probs.is_empty() {
235 return None;
236 }
237 let max_lp = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
238 if max_lp.is_nan() {
239 return None;
240 }
241 let exps: Vec<f64> = log_probs.iter().map(|&lp| (lp - max_lp).exp()).collect();
242 let sum: f64 = exps.iter().sum();
243 if sum <= 0.0 || sum.is_nan() {
244 return None;
245 }
246 let probs: Vec<f64> = exps.iter().map(|&e| e / sum).collect();
247 Some(Self { probs })
248 }
249
250 pub fn vocab_size(&self) -> usize {
252 self.probs.len()
253 }
254
255 pub fn prob(&self, token_id: usize) -> f64 {
259 self.probs.get(token_id).copied().unwrap_or(0.0)
260 }
261
262 pub fn probs(&self) -> &[f64] {
264 &self.probs
265 }
266
267 pub fn with_temperature(&self, temperature: f64) -> Option<Self> {
272 if temperature <= 0.0 {
273 return None;
274 }
275 if (temperature - 1.0).abs() < 1e-12 {
276 return Some(self.clone());
277 }
278 let log_probs: Vec<f64> = self
280 .probs
281 .iter()
282 .map(|&p| {
283 if p > 0.0 {
284 p.ln() / temperature
285 } else {
286 f64::NEG_INFINITY
287 }
288 })
289 .collect();
290 Self::from_log_probs(&log_probs)
291 }
292
293 pub fn with_top_k(&self, k: usize) -> Option<Self> {
297 if k == 0 {
298 return None;
299 }
300 if k >= self.probs.len() {
301 return Some(self.clone());
302 }
303 let mut sorted: Vec<f64> = self.probs.clone();
305 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
306 let threshold = sorted[k - 1];
307
308 let filtered: Vec<f64> = self
310 .probs
311 .iter()
312 .map(|&p| if p >= threshold { p } else { 0.0 })
313 .collect();
314 Self::from_probs(filtered)
315 }
316
317 pub fn sample_with_uniform(&self, u: f64) -> usize {
320 let u = u.clamp(0.0, 1.0 - f64::EPSILON);
321 let mut cumulative = 0.0;
322 for (i, &p) in self.probs.iter().enumerate() {
323 cumulative += p;
324 if u < cumulative {
325 return i;
326 }
327 }
328 self.probs.len().saturating_sub(1)
330 }
331
332 pub fn argmax(&self) -> usize {
334 self.probs
335 .iter()
336 .enumerate()
337 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
338 .map(|(i, _)| i)
339 .unwrap_or(0)
340 }
341}
342
343impl fmt::Display for TokenDistribution {
344 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345 let top = self.argmax();
346 write!(
347 f,
348 "TokenDistribution(vocab={}, top_token={}, top_prob={:.4})",
349 self.vocab_size(),
350 top,
351 self.prob(top),
352 )
353 }
354}