1#![deny(unsafe_code)]
16#![warn(missing_docs)]
17#![warn(rust_2018_idioms)]
18
19use rayon::prelude::*;
20use regex::Regex;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23use tiktoken_rs::CoreBPE;
24
25pub type Result<T> = std::result::Result<T, ChunkerError>;
27
28#[derive(Error, Debug)]
30pub enum ChunkerError {
31 #[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
33 UnknownEncoding(String),
34 #[error("invalid config: {0}")]
36 InvalidConfig(String),
37 #[error("tiktoken-rs error: {0}")]
39 Tiktoken(String),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct ChunkConfig {
45 pub max_tokens: usize,
47 pub overlap_tokens: usize,
50 pub min_tokens: usize,
52 pub encoding: String,
54}
55
56impl Default for ChunkConfig {
57 fn default() -> Self {
58 Self {
59 max_tokens: 512,
60 overlap_tokens: 0,
61 min_tokens: 1,
62 encoding: "cl100k_base".to_string(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
69pub struct Chunk {
70 pub text: String,
72 pub start: usize,
76 pub end: usize,
78 pub token_count: usize,
80}
81
82pub struct Chunker {
84 bpe: CoreBPE,
85 cfg: ChunkConfig,
86 sentence_re: Regex,
87}
88
89impl Chunker {
90 pub fn new(cfg: ChunkConfig) -> Result<Self> {
92 if cfg.max_tokens == 0 {
93 return Err(ChunkerError::InvalidConfig("max_tokens must be > 0".into()));
94 }
95 if cfg.overlap_tokens >= cfg.max_tokens {
96 return Err(ChunkerError::InvalidConfig(format!(
97 "overlap_tokens ({}) must be < max_tokens ({})",
98 cfg.overlap_tokens, cfg.max_tokens
99 )));
100 }
101 if cfg.min_tokens > cfg.max_tokens {
102 return Err(ChunkerError::InvalidConfig(format!(
103 "min_tokens ({}) must be <= max_tokens ({})",
104 cfg.min_tokens, cfg.max_tokens
105 )));
106 }
107 let bpe = match cfg.encoding.as_str() {
108 "cl100k_base" => {
109 tiktoken_rs::cl100k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
110 }
111 "o200k_base" => {
112 tiktoken_rs::o200k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
113 }
114 other => return Err(ChunkerError::UnknownEncoding(other.to_string())),
115 };
116
117 let sentence_re = Regex::new(
122 r"(?P<term>[.!?])(?P<close>[\)\]\}\u{201d}\u{2019}\u{00bb}'\x22]?)\s+(?P<next>[A-Z\u{00c0}-\u{00de}\u{2018}\u{201c}\(\[\{])"
123 ).expect("sentence regex compiles");
124
125 Ok(Self {
126 bpe,
127 cfg,
128 sentence_re,
129 })
130 }
131
132 pub fn split(&self, text: &str) -> Result<Vec<Chunk>> {
134 let sentences = self.split_sentences(text);
136 if sentences.is_empty() {
137 return Ok(Vec::new());
138 }
139
140 let mut s_tokens: Vec<Vec<u32>> = Vec::with_capacity(sentences.len());
142 for &(start, end) in &sentences {
143 s_tokens.push(self.bpe.encode_ordinary(&text[start..end]));
144 }
145
146 let mut raw: Vec<(Vec<u32>, usize, usize)> = Vec::new();
149 let mut cur_tokens: Vec<u32> = Vec::new();
150 let mut cur_start: Option<usize> = None;
151 let mut cur_end: usize = 0;
152 for (i, &(s_start, s_end)) in sentences.iter().enumerate() {
153 let stoks = &s_tokens[i];
154 if stoks.len() > self.cfg.max_tokens {
157 if !cur_tokens.is_empty() {
158 raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
159 cur_start = None;
160 }
161 self.slice_long_sentence(stoks, s_start, s_end, &mut raw);
162 continue;
163 }
164 if cur_tokens.len() + stoks.len() > self.cfg.max_tokens && !cur_tokens.is_empty() {
166 raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
167 cur_start = None;
168 }
169 if cur_start.is_none() {
170 cur_start = Some(s_start);
171 }
172 cur_tokens.extend_from_slice(stoks);
173 cur_end = s_end;
174 }
175 if !cur_tokens.is_empty() {
176 raw.push((cur_tokens, cur_start.unwrap(), cur_end));
177 }
178
179 let mut out: Vec<Chunk> = Vec::with_capacity(raw.len());
181 let mut prev_tail: Vec<u32> = Vec::new();
182 for (toks, start, end) in raw {
183 let mut full = Vec::with_capacity(prev_tail.len() + toks.len());
184 full.extend_from_slice(&prev_tail);
185 full.extend_from_slice(&toks);
186 let text = self
187 .bpe
188 .decode(full.clone())
189 .map_err(|e| ChunkerError::Tiktoken(e.to_string()))?;
190 prev_tail = if self.cfg.overlap_tokens > 0 && toks.len() > self.cfg.overlap_tokens {
192 toks[toks.len() - self.cfg.overlap_tokens..].to_vec()
193 } else if self.cfg.overlap_tokens > 0 {
194 toks.clone()
195 } else {
196 Vec::new()
197 };
198 let token_count = full.len();
199 if token_count < self.cfg.min_tokens {
201 continue;
202 }
203 out.push(Chunk {
204 text,
205 start,
206 end,
207 token_count,
208 });
209 }
210 Ok(out)
211 }
212
213 pub fn split_many(&self, texts: &[&str], parallel: bool) -> Result<Vec<Vec<Chunk>>> {
216 if parallel {
217 texts.par_iter().map(|t| self.split(t)).collect()
218 } else {
219 texts.iter().map(|t| self.split(t)).collect()
220 }
221 }
222
223 fn split_sentences(&self, text: &str) -> Vec<(usize, usize)> {
227 if text.is_empty() {
228 return Vec::new();
229 }
230 let mut spans: Vec<(usize, usize)> = Vec::new();
231 let mut last = 0usize;
232 for caps in self.sentence_re.captures_iter(text) {
233 let m = caps.name("term").unwrap();
234 let cut = caps
238 .name("close")
239 .filter(|c| !c.as_str().is_empty())
240 .map(|c| c.end())
241 .unwrap_or_else(|| m.end());
242 if cut <= last {
244 continue;
245 }
246 if is_abbreviation(&text[..m.end()]) {
250 continue;
251 }
252 spans.push((last, cut));
253 let mut next_start = cut;
255 while next_start < text.len() && text.as_bytes()[next_start].is_ascii_whitespace() {
256 next_start += 1;
257 }
258 last = next_start;
259 }
260 if last < text.len() {
261 spans.push((last, text.len()));
262 }
263 spans.retain(|&(s, e)| s < e && !text[s..e].trim().is_empty());
265 spans
266 }
267
268 fn slice_long_sentence(
270 &self,
271 toks: &[u32],
272 s_start: usize,
273 s_end: usize,
274 out: &mut Vec<(Vec<u32>, usize, usize)>,
275 ) {
276 let mut i = 0usize;
280 while i < toks.len() {
281 let end = (i + self.cfg.max_tokens).min(toks.len());
282 out.push((toks[i..end].to_vec(), s_start, s_end));
283 i = end;
284 }
285 }
286}
287
288fn is_abbreviation(prefix: &str) -> bool {
291 const ABBREVS: &[&str] = &[
292 "mr.", "mrs.", "ms.", "dr.", "st.", "jr.", "sr.", "prof.", "rev.", "vs.", "etc.", "e.g.",
293 "i.e.", "fig.", "cf.", "no.", "vol.", "ch.", "sec.",
294 ];
295 let lower_tail: String = prefix
296 .chars()
297 .rev()
298 .take(8)
299 .collect::<String>()
300 .chars()
301 .rev()
302 .collect::<String>()
303 .to_lowercase();
304 ABBREVS.iter().any(|a| lower_tail.ends_with(a))
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 fn cfg(max_tokens: usize) -> ChunkConfig {
312 ChunkConfig {
313 max_tokens,
314 overlap_tokens: 0,
315 min_tokens: 1,
316 encoding: "cl100k_base".to_string(),
317 }
318 }
319
320 #[test]
321 fn empty_input_yields_no_chunks() {
322 let c = Chunker::new(cfg(100)).unwrap();
323 assert!(c.split("").unwrap().is_empty());
324 }
325
326 #[test]
327 fn short_text_one_chunk() {
328 let c = Chunker::new(cfg(100)).unwrap();
329 let r = c.split("hello world").unwrap();
330 assert_eq!(r.len(), 1);
331 assert_eq!(r[0].text, "hello world");
332 }
333
334 #[test]
335 fn splits_at_sentence_boundary_under_budget() {
336 let c = Chunker::new(cfg(8)).unwrap();
337 let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
338 let chunks = c.split(text).unwrap();
339 assert!(
342 chunks.len() >= 2,
343 "expected >=2 chunks, got {}",
344 chunks.len()
345 );
346 for ch in &chunks {
347 assert!(
348 ch.token_count <= 8,
349 "chunk over budget: {} tokens",
350 ch.token_count
351 );
352 }
353 }
354
355 #[test]
356 fn long_sentence_falls_back_to_token_slicing() {
357 let c = Chunker::new(cfg(5)).unwrap();
358 let text = "the quick brown fox jumps over the lazy dog and runs through fields";
360 let chunks = c.split(text).unwrap();
361 assert!(chunks.len() > 1);
362 for ch in &chunks {
363 assert!(ch.token_count <= 5);
364 }
365 }
366
367 #[test]
368 fn overlap_re_prepends_tail_tokens() {
369 let c = Chunker::new(ChunkConfig {
370 max_tokens: 6,
371 overlap_tokens: 2,
372 min_tokens: 1,
373 encoding: "cl100k_base".to_string(),
374 })
375 .unwrap();
376 let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
377 let chunks = c.split(text).unwrap();
378 assert!(chunks.len() >= 2);
382 for ch in chunks.iter().skip(1) {
383 assert!(ch.token_count <= 6 + 2);
384 }
385 }
386
387 #[test]
388 fn min_tokens_drops_short_chunks() {
389 let c = Chunker::new(ChunkConfig {
391 max_tokens: 1000,
392 overlap_tokens: 0,
393 min_tokens: 50,
394 encoding: "cl100k_base".to_string(),
395 })
396 .unwrap();
397 let text = "tiny.";
398 assert!(c.split(text).unwrap().is_empty());
399 }
400
401 #[test]
402 fn invalid_config_overlap_ge_max() {
403 let bad = ChunkConfig {
404 max_tokens: 10,
405 overlap_tokens: 10,
406 ..Default::default()
407 };
408 assert!(Chunker::new(bad).is_err());
409 }
410
411 #[test]
412 fn invalid_config_zero_max() {
413 let bad = ChunkConfig {
414 max_tokens: 0,
415 ..Default::default()
416 };
417 assert!(Chunker::new(bad).is_err());
418 }
419
420 #[test]
421 fn unknown_encoding_rejected() {
422 let bad = ChunkConfig {
423 encoding: "nope_base".to_string(),
424 ..Default::default()
425 };
426 assert!(matches!(
427 Chunker::new(bad),
428 Err(ChunkerError::UnknownEncoding(_))
429 ));
430 }
431
432 #[test]
433 fn abbreviation_does_not_split_sentence() {
434 let c = Chunker::new(cfg(1000)).unwrap();
435 let text = "Dr. Smith arrived. He said hello.";
436 let sentences = c.split_sentences(text);
437 assert_eq!(sentences.len(), 2, "got: {:?}", sentences);
439 }
440
441 #[test]
442 fn split_many_serial_and_parallel_match() {
443 let c = Chunker::new(cfg(10)).unwrap();
444 let texts = vec!["Alpha beta gamma.", "Delta. Epsilon. Zeta."];
445 let serial = c.split_many(&texts, false).unwrap();
446 let parallel = c.split_many(&texts, true).unwrap();
447 assert_eq!(serial, parallel);
448 }
449
450 #[test]
451 fn chunk_text_decodes_to_token_count() {
452 let c = Chunker::new(cfg(10)).unwrap();
453 let text = "The quick brown fox jumps over the lazy dog.";
454 let chunks = c.split(text).unwrap();
455 let bpe = tiktoken_rs::cl100k_base().unwrap();
458 for ch in &chunks {
459 let actual = bpe.encode_ordinary(&ch.text).len();
460 assert_eq!(actual, ch.token_count);
461 }
462 }
463
464 #[test]
465 fn unicode_input_handled() {
466 let c = Chunker::new(cfg(100)).unwrap();
467 let text = "你好世界. Hello world. 🌍 done.";
468 let r = c.split(text).unwrap();
469 assert!(!r.is_empty());
470 for ch in &r {
472 assert!(!ch.text.is_empty());
473 }
474 }
475
476 #[test]
477 fn min_tokens_filters_single_word_input() {
478 let c = Chunker::new(ChunkConfig {
481 max_tokens: 100,
482 overlap_tokens: 0,
483 min_tokens: 5,
484 encoding: "cl100k_base".to_string(),
485 })
486 .unwrap();
487 let r = c.split("hi").unwrap();
488 assert!(r.is_empty());
489 }
490}