1use crate::error::{AttentionError, AttentionResult};
18
19pub type TokenId = u32;
21
22#[derive(Clone, Debug)]
24pub struct SpeculativeConfig {
25 pub gamma: usize,
27 pub temperature: f32,
29 pub top_p: f32,
32 pub max_seq_len: usize,
34}
35
36impl SpeculativeConfig {
37 pub fn new(gamma: usize) -> Self {
39 Self {
40 gamma,
41 temperature: 1.0,
42 top_p: 1.0,
43 max_seq_len: 2048,
44 }
45 }
46
47 pub fn validate(&self) -> AttentionResult<()> {
49 let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
50 if self.gamma == 0 {
51 return err("gamma must be > 0");
52 }
53 if self.gamma > 32 {
54 return err("gamma must be <= 32");
55 }
56 if self.temperature <= 0.0 {
57 return err("temperature must be > 0");
58 }
59 if self.top_p <= 0.0 || self.top_p > 1.0 {
60 return err("top_p must be in (0, 1]");
61 }
62 if self.max_seq_len == 0 {
63 return err("max_seq_len must be > 0");
64 }
65 Ok(())
66 }
67}
68
69pub trait DraftModel: Send + Sync {
71 fn draft_tokens(
77 &self,
78 prefix: &[TokenId],
79 gamma: usize,
80 ) -> Vec<(TokenId, f32)>;
81}
82
83pub trait TargetModel: Send + Sync {
85 fn verify_batch(
94 &self,
95 prefix: &[TokenId],
96 draft_tokens: &[TokenId],
97 ) -> Vec<Vec<(TokenId, f32)>>;
98}
99
100#[derive(Clone, Debug)]
102pub struct AcceptedTokens {
103 pub tokens: Vec<TokenId>,
105 pub acceptance_rate: f32,
107 pub draft_calls: usize,
109 pub target_calls: usize,
111}
112
113#[derive(Clone, Debug, Default)]
115pub struct DecodingStats {
116 pub tokens_generated: usize,
118 pub acceptance_rate: f32,
120 pub speedup_ratio: f32,
122 pub draft_latency_ms: f64,
124 pub target_latency_ms: f64,
126}
127
128pub fn theoretical_speedup(gamma: usize, acceptance_rate: f32) -> f32 {
136 let g = gamma as f32;
137 let a = acceptance_rate.clamp(0.0, 1.0);
138 let denominator = 1.0 + g * (1.0 - a);
139 if denominator <= 0.0 {
140 return 0.0;
141 }
142 (g * a) / denominator
143}
144
145pub struct SpeculativeDecoder;
147
148impl SpeculativeDecoder {
149 pub fn decode_step(
163 prefix: &[TokenId],
164 draft: &dyn DraftModel,
165 target: &dyn TargetModel,
166 config: &SpeculativeConfig,
167 rng_values: Option<&[f32]>,
168 ) -> AttentionResult<AcceptedTokens> {
169 config.validate()?;
170
171 let draft_results = draft.draft_tokens(prefix, config.gamma);
172 if draft_results.is_empty() {
173 return Err(AttentionError::EmptyInput(
174 "draft model returned no tokens".into(),
175 ));
176 }
177
178 let draft_tokens: Vec<TokenId> =
179 draft_results.iter().map(|(t, _)| *t).collect();
180 let draft_probs: Vec<f32> =
181 draft_results.iter().map(|(_, p)| *p).collect();
182
183 let target_dists = target.verify_batch(prefix, &draft_tokens);
184 if target_dists.len() < draft_tokens.len() + 1 {
185 return Err(AttentionError::ComputationError(
186 "target model must return gamma+1 distributions".into(),
187 ));
188 }
189
190 let mut accepted = Vec::new();
191 let mut rejected = false;
192
193 for i in 0..draft_tokens.len() {
194 let token = draft_tokens[i];
195 let q_i = draft_probs[i];
196 let p_i = prob_of_token(&target_dists[i], token);
197
198 let rng_val = rng_values
199 .and_then(|v| v.get(i).copied())
200 .unwrap_or(0.0);
201
202 if p_i >= q_i {
203 accepted.push(token);
205 } else if rng_val < p_i / q_i {
206 accepted.push(token);
208 } else {
209 let adjusted = sample_adjusted(
211 &target_dists[i],
212 &draft_tokens,
213 &draft_probs,
214 i,
215 );
216 accepted.push(adjusted);
217 rejected = true;
218 break;
219 }
220 }
221
222 if !rejected {
224 let bonus_dist = &target_dists[draft_tokens.len()];
225 if let Some(&(token, _)) = bonus_dist.first() {
226 accepted.push(token);
227 }
228 }
229
230 let num_draft = draft_tokens.len();
231 let num_accepted_from_draft = if rejected {
232 accepted.len().saturating_sub(1)
233 } else {
234 num_draft
235 };
236 let acceptance_rate = if num_draft > 0 {
237 num_accepted_from_draft as f32 / num_draft as f32
238 } else {
239 0.0
240 };
241
242 Ok(AcceptedTokens {
243 tokens: accepted,
244 acceptance_rate,
245 draft_calls: 1,
246 target_calls: 1,
247 })
248 }
249}
250
251fn prob_of_token(dist: &[(TokenId, f32)], token: TokenId) -> f32 {
253 dist.iter()
254 .find(|(t, _)| *t == token)
255 .map(|(_, p)| *p)
256 .unwrap_or(0.0)
257}
258
259fn sample_adjusted(
264 target_dist: &[(TokenId, f32)],
265 draft_tokens: &[TokenId],
266 draft_probs: &[f32],
267 position: usize,
268) -> TokenId {
269 let mut best_token = target_dist
270 .first()
271 .map(|(t, _)| *t)
272 .unwrap_or(0);
273 let mut best_score = f32::NEG_INFINITY;
274
275 for &(token, p_target) in target_dist {
276 let p_draft = if token == draft_tokens[position] {
277 draft_probs[position]
278 } else {
279 0.0
280 };
281 let adjusted = (p_target - p_draft).max(0.0);
282 if adjusted > best_score {
283 best_score = adjusted;
284 best_token = token;
285 }
286 }
287 best_token
288}
289
290pub trait MedusaHead: Send + Sync {
297 fn predict(&self, prefix: &[TokenId]) -> Vec<(TokenId, f32)>;
301}
302
303#[derive(Clone, Debug)]
305pub struct MedusaResult {
306 pub tokens: Vec<TokenId>,
308 pub paths_evaluated: usize,
310}
311
312pub fn medusa_decode(
318 prefix: &[TokenId],
319 heads: &[&dyn MedusaHead],
320 target: &dyn TargetModel,
321 config: &SpeculativeConfig,
322) -> AttentionResult<MedusaResult> {
323 config.validate()?;
324
325 if heads.is_empty() {
326 return Err(AttentionError::EmptyInput(
327 "at least one Medusa head required".into(),
328 ));
329 }
330
331 let head_predictions: Vec<Vec<(TokenId, f32)>> = heads
333 .iter()
334 .map(|h| h.predict(prefix))
335 .collect();
336
337 let candidate_path: Vec<TokenId> = head_predictions
339 .iter()
340 .filter_map(|dist| dist.first().map(|(t, _)| *t))
341 .collect();
342
343 if candidate_path.is_empty() {
344 return Err(AttentionError::EmptyInput(
345 "heads produced no predictions".into(),
346 ));
347 }
348
349 let target_dists = target.verify_batch(prefix, &candidate_path);
351
352 let mut accepted = Vec::new();
354 for (i, &token) in candidate_path.iter().enumerate() {
355 if i >= target_dists.len() {
356 break;
357 }
358 let p = prob_of_token(&target_dists[i], token);
359 if p > 0.0 {
360 accepted.push(token);
361 } else {
362 break;
363 }
364 }
365
366 if accepted.is_empty() {
368 if let Some(dist) = target_dists.first() {
369 if let Some(&(token, _)) = dist.first() {
370 accepted.push(token);
371 }
372 }
373 }
374
375 Ok(MedusaResult {
376 tokens: accepted,
377 paths_evaluated: 1, })
379}
380
381pub struct SimpleDraftModel {
387 pub tokens: Vec<TokenId>,
389 pub probability: f32,
391}
392
393impl DraftModel for SimpleDraftModel {
394 fn draft_tokens(
395 &self,
396 _prefix: &[TokenId],
397 gamma: usize,
398 ) -> Vec<(TokenId, f32)> {
399 (0..gamma)
400 .map(|i| {
401 let token = self.tokens[i % self.tokens.len()];
402 (token, self.probability)
403 })
404 .collect()
405 }
406}
407
408pub struct SimpleTargetModel {
410 pub distributions: Vec<Vec<(TokenId, f32)>>,
414}
415
416impl TargetModel for SimpleTargetModel {
417 fn verify_batch(
418 &self,
419 _prefix: &[TokenId],
420 draft_tokens: &[TokenId],
421 ) -> Vec<Vec<(TokenId, f32)>> {
422 let needed = draft_tokens.len() + 1;
423 (0..needed)
424 .map(|i| {
425 if i < self.distributions.len() {
426 self.distributions[i].clone()
427 } else {
428 self.distributions
429 .last()
430 .cloned()
431 .unwrap_or_else(|| vec![(0, 1.0)])
432 }
433 })
434 .collect()
435 }
436}
437
438pub struct SimpleMedusaHead {
440 pub token: TokenId,
442 pub probability: f32,
444}
445
446impl MedusaHead for SimpleMedusaHead {
447 fn predict(&self, _prefix: &[TokenId]) -> Vec<(TokenId, f32)> {
448 vec![(self.token, self.probability)]
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 fn default_config() -> SpeculativeConfig {
457 SpeculativeConfig::new(4)
458 }
459
460 #[test]
463 fn test_config_valid() {
464 assert!(default_config().validate().is_ok());
465 }
466
467 #[test]
468 fn test_config_gamma_zero() {
469 let mut cfg = default_config();
470 cfg.gamma = 0;
471 assert!(cfg.validate().is_err());
472 }
473
474 #[test]
475 fn test_config_gamma_too_large() {
476 let mut cfg = default_config();
477 cfg.gamma = 33;
478 assert!(cfg.validate().is_err());
479 }
480
481 #[test]
482 fn test_config_bad_temperature() {
483 let mut cfg = default_config();
484 cfg.temperature = 0.0;
485 assert!(cfg.validate().is_err());
486 }
487
488 #[test]
489 fn test_config_bad_top_p() {
490 let mut cfg = default_config();
491 cfg.top_p = 0.0;
492 assert!(cfg.validate().is_err());
493
494 cfg.top_p = 1.1;
495 assert!(cfg.validate().is_err());
496 }
497
498 #[test]
501 fn test_full_acceptance() {
502 let draft = SimpleDraftModel {
504 tokens: vec![10, 20, 30, 40],
505 probability: 0.5,
506 };
507 let target = SimpleTargetModel {
508 distributions: vec![
509 vec![(10, 0.8)],
510 vec![(20, 0.7)],
511 vec![(30, 0.6)],
512 vec![(40, 0.9)],
513 vec![(50, 1.0)], ],
515 };
516
517 let result = SpeculativeDecoder::decode_step(
518 &[1, 2, 3],
519 &draft,
520 &target,
521 &default_config(),
522 None,
523 )
524 .unwrap();
525
526 assert_eq!(result.tokens.len(), 5);
528 assert_eq!(result.tokens, vec![10, 20, 30, 40, 50]);
529 assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
530 }
531
532 #[test]
535 fn test_full_rejection() {
536 let draft = SimpleDraftModel {
538 tokens: vec![10, 20, 30, 40],
539 probability: 0.9,
540 };
541 let target = SimpleTargetModel {
543 distributions: vec![
544 vec![(99, 0.9)],
545 vec![(99, 0.9)],
546 vec![(99, 0.9)],
547 vec![(99, 0.9)],
548 vec![(99, 1.0)],
549 ],
550 };
551
552 let result = SpeculativeDecoder::decode_step(
553 &[1],
554 &draft,
555 &target,
556 &default_config(),
557 Some(&[1.0, 1.0, 1.0, 1.0]), )
559 .unwrap();
560
561 assert_eq!(result.tokens.len(), 1);
563 assert_eq!(result.tokens[0], 99);
564 assert!((result.acceptance_rate - 0.0).abs() < f32::EPSILON);
565 }
566
567 #[test]
570 fn test_partial_acceptance() {
571 let draft = SimpleDraftModel {
572 tokens: vec![10, 20, 30, 40],
573 probability: 0.5,
574 };
575 let target = SimpleTargetModel {
577 distributions: vec![
578 vec![(10, 0.8)],
579 vec![(20, 0.6)],
580 vec![(77, 0.9)], vec![(40, 0.9)],
582 vec![(50, 1.0)],
583 ],
584 };
585
586 let result = SpeculativeDecoder::decode_step(
587 &[1],
588 &draft,
589 &target,
590 &default_config(),
591 Some(&[0.0, 0.0, 1.0, 0.0]), )
593 .unwrap();
594
595 assert_eq!(result.tokens.len(), 3);
597 assert_eq!(result.tokens[0], 10);
598 assert_eq!(result.tokens[1], 20);
599 assert_eq!(result.tokens[2], 77);
600 assert!((result.acceptance_rate - 0.5).abs() < f32::EPSILON);
601 }
602
603 #[test]
606 fn test_rejection_sampling_distribution() {
607 let draft = SimpleDraftModel {
608 tokens: vec![10],
609 probability: 0.8,
610 };
611 let target = SimpleTargetModel {
615 distributions: vec![
616 vec![(10, 0.3), (42, 0.7)],
617 vec![(99, 1.0)],
618 ],
619 };
620
621 let cfg = SpeculativeConfig::new(1);
622 let result = SpeculativeDecoder::decode_step(
623 &[1],
624 &draft,
625 &target,
626 &cfg,
627 Some(&[1.0]), )
629 .unwrap();
630
631 assert_eq!(result.tokens.len(), 1);
632 assert_eq!(result.tokens[0], 42);
633 }
634
635 #[test]
638 fn test_theoretical_speedup() {
639 let s = theoretical_speedup(4, 1.0);
641 assert!((s - 4.0).abs() < 1e-5);
642
643 let s = theoretical_speedup(4, 0.0);
645 assert!(s.abs() < 1e-5);
646
647 let s = theoretical_speedup(4, 0.8);
649 assert!((s - 3.2 / 1.8).abs() < 1e-4);
650
651 let s = theoretical_speedup(8, 0.9);
653 assert!((s - 7.2 / 1.8).abs() < 1e-4);
654 }
655
656 #[test]
659 fn test_medusa_decode() {
660 let h1 = SimpleMedusaHead {
661 token: 10,
662 probability: 0.9,
663 };
664 let h2 = SimpleMedusaHead {
665 token: 20,
666 probability: 0.8,
667 };
668 let target = SimpleTargetModel {
669 distributions: vec![
670 vec![(10, 0.7)],
671 vec![(20, 0.6)],
672 vec![(99, 1.0)],
673 ],
674 };
675
676 let heads: Vec<&dyn MedusaHead> = vec![&h1, &h2];
677 let result =
678 medusa_decode(&[1, 2], &heads, &target, &default_config()).unwrap();
679
680 assert_eq!(result.tokens, vec![10, 20]);
681 assert_eq!(result.paths_evaluated, 1);
682 }
683
684 #[test]
685 fn test_medusa_no_heads() {
686 let target = SimpleTargetModel {
687 distributions: vec![vec![(1, 1.0)]],
688 };
689 let heads: Vec<&dyn MedusaHead> = vec![];
690 let result =
691 medusa_decode(&[1], &heads, &target, &default_config());
692 assert!(result.is_err());
693 }
694
695 #[test]
698 fn test_probabilistic_acceptance() {
699 let draft = SimpleDraftModel {
701 tokens: vec![10],
702 probability: 0.8,
703 };
704 let target = SimpleTargetModel {
705 distributions: vec![
706 vec![(10, 0.4)], vec![(99, 1.0)],
708 ],
709 };
710
711 let cfg = SpeculativeConfig::new(1);
712 let result = SpeculativeDecoder::decode_step(
714 &[1],
715 &draft,
716 &target,
717 &cfg,
718 Some(&[0.3]),
719 )
720 .unwrap();
721
722 assert_eq!(result.tokens, vec![10, 99]);
724 assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
725 }
726
727 #[test]
730 fn test_empty_prefix() {
731 let draft = SimpleDraftModel {
732 tokens: vec![5],
733 probability: 0.5,
734 };
735 let target = SimpleTargetModel {
736 distributions: vec![
737 vec![(5, 0.9)],
738 vec![(6, 1.0)],
739 ],
740 };
741
742 let cfg = SpeculativeConfig::new(1);
743 let result = SpeculativeDecoder::decode_step(
744 &[],
745 &draft,
746 &target,
747 &cfg,
748 None,
749 )
750 .unwrap();
751
752 assert_eq!(result.tokens, vec![5, 6]);
753 }
754}