Skip to main content

tensorlogic_trustformers/speculative_decoding/
engine.rs

1//! Speculative decoding engine.
2//!
3//! [`SpeculativeDecoder`] composes a [`DraftModel`] and a [`TargetModel`] and
4//! runs the Leviathan / Chen speculative generation loop:
5//!
6//! ```text
7//! loop until max_tokens reached:
8//!   1. draft.propose(prefix, k)       → k candidate tokens + distributions
9//!   2. target.verify(prefix, draft)   → k+1 target distributions
10//!   3. for i in 0..k:
11//!        if accept(p_draft_i, p_target_i, rng):
12//!            append draft[i]
13//!        else:
14//!            append resample_from_adjusted_target(p_target_i, p_draft_i, rng)
15//!            break
16//!   4. if all k accepted:
17//!        append sample_from_logprobs(p_target_{k+1}, rng)  (bonus)
18//!   5. update metrics; continue.
19//! ```
20//!
21//! The **number of tokens appended per round** is therefore in `1..=k+1`,
22//! and — crucially — the marginal distribution of each appended token is
23//! *identical* to `p_target(prefix)`.  That correctness is what the empirical
24//! chi-square test in `tests.rs` validates against 10 000 samples.
25
26use std::marker::PhantomData;
27
28use scirs2_core::random::{SeedableRng, StdRng};
29
30use crate::speculative_decoding::acceptance::{
31    accept, resample_from_adjusted_target, sample_from_logprobs,
32};
33use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
34use crate::speculative_decoding::metrics::SpeculativeMetrics;
35use crate::speculative_decoding::rng::SpecRng;
36use crate::speculative_decoding::traits::{
37    DraftModel, DraftProposal, TargetModel, TargetScores, TokenId,
38};
39
40/// Configuration for the speculative decoder.
41#[derive(Debug, Clone, PartialEq)]
42pub struct SpeculativeDecoderConfig {
43    /// Number of draft tokens to propose per round (default `4`).
44    pub k: usize,
45    /// Cost ratio `c_draft / c_target` for speedup modeling (default `0.125`).
46    pub cost_ratio: f32,
47    /// If `true`, the engine halts the generation loop on the first
48    /// `eos_token` it emits.
49    pub stop_on_eos: bool,
50    /// Optional end-of-sequence token id (ignored unless `stop_on_eos`).
51    pub eos_token: Option<TokenId>,
52}
53
54impl Default for SpeculativeDecoderConfig {
55    fn default() -> Self {
56        Self {
57            k: 4,
58            cost_ratio: 0.125,
59            stop_on_eos: false,
60            eos_token: None,
61        }
62    }
63}
64
65impl SpeculativeDecoderConfig {
66    /// Convenience builder: set draft depth.
67    pub fn with_k(mut self, k: usize) -> Self {
68        self.k = k;
69        self
70    }
71
72    /// Convenience builder: set cost ratio for the speedup estimate.
73    pub fn with_cost_ratio(mut self, r: f32) -> Self {
74        self.cost_ratio = r;
75        self
76    }
77
78    /// Convenience builder: attach an EOS token and enable early stopping.
79    pub fn with_eos(mut self, eos: TokenId) -> Self {
80        self.eos_token = Some(eos);
81        self.stop_on_eos = true;
82        self
83    }
84
85    /// Validate the configuration, returning an [`SpeculativeDecodingError`]
86    /// if any invariant is violated.
87    pub fn validate(&self) -> SpeculativeDecodingResult<()> {
88        if self.k == 0 {
89            return Err(SpeculativeDecodingError::InvalidConfig(
90                "draft depth `k` must be at least 1".into(),
91            ));
92        }
93        Ok(())
94    }
95}
96
97/// Speculative decoder composing a draft model and a target model.
98///
99/// Trait bounds are intentionally deferred to the `impl` blocks rather than
100/// baked into the struct definition so callers can hold a
101/// `SpeculativeDecoder` whose inner models do not implement `Debug`.
102pub struct SpeculativeDecoder<D: DraftModel, T: TargetModel> {
103    draft: D,
104    target: T,
105    config: SpeculativeDecoderConfig,
106    metrics: SpeculativeMetrics,
107    _pd: PhantomData<()>,
108}
109
110impl<D: DraftModel + std::fmt::Debug, T: TargetModel + std::fmt::Debug> std::fmt::Debug
111    for SpeculativeDecoder<D, T>
112{
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("SpeculativeDecoder")
115            .field("draft", &self.draft)
116            .field("target", &self.target)
117            .field("config", &self.config)
118            .field("metrics", &self.metrics)
119            .finish()
120    }
121}
122
123impl<D: DraftModel, T: TargetModel> SpeculativeDecoder<D, T> {
124    /// Build a decoder.  Returns an error if the draft and target disagree on
125    /// vocabulary size, or if the config is invalid.
126    pub fn new(
127        draft: D,
128        target: T,
129        config: SpeculativeDecoderConfig,
130    ) -> SpeculativeDecodingResult<Self> {
131        config.validate()?;
132        if draft.vocab_size() != target.vocab_size() {
133            return Err(SpeculativeDecodingError::VocabMismatch {
134                draft: draft.vocab_size(),
135                target: target.vocab_size(),
136            });
137        }
138        let metrics = SpeculativeMetrics::new().with_cost_ratio(config.cost_ratio);
139        Ok(Self {
140            draft,
141            target,
142            config,
143            metrics,
144            _pd: PhantomData,
145        })
146    }
147
148    /// Read-only access to the current metrics snapshot.
149    pub fn metrics(&self) -> &SpeculativeMetrics {
150        &self.metrics
151    }
152
153    /// Reset the metrics counters.
154    pub fn reset_metrics(&mut self) {
155        self.metrics.reset();
156    }
157
158    /// Read-only access to the configuration.
159    pub fn config(&self) -> &SpeculativeDecoderConfig {
160        &self.config
161    }
162
163    /// Run speculative decoding starting from `prefix` and producing at most
164    /// `max_tokens` *new* tokens.  The returned vector contains **only** the
165    /// generated continuation, not the original prefix.
166    ///
167    /// Uses an internally-seeded deterministic [`StdRng`] (seed `42`).
168    /// See [`Self::generate_with_rng`] for caller-controlled seeding.
169    pub fn generate(
170        &mut self,
171        prefix: &[TokenId],
172        max_tokens: usize,
173    ) -> SpeculativeDecodingResult<Vec<TokenId>> {
174        let mut rng = StdRng::seed_from_u64(42);
175        self.generate_with_rng(prefix, max_tokens, &mut rng)
176    }
177
178    /// Run speculative decoding with a caller-supplied RNG.
179    pub fn generate_with_rng(
180        &mut self,
181        prefix: &[TokenId],
182        max_tokens: usize,
183        rng: &mut dyn SpecRng,
184    ) -> SpeculativeDecodingResult<Vec<TokenId>> {
185        if prefix.is_empty() {
186            return Err(SpeculativeDecodingError::EmptyPrefix);
187        }
188
189        let vocab = self.draft.vocab_size();
190        let k = self.config.k;
191        let mut working = prefix.to_vec();
192        let mut output: Vec<TokenId> = Vec::with_capacity(max_tokens);
193
194        while output.len() < max_tokens {
195            let remaining = max_tokens - output.len();
196            let round_k = k.min(remaining.max(1));
197
198            let proposal = self.draft.propose(&working, round_k, rng)?;
199            validate_proposal(&proposal, round_k, vocab)?;
200
201            let target_scores = self.target.verify(&working, &proposal.tokens)?;
202            validate_target_scores(&target_scores, round_k, vocab)?;
203
204            let (accepted_count, emitted) =
205                run_rejection_loop(&proposal, &target_scores, round_k, vocab, rng)?;
206
207            let mut committed_this_round = 0u32;
208            for token in emitted.into_iter() {
209                output.push(token);
210                working.push(token);
211                committed_this_round += 1;
212                if output.len() >= max_tokens {
213                    break;
214                }
215                if self.config.stop_on_eos
216                    && self
217                        .config
218                        .eos_token
219                        .map(|eos| eos == token)
220                        .unwrap_or(false)
221                {
222                    break;
223                }
224            }
225
226            self.metrics.record_round(
227                round_k as u32,
228                accepted_count as u32,
229                committed_this_round,
230                round_k as u32,
231            );
232
233            if self.config.stop_on_eos {
234                if let Some(eos) = self.config.eos_token {
235                    if output.last().copied() == Some(eos) {
236                        break;
237                    }
238                }
239            }
240        }
241
242        Ok(output)
243    }
244}
245
246/// Validate the shape and content of a draft proposal.
247fn validate_proposal(p: &DraftProposal, k: usize, vocab: usize) -> SpeculativeDecodingResult<()> {
248    if p.tokens.len() != k || p.token_logprobs.len() != k || p.distributions.len() != k {
249        return Err(SpeculativeDecodingError::DraftShapeMismatch {
250            tokens: p.tokens.len(),
251            logprobs: p.token_logprobs.len(),
252            distributions: p.distributions.len(),
253        });
254    }
255    for row in &p.distributions {
256        if row.len() != vocab {
257            return Err(SpeculativeDecodingError::DistributionWidthMismatch {
258                expected: vocab,
259                got: row.len(),
260            });
261        }
262    }
263    for &t in &p.tokens {
264        if t >= vocab {
265            return Err(SpeculativeDecodingError::TokenOutOfRange {
266                token: t,
267                vocab_size: vocab,
268            });
269        }
270    }
271    Ok(())
272}
273
274/// Validate the shape and content of target-verification scores.
275fn validate_target_scores(
276    t: &TargetScores,
277    k: usize,
278    vocab: usize,
279) -> SpeculativeDecodingResult<()> {
280    if t.distributions.len() != k + 1 {
281        return Err(SpeculativeDecodingError::TargetShapeMismatch {
282            expected: k + 1,
283            got: t.distributions.len(),
284        });
285    }
286    for row in &t.distributions {
287        if row.len() != vocab {
288            return Err(SpeculativeDecodingError::DistributionWidthMismatch {
289                expected: vocab,
290                got: row.len(),
291            });
292        }
293    }
294    Ok(())
295}
296
297/// Execute the acceptance / rejection sweep for a single speculative round.
298///
299/// Returns the tuple `(accepted_count, emitted_tokens)` where `emitted_tokens`
300/// has length in `1..=k+1`.
301fn run_rejection_loop(
302    proposal: &DraftProposal,
303    target_scores: &TargetScores,
304    k: usize,
305    vocab: usize,
306    rng: &mut dyn SpecRng,
307) -> SpeculativeDecodingResult<(usize, Vec<TokenId>)> {
308    let mut emitted: Vec<TokenId> = Vec::with_capacity(k + 1);
309    let mut accepted: usize = 0;
310
311    for i in 0..k {
312        let draft_token = proposal.tokens[i];
313        let target_row = &target_scores.distributions[i];
314        let draft_row = &proposal.distributions[i];
315
316        let draft_lp = draft_row[draft_token];
317        let target_lp = target_row[draft_token];
318
319        if accept(draft_lp, target_lp, rng) {
320            emitted.push(draft_token);
321            accepted += 1;
322            continue;
323        }
324
325        // Rejection: resample from adjusted target distribution.
326        let resampled = resample_from_adjusted_target(target_row, draft_row, rng)?;
327        if resampled >= vocab {
328            return Err(SpeculativeDecodingError::TokenOutOfRange {
329                token: resampled,
330                vocab_size: vocab,
331            });
332        }
333        emitted.push(resampled);
334        return Ok((accepted, emitted));
335    }
336
337    // All k accepted — draw bonus token from target's (k+1)-th distribution.
338    let bonus_row = &target_scores.distributions[k];
339    let bonus = sample_from_logprobs(bonus_row, rng)?;
340    if bonus >= vocab {
341        return Err(SpeculativeDecodingError::TokenOutOfRange {
342            token: bonus,
343            vocab_size: vocab,
344        });
345    }
346    emitted.push(bonus);
347    Ok((accepted, emitted))
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn config_default_is_sensible() {
356        let c = SpeculativeDecoderConfig::default();
357        assert_eq!(c.k, 4);
358        assert!(c.validate().is_ok());
359    }
360
361    #[test]
362    fn config_k_zero_rejected() {
363        let c = SpeculativeDecoderConfig::default().with_k(0);
364        assert!(c.validate().is_err());
365    }
366
367    #[test]
368    fn config_builders_compose() {
369        let c = SpeculativeDecoderConfig::default()
370            .with_k(2)
371            .with_cost_ratio(0.05)
372            .with_eos(7);
373        assert_eq!(c.k, 2);
374        assert!((c.cost_ratio - 0.05).abs() < 1e-6);
375        assert_eq!(c.eos_token, Some(7));
376        assert!(c.stop_on_eos);
377    }
378}