Skip to main content

scirs2_neural/speculative/
types.rs

1//! Types for speculative decoding.
2//!
3//! This module defines the core data structures used throughout the speculative
4//! decoding pipeline: configuration, verification results, decoding statistics,
5//! and token probability distributions.
6
7use std::fmt;
8
9/// Configuration for speculative decoding.
10///
11/// Controls the behavior of the draft-then-verify loop, including how many
12/// draft tokens to generate per step, sampling parameters, and whether to
13/// adaptively adjust the draft length based on acceptance rates.
14#[derive(Debug, Clone)]
15pub struct SpeculativeConfig {
16    /// Number of tokens the draft model proposes per step.
17    ///
18    /// Higher values amortize target-model calls but risk lower acceptance
19    /// rates. Typical range: 2..=8.
20    pub draft_length: usize,
21
22    /// Sampling temperature applied to both draft and target distributions.
23    ///
24    /// Values below 1.0 sharpen the distribution (more greedy);
25    /// values above 1.0 flatten it (more random).
26    pub temperature: f64,
27
28    /// Top-k filtering: only the `top_k` highest-probability tokens are
29    /// considered during sampling. Set to 0 to disable.
30    pub top_k: usize,
31
32    /// Maximum number of tokens to generate in total (including the prompt).
33    pub max_tokens: usize,
34
35    /// When `true`, the decoder dynamically adjusts `draft_length` based on
36    /// the rolling acceptance rate.
37    pub adaptive_draft: bool,
38}
39
40impl Default for SpeculativeConfig {
41    fn default() -> Self {
42        Self {
43            draft_length: 4,
44            temperature: 1.0,
45            top_k: 50,
46            max_tokens: 512,
47            adaptive_draft: false,
48        }
49    }
50}
51
52impl fmt::Display for SpeculativeConfig {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        write!(
55            f,
56            "SpeculativeConfig(draft_length={}, temperature={:.2}, top_k={}, max_tokens={}, adaptive={})",
57            self.draft_length, self.temperature, self.top_k, self.max_tokens, self.adaptive_draft
58        )
59    }
60}
61
62/// Result of verifying a batch of draft tokens against the target model.
63#[derive(Debug, Clone)]
64pub struct VerificationResult {
65    /// Token ids that were accepted by the rejection-sampling step.
66    pub accepted_tokens: Vec<usize>,
67
68    /// If a rejection occurred, the 0-based position within the draft where
69    /// the first rejection happened. `None` when every draft token was accepted.
70    pub rejected_at: Option<usize>,
71
72    /// Fraction of draft tokens that were accepted (0.0..=1.0).
73    pub acceptance_rate: f64,
74}
75
76impl VerificationResult {
77    /// Create a new verification result.
78    pub fn new(
79        accepted_tokens: Vec<usize>,
80        rejected_at: Option<usize>,
81        acceptance_rate: f64,
82    ) -> Self {
83        Self {
84            accepted_tokens,
85            rejected_at,
86            acceptance_rate,
87        }
88    }
89
90    /// Returns `true` when all draft tokens were accepted.
91    pub fn all_accepted(&self) -> bool {
92        self.rejected_at.is_none()
93    }
94
95    /// Number of accepted tokens.
96    pub fn num_accepted(&self) -> usize {
97        self.accepted_tokens.len()
98    }
99}
100
101impl fmt::Display for VerificationResult {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        write!(
104            f,
105            "VerificationResult(accepted={}, rejected_at={:?}, rate={:.2})",
106            self.accepted_tokens.len(),
107            self.rejected_at,
108            self.acceptance_rate
109        )
110    }
111}
112
113/// Aggregate statistics collected during a full speculative decoding run.
114#[derive(Debug, Clone)]
115pub struct DecodingStats {
116    /// Total number of tokens produced (final output length minus prompt length).
117    pub total_tokens: usize,
118
119    /// Total number of draft tokens proposed across all steps.
120    pub draft_tokens: usize,
121
122    /// Total number of draft tokens accepted across all steps.
123    pub accepted_tokens: usize,
124
125    /// Wall-clock time for the decoding run, in milliseconds.
126    pub wall_time_ms: f64,
127
128    /// Average number of tokens produced per decoding step.
129    ///
130    /// For pure autoregressive decoding this is 1.0; speculative decoding
131    /// aims for values > 1.0.
132    pub tokens_per_step: f64,
133}
134
135impl DecodingStats {
136    /// Create empty statistics.
137    pub fn new() -> Self {
138        Self {
139            total_tokens: 0,
140            draft_tokens: 0,
141            accepted_tokens: 0,
142            wall_time_ms: 0.0,
143            tokens_per_step: 0.0,
144        }
145    }
146
147    /// Overall acceptance rate across the entire decoding run.
148    pub fn acceptance_rate(&self) -> f64 {
149        if self.draft_tokens == 0 {
150            0.0
151        } else {
152            self.accepted_tokens as f64 / self.draft_tokens as f64
153        }
154    }
155
156    /// Tokens generated per millisecond.
157    pub fn throughput(&self) -> f64 {
158        if self.wall_time_ms <= 0.0 {
159            0.0
160        } else {
161            self.total_tokens as f64 / self.wall_time_ms
162        }
163    }
164}
165
166impl Default for DecodingStats {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl fmt::Display for DecodingStats {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        write!(
175            f,
176            "DecodingStats(total={}, drafted={}, accepted={}, rate={:.2}, tok/step={:.2}, time={:.1}ms)",
177            self.total_tokens,
178            self.draft_tokens,
179            self.accepted_tokens,
180            self.acceptance_rate(),
181            self.tokens_per_step,
182            self.wall_time_ms,
183        )
184    }
185}
186
187/// A probability distribution over a vocabulary of tokens.
188///
189/// Wraps a dense vector of non-negative values that sum to 1.0 (within
190/// floating-point tolerance). The index into the vector is the token id.
191#[derive(Debug, Clone)]
192pub struct TokenDistribution {
193    /// Probability assigned to each token in the vocabulary.
194    probs: Vec<f64>,
195}
196
197impl TokenDistribution {
198    /// Create a distribution from a probability vector.
199    ///
200    /// Returns `None` if `probs` is empty or contains negative values.
201    /// The vector is normalized to sum to 1.0.
202    pub fn from_probs(probs: Vec<f64>) -> Option<Self> {
203        if probs.is_empty() {
204            return None;
205        }
206        // Check for negative values
207        if probs.iter().any(|&p| p < 0.0) {
208            return None;
209        }
210        let sum: f64 = probs.iter().sum();
211        if sum <= 0.0 {
212            return None;
213        }
214        let normalized: Vec<f64> = probs.iter().map(|&p| p / sum).collect();
215        Some(Self { probs: normalized })
216    }
217
218    /// Create a uniform distribution over `vocab_size` tokens.
219    pub fn uniform(vocab_size: usize) -> Option<Self> {
220        if vocab_size == 0 {
221            return None;
222        }
223        let p = 1.0 / vocab_size as f64;
224        Some(Self {
225            probs: vec![p; vocab_size],
226        })
227    }
228
229    /// Create a distribution from log-probabilities.
230    ///
231    /// Applies the softmax transformation: `p_i = exp(logp_i) / sum(exp(logp_j))`.
232    /// Uses the log-sum-exp trick for numerical stability.
233    pub fn from_log_probs(log_probs: &[f64]) -> Option<Self> {
234        if log_probs.is_empty() {
235            return None;
236        }
237        let max_lp = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
238        if max_lp.is_nan() {
239            return None;
240        }
241        let exps: Vec<f64> = log_probs.iter().map(|&lp| (lp - max_lp).exp()).collect();
242        let sum: f64 = exps.iter().sum();
243        if sum <= 0.0 || sum.is_nan() {
244            return None;
245        }
246        let probs: Vec<f64> = exps.iter().map(|&e| e / sum).collect();
247        Some(Self { probs })
248    }
249
250    /// Vocabulary size.
251    pub fn vocab_size(&self) -> usize {
252        self.probs.len()
253    }
254
255    /// Probability of a given token.
256    ///
257    /// Returns 0.0 for out-of-range token ids.
258    pub fn prob(&self, token_id: usize) -> f64 {
259        self.probs.get(token_id).copied().unwrap_or(0.0)
260    }
261
262    /// Borrow the raw probability vector.
263    pub fn probs(&self) -> &[f64] {
264        &self.probs
265    }
266
267    /// Apply temperature scaling, returning a new distribution.
268    ///
269    /// Temperature < 1.0 sharpens; temperature > 1.0 flattens.
270    /// Returns `None` if temperature is non-positive.
271    pub fn with_temperature(&self, temperature: f64) -> Option<Self> {
272        if temperature <= 0.0 {
273            return None;
274        }
275        if (temperature - 1.0).abs() < 1e-12 {
276            return Some(self.clone());
277        }
278        // Work in log space for stability
279        let log_probs: Vec<f64> = self
280            .probs
281            .iter()
282            .map(|&p| {
283                if p > 0.0 {
284                    p.ln() / temperature
285                } else {
286                    f64::NEG_INFINITY
287                }
288            })
289            .collect();
290        Self::from_log_probs(&log_probs)
291    }
292
293    /// Apply top-k filtering, zeroing out all but the `k` highest-probability tokens.
294    ///
295    /// Returns `None` if `k` is 0.
296    pub fn with_top_k(&self, k: usize) -> Option<Self> {
297        if k == 0 {
298            return None;
299        }
300        if k >= self.probs.len() {
301            return Some(self.clone());
302        }
303        // Find the k-th largest probability
304        let mut sorted: Vec<f64> = self.probs.clone();
305        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
306        let threshold = sorted[k - 1];
307
308        // Keep only tokens with prob >= threshold (may keep slightly more than k if ties)
309        let filtered: Vec<f64> = self
310            .probs
311            .iter()
312            .map(|&p| if p >= threshold { p } else { 0.0 })
313            .collect();
314        Self::from_probs(filtered)
315    }
316
317    /// Sample a token from this distribution using the provided random value
318    /// `u` in \[0, 1).
319    pub fn sample_with_uniform(&self, u: f64) -> usize {
320        let u = u.clamp(0.0, 1.0 - f64::EPSILON);
321        let mut cumulative = 0.0;
322        for (i, &p) in self.probs.iter().enumerate() {
323            cumulative += p;
324            if u < cumulative {
325                return i;
326            }
327        }
328        // Fallback: return last token (should not happen with proper normalization)
329        self.probs.len().saturating_sub(1)
330    }
331
332    /// Token id with the highest probability (argmax).
333    pub fn argmax(&self) -> usize {
334        self.probs
335            .iter()
336            .enumerate()
337            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
338            .map(|(i, _)| i)
339            .unwrap_or(0)
340    }
341}
342
343impl fmt::Display for TokenDistribution {
344    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345        let top = self.argmax();
346        write!(
347            f,
348            "TokenDistribution(vocab={}, top_token={}, top_prob={:.4})",
349            self.vocab_size(),
350            top,
351            self.prob(top),
352        )
353    }
354}