1use std::{
2 borrow::Cow,
3 cell::{Ref, RefMut},
4 cmp::Ordering,
5 fmt,
6 ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
7 rc::Rc,
8 sync::Arc,
9};
10
11use ahash::AHashMap;
12use itertools::Itertools;
13use thiserror::Error;
14
15mod characters;
16#[cfg(feature = "tokenizers")]
17mod huggingface;
18#[cfg(feature = "tiktoken-rs")]
19mod tiktoken;
20
21use crate::trim::Trim;
22pub use characters::Characters;
23
24const PROBE_OVERSHOOT_DIVISOR: usize = 4;
25const PROBE_START_BYTE_FACTOR: usize = 8;
26
27#[derive(Error, Debug)]
31#[error(transparent)]
32pub struct ChunkCapacityError(#[from] ChunkCapacityErrorRepr);
33
34#[derive(Error, Debug)]
36enum ChunkCapacityErrorRepr {
37 #[error("Max chunk size must be greater than or equal to the desired chunk size")]
38 MaxLessThanDesired,
39}
40
41#[derive(Copy, Clone, Debug, PartialEq)]
66pub struct ChunkCapacity {
67 pub(crate) desired: usize,
68 pub(crate) max: usize,
69}
70
71impl ChunkCapacity {
72 #[must_use]
74 pub fn new(size: usize) -> Self {
75 Self {
76 desired: size,
77 max: size,
78 }
79 }
80
81 #[must_use]
87 pub fn desired(&self) -> usize {
88 self.desired
89 }
90
91 #[must_use]
96 pub fn max(&self) -> usize {
97 self.max
98 }
99
100 pub fn with_max(mut self, max: usize) -> Result<Self, ChunkCapacityError> {
113 if max < self.desired {
114 Err(ChunkCapacityError(
115 ChunkCapacityErrorRepr::MaxLessThanDesired,
116 ))
117 } else {
118 self.max = max;
119 Ok(self)
120 }
121 }
122
123 #[must_use]
129 pub fn fits(&self, chunk_size: usize) -> Ordering {
130 if chunk_size < self.desired {
131 Ordering::Less
132 } else if chunk_size > self.max {
133 Ordering::Greater
134 } else {
135 Ordering::Equal
136 }
137 }
138}
139
140impl From<usize> for ChunkCapacity {
141 fn from(size: usize) -> Self {
142 ChunkCapacity::new(size)
143 }
144}
145
146impl From<Range<usize>> for ChunkCapacity {
147 fn from(range: Range<usize>) -> Self {
148 ChunkCapacity::new(range.start)
149 .with_max(range.end.saturating_sub(1).max(range.start))
150 .expect("invalid range")
151 }
152}
153
154impl From<RangeFrom<usize>> for ChunkCapacity {
155 fn from(range: RangeFrom<usize>) -> Self {
156 ChunkCapacity::new(range.start)
157 .with_max(usize::MAX)
158 .expect("invalid range")
159 }
160}
161
162impl From<RangeFull> for ChunkCapacity {
163 fn from(_: RangeFull) -> Self {
164 ChunkCapacity::new(usize::MIN)
165 .with_max(usize::MAX)
166 .expect("invalid range")
167 }
168}
169
170impl From<RangeInclusive<usize>> for ChunkCapacity {
171 fn from(range: RangeInclusive<usize>) -> Self {
172 ChunkCapacity::new(*range.start())
173 .with_max(*range.end())
174 .expect("invalid range")
175 }
176}
177
178impl From<RangeTo<usize>> for ChunkCapacity {
179 fn from(range: RangeTo<usize>) -> Self {
180 ChunkCapacity::new(usize::MIN)
181 .with_max(range.end.saturating_sub(1))
182 .expect("invalid range")
183 }
184}
185
186impl From<RangeToInclusive<usize>> for ChunkCapacity {
187 fn from(range: RangeToInclusive<usize>) -> Self {
188 ChunkCapacity::new(usize::MIN)
189 .with_max(range.end)
190 .expect("invalid range")
191 }
192}
193
194pub trait ChunkSizer {
196 fn size(&self, chunk: &str) -> usize;
198}
199
200impl<T> ChunkSizer for &T
201where
202 T: ChunkSizer,
203{
204 fn size(&self, chunk: &str) -> usize {
205 (*self).size(chunk)
206 }
207}
208
209impl<T> ChunkSizer for Ref<'_, T>
210where
211 T: ChunkSizer,
212{
213 fn size(&self, chunk: &str) -> usize {
214 self.deref().size(chunk)
215 }
216}
217
218impl<T> ChunkSizer for RefMut<'_, T>
219where
220 T: ChunkSizer,
221{
222 fn size(&self, chunk: &str) -> usize {
223 self.deref().size(chunk)
224 }
225}
226
227impl<T> ChunkSizer for Box<T>
228where
229 T: ChunkSizer,
230{
231 fn size(&self, chunk: &str) -> usize {
232 self.deref().size(chunk)
233 }
234}
235
236impl<T> ChunkSizer for Cow<'_, T>
237where
238 T: ChunkSizer + ToOwned + ?Sized,
239 <T as ToOwned>::Owned: ChunkSizer,
240{
241 fn size(&self, chunk: &str) -> usize {
242 self.as_ref().size(chunk)
243 }
244}
245
246impl<T> ChunkSizer for Rc<T>
247where
248 T: ChunkSizer,
249{
250 fn size(&self, chunk: &str) -> usize {
251 self.deref().size(chunk)
252 }
253}
254
255impl<T> ChunkSizer for Arc<T>
256where
257 T: ChunkSizer,
258{
259 fn size(&self, chunk: &str) -> usize {
260 self.as_ref().size(chunk)
261 }
262}
263
264#[derive(Error, Debug)]
268#[error(transparent)]
269pub struct ChunkConfigError(#[from] ChunkConfigErrorRepr);
270
271#[derive(Error, Debug)]
273enum ChunkConfigErrorRepr {
274 #[error("The overlap is larger than or equal to the desired chunk capacity")]
275 OverlapLargerThanCapacity,
276}
277
278#[derive(Debug)]
280pub struct ChunkConfig<Sizer>
281where
282 Sizer: ChunkSizer,
283{
284 pub(crate) capacity: ChunkCapacity,
286 pub(crate) overlap: usize,
288 pub(crate) sizer: Sizer,
290 pub(crate) trim: bool,
292}
293
294impl ChunkConfig<Characters> {
295 #[must_use]
303 pub fn new(capacity: impl Into<ChunkCapacity>) -> Self {
304 Self {
305 capacity: capacity.into(),
306 overlap: 0,
307 sizer: Characters,
308 trim: true,
309 }
310 }
311}
312
313impl<Sizer> ChunkConfig<Sizer>
314where
315 Sizer: ChunkSizer,
316{
317 pub fn capacity(&self) -> &ChunkCapacity {
319 &self.capacity
320 }
321
322 pub fn overlap(&self) -> usize {
324 self.overlap
325 }
326
327 pub fn with_overlap(mut self, overlap: usize) -> Result<Self, ChunkConfigError> {
333 if overlap >= self.capacity.desired {
334 Err(ChunkConfigError(
335 ChunkConfigErrorRepr::OverlapLargerThanCapacity,
336 ))
337 } else {
338 self.overlap = overlap;
339 Ok(self)
340 }
341 }
342
343 pub fn sizer(&self) -> &Sizer {
345 &self.sizer
346 }
347
348 #[must_use]
356 pub fn with_sizer<S: ChunkSizer>(self, sizer: S) -> ChunkConfig<S> {
357 ChunkConfig {
358 capacity: self.capacity,
359 overlap: self.overlap,
360 sizer,
361 trim: self.trim,
362 }
363 }
364
365 pub fn trim(&self) -> bool {
367 self.trim
368 }
369
370 #[must_use]
383 pub fn with_trim(mut self, trim: bool) -> Self {
384 self.trim = trim;
385 self
386 }
387}
388
389impl<T> From<T> for ChunkConfig<Characters>
390where
391 T: Into<ChunkCapacity>,
392{
393 fn from(capacity: T) -> Self {
394 Self::new(capacity)
395 }
396}
397
398#[derive(Debug)]
402pub struct MemoizedChunkSizer<'sizer, Sizer>
403where
404 Sizer: ChunkSizer,
405{
406 size_cache: AHashMap<Range<usize>, usize>,
408 sizer: &'sizer Sizer,
410}
411
412impl<'sizer, Sizer> MemoizedChunkSizer<'sizer, Sizer>
413where
414 Sizer: ChunkSizer,
415{
416 pub fn new(sizer: &'sizer Sizer) -> Self {
418 Self {
419 size_cache: AHashMap::new(),
420 sizer,
421 }
422 }
423
424 pub fn chunk_size(&mut self, offset: usize, chunk: &str, trim: Trim) -> usize {
427 let (offset, chunk) = trim.trim(offset, chunk);
428 *self
429 .size_cache
430 .entry(offset..(offset + chunk.len()))
431 .or_insert_with(|| self.sizer.size(chunk))
432 }
433
434 pub fn find_correct_level<'text, L, Boundaries>(
436 &mut self,
437 offset: usize,
438 capacity: &ChunkCapacity,
439 levels_with_first_chunk: impl Iterator<Item = (L, &'text str, Option<L>)>,
440 mut lower_boundaries_for: impl FnMut(Option<L>, usize) -> Boundaries,
441 trim: Trim,
442 ) -> (Option<L>, Option<usize>)
443 where
444 L: Copy + fmt::Debug,
445 Boundaries: Iterator<Item = usize>,
446 {
447 let mut semantic_level = None;
448 let mut max_offset = None;
449
450 let levels_with_first_chunk = levels_with_first_chunk.coalesce(
452 |(a_level, a_str, a_lower_level), (b_level, b_str, b_lower_level)| {
453 if a_str.len() >= b_str.len() {
454 Ok((b_level, b_str, b_lower_level))
455 } else {
456 Err((
457 (a_level, a_str, a_lower_level),
458 (b_level, b_str, b_lower_level),
459 ))
460 }
461 },
462 );
463
464 for (level, str, lower_level) in levels_with_first_chunk {
465 let len = str.len();
467 if len > capacity.max {
468 let mut lower_boundaries = lower_boundaries_for(lower_level, offset + len);
469 let fits = self.chunk_fits_with_boundaries(
470 offset,
471 str,
472 capacity,
473 &mut lower_boundaries,
474 trim,
475 );
476 if fits.is_gt() {
478 max_offset = Some(offset + len);
481 break;
482 }
483 }
484 semantic_level = Some(level);
486 }
487
488 (semantic_level, max_offset)
489 }
490
491 fn chunk_fits_with_boundaries(
492 &mut self,
493 offset: usize,
494 chunk: &str,
495 capacity: &ChunkCapacity,
496 lower_boundaries: &mut impl Iterator<Item = usize>,
497 trim: Trim,
498 ) -> Ordering {
499 let chunk_end = offset + chunk.len();
500 let probe_max = probe_max_size(capacity);
501 let probe_start = offset
502 .saturating_add(capacity.max.saturating_mul(PROBE_START_BYTE_FACTOR).max(1))
503 .min(chunk_end);
504 if probe_start == chunk_end {
505 let chunk_size = self.chunk_size(offset, chunk, trim);
506 return capacity.fits(chunk_size);
507 }
508
509 let mut step: usize = 1;
510 let mut next_boundary =
511 lower_boundaries.find(|&boundary| boundary > offset && boundary >= probe_start);
512
513 while let Some(boundary) = next_boundary {
514 if boundary > chunk_end {
515 break;
516 }
517
518 let prefix = chunk
519 .get(..(boundary - offset))
520 .expect("valid character boundary");
521 let chunk_size = self.chunk_size(offset, prefix, trim);
522 let fits = capacity.fits(chunk_size);
523 if chunk_size > probe_max || boundary == chunk_end {
524 return fits;
525 }
526
527 let skip = step.saturating_sub(1);
528 step = step.saturating_mul(2);
529 next_boundary = lower_boundaries.nth(skip);
530 }
531
532 let chunk_size = self.chunk_size(offset, chunk, trim);
533 capacity.fits(chunk_size)
534 }
535
536 pub fn clear_cache(&mut self) {
539 self.size_cache.clear();
540 }
541}
542
543fn probe_max_size(capacity: &ChunkCapacity) -> usize {
544 capacity
545 .max
546 .saturating_add((capacity.max / PROBE_OVERSHOOT_DIVISOR).max(1))
547}
548
549#[cfg(test)]
550mod tests {
551 use std::{
552 cell::RefCell,
553 sync::atomic::{self, AtomicUsize},
554 };
555
556 use crate::trim::Trim;
557
558 use super::*;
559
560 #[test]
561 fn check_chunk_capacity() {
562 let chunk = "12345";
563
564 assert_eq!(
565 ChunkCapacity::from(4).fits(Characters.size(chunk)),
566 Ordering::Greater
567 );
568 assert_eq!(
569 ChunkCapacity::from(5).fits(Characters.size(chunk)),
570 Ordering::Equal
571 );
572 assert_eq!(
573 ChunkCapacity::from(6).fits(Characters.size(chunk)),
574 Ordering::Less
575 );
576 }
577
578 #[test]
579 fn check_chunk_capacity_for_range() {
580 let chunk = "12345";
581
582 assert_eq!(
583 ChunkCapacity::from(0..0).fits(Characters.size(chunk)),
584 Ordering::Greater
585 );
586 assert_eq!(
587 ChunkCapacity::from(0..5).fits(Characters.size(chunk)),
588 Ordering::Greater
589 );
590 assert_eq!(
591 ChunkCapacity::from(5..6).fits(Characters.size(chunk)),
592 Ordering::Equal
593 );
594 assert_eq!(
595 ChunkCapacity::from(6..100).fits(Characters.size(chunk)),
596 Ordering::Less
597 );
598 }
599
600 #[test]
601 fn check_chunk_capacity_for_range_from() {
602 let chunk = "12345";
603
604 assert_eq!(
605 ChunkCapacity::from(0..).fits(Characters.size(chunk)),
606 Ordering::Equal
607 );
608 assert_eq!(
609 ChunkCapacity::from(5..).fits(Characters.size(chunk)),
610 Ordering::Equal
611 );
612 assert_eq!(
613 ChunkCapacity::from(6..).fits(Characters.size(chunk)),
614 Ordering::Less
615 );
616 }
617
618 #[test]
619 fn check_chunk_capacity_for_range_full() {
620 let chunk = "12345";
621
622 assert_eq!(
623 ChunkCapacity::from(..).fits(Characters.size(chunk)),
624 Ordering::Equal
625 );
626 }
627
628 #[test]
629 fn check_chunk_capacity_for_range_inclusive() {
630 let chunk = "12345";
631
632 assert_eq!(
633 ChunkCapacity::from(0..=4).fits(Characters.size(chunk)),
634 Ordering::Greater
635 );
636 assert_eq!(
637 ChunkCapacity::from(5..=6).fits(Characters.size(chunk)),
638 Ordering::Equal
639 );
640 assert_eq!(
641 ChunkCapacity::from(4..=5).fits(Characters.size(chunk)),
642 Ordering::Equal
643 );
644 assert_eq!(
645 ChunkCapacity::from(6..=100).fits(Characters.size(chunk)),
646 Ordering::Less
647 );
648 }
649
650 #[test]
651 fn check_chunk_capacity_for_range_to() {
652 let chunk = "12345";
653
654 assert_eq!(
655 ChunkCapacity::from(..0).fits(Characters.size(chunk)),
656 Ordering::Greater
657 );
658 assert_eq!(
659 ChunkCapacity::from(..5).fits(Characters.size(chunk)),
660 Ordering::Greater
661 );
662 assert_eq!(
663 ChunkCapacity::from(..6).fits(Characters.size(chunk)),
664 Ordering::Equal
665 );
666 }
667
668 #[test]
669 fn check_chunk_capacity_for_range_to_inclusive() {
670 let chunk = "12345";
671
672 assert_eq!(
673 ChunkCapacity::from(..=4).fits(Characters.size(chunk)),
674 Ordering::Greater
675 );
676 assert_eq!(
677 ChunkCapacity::from(..=5).fits(Characters.size(chunk)),
678 Ordering::Equal
679 );
680 assert_eq!(
681 ChunkCapacity::from(..=6).fits(Characters.size(chunk)),
682 Ordering::Equal
683 );
684 }
685
686 #[derive(Default)]
687 struct CountingSizer {
688 calls: AtomicUsize,
689 }
690
691 impl ChunkSizer for CountingSizer {
692 fn size(&self, chunk: &str) -> usize {
694 self.calls.fetch_add(1, atomic::Ordering::SeqCst);
695 Characters.size(chunk)
696 }
697 }
698
699 #[test]
700 fn memoized_sizer_only_calculates_once_per_text() {
701 let sizer = CountingSizer::default();
702 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
703 let text = "1234567890";
704 for _ in 0..10 {
705 memoized_sizer.chunk_size(0, text, Trim::All);
706 }
707
708 assert_eq!(memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst), 1);
709 }
710
711 #[test]
712 fn memoized_sizer_calculates_once_per_different_text() {
713 let sizer = CountingSizer::default();
714 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
715 let text = "1234567890";
716 for i in 0..10 {
717 memoized_sizer.chunk_size(0, text.get(0..i).unwrap(), Trim::All);
718 }
719
720 assert_eq!(
721 memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
722 10
723 );
724 }
725
726 #[test]
727 fn can_clear_cache_on_memoized_sizer() {
728 let sizer = CountingSizer::default();
729 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
730 let text = "1234567890";
731 for _ in 0..10 {
732 memoized_sizer.chunk_size(0, text, Trim::All);
733 memoized_sizer.clear_cache();
734 }
735
736 assert_eq!(
737 memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
738 10
739 );
740 }
741
742 #[test]
743 fn boundary_probe_is_skipped_when_probe_start_is_chunk_end() {
744 let sizer = CountingSizer::default();
745 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
746 let mut lower_boundaries = [usize::MAX].into_iter();
747
748 let fits = memoized_sizer.chunk_fits_with_boundaries(
749 0,
750 "12345678901",
751 &ChunkCapacity::new(10),
752 &mut lower_boundaries,
753 Trim::All,
754 );
755
756 assert_eq!(fits, Ordering::Greater);
757 assert_eq!(sizer.calls.load(atomic::Ordering::SeqCst), 1);
758 assert_eq!(lower_boundaries.next(), Some(usize::MAX));
759 }
760
761 #[test]
762 fn boundary_probe_stops_when_next_boundary_exceeds_chunk_end() {
763 let sizer = CountingSizer::default();
764 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
765 let chunk = "1234567890".repeat(9);
766 let mut lower_boundaries = [usize::MAX].into_iter();
767
768 let fits = memoized_sizer.chunk_fits_with_boundaries(
769 0,
770 &chunk,
771 &ChunkCapacity::new(10),
772 &mut lower_boundaries,
773 Trim::All,
774 );
775
776 assert_eq!(fits, Ordering::Greater);
777 assert_eq!(sizer.calls.load(atomic::Ordering::SeqCst), 1);
778 assert_eq!(lower_boundaries.next(), None);
779 }
780
781 #[test]
782 fn basic_chunk_config() {
783 let config = ChunkConfig::new(10);
784 assert_eq!(config.capacity, 10.into());
785 assert_eq!(config.sizer, Characters);
786 assert!(config.trim());
787 }
788
789 #[test]
790 fn disable_trimming() {
791 let config = ChunkConfig::new(10).with_trim(false);
792 assert!(!config.trim());
793 }
794
795 #[test]
796 fn new_sizer() {
797 #[derive(Debug, PartialEq)]
798 struct BasicSizer;
799
800 impl ChunkSizer for BasicSizer {
801 fn size(&self, _chunk: &str) -> usize {
802 unimplemented!()
803 }
804 }
805
806 let config = ChunkConfig::new(10).with_sizer(BasicSizer);
807 assert_eq!(config.capacity, 10.into());
808 assert_eq!(config.sizer, BasicSizer);
809 assert!(config.trim());
810 }
811
812 #[test]
813 fn chunk_capacity_max_and_desired_equal() {
814 let capacity = ChunkCapacity::new(10);
815 assert_eq!(capacity.desired(), 10);
816 assert_eq!(capacity.max(), 10);
817 }
818
819 #[test]
820 fn chunk_capacity_can_adjust_max() {
821 let capacity = ChunkCapacity::new(10).with_max(20).unwrap();
822 assert_eq!(capacity.desired(), 10);
823 assert_eq!(capacity.max(), 20);
824 }
825
826 #[test]
827 fn chunk_capacity_max_cant_be_less_than_desired() {
828 let capacity = ChunkCapacity::new(10);
829 let err = capacity.with_max(5).unwrap_err();
830 assert_eq!(
831 err.to_string(),
832 "Max chunk size must be greater than or equal to the desired chunk size"
833 );
834 assert_eq!(capacity.desired(), 10);
835 assert_eq!(capacity.max(), 10);
836 }
837
838 #[test]
839 fn set_chunk_overlap() {
840 let config = ChunkConfig::new(10).with_overlap(5).unwrap();
841 assert_eq!(config.overlap(), 5);
842 }
843
844 #[test]
845 fn cant_set_overlap_larger_than_capacity() {
846 let chunk_config = ChunkConfig::new(5);
847 let err = chunk_config.with_overlap(10).unwrap_err();
848 assert_eq!(
849 err.to_string(),
850 "The overlap is larger than or equal to the desired chunk capacity"
851 );
852 }
853
854 #[test]
855 fn cant_set_overlap_larger_than_desired() {
856 let chunk_config = ChunkConfig::new(5..15);
857 let err = chunk_config.with_overlap(10).unwrap_err();
858 assert_eq!(
859 err.to_string(),
860 "The overlap is larger than or equal to the desired chunk capacity"
861 );
862 }
863
864 #[test]
865 fn chunk_size_reference() {
866 let config = ChunkConfig::new(1).with_sizer(&Characters);
867 config.sizer().size("chunk");
868 }
869
870 #[test]
871 fn chunk_size_cow() {
872 let sizer: Cow<'_, Characters> = Cow::Owned(Characters);
873 let config = ChunkConfig::new(1).with_sizer(sizer);
874 config.sizer().size("chunk");
875
876 let sizer = Cow::Borrowed(&Characters);
877 let config = ChunkConfig::new(1).with_sizer(sizer);
878 config.sizer().size("chunk");
879 }
880
881 #[test]
882 fn chunk_size_arc() {
883 let sizer = Arc::new(Characters);
884 let config = ChunkConfig::new(1).with_sizer(sizer);
885 config.sizer().size("chunk");
886 }
887
888 #[test]
889 fn chunk_size_ref() {
890 let sizer = RefCell::new(Characters);
891 let config = ChunkConfig::new(1).with_sizer(sizer.borrow());
892 config.sizer().size("chunk");
893 }
894
895 #[test]
896 fn chunk_size_ref_mut() {
897 let sizer = RefCell::new(Characters);
898 let config = ChunkConfig::new(1).with_sizer(sizer.borrow_mut());
899 config.sizer().size("chunk");
900 }
901
902 #[test]
903 fn chunk_size_box() {
904 let sizer = Box::new(Characters);
905 let config = ChunkConfig::new(1).with_sizer(sizer);
906 config.sizer().size("chunk");
907 }
908
909 #[test]
910 fn chunk_size_rc() {
911 let sizer = Rc::new(Characters);
912 let config = ChunkConfig::new(1).with_sizer(sizer);
913 config.sizer().size("chunk");
914 }
915}