text_splitter/chunk_size/
huggingface.rsuse tokenizers::{Encoding, Tokenizer};
use crate::ChunkSizer;
fn num_tokens_with_overflow(encoding: &Encoding, pad_id: Option<u32>) -> usize {
let base = encoding
.get_ids()
.iter()
.skip_while(|&id| pad_id.map_or(false, |pad_id| id == &pad_id))
.take_while(|&id| pad_id.map_or(true, |pad_id| id != &pad_id))
.count();
let overflow: usize = encoding
.get_overflowing()
.iter()
.map(|enc| num_tokens_with_overflow(enc, pad_id))
.sum();
base + overflow
}
impl ChunkSizer for &Tokenizer {
fn size(&self, chunk: &str) -> usize {
let encoding = self
.encode(chunk, false)
.expect("Unable to tokenize the following string {chunk}");
let pad_id = self.get_padding().map(|params| params.pad_id);
num_tokens_with_overflow(&encoding, pad_id)
}
}
impl ChunkSizer for Tokenizer {
fn size(&self, chunk: &str) -> usize {
(&self).size(chunk)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn returns_size() {
let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let size = tokenizer.size(" An apple a");
assert_eq!(size, 3);
}
#[test]
fn returns_size_handles_prefix() {
let tokenizer =
tokenizers::Tokenizer::from_file("./tests/tokenizers/huggingface.json").unwrap();
let size = tokenizer.size("An apple a");
assert_eq!(size, 3);
}
#[test]
fn handles_padding() {
let tokenizer = Tokenizer::from_pretrained("thenlper/gte-small", None).unwrap();
let size = tokenizer.size("An apple a");
assert_eq!(size, 3);
}
#[test]
fn handle_truncation() {
let tokenizer = Tokenizer::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", None)
.expect("Could not load tokenizer 'sentence-transformers/all-MiniLM-L6-v2'");
assert_eq!(
tokenizer.size("An apple a day keeps the doctor away.".repeat(100).as_str()),
900
);
}
}