1use std::collections::HashSet;
12
13fn builtin_abbreviations() -> HashSet<String> {
18 [
19 "Mr", "Mrs", "Ms", "Miss", "Dr", "Prof", "Rev", "Gen", "Col", "Capt", "Lt", "Sgt", "Cpl",
21 "Pte", "Sr", "Jr", "St", "Ave", "Blvd", "Rd", "Ln", "Ct", "Pl", "Mt", "Ft", "Jan", "Feb", "Mar", "Apr", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", "Mon", "Tue",
24 "Wed", "Thu", "Fri", "Sat", "Sun", "etc", "vs", "approx", "est", "dept", "corp", "co", "inc", "Fig", "fig", "Vol", "vol",
26 "No", "Nos", "pp", "Ch", "Sec", "e.g", "i.e", "et", "al", "n.b", "N.B", "Esq",
27 ]
28 .iter()
29 .map(|s| s.to_string())
30 .collect()
31}
32
33pub struct SentenceSegmenter {
56 abbreviations: HashSet<String>,
58 pub min_sentence_len: usize,
60}
61
62impl Default for SentenceSegmenter {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl SentenceSegmenter {
69 pub fn new() -> Self {
71 Self {
72 abbreviations: builtin_abbreviations(),
73 min_sentence_len: 10,
74 }
75 }
76
77 pub fn with_abbreviations(abbrevs: Vec<String>) -> Self {
79 let mut set = builtin_abbreviations();
80 for a in abbrevs {
81 set.insert(a);
82 }
83 Self {
84 abbreviations: set,
85 min_sentence_len: 10,
86 }
87 }
88
89 pub fn segment<'a>(&self, text: &'a str) -> Vec<&'a str> {
91 if text.trim().is_empty() {
92 return Vec::new();
93 }
94 let boundaries = self.find_boundaries(text);
95 let mut result: Vec<&'a str> = Vec::new();
96 let mut start = 0;
97
98 for end in boundaries {
99 let slice = text[start..end].trim();
100 if !slice.is_empty() {
101 result.push(slice);
102 }
103 start = end;
104 }
105
106 let tail = text[start..].trim();
107 if !tail.is_empty() {
108 result.push(tail);
109 }
110
111 result
112 }
113
114 pub fn segment_owned(&self, text: &str) -> Vec<String> {
116 if text.trim().is_empty() {
117 return Vec::new();
118 }
119 let raw: Vec<String> = self.segment(text).iter().map(|s| s.to_string()).collect();
120
121 let mut result: Vec<String> = Vec::new();
123 for s in raw {
124 if s.len() < self.min_sentence_len && !result.is_empty() {
125 if let Some(last) = result.last_mut() {
126 last.push(' ');
127 last.push_str(&s);
128 }
129 } else {
130 result.push(s);
131 }
132 }
133 result
134 }
135
136 fn find_boundaries(&self, text: &str) -> Vec<usize> {
141 let chars: Vec<(usize, char)> = text.char_indices().collect();
142 let n = chars.len();
143 let mut boundaries: Vec<usize> = Vec::new();
144
145 let mut i = 0usize;
146 while i < n {
147 let (byte_pos, ch) = chars[i];
148
149 if ch == '.' || ch == '!' || ch == '?' {
150 if ch == '.' && i + 2 < n && chars[i + 1].1 == '.' && chars[i + 2].1 == '.' {
152 i += 3;
153 continue;
154 }
155
156 if ch == '.' && self.is_abbreviation_period(text, byte_pos) {
158 i += 1;
159 continue;
160 }
161
162 if ch == '.' && self.is_decimal_period(text, byte_pos) {
164 i += 1;
165 continue;
166 }
167
168 let mut end_i = i + 1;
170 while end_i < n
171 && (chars[end_i].1 == '!' || chars[end_i].1 == '?' || chars[end_i].1 == '.')
172 {
173 end_i += 1;
174 }
175
176 let boundary_byte = if end_i < n {
177 chars[end_i].0
178 } else {
179 text.len()
180 };
181
182 if self.is_sentence_boundary(text, boundary_byte) {
183 boundaries.push(boundary_byte);
184 }
185
186 i = end_i;
187 continue;
188 }
189
190 i += 1;
191 }
192
193 boundaries
194 }
195
196 fn is_abbreviation_period(&self, text: &str, period_byte: usize) -> bool {
197 let prefix = &text[..period_byte];
198 let word = prefix
199 .rsplit(|c: char| !c.is_alphabetic() && c != '.')
200 .next()
201 .unwrap_or("");
202 self.abbreviations.contains(word)
203 || self.abbreviations.contains(&word.to_lowercase())
204 || (word.len() == 1 && word.chars().next().is_some_and(|c| c.is_uppercase()))
205 }
206
207 fn is_decimal_period(&self, text: &str, period_byte: usize) -> bool {
208 let before = text[..period_byte]
210 .chars()
211 .next_back()
212 .is_some_and(|c| c.is_ascii_digit());
213 let after = text[period_byte + 1..]
214 .chars()
215 .next()
216 .is_some_and(|c| c.is_ascii_digit());
217 before && after
218 }
219
220 fn is_sentence_boundary(&self, text: &str, pos: usize) -> bool {
221 if pos >= text.len() {
222 return true;
223 }
224 let after = &text[pos..];
225 let trimmed = after.trim_start();
226 if trimmed.is_empty() {
227 return true;
228 }
229 trimmed.chars().next().is_some_and(|c| {
230 c.is_uppercase()
231 || c.is_ascii_digit()
232 || matches!(c, '"' | '\'' | '(' | '[' | '\u{201C}' | '\u{2018}')
233 })
234 }
235}
236
237#[derive(Debug, Clone)]
243pub struct TextChunk {
244 pub text: String,
246 pub start: usize,
248 pub end: usize,
250 pub chunk_index: usize,
252 pub total_chunks: usize,
254}
255
256pub struct TextChunker {
268 pub chunk_size: usize,
270 pub overlap: usize,
272 pub by_sentence: bool,
274}
275
276impl Default for TextChunker {
277 fn default() -> Self {
278 Self::new(512, 50)
279 }
280}
281
282impl TextChunker {
283 pub fn new(chunk_size: usize, overlap: usize) -> Self {
285 let safe_overlap = if overlap >= chunk_size {
286 chunk_size.saturating_sub(1)
287 } else {
288 overlap
289 };
290 Self {
291 chunk_size,
292 overlap: safe_overlap,
293 by_sentence: false,
294 }
295 }
296
297 pub fn with_sentence_boundaries(mut self) -> Self {
299 self.by_sentence = true;
300 self
301 }
302
303 pub fn chunk(&self, text: &str) -> Vec<String> {
305 self.chunk_with_metadata(text)
306 .into_iter()
307 .map(|c| c.text)
308 .collect()
309 }
310
311 pub fn chunk_with_metadata(&self, text: &str) -> Vec<TextChunk> {
313 if text.is_empty() {
314 return Vec::new();
315 }
316
317 if self.by_sentence {
318 self.chunk_by_sentence(text)
319 } else {
320 self.chunk_by_tokens(text)
321 }
322 }
323
324 fn chunk_by_tokens(&self, text: &str) -> Vec<TextChunk> {
329 let tokens: Vec<(usize, usize)> = token_byte_ranges(text);
330
331 if tokens.is_empty() {
332 return Vec::new();
333 }
334
335 let step = self.chunk_size.saturating_sub(self.overlap).max(1);
336 let n = tokens.len();
337
338 let mut raw: Vec<(usize, usize)> = Vec::new();
339 let mut start_idx = 0usize;
340 while start_idx < n {
341 let end_idx = (start_idx + self.chunk_size).min(n);
342 let chunk_start_byte = tokens[start_idx].0;
343 let chunk_end_byte = tokens[end_idx - 1].1;
344 raw.push((chunk_start_byte, chunk_end_byte));
345 if end_idx >= n {
346 break;
347 }
348 start_idx += step;
349 }
350
351 let total = raw.len();
352 raw.into_iter()
353 .enumerate()
354 .map(|(idx, (start, end))| TextChunk {
355 text: text[start..end].to_string(),
356 start,
357 end,
358 chunk_index: idx,
359 total_chunks: total,
360 })
361 .collect()
362 }
363
364 fn chunk_by_sentence(&self, text: &str) -> Vec<TextChunk> {
369 let segmenter = SentenceSegmenter::new();
370 let sentences = segmenter.segment(text);
371
372 if sentences.is_empty() {
373 return Vec::new();
374 }
375
376 let mut chunks_data: Vec<(String, usize, usize)> = Vec::new();
377 let overlap_sentences = (self.overlap / 10).max(1);
378 let mut i = 0;
379
380 while i < sentences.len() {
381 let mut word_count = 0;
382 let mut j = i;
383 let mut chunk_parts: Vec<&str> = Vec::new();
384
385 while j < sentences.len() {
386 let sentence = sentences[j];
387 let wc = sentence.split_whitespace().count();
388 if word_count + wc > self.chunk_size && !chunk_parts.is_empty() {
389 break;
390 }
391 chunk_parts.push(sentence);
392 word_count += wc;
393 j += 1;
394 }
395
396 if !chunk_parts.is_empty() {
397 let combined = chunk_parts.join(" ");
398 let start_byte = text.find(chunk_parts[0]).unwrap_or(0);
399 let last = chunk_parts[chunk_parts.len() - 1];
400 let last_start = text.rfind(last).unwrap_or(start_byte);
401 let end_byte = (last_start + last.len()).min(text.len());
402 chunks_data.push((combined, start_byte, end_byte));
403 }
404
405 let advance = (j - i).saturating_sub(overlap_sentences).max(1);
406 i += advance;
407 }
408
409 let total = chunks_data.len();
410 chunks_data
411 .into_iter()
412 .enumerate()
413 .map(|(idx, (text_s, start, end))| TextChunk {
414 text: text_s,
415 start,
416 end,
417 chunk_index: idx,
418 total_chunks: total,
419 })
420 .collect()
421 }
422}
423
424pub fn token_byte_ranges(text: &str) -> Vec<(usize, usize)> {
430 let mut result = Vec::new();
431 let mut in_token = false;
432 let mut token_start = 0usize;
433
434 for (byte_pos, ch) in text.char_indices() {
435 if ch.is_whitespace() {
436 if in_token {
437 result.push((token_start, byte_pos));
438 in_token = false;
439 }
440 } else if !in_token {
441 token_start = byte_pos;
442 in_token = true;
443 }
444 }
445 if in_token {
446 result.push((token_start, text.len()));
447 }
448 result
449}
450
451#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_basic_segmentation() {
461 let seg = SentenceSegmenter::new();
462 let sentences = seg.segment("Hello world. How are you? I am fine.");
463 assert_eq!(
464 sentences.len(),
465 3,
466 "Expected 3 sentences, got {:?}",
467 sentences
468 );
469 }
470
471 #[test]
472 fn test_abbreviation_not_split() {
473 let seg = SentenceSegmenter::new();
474 let sentences = seg.segment("We met Dr. Smith today. He is well.");
475 assert_eq!(
476 sentences.len(),
477 2,
478 "Abbreviation should not create extra splits: {:?}",
479 sentences
480 );
481 }
482
483 #[test]
484 fn test_exclamation_and_question() {
485 let seg = SentenceSegmenter::new();
486 let sentences = seg.segment("Amazing! Really? Yes absolutely.");
487 assert!(sentences.len() >= 2);
488 }
489
490 #[test]
491 fn test_segment_owned() {
492 let seg = SentenceSegmenter::new();
493 let sentences = seg.segment_owned("First sentence. Second sentence. Third sentence.");
494 assert!(!sentences.is_empty());
495 for s in &sentences {
496 assert!(!s.is_empty());
497 }
498 }
499
500 #[test]
501 fn test_empty_text_returns_empty() {
502 let seg = SentenceSegmenter::new();
503 assert!(seg.segment("").is_empty());
504 assert!(seg.segment_owned("").is_empty());
505 }
506
507 #[test]
508 fn test_single_sentence() {
509 let seg = SentenceSegmenter::new();
510 let result = seg.segment("This is just one sentence");
511 assert_eq!(result.len(), 1);
512 }
513
514 #[test]
515 fn test_with_abbreviations() {
516 let seg = SentenceSegmenter::with_abbreviations(vec!["Esq".to_string()]);
517 let result = seg.segment("John Smith, Esq. is present. He said hello.");
518 assert_eq!(result.len(), 2, "Got {:?}", result);
519 }
520
521 #[test]
522 fn test_no_false_split_on_decimal() {
523 let seg = SentenceSegmenter::new();
524 let result = seg.segment("Pi is about 3.14159 in value. That is a fact.");
525 assert_eq!(result.len(), 2, "Got {:?}", result);
526 }
527
528 #[test]
529 fn test_chunker_basic() {
530 let chunker = TextChunker::new(5, 1);
531 let text = "one two three four five six seven eight nine ten";
532 let chunks = chunker.chunk(text);
533 assert!(!chunks.is_empty());
534 for chunk in &chunks {
535 let wc = chunk.split_whitespace().count();
536 assert!(wc <= 5, "Chunk '{}' has {} words", chunk, wc);
537 }
538 }
539
540 #[test]
541 fn test_chunker_overlap() {
542 let chunker = TextChunker::new(4, 2);
543 let text = "a b c d e f g h";
544 let chunks = chunker.chunk(text);
545 assert!(chunks.len() >= 2);
546 if chunks.len() >= 2 {
547 let words_0: Vec<&str> = chunks[0].split_whitespace().collect();
548 let words_1: Vec<&str> = chunks[1].split_whitespace().collect();
549 let last_two: Vec<&str> = words_0.iter().rev().take(2).rev().copied().collect();
550 let first_two: Vec<&str> = words_1.iter().take(2).copied().collect();
551 assert_eq!(last_two, first_two, "Overlap should share tokens");
552 }
553 }
554
555 #[test]
556 fn test_chunker_with_metadata() {
557 let chunker = TextChunker::new(3, 0);
558 let text = "alpha beta gamma delta epsilon";
559 let chunks = chunker.chunk_with_metadata(text);
560 for (i, chunk) in chunks.iter().enumerate() {
561 assert_eq!(chunk.chunk_index, i);
562 assert_eq!(chunk.total_chunks, chunks.len());
563 assert_eq!(&text[chunk.start..chunk.end], chunk.text.as_str());
564 }
565 }
566
567 #[test]
568 fn test_chunker_empty_text() {
569 let chunker = TextChunker::new(10, 2);
570 assert!(chunker.chunk("").is_empty());
571 assert!(chunker.chunk_with_metadata("").is_empty());
572 }
573
574 #[test]
575 fn test_chunker_short_text() {
576 let chunker = TextChunker::new(100, 10);
577 let text = "just three words";
578 let chunks = chunker.chunk(text);
579 assert_eq!(chunks.len(), 1);
580 assert_eq!(chunks[0], text);
581 }
582
583 #[test]
584 fn test_chunker_by_sentence() {
585 let chunker = TextChunker::new(20, 5).with_sentence_boundaries();
586 let text = "The quick brown fox jumps. A lazy dog sleeps. The sun is shining.";
587 let chunks = chunker.chunk(text);
588 assert!(!chunks.is_empty());
589 }
590
591 #[test]
592 fn test_chunker_overlap_clamped() {
593 let chunker = TextChunker::new(3, 10);
594 assert!(chunker.overlap < chunker.chunk_size);
595 }
596
597 #[test]
598 fn test_token_byte_ranges() {
599 let text = "hello world foo";
600 let ranges = token_byte_ranges(text);
601 assert_eq!(ranges.len(), 3);
602 assert_eq!(&text[ranges[0].0..ranges[0].1], "hello");
603 assert_eq!(&text[ranges[1].0..ranges[1].1], "world");
604 assert_eq!(&text[ranges[2].0..ranges[2].1], "foo");
605 }
606}