Skip to main content

snipsplit_core/
lib.rs

1//! Pure-Rust core for `snipsplit`. Token-aware greedy chunker for RAG.
2//!
3//! Algorithm:
4//! 1. Split into paragraphs on blank lines, then sentences via a regex
5//!    that handles the common abbreviation pitfalls (`Mr.`, `Dr.`, `e.g.`,
6//!    `vs.`, `etc.`, version-style `1.0`, decimal numbers).
7//! 2. Greedy-pack sentences into chunks while the running BPE token count
8//!    is `<= max_tokens`.
9//! 3. If a single sentence is too big on its own, slice it at token
10//!    boundaries instead.
11//! 4. Apply `overlap_tokens` by re-prepending the last N tokens of each
12//!    emitted chunk to the next.
13//! 5. Drop chunks shorter than `min_tokens`.
14
15#![deny(unsafe_code)]
16#![warn(missing_docs)]
17#![warn(rust_2018_idioms)]
18
19use rayon::prelude::*;
20use regex::Regex;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23use tiktoken_rs::CoreBPE;
24
25/// Crate-wide result alias.
26pub type Result<T> = std::result::Result<T, ChunkerError>;
27
28/// All errors surfaced by `snipsplit-core`.
29#[derive(Error, Debug)]
30pub enum ChunkerError {
31    /// Unknown encoding name. Supported: `cl100k_base`, `o200k_base`.
32    #[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
33    UnknownEncoding(String),
34    /// Caller supplied an invalid configuration value.
35    #[error("invalid config: {0}")]
36    InvalidConfig(String),
37    /// tiktoken-rs failure.
38    #[error("tiktoken-rs error: {0}")]
39    Tiktoken(String),
40}
41
42/// Chunker configuration.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct ChunkConfig {
45    /// Hard cap on tokens per chunk.
46    pub max_tokens: usize,
47    /// Number of trailing tokens of each chunk re-prepended to the next.
48    /// Set to 0 to disable overlap.
49    pub overlap_tokens: usize,
50    /// Drop chunks shorter than this.
51    pub min_tokens: usize,
52    /// Encoding name (`cl100k_base` or `o200k_base`).
53    pub encoding: String,
54}
55
56impl Default for ChunkConfig {
57    fn default() -> Self {
58        Self {
59            max_tokens: 512,
60            overlap_tokens: 0,
61            min_tokens: 1,
62            encoding: "cl100k_base".to_string(),
63        }
64    }
65}
66
67/// One emitted chunk.
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
69pub struct Chunk {
70    /// The chunk text. Includes any prepended overlap from the prior chunk.
71    pub text: String,
72    /// Byte offset of the first chunk character (excluding overlap) in the
73    /// original input. Two adjacent chunks may have overlapping `[start, end)`
74    /// ranges if `overlap_tokens > 0`.
75    pub start: usize,
76    /// Byte offset (exclusive) of the chunk end in the original input.
77    pub end: usize,
78    /// Exact BPE token count of `text`.
79    pub token_count: usize,
80}
81
82/// Token-aware chunker.
83pub struct Chunker {
84    bpe: CoreBPE,
85    cfg: ChunkConfig,
86    sentence_re: Regex,
87}
88
89impl Chunker {
90    /// Build a chunker from a config.
91    pub fn new(cfg: ChunkConfig) -> Result<Self> {
92        if cfg.max_tokens == 0 {
93            return Err(ChunkerError::InvalidConfig("max_tokens must be > 0".into()));
94        }
95        if cfg.overlap_tokens >= cfg.max_tokens {
96            return Err(ChunkerError::InvalidConfig(format!(
97                "overlap_tokens ({}) must be < max_tokens ({})",
98                cfg.overlap_tokens, cfg.max_tokens
99            )));
100        }
101        if cfg.min_tokens > cfg.max_tokens {
102            return Err(ChunkerError::InvalidConfig(format!(
103                "min_tokens ({}) must be <= max_tokens ({})",
104                cfg.min_tokens, cfg.max_tokens
105            )));
106        }
107        let bpe = match cfg.encoding.as_str() {
108            "cl100k_base" => {
109                tiktoken_rs::cl100k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
110            }
111            "o200k_base" => {
112                tiktoken_rs::o200k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
113            }
114            other => return Err(ChunkerError::UnknownEncoding(other.to_string())),
115        };
116
117        // Sentence boundary regex. Matches whitespace following sentence-
118        // terminating punctuation, but not when preceded by a known
119        // abbreviation. Conservative; deliberately misses some cases (e.g.
120        // numbered lists) rather than over-splitting.
121        let sentence_re = Regex::new(
122            r"(?P<term>[.!?])(?P<close>[\)\]\}\u{201d}\u{2019}\u{00bb}'\x22]?)\s+(?P<next>[A-Z\u{00c0}-\u{00de}\u{2018}\u{201c}\(\[\{])"
123        ).expect("sentence regex compiles");
124
125        Ok(Self {
126            bpe,
127            cfg,
128            sentence_re,
129        })
130    }
131
132    /// Split `text` into chunks.
133    pub fn split(&self, text: &str) -> Result<Vec<Chunk>> {
134        // Step 1: collect sentence spans (start, end) into the original text.
135        let sentences = self.split_sentences(text);
136        if sentences.is_empty() {
137            return Ok(Vec::new());
138        }
139
140        // Step 2: pre-compute token IDs for each sentence so we count once.
141        let mut s_tokens: Vec<Vec<u32>> = Vec::with_capacity(sentences.len());
142        for &(start, end) in &sentences {
143            s_tokens.push(self.bpe.encode_ordinary(&text[start..end]));
144        }
145
146        // Step 3: greedy pack. Emits raw chunks (without overlap) as a vec of
147        // owned token id sequences and the (start, end) byte spans they cover.
148        let mut raw: Vec<(Vec<u32>, usize, usize)> = Vec::new();
149        let mut cur_tokens: Vec<u32> = Vec::new();
150        let mut cur_start: Option<usize> = None;
151        let mut cur_end: usize = 0;
152        for (i, &(s_start, s_end)) in sentences.iter().enumerate() {
153            let stoks = &s_tokens[i];
154            // If this single sentence already exceeds the budget, flush
155            // whatever we have and slice the sentence at token boundaries.
156            if stoks.len() > self.cfg.max_tokens {
157                if !cur_tokens.is_empty() {
158                    raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
159                    cur_start = None;
160                }
161                self.slice_long_sentence(stoks, s_start, s_end, &mut raw);
162                continue;
163            }
164            // Would adding this sentence overflow?
165            if cur_tokens.len() + stoks.len() > self.cfg.max_tokens && !cur_tokens.is_empty() {
166                raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
167                cur_start = None;
168            }
169            if cur_start.is_none() {
170                cur_start = Some(s_start);
171            }
172            cur_tokens.extend_from_slice(stoks);
173            cur_end = s_end;
174        }
175        if !cur_tokens.is_empty() {
176            raw.push((cur_tokens, cur_start.unwrap(), cur_end));
177        }
178
179        // Step 4: apply overlap and decode.
180        let mut out: Vec<Chunk> = Vec::with_capacity(raw.len());
181        let mut prev_tail: Vec<u32> = Vec::new();
182        for (toks, start, end) in raw {
183            let mut full = Vec::with_capacity(prev_tail.len() + toks.len());
184            full.extend_from_slice(&prev_tail);
185            full.extend_from_slice(&toks);
186            let text = self
187                .bpe
188                .decode(full.clone())
189                .map_err(|e| ChunkerError::Tiktoken(e.to_string()))?;
190            // Update prev_tail for the next iteration.
191            prev_tail = if self.cfg.overlap_tokens > 0 && toks.len() > self.cfg.overlap_tokens {
192                toks[toks.len() - self.cfg.overlap_tokens..].to_vec()
193            } else if self.cfg.overlap_tokens > 0 {
194                toks.clone()
195            } else {
196                Vec::new()
197            };
198            let token_count = full.len();
199            // Skip below min_tokens.
200            if token_count < self.cfg.min_tokens {
201                continue;
202            }
203            out.push(Chunk {
204                text,
205                start,
206                end,
207                token_count,
208            });
209        }
210        Ok(out)
211    }
212
213    /// Split many texts. With `parallel = true`, distributes across the
214    /// rayon pool. Each call into `split` is independent.
215    pub fn split_many(&self, texts: &[&str], parallel: bool) -> Result<Vec<Vec<Chunk>>> {
216        if parallel {
217            texts.par_iter().map(|t| self.split(t)).collect()
218        } else {
219            texts.iter().map(|t| self.split(t)).collect()
220        }
221    }
222
223    /// Sentence boundaries, returned as `(byte_start, byte_end)` half-open
224    /// ranges into `text`. Always covers the full input — there are no
225    /// gaps. Empty input returns no sentences.
226    fn split_sentences(&self, text: &str) -> Vec<(usize, usize)> {
227        if text.is_empty() {
228            return Vec::new();
229        }
230        let mut spans: Vec<(usize, usize)> = Vec::new();
231        let mut last = 0usize;
232        for caps in self.sentence_re.captures_iter(text) {
233            let m = caps.name("term").unwrap();
234            // Cut just after the closing punctuation/quote group, BEFORE the
235            // whitespace before `next`. We use the end of the `close` group
236            // if present, otherwise end of `term`.
237            let cut = caps
238                .name("close")
239                .filter(|c| !c.as_str().is_empty())
240                .map(|c| c.end())
241                .unwrap_or_else(|| m.end());
242            // Skip if this would create an empty span.
243            if cut <= last {
244                continue;
245            }
246            // Suppress split if the substring just before `term` is a known
247            // abbreviation. Cheap heuristic; production splitters use a
248            // gazetteer.
249            if is_abbreviation(&text[..m.end()]) {
250                continue;
251            }
252            spans.push((last, cut));
253            // Advance last past any whitespace.
254            let mut next_start = cut;
255            while next_start < text.len() && text.as_bytes()[next_start].is_ascii_whitespace() {
256                next_start += 1;
257            }
258            last = next_start;
259        }
260        if last < text.len() {
261            spans.push((last, text.len()));
262        }
263        // Filter out empty/whitespace-only spans.
264        spans.retain(|&(s, e)| s < e && !text[s..e].trim().is_empty());
265        spans
266    }
267
268    /// Slice an over-long sentence at token boundaries.
269    fn slice_long_sentence(
270        &self,
271        toks: &[u32],
272        s_start: usize,
273        s_end: usize,
274        out: &mut Vec<(Vec<u32>, usize, usize)>,
275    ) {
276        // We can't recover exact byte offsets per token without re-encoding
277        // partials, so attribute the entire sentence span to every slice.
278        // Callers wanting exact offsets should bump the budget instead.
279        let mut i = 0usize;
280        while i < toks.len() {
281            let end = (i + self.cfg.max_tokens).min(toks.len());
282            out.push((toks[i..end].to_vec(), s_start, s_end));
283            i = end;
284        }
285    }
286}
287
288/// Suffix-check the string against a small list of common English
289/// abbreviations that produce false positives in sentence splitting.
290fn is_abbreviation(prefix: &str) -> bool {
291    const ABBREVS: &[&str] = &[
292        "mr.", "mrs.", "ms.", "dr.", "st.", "jr.", "sr.", "prof.", "rev.", "vs.", "etc.", "e.g.",
293        "i.e.", "fig.", "cf.", "no.", "vol.", "ch.", "sec.",
294    ];
295    let lower_tail: String = prefix
296        .chars()
297        .rev()
298        .take(8)
299        .collect::<String>()
300        .chars()
301        .rev()
302        .collect::<String>()
303        .to_lowercase();
304    ABBREVS.iter().any(|a| lower_tail.ends_with(a))
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    fn cfg(max_tokens: usize) -> ChunkConfig {
312        ChunkConfig {
313            max_tokens,
314            overlap_tokens: 0,
315            min_tokens: 1,
316            encoding: "cl100k_base".to_string(),
317        }
318    }
319
320    #[test]
321    fn empty_input_yields_no_chunks() {
322        let c = Chunker::new(cfg(100)).unwrap();
323        assert!(c.split("").unwrap().is_empty());
324    }
325
326    #[test]
327    fn short_text_one_chunk() {
328        let c = Chunker::new(cfg(100)).unwrap();
329        let r = c.split("hello world").unwrap();
330        assert_eq!(r.len(), 1);
331        assert_eq!(r[0].text, "hello world");
332    }
333
334    #[test]
335    fn splits_at_sentence_boundary_under_budget() {
336        let c = Chunker::new(cfg(8)).unwrap();
337        let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
338        let chunks = c.split(text).unwrap();
339        // 3 sentences, ~5 tokens each at cl100k. Should produce more than 1
340        // chunk under a budget of 8.
341        assert!(
342            chunks.len() >= 2,
343            "expected >=2 chunks, got {}",
344            chunks.len()
345        );
346        for ch in &chunks {
347            assert!(
348                ch.token_count <= 8,
349                "chunk over budget: {} tokens",
350                ch.token_count
351            );
352        }
353    }
354
355    #[test]
356    fn long_sentence_falls_back_to_token_slicing() {
357        let c = Chunker::new(cfg(5)).unwrap();
358        // Single sentence with many tokens.
359        let text = "the quick brown fox jumps over the lazy dog and runs through fields";
360        let chunks = c.split(text).unwrap();
361        assert!(chunks.len() > 1);
362        for ch in &chunks {
363            assert!(ch.token_count <= 5);
364        }
365    }
366
367    #[test]
368    fn overlap_re_prepends_tail_tokens() {
369        let c = Chunker::new(ChunkConfig {
370            max_tokens: 6,
371            overlap_tokens: 2,
372            min_tokens: 1,
373            encoding: "cl100k_base".to_string(),
374        })
375        .unwrap();
376        let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
377        let chunks = c.split(text).unwrap();
378        // Each chunk after the first should include the prior chunk's last
379        // 2 tokens, so total token_count can exceed max_tokens by up to
380        // overlap_tokens.
381        assert!(chunks.len() >= 2);
382        for ch in chunks.iter().skip(1) {
383            assert!(ch.token_count <= 6 + 2);
384        }
385    }
386
387    #[test]
388    fn min_tokens_drops_short_chunks() {
389        // After packing, any chunk below 50 tokens is dropped.
390        let c = Chunker::new(ChunkConfig {
391            max_tokens: 1000,
392            overlap_tokens: 0,
393            min_tokens: 50,
394            encoding: "cl100k_base".to_string(),
395        })
396        .unwrap();
397        let text = "tiny.";
398        assert!(c.split(text).unwrap().is_empty());
399    }
400
401    #[test]
402    fn invalid_config_overlap_ge_max() {
403        let bad = ChunkConfig {
404            max_tokens: 10,
405            overlap_tokens: 10,
406            ..Default::default()
407        };
408        assert!(Chunker::new(bad).is_err());
409    }
410
411    #[test]
412    fn invalid_config_zero_max() {
413        let bad = ChunkConfig {
414            max_tokens: 0,
415            ..Default::default()
416        };
417        assert!(Chunker::new(bad).is_err());
418    }
419
420    #[test]
421    fn unknown_encoding_rejected() {
422        let bad = ChunkConfig {
423            encoding: "nope_base".to_string(),
424            ..Default::default()
425        };
426        assert!(matches!(
427            Chunker::new(bad),
428            Err(ChunkerError::UnknownEncoding(_))
429        ));
430    }
431
432    #[test]
433    fn abbreviation_does_not_split_sentence() {
434        let c = Chunker::new(cfg(1000)).unwrap();
435        let text = "Dr. Smith arrived. He said hello.";
436        let sentences = c.split_sentences(text);
437        // We expect ~2 sentences: "Dr. Smith arrived." and "He said hello."
438        assert_eq!(sentences.len(), 2, "got: {:?}", sentences);
439    }
440
441    #[test]
442    fn split_many_serial_and_parallel_match() {
443        let c = Chunker::new(cfg(10)).unwrap();
444        let texts = vec!["Alpha beta gamma.", "Delta. Epsilon. Zeta."];
445        let serial = c.split_many(&texts, false).unwrap();
446        let parallel = c.split_many(&texts, true).unwrap();
447        assert_eq!(serial, parallel);
448    }
449
450    #[test]
451    fn chunk_text_decodes_to_token_count() {
452        let c = Chunker::new(cfg(10)).unwrap();
453        let text = "The quick brown fox jumps over the lazy dog.";
454        let chunks = c.split(text).unwrap();
455        // For each chunk, re-encoding the chunk's text should give the
456        // same token count.
457        let bpe = tiktoken_rs::cl100k_base().unwrap();
458        for ch in &chunks {
459            let actual = bpe.encode_ordinary(&ch.text).len();
460            assert_eq!(actual, ch.token_count);
461        }
462    }
463
464    #[test]
465    fn unicode_input_handled() {
466        let c = Chunker::new(cfg(100)).unwrap();
467        let text = "你好世界. Hello world. 🌍 done.";
468        let r = c.split(text).unwrap();
469        assert!(!r.is_empty());
470        // No chunk should crash on decoding.
471        for ch in &r {
472            assert!(!ch.text.is_empty());
473        }
474    }
475
476    #[test]
477    fn min_tokens_filters_single_word_input() {
478        // Config requires min 5 tokens but the input is just "hi", so the
479        // single emitted chunk falls under the floor and is dropped.
480        let c = Chunker::new(ChunkConfig {
481            max_tokens: 100,
482            overlap_tokens: 0,
483            min_tokens: 5,
484            encoding: "cl100k_base".to_string(),
485        })
486        .unwrap();
487        let r = c.split("hi").unwrap();
488        assert!(r.is_empty());
489    }
490}