1#![deny(unsafe_code)]
16#![warn(missing_docs)]
17#![warn(rust_2018_idioms)]
18
19use rayon::prelude::*;
20use regex::Regex;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23use tiktoken_rs::CoreBPE;
24
25pub type Result<T> = std::result::Result<T, ChunkerError>;
27
28#[derive(Error, Debug)]
30pub enum ChunkerError {
31 #[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
33 UnknownEncoding(String),
34 #[error("invalid config: {0}")]
36 InvalidConfig(String),
37 #[error("tiktoken-rs error: {0}")]
39 Tiktoken(String),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct ChunkConfig {
45 pub max_tokens: usize,
47 pub overlap_tokens: usize,
50 pub min_tokens: usize,
52 pub encoding: String,
54 #[serde(default)]
59 pub preserve_paragraphs: bool,
60}
61
62impl Default for ChunkConfig {
63 fn default() -> Self {
64 Self {
65 max_tokens: 512,
66 overlap_tokens: 0,
67 min_tokens: 1,
68 encoding: "cl100k_base".to_string(),
69 preserve_paragraphs: false,
70 }
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct Chunk {
77 pub text: String,
79 pub start: usize,
83 pub end: usize,
85 pub token_count: usize,
87}
88
89pub struct Chunker {
91 bpe: CoreBPE,
92 cfg: ChunkConfig,
93 sentence_re: Regex,
94}
95
96impl Chunker {
97 pub fn new(cfg: ChunkConfig) -> Result<Self> {
99 if cfg.max_tokens == 0 {
100 return Err(ChunkerError::InvalidConfig("max_tokens must be > 0".into()));
101 }
102 if cfg.overlap_tokens >= cfg.max_tokens {
103 return Err(ChunkerError::InvalidConfig(format!(
104 "overlap_tokens ({}) must be < max_tokens ({})",
105 cfg.overlap_tokens, cfg.max_tokens
106 )));
107 }
108 if cfg.min_tokens > cfg.max_tokens {
109 return Err(ChunkerError::InvalidConfig(format!(
110 "min_tokens ({}) must be <= max_tokens ({})",
111 cfg.min_tokens, cfg.max_tokens
112 )));
113 }
114 let bpe = match cfg.encoding.as_str() {
115 "cl100k_base" => {
116 tiktoken_rs::cl100k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
117 }
118 "o200k_base" => {
119 tiktoken_rs::o200k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
120 }
121 other => return Err(ChunkerError::UnknownEncoding(other.to_string())),
122 };
123
124 let sentence_re = Regex::new(
129 r"(?P<term>[.!?])(?P<close>[\)\]\}\u{201d}\u{2019}\u{00bb}'\x22]?)\s+(?P<next>[A-Z\u{00c0}-\u{00de}\u{2018}\u{201c}\(\[\{])"
130 ).expect("sentence regex compiles");
131
132 Ok(Self {
133 bpe,
134 cfg,
135 sentence_re,
136 })
137 }
138
139 pub fn split(&self, text: &str) -> Result<Vec<Chunk>> {
141 if self.cfg.preserve_paragraphs {
145 let mut all: Vec<Chunk> = Vec::new();
146 let paragraphs = split_paragraphs(text);
147 for (p_start, p_end) in paragraphs {
148 let para_text = &text[p_start..p_end];
149 let mut chunks = self.split_internal(para_text)?;
150 for c in &mut chunks {
153 c.start += p_start;
154 c.end += p_start;
155 }
156 all.extend(chunks);
157 }
158 return Ok(all);
159 }
160 self.split_internal(text)
161 }
162
163 fn split_internal(&self, text: &str) -> Result<Vec<Chunk>> {
166 let sentences = self.split_sentences(text);
168 if sentences.is_empty() {
169 return Ok(Vec::new());
170 }
171
172 let mut s_tokens: Vec<Vec<u32>> = Vec::with_capacity(sentences.len());
174 for &(start, end) in &sentences {
175 s_tokens.push(self.bpe.encode_ordinary(&text[start..end]));
176 }
177
178 let mut raw: Vec<(Vec<u32>, usize, usize)> = Vec::new();
181 let mut cur_tokens: Vec<u32> = Vec::new();
182 let mut cur_start: Option<usize> = None;
183 let mut cur_end: usize = 0;
184 for (i, &(s_start, s_end)) in sentences.iter().enumerate() {
185 let stoks = &s_tokens[i];
186 if stoks.len() > self.cfg.max_tokens {
189 if !cur_tokens.is_empty() {
190 raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
191 cur_start = None;
192 }
193 self.slice_long_sentence(stoks, s_start, s_end, &mut raw);
194 continue;
195 }
196 if cur_tokens.len() + stoks.len() > self.cfg.max_tokens && !cur_tokens.is_empty() {
198 raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
199 cur_start = None;
200 }
201 if cur_start.is_none() {
202 cur_start = Some(s_start);
203 }
204 cur_tokens.extend_from_slice(stoks);
205 cur_end = s_end;
206 }
207 if !cur_tokens.is_empty() {
208 raw.push((cur_tokens, cur_start.unwrap(), cur_end));
209 }
210
211 let mut out: Vec<Chunk> = Vec::with_capacity(raw.len());
213 let mut prev_tail: Vec<u32> = Vec::new();
214 for (toks, start, end) in raw {
215 let mut full = Vec::with_capacity(prev_tail.len() + toks.len());
216 full.extend_from_slice(&prev_tail);
217 full.extend_from_slice(&toks);
218 let text = self
219 .bpe
220 .decode(full.clone())
221 .map_err(|e| ChunkerError::Tiktoken(e.to_string()))?;
222 prev_tail = if self.cfg.overlap_tokens > 0 && toks.len() > self.cfg.overlap_tokens {
224 toks[toks.len() - self.cfg.overlap_tokens..].to_vec()
225 } else if self.cfg.overlap_tokens > 0 {
226 toks.clone()
227 } else {
228 Vec::new()
229 };
230 let token_count = full.len();
231 if token_count < self.cfg.min_tokens {
233 continue;
234 }
235 out.push(Chunk {
236 text,
237 start,
238 end,
239 token_count,
240 });
241 }
242 Ok(out)
243 }
244
245 pub fn split_many(&self, texts: &[&str], parallel: bool) -> Result<Vec<Vec<Chunk>>> {
248 if parallel {
249 texts.par_iter().map(|t| self.split(t)).collect()
250 } else {
251 texts.iter().map(|t| self.split(t)).collect()
252 }
253 }
254
255 fn split_sentences(&self, text: &str) -> Vec<(usize, usize)> {
259 if text.is_empty() {
260 return Vec::new();
261 }
262 let mut spans: Vec<(usize, usize)> = Vec::new();
263 let mut last = 0usize;
264 for caps in self.sentence_re.captures_iter(text) {
265 let m = caps.name("term").unwrap();
266 let cut = caps
270 .name("close")
271 .filter(|c| !c.as_str().is_empty())
272 .map(|c| c.end())
273 .unwrap_or_else(|| m.end());
274 if cut <= last {
276 continue;
277 }
278 if is_abbreviation(&text[..m.end()]) {
282 continue;
283 }
284 spans.push((last, cut));
285 let mut next_start = cut;
287 while next_start < text.len() && text.as_bytes()[next_start].is_ascii_whitespace() {
288 next_start += 1;
289 }
290 last = next_start;
291 }
292 if last < text.len() {
293 spans.push((last, text.len()));
294 }
295 spans.retain(|&(s, e)| s < e && !text[s..e].trim().is_empty());
297 spans
298 }
299
300 fn slice_long_sentence(
302 &self,
303 toks: &[u32],
304 s_start: usize,
305 s_end: usize,
306 out: &mut Vec<(Vec<u32>, usize, usize)>,
307 ) {
308 let mut i = 0usize;
312 while i < toks.len() {
313 let end = (i + self.cfg.max_tokens).min(toks.len());
314 out.push((toks[i..end].to_vec(), s_start, s_end));
315 i = end;
316 }
317 }
318}
319
320fn split_paragraphs(text: &str) -> Vec<(usize, usize)> {
326 if text.is_empty() {
327 return Vec::new();
328 }
329 let mut spans: Vec<(usize, usize)> = Vec::new();
330 let mut start: Option<usize> = None;
331 let bytes = text.as_bytes();
332 let mut i = 0usize;
333 while i < bytes.len() {
334 if bytes[i] == b'\n' {
335 let mut nl_end = i;
337 while nl_end < bytes.len() && bytes[nl_end] == b'\n' {
338 nl_end += 1;
339 }
340 if nl_end - i >= 2 {
341 if let Some(s) = start.take() {
343 spans.push((s, i));
344 }
345 i = nl_end;
346 continue;
347 }
348 }
349 if start.is_none() {
350 start = Some(i);
351 }
352 i += 1;
353 }
354 if let Some(s) = start {
355 spans.push((s, text.len()));
356 }
357 spans
358}
359
360fn is_abbreviation(prefix: &str) -> bool {
361 const ABBREVS: &[&str] = &[
362 "mr.", "mrs.", "ms.", "dr.", "st.", "jr.", "sr.", "prof.", "rev.", "vs.", "etc.", "e.g.",
363 "i.e.", "fig.", "cf.", "no.", "vol.", "ch.", "sec.",
364 ];
365 let lower_tail: String = prefix
366 .chars()
367 .rev()
368 .take(8)
369 .collect::<String>()
370 .chars()
371 .rev()
372 .collect::<String>()
373 .to_lowercase();
374 ABBREVS.iter().any(|a| lower_tail.ends_with(a))
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 fn cfg(max_tokens: usize) -> ChunkConfig {
382 ChunkConfig {
383 max_tokens,
384 overlap_tokens: 0,
385 min_tokens: 1,
386 encoding: "cl100k_base".to_string(),
387 preserve_paragraphs: false,
388 }
389 }
390
391 #[test]
392 fn empty_input_yields_no_chunks() {
393 let c = Chunker::new(cfg(100)).unwrap();
394 assert!(c.split("").unwrap().is_empty());
395 }
396
397 #[test]
398 fn short_text_one_chunk() {
399 let c = Chunker::new(cfg(100)).unwrap();
400 let r = c.split("hello world").unwrap();
401 assert_eq!(r.len(), 1);
402 assert_eq!(r[0].text, "hello world");
403 }
404
405 #[test]
406 fn splits_at_sentence_boundary_under_budget() {
407 let c = Chunker::new(cfg(8)).unwrap();
408 let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
409 let chunks = c.split(text).unwrap();
410 assert!(
413 chunks.len() >= 2,
414 "expected >=2 chunks, got {}",
415 chunks.len()
416 );
417 for ch in &chunks {
418 assert!(
419 ch.token_count <= 8,
420 "chunk over budget: {} tokens",
421 ch.token_count
422 );
423 }
424 }
425
426 #[test]
427 fn long_sentence_falls_back_to_token_slicing() {
428 let c = Chunker::new(cfg(5)).unwrap();
429 let text = "the quick brown fox jumps over the lazy dog and runs through fields";
431 let chunks = c.split(text).unwrap();
432 assert!(chunks.len() > 1);
433 for ch in &chunks {
434 assert!(ch.token_count <= 5);
435 }
436 }
437
438 #[test]
439 fn overlap_re_prepends_tail_tokens() {
440 let c = Chunker::new(ChunkConfig {
441 max_tokens: 6,
442 overlap_tokens: 2,
443 min_tokens: 1,
444 encoding: "cl100k_base".to_string(),
445 preserve_paragraphs: false,
446 })
447 .unwrap();
448 let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
449 let chunks = c.split(text).unwrap();
450 assert!(chunks.len() >= 2);
454 for ch in chunks.iter().skip(1) {
455 assert!(ch.token_count <= 6 + 2);
456 }
457 }
458
459 #[test]
460 fn min_tokens_drops_short_chunks() {
461 let c = Chunker::new(ChunkConfig {
463 max_tokens: 1000,
464 overlap_tokens: 0,
465 min_tokens: 50,
466 encoding: "cl100k_base".to_string(),
467 preserve_paragraphs: false,
468 })
469 .unwrap();
470 let text = "tiny.";
471 assert!(c.split(text).unwrap().is_empty());
472 }
473
474 #[test]
475 fn invalid_config_overlap_ge_max() {
476 let bad = ChunkConfig {
477 max_tokens: 10,
478 overlap_tokens: 10,
479 ..Default::default()
480 };
481 assert!(Chunker::new(bad).is_err());
482 }
483
484 #[test]
485 fn invalid_config_zero_max() {
486 let bad = ChunkConfig {
487 max_tokens: 0,
488 ..Default::default()
489 };
490 assert!(Chunker::new(bad).is_err());
491 }
492
493 #[test]
494 fn unknown_encoding_rejected() {
495 let bad = ChunkConfig {
496 encoding: "nope_base".to_string(),
497 ..Default::default()
498 };
499 assert!(matches!(
500 Chunker::new(bad),
501 Err(ChunkerError::UnknownEncoding(_))
502 ));
503 }
504
505 #[test]
506 fn abbreviation_does_not_split_sentence() {
507 let c = Chunker::new(cfg(1000)).unwrap();
508 let text = "Dr. Smith arrived. He said hello.";
509 let sentences = c.split_sentences(text);
510 assert_eq!(sentences.len(), 2, "got: {:?}", sentences);
512 }
513
514 #[test]
515 fn split_many_serial_and_parallel_match() {
516 let c = Chunker::new(cfg(10)).unwrap();
517 let texts = vec!["Alpha beta gamma.", "Delta. Epsilon. Zeta."];
518 let serial = c.split_many(&texts, false).unwrap();
519 let parallel = c.split_many(&texts, true).unwrap();
520 assert_eq!(serial, parallel);
521 }
522
523 #[test]
524 fn chunk_text_decodes_to_token_count() {
525 let c = Chunker::new(cfg(10)).unwrap();
526 let text = "The quick brown fox jumps over the lazy dog.";
527 let chunks = c.split(text).unwrap();
528 let bpe = tiktoken_rs::cl100k_base().unwrap();
531 for ch in &chunks {
532 let actual = bpe.encode_ordinary(&ch.text).len();
533 assert_eq!(actual, ch.token_count);
534 }
535 }
536
537 #[test]
538 fn unicode_input_handled() {
539 let c = Chunker::new(cfg(100)).unwrap();
540 let text = "你好世界. Hello world. 🌍 done.";
541 let r = c.split(text).unwrap();
542 assert!(!r.is_empty());
543 for ch in &r {
545 assert!(!ch.text.is_empty());
546 }
547 }
548
549 #[test]
550 fn min_tokens_filters_single_word_input() {
551 let c = Chunker::new(ChunkConfig {
554 max_tokens: 100,
555 overlap_tokens: 0,
556 min_tokens: 5,
557 encoding: "cl100k_base".to_string(),
558 preserve_paragraphs: false,
559 })
560 .unwrap();
561 let r = c.split("hi").unwrap();
562 assert!(r.is_empty());
563 }
564
565 #[test]
566 fn preserve_paragraphs_emits_per_paragraph_chunks() {
567 let c = Chunker::new(ChunkConfig {
570 max_tokens: 100,
571 overlap_tokens: 0,
572 min_tokens: 1,
573 encoding: "cl100k_base".to_string(),
574 preserve_paragraphs: true,
575 })
576 .unwrap();
577 let text = "First paragraph here.\n\nSecond paragraph here.";
578 let r = c.split(text).unwrap();
579 assert_eq!(r.len(), 2);
580 for ch in &r {
582 assert!(ch.end <= text.len());
583 assert!(text.get(ch.start..ch.end).is_some());
584 }
585 }
586
587 #[test]
588 fn preserve_paragraphs_respects_token_budget_per_paragraph() {
589 let c = Chunker::new(ChunkConfig {
592 max_tokens: 5,
593 overlap_tokens: 0,
594 min_tokens: 1,
595 encoding: "cl100k_base".to_string(),
596 preserve_paragraphs: true,
597 })
598 .unwrap();
599 let text = "alpha beta gamma delta epsilon zeta\n\nshort.";
600 let r = c.split(text).unwrap();
601 assert!(r.len() >= 2);
603 }
604
605 #[test]
606 fn split_paragraphs_helper_returns_disjoint_spans() {
607 let text = "para 1\n\n\npara 2\n\npara 3";
608 let spans = split_paragraphs(text);
609 assert_eq!(spans.len(), 3);
610 assert_eq!(&text[spans[0].0..spans[0].1], "para 1");
612 assert_eq!(&text[spans[1].0..spans[1].1], "para 2");
613 assert_eq!(&text[spans[2].0..spans[2].1], "para 3");
614 }
615
616 #[test]
617 fn preserve_paragraphs_default_off_keeps_existing_behavior() {
618 let c = Chunker::new(ChunkConfig {
621 max_tokens: 100,
622 overlap_tokens: 0,
623 min_tokens: 1,
624 encoding: "cl100k_base".to_string(),
625 preserve_paragraphs: false,
626 })
627 .unwrap();
628 let text = "First paragraph here.\n\nSecond paragraph here.";
629 let r = c.split(text).unwrap();
630 assert_eq!(r.len(), 1);
631 }
632}