Skip to main content

tensorlogic_trustformers/speculative_decoding/
traits.rs

1//! Core abstractions for model-level speculative decoding.
2//!
3//! The speculative decoding protocol of Leviathan et al. (2023) and
4//! Chen et al. (2023) is driven by two cooperating models:
5//!
6//! * A cheap **draft model** that extends the current prefix by *k* candidate
7//!   tokens (cf. [`DraftModel`]).
8//! * An expensive **target model** that, in a single parallel forward pass,
9//!   scores every prefix-continuation of the draft (cf. [`TargetModel`]).
10//!
11//! The engine then runs the Bernoulli acceptance test
12//! `accept = min(1, p_target / p_draft)` position by position, re-sampling
13//! the first rejection from the adjusted target distribution
14//! `max(0, p_target - p_draft)` (see `acceptance.rs`).  Because this trait
15//! layer never references probabilities in linear space directly — the engine
16//! converts between log-probs and probs at the call sites — we can host the
17//! draft/target models on CPU or GPU, eager or graph-compiled, without
18//! leaking those concerns into the acceptance math.
19//!
20//! ## Why full distributions, not just token log-probs
21//!
22//! The naive signature
23//! `propose(prefix, k) -> Vec<(TokenId, LogProb)>`
24//! collapses each step into a single `(token, logprob)` pair.  That is
25//! insufficient: the adjusted re-sampling distribution
26//! `max(0, p_target - p_draft)` is defined over the **entire vocabulary**,
27//! so both draft and target must return full per-position distributions.
28//! The trait shapes encode this explicitly.
29//!
30//! ## Invariants enforced by [`DraftProposal`] / [`TargetScores`]
31//!
32//! * `tokens.len() == distributions.len() == k`.
33//! * `distributions[i].len() == vocab_size` for the configured vocab.
34//! * Every `LogProb` row is normalized (log-sum-exp ≈ 0).  The traits do not
35//!   re-normalize — it is the implementation's responsibility.
36//!
37//! The engine defensively checks shapes at runtime and short-circuits with a
38//! [`crate::speculative_decoding::SpeculativeDecodingError`] if anything is malformed.
39
40use crate::speculative_decoding::error::SpeculativeDecodingResult;
41
42/// Vocabulary-scoped token identifier.
43///
44/// Matches the convention used by `rule_guided_decoder::TokenId` so the two
45/// decoders can share mappers in future work.
46pub type TokenId = usize;
47
48/// Natural-log probability.  We deliberately use a type alias rather than a
49/// newtype so callers can freely mix with `f64` arithmetic; the engine is the
50/// only place where domains matter (log vs. linear) and it converts locally.
51pub type LogProb = f64;
52
53/// Output of a single [`DraftModel::propose`] call.
54///
55/// Fields are aligned index-wise: `tokens[i]` is the draft's sampled token at
56/// step *i*, `token_logprobs[i]` is its log-probability under the draft, and
57/// `distributions[i]` is the draft's **full** log-probability row over the
58/// vocabulary for that step (needed by the engine for the rejection test and
59/// the adjusted re-sampling).
60#[derive(Debug, Clone, PartialEq)]
61pub struct DraftProposal {
62    /// The `k` tokens the draft model proposes to extend the prefix with.
63    pub tokens: Vec<TokenId>,
64    /// Log-probability of each chosen token under the draft distribution.
65    pub token_logprobs: Vec<LogProb>,
66    /// Full per-step log-probability rows — `distributions[i]` has length
67    /// `vocab_size` and sums (in linear space) to 1.
68    pub distributions: Vec<Vec<LogProb>>,
69}
70
71impl DraftProposal {
72    /// Length of the proposal (number of draft positions, commonly `k`).
73    pub fn len(&self) -> usize {
74        self.tokens.len()
75    }
76
77    /// Is the proposal empty (no tokens drafted)?
78    pub fn is_empty(&self) -> bool {
79        self.tokens.is_empty()
80    }
81}
82
83/// Output of a single [`TargetModel::verify`] call.
84///
85/// For `k` draft tokens the target must return `k + 1` distributions: the
86/// first `k` at the draft-covered positions (used by the acceptance test),
87/// plus one **bonus** distribution at position `k + 1` that the engine uses
88/// if every draft token is accepted — see Leviathan et al. 2023 §3.2.
89#[derive(Debug, Clone, PartialEq)]
90pub struct TargetScores {
91    /// `k + 1` log-probability rows, each of length `vocab_size`.
92    pub distributions: Vec<Vec<LogProb>>,
93}
94
95impl TargetScores {
96    /// Number of positions scored (always `k + 1` in canonical usage).
97    pub fn len(&self) -> usize {
98        self.distributions.len()
99    }
100
101    /// Are there no scored positions at all?
102    pub fn is_empty(&self) -> bool {
103        self.distributions.is_empty()
104    }
105}
106
107/// A model capable of *cheaply* extending a prefix by `k` tokens while
108/// exposing full vocabulary distributions at every step.
109///
110/// Implementations must be deterministic w.r.t. the supplied RNG so that the
111/// engine's empirical-distribution tests are reproducible.
112pub trait DraftModel {
113    /// Vocabulary cardinality the model emits log-probs over.
114    fn vocab_size(&self) -> usize;
115
116    /// Extend `prefix` by `k` draft tokens; return the chosen tokens, their
117    /// log-probabilities and the **full** per-position distributions.
118    ///
119    /// Note `rng` is a dyn-compatible shim: the callee can down-mix it into
120    /// whatever PRNG it likes internally, but the engine always drives a
121    /// single `StdRng` to keep the acceptance branch of the algorithm
122    /// reproducible.
123    fn propose(
124        &self,
125        prefix: &[TokenId],
126        k: usize,
127        rng: &mut dyn crate::speculative_decoding::rng::SpecRng,
128    ) -> SpeculativeDecodingResult<DraftProposal>;
129}
130
131/// A model that, given a prefix and up to `k` draft continuations, returns
132/// per-position distributions (as log-probs) in a single forward pass.
133pub trait TargetModel {
134    /// Vocabulary cardinality the target emits log-probs over.  Must match
135    /// the draft's `vocab_size()`.
136    fn vocab_size(&self) -> usize;
137
138    /// Score `prefix` concatenated with `draft_tokens`: return `k + 1`
139    /// distributions (the `k` draft-covered positions plus the bonus).
140    fn verify(
141        &self,
142        prefix: &[TokenId],
143        draft_tokens: &[TokenId],
144    ) -> SpeculativeDecodingResult<TargetScores>;
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn draft_proposal_len_matches_tokens() {
153        let p = DraftProposal {
154            tokens: vec![1, 2, 3],
155            token_logprobs: vec![-0.1, -0.2, -0.3],
156            distributions: vec![vec![-0.1; 4], vec![-0.2; 4], vec![-0.3; 4]],
157        };
158        assert_eq!(p.len(), 3);
159        assert!(!p.is_empty());
160    }
161
162    #[test]
163    fn empty_proposal_is_empty() {
164        let p = DraftProposal {
165            tokens: vec![],
166            token_logprobs: vec![],
167            distributions: vec![],
168        };
169        assert!(p.is_empty());
170        assert_eq!(p.len(), 0);
171    }
172
173    #[test]
174    fn target_scores_len_matches_rows() {
175        let t = TargetScores {
176            distributions: vec![vec![-0.5; 4], vec![-0.5; 4]],
177        };
178        assert_eq!(t.len(), 2);
179        assert!(!t.is_empty());
180    }
181}