1use std::{
2 cmp::Ordering,
3 fmt,
4 ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
5};
6
7use ahash::AHashMap;
8use itertools::Itertools;
9use thiserror::Error;
10
11mod characters;
12#[cfg(feature = "tokenizers")]
13mod huggingface;
14#[cfg(feature = "rust-tokenizers")]
15mod rust_tokenizers;
16#[cfg(feature = "tiktoken-rs")]
17mod tiktoken;
18
19use crate::trim::Trim;
20pub use characters::Characters;
21
22#[derive(Error, Debug)]
26#[error(transparent)]
27pub struct ChunkCapacityError(#[from] ChunkCapacityErrorRepr);
28
29#[derive(Error, Debug)]
31enum ChunkCapacityErrorRepr {
32 #[error("Max chunk size must be greater than or equal to the desired chunk size")]
33 MaxLessThanDesired,
34}
35
36#[derive(Copy, Clone, Debug, PartialEq)]
61pub struct ChunkCapacity {
62 pub(crate) desired: usize,
63 pub(crate) max: usize,
64}
65
66impl ChunkCapacity {
67 #[must_use]
69 pub fn new(size: usize) -> Self {
70 Self {
71 desired: size,
72 max: size,
73 }
74 }
75
76 #[must_use]
82 pub fn desired(&self) -> usize {
83 self.desired
84 }
85
86 #[must_use]
91 pub fn max(&self) -> usize {
92 self.max
93 }
94
95 pub fn with_max(mut self, max: usize) -> Result<Self, ChunkCapacityError> {
108 if max < self.desired {
109 Err(ChunkCapacityError(
110 ChunkCapacityErrorRepr::MaxLessThanDesired,
111 ))
112 } else {
113 self.max = max;
114 Ok(self)
115 }
116 }
117
118 #[must_use]
124 pub fn fits(&self, chunk_size: usize) -> Ordering {
125 if chunk_size < self.desired {
126 Ordering::Less
127 } else if chunk_size > self.max {
128 Ordering::Greater
129 } else {
130 Ordering::Equal
131 }
132 }
133}
134
135impl From<usize> for ChunkCapacity {
136 fn from(size: usize) -> Self {
137 ChunkCapacity::new(size)
138 }
139}
140
141impl From<Range<usize>> for ChunkCapacity {
142 fn from(range: Range<usize>) -> Self {
143 ChunkCapacity::new(range.start)
144 .with_max(range.end.saturating_sub(1).max(range.start))
145 .expect("invalid range")
146 }
147}
148
149impl From<RangeFrom<usize>> for ChunkCapacity {
150 fn from(range: RangeFrom<usize>) -> Self {
151 ChunkCapacity::new(range.start)
152 .with_max(usize::MAX)
153 .expect("invalid range")
154 }
155}
156
157impl From<RangeFull> for ChunkCapacity {
158 fn from(_: RangeFull) -> Self {
159 ChunkCapacity::new(usize::MIN)
160 .with_max(usize::MAX)
161 .expect("invalid range")
162 }
163}
164
165impl From<RangeInclusive<usize>> for ChunkCapacity {
166 fn from(range: RangeInclusive<usize>) -> Self {
167 ChunkCapacity::new(*range.start())
168 .with_max(*range.end())
169 .expect("invalid range")
170 }
171}
172
173impl From<RangeTo<usize>> for ChunkCapacity {
174 fn from(range: RangeTo<usize>) -> Self {
175 ChunkCapacity::new(usize::MIN)
176 .with_max(range.end.saturating_sub(1))
177 .expect("invalid range")
178 }
179}
180
181impl From<RangeToInclusive<usize>> for ChunkCapacity {
182 fn from(range: RangeToInclusive<usize>) -> Self {
183 ChunkCapacity::new(usize::MIN)
184 .with_max(range.end)
185 .expect("invalid range")
186 }
187}
188
189pub trait ChunkSizer {
191 fn size(&self, chunk: &str) -> usize;
193}
194
195#[derive(Error, Debug)]
199#[error(transparent)]
200pub struct ChunkConfigError(#[from] ChunkConfigErrorRepr);
201
202#[derive(Error, Debug)]
204enum ChunkConfigErrorRepr {
205 #[error("The overlap is larger than or equal to the desired chunk capacity")]
206 OverlapLargerThanCapacity,
207}
208
209#[derive(Debug)]
211pub struct ChunkConfig<Sizer>
212where
213 Sizer: ChunkSizer,
214{
215 pub(crate) capacity: ChunkCapacity,
217 pub(crate) overlap: usize,
219 pub(crate) sizer: Sizer,
221 pub(crate) trim: bool,
223}
224
225impl ChunkConfig<Characters> {
226 #[must_use]
234 pub fn new(capacity: impl Into<ChunkCapacity>) -> Self {
235 Self {
236 capacity: capacity.into(),
237 overlap: 0,
238 sizer: Characters,
239 trim: true,
240 }
241 }
242}
243
244impl<Sizer> ChunkConfig<Sizer>
245where
246 Sizer: ChunkSizer,
247{
248 pub fn capacity(&self) -> &ChunkCapacity {
250 &self.capacity
251 }
252
253 pub fn overlap(&self) -> usize {
255 self.overlap
256 }
257
258 pub fn with_overlap(mut self, overlap: usize) -> Result<Self, ChunkConfigError> {
264 if overlap >= self.capacity.desired {
265 Err(ChunkConfigError(
266 ChunkConfigErrorRepr::OverlapLargerThanCapacity,
267 ))
268 } else {
269 self.overlap = overlap;
270 Ok(self)
271 }
272 }
273
274 pub fn sizer(&self) -> &Sizer {
276 &self.sizer
277 }
278
279 #[must_use]
287 pub fn with_sizer<S: ChunkSizer>(self, sizer: S) -> ChunkConfig<S> {
288 ChunkConfig {
289 capacity: self.capacity,
290 overlap: self.overlap,
291 sizer,
292 trim: self.trim,
293 }
294 }
295
296 pub fn trim(&self) -> bool {
298 self.trim
299 }
300
301 #[must_use]
314 pub fn with_trim(mut self, trim: bool) -> Self {
315 self.trim = trim;
316 self
317 }
318}
319
320impl<T> From<T> for ChunkConfig<Characters>
321where
322 T: Into<ChunkCapacity>,
323{
324 fn from(capacity: T) -> Self {
325 Self::new(capacity)
326 }
327}
328
329#[derive(Debug)]
333pub struct MemoizedChunkSizer<'sizer, Sizer>
334where
335 Sizer: ChunkSizer,
336{
337 size_cache: AHashMap<Range<usize>, usize>,
339 sizer: &'sizer Sizer,
341}
342
343impl<'sizer, Sizer> MemoizedChunkSizer<'sizer, Sizer>
344where
345 Sizer: ChunkSizer,
346{
347 pub fn new(sizer: &'sizer Sizer) -> Self {
349 Self {
350 size_cache: AHashMap::new(),
351 sizer,
352 }
353 }
354
355 pub fn chunk_size(&mut self, offset: usize, chunk: &str, trim: Trim) -> usize {
358 let (offset, chunk) = trim.trim(offset, chunk);
359 *self
360 .size_cache
361 .entry(offset..(offset + chunk.len()))
362 .or_insert_with(|| self.sizer.size(chunk))
363 }
364
365 pub fn find_correct_level<'text, L: fmt::Debug>(
367 &mut self,
368 offset: usize,
369 capacity: &ChunkCapacity,
370 levels_with_first_chunk: impl Iterator<Item = (L, &'text str)>,
371 trim: Trim,
372 ) -> (Option<L>, Option<usize>) {
373 let mut semantic_level = None;
374 let mut max_offset = None;
375
376 let levels_with_first_chunk =
378 levels_with_first_chunk.coalesce(|(a_level, a_str), (b_level, b_str)| {
379 if a_str.len() >= b_str.len() {
380 Ok((b_level, b_str))
381 } else {
382 Err(((a_level, a_str), (b_level, b_str)))
383 }
384 });
385
386 for (level, str) in levels_with_first_chunk {
387 let len = str.len();
389 if len > capacity.max {
390 let chunk_size = self.chunk_size(offset, str, trim);
391 let fits = capacity.fits(chunk_size);
392 if fits.is_gt() {
394 max_offset = Some(offset + len);
395 break;
396 }
397 }
398 semantic_level = Some(level);
400 }
401
402 (semantic_level, max_offset)
403 }
404
405 pub fn clear_cache(&mut self) {
408 self.size_cache.clear();
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use std::sync::atomic::{self, AtomicUsize};
415
416 use crate::trim::Trim;
417
418 use super::*;
419
420 #[test]
421 fn check_chunk_capacity() {
422 let chunk = "12345";
423
424 assert_eq!(
425 ChunkCapacity::from(4).fits(Characters.size(chunk)),
426 Ordering::Greater
427 );
428 assert_eq!(
429 ChunkCapacity::from(5).fits(Characters.size(chunk)),
430 Ordering::Equal
431 );
432 assert_eq!(
433 ChunkCapacity::from(6).fits(Characters.size(chunk)),
434 Ordering::Less
435 );
436 }
437
438 #[test]
439 fn check_chunk_capacity_for_range() {
440 let chunk = "12345";
441
442 assert_eq!(
443 ChunkCapacity::from(0..0).fits(Characters.size(chunk)),
444 Ordering::Greater
445 );
446 assert_eq!(
447 ChunkCapacity::from(0..5).fits(Characters.size(chunk)),
448 Ordering::Greater
449 );
450 assert_eq!(
451 ChunkCapacity::from(5..6).fits(Characters.size(chunk)),
452 Ordering::Equal
453 );
454 assert_eq!(
455 ChunkCapacity::from(6..100).fits(Characters.size(chunk)),
456 Ordering::Less
457 );
458 }
459
460 #[test]
461 fn check_chunk_capacity_for_range_from() {
462 let chunk = "12345";
463
464 assert_eq!(
465 ChunkCapacity::from(0..).fits(Characters.size(chunk)),
466 Ordering::Equal
467 );
468 assert_eq!(
469 ChunkCapacity::from(5..).fits(Characters.size(chunk)),
470 Ordering::Equal
471 );
472 assert_eq!(
473 ChunkCapacity::from(6..).fits(Characters.size(chunk)),
474 Ordering::Less
475 );
476 }
477
478 #[test]
479 fn check_chunk_capacity_for_range_full() {
480 let chunk = "12345";
481
482 assert_eq!(
483 ChunkCapacity::from(..).fits(Characters.size(chunk)),
484 Ordering::Equal
485 );
486 }
487
488 #[test]
489 fn check_chunk_capacity_for_range_inclusive() {
490 let chunk = "12345";
491
492 assert_eq!(
493 ChunkCapacity::from(0..=4).fits(Characters.size(chunk)),
494 Ordering::Greater
495 );
496 assert_eq!(
497 ChunkCapacity::from(5..=6).fits(Characters.size(chunk)),
498 Ordering::Equal
499 );
500 assert_eq!(
501 ChunkCapacity::from(4..=5).fits(Characters.size(chunk)),
502 Ordering::Equal
503 );
504 assert_eq!(
505 ChunkCapacity::from(6..=100).fits(Characters.size(chunk)),
506 Ordering::Less
507 );
508 }
509
510 #[test]
511 fn check_chunk_capacity_for_range_to() {
512 let chunk = "12345";
513
514 assert_eq!(
515 ChunkCapacity::from(..0).fits(Characters.size(chunk)),
516 Ordering::Greater
517 );
518 assert_eq!(
519 ChunkCapacity::from(..5).fits(Characters.size(chunk)),
520 Ordering::Greater
521 );
522 assert_eq!(
523 ChunkCapacity::from(..6).fits(Characters.size(chunk)),
524 Ordering::Equal
525 );
526 }
527
528 #[test]
529 fn check_chunk_capacity_for_range_to_inclusive() {
530 let chunk = "12345";
531
532 assert_eq!(
533 ChunkCapacity::from(..=4).fits(Characters.size(chunk)),
534 Ordering::Greater
535 );
536 assert_eq!(
537 ChunkCapacity::from(..=5).fits(Characters.size(chunk)),
538 Ordering::Equal
539 );
540 assert_eq!(
541 ChunkCapacity::from(..=6).fits(Characters.size(chunk)),
542 Ordering::Equal
543 );
544 }
545
546 #[derive(Default)]
547 struct CountingSizer {
548 calls: AtomicUsize,
549 }
550
551 impl ChunkSizer for CountingSizer {
552 fn size(&self, chunk: &str) -> usize {
554 self.calls.fetch_add(1, atomic::Ordering::SeqCst);
555 Characters.size(chunk)
556 }
557 }
558
559 #[test]
560 fn memoized_sizer_only_calculates_once_per_text() {
561 let sizer = CountingSizer::default();
562 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
563 let text = "1234567890";
564 for _ in 0..10 {
565 memoized_sizer.chunk_size(0, text, Trim::All);
566 }
567
568 assert_eq!(memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst), 1);
569 }
570
571 #[test]
572 fn memoized_sizer_calculates_once_per_different_text() {
573 let sizer = CountingSizer::default();
574 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
575 let text = "1234567890";
576 for i in 0..10 {
577 memoized_sizer.chunk_size(0, text.get(0..i).unwrap(), Trim::All);
578 }
579
580 assert_eq!(
581 memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
582 10
583 );
584 }
585
586 #[test]
587 fn can_clear_cache_on_memoized_sizer() {
588 let sizer = CountingSizer::default();
589 let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
590 let text = "1234567890";
591 for _ in 0..10 {
592 memoized_sizer.chunk_size(0, text, Trim::All);
593 memoized_sizer.clear_cache();
594 }
595
596 assert_eq!(
597 memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
598 10
599 );
600 }
601
602 #[test]
603 fn basic_chunk_config() {
604 let config = ChunkConfig::new(10);
605 assert_eq!(config.capacity, 10.into());
606 assert_eq!(config.sizer, Characters);
607 assert!(config.trim());
608 }
609
610 #[test]
611 fn disable_trimming() {
612 let config = ChunkConfig::new(10).with_trim(false);
613 assert!(!config.trim());
614 }
615
616 #[test]
617 fn new_sizer() {
618 #[derive(Debug, PartialEq)]
619 struct BasicSizer;
620
621 impl ChunkSizer for BasicSizer {
622 fn size(&self, _chunk: &str) -> usize {
623 unimplemented!()
624 }
625 }
626
627 let config = ChunkConfig::new(10).with_sizer(BasicSizer);
628 assert_eq!(config.capacity, 10.into());
629 assert_eq!(config.sizer, BasicSizer);
630 assert!(config.trim());
631 }
632
633 #[test]
634 fn chunk_capacity_max_and_desired_equal() {
635 let capacity = ChunkCapacity::new(10);
636 assert_eq!(capacity.desired(), 10);
637 assert_eq!(capacity.max(), 10);
638 }
639
640 #[test]
641 fn chunk_capacity_can_adjust_max() {
642 let capacity = ChunkCapacity::new(10).with_max(20).unwrap();
643 assert_eq!(capacity.desired(), 10);
644 assert_eq!(capacity.max(), 20);
645 }
646
647 #[test]
648 fn chunk_capacity_max_cant_be_less_than_desired() {
649 let capacity = ChunkCapacity::new(10);
650 let err = capacity.with_max(5).unwrap_err();
651 assert_eq!(
652 err.to_string(),
653 "Max chunk size must be greater than or equal to the desired chunk size"
654 );
655 assert_eq!(capacity.desired(), 10);
656 assert_eq!(capacity.max(), 10);
657 }
658
659 #[test]
660 fn set_chunk_overlap() {
661 let config = ChunkConfig::new(10).with_overlap(5).unwrap();
662 assert_eq!(config.overlap(), 5);
663 }
664
665 #[test]
666 fn cant_set_overlap_larger_than_capacity() {
667 let chunk_config = ChunkConfig::new(5);
668 let err = chunk_config.with_overlap(10).unwrap_err();
669 assert_eq!(
670 err.to_string(),
671 "The overlap is larger than or equal to the desired chunk capacity"
672 );
673 }
674
675 #[test]
676 fn cant_set_overlap_larger_than_desired() {
677 let chunk_config = ChunkConfig::new(5..15);
678 let err = chunk_config.with_overlap(10).unwrap_err();
679 assert_eq!(
680 err.to_string(),
681 "The overlap is larger than or equal to the desired chunk capacity"
682 );
683 }
684}