wordchipper/decoders/
slab_index_decoder.rs1use core::marker::PhantomData;
4
5use crate::{
6 TokenType,
7 WCResult,
8 alloc::{
9 sync::Arc,
10 vec,
11 vec::Vec,
12 },
13 decoders::{
14 DecodeResult,
15 TokenDecoder,
16 },
17 vocab::{
18 DEFAULT_BYTE_PER_TOKEN_RATIO,
19 TokenSpanMap,
20 UnifiedTokenVocab,
21 },
22};
23
24#[derive(Clone)]
34pub struct SlabIndexDecoder<T: TokenType> {
35 index: Vec<(usize, usize)>,
36 slab: Vec<u8>,
37
38 expected_bytes_per_token: f32,
39 _marker: PhantomData<T>,
40}
41
42impl<T: TokenType> SlabIndexDecoder<T> {
43 pub fn from_vocab(vocab: Arc<UnifiedTokenVocab<T>>) -> Self {
49 Self::new(vocab.unified_dictionary())
50 }
51
52 pub fn new(token_spans: TokenSpanMap<T>) -> Self {
57 let max_token = token_spans.keys().max().unwrap().to_usize().unwrap();
58 let mut index = vec![(0, 0); max_token + 1];
59
60 let total_bytes = token_spans.values().map(|span| span.len()).sum();
61 let mut slab = Vec::with_capacity(total_bytes);
62
63 let mut tokens: Vec<T> = token_spans.keys().copied().collect();
64 tokens.sort_unstable();
65
66 for token in tokens {
67 let idx = token.to_usize().unwrap();
68 let span = token_spans.get(&token).unwrap();
69 index[idx] = (slab.len(), slab.len() + span.len());
70 slab.extend_from_slice(span);
71 }
72
73 Self {
74 index,
75 slab,
76 expected_bytes_per_token: DEFAULT_BYTE_PER_TOKEN_RATIO,
77 _marker: PhantomData,
78 }
79 }
80
81 pub fn expected_bytes_per_token(&self) -> f32 {
83 self.expected_bytes_per_token
84 }
85
86 pub fn with_expected_bytes_per_token(
91 mut self,
92 expected: f32,
93 ) -> Self {
94 self.expected_bytes_per_token = expected;
95 self
96 }
97
98 pub fn predicted_byte_buffer_size(
100 &self,
101 tokens: &[T],
102 ) -> usize {
103 (tokens.len() as f32 * 1.1 * self.expected_bytes_per_token) as usize
104 }
105
106 pub fn lookup_span(
108 &self,
109 token: &T,
110 ) -> Option<&[u8]> {
111 let idx = token.to_usize().unwrap();
112 if idx >= self.index.len() {
113 return None;
114 }
115 let (start, end) = &self.index[idx];
116 if end > start {
117 Some(&self.slab[*start..*end])
118 } else {
119 None
120 }
121 }
122}
123
124impl<T: TokenType> TokenDecoder<T> for SlabIndexDecoder<T> {
125 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, tokens)))]
126 fn try_decode_to_bytes(
127 &self,
128 tokens: &[T],
129 ) -> WCResult<DecodeResult<Vec<u8>>> {
130 let capacity = self.predicted_byte_buffer_size(tokens);
131 let mut value = Vec::with_capacity(capacity);
132
133 let mut consumed = 0;
134 for t in tokens {
135 if let Some(w) = self.lookup_span(t) {
136 value.extend(w);
137 consumed += 1;
138 } else {
139 break;
140 }
141 }
142 Ok(DecodeResult::new(value, Some(tokens.len() - consumed)))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::{
150 alloc::sync::Arc,
151 decoders::utility::testing::common_decoder_tests,
152 pretrained::openai::OA_CL100K_BASE_PATTERN,
153 spanners::TextSpanningConfig,
154 vocab::{
155 UnifiedTokenVocab,
156 utility::testing::{
157 build_test_shift_byte_vocab,
158 build_test_vocab,
159 },
160 },
161 };
162
163 #[test]
164 fn test_decoder() {
165 type T = u16;
166
167 let vocab: Arc<UnifiedTokenVocab<T>> = build_test_vocab(
168 build_test_shift_byte_vocab(10),
169 TextSpanningConfig::from_pattern(OA_CL100K_BASE_PATTERN),
170 )
171 .into();
172
173 let decoder =
174 SlabIndexDecoder::from_vocab(vocab.clone()).with_expected_bytes_per_token(7.5);
175
176 assert_eq!(decoder.expected_bytes_per_token(), 7.5);
177
178 common_decoder_tests(vocab, Arc::new(decoder));
179 }
180}