1use ahash::{AHashMap as HashMap, AHashSet as HashSet};
2use derive_getters::Getters;
3use std::collections::BTreeMap;
4use thiserror::Error;
5use wasm_bindgen::{prelude::wasm_bindgen, JsError};
6
7#[derive(Debug, Error)]
8pub enum TokenizerError {
9 #[error("failed to parse vocabulary: {0}")]
10 FailedToParseVocabulary(serde_json::Error),
11 #[error("no matching token found")]
12 NoMatchingTokenFound,
13 #[error("out of range token: {0}")]
14 OutOfRangeToken(u32),
15}
16
17#[derive(Debug, Clone, Getters)]
18pub struct Tokenizer {
19 first_bytes_to_lengths: Vec<Box<[u16]>>,
20 bytes_to_token_index: HashMap<Vec<u8>, u32>,
21 token_index_to_bytes: Vec<Vec<u8>>,
22}
23
24#[derive(serde::Serialize, serde::Deserialize)]
25#[serde(untagged)]
26enum StrOrBytes {
27 Str(String),
28 Bytes(Vec<u8>),
29}
30
31impl Tokenizer {
32 pub fn new(vocab: &str) -> Result<Self, TokenizerError> {
33 let map: BTreeMap<u32, StrOrBytes> =
34 serde_json::from_str(vocab).map_err(TokenizerError::FailedToParseVocabulary)?;
35
36 let list: Vec<(Vec<u8>, u32)> = map
37 .into_iter()
38 .map(|(token, pattern)| {
39 let pattern = match pattern {
40 StrOrBytes::Str(string) => string.into_bytes(),
41 StrOrBytes::Bytes(bytes) => bytes,
42 };
43 (pattern, token)
44 })
45 .collect();
46
47 let mut first_bytes_to_len = Vec::new();
48 first_bytes_to_len.resize(u16::MAX as usize, 2);
49
50 let mut first_bytes_to_lengths = Vec::new();
51 first_bytes_to_lengths.resize(u16::MAX as usize, {
52 let mut set = HashSet::new();
53 set.insert(1);
54 set
55 });
56
57 let mut token_index_to_bytes = Vec::new();
58 let max_token_index = list.iter().map(|(_, index)| *index).max().unwrap_or(0) as usize;
60 token_index_to_bytes.resize_with(max_token_index + 1, Vec::new);
61
62 let mut bytes_to_token_index = HashMap::new();
63 for (token_bytes, token_index) in list {
64 if token_bytes.len() >= 2 {
65 let key = u16::from_ne_bytes([token_bytes[0], token_bytes[1]]) as usize;
66 let max_length = &mut first_bytes_to_len[key];
67 if token_bytes.len() > *max_length {
68 *max_length = token_bytes.len();
69 }
70
71 first_bytes_to_lengths[key].insert(token_bytes.len() as u16);
72 }
73
74 bytes_to_token_index.insert(token_bytes.clone(), token_index);
75 token_index_to_bytes[token_index as usize] = token_bytes;
76 }
77
78 let first_bytes_to_lengths: Vec<Box<[_]>> = first_bytes_to_lengths
79 .into_iter()
80 .map(|inner| {
81 let mut inner: Vec<_> = inner.into_iter().collect();
82 inner.sort_unstable_by_key(|l| !*l);
83 inner.into_boxed_slice()
84 })
85 .collect();
86
87 Ok(Tokenizer {
88 first_bytes_to_lengths,
89 bytes_to_token_index,
90 token_index_to_bytes,
91 })
92 }
93
94 pub fn encode(&self, input: &[u8]) -> Result<Vec<u32>, TokenizerError> {
95 let mut output = Vec::new();
96 self.encode_into(input, &mut output)?;
97 Ok(output)
98 }
99
100 pub fn decode(&self, tokens: &[u32]) -> Result<Vec<u8>, TokenizerError> {
101 let mut output = Vec::with_capacity(tokens.len());
102 self.decode_into(tokens, &mut output)?;
103 Ok(output)
104 }
105}
106
107impl Tokenizer {
108 pub fn encode_into(
109 &self,
110 mut input: &[u8],
111 output: &mut Vec<u32>,
112 ) -> Result<(), TokenizerError> {
113 'next_token: while !input.is_empty() {
114 let lengths = if input.len() >= 2 {
115 let key = u16::from_ne_bytes([input[0], input[1]]) as usize;
116 &self.first_bytes_to_lengths[key][..]
117 } else {
118 &[1][..]
119 };
120
121 for &length in lengths {
122 let length = length as usize;
123 if length > input.len() {
124 continue;
125 }
126
127 if let Some(&token_index) = self.bytes_to_token_index.get(&input[..length]) {
128 output.push(token_index);
129 input = &input[length..];
130 continue 'next_token;
131 }
132 }
133
134 return Err(TokenizerError::NoMatchingTokenFound);
135 }
136
137 Ok(())
138 }
139
140 pub fn decode_into(&self, tokens: &[u32], output: &mut Vec<u8>) -> Result<(), TokenizerError> {
141 for &token in tokens {
142 let bytes = self
143 .token_index_to_bytes
144 .get(token as usize)
145 .ok_or(TokenizerError::OutOfRangeToken(token))?;
146
147 output.extend_from_slice(bytes);
148 }
149
150 Ok(())
151 }
152}
153
154#[wasm_bindgen(js_name = Tokenizer)]
155pub struct JsTokenizer(Tokenizer);
156
157#[wasm_bindgen(js_class = Tokenizer)]
158impl JsTokenizer {
159 #[wasm_bindgen(constructor)]
160 pub fn new(vocab: &str) -> Result<Self, JsError> {
161 Ok(Self(Tokenizer::new(vocab)?))
162 }
163
164 pub fn encode(&self, input: &[u8]) -> Result<Vec<u32>, JsError> {
165 Ok(self.0.encode(input)?)
166 }
167
168 pub fn decode(&self, tokens: &[u32]) -> Result<Vec<u8>, JsError> {
169 Ok(self.0.decode(tokens)?)
170 }
171}