swiftide_indexing/transformers/
chunk_text.rs1use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{indexing::IndexingStream, indexing::Node, ChunkerTransformer};
7use text_splitter::{Characters, ChunkConfig, TextSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Debug, Clone, Builder)]
12#[builder(setter(strip_option))]
13pub struct ChunkText {
23 #[builder(default)]
28 concurrency: Option<usize>,
29
30 #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
34 #[allow(dead_code)]
35 max_characters: usize,
36
37 #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
46 range: std::ops::Range<usize>,
47
48 #[builder(setter(into), default = "self.default_client()")]
52 chunker: Arc<TextSplitter<Characters>>,
53}
54
55impl Default for ChunkText {
56 fn default() -> Self {
57 Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
58 }
59}
60
61impl ChunkText {
62 pub fn builder() -> ChunkTextBuilder {
63 ChunkTextBuilder::default()
64 }
65
66 #[allow(clippy::missing_panics_doc)]
68 pub fn from_max_characters(max_characters: usize) -> Self {
69 Self::builder()
70 .max_characters(max_characters)
71 .build()
72 .expect("Cannot fail")
73 }
74
75 #[allow(clippy::missing_panics_doc)]
79 pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
80 Self::builder().range(range).build().expect("Cannot fail")
81 }
82
83 #[must_use]
85 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
86 self.concurrency = Some(concurrency);
87 self
88 }
89
90 fn min_size(&self) -> usize {
91 self.range.start
92 }
93}
94
95impl ChunkTextBuilder {
96 fn default_client(&self) -> Arc<TextSplitter<Characters>> {
97 let chunk_config: ChunkConfig<Characters> = self
98 .range
99 .clone()
100 .map(ChunkConfig::<Characters>::from)
101 .or_else(|| self.max_characters.map(Into::into))
102 .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
103
104 Arc::new(TextSplitter::new(chunk_config))
105 }
106}
107#[async_trait]
108impl ChunkerTransformer for ChunkText {
109 #[tracing::instrument(skip_all, name = "transformers.chunk_text")]
110 async fn transform_node(&self, node: Node) -> IndexingStream {
111 let chunks = self
112 .chunker
113 .chunks(&node.chunk)
114 .filter_map(|chunk| {
115 let trim = chunk.trim();
116 if trim.is_empty() || trim.len() < self.min_size() {
117 None
118 } else {
119 Some(chunk.to_string())
120 }
121 })
122 .collect::<Vec<String>>();
123
124 IndexingStream::iter(
125 chunks
126 .into_iter()
127 .map(move |chunk| Node::build_from_other(&node).chunk(chunk).build()),
128 )
129 }
130
131 fn concurrency(&self) -> Option<usize> {
132 self.concurrency
133 }
134}
135
136#[cfg(test)]
137mod test {
138 use super::*;
139 use futures_util::stream::TryStreamExt;
140
141 const TEXT: &str = r"
142 This is a text.
143
144 This is a paragraph.
145
146 This is another paragraph.
147 ";
148
149 #[tokio::test]
150 async fn test_transforming_with_max_characters_and_trimming() {
151 let chunker = ChunkText::from_max_characters(40);
152
153 let node = Node::new(TEXT.to_string());
154
155 let nodes: Vec<Node> = chunker
156 .transform_node(node)
157 .await
158 .try_collect()
159 .await
160 .unwrap();
161
162 for line in TEXT.lines().filter(|line| !line.trim().is_empty()) {
163 assert!(nodes.iter().any(|node| node.chunk == line.trim()));
164 }
165
166 assert_eq!(nodes.len(), 3);
167 }
168
169 #[tokio::test]
170 async fn test_always_within_range() {
171 let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
172 for range in ranges {
173 let chunker = ChunkText::from_chunk_range(range.clone());
174 let node = Node::new(TEXT.to_string());
175 let nodes: Vec<Node> = chunker
176 .transform_node(node)
177 .await
178 .try_collect()
179 .await
180 .unwrap();
181 assert!(
183 nodes.iter().all(|node| {
184 let len = node.chunk.len();
185 range.contains(&len)
186 }),
187 "{:?}, {:?}",
188 range,
189 nodes.iter().filter(|node| {
190 let len = node.chunk.len();
191 !range.contains(&len)
192 })
193 );
194 }
195 }
196
197 #[test]
198 fn test_builder() {
199 ChunkText::builder()
200 .chunker(text_splitter::TextSplitter::new(40))
201 .concurrency(10)
202 .range(10..20)
203 .build()
204 .unwrap();
205 }
206}