Skip to main content

tensorlogic_trustformers/speculative_decoding/
error.rs

1//! Error taxonomy for the speculative decoder.
2//!
3//! Kept local (not merged into [`crate::error::TrustformerError`]) so that
4//! decoder-internal diagnostics do not pollute the public transformer error
5//! enum.  A [`From`] bridge forwards errors to the crate-wide type when the
6//! decoder is invoked from higher-level code.
7
8use thiserror::Error;
9
10/// Errors that can be raised during speculative decoding.
11#[derive(Debug, Clone, Error, PartialEq)]
12pub enum SpeculativeDecodingError {
13    /// A draft and target model disagreed on `vocab_size`.
14    #[error(
15        "vocab size mismatch between draft ({draft}) and target ({target}) \
16         speculative-decoding models"
17    )]
18    VocabMismatch { draft: usize, target: usize },
19
20    /// A distribution row had the wrong width.
21    #[error("distribution row width mismatch: expected {expected}, got {got}")]
22    DistributionWidthMismatch { expected: usize, got: usize },
23
24    /// The draft model returned a proposal whose `tokens` / `token_logprobs`
25    /// / `distributions` vectors disagreed in length.
26    #[error(
27        "draft proposal shape mismatch: tokens={tokens}, token_logprobs={logprobs}, \
28         distributions={distributions}"
29    )]
30    DraftShapeMismatch {
31        tokens: usize,
32        logprobs: usize,
33        distributions: usize,
34    },
35
36    /// The target model returned the wrong number of distribution rows.
37    #[error("target verification shape mismatch: expected {expected} rows (k+1), got {got}")]
38    TargetShapeMismatch { expected: usize, got: usize },
39
40    /// The caller asked for `k == 0` draft tokens (or similar degenerate
41    /// configuration).
42    #[error("invalid configuration: {0}")]
43    InvalidConfig(String),
44
45    /// The caller supplied an empty prefix to `generate` even though the
46    /// configured models require at least one bos/sos token.
47    #[error("speculative decoding was invoked with an empty prefix")]
48    EmptyPrefix,
49
50    /// A token id produced by a model was outside the configured vocabulary.
51    #[error("token id {token} is out of range for vocabulary size {vocab_size}")]
52    TokenOutOfRange { token: usize, vocab_size: usize },
53
54    /// A probability row collapsed to zero mass — typically because
55    /// `max(0, p_target - p_draft)` was identically zero on every index, which
56    /// happens iff `p_target == p_draft` everywhere.  We fall back to the raw
57    /// target distribution in that case; this variant is reserved for cases
58    /// where even that is degenerate.
59    #[error("no mass left in adjusted distribution and target fallback is also zero")]
60    DegenerateDistribution,
61
62    /// A model implementation returned a descriptive error.
63    #[error("model error: {0}")]
64    ModelError(String),
65}
66
67/// Result alias used across the speculative-decoding module.
68pub type SpeculativeDecodingResult<T> = Result<T, SpeculativeDecodingError>;
69
70impl From<SpeculativeDecodingError> for crate::error::TrustformerError {
71    fn from(err: SpeculativeDecodingError) -> Self {
72        crate::error::TrustformerError::CompilationError(err.to_string())
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn display_contains_context() {
82        let err = SpeculativeDecodingError::VocabMismatch {
83            draft: 10,
84            target: 20,
85        };
86        let msg = err.to_string();
87        assert!(msg.contains("10"));
88        assert!(msg.contains("20"));
89        assert!(msg.contains("vocab"));
90    }
91
92    #[test]
93    fn bridges_into_trustformer_error() {
94        let err = SpeculativeDecodingError::InvalidConfig("k must be > 0".into());
95        let bridged: crate::error::TrustformerError = err.into();
96        assert!(bridged.to_string().contains("k must be > 0"));
97    }
98
99    #[test]
100    fn empty_prefix_is_distinct() {
101        let err = SpeculativeDecodingError::EmptyPrefix;
102        assert!(err.to_string().contains("empty prefix"));
103    }
104}