1use crate::error::{AttentionError, AttentionResult};
18
19pub type TokenId = u32;
21
22#[derive(Clone, Debug)]
24pub struct SpeculativeConfig {
25 pub gamma: usize,
27 pub temperature: f32,
29 pub top_p: f32,
32 pub max_seq_len: usize,
34}
35
36impl SpeculativeConfig {
37 pub fn new(gamma: usize) -> Self {
39 Self {
40 gamma,
41 temperature: 1.0,
42 top_p: 1.0,
43 max_seq_len: 2048,
44 }
45 }
46
47 pub fn validate(&self) -> AttentionResult<()> {
49 let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
50 if self.gamma == 0 {
51 return err("gamma must be > 0");
52 }
53 if self.gamma > 32 {
54 return err("gamma must be <= 32");
55 }
56 if self.temperature <= 0.0 {
57 return err("temperature must be > 0");
58 }
59 if self.top_p <= 0.0 || self.top_p > 1.0 {
60 return err("top_p must be in (0, 1]");
61 }
62 if self.max_seq_len == 0 {
63 return err("max_seq_len must be > 0");
64 }
65 Ok(())
66 }
67}
68
69pub trait DraftModel: Send + Sync {
71 fn draft_tokens(&self, prefix: &[TokenId], gamma: usize) -> Vec<(TokenId, f32)>;
77}
78
79pub trait TargetModel: Send + Sync {
81 fn verify_batch(
90 &self,
91 prefix: &[TokenId],
92 draft_tokens: &[TokenId],
93 ) -> Vec<Vec<(TokenId, f32)>>;
94}
95
96#[derive(Clone, Debug)]
98pub struct AcceptedTokens {
99 pub tokens: Vec<TokenId>,
101 pub acceptance_rate: f32,
103 pub draft_calls: usize,
105 pub target_calls: usize,
107}
108
109#[derive(Clone, Debug, Default)]
111pub struct DecodingStats {
112 pub tokens_generated: usize,
114 pub acceptance_rate: f32,
116 pub speedup_ratio: f32,
118 pub draft_latency_ms: f64,
120 pub target_latency_ms: f64,
122}
123
124pub fn theoretical_speedup(gamma: usize, acceptance_rate: f32) -> f32 {
132 let g = gamma as f32;
133 let a = acceptance_rate.clamp(0.0, 1.0);
134 let denominator = 1.0 + g * (1.0 - a);
135 if denominator <= 0.0 {
136 return 0.0;
137 }
138 (g * a) / denominator
139}
140
141pub struct SpeculativeDecoder;
143
144impl SpeculativeDecoder {
145 pub fn decode_step(
159 prefix: &[TokenId],
160 draft: &dyn DraftModel,
161 target: &dyn TargetModel,
162 config: &SpeculativeConfig,
163 rng_values: Option<&[f32]>,
164 ) -> AttentionResult<AcceptedTokens> {
165 config.validate()?;
166
167 let draft_results = draft.draft_tokens(prefix, config.gamma);
168 if draft_results.is_empty() {
169 return Err(AttentionError::EmptyInput(
170 "draft model returned no tokens".into(),
171 ));
172 }
173
174 let draft_tokens: Vec<TokenId> = draft_results.iter().map(|(t, _)| *t).collect();
175 let draft_probs: Vec<f32> = draft_results.iter().map(|(_, p)| *p).collect();
176
177 let target_dists = target.verify_batch(prefix, &draft_tokens);
178 if target_dists.len() < draft_tokens.len() + 1 {
179 return Err(AttentionError::ComputationError(
180 "target model must return gamma+1 distributions".into(),
181 ));
182 }
183
184 let mut accepted = Vec::new();
185 let mut rejected = false;
186
187 for i in 0..draft_tokens.len() {
188 let token = draft_tokens[i];
189 let q_i = draft_probs[i];
190 let p_i = prob_of_token(&target_dists[i], token);
191
192 let rng_val = rng_values.and_then(|v| v.get(i).copied()).unwrap_or(0.0);
193
194 if p_i >= q_i {
195 accepted.push(token);
197 } else if rng_val < p_i / q_i {
198 accepted.push(token);
200 } else {
201 let adjusted = sample_adjusted(&target_dists[i], &draft_tokens, &draft_probs, i);
203 accepted.push(adjusted);
204 rejected = true;
205 break;
206 }
207 }
208
209 if !rejected {
211 let bonus_dist = &target_dists[draft_tokens.len()];
212 if let Some(&(token, _)) = bonus_dist.first() {
213 accepted.push(token);
214 }
215 }
216
217 let num_draft = draft_tokens.len();
218 let num_accepted_from_draft = if rejected {
219 accepted.len().saturating_sub(1)
220 } else {
221 num_draft
222 };
223 let acceptance_rate = if num_draft > 0 {
224 num_accepted_from_draft as f32 / num_draft as f32
225 } else {
226 0.0
227 };
228
229 Ok(AcceptedTokens {
230 tokens: accepted,
231 acceptance_rate,
232 draft_calls: 1,
233 target_calls: 1,
234 })
235 }
236}
237
238fn prob_of_token(dist: &[(TokenId, f32)], token: TokenId) -> f32 {
240 dist.iter()
241 .find(|(t, _)| *t == token)
242 .map(|(_, p)| *p)
243 .unwrap_or(0.0)
244}
245
246fn sample_adjusted(
251 target_dist: &[(TokenId, f32)],
252 draft_tokens: &[TokenId],
253 draft_probs: &[f32],
254 position: usize,
255) -> TokenId {
256 let mut best_token = target_dist.first().map(|(t, _)| *t).unwrap_or(0);
257 let mut best_score = f32::NEG_INFINITY;
258
259 for &(token, p_target) in target_dist {
260 let p_draft = if token == draft_tokens[position] {
261 draft_probs[position]
262 } else {
263 0.0
264 };
265 let adjusted = (p_target - p_draft).max(0.0);
266 if adjusted > best_score {
267 best_score = adjusted;
268 best_token = token;
269 }
270 }
271 best_token
272}
273
274pub trait MedusaHead: Send + Sync {
281 fn predict(&self, prefix: &[TokenId]) -> Vec<(TokenId, f32)>;
285}
286
287#[derive(Clone, Debug)]
289pub struct MedusaResult {
290 pub tokens: Vec<TokenId>,
292 pub paths_evaluated: usize,
294}
295
296pub fn medusa_decode(
302 prefix: &[TokenId],
303 heads: &[&dyn MedusaHead],
304 target: &dyn TargetModel,
305 config: &SpeculativeConfig,
306) -> AttentionResult<MedusaResult> {
307 config.validate()?;
308
309 if heads.is_empty() {
310 return Err(AttentionError::EmptyInput(
311 "at least one Medusa head required".into(),
312 ));
313 }
314
315 let head_predictions: Vec<Vec<(TokenId, f32)>> =
317 heads.iter().map(|h| h.predict(prefix)).collect();
318
319 let candidate_path: Vec<TokenId> = head_predictions
321 .iter()
322 .filter_map(|dist| dist.first().map(|(t, _)| *t))
323 .collect();
324
325 if candidate_path.is_empty() {
326 return Err(AttentionError::EmptyInput(
327 "heads produced no predictions".into(),
328 ));
329 }
330
331 let target_dists = target.verify_batch(prefix, &candidate_path);
333
334 let mut accepted = Vec::new();
336 for (i, &token) in candidate_path.iter().enumerate() {
337 if i >= target_dists.len() {
338 break;
339 }
340 let p = prob_of_token(&target_dists[i], token);
341 if p > 0.0 {
342 accepted.push(token);
343 } else {
344 break;
345 }
346 }
347
348 if accepted.is_empty() {
350 if let Some(dist) = target_dists.first() {
351 if let Some(&(token, _)) = dist.first() {
352 accepted.push(token);
353 }
354 }
355 }
356
357 Ok(MedusaResult {
358 tokens: accepted,
359 paths_evaluated: 1, })
361}
362
363pub struct SimpleDraftModel {
369 pub tokens: Vec<TokenId>,
371 pub probability: f32,
373}
374
375impl DraftModel for SimpleDraftModel {
376 fn draft_tokens(&self, _prefix: &[TokenId], gamma: usize) -> Vec<(TokenId, f32)> {
377 (0..gamma)
378 .map(|i| {
379 let token = self.tokens[i % self.tokens.len()];
380 (token, self.probability)
381 })
382 .collect()
383 }
384}
385
386pub struct SimpleTargetModel {
388 pub distributions: Vec<Vec<(TokenId, f32)>>,
392}
393
394impl TargetModel for SimpleTargetModel {
395 fn verify_batch(
396 &self,
397 _prefix: &[TokenId],
398 draft_tokens: &[TokenId],
399 ) -> Vec<Vec<(TokenId, f32)>> {
400 let needed = draft_tokens.len() + 1;
401 (0..needed)
402 .map(|i| {
403 if i < self.distributions.len() {
404 self.distributions[i].clone()
405 } else {
406 self.distributions
407 .last()
408 .cloned()
409 .unwrap_or_else(|| vec![(0, 1.0)])
410 }
411 })
412 .collect()
413 }
414}
415
416pub struct SimpleMedusaHead {
418 pub token: TokenId,
420 pub probability: f32,
422}
423
424impl MedusaHead for SimpleMedusaHead {
425 fn predict(&self, _prefix: &[TokenId]) -> Vec<(TokenId, f32)> {
426 vec![(self.token, self.probability)]
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 fn default_config() -> SpeculativeConfig {
435 SpeculativeConfig::new(4)
436 }
437
438 #[test]
441 fn test_config_valid() {
442 assert!(default_config().validate().is_ok());
443 }
444
445 #[test]
446 fn test_config_gamma_zero() {
447 let mut cfg = default_config();
448 cfg.gamma = 0;
449 assert!(cfg.validate().is_err());
450 }
451
452 #[test]
453 fn test_config_gamma_too_large() {
454 let mut cfg = default_config();
455 cfg.gamma = 33;
456 assert!(cfg.validate().is_err());
457 }
458
459 #[test]
460 fn test_config_bad_temperature() {
461 let mut cfg = default_config();
462 cfg.temperature = 0.0;
463 assert!(cfg.validate().is_err());
464 }
465
466 #[test]
467 fn test_config_bad_top_p() {
468 let mut cfg = default_config();
469 cfg.top_p = 0.0;
470 assert!(cfg.validate().is_err());
471
472 cfg.top_p = 1.1;
473 assert!(cfg.validate().is_err());
474 }
475
476 #[test]
479 fn test_full_acceptance() {
480 let draft = SimpleDraftModel {
482 tokens: vec![10, 20, 30, 40],
483 probability: 0.5,
484 };
485 let target = SimpleTargetModel {
486 distributions: vec![
487 vec![(10, 0.8)],
488 vec![(20, 0.7)],
489 vec![(30, 0.6)],
490 vec![(40, 0.9)],
491 vec![(50, 1.0)], ],
493 };
494
495 let result =
496 SpeculativeDecoder::decode_step(&[1, 2, 3], &draft, &target, &default_config(), None)
497 .unwrap();
498
499 assert_eq!(result.tokens.len(), 5);
501 assert_eq!(result.tokens, vec![10, 20, 30, 40, 50]);
502 assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
503 }
504
505 #[test]
508 fn test_full_rejection() {
509 let draft = SimpleDraftModel {
511 tokens: vec![10, 20, 30, 40],
512 probability: 0.9,
513 };
514 let target = SimpleTargetModel {
516 distributions: vec![
517 vec![(99, 0.9)],
518 vec![(99, 0.9)],
519 vec![(99, 0.9)],
520 vec![(99, 0.9)],
521 vec![(99, 1.0)],
522 ],
523 };
524
525 let result = SpeculativeDecoder::decode_step(
526 &[1],
527 &draft,
528 &target,
529 &default_config(),
530 Some(&[1.0, 1.0, 1.0, 1.0]), )
532 .unwrap();
533
534 assert_eq!(result.tokens.len(), 1);
536 assert_eq!(result.tokens[0], 99);
537 assert!((result.acceptance_rate - 0.0).abs() < f32::EPSILON);
538 }
539
540 #[test]
543 fn test_partial_acceptance() {
544 let draft = SimpleDraftModel {
545 tokens: vec![10, 20, 30, 40],
546 probability: 0.5,
547 };
548 let target = SimpleTargetModel {
550 distributions: vec![
551 vec![(10, 0.8)],
552 vec![(20, 0.6)],
553 vec![(77, 0.9)], vec![(40, 0.9)],
555 vec![(50, 1.0)],
556 ],
557 };
558
559 let result = SpeculativeDecoder::decode_step(
560 &[1],
561 &draft,
562 &target,
563 &default_config(),
564 Some(&[0.0, 0.0, 1.0, 0.0]), )
566 .unwrap();
567
568 assert_eq!(result.tokens.len(), 3);
570 assert_eq!(result.tokens[0], 10);
571 assert_eq!(result.tokens[1], 20);
572 assert_eq!(result.tokens[2], 77);
573 assert!((result.acceptance_rate - 0.5).abs() < f32::EPSILON);
574 }
575
576 #[test]
579 fn test_rejection_sampling_distribution() {
580 let draft = SimpleDraftModel {
581 tokens: vec![10],
582 probability: 0.8,
583 };
584 let target = SimpleTargetModel {
588 distributions: vec![vec![(10, 0.3), (42, 0.7)], vec![(99, 1.0)]],
589 };
590
591 let cfg = SpeculativeConfig::new(1);
592 let result = SpeculativeDecoder::decode_step(
593 &[1],
594 &draft,
595 &target,
596 &cfg,
597 Some(&[1.0]), )
599 .unwrap();
600
601 assert_eq!(result.tokens.len(), 1);
602 assert_eq!(result.tokens[0], 42);
603 }
604
605 #[test]
608 fn test_theoretical_speedup() {
609 let s = theoretical_speedup(4, 1.0);
611 assert!((s - 4.0).abs() < 1e-5);
612
613 let s = theoretical_speedup(4, 0.0);
615 assert!(s.abs() < 1e-5);
616
617 let s = theoretical_speedup(4, 0.8);
619 assert!((s - 3.2 / 1.8).abs() < 1e-4);
620
621 let s = theoretical_speedup(8, 0.9);
623 assert!((s - 7.2 / 1.8).abs() < 1e-4);
624 }
625
626 #[test]
629 fn test_medusa_decode() {
630 let h1 = SimpleMedusaHead {
631 token: 10,
632 probability: 0.9,
633 };
634 let h2 = SimpleMedusaHead {
635 token: 20,
636 probability: 0.8,
637 };
638 let target = SimpleTargetModel {
639 distributions: vec![vec![(10, 0.7)], vec![(20, 0.6)], vec![(99, 1.0)]],
640 };
641
642 let heads: Vec<&dyn MedusaHead> = vec![&h1, &h2];
643 let result = medusa_decode(&[1, 2], &heads, &target, &default_config()).unwrap();
644
645 assert_eq!(result.tokens, vec![10, 20]);
646 assert_eq!(result.paths_evaluated, 1);
647 }
648
649 #[test]
650 fn test_medusa_no_heads() {
651 let target = SimpleTargetModel {
652 distributions: vec![vec![(1, 1.0)]],
653 };
654 let heads: Vec<&dyn MedusaHead> = vec![];
655 let result = medusa_decode(&[1], &heads, &target, &default_config());
656 assert!(result.is_err());
657 }
658
659 #[test]
662 fn test_probabilistic_acceptance() {
663 let draft = SimpleDraftModel {
665 tokens: vec![10],
666 probability: 0.8,
667 };
668 let target = SimpleTargetModel {
669 distributions: vec![
670 vec![(10, 0.4)], vec![(99, 1.0)],
672 ],
673 };
674
675 let cfg = SpeculativeConfig::new(1);
676 let result =
678 SpeculativeDecoder::decode_step(&[1], &draft, &target, &cfg, Some(&[0.3])).unwrap();
679
680 assert_eq!(result.tokens, vec![10, 99]);
682 assert!((result.acceptance_rate - 1.0).abs() < f32::EPSILON);
683 }
684
685 #[test]
688 fn test_empty_prefix() {
689 let draft = SimpleDraftModel {
690 tokens: vec![5],
691 probability: 0.5,
692 };
693 let target = SimpleTargetModel {
694 distributions: vec![vec![(5, 0.9)], vec![(6, 1.0)]],
695 };
696
697 let cfg = SpeculativeConfig::new(1);
698 let result = SpeculativeDecoder::decode_step(&[], &draft, &target, &cfg, None).unwrap();
699
700 assert_eq!(result.tokens, vec![5, 6]);
701 }
702}