1use std::{
2 borrow::Cow,
3 cell::{Ref, RefMut},
4 cmp::Ordering,
5 fmt,
6 ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
7 rc::Rc,
8 sync::Arc,
9};
10
11use ahash::AHashMap;
12use itertools::Itertools;
13use thiserror::Error;
14
15mod characters;
16#[cfg(feature = "tokenizers")]
17mod huggingface;
18#[cfg(feature = "tiktoken-rs")]
19mod tiktoken;
20
21use crate::trim::Trim;
22pub use characters::Characters;
23
24#[derive(Error, Debug)]
28#[error(transparent)]
29pub struct ChunkCapacityError(#[from] ChunkCapacityErrorRepr);
30
31#[derive(Error, Debug)]
33enum ChunkCapacityErrorRepr {
34 #[error("Max chunk size must be greater than or equal to the desired chunk size")]
35 MaxLessThanDesired,
36}
37
38#[derive(Copy, Clone, Debug, PartialEq)]
63pub struct ChunkCapacity {
64 pub(crate) desired: usize,
65 pub(crate) max: usize,
66}
67
68impl ChunkCapacity {
69 #[must_use]
71 pub fn new(size: usize) -> Self {
72 Self {
73 desired: size,
74 max: size,
75 }
76 }
77
78 #[must_use]
84 pub fn desired(&self) -> usize {
85 self.desired
86 }
87
88 #[must_use]
93 pub fn max(&self) -> usize {
94 self.max
95 }
96
97 pub fn with_max(mut self, max: usize) -> Result<Self, ChunkCapacityError> {
110 if max < self.desired {
111 Err(ChunkCapacityError(
112 ChunkCapacityErrorRepr::MaxLessThanDesired,
113 ))
114 } else {
115 self.max = max;
116 Ok(self)
117 }
118 }
119
120 #[must_use]
126 pub fn fits(&self, chunk_size: usize) -> Ordering {
127 if chunk_size < self.desired {
128 Ordering::Less
129 } else if chunk_size > self.max {
130 Ordering::Greater
131 } else {
132 Ordering::Equal
133 }
134 }
135}
136
137impl From<usize> for ChunkCapacity {
138 fn from(size: usize) -> Self {
139 ChunkCapacity::new(size)
140 }
141}
142
143impl From<Range<usize>> for ChunkCapacity {
144 fn from(range: Range<usize>) -> Self {
145 ChunkCapacity::new(range.start)
146 .with_max(range.end.saturating_sub(1).max(range.start))
147 .expect("invalid range")
148 }
149}
150
151impl From<RangeFrom<usize>> for ChunkCapacity {
152 fn from(range: RangeFrom<usize>) -> Self {
153 ChunkCapacity::new(range.start)
154 .with_max(usize::MAX)
155 .expect("invalid range")
156 }
157}
158
159impl From<RangeFull> for ChunkCapacity {
160 fn from(_: RangeFull) -> Self {
161 ChunkCapacity::new(usize::MIN)
162 .with_max(usize::MAX)
163 .expect("invalid range")
164 }
165}
166
167impl From<RangeInclusive<usize>> for ChunkCapacity {
168 fn from(range: RangeInclusive<usize>) -> Self {
169 ChunkCapacity::new(*range.start())
170 .with_max(*range.end())
171 .expect("invalid range")
172 }
173}
174
175impl From<RangeTo<usize>> for ChunkCapacity {
176 fn from(range: RangeTo<usize>) -> Self {
177 ChunkCapacity::new(usize::MIN)
178 .with_max(range.end.saturating_sub(1))
179 .expect("invalid range")
180 }
181}
182
183impl From<RangeToInclusive<usize>> for ChunkCapacity {
184 fn from(range: RangeToInclusive<usize>) -> Self {
185 ChunkCapacity::new(usize::MIN)
186 .with_max(range.end)
187 .expect("invalid range")
188 }
189}
190
191pub trait ChunkSizer {
193 fn size(&self, chunk: &str) -> usize;
195}
196
197impl<T> ChunkSizer for &T
198where
199 T: ChunkSizer,
200{
201 fn size(&self, chunk: &str) -> usize {
202 (*self).size(chunk)
203 }
204}
205
206impl<T> ChunkSizer for Ref<'_, T>
207where
208 T: ChunkSizer,
209{
210 fn size(&self, chunk: &str) -> usize {
211 self.deref().size(chunk)
212 }
213}
214
215impl<T> ChunkSizer for RefMut<'_, T>
216where
217 T: ChunkSizer,
218{
219 fn size(&self, chunk: &str) -> usize {
220 self.deref().size(chunk)
221 }
222}
223
224impl<T> ChunkSizer for Box<T>
225where
226 T: ChunkSizer,
227{
228 fn size(&self, chunk: &str) -> usize {
229 self.deref().size(chunk)
230 }
231}
232
233impl<T> ChunkSizer for Cow<'_, T>
234where
235 T: ChunkSizer + ToOwned + ?Sized,
236 <T as ToOwned>::Owned: ChunkSizer,
237{
238 fn size(&self, chunk: &str) -> usize {
239 self.as_ref().size(chunk)
240 }
241}
242
243impl<T> ChunkSizer for Rc<T>
244where
245 T: ChunkSizer,
246{
247 fn size(&self, chunk: &str) -> usize {
248 self.deref().size(chunk)
249 }
250}
251
252impl<T> ChunkSizer for Arc<T>
253where
254 T: ChunkSizer,
255{
256 fn size(&self, chunk: &str) -> usize {
257 self.as_ref().size(chunk)
258 }
259}
260
261#[derive(Error, Debug)]
265#[error(transparent)]
266pub struct ChunkConfigError(#[from] ChunkConfigErrorRepr);
267
268#[derive(Error, Debug)]
270enum ChunkConfigErrorRepr {
271 #[error("The overlap is larger than or equal to the desired chunk capacity")]
272 OverlapLargerThanCapacity,
273}
274
275#[derive(Debug)]
277pub struct ChunkConfig<Sizer>
278where
279 Sizer: ChunkSizer,
280{
281 pub(crate) capacity: ChunkCapacity,
283 pub(crate) overlap: usize,
285 pub(crate) sizer: Sizer,
287 pub(crate) trim: bool,
289}
290
291impl ChunkConfig<Characters> {
292 #[must_use]
300 pub fn new(capacity: impl Into<ChunkCapacity>) -> Self {
301 Self {
302 capacity: capacity.into(),
303 overlap: 0,
304 sizer: Characters,
305 trim: true,
306 }
307 }
308}
309
310impl<Sizer> ChunkConfig<Sizer>
311where
312 Sizer: ChunkSizer,
313{
314 pub fn capacity(&self) -> &ChunkCapacity {
316 &self.capacity
317 }
318
319 pub fn overlap(&self) -> usize {
321 self.overlap
322 }
323
324 pub fn with_overlap(mut self, overlap: usize) -> Result<Self, ChunkConfigError> {
330 if overlap >= self.capacity.desired {
331 Err(ChunkConfigError(
332 ChunkConfigErrorRepr::OverlapLargerThanCapacity,
333 ))
334 } else {
335 self.overlap = overlap;
336 Ok(self)
337 }
338 }
339
340 pub fn sizer(&self) -> &Sizer {
342 &self.sizer
343 }
344
345 #[must_use]
353 pub fn with_sizer<S: ChunkSizer>(self, sizer: S) -> ChunkConfig<S> {
354 ChunkConfig {
355 capacity: self.capacity,
356 overlap: self.overlap,
357 sizer,
358 trim: self.trim,
359 }
360 }
361
362 pub fn trim(&self) -> bool {
364 self.trim
365 }
366
367 #[must_use]
380 pub fn with_trim(mut self, trim: bool) -> Self {
381 self.trim = trim;
382 self
383 }
384}
385
386impl<T> From<T> for ChunkConfig<Characters>
387where
388 T: Into<ChunkCapacity>,
389{
390 fn from(capacity: T) -> Self {
391 Self::new(capacity)
392 }
393}
394
395#[derive(Debug)]
399pub struct MemoizedChunkSizer<'sizer, Sizer>
400where
401 Sizer: ChunkSizer,
402{
403 size_cache: AHashMap<Range<usize>, usize>,
405 sizer: &'sizer Sizer,
407}
408
409impl<'sizer, Sizer> MemoizedChunkSizer<'sizer, Sizer>
410where
411 Sizer: ChunkSizer,
412{
413 pub fn new(sizer: &'sizer Sizer) -> Self {
415 Self {
416 size_cache: AHashMap::new(),
417 sizer,
418 }
419 }
420
421 pub fn chunk_size(&mut self, offset: usize, chunk: &str, trim: Trim) -> usize {
424 let (offset, chunk) = trim.trim(offset, chunk);
425 *self
426 .size_cache
427 .entry(offset..(offset + chunk.len()))
428 .or_insert_with(|| self.sizer.size(chunk))
429 }
430
431 pub fn find_correct_level<'text, L: fmt::Debug>(
433 &mut self,
434 offset: usize,
435 capacity: &ChunkCapacity,
436 levels_with_first_chunk: impl Iterator<Item = (L, &'text str)>,
437 trim: Trim,
438 ) -> (Option<L>, Option<usize>) {
439 let mut semantic_level = None;
440 let mut max_offset = None;
441
442 let levels_with_first_chunk =
444 levels_with_first_chunk.coalesce(|(a_level, a_str), (b_level, b_str)| {
445 if a_str.len() >= b_str.len() {
446 Ok((b_level, b_str))
447 } else {
448 Err(((a_level, a_str), (b_level, b_str)))
449 }
450 });
451
452 for (level, str) in levels_with_first_chunk {
453 let len = str.len();
455 if len > capacity.max {
456 let chunk_size = self.chunk_size(offset, str, trim);
457 let fits = capacity.fits(chunk_size);
458 if fits.is_gt() {
460 max_offset = Some(offset + len);
461 break;
462 }
463 }
464 semantic_level = Some(level);
466 }
467
468 (semantic_level, max_offset)
469 }
470
471 pub fn clear_cache(&mut self) {
474 self.size_cache.clear();
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use std::{
481 cell::RefCell,
482 sync::atomic::{self, AtomicUsize},
483 };
484
485 use crate::trim::Trim;
486
487 use super::*;
488
489 #[test]
490 fn check_chunk_capacity() {
491 let chunk = "12345";
492
493 assert_eq!(
494 ChunkCapacity::from(4).fits(Characters.size(chunk)),
495 Ordering::Greater
496 );
497 assert_eq!(
498 ChunkCapacity::from(5).fits(Characters.size(chunk)),
499 Ordering::Equal
500 );
501 assert_eq!(
502 ChunkCapacity::from(6).fits(Characters.size(chunk)),
503 Ordering::Less
504 );
505 }
506
507 #[test]
508 fn check_chunk_capacity_for_range() {
509 let chunk = "12345";
510
511 assert_eq!(
512 ChunkCapacity::from(0..0).fits(Characters.size(chunk)),
513 Ordering::Greater
514 );
515 assert_eq!(
516 ChunkCapacity::from(0..5).fits(Characters.size(chunk)),
517 Ordering::Greater
518 );
519 assert_eq!(
520 ChunkCapacity::from(5..6).fits(Characters.size(chunk)),
521 Ordering::Equal
522 );
523 assert_eq!(
524 ChunkCapacity::from(6..100).fits(Characters.size(chunk)),
525 Ordering::Less
526 );
527 }
528
529 #[test]
530 fn check_chunk_capacity_for_range_from() {
531 let chunk = "12345";
532
533 assert_eq!(
534 ChunkCapacity::from(0..).fits(Characters.size(chunk)),
535 Ordering::Equal
536 );
537 assert_eq!(
538 ChunkCapacity::from(5..).fits(Characters.size(chunk)),
539 Ordering::Equal
540 );
541 assert_eq!(
542 ChunkCapacity::from(6..).fits(Characters.size(chunk)),
543 Ordering::Less
544 );
545 }
546
547 #[test]
548 fn check_chunk_capacity_for_range_full() {
549 let chunk = "12345";
550
551 assert_eq!(
552 ChunkCapacity::from(..).fits(Characters.size(chunk)),
553 Ordering::Equal
554 );
555 }
556
557 #[test]
558 fn check_chunk_capacity_for_range_inclusive() {
559 let chunk = "12345";
560
561 assert_eq!(
562 ChunkCapacity::from(0..=4).fits(Characters.size(chunk)),
563 Ordering::Greater
564 );
565 assert_eq!(
566 ChunkCapacity::from(5..=6).fits(Characters.size(chunk)),
567 Ordering::Equal
568 );
569 assert_eq!(
570 ChunkCapacity::from(4..=5).fits(Characters.size(chunk)),
571 Ordering::Equal
572 );
573 assert_eq!(
574 ChunkCapacity::from(6..=100).fits(Characters.size(chunk)),
575 Ordering::Less
576 );
577 }
578
579 #[test]
580 fn check_chunk_capacity_for_range_to() {
581 let chunk = "12345";
582
583 assert_eq!(
584 ChunkCapacity::from(..0).fits(Characters.size(chunk)),
585 Ordering::Greater
586 );
587 assert_eq!(
588 ChunkCapacity::from(..5).fits(Characters.size(chunk)),
589 Ordering::Greater
590 );
591 assert_eq!(
592 ChunkCapacity::from(..6).fits(Characters.size(chunk)),
593 Ordering::Equal
594 );
595 }
596
597 #[test]
598 fn check_chunk_capacity_for_range_to_inclusive() {
599 let chunk = "12345";
600
601 assert_eq!(
602 ChunkCapacity::from(..=4).fits(Characters.size(chunk)),
603 Ordering::Greater
604 );
605 assert_eq!(
606 ChunkCapacity::from(..=5).fits(Characters.size(chunk)),
607 Ordering::Equal
608 );
609 assert_eq!(
610 ChunkCapacity::from(..=6).fits(Characters.size(chunk)),
611 Ordering::Equal
612 );
613 }
614
615 #[derive(Default)]
616 struct CountingSizer {
617 calls: AtomicUsize,
618 }
619
620 impl ChunkSizer for CountingSizer {
621 fn size(&self, chunk: &str) -> usize {
623 self.calls.fetch_add(1, atomic::Ordering::SeqCst);
624 Characters.size(chunk)
625 }
626 }
627
628 #[test]
629 fn memoized_sizer_only_calculates_once_per_text() {
630 let sizer = CountingSizer::default();
631 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
632 let text = "1234567890";
633 for _ in 0..10 {
634 memoized_sizer.chunk_size(0, text, Trim::All);
635 }
636
637 assert_eq!(memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst), 1);
638 }
639
640 #[test]
641 fn memoized_sizer_calculates_once_per_different_text() {
642 let sizer = CountingSizer::default();
643 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
644 let text = "1234567890";
645 for i in 0..10 {
646 memoized_sizer.chunk_size(0, text.get(0..i).unwrap(), Trim::All);
647 }
648
649 assert_eq!(
650 memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
651 10
652 );
653 }
654
655 #[test]
656 fn can_clear_cache_on_memoized_sizer() {
657 let sizer = CountingSizer::default();
658 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
659 let text = "1234567890";
660 for _ in 0..10 {
661 memoized_sizer.chunk_size(0, text, Trim::All);
662 memoized_sizer.clear_cache();
663 }
664
665 assert_eq!(
666 memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
667 10
668 );
669 }
670
671 #[test]
672 fn basic_chunk_config() {
673 let config = ChunkConfig::new(10);
674 assert_eq!(config.capacity, 10.into());
675 assert_eq!(config.sizer, Characters);
676 assert!(config.trim());
677 }
678
679 #[test]
680 fn disable_trimming() {
681 let config = ChunkConfig::new(10).with_trim(false);
682 assert!(!config.trim());
683 }
684
685 #[test]
686 fn new_sizer() {
687 #[derive(Debug, PartialEq)]
688 struct BasicSizer;
689
690 impl ChunkSizer for BasicSizer {
691 fn size(&self, _chunk: &str) -> usize {
692 unimplemented!()
693 }
694 }
695
696 let config = ChunkConfig::new(10).with_sizer(BasicSizer);
697 assert_eq!(config.capacity, 10.into());
698 assert_eq!(config.sizer, BasicSizer);
699 assert!(config.trim());
700 }
701
702 #[test]
703 fn chunk_capacity_max_and_desired_equal() {
704 let capacity = ChunkCapacity::new(10);
705 assert_eq!(capacity.desired(), 10);
706 assert_eq!(capacity.max(), 10);
707 }
708
709 #[test]
710 fn chunk_capacity_can_adjust_max() {
711 let capacity = ChunkCapacity::new(10).with_max(20).unwrap();
712 assert_eq!(capacity.desired(), 10);
713 assert_eq!(capacity.max(), 20);
714 }
715
716 #[test]
717 fn chunk_capacity_max_cant_be_less_than_desired() {
718 let capacity = ChunkCapacity::new(10);
719 let err = capacity.with_max(5).unwrap_err();
720 assert_eq!(
721 err.to_string(),
722 "Max chunk size must be greater than or equal to the desired chunk size"
723 );
724 assert_eq!(capacity.desired(), 10);
725 assert_eq!(capacity.max(), 10);
726 }
727
728 #[test]
729 fn set_chunk_overlap() {
730 let config = ChunkConfig::new(10).with_overlap(5).unwrap();
731 assert_eq!(config.overlap(), 5);
732 }
733
734 #[test]
735 fn cant_set_overlap_larger_than_capacity() {
736 let chunk_config = ChunkConfig::new(5);
737 let err = chunk_config.with_overlap(10).unwrap_err();
738 assert_eq!(
739 err.to_string(),
740 "The overlap is larger than or equal to the desired chunk capacity"
741 );
742 }
743
744 #[test]
745 fn cant_set_overlap_larger_than_desired() {
746 let chunk_config = ChunkConfig::new(5..15);
747 let err = chunk_config.with_overlap(10).unwrap_err();
748 assert_eq!(
749 err.to_string(),
750 "The overlap is larger than or equal to the desired chunk capacity"
751 );
752 }
753
754 #[test]
755 fn chunk_size_reference() {
756 let config = ChunkConfig::new(1).with_sizer(&Characters);
757 config.sizer().size("chunk");
758 }
759
760 #[test]
761 fn chunk_size_cow() {
762 let sizer: Cow<'_, Characters> = Cow::Owned(Characters);
763 let config = ChunkConfig::new(1).with_sizer(sizer);
764 config.sizer().size("chunk");
765
766 let sizer = Cow::Borrowed(&Characters);
767 let config = ChunkConfig::new(1).with_sizer(sizer);
768 config.sizer().size("chunk");
769 }
770
771 #[test]
772 fn chunk_size_arc() {
773 let sizer = Arc::new(Characters);
774 let config = ChunkConfig::new(1).with_sizer(sizer);
775 config.sizer().size("chunk");
776 }
777
778 #[test]
779 fn chunk_size_ref() {
780 let sizer = RefCell::new(Characters);
781 let config = ChunkConfig::new(1).with_sizer(sizer.borrow());
782 config.sizer().size("chunk");
783 }
784
785 #[test]
786 fn chunk_size_ref_mut() {
787 let sizer = RefCell::new(Characters);
788 let config = ChunkConfig::new(1).with_sizer(sizer.borrow_mut());
789 config.sizer().size("chunk");
790 }
791
792 #[test]
793 fn chunk_size_box() {
794 let sizer = Box::new(Characters);
795 let config = ChunkConfig::new(1).with_sizer(sizer);
796 config.sizer().size("chunk");
797 }
798
799 #[test]
800 fn chunk_size_rc() {
801 let sizer = Rc::new(Characters);
802 let config = ChunkConfig::new(1).with_sizer(sizer);
803 config.sizer().size("chunk");
804 }
805}