Skip to main content

wordchipper/decoders/
slab_index_decoder.rs

1//! # Slab Index Decoder
2
3use core::marker::PhantomData;
4
5use crate::{
6    TokenType,
7    WCResult,
8    alloc::{
9        sync::Arc,
10        vec,
11        vec::Vec,
12    },
13    decoders::{
14        DecodeResult,
15        TokenDecoder,
16    },
17    vocab::{
18        DEFAULT_BYTE_PER_TOKEN_RATIO,
19        TokenSpanMap,
20        UnifiedTokenVocab,
21    },
22};
23
24/// A [`TokenDecoder<T>`] which keeps a dense array index into a shared slab.
25///
26/// It is expected that all tokens (single-byte and multibyte words,
27/// and special tokens) are stored in the slab index.
28///
29/// ## Style Hints
30///
31/// When there is no local ambiguity, instance names should prefer `decoder`;
32/// and expand to `dict_decoder` when there is ambiguity.
33#[derive(Clone)]
34pub struct SlabIndexDecoder<T: TokenType> {
35    index: Vec<(usize, usize)>,
36    slab: Vec<u8>,
37
38    expected_bytes_per_token: f32,
39    _marker: PhantomData<T>,
40}
41
42impl<T: TokenType> SlabIndexDecoder<T> {
43    /// Build a [`SlabIndexDecoder`] from this [`UnifiedTokenVocab`].
44    ///
45    /// ## Arguments
46    /// * `unified_vocab` - The unified token vocabulary to build the decoder
47    ///   from.
48    pub fn from_vocab(vocab: Arc<UnifiedTokenVocab<T>>) -> Self {
49        Self::new(vocab.unified_dictionary())
50    }
51
52    /// Creates a new Decoder.
53    ///
54    /// ## Arguments
55    /// * `token_spans` - The token to word mapping.
56    pub fn new(token_spans: TokenSpanMap<T>) -> Self {
57        let max_token = token_spans.keys().max().unwrap().to_usize().unwrap();
58        let mut index = vec![(0, 0); max_token + 1];
59
60        let total_bytes = token_spans.values().map(|span| span.len()).sum();
61        let mut slab = Vec::with_capacity(total_bytes);
62
63        let mut tokens: Vec<T> = token_spans.keys().copied().collect();
64        tokens.sort_unstable();
65
66        for token in tokens {
67            let idx = token.to_usize().unwrap();
68            let span = token_spans.get(&token).unwrap();
69            index[idx] = (slab.len(), slab.len() + span.len());
70            slab.extend_from_slice(span);
71        }
72
73        Self {
74            index,
75            slab,
76            expected_bytes_per_token: DEFAULT_BYTE_PER_TOKEN_RATIO,
77            _marker: PhantomData,
78        }
79    }
80
81    /// Get the expected bytes per token.
82    pub fn expected_bytes_per_token(&self) -> f32 {
83        self.expected_bytes_per_token
84    }
85
86    /// Sets the expected bytes per token.
87    ///
88    /// This is used to bias the capacity of the output buffer in
89    /// `try_decode_to_bytes`.
90    pub fn with_expected_bytes_per_token(
91        mut self,
92        expected: f32,
93    ) -> Self {
94        self.expected_bytes_per_token = expected;
95        self
96    }
97
98    /// Predict the capacity needed when pre-allocating output buffers.
99    pub fn predicted_byte_buffer_size(
100        &self,
101        tokens: &[T],
102    ) -> usize {
103        (tokens.len() as f32 * 1.1 * self.expected_bytes_per_token) as usize
104    }
105
106    /// Lookup a token.
107    pub fn lookup_span(
108        &self,
109        token: &T,
110    ) -> Option<&[u8]> {
111        let idx = token.to_usize().unwrap();
112        if idx >= self.index.len() {
113            return None;
114        }
115        let (start, end) = &self.index[idx];
116        if end > start {
117            Some(&self.slab[*start..*end])
118        } else {
119            None
120        }
121    }
122}
123
124impl<T: TokenType> TokenDecoder<T> for SlabIndexDecoder<T> {
125    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, tokens)))]
126    fn try_decode_to_bytes(
127        &self,
128        tokens: &[T],
129    ) -> WCResult<DecodeResult<Vec<u8>>> {
130        let capacity = self.predicted_byte_buffer_size(tokens);
131        let mut value = Vec::with_capacity(capacity);
132
133        let mut consumed = 0;
134        for t in tokens {
135            if let Some(w) = self.lookup_span(t) {
136                value.extend(w);
137                consumed += 1;
138            } else {
139                break;
140            }
141        }
142        Ok(DecodeResult::new(value, Some(tokens.len() - consumed)))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::{
150        alloc::sync::Arc,
151        decoders::utility::testing::common_decoder_tests,
152        pretrained::openai::OA_CL100K_BASE_PATTERN,
153        spanners::TextSpanningConfig,
154        vocab::{
155            UnifiedTokenVocab,
156            utility::testing::{
157                build_test_shift_byte_vocab,
158                build_test_vocab,
159            },
160        },
161    };
162
163    #[test]
164    fn test_decoder() {
165        type T = u16;
166
167        let vocab: Arc<UnifiedTokenVocab<T>> = build_test_vocab(
168            build_test_shift_byte_vocab(10),
169            TextSpanningConfig::from_pattern(OA_CL100K_BASE_PATTERN),
170        )
171        .into();
172
173        let decoder =
174            SlabIndexDecoder::from_vocab(vocab.clone()).with_expected_bytes_per_token(7.5);
175
176        assert_eq!(decoder.expected_bytes_per_token(), 7.5);
177
178        common_decoder_tests(vocab, Arc::new(decoder));
179    }
180}