web_rwkv/
tokenizer.rs

1use ahash::{AHashMap as HashMap, AHashSet as HashSet};
2use derive_getters::Getters;
3use std::collections::BTreeMap;
4use thiserror::Error;
5use wasm_bindgen::{prelude::wasm_bindgen, JsError};
6
7#[derive(Debug, Error)]
8pub enum TokenizerError {
9    #[error("failed to parse vocabulary: {0}")]
10    FailedToParseVocabulary(serde_json::Error),
11    #[error("no matching token found")]
12    NoMatchingTokenFound,
13    #[error("out of range token: {0}")]
14    OutOfRangeToken(u32),
15}
16
17#[derive(Debug, Clone, Getters)]
18pub struct Tokenizer {
19    first_bytes_to_lengths: Vec<Box<[u16]>>,
20    bytes_to_token_index: HashMap<Vec<u8>, u32>,
21    token_index_to_bytes: Vec<Vec<u8>>,
22}
23
24#[derive(serde::Serialize, serde::Deserialize)]
25#[serde(untagged)]
26enum StrOrBytes {
27    Str(String),
28    Bytes(Vec<u8>),
29}
30
31impl Tokenizer {
32    pub fn new(vocab: &str) -> Result<Self, TokenizerError> {
33        let map: BTreeMap<u32, StrOrBytes> =
34            serde_json::from_str(vocab).map_err(TokenizerError::FailedToParseVocabulary)?;
35
36        let list: Vec<(Vec<u8>, u32)> = map
37            .into_iter()
38            .map(|(token, pattern)| {
39                let pattern = match pattern {
40                    StrOrBytes::Str(string) => string.into_bytes(),
41                    StrOrBytes::Bytes(bytes) => bytes,
42                };
43                (pattern, token)
44            })
45            .collect();
46
47        let mut first_bytes_to_len = Vec::new();
48        first_bytes_to_len.resize(u16::MAX as usize, 2);
49
50        let mut first_bytes_to_lengths = Vec::new();
51        first_bytes_to_lengths.resize(u16::MAX as usize, {
52            let mut set = HashSet::new();
53            set.insert(1);
54            set
55        });
56
57        let mut token_index_to_bytes = Vec::new();
58        // Find the max token index to determine the size of the vector.
59        let max_token_index = list.iter().map(|(_, index)| *index).max().unwrap_or(0) as usize;
60        token_index_to_bytes.resize_with(max_token_index + 1, Vec::new);
61
62        let mut bytes_to_token_index = HashMap::new();
63        for (token_bytes, token_index) in list {
64            if token_bytes.len() >= 2 {
65                let key = u16::from_ne_bytes([token_bytes[0], token_bytes[1]]) as usize;
66                let max_length = &mut first_bytes_to_len[key];
67                if token_bytes.len() > *max_length {
68                    *max_length = token_bytes.len();
69                }
70
71                first_bytes_to_lengths[key].insert(token_bytes.len() as u16);
72            }
73
74            bytes_to_token_index.insert(token_bytes.clone(), token_index);
75            token_index_to_bytes[token_index as usize] = token_bytes;
76        }
77
78        let first_bytes_to_lengths: Vec<Box<[_]>> = first_bytes_to_lengths
79            .into_iter()
80            .map(|inner| {
81                let mut inner: Vec<_> = inner.into_iter().collect();
82                inner.sort_unstable_by_key(|l| !*l);
83                inner.into_boxed_slice()
84            })
85            .collect();
86
87        Ok(Tokenizer {
88            first_bytes_to_lengths,
89            bytes_to_token_index,
90            token_index_to_bytes,
91        })
92    }
93
94    pub fn encode(&self, input: &[u8]) -> Result<Vec<u32>, TokenizerError> {
95        let mut output = Vec::new();
96        self.encode_into(input, &mut output)?;
97        Ok(output)
98    }
99
100    pub fn decode(&self, tokens: &[u32]) -> Result<Vec<u8>, TokenizerError> {
101        let mut output = Vec::with_capacity(tokens.len());
102        self.decode_into(tokens, &mut output)?;
103        Ok(output)
104    }
105}
106
107impl Tokenizer {
108    pub fn encode_into(
109        &self,
110        mut input: &[u8],
111        output: &mut Vec<u32>,
112    ) -> Result<(), TokenizerError> {
113        'next_token: while !input.is_empty() {
114            let lengths = if input.len() >= 2 {
115                let key = u16::from_ne_bytes([input[0], input[1]]) as usize;
116                &self.first_bytes_to_lengths[key][..]
117            } else {
118                &[1][..]
119            };
120
121            for &length in lengths {
122                let length = length as usize;
123                if length > input.len() {
124                    continue;
125                }
126
127                if let Some(&token_index) = self.bytes_to_token_index.get(&input[..length]) {
128                    output.push(token_index);
129                    input = &input[length..];
130                    continue 'next_token;
131                }
132            }
133
134            return Err(TokenizerError::NoMatchingTokenFound);
135        }
136
137        Ok(())
138    }
139
140    pub fn decode_into(&self, tokens: &[u32], output: &mut Vec<u8>) -> Result<(), TokenizerError> {
141        for &token in tokens {
142            let bytes = self
143                .token_index_to_bytes
144                .get(token as usize)
145                .ok_or(TokenizerError::OutOfRangeToken(token))?;
146
147            output.extend_from_slice(bytes);
148        }
149
150        Ok(())
151    }
152}
153
154#[wasm_bindgen(js_name = Tokenizer)]
155pub struct JsTokenizer(Tokenizer);
156
157#[wasm_bindgen(js_class = Tokenizer)]
158impl JsTokenizer {
159    #[wasm_bindgen(constructor)]
160    pub fn new(vocab: &str) -> Result<Self, JsError> {
161        Ok(Self(Tokenizer::new(vocab)?))
162    }
163
164    pub fn encode(&self, input: &[u8]) -> Result<Vec<u32>, JsError> {
165        Ok(self.0.encode(input)?)
166    }
167
168    pub fn decode(&self, tokens: &[u32]) -> Result<Vec<u8>, JsError> {
169        Ok(self.0.decode(tokens)?)
170    }
171}