tensorlogic_trustformers/speculative_decoding/
engine.rs1use std::marker::PhantomData;
27
28use scirs2_core::random::{SeedableRng, StdRng};
29
30use crate::speculative_decoding::acceptance::{
31 accept, resample_from_adjusted_target, sample_from_logprobs,
32};
33use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
34use crate::speculative_decoding::metrics::SpeculativeMetrics;
35use crate::speculative_decoding::rng::SpecRng;
36use crate::speculative_decoding::traits::{
37 DraftModel, DraftProposal, TargetModel, TargetScores, TokenId,
38};
39
40#[derive(Debug, Clone, PartialEq)]
42pub struct SpeculativeDecoderConfig {
43 pub k: usize,
45 pub cost_ratio: f32,
47 pub stop_on_eos: bool,
50 pub eos_token: Option<TokenId>,
52}
53
54impl Default for SpeculativeDecoderConfig {
55 fn default() -> Self {
56 Self {
57 k: 4,
58 cost_ratio: 0.125,
59 stop_on_eos: false,
60 eos_token: None,
61 }
62 }
63}
64
65impl SpeculativeDecoderConfig {
66 pub fn with_k(mut self, k: usize) -> Self {
68 self.k = k;
69 self
70 }
71
72 pub fn with_cost_ratio(mut self, r: f32) -> Self {
74 self.cost_ratio = r;
75 self
76 }
77
78 pub fn with_eos(mut self, eos: TokenId) -> Self {
80 self.eos_token = Some(eos);
81 self.stop_on_eos = true;
82 self
83 }
84
85 pub fn validate(&self) -> SpeculativeDecodingResult<()> {
88 if self.k == 0 {
89 return Err(SpeculativeDecodingError::InvalidConfig(
90 "draft depth `k` must be at least 1".into(),
91 ));
92 }
93 Ok(())
94 }
95}
96
97pub struct SpeculativeDecoder<D: DraftModel, T: TargetModel> {
103 draft: D,
104 target: T,
105 config: SpeculativeDecoderConfig,
106 metrics: SpeculativeMetrics,
107 _pd: PhantomData<()>,
108}
109
110impl<D: DraftModel + std::fmt::Debug, T: TargetModel + std::fmt::Debug> std::fmt::Debug
111 for SpeculativeDecoder<D, T>
112{
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("SpeculativeDecoder")
115 .field("draft", &self.draft)
116 .field("target", &self.target)
117 .field("config", &self.config)
118 .field("metrics", &self.metrics)
119 .finish()
120 }
121}
122
123impl<D: DraftModel, T: TargetModel> SpeculativeDecoder<D, T> {
124 pub fn new(
127 draft: D,
128 target: T,
129 config: SpeculativeDecoderConfig,
130 ) -> SpeculativeDecodingResult<Self> {
131 config.validate()?;
132 if draft.vocab_size() != target.vocab_size() {
133 return Err(SpeculativeDecodingError::VocabMismatch {
134 draft: draft.vocab_size(),
135 target: target.vocab_size(),
136 });
137 }
138 let metrics = SpeculativeMetrics::new().with_cost_ratio(config.cost_ratio);
139 Ok(Self {
140 draft,
141 target,
142 config,
143 metrics,
144 _pd: PhantomData,
145 })
146 }
147
148 pub fn metrics(&self) -> &SpeculativeMetrics {
150 &self.metrics
151 }
152
153 pub fn reset_metrics(&mut self) {
155 self.metrics.reset();
156 }
157
158 pub fn config(&self) -> &SpeculativeDecoderConfig {
160 &self.config
161 }
162
163 pub fn generate(
170 &mut self,
171 prefix: &[TokenId],
172 max_tokens: usize,
173 ) -> SpeculativeDecodingResult<Vec<TokenId>> {
174 let mut rng = StdRng::seed_from_u64(42);
175 self.generate_with_rng(prefix, max_tokens, &mut rng)
176 }
177
178 pub fn generate_with_rng(
180 &mut self,
181 prefix: &[TokenId],
182 max_tokens: usize,
183 rng: &mut dyn SpecRng,
184 ) -> SpeculativeDecodingResult<Vec<TokenId>> {
185 if prefix.is_empty() {
186 return Err(SpeculativeDecodingError::EmptyPrefix);
187 }
188
189 let vocab = self.draft.vocab_size();
190 let k = self.config.k;
191 let mut working = prefix.to_vec();
192 let mut output: Vec<TokenId> = Vec::with_capacity(max_tokens);
193
194 while output.len() < max_tokens {
195 let remaining = max_tokens - output.len();
196 let round_k = k.min(remaining.max(1));
197
198 let proposal = self.draft.propose(&working, round_k, rng)?;
199 validate_proposal(&proposal, round_k, vocab)?;
200
201 let target_scores = self.target.verify(&working, &proposal.tokens)?;
202 validate_target_scores(&target_scores, round_k, vocab)?;
203
204 let (accepted_count, emitted) =
205 run_rejection_loop(&proposal, &target_scores, round_k, vocab, rng)?;
206
207 let mut committed_this_round = 0u32;
208 for token in emitted.into_iter() {
209 output.push(token);
210 working.push(token);
211 committed_this_round += 1;
212 if output.len() >= max_tokens {
213 break;
214 }
215 if self.config.stop_on_eos
216 && self
217 .config
218 .eos_token
219 .map(|eos| eos == token)
220 .unwrap_or(false)
221 {
222 break;
223 }
224 }
225
226 self.metrics.record_round(
227 round_k as u32,
228 accepted_count as u32,
229 committed_this_round,
230 round_k as u32,
231 );
232
233 if self.config.stop_on_eos {
234 if let Some(eos) = self.config.eos_token {
235 if output.last().copied() == Some(eos) {
236 break;
237 }
238 }
239 }
240 }
241
242 Ok(output)
243 }
244}
245
246fn validate_proposal(p: &DraftProposal, k: usize, vocab: usize) -> SpeculativeDecodingResult<()> {
248 if p.tokens.len() != k || p.token_logprobs.len() != k || p.distributions.len() != k {
249 return Err(SpeculativeDecodingError::DraftShapeMismatch {
250 tokens: p.tokens.len(),
251 logprobs: p.token_logprobs.len(),
252 distributions: p.distributions.len(),
253 });
254 }
255 for row in &p.distributions {
256 if row.len() != vocab {
257 return Err(SpeculativeDecodingError::DistributionWidthMismatch {
258 expected: vocab,
259 got: row.len(),
260 });
261 }
262 }
263 for &t in &p.tokens {
264 if t >= vocab {
265 return Err(SpeculativeDecodingError::TokenOutOfRange {
266 token: t,
267 vocab_size: vocab,
268 });
269 }
270 }
271 Ok(())
272}
273
274fn validate_target_scores(
276 t: &TargetScores,
277 k: usize,
278 vocab: usize,
279) -> SpeculativeDecodingResult<()> {
280 if t.distributions.len() != k + 1 {
281 return Err(SpeculativeDecodingError::TargetShapeMismatch {
282 expected: k + 1,
283 got: t.distributions.len(),
284 });
285 }
286 for row in &t.distributions {
287 if row.len() != vocab {
288 return Err(SpeculativeDecodingError::DistributionWidthMismatch {
289 expected: vocab,
290 got: row.len(),
291 });
292 }
293 }
294 Ok(())
295}
296
297fn run_rejection_loop(
302 proposal: &DraftProposal,
303 target_scores: &TargetScores,
304 k: usize,
305 vocab: usize,
306 rng: &mut dyn SpecRng,
307) -> SpeculativeDecodingResult<(usize, Vec<TokenId>)> {
308 let mut emitted: Vec<TokenId> = Vec::with_capacity(k + 1);
309 let mut accepted: usize = 0;
310
311 for i in 0..k {
312 let draft_token = proposal.tokens[i];
313 let target_row = &target_scores.distributions[i];
314 let draft_row = &proposal.distributions[i];
315
316 let draft_lp = draft_row[draft_token];
317 let target_lp = target_row[draft_token];
318
319 if accept(draft_lp, target_lp, rng) {
320 emitted.push(draft_token);
321 accepted += 1;
322 continue;
323 }
324
325 let resampled = resample_from_adjusted_target(target_row, draft_row, rng)?;
327 if resampled >= vocab {
328 return Err(SpeculativeDecodingError::TokenOutOfRange {
329 token: resampled,
330 vocab_size: vocab,
331 });
332 }
333 emitted.push(resampled);
334 return Ok((accepted, emitted));
335 }
336
337 let bonus_row = &target_scores.distributions[k];
339 let bonus = sample_from_logprobs(bonus_row, rng)?;
340 if bonus >= vocab {
341 return Err(SpeculativeDecodingError::TokenOutOfRange {
342 token: bonus,
343 vocab_size: vocab,
344 });
345 }
346 emitted.push(bonus);
347 Ok((accepted, emitted))
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn config_default_is_sensible() {
356 let c = SpeculativeDecoderConfig::default();
357 assert_eq!(c.k, 4);
358 assert!(c.validate().is_ok());
359 }
360
361 #[test]
362 fn config_k_zero_rejected() {
363 let c = SpeculativeDecoderConfig::default().with_k(0);
364 assert!(c.validate().is_err());
365 }
366
367 #[test]
368 fn config_builders_compose() {
369 let c = SpeculativeDecoderConfig::default()
370 .with_k(2)
371 .with_cost_ratio(0.05)
372 .with_eos(7);
373 assert_eq!(c.k, 2);
374 assert!((c.cost_ratio - 0.05).abs() < 1e-6);
375 assert_eq!(c.eos_token, Some(7));
376 assert!(c.stop_on_eos);
377 }
378}