Skip to main content

ruvector_attention/attention/
speculative.rs

1//! Speculative decoding with draft-verify paradigm.
2//!
3//! Speculative decoding (Leviathan et al., 2023) achieves 2-3x inference speedup
4//! with **zero quality loss** by exploiting the asymmetry between generating and
5//! verifying tokens. A small "draft" model proposes gamma candidate tokens cheaply,
6//! then the large "target" model verifies all candidates in a single forward pass.
7//!
8//! The key insight: autoregressive generation is memory-bandwidth-bound, not
9//! compute-bound. The target model's forward pass for gamma+1 positions costs
10//! nearly the same as a single-token forward pass because the GPU is underutilized
11//! during single-token generation. By batching gamma+1 positions, we amortize the
12//! cost of the target model across multiple accepted tokens.
13//!
14//! The rejection sampling scheme guarantees that the output distribution is
15//! **identical** to sampling from the target model alone -- no approximation.
16
17use crate::error::{AttentionError, AttentionResult};
18
19/// Token identifier.
20pub type TokenId = u32;
21
22/// Configuration for speculative decoding.
23#[derive(Clone, Debug)]
24pub struct SpeculativeConfig {
25    /// Number of draft tokens to generate per step (typically 4-8).
26    pub gamma: usize,
27    /// Sampling temperature. Values > 1.0 increase randomness.
28    pub temperature: f32,
29    /// Nucleus sampling threshold. Tokens with cumulative probability above
30    /// this are excluded.
31    pub top_p: f32,
32    /// Maximum sequence length for the generation.
33    pub max_seq_len: usize,
34}
35
36impl SpeculativeConfig {
37    /// Creates a new configuration with the given draft length.
38    pub fn new(gamma: usize) -> Self {
39        Self {
40            gamma,
41            temperature: 1.0,
42            top_p: 1.0,
43            max_seq_len: 2048,
44        }
45    }
46
47    /// Validates the configuration parameters.
48    pub fn validate(&self) -> AttentionResult<()> {
49        let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
50        if self.gamma == 0 {
51            return err("gamma must be > 0");
52        }
53        if self.gamma > 32 {
54            return err("gamma must be <= 32");
55        }
56        if self.temperature <= 0.0 {
57            return err("temperature must be > 0");
58        }
59        if self.top_p <= 0.0 || self.top_p > 1.0 {
60            return err("top_p must be in (0, 1]");
61        }
62        if self.max_seq_len == 0 {
63            return err("max_seq_len must be > 0");
64        }
65        Ok(())
66    }
67}
68
69/// Draft model trait: a small, fast model that proposes candidate tokens.
70pub trait DraftModel: Send + Sync {
71    /// Generates `gamma` draft tokens given a prefix.
72    ///
73    /// Returns a vector of (token_id, probability) pairs representing the
74    /// draft model's greedy/sampled choices and their probabilities under
75    /// the draft distribution.
76    fn draft_tokens(
77        &self,
78        prefix: &[TokenId],
79        gamma: usize,
80    ) -> Vec<(TokenId, f32)>;
81}
82
83/// Target model trait: the large, accurate model that verifies drafts.
84pub trait TargetModel: Send + Sync {
85    /// Evaluates the target model on all draft positions in one forward pass.
86    ///
87    /// Given the prefix and the draft tokens, returns the target model's full
88    /// probability distribution at each of the `gamma + 1` positions (gamma
89    /// verification positions plus one bonus position).
90    ///
91    /// Each inner `Vec<(TokenId, f32)>` is a sparse probability distribution
92    /// over the vocabulary (only tokens with nonzero probability need appear).
93    fn verify_batch(
94        &self,
95        prefix: &[TokenId],
96        draft_tokens: &[TokenId],
97    ) -> Vec<Vec<(TokenId, f32)>>;
98}
99
100/// Result of a single speculative decoding step.
101#[derive(Clone, Debug)]
102pub struct AcceptedTokens {
103    /// The tokens accepted in this step (1 to gamma+1).
104    pub tokens: Vec<TokenId>,
105    /// Fraction of draft tokens that were accepted.
106    pub acceptance_rate: f32,
107    /// Number of draft model calls made.
108    pub draft_calls: usize,
109    /// Number of target model calls made (always 1 per step).
110    pub target_calls: usize,
111}
112
113/// Aggregate statistics for a speculative decoding session.
114#[derive(Clone, Debug, Default)]
115pub struct DecodingStats {
116    /// Total tokens generated across all steps.
117    pub tokens_generated: usize,
118    /// Running acceptance rate.
119    pub acceptance_rate: f32,
120    /// Observed speedup ratio vs autoregressive decoding.
121    pub speedup_ratio: f32,
122    /// Average draft model latency in milliseconds.
123    pub draft_latency_ms: f64,
124    /// Average target model latency in milliseconds.
125    pub target_latency_ms: f64,
126}
127
128/// Computes the theoretical speedup from speculative decoding.
129///
130/// Formula: `(gamma * alpha) / (1 + gamma * (1 - alpha))`
131///
132/// where `gamma` is the draft length and `alpha` is the acceptance rate.
133/// At alpha=1.0 (all accepted) speedup approaches gamma.
134/// At alpha=0.0 (all rejected) speedup is 0 (worse than baseline).
135pub fn theoretical_speedup(gamma: usize, acceptance_rate: f32) -> f32 {
136    let g = gamma as f32;
137    let a = acceptance_rate.clamp(0.0, 1.0);
138    let denominator = 1.0 + g * (1.0 - a);
139    if denominator <= 0.0 {
140        return 0.0;
141    }
142    (g * a) / denominator
143}
144
145/// The core speculative decoder implementing the Leviathan et al. algorithm.
146pub struct SpeculativeDecoder;
147
148impl SpeculativeDecoder {
149    /// Performs one speculative decoding step.
150    ///
151    /// # Algorithm
152    ///
153    /// 1. Draft model generates gamma candidate tokens with probabilities q_i.
154    /// 2. Target model verifies all gamma+1 positions in one forward pass,
155    ///    producing distributions p_i.
156    /// 3. For each draft token i (left to right):
157    ///    - If p_i(t_i) >= q_i(t_i): accept unconditionally.
158    ///    - Otherwise: accept with probability p_i(t_i) / q_i(t_i).
159    ///    - On rejection: sample from adjusted distribution max(0, p_i - q_i)
160    ///      (normalized), then stop.
161    /// 4. If all gamma tokens accepted: bonus sample from p_{gamma+1}.
162    pub fn decode_step(
163        prefix: &[TokenId],
164        draft: &dyn DraftModel,
165        target: &dyn TargetModel,
166        config: &SpeculativeConfig,
167        rng_values: Option<&[f32]>,
168    ) -> AttentionResult<AcceptedTokens> {
169        config.validate()?;
170
171        let draft_results = draft.draft_tokens(prefix, config.gamma);
172        if draft_results.is_empty() {
173            return Err(AttentionError::EmptyInput(
174                "draft model returned no tokens".into(),
175            ));
176        }
177
178        let draft_tokens: Vec<TokenId> =
179            draft_results.iter().map(|(t, _)| *t).collect();
180        let draft_probs: Vec<f32> =
181            draft_results.iter().map(|(_, p)| *p).collect();
182
183        let target_dists = target.verify_batch(prefix, &draft_tokens);
184        if target_dists.len() < draft_tokens.len() + 1 {
185            return Err(AttentionError::ComputationError(
186                "target model must return gamma+1 distributions".into(),
187            ));
188        }
189
190        let mut accepted = Vec::new();
191        let mut rejected = false;
192
193        for i in 0..draft_tokens.len() {
194            let token = draft_tokens[i];
195            let q_i = draft_probs[i];
196            let p_i = prob_of_token(&target_dists[i], token);
197
198            let rng_val = rng_values
199                .and_then(|v| v.get(i).copied())
200                .unwrap_or(0.0);
201
202            if p_i >= q_i {
203                // Accept unconditionally: target agrees at least as much.
204                accepted.push(token);
205            } else if rng_val < p_i / q_i {
206                // Accept with probability p_i / q_i.
207                accepted.push(token);
208            } else {
209                // Reject: sample from adjusted distribution max(0, p - q).
210                let adjusted = sample_adjusted(
211                    &target_dists[i],
212                    &draft_tokens,
213                    &draft_probs,
214                    i,
215                );
216                accepted.push(adjusted);
217                rejected = true;
218                break;
219            }
220        }
221
222        // If all gamma tokens accepted, bonus sample from p_{gamma+1}.
223        if !rejected {
224            let bonus_dist = &target_dists[draft_tokens.len()];
225            if let Some(&(token, _)) = bonus_dist.first() {
226                accepted.push(token);
227            }
228        }
229
230        let num_draft = draft_tokens.len();
231        let num_accepted_from_draft = if rejected {
232            accepted.len().saturating_sub(1)
233        } else {
234            num_draft
235        };
236        let acceptance_rate = if num_draft > 0 {
237            num_accepted_from_draft as f32 / num_draft as f32
238        } else {
239            0.0
240        };
241
242        Ok(AcceptedTokens {
243            tokens: accepted,
244            acceptance_rate,
245            draft_calls: 1,
246            target_calls: 1,
247        })
248    }
249}
250
251/// Look up the probability of a specific token in a sparse distribution.
252fn prob_of_token(dist: &[(TokenId, f32)], token: TokenId) -> f32 {
253    dist.iter()
254        .find(|(t, _)| *t == token)
255        .map(|(_, p)| *p)
256        .unwrap_or(0.0)
257}
258
259/// Sample from the adjusted distribution max(0, p_i - q_i), normalized.
260///
261/// For simplicity, we take the token with the highest adjusted probability.
262/// In production, this would use proper categorical sampling.
263fn sample_adjusted(
264    target_dist: &[(TokenId, f32)],
265    draft_tokens: &[TokenId],
266    draft_probs: &[f32],
267    position: usize,
268) -> TokenId {
269    let mut best_token = target_dist
270        .first()
271        .map(|(t, _)| *t)
272        .unwrap_or(0);
273    let mut best_score = f32::NEG_INFINITY;
274
275    for &(token, p_target) in target_dist {
276        let p_draft = if token == draft_tokens[position] {
277            draft_probs[position]
278        } else {
279            0.0
280        };
281        let adjusted = (p_target - p_draft).max(0.0);
282        if adjusted > best_score {
283            best_score = adjusted;
284            best_token = token;
285        }
286    }
287    best_token
288}
289
290// ---------------------------------------------------------------------------
291// Medusa-style parallel decoding
292// ---------------------------------------------------------------------------
293
294/// A single Medusa prediction head that produces candidate tokens
295/// from a shared hidden state.
296pub trait MedusaHead: Send + Sync {
297    /// Predicts candidate tokens for one future position.
298    ///
299    /// Returns a sparse distribution over the vocabulary.
300    fn predict(&self, prefix: &[TokenId]) -> Vec<(TokenId, f32)>;
301}
302
303/// Result of Medusa-style tree verification.
304#[derive(Clone, Debug)]
305pub struct MedusaResult {
306    /// Accepted tokens from the best verified path.
307    pub tokens: Vec<TokenId>,
308    /// Number of candidate paths evaluated.
309    pub paths_evaluated: usize,
310}
311
312/// Performs simplified Medusa-style parallel decoding.
313///
314/// Instead of a single draft sequence, multiple independent heads each
315/// predict one future token, forming a tree of candidates. The target
316/// model verifies the most promising path in one forward pass.
317pub fn medusa_decode(
318    prefix: &[TokenId],
319    heads: &[&dyn MedusaHead],
320    target: &dyn TargetModel,
321    config: &SpeculativeConfig,
322) -> AttentionResult<MedusaResult> {
323    config.validate()?;
324
325    if heads.is_empty() {
326        return Err(AttentionError::EmptyInput(
327            "at least one Medusa head required".into(),
328        ));
329    }
330
331    // Each head predicts one position ahead.
332    let head_predictions: Vec<Vec<(TokenId, f32)>> = heads
333        .iter()
334        .map(|h| h.predict(prefix))
335        .collect();
336
337    // Build the greedy candidate path (top-1 from each head).
338    let candidate_path: Vec<TokenId> = head_predictions
339        .iter()
340        .filter_map(|dist| dist.first().map(|(t, _)| *t))
341        .collect();
342
343    if candidate_path.is_empty() {
344        return Err(AttentionError::EmptyInput(
345            "heads produced no predictions".into(),
346        ));
347    }
348
349    // Verify the candidate path with the target model.
350    let target_dists = target.verify_batch(prefix, &candidate_path);
351
352    // Accept tokens while the target model agrees.
353    let mut accepted = Vec::new();
354    for (i, &token) in candidate_path.iter().enumerate() {
355        if i >= target_dists.len() {
356            break;
357        }
358        let p = prob_of_token(&target_dists[i], token);
359        if p > 0.0 {
360            accepted.push(token);
361        } else {
362            break;
363        }
364    }
365
366    // If nothing was accepted, take the target model's top choice at pos 0.
367    if accepted.is_empty() {
368        if let Some(dist) = target_dists.first() {
369            if let Some(&(token, _)) = dist.first() {
370                accepted.push(token);
371            }
372        }
373    }
374
375    Ok(MedusaResult {
376        tokens: accepted,
377        paths_evaluated: 1, // greedy path only in this simplified version
378    })
379}
380
381// ---------------------------------------------------------------------------
382// Mock implementations for testing
383// ---------------------------------------------------------------------------
384
385/// A mock draft model with a configurable token sequence and probability.
386pub struct SimpleDraftModel {
387    /// Tokens the draft model will propose, cycling if gamma > len.
388    pub tokens: Vec<TokenId>,
389    /// Probability assigned to each drafted token.
390    pub probability: f32,
391}
392
393impl DraftModel for SimpleDraftModel {
394    fn draft_tokens(
395        &self,
396        _prefix: &[TokenId],
397        gamma: usize,
398    ) -> Vec<(TokenId, f32)> {
399        (0..gamma)
400            .map(|i| {
401                let token = self.tokens[i % self.tokens.len()];
402                (token, self.probability)
403            })
404            .collect()
405    }
406}
407
408/// A mock target model that returns configurable distributions.
409pub struct SimpleTargetModel {
410    /// Distributions to return for each position.
411    /// If `verify_batch` requests more positions than available,
412    /// the last distribution is repeated.
413    pub distributions: Vec<Vec<(TokenId, f32)>>,
414}
415
416impl TargetModel for SimpleTargetModel {
417    fn verify_batch(
418        &self,
419        _prefix: &[TokenId],
420        draft_tokens: &[TokenId],
421    ) -> Vec<Vec<(TokenId, f32)>> {
422        let needed = draft_tokens.len() + 1;
423        (0..needed)
424            .map(|i| {
425                if i < self.distributions.len() {
426                    self.distributions[i].clone()
427                } else {
428                    self.distributions
429                        .last()
430                        .cloned()
431                        .unwrap_or_else(|| vec![(0, 1.0)])
432                }
433            })
434            .collect()
435    }
436}
437
438/// A mock Medusa head that always predicts a fixed token.
439pub struct SimpleMedusaHead {
440    /// The token this head predicts.
441    pub token: TokenId,
442    /// Probability assigned to the prediction.
443    pub probability: f32,
444}
445
446impl MedusaHead for SimpleMedusaHead {
447    fn predict(&self, _prefix: &[TokenId]) -> Vec<(TokenId, f32)> {
448        vec![(self.token, self.probability)]
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    fn default_config() -> SpeculativeConfig {
457        SpeculativeConfig::new(4)
458    }
459
460    // -- Config validation tests --
461
462    #[test]
463    fn test_config_valid() {
464        assert!(default_config().validate().is_ok());
465    }
466
467    #[test]
468    fn test_config_gamma_zero() {
469        let mut cfg = default_config();
470        cfg.gamma = 0;
471        assert!(cfg.validate().is_err());
472    }
473
474    #[test]
475    fn test_config_gamma_too_large() {
476        let mut cfg = default_config();
477        cfg.gamma = 33;
478        assert!(cfg.validate().is_err());
479    }
480
481    #[test]
482    fn test_config_bad_temperature() {
483        let mut cfg = default_config();
484        cfg.temperature = 0.0;
485        assert!(cfg.validate().is_err());
486    }
487
488    #[test]
489    fn test_config_bad_top_p() {
490        let mut cfg = default_config();
491        cfg.top_p = 0.0;
492        assert!(cfg.validate().is_err());
493
494        cfg.top_p = 1.1;
495        assert!(cfg.validate().is_err());
496    }
497
498    // -- Full acceptance test --
499
500    #[test]
501    fn test_full_acceptance() {
502        // Target probability >= draft probability at every position -> all accept.
503        let draft = SimpleDraftModel {
504            tokens: vec![10, 20, 30, 40],
505            probability: 0.5,
506        };
507        let target = SimpleTargetModel {
508            distributions: vec![
509                vec![(10, 0.8)],
510                vec![(20, 0.7)],
511                vec![(30, 0.6)],
512                vec![(40, 0.9)],
513                vec![(50, 1.0)], // bonus position
514            ],
515        };
516
517        let result = SpeculativeDecoder::decode_step(
518            &[1, 2, 3],
519            &draft,
520            &target,
521            &default_config(),
522            None,
523        )
524        .unwrap();
525
526        // All 4 draft tokens accepted + 1 bonus = 5 tokens.
527        assert_eq!(result.tokens.len(), 5);
528        assert_eq!(result.tokens, vec![10, 20, 30, 40, 50]);
529        assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
530    }
531
532    // -- Full rejection test --
533
534    #[test]
535    fn test_full_rejection() {
536        // Target probability 0 for the draft token -> immediate rejection.
537        let draft = SimpleDraftModel {
538            tokens: vec![10, 20, 30, 40],
539            probability: 0.9,
540        };
541        // The target gives 0 prob to token 10, but high prob to token 99.
542        let target = SimpleTargetModel {
543            distributions: vec![
544                vec![(99, 0.9)],
545                vec![(99, 0.9)],
546                vec![(99, 0.9)],
547                vec![(99, 0.9)],
548                vec![(99, 1.0)],
549            ],
550        };
551
552        let result = SpeculativeDecoder::decode_step(
553            &[1],
554            &draft,
555            &target,
556            &default_config(),
557            Some(&[1.0, 1.0, 1.0, 1.0]), // rng=1.0 forces rejection
558        )
559        .unwrap();
560
561        // First token rejected, replaced by adjusted sample (token 99).
562        assert_eq!(result.tokens.len(), 1);
563        assert_eq!(result.tokens[0], 99);
564        assert!((result.acceptance_rate - 0.0).abs() < f32::EPSILON);
565    }
566
567    // -- Partial acceptance test --
568
569    #[test]
570    fn test_partial_acceptance() {
571        let draft = SimpleDraftModel {
572            tokens: vec![10, 20, 30, 40],
573            probability: 0.5,
574        };
575        // Accept first two (p >= q), reject third (p=0).
576        let target = SimpleTargetModel {
577            distributions: vec![
578                vec![(10, 0.8)],
579                vec![(20, 0.6)],
580                vec![(77, 0.9)], // no prob for 30 -> reject
581                vec![(40, 0.9)],
582                vec![(50, 1.0)],
583            ],
584        };
585
586        let result = SpeculativeDecoder::decode_step(
587            &[1],
588            &draft,
589            &target,
590            &default_config(),
591            Some(&[0.0, 0.0, 1.0, 0.0]), // rng=1.0 at pos 2 forces reject
592        )
593        .unwrap();
594
595        // Accepted: 10, 20, then rejected at 30 -> adjusted sample = 77.
596        assert_eq!(result.tokens.len(), 3);
597        assert_eq!(result.tokens[0], 10);
598        assert_eq!(result.tokens[1], 20);
599        assert_eq!(result.tokens[2], 77);
600        assert!((result.acceptance_rate - 0.5).abs() < f32::EPSILON);
601    }
602
603    // -- Rejection sampling produces adjusted distribution token --
604
605    #[test]
606    fn test_rejection_sampling_distribution() {
607        let draft = SimpleDraftModel {
608            tokens: vec![10],
609            probability: 0.8,
610        };
611        // Target gives 0.3 to token 10 and 0.7 to token 42.
612        // Adjusted: max(0, 0.3 - 0.8) = 0 for 10, max(0, 0.7 - 0) = 0.7 for 42.
613        // So adjusted sample should be 42.
614        let target = SimpleTargetModel {
615            distributions: vec![
616                vec![(10, 0.3), (42, 0.7)],
617                vec![(99, 1.0)],
618            ],
619        };
620
621        let cfg = SpeculativeConfig::new(1);
622        let result = SpeculativeDecoder::decode_step(
623            &[1],
624            &draft,
625            &target,
626            &cfg,
627            Some(&[1.0]), // force rejection
628        )
629        .unwrap();
630
631        assert_eq!(result.tokens.len(), 1);
632        assert_eq!(result.tokens[0], 42);
633    }
634
635    // -- Speedup calculation --
636
637    #[test]
638    fn test_theoretical_speedup() {
639        // gamma=4, alpha=1.0 -> speedup = 4*1 / (1+4*0) = 4.0
640        let s = theoretical_speedup(4, 1.0);
641        assert!((s - 4.0).abs() < 1e-5);
642
643        // gamma=4, alpha=0.0 -> speedup = 0 / (1+4) = 0.0
644        let s = theoretical_speedup(4, 0.0);
645        assert!(s.abs() < 1e-5);
646
647        // gamma=4, alpha=0.8 -> 4*0.8 / (1+4*0.2) = 3.2 / 1.8 ~= 1.778
648        let s = theoretical_speedup(4, 0.8);
649        assert!((s - 3.2 / 1.8).abs() < 1e-4);
650
651        // gamma=8, alpha=0.9 -> 7.2 / 1.8 = 4.0
652        let s = theoretical_speedup(8, 0.9);
653        assert!((s - 7.2 / 1.8).abs() < 1e-4);
654    }
655
656    // -- Medusa tree verification --
657
658    #[test]
659    fn test_medusa_decode() {
660        let h1 = SimpleMedusaHead {
661            token: 10,
662            probability: 0.9,
663        };
664        let h2 = SimpleMedusaHead {
665            token: 20,
666            probability: 0.8,
667        };
668        let target = SimpleTargetModel {
669            distributions: vec![
670                vec![(10, 0.7)],
671                vec![(20, 0.6)],
672                vec![(99, 1.0)],
673            ],
674        };
675
676        let heads: Vec<&dyn MedusaHead> = vec![&h1, &h2];
677        let result =
678            medusa_decode(&[1, 2], &heads, &target, &default_config()).unwrap();
679
680        assert_eq!(result.tokens, vec![10, 20]);
681        assert_eq!(result.paths_evaluated, 1);
682    }
683
684    #[test]
685    fn test_medusa_no_heads() {
686        let target = SimpleTargetModel {
687            distributions: vec![vec![(1, 1.0)]],
688        };
689        let heads: Vec<&dyn MedusaHead> = vec![];
690        let result =
691            medusa_decode(&[1], &heads, &target, &default_config());
692        assert!(result.is_err());
693    }
694
695    // -- Edge case: probabilistic acceptance --
696
697    #[test]
698    fn test_probabilistic_acceptance() {
699        // p_i(t_i) < q_i(t_i) but rng is low enough to accept.
700        let draft = SimpleDraftModel {
701            tokens: vec![10],
702            probability: 0.8,
703        };
704        let target = SimpleTargetModel {
705            distributions: vec![
706                vec![(10, 0.4)], // p/q = 0.5
707                vec![(99, 1.0)],
708            ],
709        };
710
711        let cfg = SpeculativeConfig::new(1);
712        // rng = 0.3 < 0.5 (p/q) -> accept
713        let result = SpeculativeDecoder::decode_step(
714            &[1],
715            &draft,
716            &target,
717            &cfg,
718            Some(&[0.3]),
719        )
720        .unwrap();
721
722        // Accepted draft token + bonus
723        assert_eq!(result.tokens, vec![10, 99]);
724        assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
725    }
726
727    // -- Edge case: empty prefix --
728
729    #[test]
730    fn test_empty_prefix() {
731        let draft = SimpleDraftModel {
732            tokens: vec![5],
733            probability: 0.5,
734        };
735        let target = SimpleTargetModel {
736            distributions: vec![
737                vec![(5, 0.9)],
738                vec![(6, 1.0)],
739            ],
740        };
741
742        let cfg = SpeculativeConfig::new(1);
743        let result = SpeculativeDecoder::decode_step(
744            &[],
745            &draft,
746            &target,
747            &cfg,
748            None,
749        )
750        .unwrap();
751
752        assert_eq!(result.tokens, vec![5, 6]);
753    }
754}