swiftide_indexing/transformers/
chunk_text.rs

1//! Chunk text content into smaller pieces
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::TextNode};
7use text_splitter::{Characters, ChunkConfig, TextSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Debug, Clone, Builder)]
12#[builder(setter(strip_option))]
13/// A transformer that chunks text content into smaller pieces.
14///
15/// The transformer will split the text content into smaller pieces based on the specified
16/// `max_characters` or `range` of characters.
17///
18/// For further customization, you can use the builder to create a custom splitter. Uses
19/// `text_splitter` under the hood.
20///
21/// Technically that might work with every splitter `text_splitter` provides.
22pub struct ChunkText {
23    /// The max number of concurrent chunks to process.
24    ///
25    /// Defaults to `None`. If you use a splitter that is resource heavy, this parameter can be
26    /// tuned.
27    #[builder(default)]
28    concurrency: Option<usize>,
29
30    /// Optional maximum number of characters per chunk.
31    ///
32    /// Defaults to [`DEFAULT_MAX_CHAR_SIZE`].
33    #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
34    #[allow(dead_code)]
35    max_characters: usize,
36
37    /// A range of minimum and maximum characters per chunk.
38    ///
39    /// Chunks smaller than the range min will be ignored. `max_characters` will be ignored if this
40    /// is set.
41    ///
42    /// If you provide a custom chunker with a range, you might want to set the range as well.
43    ///
44    /// Defaults to 0..[`max_characters`]
45    #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
46    range: std::ops::Range<usize>,
47
48    /// The text splitter from [`text_splitter`]
49    ///
50    /// Defaults to a new [`TextSplitter`] with the specified `max_characters`.
51    #[builder(setter(into), default = "self.default_client()")]
52    chunker: Arc<TextSplitter<Characters>>,
53}
54
55impl Default for ChunkText {
56    fn default() -> Self {
57        Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
58    }
59}
60
61impl ChunkText {
62    pub fn builder() -> ChunkTextBuilder {
63        ChunkTextBuilder::default()
64    }
65
66    /// Create a new transformer with a maximum number of characters per chunk.
67    #[allow(clippy::missing_panics_doc)]
68    pub fn from_max_characters(max_characters: usize) -> Self {
69        Self::builder()
70            .max_characters(max_characters)
71            .build()
72            .expect("Cannot fail")
73    }
74
75    /// Create a new transformer with a range of characters per chunk.
76    ///
77    /// Chunks smaller than the range will be ignored.
78    #[allow(clippy::missing_panics_doc)]
79    pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
80        Self::builder().range(range).build().expect("Cannot fail")
81    }
82
83    /// Set the number of concurrent chunks to process.
84    #[must_use]
85    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
86        self.concurrency = Some(concurrency);
87        self
88    }
89
90    fn min_size(&self) -> usize {
91        self.range.start
92    }
93}
94
95impl ChunkTextBuilder {
96    fn default_client(&self) -> Arc<TextSplitter<Characters>> {
97        let chunk_config: ChunkConfig<Characters> = self
98            .range
99            .clone()
100            .map(ChunkConfig::<Characters>::from)
101            .or_else(|| self.max_characters.map(Into::into))
102            .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
103
104        Arc::new(TextSplitter::new(chunk_config))
105    }
106}
107#[async_trait]
108impl ChunkerTransformer for ChunkText {
109    type Input = String;
110    type Output = String;
111
112    #[tracing::instrument(skip_all, name = "transformers.chunk_text")]
113    async fn transform_node(&self, node: TextNode) -> IndexingStream<String> {
114        let chunks = self
115            .chunker
116            .chunks(&node.chunk)
117            .filter_map(|chunk| {
118                let trim = chunk.trim();
119                if trim.is_empty() || trim.len() < self.min_size() {
120                    None
121                } else {
122                    Some(chunk.to_string())
123                }
124            })
125            .collect::<Vec<String>>();
126
127        IndexingStream::iter(
128            chunks
129                .into_iter()
130                .map(move |chunk| TextNode::build_from_other(&node).chunk(chunk).build()),
131        )
132    }
133
134    fn concurrency(&self) -> Option<usize> {
135        self.concurrency
136    }
137}
138
139#[cfg(test)]
140mod test {
141    use super::*;
142    use futures_util::stream::TryStreamExt;
143
144    const TEXT: &str = r"
145        This is a text.
146
147        This is a paragraph.
148
149        This is another paragraph.
150        ";
151
152    #[tokio::test]
153    async fn test_transforming_with_max_characters_and_trimming() {
154        let chunker = ChunkText::from_max_characters(40);
155
156        let node = TextNode::new(TEXT.to_string());
157
158        let nodes: Vec<TextNode> = chunker
159            .transform_node(node)
160            .await
161            .try_collect()
162            .await
163            .unwrap();
164
165        for line in TEXT.lines().filter(|line| !line.trim().is_empty()) {
166            assert!(nodes.iter().any(|node| node.chunk == line.trim()));
167        }
168
169        assert_eq!(nodes.len(), 3);
170    }
171
172    #[tokio::test]
173    async fn test_always_within_range() {
174        let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
175        for range in ranges {
176            let chunker = ChunkText::from_chunk_range(range.clone());
177            let node = TextNode::new(TEXT.to_string());
178            let nodes: Vec<TextNode> = chunker
179                .transform_node(node)
180                .await
181                .try_collect()
182                .await
183                .unwrap();
184            // Assert all nodes chunk length within the range
185            assert!(
186                nodes.iter().all(|node| {
187                    let len = node.chunk.len();
188                    range.contains(&len)
189                }),
190                "{:?}, {:?}",
191                range,
192                nodes.iter().filter(|node| {
193                    let len = node.chunk.len();
194                    !range.contains(&len)
195                })
196            );
197        }
198    }
199
200    #[test]
201    fn test_builder() {
202        ChunkText::builder()
203            .chunker(text_splitter::TextSplitter::new(40))
204            .concurrency(10)
205            .range(10..20)
206            .build()
207            .unwrap();
208    }
209}