use std::{
cmp::Ordering,
ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
};
use ahash::AHashMap;
mod characters;
#[cfg(feature = "tokenizers")]
mod huggingface;
#[cfg(feature = "tiktoken-rs")]
mod tiktoken;
pub use characters::Characters;
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct ChunkSize {
fits: Ordering,
max_chunk_size_offset: Option<usize>,
size: usize,
}
impl ChunkSize {
pub fn from_size(size: usize, capacity: &impl ChunkCapacity) -> Self {
Self {
fits: capacity.fits(size),
max_chunk_size_offset: None,
size,
}
}
pub fn from_offsets(
offsets: impl Iterator<Item = Range<usize>>,
capacity: &impl ChunkCapacity,
) -> Self {
let mut chunk_size = offsets.fold(
Self {
fits: Ordering::Less,
max_chunk_size_offset: None,
size: 0,
},
|mut acc, range| {
acc.size += 1;
if acc.size <= capacity.end() {
acc.max_chunk_size_offset = Some(range.end);
}
acc
},
);
chunk_size.fits = capacity.fits(chunk_size.size);
chunk_size
}
#[must_use]
pub fn fits(&self) -> Ordering {
self.fits
}
#[must_use]
pub fn max_chunk_size_offset(&self) -> Option<usize> {
self.max_chunk_size_offset
}
#[must_use]
pub fn size(&self) -> usize {
self.size
}
}
pub trait ChunkSizer {
fn chunk_size(&self, chunk: &str, capacity: &impl ChunkCapacity) -> ChunkSize;
}
#[derive(Debug)]
pub struct MemoizedChunkSizer<'sizer, C, S>
where
C: ChunkCapacity,
S: ChunkSizer,
{
cache: AHashMap<Range<usize>, ChunkSize>,
chunk_capacity: C,
sizer: &'sizer S,
}
impl<'sizer, C, S> MemoizedChunkSizer<'sizer, C, S>
where
C: ChunkCapacity,
S: ChunkSizer,
{
pub fn new(chunk_capacity: C, sizer: &'sizer S) -> Self {
Self {
cache: AHashMap::new(),
chunk_capacity,
sizer,
}
}
pub fn chunk_size(&mut self, offset: usize, chunk: &str) -> ChunkSize {
*self
.cache
.entry(offset..(offset + chunk.len()))
.or_insert_with(|| self.sizer.chunk_size(chunk, &self.chunk_capacity))
}
pub fn check_capacity(&mut self, (offset, chunk): (usize, &str)) -> ChunkSize {
let mut chunk_size = self.chunk_size(offset, chunk);
if let Some(max_chunk_size_offset) = chunk_size.max_chunk_size_offset.as_mut() {
*max_chunk_size_offset += offset;
}
chunk_size
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
pub trait ChunkCapacity {
fn start(&self) -> Option<usize> {
None
}
#[must_use]
fn end(&self) -> usize;
fn fits(&self, chunk_size: usize) -> Ordering {
let end = self.end();
match self.start() {
Some(start) => {
if chunk_size < start {
Ordering::Less
} else if chunk_size > end {
Ordering::Greater
} else {
Ordering::Equal
}
}
None => chunk_size.cmp(&end),
}
}
}
impl ChunkCapacity for usize {
fn end(&self) -> usize {
*self
}
}
impl ChunkCapacity for Range<usize> {
fn start(&self) -> Option<usize> {
Some(self.start)
}
fn end(&self) -> usize {
self.end.saturating_sub(1).max(self.start)
}
}
impl ChunkCapacity for RangeFrom<usize> {
fn start(&self) -> Option<usize> {
Some(self.start)
}
fn end(&self) -> usize {
usize::MAX
}
}
impl ChunkCapacity for RangeFull {
fn start(&self) -> Option<usize> {
Some(usize::MIN)
}
fn end(&self) -> usize {
usize::MAX
}
}
impl ChunkCapacity for RangeInclusive<usize> {
fn start(&self) -> Option<usize> {
Some(*self.start())
}
fn end(&self) -> usize {
*self.end()
}
}
impl ChunkCapacity for RangeTo<usize> {
fn start(&self) -> Option<usize> {
Some(usize::MIN)
}
fn end(&self) -> usize {
self.end.saturating_sub(1)
}
}
impl ChunkCapacity for RangeToInclusive<usize> {
fn start(&self) -> Option<usize> {
Some(usize::MIN)
}
fn end(&self) -> usize {
self.end
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{self, AtomicUsize};
use super::*;
#[test]
fn check_chunk_capacity() {
let chunk = "12345";
assert_eq!(Characters.chunk_size(chunk, &4).fits, Ordering::Greater);
assert_eq!(Characters.chunk_size(chunk, &5).fits, Ordering::Equal);
assert_eq!(Characters.chunk_size(chunk, &6).fits, Ordering::Less);
}
#[test]
fn check_chunk_capacity_for_range() {
let chunk = "12345";
assert_eq!(
Characters.chunk_size(chunk, &(0..0)).fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(0..5)).fits,
Ordering::Greater
);
assert_eq!(Characters.chunk_size(chunk, &(5..6)).fits, Ordering::Equal);
assert_eq!(Characters.chunk_size(chunk, &(6..100)).fits, Ordering::Less);
}
#[test]
fn check_chunk_capacity_for_range_from() {
let chunk = "12345";
assert_eq!(Characters.chunk_size(chunk, &(0..)).fits, Ordering::Equal);
assert_eq!(Characters.chunk_size(chunk, &(5..)).fits, Ordering::Equal);
assert_eq!(Characters.chunk_size(chunk, &(6..)).fits, Ordering::Less);
}
#[test]
fn check_chunk_capacity_for_range_full() {
let chunk = "12345";
assert_eq!(Characters.chunk_size(chunk, &..).fits, Ordering::Equal);
}
#[test]
fn check_chunk_capacity_for_range_inclusive() {
let chunk = "12345";
assert_eq!(
Characters.chunk_size(chunk, &(0..=4)).fits,
Ordering::Greater
);
assert_eq!(Characters.chunk_size(chunk, &(5..=6)).fits, Ordering::Equal);
assert_eq!(Characters.chunk_size(chunk, &(4..=5)).fits, Ordering::Equal);
assert_eq!(
Characters.chunk_size(chunk, &(6..=100)).fits,
Ordering::Less
);
}
#[test]
fn check_chunk_capacity_for_range_to() {
let chunk = "12345";
assert_eq!(Characters.chunk_size(chunk, &(..0)).fits, Ordering::Greater);
assert_eq!(Characters.chunk_size(chunk, &(..5)).fits, Ordering::Greater);
assert_eq!(Characters.chunk_size(chunk, &(..6)).fits, Ordering::Equal);
}
#[test]
fn check_chunk_capacity_for_range_to_inclusive() {
let chunk = "12345";
assert_eq!(
Characters.chunk_size(chunk, &(..=4)).fits,
Ordering::Greater
);
assert_eq!(Characters.chunk_size(chunk, &(..=5)).fits, Ordering::Equal);
assert_eq!(Characters.chunk_size(chunk, &(..=6)).fits, Ordering::Equal);
}
#[test]
fn chunk_size_from_offsets() {
let offsets = [0..1, 1..2, 2..3];
let chunk_size = ChunkSize::from_offsets(offsets.clone().into_iter(), &1);
assert_eq!(
ChunkSize {
fits: Ordering::Greater,
size: offsets.len(),
max_chunk_size_offset: Some(1)
},
chunk_size
);
}
#[test]
fn chunk_size_from_empty_offsets() {
let offsets = [];
let chunk_size = ChunkSize::from_offsets(offsets.clone().into_iter(), &1);
assert_eq!(
ChunkSize {
fits: Ordering::Less,
size: offsets.len(),
max_chunk_size_offset: None
},
chunk_size
);
}
#[test]
fn chunk_size_from_small_offsets() {
let offsets = [0..1, 1..2, 2..3];
let chunk_size = ChunkSize::from_offsets(offsets.clone().into_iter(), &4);
assert_eq!(
ChunkSize {
fits: Ordering::Less,
size: offsets.len(),
max_chunk_size_offset: Some(3)
},
chunk_size
);
}
#[derive(Default)]
struct CountingSizer {
calls: AtomicUsize,
}
impl ChunkSizer for CountingSizer {
fn chunk_size(&self, chunk: &str, capacity: &impl ChunkCapacity) -> ChunkSize {
self.calls.fetch_add(1, atomic::Ordering::SeqCst);
Characters.chunk_size(chunk, capacity)
}
}
#[test]
fn memoized_sizer_only_calculates_once_per_text() {
let sizer = CountingSizer::default();
let mut memoized_sizer = MemoizedChunkSizer::new(10, &sizer);
let text = "1234567890";
for _ in 0..10 {
memoized_sizer.chunk_size(0, text);
}
assert_eq!(memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst), 1);
}
#[test]
fn memoized_sizer_calculates_once_per_different_text() {
let sizer = CountingSizer::default();
let mut memoized_sizer = MemoizedChunkSizer::new(10, &sizer);
let text = "1234567890";
for i in 0..10 {
memoized_sizer.chunk_size(0, text.get(0..i).unwrap());
}
assert_eq!(
memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
10
);
}
#[test]
fn can_clear_cache_on_memoized_sizer() {
let sizer = CountingSizer::default();
let mut memoized_sizer = MemoizedChunkSizer::new(10, &sizer);
let text = "1234567890";
for _ in 0..10 {
memoized_sizer.chunk_size(0, text);
memoized_sizer.clear_cache();
}
assert_eq!(
memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
10
);
}
#[test]
fn test_chunk_size_from_size() {
let chunk_size = ChunkSize::from_size(10, &10);
assert_eq!(
ChunkSize {
fits: Ordering::Equal,
size: 10,
max_chunk_size_offset: None
},
chunk_size
);
}
}