tandem_memory/
chunking.rs1use crate::types::{MemoryError, MemoryResult, MIN_CHUNK_LENGTH};
5use tiktoken_rs::cl100k_base;
6
7#[derive(Debug, Clone)]
9pub struct TextChunk {
10 pub content: String,
11 pub token_count: usize,
12 pub start_index: usize,
13 pub end_index: usize,
14}
15
16#[derive(Debug, Clone)]
18pub struct ChunkingConfig {
19 pub chunk_size: usize,
21 pub chunk_overlap: usize,
23 pub separator: Option<String>,
25}
26
27impl Default for ChunkingConfig {
28 fn default() -> Self {
29 Self {
30 chunk_size: 512,
31 chunk_overlap: 64,
32 separator: None,
33 }
34 }
35}
36
37pub struct Tokenizer {
39 bpe: tiktoken_rs::CoreBPE,
40}
41
42impl Tokenizer {
43 pub fn new() -> MemoryResult<Self> {
44 let bpe = cl100k_base().map_err(|e| MemoryError::Tokenization(e.to_string()))?;
45 Ok(Self { bpe })
46 }
47
48 pub fn count_tokens(&self, text: &str) -> usize {
50 self.bpe.encode_with_special_tokens(text).len()
51 }
52
53 pub fn encode(&self, text: &str) -> Vec<u32> {
55 self.bpe.encode_with_special_tokens(text)
56 }
57
58 pub fn decode(&self, tokens: &[u32]) -> String {
60 self.bpe.decode(tokens.to_vec()).unwrap_or_default()
61 }
62}
63
64impl Default for Tokenizer {
65 fn default() -> Self {
66 Self::new().expect("Failed to initialize tokenizer")
67 }
68}
69
70pub fn chunk_text(text: &str, config: &ChunkingConfig) -> MemoryResult<Vec<TextChunk>> {
72 if text.is_empty() {
73 return Ok(Vec::new());
74 }
75
76 if text.len() < MIN_CHUNK_LENGTH {
78 let tokenizer = Tokenizer::new()?;
79 let token_count = tokenizer.count_tokens(text);
80 return Ok(vec![TextChunk {
81 content: text.to_string(),
82 token_count,
83 start_index: 0,
84 end_index: text.len(),
85 }]);
86 }
87
88 let tokenizer = Tokenizer::new()?;
89 let tokens = tokenizer.encode(text);
90
91 if tokens.len() <= config.chunk_size {
92 return Ok(vec![TextChunk {
94 content: text.to_string(),
95 token_count: tokens.len(),
96 start_index: 0,
97 end_index: text.len(),
98 }]);
99 }
100
101 let mut chunks = Vec::new();
102 let mut start = 0;
103
104 while start < tokens.len() {
105 let end = (start + config.chunk_size).min(tokens.len());
106 let chunk_tokens = &tokens[start..end];
107 let chunk_text = tokenizer.decode(chunk_tokens);
108
109 let start_char = if start == 0 {
111 0
112 } else {
113 let prev_tokens = &tokens[..start];
115 tokenizer.decode(prev_tokens).len()
116 };
117
118 let end_char = start_char + chunk_text.len();
119
120 chunks.push(TextChunk {
121 content: chunk_text,
122 token_count: chunk_tokens.len(),
123 start_index: start_char,
124 end_index: end_char.min(text.len()),
125 });
126
127 let step = config.chunk_size.saturating_sub(config.chunk_overlap);
129 if step == 0 {
130 start = end;
132 } else {
133 start += step;
134 }
135
136 if start >= tokens.len() {
138 break;
139 }
140 }
141
142 Ok(chunks)
143}
144
145pub fn chunk_text_semantic(text: &str, config: &ChunkingConfig) -> MemoryResult<Vec<TextChunk>> {
147 if text.is_empty() {
148 return Ok(Vec::new());
149 }
150
151 if text.len() < MIN_CHUNK_LENGTH {
153 let tokenizer = Tokenizer::new()?;
154 let token_count = tokenizer.count_tokens(text);
155 return Ok(vec![TextChunk {
156 content: text.to_string(),
157 token_count,
158 start_index: 0,
159 end_index: text.len(),
160 }]);
161 }
162
163 let tokenizer = Tokenizer::new()?;
164 let tokens = tokenizer.encode(text);
165
166 if tokens.len() <= config.chunk_size {
167 return Ok(vec![TextChunk {
168 content: text.to_string(),
169 token_count: tokens.len(),
170 start_index: 0,
171 end_index: text.len(),
172 }]);
173 }
174
175 let paragraphs: Vec<&str> = text
177 .split("\n\n")
178 .filter(|p| !p.trim().is_empty())
179 .collect();
180
181 let mut chunks = Vec::new();
182 let mut current_chunk = String::new();
183 let mut current_tokens = 0;
184 let mut chunk_start = 0;
185 let mut current_pos = 0;
186
187 for paragraph in paragraphs {
188 let para_tokens = tokenizer.count_tokens(paragraph);
189
190 if para_tokens > config.chunk_size {
191 if !current_chunk.is_empty() {
193 chunks.push(TextChunk {
194 content: current_chunk.clone(),
195 token_count: current_tokens,
196 start_index: chunk_start,
197 end_index: current_pos,
198 });
199 current_chunk.clear();
200 current_tokens = 0;
201 chunk_start = current_pos;
202 }
203
204 let sentences: Vec<&str> = paragraph
206 .split(['.', '!', '?'])
207 .filter(|s| !s.trim().is_empty())
208 .collect();
209
210 for sentence in sentences {
211 let sentence_with_punct = format!("{}.", sentence.trim());
212 let sent_tokens = tokenizer.count_tokens(&sentence_with_punct);
213
214 if current_tokens + sent_tokens > config.chunk_size && !current_chunk.is_empty() {
215 chunks.push(TextChunk {
216 content: current_chunk.clone(),
217 token_count: current_tokens,
218 start_index: chunk_start,
219 end_index: current_pos,
220 });
221 let overlap_tokens =
223 current_tokens.saturating_sub(config.chunk_size - config.chunk_overlap);
224 if overlap_tokens > 0 && overlap_tokens < current_tokens {
225 let overlap_text =
226 get_last_n_tokens(&tokenizer, ¤t_chunk, overlap_tokens);
227 current_chunk = overlap_text;
228 current_tokens = overlap_tokens;
229 chunk_start = current_pos - current_chunk.len();
230 } else {
231 current_chunk.clear();
232 current_tokens = 0;
233 chunk_start = current_pos;
234 }
235 }
236
237 current_chunk.push_str(&sentence_with_punct);
238 current_chunk.push(' ');
239 current_tokens += sent_tokens;
240 current_pos += sentence_with_punct.len() + 1;
241 }
242 } else if current_tokens + para_tokens > config.chunk_size {
243 if !current_chunk.is_empty() {
245 chunks.push(TextChunk {
246 content: current_chunk.clone(),
247 token_count: current_tokens,
248 start_index: chunk_start,
249 end_index: current_pos,
250 });
251 }
252 current_chunk = paragraph.to_string();
253 current_chunk.push('\n');
254 current_tokens = para_tokens;
255 chunk_start = current_pos;
256 current_pos += paragraph.len() + 1;
257 } else {
258 current_chunk.push_str(paragraph);
260 current_chunk.push('\n');
261 current_tokens += para_tokens;
262 current_pos += paragraph.len() + 1;
263 }
264 }
265
266 if !current_chunk.is_empty() {
268 chunks.push(TextChunk {
269 content: current_chunk.trim().to_string(),
270 token_count: current_tokens,
271 start_index: chunk_start,
272 end_index: text.len(),
273 });
274 }
275
276 Ok(chunks)
277}
278
279fn get_last_n_tokens(tokenizer: &Tokenizer, text: &str, n: usize) -> String {
281 let tokens = tokenizer.encode(text);
282 let start = tokens.len().saturating_sub(n);
283 let last_tokens = &tokens[start..];
284 tokenizer.decode(last_tokens)
285}
286
287pub fn estimate_token_count(text: &str) -> usize {
289 text.len() / 4
291}
292
293pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> MemoryResult<String> {
295 let tokenizer = Tokenizer::new()?;
296 let tokens = tokenizer.encode(text);
297
298 if tokens.len() <= max_tokens {
299 Ok(text.to_string())
300 } else {
301 let truncated = &tokens[..max_tokens];
302 Ok(tokenizer.decode(truncated))
303 }
304}
305
306pub fn merge_small_chunks(chunks: Vec<TextChunk>, min_tokens: usize) -> Vec<TextChunk> {
308 if chunks.len() < 2 {
309 return chunks;
310 }
311
312 let mut merged = Vec::new();
313 let mut current = chunks[0].clone();
314
315 for chunk in chunks.into_iter().skip(1) {
316 if current.token_count < min_tokens {
317 current.content.push('\n');
319 current.content.push_str(&chunk.content);
320 current.token_count += chunk.token_count;
321 current.end_index = chunk.end_index;
322 } else {
323 merged.push(current);
324 current = chunk;
325 }
326 }
327
328 merged.push(current);
329 merged
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_chunk_text_empty() {
338 let config = ChunkingConfig::default();
339 let chunks = chunk_text("", &config).unwrap();
340 assert!(chunks.is_empty());
341 }
342
343 #[test]
344 fn test_chunk_text_short() {
345 let config = ChunkingConfig::default();
346 let text = "This is a short text.";
347 let chunks = chunk_text(text, &config).unwrap();
348 assert_eq!(chunks.len(), 1);
349 assert_eq!(chunks[0].content, text);
350 }
351
352 #[test]
353 fn test_chunk_text_long() {
354 let config = ChunkingConfig {
355 chunk_size: 10,
356 chunk_overlap: 2,
357 separator: None,
358 };
359 let text = "This is a much longer text that needs to be split into multiple chunks. It contains several sentences and should be broken up appropriately.";
360 let chunks = chunk_text(text, &config).unwrap();
361 assert!(chunks.len() > 1);
362
363 for i in 1..chunks.len() {
365 let prev_end = chunks[i - 1].end_index;
366 let curr_start = chunks[i].start_index;
367 assert!(curr_start < prev_end, "Chunks should overlap");
368 }
369 }
370
371 #[test]
372 fn test_tokenizer_count() {
373 let tokenizer = Tokenizer::new().unwrap();
374 let count = tokenizer.count_tokens("Hello world");
375 assert!(count > 0);
376 }
377
378 #[test]
379 fn test_estimate_token_count() {
380 let text = "This is a test sentence with approximately twelve tokens.";
381 let estimated = estimate_token_count(text);
382 let tokenizer = Tokenizer::new().unwrap();
383 let actual = tokenizer.count_tokens(text);
384
385 let diff = (estimated as i64 - actual as i64).abs();
387 assert!(diff < 5, "Estimate should be close to actual");
388 }
389}