tensorlogic_trustformers/speculative_decoding/traits.rs
1//! Core abstractions for model-level speculative decoding.
2//!
3//! The speculative decoding protocol of Leviathan et al. (2023) and
4//! Chen et al. (2023) is driven by two cooperating models:
5//!
6//! * A cheap **draft model** that extends the current prefix by *k* candidate
7//! tokens (cf. [`DraftModel`]).
8//! * An expensive **target model** that, in a single parallel forward pass,
9//! scores every prefix-continuation of the draft (cf. [`TargetModel`]).
10//!
11//! The engine then runs the Bernoulli acceptance test
12//! `accept = min(1, p_target / p_draft)` position by position, re-sampling
13//! the first rejection from the adjusted target distribution
14//! `max(0, p_target - p_draft)` (see `acceptance.rs`). Because this trait
15//! layer never references probabilities in linear space directly — the engine
16//! converts between log-probs and probs at the call sites — we can host the
17//! draft/target models on CPU or GPU, eager or graph-compiled, without
18//! leaking those concerns into the acceptance math.
19//!
20//! ## Why full distributions, not just token log-probs
21//!
22//! The naive signature
23//! `propose(prefix, k) -> Vec<(TokenId, LogProb)>`
24//! collapses each step into a single `(token, logprob)` pair. That is
25//! insufficient: the adjusted re-sampling distribution
26//! `max(0, p_target - p_draft)` is defined over the **entire vocabulary**,
27//! so both draft and target must return full per-position distributions.
28//! The trait shapes encode this explicitly.
29//!
30//! ## Invariants enforced by [`DraftProposal`] / [`TargetScores`]
31//!
32//! * `tokens.len() == distributions.len() == k`.
33//! * `distributions[i].len() == vocab_size` for the configured vocab.
34//! * Every `LogProb` row is normalized (log-sum-exp ≈ 0). The traits do not
35//! re-normalize — it is the implementation's responsibility.
36//!
37//! The engine defensively checks shapes at runtime and short-circuits with a
38//! [`crate::speculative_decoding::SpeculativeDecodingError`] if anything is malformed.
39
40use crate::speculative_decoding::error::SpeculativeDecodingResult;
41
42/// Vocabulary-scoped token identifier.
43///
44/// Matches the convention used by `rule_guided_decoder::TokenId` so the two
45/// decoders can share mappers in future work.
46pub type TokenId = usize;
47
48/// Natural-log probability. We deliberately use a type alias rather than a
49/// newtype so callers can freely mix with `f64` arithmetic; the engine is the
50/// only place where domains matter (log vs. linear) and it converts locally.
51pub type LogProb = f64;
52
53/// Output of a single [`DraftModel::propose`] call.
54///
55/// Fields are aligned index-wise: `tokens[i]` is the draft's sampled token at
56/// step *i*, `token_logprobs[i]` is its log-probability under the draft, and
57/// `distributions[i]` is the draft's **full** log-probability row over the
58/// vocabulary for that step (needed by the engine for the rejection test and
59/// the adjusted re-sampling).
60#[derive(Debug, Clone, PartialEq)]
61pub struct DraftProposal {
62 /// The `k` tokens the draft model proposes to extend the prefix with.
63 pub tokens: Vec<TokenId>,
64 /// Log-probability of each chosen token under the draft distribution.
65 pub token_logprobs: Vec<LogProb>,
66 /// Full per-step log-probability rows — `distributions[i]` has length
67 /// `vocab_size` and sums (in linear space) to 1.
68 pub distributions: Vec<Vec<LogProb>>,
69}
70
71impl DraftProposal {
72 /// Length of the proposal (number of draft positions, commonly `k`).
73 pub fn len(&self) -> usize {
74 self.tokens.len()
75 }
76
77 /// Is the proposal empty (no tokens drafted)?
78 pub fn is_empty(&self) -> bool {
79 self.tokens.is_empty()
80 }
81}
82
83/// Output of a single [`TargetModel::verify`] call.
84///
85/// For `k` draft tokens the target must return `k + 1` distributions: the
86/// first `k` at the draft-covered positions (used by the acceptance test),
87/// plus one **bonus** distribution at position `k + 1` that the engine uses
88/// if every draft token is accepted — see Leviathan et al. 2023 §3.2.
89#[derive(Debug, Clone, PartialEq)]
90pub struct TargetScores {
91 /// `k + 1` log-probability rows, each of length `vocab_size`.
92 pub distributions: Vec<Vec<LogProb>>,
93}
94
95impl TargetScores {
96 /// Number of positions scored (always `k + 1` in canonical usage).
97 pub fn len(&self) -> usize {
98 self.distributions.len()
99 }
100
101 /// Are there no scored positions at all?
102 pub fn is_empty(&self) -> bool {
103 self.distributions.is_empty()
104 }
105}
106
107/// A model capable of *cheaply* extending a prefix by `k` tokens while
108/// exposing full vocabulary distributions at every step.
109///
110/// Implementations must be deterministic w.r.t. the supplied RNG so that the
111/// engine's empirical-distribution tests are reproducible.
112pub trait DraftModel {
113 /// Vocabulary cardinality the model emits log-probs over.
114 fn vocab_size(&self) -> usize;
115
116 /// Extend `prefix` by `k` draft tokens; return the chosen tokens, their
117 /// log-probabilities and the **full** per-position distributions.
118 ///
119 /// Note `rng` is a dyn-compatible shim: the callee can down-mix it into
120 /// whatever PRNG it likes internally, but the engine always drives a
121 /// single `StdRng` to keep the acceptance branch of the algorithm
122 /// reproducible.
123 fn propose(
124 &self,
125 prefix: &[TokenId],
126 k: usize,
127 rng: &mut dyn crate::speculative_decoding::rng::SpecRng,
128 ) -> SpeculativeDecodingResult<DraftProposal>;
129}
130
131/// A model that, given a prefix and up to `k` draft continuations, returns
132/// per-position distributions (as log-probs) in a single forward pass.
133pub trait TargetModel {
134 /// Vocabulary cardinality the target emits log-probs over. Must match
135 /// the draft's `vocab_size()`.
136 fn vocab_size(&self) -> usize;
137
138 /// Score `prefix` concatenated with `draft_tokens`: return `k + 1`
139 /// distributions (the `k` draft-covered positions plus the bonus).
140 fn verify(
141 &self,
142 prefix: &[TokenId],
143 draft_tokens: &[TokenId],
144 ) -> SpeculativeDecodingResult<TargetScores>;
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn draft_proposal_len_matches_tokens() {
153 let p = DraftProposal {
154 tokens: vec![1, 2, 3],
155 token_logprobs: vec![-0.1, -0.2, -0.3],
156 distributions: vec![vec![-0.1; 4], vec![-0.2; 4], vec![-0.3; 4]],
157 };
158 assert_eq!(p.len(), 3);
159 assert!(!p.is_empty());
160 }
161
162 #[test]
163 fn empty_proposal_is_empty() {
164 let p = DraftProposal {
165 tokens: vec![],
166 token_logprobs: vec![],
167 distributions: vec![],
168 };
169 assert!(p.is_empty());
170 assert_eq!(p.len(), 0);
171 }
172
173 #[test]
174 fn target_scores_len_matches_rows() {
175 let t = TargetScores {
176 distributions: vec![vec![-0.5; 4], vec![-0.5; 4]],
177 };
178 assert_eq!(t.len(), 2);
179 assert!(!t.is_empty());
180 }
181}