tensorlogic_trustformers/speculative_decoding/
mock_models.rs1use crate::speculative_decoding::error::{SpeculativeDecodingError, SpeculativeDecodingResult};
18use crate::speculative_decoding::rng::SpecRng;
19use crate::speculative_decoding::traits::{
20 DraftModel, DraftProposal, LogProb, TargetModel, TargetScores, TokenId,
21};
22
23#[derive(Debug, Clone)]
26pub struct FixedDistDraftModel {
27 probs: Vec<f64>,
28 logprobs: Vec<LogProb>,
29}
30
31impl FixedDistDraftModel {
32 pub fn new(probs: Vec<f64>) -> SpeculativeDecodingResult<Self> {
35 if probs.is_empty() {
36 return Err(SpeculativeDecodingError::InvalidConfig(
37 "FixedDistDraftModel requires a non-empty probability vector".into(),
38 ));
39 }
40 let sum: f64 = probs.iter().copied().sum();
41 if !(sum > 0.0 && sum.is_finite()) {
42 return Err(SpeculativeDecodingError::InvalidConfig(
43 "FixedDistDraftModel probabilities must have positive finite mass".into(),
44 ));
45 }
46 let normalized: Vec<f64> = probs.iter().map(|p| *p / sum).collect();
47 let logprobs: Vec<f64> = normalized
48 .iter()
49 .map(|p| if *p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
50 .collect();
51 Ok(Self {
52 probs: normalized,
53 logprobs,
54 })
55 }
56
57 pub fn probs(&self) -> &[f64] {
59 &self.probs
60 }
61
62 pub fn logprobs(&self) -> &[LogProb] {
64 &self.logprobs
65 }
66}
67
68impl DraftModel for FixedDistDraftModel {
69 fn vocab_size(&self) -> usize {
70 self.probs.len()
71 }
72
73 fn propose(
74 &self,
75 _prefix: &[TokenId],
76 k: usize,
77 rng: &mut dyn SpecRng,
78 ) -> SpeculativeDecodingResult<DraftProposal> {
79 let mut tokens = Vec::with_capacity(k);
80 let mut token_logprobs = Vec::with_capacity(k);
81 let mut distributions = Vec::with_capacity(k);
82 for _ in 0..k {
83 let idx = sample_categorical(&self.probs, rng)?;
84 tokens.push(idx);
85 token_logprobs.push(self.logprobs[idx]);
86 distributions.push(self.logprobs.clone());
87 }
88 Ok(DraftProposal {
89 tokens,
90 token_logprobs,
91 distributions,
92 })
93 }
94}
95
96#[derive(Debug, Clone)]
99pub struct FixedDistTargetModel {
100 probs: Vec<f64>,
101 logprobs: Vec<LogProb>,
102}
103
104impl FixedDistTargetModel {
105 pub fn new(probs: Vec<f64>) -> SpeculativeDecodingResult<Self> {
108 if probs.is_empty() {
109 return Err(SpeculativeDecodingError::InvalidConfig(
110 "FixedDistTargetModel requires a non-empty probability vector".into(),
111 ));
112 }
113 let sum: f64 = probs.iter().copied().sum();
114 if !(sum > 0.0 && sum.is_finite()) {
115 return Err(SpeculativeDecodingError::InvalidConfig(
116 "FixedDistTargetModel probabilities must have positive finite mass".into(),
117 ));
118 }
119 let normalized: Vec<f64> = probs.iter().map(|p| *p / sum).collect();
120 let logprobs: Vec<f64> = normalized
121 .iter()
122 .map(|p| if *p > 0.0 { p.ln() } else { f64::NEG_INFINITY })
123 .collect();
124 Ok(Self {
125 probs: normalized,
126 logprobs,
127 })
128 }
129
130 pub fn probs(&self) -> &[f64] {
132 &self.probs
133 }
134
135 pub fn logprobs(&self) -> &[LogProb] {
137 &self.logprobs
138 }
139}
140
141impl TargetModel for FixedDistTargetModel {
142 fn vocab_size(&self) -> usize {
143 self.probs.len()
144 }
145
146 fn verify(
147 &self,
148 _prefix: &[TokenId],
149 draft_tokens: &[TokenId],
150 ) -> SpeculativeDecodingResult<TargetScores> {
151 let rows = draft_tokens.len() + 1;
152 let distributions: Vec<Vec<LogProb>> = (0..rows).map(|_| self.logprobs.clone()).collect();
153 Ok(TargetScores { distributions })
154 }
155}
156
157pub(crate) fn sample_categorical(
159 probs: &[f64],
160 rng: &mut dyn SpecRng,
161) -> SpeculativeDecodingResult<TokenId> {
162 if probs.is_empty() {
163 return Err(SpeculativeDecodingError::DegenerateDistribution);
164 }
165 let u = rng.next_unit_f64();
166 let mut cum = 0.0;
167 for (i, p) in probs.iter().enumerate() {
168 cum += *p;
169 if u < cum {
170 return Ok(i);
171 }
172 }
173 Ok(probs.len() - 1)
174}
175
176pub type MockDraftModel = FixedDistDraftModel;
178
179pub type MockTargetModel = FixedDistTargetModel;
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use scirs2_core::random::{SeedableRng, StdRng};
186
187 #[test]
188 fn draft_model_normalizes_input() {
189 let d = FixedDistDraftModel::new(vec![2.0, 2.0]).expect("normalize");
190 for p in d.probs() {
191 assert!((p - 0.5).abs() < 1e-9);
192 }
193 }
194
195 #[test]
196 fn draft_model_rejects_empty() {
197 let r = FixedDistDraftModel::new(vec![]);
198 assert!(r.is_err());
199 }
200
201 #[test]
202 fn draft_model_rejects_zero_mass() {
203 let r = FixedDistDraftModel::new(vec![0.0, 0.0, 0.0]);
204 assert!(r.is_err());
205 }
206
207 #[test]
208 fn propose_shapes_are_consistent() {
209 let d = FixedDistDraftModel::new(vec![0.25; 4]).expect("d");
210 let mut rng = StdRng::seed_from_u64(1);
211 let p = d.propose(&[0, 1, 2], 3, &mut rng).expect("propose");
212 assert_eq!(p.tokens.len(), 3);
213 assert_eq!(p.token_logprobs.len(), 3);
214 assert_eq!(p.distributions.len(), 3);
215 for row in &p.distributions {
216 assert_eq!(row.len(), 4);
217 }
218 }
219
220 #[test]
221 fn verify_returns_k_plus_one_rows() {
222 let t = FixedDistTargetModel::new(vec![0.25; 4]).expect("t");
223 let ts = t.verify(&[0, 1], &[1, 2, 3]).expect("verify");
224 assert_eq!(ts.distributions.len(), 4);
225 for row in &ts.distributions {
226 assert_eq!(row.len(), 4);
227 }
228 }
229
230 #[test]
231 fn propose_is_reproducible_with_seed() {
232 let d = FixedDistDraftModel::new(vec![0.1, 0.2, 0.3, 0.4]).expect("d");
233 let mut r1 = StdRng::seed_from_u64(7);
234 let mut r2 = StdRng::seed_from_u64(7);
235 let p1 = d.propose(&[0], 8, &mut r1).expect("p1");
236 let p2 = d.propose(&[0], 8, &mut r2).expect("p2");
237 assert_eq!(p1.tokens, p2.tokens);
238 }
239}