Skip to main content

three_dcf_core/
stats.rs

1use std::collections::HashMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use base64::{engine::general_purpose, Engine as _};
6use rustc_hash::FxHashMap;
7use serde::Deserialize;
8
9use crate::decoder::Decoder;
10use crate::document::Document;
11use crate::error::{DcfError, Result};
12use crate::serializer::TextSerializer;
13
14#[derive(Debug, Clone)]
15pub enum TokenizerKind {
16    Cl100k,
17    Gpt2,
18    O200k,
19    Anthropic,
20    Custom(PathBuf),
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Stats {
25    pub tokens_raw: usize,
26    pub tokens_3dcf: usize,
27    pub cells: usize,
28    pub unique_payloads: usize,
29    pub savings_ratio: f32,
30}
31
32impl Stats {
33    pub fn measure(document: &Document, tokenizer: TokenizerKind) -> Result<Self> {
34        let encoder = tokenizer.build()?;
35        Self::measure_with_bpe(document, &encoder)
36    }
37
38    pub fn measure_with_bpe(document: &Document, tokenizer: &tiktoken_rs::CoreBPE) -> Result<Self> {
39        let decoder = Decoder::new();
40        let raw_text = decoder.to_text(document)?;
41        let textual = TextSerializer::new().to_string(document)?;
42        let tokens_raw = tokenizer
43            .encode_with_special_tokens(raw_text.as_str())
44            .len();
45        let tokens_3dcf = tokenizer.encode_with_special_tokens(textual.as_str()).len();
46        let savings_ratio = if tokens_3dcf == 0 {
47            0.0
48        } else {
49            tokens_raw as f32 / tokens_3dcf as f32
50        };
51        Ok(Self {
52            tokens_raw,
53            tokens_3dcf,
54            cells: document.total_cells(),
55            unique_payloads: document.dict.len(),
56            savings_ratio,
57        })
58    }
59}
60
61pub fn estimate_tokens(text: &str, tokenizer: &TokenizerKind) -> Result<usize> {
62    let encoder = tokenizer.build()?;
63    Ok(estimate_tokens_with_bpe(text, &encoder))
64}
65
66pub fn estimate_tokens_with_bpe(text: &str, tokenizer: &tiktoken_rs::CoreBPE) -> usize {
67    tokenizer.encode_with_special_tokens(text).len()
68}
69
70impl TokenizerKind {
71    pub fn build(&self) -> Result<tiktoken_rs::CoreBPE> {
72        match self {
73            TokenizerKind::Cl100k => {
74                tiktoken_rs::cl100k_base().map_err(|e| DcfError::Tokenizer(e.to_string()))
75            }
76            TokenizerKind::Gpt2 => {
77                tiktoken_rs::p50k_base().map_err(|e| DcfError::Tokenizer(e.to_string()))
78            }
79            TokenizerKind::O200k => {
80                tiktoken_rs::o200k_base().map_err(|e| DcfError::Tokenizer(e.to_string()))
81            }
82            TokenizerKind::Anthropic => anthropic_base(),
83            TokenizerKind::Custom(path) => load_custom_tokenizer(path),
84        }
85    }
86}
87
88fn anthropic_base() -> Result<tiktoken_rs::CoreBPE> {
89    // Placeholder: Anthropic tokenization aligns closely with cl100k defaults.
90    tiktoken_rs::cl100k_base().map_err(|e| DcfError::Tokenizer(e.to_string()))
91}
92
93fn load_custom_tokenizer(path: &Path) -> Result<tiktoken_rs::CoreBPE> {
94    let data = fs::read_to_string(path).map_err(|e| {
95        DcfError::Tokenizer(format!(
96            "failed to read tokenizer file {}: {e}",
97            path.display()
98        ))
99    })?;
100    let spec: CustomTokenizerSpec = serde_json::from_str(&data)
101        .map_err(|e| DcfError::Tokenizer(format!("invalid tokenizer json: {e}")))?;
102    let mut encoder: FxHashMap<Vec<u8>, usize> = FxHashMap::default();
103    for (token, rank) in spec.mergeable_ranks {
104        encoder.insert(decode_token_key(&token), rank);
105    }
106    let mut special_tokens: FxHashMap<String, usize> = FxHashMap::default();
107    special_tokens.extend(spec.special_tokens.into_iter());
108    tiktoken_rs::CoreBPE::new(encoder, special_tokens, &spec.pat_str)
109        .map_err(|e| DcfError::Tokenizer(format!("failed to build tokenizer: {e}")))
110}
111
112#[derive(Deserialize)]
113struct CustomTokenizerSpec {
114    pat_str: String,
115    mergeable_ranks: HashMap<String, usize>,
116    #[serde(default)]
117    special_tokens: HashMap<String, usize>,
118}
119
120fn decode_token_key(key: &str) -> Vec<u8> {
121    general_purpose::STANDARD
122        .decode(key)
123        .unwrap_or_else(|_| key.as_bytes().to_vec())
124}