tensorlogic_trustformers/speculative_decoding/
error.rs1use thiserror::Error;
9
10#[derive(Debug, Clone, Error, PartialEq)]
12pub enum SpeculativeDecodingError {
13 #[error(
15 "vocab size mismatch between draft ({draft}) and target ({target}) \
16 speculative-decoding models"
17 )]
18 VocabMismatch { draft: usize, target: usize },
19
20 #[error("distribution row width mismatch: expected {expected}, got {got}")]
22 DistributionWidthMismatch { expected: usize, got: usize },
23
24 #[error(
27 "draft proposal shape mismatch: tokens={tokens}, token_logprobs={logprobs}, \
28 distributions={distributions}"
29 )]
30 DraftShapeMismatch {
31 tokens: usize,
32 logprobs: usize,
33 distributions: usize,
34 },
35
36 #[error("target verification shape mismatch: expected {expected} rows (k+1), got {got}")]
38 TargetShapeMismatch { expected: usize, got: usize },
39
40 #[error("invalid configuration: {0}")]
43 InvalidConfig(String),
44
45 #[error("speculative decoding was invoked with an empty prefix")]
48 EmptyPrefix,
49
50 #[error("token id {token} is out of range for vocabulary size {vocab_size}")]
52 TokenOutOfRange { token: usize, vocab_size: usize },
53
54 #[error("no mass left in adjusted distribution and target fallback is also zero")]
60 DegenerateDistribution,
61
62 #[error("model error: {0}")]
64 ModelError(String),
65}
66
67pub type SpeculativeDecodingResult<T> = Result<T, SpeculativeDecodingError>;
69
70impl From<SpeculativeDecodingError> for crate::error::TrustformerError {
71 fn from(err: SpeculativeDecodingError) -> Self {
72 crate::error::TrustformerError::CompilationError(err.to_string())
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn display_contains_context() {
82 let err = SpeculativeDecodingError::VocabMismatch {
83 draft: 10,
84 target: 20,
85 };
86 let msg = err.to_string();
87 assert!(msg.contains("10"));
88 assert!(msg.contains("20"));
89 assert!(msg.contains("vocab"));
90 }
91
92 #[test]
93 fn bridges_into_trustformer_error() {
94 let err = SpeculativeDecodingError::InvalidConfig("k must be > 0".into());
95 let bridged: crate::error::TrustformerError = err.into();
96 assert!(bridged.to_string().contains("k must be > 0"));
97 }
98
99 #[test]
100 fn empty_prefix_is_distinct() {
101 let err = SpeculativeDecodingError::EmptyPrefix;
102 assert!(err.to_string().contains("empty prefix"));
103 }
104}