text_splitter/
chunk_size.rs

1use std::{
2    cmp::Ordering,
3    fmt,
4    ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
5};
6
7use ahash::AHashMap;
8use itertools::Itertools;
9use thiserror::Error;
10
11mod characters;
12#[cfg(feature = "tokenizers")]
13mod huggingface;
14#[cfg(feature = "rust-tokenizers")]
15mod rust_tokenizers;
16#[cfg(feature = "tiktoken-rs")]
17mod tiktoken;
18
19use crate::trim::Trim;
20pub use characters::Characters;
21
22/// Indicates there was an error with the chunk capacity configuration.
23/// The `Display` implementation will provide a human-readable error message to
24/// help debug the issue that caused the error.
25#[derive(Error, Debug)]
26#[error(transparent)]
27pub struct ChunkCapacityError(#[from] ChunkCapacityErrorRepr);
28
29/// Private error and free to change across minor version of the crate.
30#[derive(Error, Debug)]
31enum ChunkCapacityErrorRepr {
32    #[error("Max chunk size must be greater than or equal to the desired chunk size")]
33    MaxLessThanDesired,
34}
35
36/// Describes the valid chunk size(s) that can be generated.
37///
38/// The `desired` size is the target size for the chunk. In most cases, this
39/// will also serve as the maximum size of the chunk. It is always possible
40/// that a chunk may be returned that is less than the `desired` value, as
41/// adding the next piece of text may have made it larger than the `desired`
42/// capacity.
43///
44/// The `max` size is the maximum possible chunk size that can be generated.
45/// By setting this to a larger value than `desired`, it means that the chunk
46/// should be as close to `desired` as possible, but can be larger if it means
47/// staying at a larger semantic level.
48///
49/// The splitter will consume text until at maxumum somewhere between `desired`
50/// and `max`, if they differ, but never above `max`.
51///
52/// If you need to ensure a fixed size, set `desired` and `max` to the same
53/// value. For example, if you are trying to maximize the context window for an
54/// embedding.
55///
56/// If you are loosely targeting a size, but have some extra room, for example
57/// in a RAG use case where you roughly want a certain part of a document, you
58/// can set `max` to your absolute maxumum, and the splitter can stay at a
59/// higher semantic level when determining the chunk.
60#[derive(Copy, Clone, Debug, PartialEq)]
61pub struct ChunkCapacity {
62    pub(crate) desired: usize,
63    pub(crate) max: usize,
64}
65
66impl ChunkCapacity {
67    /// Create a new `ChunkCapacity` with the same `desired` and `max` size.
68    #[must_use]
69    pub fn new(size: usize) -> Self {
70        Self {
71            desired: size,
72            max: size,
73        }
74    }
75
76    /// The `desired` size is the target size for the chunk. In most cases, this
77    /// will also serve as the maximum size of the chunk. It is always possible
78    /// that a chunk may be returned that is less than the `desired` value, as
79    /// adding the next piece of text may have made it larger than the `desired`
80    /// capacity.
81    #[must_use]
82    pub fn desired(&self) -> usize {
83        self.desired
84    }
85
86    /// The `max` size is the maximum possible chunk size that can be generated.
87    /// By setting this to a larger value than `desired`, it means that the chunk
88    /// should be as close to `desired` as possible, but can be larger if it means
89    /// staying at a larger semantic level.
90    #[must_use]
91    pub fn max(&self) -> usize {
92        self.max
93    }
94
95    /// If you need to ensure a fixed size, set `desired` and `max` to the same
96    /// value. For example, if you are trying to maximize the context window for an
97    /// embedding.
98    ///
99    /// If you are loosely targeting a size, but have some extra room, for example
100    /// in a RAG use case where you roughly want a certain part of a document, you
101    /// can set `max` to your absolute maxumum, and the splitter can stay at a
102    /// higher semantic level when determining the chunk.
103    ///
104    /// # Errors
105    ///
106    /// If the `max` size is less than the `desired` size, an error is returned.
107    pub fn with_max(mut self, max: usize) -> Result<Self, ChunkCapacityError> {
108        if max < self.desired {
109            Err(ChunkCapacityError(
110                ChunkCapacityErrorRepr::MaxLessThanDesired,
111            ))
112        } else {
113            self.max = max;
114            Ok(self)
115        }
116    }
117
118    /// Validate if a given chunk fits within the capacity
119    ///
120    /// - `Ordering::Less` indicates more could be added
121    /// - `Ordering::Equal` indicates the chunk is within the capacity range
122    /// - `Ordering::Greater` indicates the chunk is larger than the capacity
123    #[must_use]
124    pub fn fits(&self, chunk_size: usize) -> Ordering {
125        if chunk_size < self.desired {
126            Ordering::Less
127        } else if chunk_size > self.max {
128            Ordering::Greater
129        } else {
130            Ordering::Equal
131        }
132    }
133}
134
135impl From<usize> for ChunkCapacity {
136    fn from(size: usize) -> Self {
137        ChunkCapacity::new(size)
138    }
139}
140
141impl From<Range<usize>> for ChunkCapacity {
142    fn from(range: Range<usize>) -> Self {
143        ChunkCapacity::new(range.start)
144            .with_max(range.end.saturating_sub(1).max(range.start))
145            .expect("invalid range")
146    }
147}
148
149impl From<RangeFrom<usize>> for ChunkCapacity {
150    fn from(range: RangeFrom<usize>) -> Self {
151        ChunkCapacity::new(range.start)
152            .with_max(usize::MAX)
153            .expect("invalid range")
154    }
155}
156
157impl From<RangeFull> for ChunkCapacity {
158    fn from(_: RangeFull) -> Self {
159        ChunkCapacity::new(usize::MIN)
160            .with_max(usize::MAX)
161            .expect("invalid range")
162    }
163}
164
165impl From<RangeInclusive<usize>> for ChunkCapacity {
166    fn from(range: RangeInclusive<usize>) -> Self {
167        ChunkCapacity::new(*range.start())
168            .with_max(*range.end())
169            .expect("invalid range")
170    }
171}
172
173impl From<RangeTo<usize>> for ChunkCapacity {
174    fn from(range: RangeTo<usize>) -> Self {
175        ChunkCapacity::new(usize::MIN)
176            .with_max(range.end.saturating_sub(1))
177            .expect("invalid range")
178    }
179}
180
181impl From<RangeToInclusive<usize>> for ChunkCapacity {
182    fn from(range: RangeToInclusive<usize>) -> Self {
183        ChunkCapacity::new(usize::MIN)
184            .with_max(range.end)
185            .expect("invalid range")
186    }
187}
188
189/// Determines the size of a given chunk.
190pub trait ChunkSizer {
191    /// Determine the size of a given chunk to use for validation
192    fn size(&self, chunk: &str) -> usize;
193}
194
195/// Indicates there was an error with the chunk configuration.
196/// The `Display` implementation will provide a human-readable error message to
197/// help debug the issue that caused the error.
198#[derive(Error, Debug)]
199#[error(transparent)]
200pub struct ChunkConfigError(#[from] ChunkConfigErrorRepr);
201
202/// Private error and free to change across minor version of the crate.
203#[derive(Error, Debug)]
204enum ChunkConfigErrorRepr {
205    #[error("The overlap is larger than or equal to the desired chunk capacity")]
206    OverlapLargerThanCapacity,
207}
208
209/// Configuration for how chunks should be created
210#[derive(Debug)]
211pub struct ChunkConfig<Sizer>
212where
213    Sizer: ChunkSizer,
214{
215    /// The chunk capacity to use for filling chunks
216    pub(crate) capacity: ChunkCapacity,
217    /// The amount of overlap between chunks. Defaults to 0.
218    pub(crate) overlap: usize,
219    /// The chunk sizer to use for determining the size of each chunk
220    pub(crate) sizer: Sizer,
221    /// Whether whitespace will be trimmed from the beginning and end of each chunk
222    pub(crate) trim: bool,
223}
224
225impl ChunkConfig<Characters> {
226    /// Create a basic configuration for chunking with only the required value a chunk capacity.
227    ///
228    /// By default, chunk sizes will be calculated based on the number of characters in each chunk.
229    /// You can set a custom chunk sizer by calling [`Self::with_sizer`].
230    ///
231    /// By default, chunks will be trimmed. If you want to preserve whitespace,
232    /// call [`Self::with_trim`] and set it to `false`.
233    #[must_use]
234    pub fn new(capacity: impl Into<ChunkCapacity>) -> Self {
235        Self {
236            capacity: capacity.into(),
237            overlap: 0,
238            sizer: Characters,
239            trim: true,
240        }
241    }
242}
243
244impl<Sizer> ChunkConfig<Sizer>
245where
246    Sizer: ChunkSizer,
247{
248    /// Retrieve a reference to the chunk capacity for this configuration.
249    pub fn capacity(&self) -> &ChunkCapacity {
250        &self.capacity
251    }
252
253    /// Retrieve the amount of overlap between chunks.
254    pub fn overlap(&self) -> usize {
255        self.overlap
256    }
257
258    /// Set the amount of overlap between chunks.
259    ///
260    /// # Errors
261    ///
262    /// Will return an error if the overlap is larger than or equal to the chunk capacity.
263    pub fn with_overlap(mut self, overlap: usize) -> Result<Self, ChunkConfigError> {
264        if overlap >= self.capacity.desired {
265            Err(ChunkConfigError(
266                ChunkConfigErrorRepr::OverlapLargerThanCapacity,
267            ))
268        } else {
269            self.overlap = overlap;
270            Ok(self)
271        }
272    }
273
274    /// Retrieve a reference to the chunk sizer for this configuration.
275    pub fn sizer(&self) -> &Sizer {
276        &self.sizer
277    }
278
279    /// Set a custom chunk sizer to use for determining the size of each chunk
280    ///
281    /// ```
282    /// use text_splitter::{Characters, ChunkConfig};
283    ///
284    /// let config = ChunkConfig::new(512).with_sizer(Characters);
285    /// ```
286    #[must_use]
287    pub fn with_sizer<S: ChunkSizer>(self, sizer: S) -> ChunkConfig<S> {
288        ChunkConfig {
289            capacity: self.capacity,
290            overlap: self.overlap,
291            sizer,
292            trim: self.trim,
293        }
294    }
295
296    /// Whether chunkd should have whitespace trimmed from the beginning and end or not.
297    pub fn trim(&self) -> bool {
298        self.trim
299    }
300
301    /// Specify whether chunks should have whitespace trimmed from the
302    /// beginning and end or not.
303    ///
304    /// If `false` (default), joining all chunks should return the original
305    /// string.
306    /// If `true`, all chunks will have whitespace removed from beginning and end.
307    ///
308    /// ```
309    /// use text_splitter::ChunkConfig;
310    ///
311    /// let config = ChunkConfig::new(512).with_trim(false);
312    /// ```
313    #[must_use]
314    pub fn with_trim(mut self, trim: bool) -> Self {
315        self.trim = trim;
316        self
317    }
318}
319
320impl<T> From<T> for ChunkConfig<Characters>
321where
322    T: Into<ChunkCapacity>,
323{
324    fn from(capacity: T) -> Self {
325        Self::new(capacity)
326    }
327}
328
329/// A memoized chunk sizer that caches the size of chunks.
330/// Very helpful when the same chunk is being validated multiple times, which
331/// happens often, and can be expensive to compute, such as with tokenizers.
332#[derive(Debug)]
333pub struct MemoizedChunkSizer<'sizer, Sizer>
334where
335    Sizer: ChunkSizer,
336{
337    /// Cache of chunk sizes per byte offset range for base capacity
338    size_cache: AHashMap<Range<usize>, usize>,
339    /// The sizer used for caluclating chunk sizes
340    sizer: &'sizer Sizer,
341}
342
343impl<'sizer, Sizer> MemoizedChunkSizer<'sizer, Sizer>
344where
345    Sizer: ChunkSizer,
346{
347    /// Wrap any chunk sizer for memoization
348    pub fn new(sizer: &'sizer Sizer) -> Self {
349        Self {
350            size_cache: AHashMap::new(),
351            sizer,
352        }
353    }
354
355    /// Determine the size of a given chunk to use for validation,
356    /// returning a cached value if it exists, and storing the result if not.
357    pub fn chunk_size(&mut self, offset: usize, chunk: &str, trim: Trim) -> usize {
358        let (offset, chunk) = trim.trim(offset, chunk);
359        *self
360            .size_cache
361            .entry(offset..(offset + chunk.len()))
362            .or_insert_with(|| self.sizer.size(chunk))
363    }
364
365    /// Find the best level to start splitting the text
366    pub fn find_correct_level<'text, L: fmt::Debug>(
367        &mut self,
368        offset: usize,
369        capacity: &ChunkCapacity,
370        levels_with_first_chunk: impl Iterator<Item = (L, &'text str)>,
371        trim: Trim,
372    ) -> (Option<L>, Option<usize>) {
373        let mut semantic_level = None;
374        let mut max_offset = None;
375
376        // We assume that larger levels are also longer. We can skip lower levels if going to a higher level would result in a shorter text
377        let levels_with_first_chunk =
378            levels_with_first_chunk.coalesce(|(a_level, a_str), (b_level, b_str)| {
379                if a_str.len() >= b_str.len() {
380                    Ok((b_level, b_str))
381                } else {
382                    Err(((a_level, a_str), (b_level, b_str)))
383                }
384            });
385
386        for (level, str) in levels_with_first_chunk {
387            // Skip tokenizing levels that we know are too small anyway.
388            let len = str.len();
389            if len > capacity.max {
390                let chunk_size = self.chunk_size(offset, str, trim);
391                let fits = capacity.fits(chunk_size);
392                // If this no longer fits, we use the level we are at.
393                if fits.is_gt() {
394                    max_offset = Some(offset + len);
395                    break;
396                }
397            }
398            // Otherwise break up the text with the next level
399            semantic_level = Some(level);
400        }
401
402        (semantic_level, max_offset)
403    }
404
405    /// Clear the cached values. Once we've moved the cursor,
406    /// we don't need to keep the old values around.
407    pub fn clear_cache(&mut self) {
408        self.size_cache.clear();
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use std::sync::atomic::{self, AtomicUsize};
415
416    use crate::trim::Trim;
417
418    use super::*;
419
420    #[test]
421    fn check_chunk_capacity() {
422        let chunk = "12345";
423
424        assert_eq!(
425            ChunkCapacity::from(4).fits(Characters.size(chunk)),
426            Ordering::Greater
427        );
428        assert_eq!(
429            ChunkCapacity::from(5).fits(Characters.size(chunk)),
430            Ordering::Equal
431        );
432        assert_eq!(
433            ChunkCapacity::from(6).fits(Characters.size(chunk)),
434            Ordering::Less
435        );
436    }
437
438    #[test]
439    fn check_chunk_capacity_for_range() {
440        let chunk = "12345";
441
442        assert_eq!(
443            ChunkCapacity::from(0..0).fits(Characters.size(chunk)),
444            Ordering::Greater
445        );
446        assert_eq!(
447            ChunkCapacity::from(0..5).fits(Characters.size(chunk)),
448            Ordering::Greater
449        );
450        assert_eq!(
451            ChunkCapacity::from(5..6).fits(Characters.size(chunk)),
452            Ordering::Equal
453        );
454        assert_eq!(
455            ChunkCapacity::from(6..100).fits(Characters.size(chunk)),
456            Ordering::Less
457        );
458    }
459
460    #[test]
461    fn check_chunk_capacity_for_range_from() {
462        let chunk = "12345";
463
464        assert_eq!(
465            ChunkCapacity::from(0..).fits(Characters.size(chunk)),
466            Ordering::Equal
467        );
468        assert_eq!(
469            ChunkCapacity::from(5..).fits(Characters.size(chunk)),
470            Ordering::Equal
471        );
472        assert_eq!(
473            ChunkCapacity::from(6..).fits(Characters.size(chunk)),
474            Ordering::Less
475        );
476    }
477
478    #[test]
479    fn check_chunk_capacity_for_range_full() {
480        let chunk = "12345";
481
482        assert_eq!(
483            ChunkCapacity::from(..).fits(Characters.size(chunk)),
484            Ordering::Equal
485        );
486    }
487
488    #[test]
489    fn check_chunk_capacity_for_range_inclusive() {
490        let chunk = "12345";
491
492        assert_eq!(
493            ChunkCapacity::from(0..=4).fits(Characters.size(chunk)),
494            Ordering::Greater
495        );
496        assert_eq!(
497            ChunkCapacity::from(5..=6).fits(Characters.size(chunk)),
498            Ordering::Equal
499        );
500        assert_eq!(
501            ChunkCapacity::from(4..=5).fits(Characters.size(chunk)),
502            Ordering::Equal
503        );
504        assert_eq!(
505            ChunkCapacity::from(6..=100).fits(Characters.size(chunk)),
506            Ordering::Less
507        );
508    }
509
510    #[test]
511    fn check_chunk_capacity_for_range_to() {
512        let chunk = "12345";
513
514        assert_eq!(
515            ChunkCapacity::from(..0).fits(Characters.size(chunk)),
516            Ordering::Greater
517        );
518        assert_eq!(
519            ChunkCapacity::from(..5).fits(Characters.size(chunk)),
520            Ordering::Greater
521        );
522        assert_eq!(
523            ChunkCapacity::from(..6).fits(Characters.size(chunk)),
524            Ordering::Equal
525        );
526    }
527
528    #[test]
529    fn check_chunk_capacity_for_range_to_inclusive() {
530        let chunk = "12345";
531
532        assert_eq!(
533            ChunkCapacity::from(..=4).fits(Characters.size(chunk)),
534            Ordering::Greater
535        );
536        assert_eq!(
537            ChunkCapacity::from(..=5).fits(Characters.size(chunk)),
538            Ordering::Equal
539        );
540        assert_eq!(
541            ChunkCapacity::from(..=6).fits(Characters.size(chunk)),
542            Ordering::Equal
543        );
544    }
545
546    #[derive(Default)]
547    struct CountingSizer {
548        calls: AtomicUsize,
549    }
550
551    impl ChunkSizer for CountingSizer {
552        // Return character version, but count calls
553        fn size(&self, chunk: &str) -> usize {
554            self.calls.fetch_add(1, atomic::Ordering::SeqCst);
555            Characters.size(chunk)
556        }
557    }
558
559    #[test]
560    fn memoized_sizer_only_calculates_once_per_text() {
561        let sizer = CountingSizer::default();
562        let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
563        let text = "1234567890";
564        for _ in 0..10 {
565            memoized_sizer.chunk_size(0, text, Trim::All);
566        }
567
568        assert_eq!(memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst), 1);
569    }
570
571    #[test]
572    fn memoized_sizer_calculates_once_per_different_text() {
573        let sizer = CountingSizer::default();
574        let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
575        let text = "1234567890";
576        for i in 0..10 {
577            memoized_sizer.chunk_size(0, text.get(0..i).unwrap(), Trim::All);
578        }
579
580        assert_eq!(
581            memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
582            10
583        );
584    }
585
586    #[test]
587    fn can_clear_cache_on_memoized_sizer() {
588        let sizer = CountingSizer::default();
589        let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
590        let text = "1234567890";
591        for _ in 0..10 {
592            memoized_sizer.chunk_size(0, text, Trim::All);
593            memoized_sizer.clear_cache();
594        }
595
596        assert_eq!(
597            memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
598            10
599        );
600    }
601
602    #[test]
603    fn basic_chunk_config() {
604        let config = ChunkConfig::new(10);
605        assert_eq!(config.capacity, 10.into());
606        assert_eq!(config.sizer, Characters);
607        assert!(config.trim());
608    }
609
610    #[test]
611    fn disable_trimming() {
612        let config = ChunkConfig::new(10).with_trim(false);
613        assert!(!config.trim());
614    }
615
616    #[test]
617    fn new_sizer() {
618        #[derive(Debug, PartialEq)]
619        struct BasicSizer;
620
621        impl ChunkSizer for BasicSizer {
622            fn size(&self, _chunk: &str) -> usize {
623                unimplemented!()
624            }
625        }
626
627        let config = ChunkConfig::new(10).with_sizer(BasicSizer);
628        assert_eq!(config.capacity, 10.into());
629        assert_eq!(config.sizer, BasicSizer);
630        assert!(config.trim());
631    }
632
633    #[test]
634    fn chunk_capacity_max_and_desired_equal() {
635        let capacity = ChunkCapacity::new(10);
636        assert_eq!(capacity.desired(), 10);
637        assert_eq!(capacity.max(), 10);
638    }
639
640    #[test]
641    fn chunk_capacity_can_adjust_max() {
642        let capacity = ChunkCapacity::new(10).with_max(20).unwrap();
643        assert_eq!(capacity.desired(), 10);
644        assert_eq!(capacity.max(), 20);
645    }
646
647    #[test]
648    fn chunk_capacity_max_cant_be_less_than_desired() {
649        let capacity = ChunkCapacity::new(10);
650        let err = capacity.with_max(5).unwrap_err();
651        assert_eq!(
652            err.to_string(),
653            "Max chunk size must be greater than or equal to the desired chunk size"
654        );
655        assert_eq!(capacity.desired(), 10);
656        assert_eq!(capacity.max(), 10);
657    }
658
659    #[test]
660    fn set_chunk_overlap() {
661        let config = ChunkConfig::new(10).with_overlap(5).unwrap();
662        assert_eq!(config.overlap(), 5);
663    }
664
665    #[test]
666    fn cant_set_overlap_larger_than_capacity() {
667        let chunk_config = ChunkConfig::new(5);
668        let err = chunk_config.with_overlap(10).unwrap_err();
669        assert_eq!(
670            err.to_string(),
671            "The overlap is larger than or equal to the desired chunk capacity"
672        );
673    }
674
675    #[test]
676    fn cant_set_overlap_larger_than_desired() {
677        let chunk_config = ChunkConfig::new(5..15);
678        let err = chunk_config.with_overlap(10).unwrap_err();
679        assert_eq!(
680            err.to_string(),
681            "The overlap is larger than or equal to the desired chunk capacity"
682        );
683    }
684}