swiftide_indexing/transformers/
chunk_text.rs1use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::TextNode};
7use text_splitter::{Characters, ChunkConfig, TextSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Debug, Clone, Builder)]
12#[builder(setter(strip_option))]
13pub struct ChunkText {
23 #[builder(default)]
28 concurrency: Option<usize>,
29
30 #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
34 #[allow(dead_code)]
35 max_characters: usize,
36
37 #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
46 range: std::ops::Range<usize>,
47
48 #[builder(setter(into), default = "self.default_client()")]
52 chunker: Arc<TextSplitter<Characters>>,
53}
54
55impl Default for ChunkText {
56 fn default() -> Self {
57 Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
58 }
59}
60
61impl ChunkText {
62 pub fn builder() -> ChunkTextBuilder {
63 ChunkTextBuilder::default()
64 }
65
66 #[allow(clippy::missing_panics_doc)]
68 pub fn from_max_characters(max_characters: usize) -> Self {
69 Self::builder()
70 .max_characters(max_characters)
71 .build()
72 .expect("Cannot fail")
73 }
74
75 #[allow(clippy::missing_panics_doc)]
79 pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
80 Self::builder().range(range).build().expect("Cannot fail")
81 }
82
83 #[must_use]
85 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
86 self.concurrency = Some(concurrency);
87 self
88 }
89
90 fn min_size(&self) -> usize {
91 self.range.start
92 }
93}
94
95impl ChunkTextBuilder {
96 fn default_client(&self) -> Arc<TextSplitter<Characters>> {
97 let chunk_config: ChunkConfig<Characters> = self
98 .range
99 .clone()
100 .map(ChunkConfig::<Characters>::from)
101 .or_else(|| self.max_characters.map(Into::into))
102 .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
103
104 Arc::new(TextSplitter::new(chunk_config))
105 }
106}
107#[async_trait]
108impl ChunkerTransformer for ChunkText {
109 type Input = String;
110 type Output = String;
111
112 #[tracing::instrument(skip_all, name = "transformers.chunk_text")]
113 async fn transform_node(&self, node: TextNode) -> IndexingStream<String> {
114 let chunks = self
115 .chunker
116 .chunks(&node.chunk)
117 .filter_map(|chunk| {
118 let trim = chunk.trim();
119 if trim.is_empty() || trim.len() < self.min_size() {
120 None
121 } else {
122 Some(chunk.to_string())
123 }
124 })
125 .collect::<Vec<String>>();
126
127 IndexingStream::iter(
128 chunks
129 .into_iter()
130 .map(move |chunk| TextNode::build_from_other(&node).chunk(chunk).build()),
131 )
132 }
133
134 fn concurrency(&self) -> Option<usize> {
135 self.concurrency
136 }
137}
138
139#[cfg(test)]
140mod test {
141 use super::*;
142 use futures_util::stream::TryStreamExt;
143
144 const TEXT: &str = r"
145 This is a text.
146
147 This is a paragraph.
148
149 This is another paragraph.
150 ";
151
152 #[tokio::test]
153 async fn test_transforming_with_max_characters_and_trimming() {
154 let chunker = ChunkText::from_max_characters(40);
155
156 let node = TextNode::new(TEXT.to_string());
157
158 let nodes: Vec<TextNode> = chunker
159 .transform_node(node)
160 .await
161 .try_collect()
162 .await
163 .unwrap();
164
165 for line in TEXT.lines().filter(|line| !line.trim().is_empty()) {
166 assert!(nodes.iter().any(|node| node.chunk == line.trim()));
167 }
168
169 assert_eq!(nodes.len(), 3);
170 }
171
172 #[tokio::test]
173 async fn test_always_within_range() {
174 let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
175 for range in ranges {
176 let chunker = ChunkText::from_chunk_range(range.clone());
177 let node = TextNode::new(TEXT.to_string());
178 let nodes: Vec<TextNode> = chunker
179 .transform_node(node)
180 .await
181 .try_collect()
182 .await
183 .unwrap();
184 assert!(
186 nodes.iter().all(|node| {
187 let len = node.chunk.len();
188 range.contains(&len)
189 }),
190 "{:?}, {:?}",
191 range,
192 nodes.iter().filter(|node| {
193 let len = node.chunk.len();
194 !range.contains(&len)
195 })
196 );
197 }
198 }
199
200 #[test]
201 fn test_builder() {
202 ChunkText::builder()
203 .chunker(text_splitter::TextSplitter::new(40))
204 .concurrency(10)
205 .range(10..20)
206 .build()
207 .unwrap();
208 }
209}