swiftide_indexing/transformers/
chunk_text.rs

1//! Chunk text content into smaller pieces
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{indexing::IndexingStream, indexing::Node, ChunkerTransformer};
7use text_splitter::{Characters, ChunkConfig, TextSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Debug, Clone, Builder)]
12#[builder(setter(strip_option))]
13/// A transformer that chunks text content into smaller pieces.
14///
15/// The transformer will split the text content into smaller pieces based on the specified
16/// `max_characters` or `range` of characters.
17///
18/// For further customization, you can use the builder to create a custom splitter. Uses
19/// `text_splitter` under the hood.
20///
21/// Technically that might work with every splitter `text_splitter` provides.
22pub struct ChunkText {
23    /// The max number of concurrent chunks to process.
24    ///
25    /// Defaults to `None`. If you use a splitter that is resource heavy, this parameter can be
26    /// tuned.
27    #[builder(default)]
28    concurrency: Option<usize>,
29
30    /// Optional maximum number of characters per chunk.
31    ///
32    /// Defaults to [`DEFAULT_MAX_CHAR_SIZE`].
33    #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
34    #[allow(dead_code)]
35    max_characters: usize,
36
37    /// A range of minimum and maximum characters per chunk.
38    ///
39    /// Chunks smaller than the range min will be ignored. `max_characters` will be ignored if this
40    /// is set.
41    ///
42    /// If you provide a custom chunker with a range, you might want to set the range as well.
43    ///
44    /// Defaults to 0..[`max_characters`]
45    #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
46    range: std::ops::Range<usize>,
47
48    /// The text splitter from [`text_splitter`]
49    ///
50    /// Defaults to a new [`TextSplitter`] with the specified `max_characters`.
51    #[builder(setter(into), default = "self.default_client()")]
52    chunker: Arc<TextSplitter<Characters>>,
53}
54
55impl Default for ChunkText {
56    fn default() -> Self {
57        Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
58    }
59}
60
61impl ChunkText {
62    pub fn builder() -> ChunkTextBuilder {
63        ChunkTextBuilder::default()
64    }
65
66    /// Create a new transformer with a maximum number of characters per chunk.
67    #[allow(clippy::missing_panics_doc)]
68    pub fn from_max_characters(max_characters: usize) -> Self {
69        Self::builder()
70            .max_characters(max_characters)
71            .build()
72            .expect("Cannot fail")
73    }
74
75    /// Create a new transformer with a range of characters per chunk.
76    ///
77    /// Chunks smaller than the range will be ignored.
78    #[allow(clippy::missing_panics_doc)]
79    pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
80        Self::builder().range(range).build().expect("Cannot fail")
81    }
82
83    /// Set the number of concurrent chunks to process.
84    #[must_use]
85    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
86        self.concurrency = Some(concurrency);
87        self
88    }
89
90    fn min_size(&self) -> usize {
91        self.range.start
92    }
93}
94
95impl ChunkTextBuilder {
96    fn default_client(&self) -> Arc<TextSplitter<Characters>> {
97        let chunk_config: ChunkConfig<Characters> = self
98            .range
99            .clone()
100            .map(ChunkConfig::<Characters>::from)
101            .or_else(|| self.max_characters.map(Into::into))
102            .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
103
104        Arc::new(TextSplitter::new(chunk_config))
105    }
106}
107#[async_trait]
108impl ChunkerTransformer for ChunkText {
109    #[tracing::instrument(skip_all, name = "transformers.chunk_text")]
110    async fn transform_node(&self, node: Node) -> IndexingStream {
111        let chunks = self
112            .chunker
113            .chunks(&node.chunk)
114            .filter_map(|chunk| {
115                let trim = chunk.trim();
116                if trim.is_empty() || trim.len() < self.min_size() {
117                    None
118                } else {
119                    Some(chunk.to_string())
120                }
121            })
122            .collect::<Vec<String>>();
123
124        IndexingStream::iter(
125            chunks
126                .into_iter()
127                .map(move |chunk| Node::build_from_other(&node).chunk(chunk).build()),
128        )
129    }
130
131    fn concurrency(&self) -> Option<usize> {
132        self.concurrency
133    }
134}
135
136#[cfg(test)]
137mod test {
138    use super::*;
139    use futures_util::stream::TryStreamExt;
140
141    const TEXT: &str = r"
142        This is a text.
143
144        This is a paragraph.
145
146        This is another paragraph.
147        ";
148
149    #[tokio::test]
150    async fn test_transforming_with_max_characters_and_trimming() {
151        let chunker = ChunkText::from_max_characters(40);
152
153        let node = Node::new(TEXT.to_string());
154
155        let nodes: Vec<Node> = chunker
156            .transform_node(node)
157            .await
158            .try_collect()
159            .await
160            .unwrap();
161
162        for line in TEXT.lines().filter(|line| !line.trim().is_empty()) {
163            assert!(nodes.iter().any(|node| node.chunk == line.trim()));
164        }
165
166        assert_eq!(nodes.len(), 3);
167    }
168
169    #[tokio::test]
170    async fn test_always_within_range() {
171        let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
172        for range in ranges {
173            let chunker = ChunkText::from_chunk_range(range.clone());
174            let node = Node::new(TEXT.to_string());
175            let nodes: Vec<Node> = chunker
176                .transform_node(node)
177                .await
178                .try_collect()
179                .await
180                .unwrap();
181            // Assert all nodes chunk length within the range
182            assert!(
183                nodes.iter().all(|node| {
184                    let len = node.chunk.len();
185                    range.contains(&len)
186                }),
187                "{:?}, {:?}",
188                range,
189                nodes.iter().filter(|node| {
190                    let len = node.chunk.len();
191                    !range.contains(&len)
192                })
193            );
194        }
195    }
196
197    #[test]
198    fn test_builder() {
199        ChunkText::builder()
200            .chunker(text_splitter::TextSplitter::new(40))
201            .concurrency(10)
202            .range(10..20)
203            .build()
204            .unwrap();
205    }
206}