Skip to main content

tandem_memory/
chunking.rs

1// Text Chunking Module
2// Splits text into chunks with configurable size and overlap
3
4use crate::types::{MemoryError, MemoryResult, MIN_CHUNK_LENGTH};
5use tiktoken_rs::cl100k_base;
6
7/// A text chunk with metadata
8#[derive(Debug, Clone)]
9pub struct TextChunk {
10    pub content: String,
11    pub token_count: usize,
12    pub start_index: usize,
13    pub end_index: usize,
14}
15
16/// Chunking configuration
17#[derive(Debug, Clone)]
18pub struct ChunkingConfig {
19    /// Target chunk size in tokens
20    pub chunk_size: usize,
21    /// Overlap between chunks in tokens
22    pub chunk_overlap: usize,
23    /// Separator to use for splitting (if None, uses token boundaries)
24    pub separator: Option<String>,
25}
26
27impl Default for ChunkingConfig {
28    fn default() -> Self {
29        Self {
30            chunk_size: 512,
31            chunk_overlap: 64,
32            separator: None,
33        }
34    }
35}
36
37/// Tokenizer wrapper for counting tokens
38pub struct Tokenizer {
39    bpe: tiktoken_rs::CoreBPE,
40}
41
42impl Tokenizer {
43    pub fn new() -> MemoryResult<Self> {
44        let bpe = cl100k_base().map_err(|e| MemoryError::Tokenization(e.to_string()))?;
45        Ok(Self { bpe })
46    }
47
48    /// Count tokens in text
49    pub fn count_tokens(&self, text: &str) -> usize {
50        self.bpe.encode_with_special_tokens(text).len()
51    }
52
53    /// Encode text to tokens
54    pub fn encode(&self, text: &str) -> Vec<u32> {
55        self.bpe.encode_with_special_tokens(text)
56    }
57
58    /// Decode tokens to text
59    pub fn decode(&self, tokens: &[u32]) -> String {
60        self.bpe.decode(tokens.to_vec()).unwrap_or_default()
61    }
62}
63
64impl Default for Tokenizer {
65    fn default() -> Self {
66        Self::new().expect("Failed to initialize tokenizer")
67    }
68}
69
70/// Chunk text into pieces with overlap
71pub fn chunk_text(text: &str, config: &ChunkingConfig) -> MemoryResult<Vec<TextChunk>> {
72    if text.is_empty() {
73        return Ok(Vec::new());
74    }
75
76    // If text is short enough, return as single chunk
77    if text.len() < MIN_CHUNK_LENGTH {
78        let tokenizer = Tokenizer::new()?;
79        let token_count = tokenizer.count_tokens(text);
80        return Ok(vec![TextChunk {
81            content: text.to_string(),
82            token_count,
83            start_index: 0,
84            end_index: text.len(),
85        }]);
86    }
87
88    let tokenizer = Tokenizer::new()?;
89    let tokens = tokenizer.encode(text);
90
91    if tokens.len() <= config.chunk_size {
92        // Text fits in a single chunk
93        return Ok(vec![TextChunk {
94            content: text.to_string(),
95            token_count: tokens.len(),
96            start_index: 0,
97            end_index: text.len(),
98        }]);
99    }
100
101    let mut chunks = Vec::new();
102    let mut start = 0;
103
104    while start < tokens.len() {
105        let end = (start + config.chunk_size).min(tokens.len());
106        let chunk_tokens = &tokens[start..end];
107        let chunk_text = tokenizer.decode(chunk_tokens);
108
109        // Find character boundaries
110        let start_char = if start == 0 {
111            0
112        } else {
113            // Find the character position that corresponds to this token
114            let prev_tokens = &tokens[..start];
115            tokenizer.decode(prev_tokens).len()
116        };
117
118        let end_char = start_char + chunk_text.len();
119
120        chunks.push(TextChunk {
121            content: chunk_text,
122            token_count: chunk_tokens.len(),
123            start_index: start_char,
124            end_index: end_char.min(text.len()),
125        });
126
127        // Move start forward by chunk_size - overlap
128        let step = config.chunk_size.saturating_sub(config.chunk_overlap);
129        if step == 0 {
130            // Prevent infinite loop if overlap equals or exceeds chunk_size
131            start = end;
132        } else {
133            start += step;
134        }
135
136        // Ensure we make progress
137        if start >= tokens.len() {
138            break;
139        }
140    }
141
142    Ok(chunks)
143}
144
145/// Chunk text using semantic boundaries (paragraphs, sentences)
146pub fn chunk_text_semantic(text: &str, config: &ChunkingConfig) -> MemoryResult<Vec<TextChunk>> {
147    if text.is_empty() {
148        return Ok(Vec::new());
149    }
150
151    // If text is short enough, return as single chunk
152    if text.len() < MIN_CHUNK_LENGTH {
153        let tokenizer = Tokenizer::new()?;
154        let token_count = tokenizer.count_tokens(text);
155        return Ok(vec![TextChunk {
156            content: text.to_string(),
157            token_count,
158            start_index: 0,
159            end_index: text.len(),
160        }]);
161    }
162
163    let tokenizer = Tokenizer::new()?;
164    let tokens = tokenizer.encode(text);
165
166    if tokens.len() <= config.chunk_size {
167        return Ok(vec![TextChunk {
168            content: text.to_string(),
169            token_count: tokens.len(),
170            start_index: 0,
171            end_index: text.len(),
172        }]);
173    }
174
175    // Split by paragraphs first
176    let paragraphs: Vec<&str> = text
177        .split("\n\n")
178        .filter(|p| !p.trim().is_empty())
179        .collect();
180
181    let mut chunks = Vec::new();
182    let mut current_chunk = String::new();
183    let mut current_tokens = 0;
184    let mut chunk_start = 0;
185    let mut current_pos = 0;
186
187    for paragraph in paragraphs {
188        let para_tokens = tokenizer.count_tokens(paragraph);
189
190        if para_tokens > config.chunk_size {
191            // Paragraph is too long, split by sentences
192            if !current_chunk.is_empty() {
193                chunks.push(TextChunk {
194                    content: current_chunk.clone(),
195                    token_count: current_tokens,
196                    start_index: chunk_start,
197                    end_index: current_pos,
198                });
199                current_chunk.clear();
200                current_tokens = 0;
201                chunk_start = current_pos;
202            }
203
204            // Split long paragraph by sentences
205            let sentences: Vec<&str> = paragraph
206                .split(['.', '!', '?'])
207                .filter(|s| !s.trim().is_empty())
208                .collect();
209
210            for sentence in sentences {
211                let sentence_with_punct = format!("{}.", sentence.trim());
212                let sent_tokens = tokenizer.count_tokens(&sentence_with_punct);
213
214                if current_tokens + sent_tokens > config.chunk_size && !current_chunk.is_empty() {
215                    chunks.push(TextChunk {
216                        content: current_chunk.clone(),
217                        token_count: current_tokens,
218                        start_index: chunk_start,
219                        end_index: current_pos,
220                    });
221                    // Keep overlap
222                    let overlap_tokens =
223                        current_tokens.saturating_sub(config.chunk_size - config.chunk_overlap);
224                    if overlap_tokens > 0 && overlap_tokens < current_tokens {
225                        let overlap_text =
226                            get_last_n_tokens(&tokenizer, &current_chunk, overlap_tokens);
227                        current_chunk = overlap_text;
228                        current_tokens = overlap_tokens;
229                        chunk_start = current_pos - current_chunk.len();
230                    } else {
231                        current_chunk.clear();
232                        current_tokens = 0;
233                        chunk_start = current_pos;
234                    }
235                }
236
237                current_chunk.push_str(&sentence_with_punct);
238                current_chunk.push(' ');
239                current_tokens += sent_tokens;
240                current_pos += sentence_with_punct.len() + 1;
241            }
242        } else if current_tokens + para_tokens > config.chunk_size {
243            // Start new chunk
244            if !current_chunk.is_empty() {
245                chunks.push(TextChunk {
246                    content: current_chunk.clone(),
247                    token_count: current_tokens,
248                    start_index: chunk_start,
249                    end_index: current_pos,
250                });
251            }
252            current_chunk = paragraph.to_string();
253            current_chunk.push('\n');
254            current_tokens = para_tokens;
255            chunk_start = current_pos;
256            current_pos += paragraph.len() + 1;
257        } else {
258            // Add to current chunk
259            current_chunk.push_str(paragraph);
260            current_chunk.push('\n');
261            current_tokens += para_tokens;
262            current_pos += paragraph.len() + 1;
263        }
264    }
265
266    // Don't forget the last chunk
267    if !current_chunk.is_empty() {
268        chunks.push(TextChunk {
269            content: current_chunk.trim().to_string(),
270            token_count: current_tokens,
271            start_index: chunk_start,
272            end_index: text.len(),
273        });
274    }
275
276    Ok(chunks)
277}
278
279/// Get the last n tokens from text
280fn get_last_n_tokens(tokenizer: &Tokenizer, text: &str, n: usize) -> String {
281    let tokens = tokenizer.encode(text);
282    let start = tokens.len().saturating_sub(n);
283    let last_tokens = &tokens[start..];
284    tokenizer.decode(last_tokens)
285}
286
287/// Estimate token count without full tokenization (faster but less accurate)
288pub fn estimate_token_count(text: &str) -> usize {
289    // Rough estimate: ~4 characters per token on average for English
290    text.len() / 4
291}
292
293/// Truncate text to fit within token budget
294pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> MemoryResult<String> {
295    let tokenizer = Tokenizer::new()?;
296    let tokens = tokenizer.encode(text);
297
298    if tokens.len() <= max_tokens {
299        Ok(text.to_string())
300    } else {
301        let truncated = &tokens[..max_tokens];
302        Ok(tokenizer.decode(truncated))
303    }
304}
305
306/// Merge small chunks to reduce overhead
307pub fn merge_small_chunks(chunks: Vec<TextChunk>, min_tokens: usize) -> Vec<TextChunk> {
308    if chunks.len() < 2 {
309        return chunks;
310    }
311
312    let mut merged = Vec::new();
313    let mut current = chunks[0].clone();
314
315    for chunk in chunks.into_iter().skip(1) {
316        if current.token_count < min_tokens {
317            // Merge with current
318            current.content.push('\n');
319            current.content.push_str(&chunk.content);
320            current.token_count += chunk.token_count;
321            current.end_index = chunk.end_index;
322        } else {
323            merged.push(current);
324            current = chunk;
325        }
326    }
327
328    merged.push(current);
329    merged
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_chunk_text_empty() {
338        let config = ChunkingConfig::default();
339        let chunks = chunk_text("", &config).unwrap();
340        assert!(chunks.is_empty());
341    }
342
343    #[test]
344    fn test_chunk_text_short() {
345        let config = ChunkingConfig::default();
346        let text = "This is a short text.";
347        let chunks = chunk_text(text, &config).unwrap();
348        assert_eq!(chunks.len(), 1);
349        assert_eq!(chunks[0].content, text);
350    }
351
352    #[test]
353    fn test_chunk_text_long() {
354        let config = ChunkingConfig {
355            chunk_size: 10,
356            chunk_overlap: 2,
357            separator: None,
358        };
359        let text = "This is a much longer text that needs to be split into multiple chunks. It contains several sentences and should be broken up appropriately.";
360        let chunks = chunk_text(text, &config).unwrap();
361        assert!(chunks.len() > 1);
362
363        // Check overlap
364        for i in 1..chunks.len() {
365            let prev_end = chunks[i - 1].end_index;
366            let curr_start = chunks[i].start_index;
367            assert!(curr_start < prev_end, "Chunks should overlap");
368        }
369    }
370
371    #[test]
372    fn test_tokenizer_count() {
373        let tokenizer = Tokenizer::new().unwrap();
374        let count = tokenizer.count_tokens("Hello world");
375        assert!(count > 0);
376    }
377
378    #[test]
379    fn test_estimate_token_count() {
380        let text = "This is a test sentence with approximately twelve tokens.";
381        let estimated = estimate_token_count(text);
382        let tokenizer = Tokenizer::new().unwrap();
383        let actual = tokenizer.count_tokens(text);
384
385        // Estimate should be in the ballpark
386        let diff = (estimated as i64 - actual as i64).abs();
387        assert!(diff < 5, "Estimate should be close to actual");
388    }
389}