Skip to main content

tensorlogic_trustformers/speculative_decoding/
mock_models.rs

1//! Deterministic mock models for testing speculative decoding.
2//!
3//! These mocks return **fixed** categorical distributions independent of the
4//! supplied prefix, which is exactly what the Leviathan theorem needs to make
5//! empirical-distribution tests tractable: if target and draft both ignore
6//! history, then the marginal distribution over emitted tokens is precisely
7//! `p_target`, and the engine's output can be chi-squared against `p_target`.
8//!
9//! Two mock kinds are exposed:
10//!
11//! * [`FixedDistDraftModel`] / [`FixedDistTargetModel`] — constant log-probs
12//!   over a tiny vocabulary.  Used in both the module-private unit tests and
13//!   the crate-level integration test.
14//! * [`MockDraftModel`] / [`MockTargetModel`] — thin typedefs exported for
15//!   downstream consumers that want a "just works" pair.
16
17use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
18use crate::speculative_decoding::rng::SpecRng;
19use crate::speculative_decoding::traits::{
20    DraftModel, DraftProposal, LogProb, TargetModel, TargetScores, TokenId,
21};
22
23/// Draft model that always returns the same categorical distribution and
24/// samples tokens from it.
25#[derive(Debug, Clone)]
26pub struct FixedDistDraftModel {
27    probs: Vec<f64>,
28    logprobs: Vec<LogProb>,
29}
30
31impl FixedDistDraftModel {
32    /// Build a mock draft from a linear-space probability vector.  The vector
33    /// must be non-empty and sum to ≈ 1.
34    pub fn new(probs: Vec<f64>) -> SpeculativeDecodingResult<Self> {
35        if probs.is_empty() {
36            return Err(SpeculativeDecodingError::InvalidConfig(
37                "FixedDistDraftModel requires a non-empty probability vector".into(),
38            ));
39        }
40        let sum: f64 = probs.iter().copied().sum();
41        if !(sum > 0.0 && sum.is_finite()) {
42            return Err(SpeculativeDecodingError::InvalidConfig(
43                "FixedDistDraftModel probabilities must have positive finite mass".into(),
44            ));
45        }
46        let normalized: Vec<f64> = probs.iter().map(|p| *p / sum).collect();
47        let logprobs: Vec<f64> = normalized
48            .iter()
49            .map(|p| if *p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
50            .collect();
51        Ok(Self {
52            probs: normalized,
53            logprobs,
54        })
55    }
56
57    /// Access the (normalized) linear-space probabilities.
58    pub fn probs(&self) -> &[f64] {
59        &self.probs
60    }
61
62    /// Access the log-probability row.
63    pub fn logprobs(&self) -> &[LogProb] {
64        &self.logprobs
65    }
66}
67
68impl DraftModel for FixedDistDraftModel {
69    fn vocab_size(&self) -> usize {
70        self.probs.len()
71    }
72
73    fn propose(
74        &self,
75        _prefix: &[TokenId],
76        k: usize,
77        rng: &mut dyn SpecRng,
78    ) -> SpeculativeDecodingResult<DraftProposal> {
79        let mut tokens = Vec::with_capacity(k);
80        let mut token_logprobs = Vec::with_capacity(k);
81        let mut distributions = Vec::with_capacity(k);
82        for _ in 0..k {
83            let idx = sample_categorical(&self.probs, rng)?;
84            tokens.push(idx);
85            token_logprobs.push(self.logprobs[idx]);
86            distributions.push(self.logprobs.clone());
87        }
88        Ok(DraftProposal {
89            tokens,
90            token_logprobs,
91            distributions,
92        })
93    }
94}
95
96/// Target model returning the same categorical distribution for every
97/// position.
98#[derive(Debug, Clone)]
99pub struct FixedDistTargetModel {
100    probs: Vec<f64>,
101    logprobs: Vec<LogProb>,
102}
103
104impl FixedDistTargetModel {
105    /// Build a mock target from a linear-space probability vector.  See
106    /// [`FixedDistDraftModel::new`] for invariants.
107    pub fn new(probs: Vec<f64>) -> SpeculativeDecodingResult<Self> {
108        if probs.is_empty() {
109            return Err(SpeculativeDecodingError::InvalidConfig(
110                "FixedDistTargetModel requires a non-empty probability vector".into(),
111            ));
112        }
113        let sum: f64 = probs.iter().copied().sum();
114        if !(sum > 0.0 && sum.is_finite()) {
115            return Err(SpeculativeDecodingError::InvalidConfig(
116                "FixedDistTargetModel probabilities must have positive finite mass".into(),
117            ));
118        }
119        let normalized: Vec<f64> = probs.iter().map(|p| *p / sum).collect();
120        let logprobs: Vec<f64> = normalized
121            .iter()
122            .map(|p| if *p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
123            .collect();
124        Ok(Self {
125            probs: normalized,
126            logprobs,
127        })
128    }
129
130    /// Access the (normalized) linear-space probabilities.
131    pub fn probs(&self) -> &[f64] {
132        &self.probs
133    }
134
135    /// Access the log-probability row.
136    pub fn logprobs(&self) -> &[LogProb] {
137        &self.logprobs
138    }
139}
140
141impl TargetModel for FixedDistTargetModel {
142    fn vocab_size(&self) -> usize {
143        self.probs.len()
144    }
145
146    fn verify(
147        &self,
148        _prefix: &[TokenId],
149        draft_tokens: &[TokenId],
150    ) -> SpeculativeDecodingResult<TargetScores> {
151        let rows = draft_tokens.len() + 1;
152        let distributions: Vec<Vec<LogProb>> = (0..rows).map(|_| self.logprobs.clone()).collect();
153        Ok(TargetScores { distributions })
154    }
155}
156
157/// Helper: sample a categorical index via inverse-CDF against `rng`.
158pub(crate) fn sample_categorical(
159    probs: &[f64],
160    rng: &mut dyn SpecRng,
161) -> SpeculativeDecodingResult<TokenId> {
162    if probs.is_empty() {
163        return Err(SpeculativeDecodingError::DegenerateDistribution);
164    }
165    let u = rng.next_unit_f64();
166    let mut cum = 0.0;
167    for (i, p) in probs.iter().enumerate() {
168        cum += *p;
169        if u < cum {
170            return Ok(i);
171        }
172    }
173    Ok(probs.len() - 1)
174}
175
176/// Public alias used by tests: fixed-distribution draft model.
177pub type MockDraftModel = FixedDistDraftModel;
178
179/// Public alias used by tests: fixed-distribution target model.
180pub type MockTargetModel = FixedDistTargetModel;
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use scirs2_core::random::{SeedableRng, StdRng};
186
187    #[test]
188    fn draft_model_normalizes_input() {
189        let d = FixedDistDraftModel::new(vec![2.0, 2.0]).expect("normalize");
190        for p in d.probs() {
191            assert!((p - 0.5).abs() < 1e-9);
192        }
193    }
194
195    #[test]
196    fn draft_model_rejects_empty() {
197        let r = FixedDistDraftModel::new(vec![]);
198        assert!(r.is_err());
199    }
200
201    #[test]
202    fn draft_model_rejects_zero_mass() {
203        let r = FixedDistDraftModel::new(vec![0.0, 0.0, 0.0]);
204        assert!(r.is_err());
205    }
206
207    #[test]
208    fn propose_shapes_are_consistent() {
209        let d = FixedDistDraftModel::new(vec![0.25; 4]).expect("d");
210        let mut rng = StdRng::seed_from_u64(1);
211        let p = d.propose(&[0, 1, 2], 3, &mut rng).expect("propose");
212        assert_eq!(p.tokens.len(), 3);
213        assert_eq!(p.token_logprobs.len(), 3);
214        assert_eq!(p.distributions.len(), 3);
215        for row in &p.distributions {
216            assert_eq!(row.len(), 4);
217        }
218    }
219
220    #[test]
221    fn verify_returns_k_plus_one_rows() {
222        let t = FixedDistTargetModel::new(vec![0.25; 4]).expect("t");
223        let ts = t.verify(&[0, 1], &[1, 2, 3]).expect("verify");
224        assert_eq!(ts.distributions.len(), 4);
225        for row in &ts.distributions {
226            assert_eq!(row.len(), 4);
227        }
228    }
229
230    #[test]
231    fn propose_is_reproducible_with_seed() {
232        let d = FixedDistDraftModel::new(vec![0.1, 0.2, 0.3, 0.4]).expect("d");
233        let mut r1 = StdRng::seed_from_u64(7);
234        let mut r2 = StdRng::seed_from_u64(7);
235        let p1 = d.propose(&[0], 8, &mut r1).expect("p1");
236        let p2 = d.propose(&[0], 8, &mut r2).expect("p2");
237        assert_eq!(p1.tokens, p2.tokens);
238    }
239}