sapient_tokenizers/
tokenizer.rs1use std::path::Path;
4
5use anyhow::{Context, Result};
6use tokenizers::Tokenizer;
7
8#[derive(Debug, Clone)]
11pub struct TokenizerOptions {
12 pub add_bos: bool,
14 pub add_eos: bool,
16 pub max_length: usize,
18}
19
20impl Default for TokenizerOptions {
21 fn default() -> Self {
22 Self {
23 add_bos: true,
24 add_eos: false,
25 max_length: 0,
26 }
27 }
28}
29
30const EOS_CANDIDATES: &[&str] = &[
39 "</s>",
40 "<eos>",
41 "<|endoftext|>",
42 "<|end_of_text|>",
43 "<|eot_id|>",
44 "<|im_end|>",
45 "<end_of_turn>",
46 "<|redacted_EOS|>",
47];
48
49pub struct SapientTokenizer {
50 inner: Tokenizer,
51 pub bos_id: Option<u32>,
52 pub eos_id: Option<u32>,
53 pub eos_ids: Vec<u32>,
55 pub pad_id: Option<u32>,
56 opts: TokenizerOptions,
57}
58
59impl SapientTokenizer {
60 pub fn from_file(path: &Path, opts: TokenizerOptions) -> Result<Self> {
62 match Tokenizer::from_file(path) {
63 Ok(inner) => Self::from_inner(inner, opts),
64 Err(first_err) => {
65 let normalized = normalize_tokenizer_json(path).with_context(|| {
66 format!("Failed to load tokenizer and could not normalize it: {first_err}")
67 })?;
68 let inner = Tokenizer::from_bytes(&normalized)
69 .map_err(|e| anyhow::anyhow!("Failed to load normalized tokenizer: {e}"))?;
70 Self::from_inner(inner, opts)
71 }
72 }
73 }
74
75 pub fn from_pretrained(model_id: &str) -> Result<Self> {
77 let inner = Tokenizer::from_pretrained(model_id, None)
78 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer for '{model_id}': {e}"))?;
79
80 let bos_id = Self::special_token_id(&inner, &["<s>", "<bos>", "<|begin_of_text|>"]);
81 let eos_ids = Self::all_special_token_ids(&inner, EOS_CANDIDATES);
82 let eos_id = eos_ids.first().copied();
83 let pad_id = Self::special_token_id(&inner, &["<pad>"]);
84
85 Ok(Self {
86 inner,
87 bos_id,
88 eos_id,
89 eos_ids,
90 pad_id,
91 opts: TokenizerOptions::default(),
92 })
93 }
94
95 pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
97 let encoding = self
98 .inner
99 .encode(text, true)
100 .map_err(|e| anyhow::anyhow!("Tokenizer encode error: {e}"))?;
101
102 let mut ids = encoding.get_ids().to_vec();
103
104 if self.opts.add_bos {
106 if let Some(bos) = self.bos_id {
107 if ids.first() != Some(&bos) {
108 ids.insert(0, bos);
109 }
110 }
111 }
112
113 if self.opts.add_eos {
115 if let Some(eos) = self.eos_id {
116 ids.push(eos);
117 }
118 }
119
120 if self.opts.max_length > 0 && ids.len() > self.opts.max_length {
122 ids.truncate(self.opts.max_length);
123 }
124
125 Ok(ids)
126 }
127
128 pub fn decode(&self, ids: &[u32], skip_special: bool) -> Result<String> {
130 self.inner
131 .decode(ids, skip_special)
132 .map_err(|e| anyhow::anyhow!("Tokenizer decode error: {e}"))
133 }
134
135 pub fn decode_token(&self, id: u32) -> Result<String> {
137 self.decode(&[id], true)
138 }
139
140 pub fn vocab_size(&self) -> usize {
142 self.inner.get_vocab_size(true)
143 }
144
145 pub fn is_eos(&self, id: u32) -> bool {
149 self.eos_ids.contains(&id)
150 }
151
152 fn special_token_id(tok: &Tokenizer, candidates: &[&str]) -> Option<u32> {
153 for c in candidates {
154 if let Some(id) = tok.token_to_id(c) {
155 return Some(id);
156 }
157 }
158 None
159 }
160
161 fn all_special_token_ids(tok: &Tokenizer, candidates: &[&str]) -> Vec<u32> {
163 let mut ids = Vec::new();
164 for c in candidates {
165 if let Some(id) = tok.token_to_id(c) {
166 if !ids.contains(&id) {
167 ids.push(id);
168 }
169 }
170 }
171 ids
172 }
173
174 fn from_inner(inner: Tokenizer, opts: TokenizerOptions) -> Result<Self> {
175 let bos_id =
176 Self::special_token_id(&inner, &["<s>", "<bos>", "<|begin_of_text|>", "[BOS]"]);
177 let eos_ids = Self::all_special_token_ids(&inner, EOS_CANDIDATES);
178 let eos_id = eos_ids.first().copied();
179 let pad_id =
180 Self::special_token_id(&inner, &["<pad>", "<|finetune_right_pad_id|>", "[PAD]"]);
181
182 Ok(Self {
183 inner,
184 bos_id,
185 eos_id,
186 eos_ids,
187 pad_id,
188 opts,
189 })
190 }
191}
192
193fn normalize_tokenizer_json(path: &Path) -> Result<Vec<u8>> {
196 let text = std::fs::read_to_string(path)?;
197 let mut value: serde_json::Value = serde_json::from_str(&text)?;
198
199 let Some(model) = value.get_mut("model") else {
200 anyhow::bail!("tokenizer.json missing model section");
201 };
202 let Some(merges) = model.get_mut("merges") else {
203 anyhow::bail!("tokenizer.json missing BPE merges");
204 };
205 let Some(arr) = merges.as_array_mut() else {
206 anyhow::bail!("tokenizer merges are not an array");
207 };
208 if arr.is_empty() {
209 return Ok(text.into_bytes());
210 }
211 if !arr[0].is_array() {
212 anyhow::bail!("tokenizer merges already use string format");
213 }
214
215 let normalized: Vec<String> = arr
216 .iter()
217 .filter_map(|entry| {
218 let pair = entry.as_array()?;
219 if pair.len() != 2 {
220 return None;
221 }
222 Some(format!("{} {}", pair[0].as_str()?, pair[1].as_str()?))
223 })
224 .collect();
225
226 if normalized.len() != arr.len() {
227 anyhow::bail!("failed to normalize all BPE merges");
228 }
229
230 *merges = serde_json::Value::Array(
231 normalized
232 .into_iter()
233 .map(serde_json::Value::String)
234 .collect(),
235 );
236
237 Ok(serde_json::to_vec(&value)?)
238}
239
240#[cfg(test)]
241mod tests {
242 }