Skip to main content

ruvector_attention/attention/
speculative.rs

1//! Speculative decoding with draft-verify paradigm.
2//!
3//! Speculative decoding (Leviathan et al., 2023) achieves 2-3x inference speedup
4//! with **zero quality loss** by exploiting the asymmetry between generating and
5//! verifying tokens. A small "draft" model proposes gamma candidate tokens cheaply,
6//! then the large "target" model verifies all candidates in a single forward pass.
7//!
8//! The key insight: autoregressive generation is memory-bandwidth-bound, not
9//! compute-bound. The target model's forward pass for gamma+1 positions costs
10//! nearly the same as a single-token forward pass because the GPU is underutilized
11//! during single-token generation. By batching gamma+1 positions, we amortize the
12//! cost of the target model across multiple accepted tokens.
13//!
14//! The rejection sampling scheme guarantees that the output distribution is
15//! **identical** to sampling from the target model alone -- no approximation.
16
17use crate::error::{AttentionError, AttentionResult};
18
19/// Token identifier.
20pub type TokenId = u32;
21
22/// Configuration for speculative decoding.
23#[derive(Clone, Debug)]
24pub struct SpeculativeConfig {
25    /// Number of draft tokens to generate per step (typically 4-8).
26    pub gamma: usize,
27    /// Sampling temperature. Values > 1.0 increase randomness.
28    pub temperature: f32,
29    /// Nucleus sampling threshold. Tokens with cumulative probability above
30    /// this are excluded.
31    pub top_p: f32,
32    /// Maximum sequence length for the generation.
33    pub max_seq_len: usize,
34}
35
36impl SpeculativeConfig {
37    /// Creates a new configuration with the given draft length.
38    pub fn new(gamma: usize) -> Self {
39        Self {
40            gamma,
41            temperature: 1.0,
42            top_p: 1.0,
43            max_seq_len: 2048,
44        }
45    }
46
47    /// Validates the configuration parameters.
48    pub fn validate(&self) -> AttentionResult<()> {
49        let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
50        if self.gamma == 0 {
51            return err("gamma must be > 0");
52        }
53        if self.gamma > 32 {
54            return err("gamma must be <= 32");
55        }
56        if self.temperature <= 0.0 {
57            return err("temperature must be > 0");
58        }
59        if self.top_p <= 0.0 || self.top_p > 1.0 {
60            return err("top_p must be in (0, 1]");
61        }
62        if self.max_seq_len == 0 {
63            return err("max_seq_len must be > 0");
64        }
65        Ok(())
66    }
67}
68
69/// Draft model trait: a small, fast model that proposes candidate tokens.
70pub trait DraftModel: Send + Sync {
71    /// Generates `gamma` draft tokens given a prefix.
72    ///
73    /// Returns a vector of (token_id, probability) pairs representing the
74    /// draft model's greedy/sampled choices and their probabilities under
75    /// the draft distribution.
76    fn draft_tokens(&self, prefix: &[TokenId], gamma: usize) -> Vec<(TokenId, f32)>;
77}
78
79/// Target model trait: the large, accurate model that verifies drafts.
80pub trait TargetModel: Send + Sync {
81    /// Evaluates the target model on all draft positions in one forward pass.
82    ///
83    /// Given the prefix and the draft tokens, returns the target model's full
84    /// probability distribution at each of the `gamma + 1` positions (gamma
85    /// verification positions plus one bonus position).
86    ///
87    /// Each inner `Vec<(TokenId, f32)>` is a sparse probability distribution
88    /// over the vocabulary (only tokens with nonzero probability need appear).
89    fn verify_batch(
90        &self,
91        prefix: &[TokenId],
92        draft_tokens: &[TokenId],
93    ) -> Vec<Vec<(TokenId, f32)>>;
94}
95
96/// Result of a single speculative decoding step.
97#[derive(Clone, Debug)]
98pub struct AcceptedTokens {
99    /// The tokens accepted in this step (1 to gamma+1).
100    pub tokens: Vec<TokenId>,
101    /// Fraction of draft tokens that were accepted.
102    pub acceptance_rate: f32,
103    /// Number of draft model calls made.
104    pub draft_calls: usize,
105    /// Number of target model calls made (always 1 per step).
106    pub target_calls: usize,
107}
108
109/// Aggregate statistics for a speculative decoding session.
110#[derive(Clone, Debug, Default)]
111pub struct DecodingStats {
112    /// Total tokens generated across all steps.
113    pub tokens_generated: usize,
114    /// Running acceptance rate.
115    pub acceptance_rate: f32,
116    /// Observed speedup ratio vs autoregressive decoding.
117    pub speedup_ratio: f32,
118    /// Average draft model latency in milliseconds.
119    pub draft_latency_ms: f64,
120    /// Average target model latency in milliseconds.
121    pub target_latency_ms: f64,
122}
123
124/// Computes the theoretical speedup from speculative decoding.
125///
126/// Formula: `(gamma * alpha) / (1 + gamma * (1 - alpha))`
127///
128/// where `gamma` is the draft length and `alpha` is the acceptance rate.
129/// At alpha=1.0 (all accepted) speedup approaches gamma.
130/// At alpha=0.0 (all rejected) speedup is 0 (worse than baseline).
131pub fn theoretical_speedup(gamma: usize, acceptance_rate: f32) -> f32 {
132    let g = gamma as f32;
133    let a = acceptance_rate.clamp(0.0, 1.0);
134    let denominator = 1.0 + g * (1.0 - a);
135    if denominator <= 0.0 {
136        return 0.0;
137    }
138    (g * a) / denominator
139}
140
141/// The core speculative decoder implementing the Leviathan et al. algorithm.
142pub struct SpeculativeDecoder;
143
144impl SpeculativeDecoder {
145    /// Performs one speculative decoding step.
146    ///
147    /// # Algorithm
148    ///
149    /// 1. Draft model generates gamma candidate tokens with probabilities q_i.
150    /// 2. Target model verifies all gamma+1 positions in one forward pass,
151    ///    producing distributions p_i.
152    /// 3. For each draft token i (left to right):
153    ///    - If p_i(t_i) >= q_i(t_i): accept unconditionally.
154    ///    - Otherwise: accept with probability p_i(t_i) / q_i(t_i).
155    ///    - On rejection: sample from adjusted distribution max(0, p_i - q_i)
156    ///      (normalized), then stop.
157    /// 4. If all gamma tokens accepted: bonus sample from p_{gamma+1}.
158    pub fn decode_step(
159        prefix: &[TokenId],
160        draft: &dyn DraftModel,
161        target: &dyn TargetModel,
162        config: &SpeculativeConfig,
163        rng_values: Option<&[f32]>,
164    ) -> AttentionResult<AcceptedTokens> {
165        config.validate()?;
166
167        let draft_results = draft.draft_tokens(prefix, config.gamma);
168        if draft_results.is_empty() {
169            return Err(AttentionError::EmptyInput(
170                "draft model returned no tokens".into(),
171            ));
172        }
173
174        let draft_tokens: Vec<TokenId> = draft_results.iter().map(|(t, _)| *t).collect();
175        let draft_probs: Vec<f32> = draft_results.iter().map(|(_, p)| *p).collect();
176
177        let target_dists = target.verify_batch(prefix, &draft_tokens);
178        if target_dists.len() < draft_tokens.len() + 1 {
179            return Err(AttentionError::ComputationError(
180                "target model must return gamma+1 distributions".into(),
181            ));
182        }
183
184        let mut accepted = Vec::new();
185        let mut rejected = false;
186
187        for i in 0..draft_tokens.len() {
188            let token = draft_tokens[i];
189            let q_i = draft_probs[i];
190            let p_i = prob_of_token(&target_dists[i], token);
191
192            let rng_val = rng_values.and_then(|v| v.get(i).copied()).unwrap_or(0.0);
193
194            if p_i >= q_i {
195                // Accept unconditionally: target agrees at least as much.
196                accepted.push(token);
197            } else if rng_val < p_i / q_i {
198                // Accept with probability p_i / q_i.
199                accepted.push(token);
200            } else {
201                // Reject: sample from adjusted distribution max(0, p - q).
202                let adjusted = sample_adjusted(&target_dists[i], &draft_tokens, &draft_probs, i);
203                accepted.push(adjusted);
204                rejected = true;
205                break;
206            }
207        }
208
209        // If all gamma tokens accepted, bonus sample from p_{gamma+1}.
210        if !rejected {
211            let bonus_dist = &target_dists[draft_tokens.len()];
212            if let Some(&(token, _)) = bonus_dist.first() {
213                accepted.push(token);
214            }
215        }
216
217        let num_draft = draft_tokens.len();
218        let num_accepted_from_draft = if rejected {
219            accepted.len().saturating_sub(1)
220        } else {
221            num_draft
222        };
223        let acceptance_rate = if num_draft > 0 {
224            num_accepted_from_draft as f32 / num_draft as f32
225        } else {
226            0.0
227        };
228
229        Ok(AcceptedTokens {
230            tokens: accepted,
231            acceptance_rate,
232            draft_calls: 1,
233            target_calls: 1,
234        })
235    }
236}
237
238/// Look up the probability of a specific token in a sparse distribution.
239fn prob_of_token(dist: &[(TokenId, f32)], token: TokenId) -> f32 {
240    dist.iter()
241        .find(|(t, _)| *t == token)
242        .map(|(_, p)| *p)
243        .unwrap_or(0.0)
244}
245
246/// Sample from the adjusted distribution max(0, p_i - q_i), normalized.
247///
248/// For simplicity, we take the token with the highest adjusted probability.
249/// In production, this would use proper categorical sampling.
250fn sample_adjusted(
251    target_dist: &[(TokenId, f32)],
252    draft_tokens: &[TokenId],
253    draft_probs: &[f32],
254    position: usize,
255) -> TokenId {
256    let mut best_token = target_dist.first().map(|(t, _)| *t).unwrap_or(0);
257    let mut best_score = f32::NEG_INFINITY;
258
259    for &(token, p_target) in target_dist {
260        let p_draft = if token == draft_tokens[position] {
261            draft_probs[position]
262        } else {
263            0.0
264        };
265        let adjusted = (p_target - p_draft).max(0.0);
266        if adjusted > best_score {
267            best_score = adjusted;
268            best_token = token;
269        }
270    }
271    best_token
272}
273
274// ---------------------------------------------------------------------------
275// Medusa-style parallel decoding
276// ---------------------------------------------------------------------------
277
278/// A single Medusa prediction head that produces candidate tokens
279/// from a shared hidden state.
280pub trait MedusaHead: Send + Sync {
281    /// Predicts candidate tokens for one future position.
282    ///
283    /// Returns a sparse distribution over the vocabulary.
284    fn predict(&self, prefix: &[TokenId]) -> Vec<(TokenId, f32)>;
285}
286
287/// Result of Medusa-style tree verification.
288#[derive(Clone, Debug)]
289pub struct MedusaResult {
290    /// Accepted tokens from the best verified path.
291    pub tokens: Vec<TokenId>,
292    /// Number of candidate paths evaluated.
293    pub paths_evaluated: usize,
294}
295
296/// Performs simplified Medusa-style parallel decoding.
297///
298/// Instead of a single draft sequence, multiple independent heads each
299/// predict one future token, forming a tree of candidates. The target
300/// model verifies the most promising path in one forward pass.
301pub fn medusa_decode(
302    prefix: &[TokenId],
303    heads: &[&dyn MedusaHead],
304    target: &dyn TargetModel,
305    config: &SpeculativeConfig,
306) -> AttentionResult<MedusaResult> {
307    config.validate()?;
308
309    if heads.is_empty() {
310        return Err(AttentionError::EmptyInput(
311            "at least one Medusa head required".into(),
312        ));
313    }
314
315    // Each head predicts one position ahead.
316    let head_predictions: Vec<Vec<(TokenId, f32)>> =
317        heads.iter().map(|h| h.predict(prefix)).collect();
318
319    // Build the greedy candidate path (top-1 from each head).
320    let candidate_path: Vec<TokenId> = head_predictions
321        .iter()
322        .filter_map(|dist| dist.first().map(|(t, _)| *t))
323        .collect();
324
325    if candidate_path.is_empty() {
326        return Err(AttentionError::EmptyInput(
327            "heads produced no predictions".into(),
328        ));
329    }
330
331    // Verify the candidate path with the target model.
332    let target_dists = target.verify_batch(prefix, &candidate_path);
333
334    // Accept tokens while the target model agrees.
335    let mut accepted = Vec::new();
336    for (i, &token) in candidate_path.iter().enumerate() {
337        if i >= target_dists.len() {
338            break;
339        }
340        let p = prob_of_token(&target_dists[i], token);
341        if p > 0.0 {
342            accepted.push(token);
343        } else {
344            break;
345        }
346    }
347
348    // If nothing was accepted, take the target model's top choice at pos 0.
349    if accepted.is_empty() {
350        if let Some(dist) = target_dists.first() {
351            if let Some(&(token, _)) = dist.first() {
352                accepted.push(token);
353            }
354        }
355    }
356
357    Ok(MedusaResult {
358        tokens: accepted,
359        paths_evaluated: 1, // greedy path only in this simplified version
360    })
361}
362
363// ---------------------------------------------------------------------------
364// Mock implementations for testing
365// ---------------------------------------------------------------------------
366
367/// A mock draft model with a configurable token sequence and probability.
368pub struct SimpleDraftModel {
369    /// Tokens the draft model will propose, cycling if gamma > len.
370    pub tokens: Vec<TokenId>,
371    /// Probability assigned to each drafted token.
372    pub probability: f32,
373}
374
375impl DraftModel for SimpleDraftModel {
376    fn draft_tokens(&self, _prefix: &[TokenId], gamma: usize) -> Vec<(TokenId, f32)> {
377        (0..gamma)
378            .map(|i| {
379                let token = self.tokens[i % self.tokens.len()];
380                (token, self.probability)
381            })
382            .collect()
383    }
384}
385
386/// A mock target model that returns configurable distributions.
387pub struct SimpleTargetModel {
388    /// Distributions to return for each position.
389    /// If `verify_batch` requests more positions than available,
390    /// the last distribution is repeated.
391    pub distributions: Vec<Vec<(TokenId, f32)>>,
392}
393
394impl TargetModel for SimpleTargetModel {
395    fn verify_batch(
396        &self,
397        _prefix: &[TokenId],
398        draft_tokens: &[TokenId],
399    ) -> Vec<Vec<(TokenId, f32)>> {
400        let needed = draft_tokens.len() + 1;
401        (0..needed)
402            .map(|i| {
403                if i < self.distributions.len() {
404                    self.distributions[i].clone()
405                } else {
406                    self.distributions
407                        .last()
408                        .cloned()
409                        .unwrap_or_else(|| vec![(0, 1.0)])
410                }
411            })
412            .collect()
413    }
414}
415
416/// A mock Medusa head that always predicts a fixed token.
417pub struct SimpleMedusaHead {
418    /// The token this head predicts.
419    pub token: TokenId,
420    /// Probability assigned to the prediction.
421    pub probability: f32,
422}
423
424impl MedusaHead for SimpleMedusaHead {
425    fn predict(&self, _prefix: &[TokenId]) -> Vec<(TokenId, f32)> {
426        vec![(self.token, self.probability)]
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    fn default_config() -> SpeculativeConfig {
435        SpeculativeConfig::new(4)
436    }
437
438    // -- Config validation tests --
439
440    #[test]
441    fn test_config_valid() {
442        assert!(default_config().validate().is_ok());
443    }
444
445    #[test]
446    fn test_config_gamma_zero() {
447        let mut cfg = default_config();
448        cfg.gamma = 0;
449        assert!(cfg.validate().is_err());
450    }
451
452    #[test]
453    fn test_config_gamma_too_large() {
454        let mut cfg = default_config();
455        cfg.gamma = 33;
456        assert!(cfg.validate().is_err());
457    }
458
459    #[test]
460    fn test_config_bad_temperature() {
461        let mut cfg = default_config();
462        cfg.temperature = 0.0;
463        assert!(cfg.validate().is_err());
464    }
465
466    #[test]
467    fn test_config_bad_top_p() {
468        let mut cfg = default_config();
469        cfg.top_p = 0.0;
470        assert!(cfg.validate().is_err());
471
472        cfg.top_p = 1.1;
473        assert!(cfg.validate().is_err());
474    }
475
476    // -- Full acceptance test --
477
478    #[test]
479    fn test_full_acceptance() {
480        // Target probability >= draft probability at every position -> all accept.
481        let draft = SimpleDraftModel {
482            tokens: vec![10, 20, 30, 40],
483            probability: 0.5,
484        };
485        let target = SimpleTargetModel {
486            distributions: vec![
487                vec![(10, 0.8)],
488                vec![(20, 0.7)],
489                vec![(30, 0.6)],
490                vec![(40, 0.9)],
491                vec![(50, 1.0)], // bonus position
492            ],
493        };
494
495        let result =
496            SpeculativeDecoder::decode_step(&[1, 2, 3], &draft, &target, &default_config(), None)
497                .unwrap();
498
499        // All 4 draft tokens accepted + 1 bonus = 5 tokens.
500        assert_eq!(result.tokens.len(), 5);
501        assert_eq!(result.tokens, vec![10, 20, 30, 40, 50]);
502        assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
503    }
504
505    // -- Full rejection test --
506
507    #[test]
508    fn test_full_rejection() {
509        // Target probability 0 for the draft token -> immediate rejection.
510        let draft = SimpleDraftModel {
511            tokens: vec![10, 20, 30, 40],
512            probability: 0.9,
513        };
514        // The target gives 0 prob to token 10, but high prob to token 99.
515        let target = SimpleTargetModel {
516            distributions: vec![
517                vec![(99, 0.9)],
518                vec![(99, 0.9)],
519                vec![(99, 0.9)],
520                vec![(99, 0.9)],
521                vec![(99, 1.0)],
522            ],
523        };
524
525        let result = SpeculativeDecoder::decode_step(
526            &[1],
527            &draft,
528            &target,
529            &default_config(),
530            Some(&[1.0, 1.0, 1.0, 1.0]), // rng=1.0 forces rejection
531        )
532        .unwrap();
533
534        // First token rejected, replaced by adjusted sample (token 99).
535        assert_eq!(result.tokens.len(), 1);
536        assert_eq!(result.tokens[0], 99);
537        assert!((result.acceptance_rate - 0.0).abs() < f32::EPSILON);
538    }
539
540    // -- Partial acceptance test --
541
542    #[test]
543    fn test_partial_acceptance() {
544        let draft = SimpleDraftModel {
545            tokens: vec![10, 20, 30, 40],
546            probability: 0.5,
547        };
548        // Accept first two (p >= q), reject third (p=0).
549        let target = SimpleTargetModel {
550            distributions: vec![
551                vec![(10, 0.8)],
552                vec![(20, 0.6)],
553                vec![(77, 0.9)], // no prob for 30 -> reject
554                vec![(40, 0.9)],
555                vec![(50, 1.0)],
556            ],
557        };
558
559        let result = SpeculativeDecoder::decode_step(
560            &[1],
561            &draft,
562            &target,
563            &default_config(),
564            Some(&[0.0, 0.0, 1.0, 0.0]), // rng=1.0 at pos 2 forces reject
565        )
566        .unwrap();
567
568        // Accepted: 10, 20, then rejected at 30 -> adjusted sample = 77.
569        assert_eq!(result.tokens.len(), 3);
570        assert_eq!(result.tokens[0], 10);
571        assert_eq!(result.tokens[1], 20);
572        assert_eq!(result.tokens[2], 77);
573        assert!((result.acceptance_rate - 0.5).abs() < f32::EPSILON);
574    }
575
576    // -- Rejection sampling produces adjusted distribution token --
577
578    #[test]
579    fn test_rejection_sampling_distribution() {
580        let draft = SimpleDraftModel {
581            tokens: vec![10],
582            probability: 0.8,
583        };
584        // Target gives 0.3 to token 10 and 0.7 to token 42.
585        // Adjusted: max(0, 0.3 - 0.8) = 0 for 10, max(0, 0.7 - 0) = 0.7 for 42.
586        // So adjusted sample should be 42.
587        let target = SimpleTargetModel {
588            distributions: vec![vec![(10, 0.3), (42, 0.7)], vec![(99, 1.0)]],
589        };
590
591        let cfg = SpeculativeConfig::new(1);
592        let result = SpeculativeDecoder::decode_step(
593            &[1],
594            &draft,
595            &target,
596            &cfg,
597            Some(&[1.0]), // force rejection
598        )
599        .unwrap();
600
601        assert_eq!(result.tokens.len(), 1);
602        assert_eq!(result.tokens[0], 42);
603    }
604
605    // -- Speedup calculation --
606
607    #[test]
608    fn test_theoretical_speedup() {
609        // gamma=4, alpha=1.0 -> speedup = 4*1 / (1+4*0) = 4.0
610        let s = theoretical_speedup(4, 1.0);
611        assert!((s - 4.0).abs() < 1e-5);
612
613        // gamma=4, alpha=0.0 -> speedup = 0 / (1+4) = 0.0
614        let s = theoretical_speedup(4, 0.0);
615        assert!(s.abs() < 1e-5);
616
617        // gamma=4, alpha=0.8 -> 4*0.8 / (1+4*0.2) = 3.2 / 1.8 ~= 1.778
618        let s = theoretical_speedup(4, 0.8);
619        assert!((s - 3.2 / 1.8).abs() < 1e-4);
620
621        // gamma=8, alpha=0.9 -> 7.2 / 1.8 = 4.0
622        let s = theoretical_speedup(8, 0.9);
623        assert!((s - 7.2 / 1.8).abs() < 1e-4);
624    }
625
626    // -- Medusa tree verification --
627
628    #[test]
629    fn test_medusa_decode() {
630        let h1 = SimpleMedusaHead {
631            token: 10,
632            probability: 0.9,
633        };
634        let h2 = SimpleMedusaHead {
635            token: 20,
636            probability: 0.8,
637        };
638        let target = SimpleTargetModel {
639            distributions: vec![vec![(10, 0.7)], vec![(20, 0.6)], vec![(99, 1.0)]],
640        };
641
642        let heads: Vec<&dyn MedusaHead> = vec![&h1, &h2];
643        let result = medusa_decode(&[1, 2], &heads, &target, &default_config()).unwrap();
644
645        assert_eq!(result.tokens, vec![10, 20]);
646        assert_eq!(result.paths_evaluated, 1);
647    }
648
649    #[test]
650    fn test_medusa_no_heads() {
651        let target = SimpleTargetModel {
652            distributions: vec![vec![(1, 1.0)]],
653        };
654        let heads: Vec<&dyn MedusaHead> = vec![];
655        let result = medusa_decode(&[1], &heads, &target, &default_config());
656        assert!(result.is_err());
657    }
658
659    // -- Edge case: probabilistic acceptance --
660
661    #[test]
662    fn test_probabilistic_acceptance() {
663        // p_i(t_i) < q_i(t_i) but rng is low enough to accept.
664        let draft = SimpleDraftModel {
665            tokens: vec![10],
666            probability: 0.8,
667        };
668        let target = SimpleTargetModel {
669            distributions: vec![
670                vec![(10, 0.4)], // p/q = 0.5
671                vec![(99, 1.0)],
672            ],
673        };
674
675        let cfg = SpeculativeConfig::new(1);
676        // rng = 0.3 < 0.5 (p/q) -> accept
677        let result =
678            SpeculativeDecoder::decode_step(&[1], &draft, &target, &cfg, Some(&[0.3])).unwrap();
679
680        // Accepted draft token + bonus
681        assert_eq!(result.tokens, vec![10, 99]);
682        assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
683    }
684
685    // -- Edge case: empty prefix --
686
687    #[test]
688    fn test_empty_prefix() {
689        let draft = SimpleDraftModel {
690            tokens: vec![5],
691            probability: 0.5,
692        };
693        let target = SimpleTargetModel {
694            distributions: vec![vec![(5, 0.9)], vec![(6, 1.0)]],
695        };
696
697        let cfg = SpeculativeConfig::new(1);
698        let result = SpeculativeDecoder::decode_step(&[], &draft, &target, &cfg, None).unwrap();
699
700        assert_eq!(result.tokens, vec![5, 6]);
701    }
702}