1use wasm_bindgen::prelude::{wasm_bindgen, JsError};
11
12#[wasm_bindgen]
17pub struct Encoding {
18 name: &'static str,
20 bpe: &'static tiktoken::CoreBpe,
22}
23
24#[wasm_bindgen]
25impl Encoding {
26 pub fn encode(&self, text: &str) -> Vec<u32> {
31 self.bpe.encode(text)
32 }
33
34 #[wasm_bindgen(js_name = encodeWithSpecialTokens)]
39 pub fn encode_with_special_tokens(&self, text: &str) -> Vec<u32> {
40 self.bpe.encode_with_special_tokens(text)
41 }
42
43 pub fn decode(&self, tokens: &[u32]) -> String {
47 let bytes = self.bpe.decode(tokens);
48 String::from_utf8_lossy(&bytes).into_owned()
49 }
50
51 pub fn count(&self, text: &str) -> usize {
55 self.bpe.count(text)
56 }
57
58 #[wasm_bindgen(js_name = countWithSpecialTokens)]
63 pub fn count_with_special_tokens(&self, text: &str) -> usize {
64 self.bpe.count_with_special_tokens(text)
65 }
66
67 #[wasm_bindgen(js_name = vocabSize, getter)]
69 pub fn vocab_size(&self) -> usize {
70 self.bpe.vocab_size()
71 }
72
73 #[wasm_bindgen(js_name = numSpecialTokens, getter)]
75 pub fn num_special_tokens(&self) -> usize {
76 self.bpe.num_special_tokens()
77 }
78
79 #[wasm_bindgen(getter)]
81 pub fn name(&self) -> String {
82 self.name.to_string()
83 }
84}
85
86#[wasm_bindgen(js_name = listEncodings)]
90pub fn list_encodings() -> Vec<String> {
91 tiktoken::list_encodings()
92 .iter()
93 .map(|s| s.to_string())
94 .collect()
95}
96
97#[wasm_bindgen(js_name = getEncoding)]
112pub fn get_encoding(name: &str) -> Result<Encoding, JsError> {
113 let static_name = tiktoken::list_encodings()
115 .iter()
116 .find(|&&n| n == name)
117 .ok_or_else(|| JsError::new(&format!("unknown encoding: {name}")))?;
118 let bpe = tiktoken::get_encoding(name)
119 .ok_or_else(|| JsError::new(&format!("unknown encoding: {name}")))?;
120 Ok(Encoding {
121 name: static_name,
122 bpe,
123 })
124}
125
126#[wasm_bindgen(js_name = encodingForModel)]
132pub fn encoding_for_model(model: &str) -> Result<Encoding, JsError> {
133 let name = tiktoken::model_to_encoding(model)
134 .ok_or_else(|| JsError::new(&format!("unknown model: {model}")))?;
135 let bpe = tiktoken::get_encoding(name)
136 .ok_or_else(|| JsError::new(&format!("unknown encoding: {name}")))?;
137 Ok(Encoding { name, bpe })
138}
139
140#[wasm_bindgen(js_name = modelToEncoding)]
144pub fn model_to_encoding(model: &str) -> Option<String> {
145 tiktoken::model_to_encoding(model).map(|s| s.to_string())
146}
147
148#[wasm_bindgen(js_name = estimateCost)]
153pub fn estimate_cost(
154 model_id: &str,
155 input_tokens: u32,
156 output_tokens: u32,
157) -> Result<f64, JsError> {
158 tiktoken::pricing::estimate_cost(model_id, input_tokens as u64, output_tokens as u64)
159 .ok_or_else(|| JsError::new(&format!("unknown model: {model_id}")))
160}
161
162#[wasm_bindgen(js_name = getModelInfo)]
169pub fn get_model_info(model_id: &str) -> Result<ModelInfo, JsError> {
170 let model = tiktoken::pricing::get_model(model_id)
171 .ok_or_else(|| JsError::new(&format!("unknown model: {model_id}")))?;
172 Ok(convert_model(model))
173}
174
175#[wasm_bindgen(js_name = allModels)]
179pub fn all_models() -> Vec<ModelInfo> {
180 tiktoken::pricing::all_models()
181 .iter()
182 .map(convert_model)
183 .collect()
184}
185
186#[wasm_bindgen(js_name = modelsByProvider)]
191pub fn models_by_provider(provider: &str) -> Vec<ModelInfo> {
192 let Some(provider) = parse_provider(provider) else {
193 return Vec::new();
194 };
195
196 tiktoken::pricing::models_by_provider(provider)
197 .iter()
198 .map(|m| convert_model(m))
199 .collect()
200}
201
202fn convert_model(m: &tiktoken::pricing::Model) -> ModelInfo {
203 ModelInfo {
204 id: m.id,
205 provider: m.provider.to_string(),
206 input_per_1m: m.pricing.input_per_1m,
207 output_per_1m: m.pricing.output_per_1m,
208 cached_input_per_1m: m.pricing.cached_input_per_1m,
209 context_window: m.context_window,
210 max_output: m.max_output,
211 }
212}
213
214fn parse_provider(s: &str) -> Option<tiktoken::pricing::Provider> {
215 match s {
216 "OpenAI" => Some(tiktoken::pricing::Provider::OpenAI),
217 "Anthropic" => Some(tiktoken::pricing::Provider::Anthropic),
218 "Google" => Some(tiktoken::pricing::Provider::Google),
219 "Meta" => Some(tiktoken::pricing::Provider::Meta),
220 "DeepSeek" => Some(tiktoken::pricing::Provider::DeepSeek),
221 "Alibaba" => Some(tiktoken::pricing::Provider::Alibaba),
222 "Mistral" => Some(tiktoken::pricing::Provider::Mistral),
223 _ => None,
224 }
225}
226
227#[wasm_bindgen]
229#[derive(Clone)]
230pub struct ModelInfo {
231 id: &'static str,
232 provider: String,
233 input_per_1m: f64,
234 output_per_1m: f64,
235 cached_input_per_1m: Option<f64>,
236 context_window: u32,
237 max_output: u32,
238}
239
240#[wasm_bindgen]
241impl ModelInfo {
242 #[wasm_bindgen(getter)]
243 pub fn id(&self) -> String {
244 self.id.to_string()
245 }
246 #[wasm_bindgen(getter)]
247 pub fn provider(&self) -> String {
248 self.provider.clone()
249 }
250 #[wasm_bindgen(getter, js_name = inputPer1m)]
251 pub fn input_per_1m(&self) -> f64 {
252 self.input_per_1m
253 }
254 #[wasm_bindgen(getter, js_name = outputPer1m)]
255 pub fn output_per_1m(&self) -> f64 {
256 self.output_per_1m
257 }
258 #[wasm_bindgen(getter, js_name = cachedInputPer1m)]
259 pub fn cached_input_per_1m(&self) -> Option<f64> {
260 self.cached_input_per_1m
261 }
262 #[wasm_bindgen(getter, js_name = contextWindow)]
263 pub fn context_window(&self) -> u32 {
264 self.context_window
265 }
266 #[wasm_bindgen(getter, js_name = maxOutput)]
267 pub fn max_output(&self) -> u32 {
268 self.max_output
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn all_encodings_roundtrip() {
278 for &name in tiktoken::list_encodings() {
279 let enc = get_encoding(name).unwrap();
280 let text = "hello world 你好 🚀";
281 let tokens = enc.encode(text);
282 let decoded = enc.decode(&tokens);
283 assert_eq!(decoded, text, "roundtrip failed for {name}");
284 }
285 }
286
287 #[test]
288 fn encoding_for_known_models() {
289 let models = [
290 "gpt-4o", "gpt-4", "gpt-3.5-turbo", "llama-4", "deepseek-r1", "qwen3", "mistral-large",
291 ];
292 for model in models {
293 let enc = encoding_for_model(model);
294 assert!(enc.is_ok(), "encoding_for_model failed for {model}");
295 }
296 }
297
298 #[test]
299 fn list_encodings_count() {
300 let names = list_encodings();
301 assert_eq!(names.len(), 9);
302 }
303
304 #[test]
305 fn all_models_count() {
306 let models = all_models();
307 assert_eq!(models.len(), tiktoken::pricing::all_models().len());
308 }
309
310 #[test]
311 fn models_by_valid_provider() {
312 let openai = models_by_provider("OpenAI");
313 assert!(!openai.is_empty());
314 for m in &openai {
315 assert_eq!(m.provider, "OpenAI");
316 }
317 }
318
319 #[test]
320 fn models_by_invalid_provider() {
321 let unknown = models_by_provider("NonExistent");
322 assert!(unknown.is_empty());
323 }
324
325 #[test]
326 fn estimate_cost_known_model() {
327 let cost = estimate_cost("gpt-4o", 1000, 1000).unwrap();
328 assert!(cost > 0.0);
329 }
330
331 #[test]
332 fn estimate_cost_unknown_model() {
333 assert!(estimate_cost("fake-model", 1000, 1000).is_err());
334 }
335
336 #[test]
337 fn get_model_info_known() {
338 let info = get_model_info("gpt-4o").unwrap();
339 assert_eq!(info.id(), "gpt-4o");
340 assert_eq!(info.provider(), "OpenAI");
341 assert!(info.context_window() > 0);
342 }
343
344 #[test]
345 fn get_model_info_unknown() {
346 assert!(get_model_info("fake-model").is_err());
347 }
348
349 #[test]
350 fn unknown_encoding_error() {
351 assert!(get_encoding("nonexistent").is_err());
352 }
353
354 #[test]
355 fn unknown_model_encoding_error() {
356 assert!(encoding_for_model("nonexistent-model-xyz").is_err());
357 }
358
359 #[test]
360 fn model_to_encoding_known() {
361 let name = model_to_encoding("gpt-4o");
362 assert_eq!(name.as_deref(), Some("o200k_base"));
363 }
364
365 #[test]
366 fn model_to_encoding_unknown() {
367 assert!(model_to_encoding("fake-model").is_none());
368 }
369
370 #[test]
371 fn parse_provider_all_variants() {
372 assert!(parse_provider("OpenAI").is_some());
373 assert!(parse_provider("Anthropic").is_some());
374 assert!(parse_provider("Google").is_some());
375 assert!(parse_provider("Meta").is_some());
376 assert!(parse_provider("DeepSeek").is_some());
377 assert!(parse_provider("Alibaba").is_some());
378 assert!(parse_provider("Mistral").is_some());
379 assert!(parse_provider("Unknown").is_none());
380 }
381}