1use crate::error::{Result, VecStoreError};
30
31pub trait TextSplitter {
33 fn split_text(&self, text: &str) -> Result<Vec<String>>;
35
36 fn split_with_metadata(&self, text: &str) -> Result<Vec<TextChunk>> {
38 let chunks = self.split_text(text)?;
39 Ok(chunks
40 .into_iter()
41 .enumerate()
42 .map(|(i, content)| TextChunk {
43 index: i,
44 content,
45 char_start: 0, char_end: 0,
47 })
48 .collect())
49 }
50}
51
52#[derive(Debug, Clone, PartialEq)]
54pub struct TextChunk {
55 pub index: usize,
57 pub content: String,
59 pub char_start: usize,
61 pub char_end: usize,
63}
64
65pub struct RecursiveCharacterTextSplitter {
85 chunk_size: usize,
87 chunk_overlap: usize,
89 separators: Vec<String>,
91}
92
93impl RecursiveCharacterTextSplitter {
94 pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
109 Self {
110 chunk_size,
111 chunk_overlap,
112 separators: vec![
113 "\n\n".to_string(), "\n".to_string(), ". ".to_string(), "! ".to_string(),
117 "? ".to_string(),
118 " ".to_string(), "".to_string(), ],
121 }
122 }
123
124 pub fn with_separators(mut self, separators: Vec<String>) -> Self {
126 self.separators = separators;
127 self
128 }
129
130 fn split_recursive(&self, text: &str, separators: &[String]) -> Vec<String> {
131 if text.len() <= self.chunk_size {
132 return vec![text.to_string()];
133 }
134
135 if separators.is_empty() {
136 return self.split_by_chars(text);
138 }
139
140 let sep = &separators[0];
141 let remaining_seps = &separators[1..];
142
143 if sep.is_empty() {
144 return self.split_by_chars(text);
146 }
147
148 let parts: Vec<&str> = text.split(sep).collect();
150
151 let mut chunks = Vec::new();
152 let mut current_chunk = String::new();
153
154 for (i, part) in parts.iter().enumerate() {
155 let part_with_sep = if i < parts.len() - 1 {
156 format!("{}{}", part, sep)
157 } else {
158 part.to_string()
159 };
160
161 if part_with_sep.len() > self.chunk_size {
163 if !current_chunk.is_empty() {
164 chunks.push(current_chunk.clone());
165 current_chunk.clear();
166 }
167 let sub_chunks = self.split_recursive(&part_with_sep, remaining_seps);
168 chunks.extend(sub_chunks);
169 continue;
170 }
171
172 if current_chunk.len() + part_with_sep.len() <= self.chunk_size {
174 current_chunk.push_str(&part_with_sep);
175 } else {
176 if !current_chunk.is_empty() {
178 chunks.push(current_chunk.clone());
179 }
180 current_chunk = part_with_sep;
181 }
182 }
183
184 if !current_chunk.is_empty() {
185 chunks.push(current_chunk);
186 }
187
188 self.add_overlap(chunks)
190 }
191
192 fn split_by_chars(&self, text: &str) -> Vec<String> {
193 let chars: Vec<char> = text.chars().collect();
194 let mut chunks = Vec::new();
195
196 let mut i = 0;
197 while i < chars.len() {
198 let end = (i + self.chunk_size).min(chars.len());
199 let chunk: String = chars[i..end].iter().collect();
200 chunks.push(chunk);
201
202 if end >= chars.len() {
203 break;
204 }
205
206 i += self.chunk_size - self.chunk_overlap;
208 }
209
210 chunks
211 }
212
213 fn add_overlap(&self, chunks: Vec<String>) -> Vec<String> {
214 if self.chunk_overlap == 0 || chunks.len() <= 1 {
215 return chunks;
216 }
217
218 let mut result = Vec::new();
219
220 for (i, chunk) in chunks.iter().enumerate() {
221 if i == 0 {
222 result.push(chunk.clone());
223 continue;
224 }
225
226 let prev_chunk = &chunks[i - 1];
228 let overlap_chars: Vec<char> = prev_chunk.chars().collect();
229 let overlap_start = overlap_chars.len().saturating_sub(self.chunk_overlap);
230 let overlap: String = overlap_chars[overlap_start..].iter().collect();
231
232 let new_chunk = format!("{}{}", overlap, chunk);
233 result.push(new_chunk);
234 }
235
236 result
237 }
238}
239
240impl TextSplitter for RecursiveCharacterTextSplitter {
241 fn split_text(&self, text: &str) -> Result<Vec<String>> {
242 if text.is_empty() {
243 return Ok(vec![]);
244 }
245
246 if self.chunk_size == 0 {
247 return Err(VecStoreError::invalid_parameter(
248 "chunk_size",
249 "must be greater than 0",
250 ));
251 }
252
253 if self.chunk_overlap >= self.chunk_size {
254 return Err(VecStoreError::invalid_parameter(
255 "chunk_overlap",
256 "must be less than chunk_size",
257 ));
258 }
259
260 Ok(self.split_recursive(text, &self.separators))
261 }
262}
263
264pub struct TokenTextSplitter {
282 max_tokens: usize,
284 token_overlap: usize,
286 chars_per_token: usize,
288}
289
290impl TokenTextSplitter {
291 pub fn new(max_tokens: usize, token_overlap: usize) -> Self {
306 Self {
307 max_tokens,
308 token_overlap,
309 chars_per_token: 4, }
311 }
312
313 pub fn with_chars_per_token(mut self, chars_per_token: usize) -> Self {
315 self.chars_per_token = chars_per_token;
316 self
317 }
318}
319
320impl TextSplitter for TokenTextSplitter {
321 fn split_text(&self, text: &str) -> Result<Vec<String>> {
322 if text.is_empty() {
323 return Ok(vec![]);
324 }
325
326 let chunk_size = self.max_tokens * self.chars_per_token;
328 let chunk_overlap = self.token_overlap * self.chars_per_token;
329
330 let char_splitter = RecursiveCharacterTextSplitter::new(chunk_size, chunk_overlap);
332 char_splitter.split_text(text)
333 }
334}
335
336pub struct MarkdownTextSplitter {
364 chunk_size: usize,
366 chunk_overlap: usize,
368 preserve_headers: bool,
370}
371
372impl MarkdownTextSplitter {
373 pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
375 Self {
376 chunk_size,
377 chunk_overlap,
378 preserve_headers: false, }
380 }
381
382 pub fn with_preserve_headers(mut self, preserve: bool) -> Self {
384 self.preserve_headers = preserve;
385 self
386 }
387
388 fn parse_sections(&self, text: &str) -> Vec<MarkdownSection> {
390 let mut sections = Vec::new();
391 let mut current_section = MarkdownSection {
392 level: 0,
393 header: String::new(),
394 content: String::new(),
395 header_chain: Vec::new(),
396 };
397
398 let mut header_stack: Vec<(usize, String)> = Vec::new();
399
400 for line in text.lines() {
401 if let Some(level) = self.parse_header_level(line) {
402 if !current_section.content.is_empty() || !current_section.header.is_empty() {
404 sections.push(current_section.clone());
405 }
406
407 let header_text = line.trim_start_matches('#').trim().to_string();
409
410 header_stack.retain(|(l, _)| *l < level);
412 header_stack.push((level, header_text.clone()));
413
414 current_section = MarkdownSection {
416 level,
417 header: header_text,
418 content: String::new(),
419 header_chain: header_stack.iter().map(|(_, h)| h.clone()).collect(),
420 };
421 } else {
422 if !current_section.content.is_empty() {
424 current_section.content.push('\n');
425 }
426 current_section.content.push_str(line);
427 }
428 }
429
430 if !current_section.content.is_empty() || !current_section.header.is_empty() {
432 sections.push(current_section);
433 }
434
435 sections
436 }
437
438 fn parse_header_level(&self, line: &str) -> Option<usize> {
440 let trimmed = line.trim_start();
441 if !trimmed.starts_with('#') {
442 return None;
443 }
444
445 let level = trimmed.chars().take_while(|&c| c == '#').count();
446 if level > 0 && level <= 6 {
447 Some(level)
449 } else {
450 None
451 }
452 }
453}
454
455#[derive(Debug, Clone)]
457struct MarkdownSection {
458 level: usize,
459 header: String,
460 content: String,
461 header_chain: Vec<String>, }
463
464impl TextSplitter for MarkdownTextSplitter {
465 fn split_text(&self, text: &str) -> Result<Vec<String>> {
466 if text.is_empty() {
467 return Ok(vec![]);
468 }
469
470 if self.chunk_size == 0 {
471 return Err(VecStoreError::invalid_parameter(
472 "chunk_size",
473 "must be greater than 0",
474 ));
475 }
476
477 let sections = self.parse_sections(text);
479
480 let mut chunks = Vec::new();
481 let mut current_chunk = String::new();
482 let mut current_header_context = String::new();
483
484 for section in sections {
485 if self.preserve_headers && !section.header_chain.is_empty() {
487 current_header_context = section
488 .header_chain
489 .iter()
490 .enumerate()
491 .map(|(i, h)| format!("{} {}", "#".repeat(i + 1), h))
492 .collect::<Vec<_>>()
493 .join("\n");
494 current_header_context.push_str("\n\n");
495 }
496
497 let section_text = if section.header.is_empty() {
498 section.content.clone()
499 } else {
500 format!(
501 "{} {}\n\n{}",
502 "#".repeat(section.level),
503 section.header,
504 section.content
505 )
506 };
507
508 let chunk_with_section = if self.preserve_headers {
510 format!(
511 "{}{}{}",
512 current_chunk, current_header_context, section_text
513 )
514 } else {
515 format!("{}{}", current_chunk, section_text)
516 };
517
518 if chunk_with_section.len() <= self.chunk_size {
519 current_chunk = chunk_with_section;
520 } else {
521 if !current_chunk.is_empty() {
523 chunks.push(current_chunk.trim().to_string());
524 }
525
526 if section_text.len() > self.chunk_size {
528 let splitter = RecursiveCharacterTextSplitter::new(
529 self.chunk_size.saturating_sub(current_header_context.len()),
530 self.chunk_overlap,
531 );
532 let sub_chunks = splitter.split_text(§ion_text)?;
533
534 for sub_chunk in sub_chunks {
535 if self.preserve_headers && !current_header_context.is_empty() {
536 chunks.push(format!("{}{}", current_header_context, sub_chunk));
537 } else {
538 chunks.push(sub_chunk);
539 }
540 }
541 current_chunk = String::new();
542 } else {
543 current_chunk = if self.preserve_headers {
544 format!("{}{}", current_header_context, section_text)
545 } else {
546 section_text
547 };
548 }
549 }
550 }
551
552 if !current_chunk.is_empty() {
554 chunks.push(current_chunk.trim().to_string());
555 }
556
557 Ok(chunks)
558 }
559}
560
561pub struct CodeTextSplitter {
589 chunk_size: usize,
591 chunk_overlap: usize,
593 language: Option<String>,
595}
596
597impl CodeTextSplitter {
598 pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
600 Self {
601 chunk_size,
602 chunk_overlap,
603 language: None, }
605 }
606
607 pub fn with_language(mut self, language: impl Into<String>) -> Self {
609 self.language = Some(language.into());
610 self
611 }
612
613 fn is_code_block_start(&self, line: &str) -> bool {
615 let trimmed = line.trim_start();
616
617 match self.language.as_deref() {
618 Some("rust") => {
619 trimmed.starts_with("fn ")
620 || trimmed.starts_with("pub fn ")
621 || trimmed.starts_with("struct ")
622 || trimmed.starts_with("pub struct ")
623 || trimmed.starts_with("enum ")
624 || trimmed.starts_with("pub enum ")
625 || trimmed.starts_with("impl ")
626 || trimmed.starts_with("trait ")
627 }
628 Some("python") => {
629 trimmed.starts_with("def ")
630 || trimmed.starts_with("class ")
631 || trimmed.starts_with("async def ")
632 }
633 Some("javascript") | Some("typescript") => {
634 trimmed.starts_with("function ")
635 || trimmed.starts_with("class ")
636 || trimmed.starts_with("const ")
637 || trimmed.starts_with("let ")
638 || trimmed.starts_with("async function ")
639 || trimmed.starts_with("export ")
640 }
641 Some("java") | Some("c") | Some("cpp") => {
642 (trimmed.contains('(')
644 && trimmed.contains(')')
645 && (trimmed.contains("public")
646 || trimmed.contains("private")
647 || trimmed.contains("void")
648 || trimmed.contains("int")))
649 || trimmed.starts_with("class ")
650 }
651 Some("go") => {
652 trimmed.starts_with("func ")
653 || trimmed.starts_with("type ")
654 || trimmed.starts_with("struct ")
655 }
656 _ => {
657 trimmed.starts_with("fn ")
659 || trimmed.starts_with("function ")
660 || trimmed.starts_with("def ")
661 || trimmed.starts_with("class ")
662 }
663 }
664 }
665
666 fn get_separators(&self) -> Vec<String> {
668 vec![
669 "\n\n".to_string(), "\n}\n".to_string(), "\n\n".to_string(), "\n".to_string(), "; ".to_string(), " ".to_string(), "".to_string(), ]
677 }
678}
679
680impl TextSplitter for CodeTextSplitter {
681 fn split_text(&self, text: &str) -> Result<Vec<String>> {
682 if text.is_empty() {
683 return Ok(vec![]);
684 }
685
686 if self.chunk_size == 0 {
687 return Err(VecStoreError::invalid_parameter(
688 "chunk_size",
689 "must be greater than 0",
690 ));
691 }
692
693 let separators = self.get_separators();
695 let splitter = RecursiveCharacterTextSplitter::new(self.chunk_size, self.chunk_overlap)
696 .with_separators(separators);
697
698 if self.language.is_some() {
700 let mut chunks = Vec::new();
701 let mut current_chunk = String::new();
702 let mut current_block = String::new();
703
704 for line in text.lines() {
705 let line_with_newline = format!("{}\n", line);
706
707 if self.is_code_block_start(line) && !current_block.is_empty() {
709 if current_chunk.len() + current_block.len() <= self.chunk_size {
711 current_chunk.push_str(¤t_block);
712 current_block.clear();
713 } else {
714 if !current_chunk.is_empty() {
715 chunks.push(current_chunk.clone());
716 }
717 current_chunk = current_block.clone();
718 current_block.clear();
719 }
720 }
721
722 current_block.push_str(&line_with_newline);
723
724 if current_block.len() > self.chunk_size {
726 if !current_chunk.is_empty() {
727 chunks.push(current_chunk.clone());
728 current_chunk.clear();
729 }
730
731 let sub_chunks = splitter.split_text(¤t_block)?;
733 chunks.extend(sub_chunks);
734 current_block.clear();
735 }
736 }
737
738 if !current_block.is_empty() {
740 current_chunk.push_str(¤t_block);
741 }
742 if !current_chunk.is_empty() {
743 chunks.push(current_chunk);
744 }
745
746 return Ok(chunks);
747 }
748
749 splitter.split_text(text)
751 }
752}
753
754pub trait Embedder {
775 fn embed(&self, text: &str) -> Result<Vec<f32>>;
777}
778
779pub struct SemanticTextSplitter {
805 embedder: Box<dyn Embedder>,
807 max_chunk_size: usize,
809 min_chunk_size: usize,
811 similarity_threshold: f32,
813}
814
815impl SemanticTextSplitter {
816 pub fn new(embedder: Box<dyn Embedder>, max_chunk_size: usize, min_chunk_size: usize) -> Self {
823 Self {
824 embedder,
825 max_chunk_size,
826 min_chunk_size,
827 similarity_threshold: 0.7, }
829 }
830
831 pub fn with_similarity_threshold(mut self, threshold: f32) -> Self {
836 self.similarity_threshold = threshold.clamp(0.0, 1.0);
837 self
838 }
839
840 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
842 if a.len() != b.len() {
843 return 0.0;
844 }
845
846 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
847 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
848 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
849
850 if norm_a == 0.0 || norm_b == 0.0 {
851 return 0.0;
852 }
853
854 dot_product / (norm_a * norm_b)
855 }
856
857 fn split_sentences(&self, text: &str) -> Vec<String> {
859 text.split(&['.', '!', '?'][..])
861 .filter(|s| !s.trim().is_empty())
862 .map(|s| s.trim().to_string())
863 .collect()
864 }
865}
866
867impl TextSplitter for SemanticTextSplitter {
868 fn split_text(&self, text: &str) -> Result<Vec<String>> {
869 if text.is_empty() {
870 return Ok(vec![]);
871 }
872
873 if self.max_chunk_size == 0 {
874 return Err(VecStoreError::invalid_parameter(
875 "max_chunk_size",
876 "must be greater than 0",
877 ));
878 }
879
880 let sentences = self.split_sentences(text);
882
883 if sentences.is_empty() {
884 return Ok(vec![]);
885 }
886
887 let mut sentence_embeddings = Vec::new();
889 for sentence in &sentences {
890 let embedding = self.embedder.embed(sentence)?;
891 sentence_embeddings.push(embedding);
892 }
893
894 let mut chunks = Vec::new();
896 let mut current_chunk = String::new();
897 let mut current_embedding: Option<Vec<f32>> = None;
898
899 for (i, sentence) in sentences.iter().enumerate() {
900 let sentence_with_space = if current_chunk.is_empty() {
901 sentence.clone()
902 } else {
903 format!(" {}", sentence)
904 };
905
906 if current_chunk.len() + sentence_with_space.len() > self.max_chunk_size {
908 if current_chunk.len() >= self.min_chunk_size {
910 chunks.push(current_chunk.clone());
911 current_chunk.clear();
912 current_embedding = None;
913 }
914 }
915
916 let should_add = if let Some(ref chunk_emb) = current_embedding {
918 let similarity = self.cosine_similarity(chunk_emb, &sentence_embeddings[i]);
919 similarity >= self.similarity_threshold
920 } else {
921 true };
923
924 if should_add || current_chunk.is_empty() {
925 current_chunk.push_str(&sentence_with_space);
927
928 if let Some(ref mut chunk_emb) = current_embedding {
930 for (j, val) in sentence_embeddings[i].iter().enumerate() {
932 chunk_emb[j] = (chunk_emb[j] + val) / 2.0;
933 }
934 } else {
935 current_embedding = Some(sentence_embeddings[i].clone());
936 }
937 } else {
938 if current_chunk.len() >= self.min_chunk_size {
940 chunks.push(current_chunk.clone());
941 }
942 current_chunk = sentence.clone();
943 current_embedding = Some(sentence_embeddings[i].clone());
944 }
945 }
946
947 if !current_chunk.is_empty() && current_chunk.len() >= self.min_chunk_size {
949 chunks.push(current_chunk);
950 }
951
952 if chunks.is_empty() {
954 let fallback =
955 RecursiveCharacterTextSplitter::new(self.max_chunk_size, self.min_chunk_size / 2);
956 return fallback.split_text(text);
957 }
958
959 Ok(chunks)
960 }
961}
962
963#[cfg(test)]
964mod tests {
965 use super::*;
966
967 #[test]
968 fn test_recursive_splitter_basic() {
969 let splitter = RecursiveCharacterTextSplitter::new(20, 0);
970 let text = "Short text.";
971 let chunks = splitter.split_text(text).unwrap();
972 assert_eq!(chunks.len(), 1);
973 assert_eq!(chunks[0], text);
974 }
975
976 #[test]
977 fn test_recursive_splitter_paragraphs() {
978 let splitter = RecursiveCharacterTextSplitter::new(50, 0);
979 let text = "First paragraph.\n\nSecond paragraph.";
980 let chunks = splitter.split_text(text).unwrap();
981 assert!(chunks.len() >= 1);
982 }
983
984 #[test]
985 fn test_recursive_splitter_overlap() {
986 let splitter = RecursiveCharacterTextSplitter::new(20, 5);
987 let text = "This is a longer text that should be split into multiple chunks.";
988 let chunks = splitter.split_text(text).unwrap();
989 assert!(chunks.len() > 1);
990 }
991
992 #[test]
993 fn test_token_splitter() {
994 let splitter = TokenTextSplitter::new(10, 2); let text = "This is a test. This text should be split based on token count.";
996 let chunks = splitter.split_text(text).unwrap();
997 assert!(chunks.len() > 0);
998 }
999
1000 #[test]
1001 fn test_empty_text() {
1002 let splitter = RecursiveCharacterTextSplitter::new(100, 10);
1003 let chunks = splitter.split_text("").unwrap();
1004 assert_eq!(chunks.len(), 0);
1005 }
1006
1007 #[test]
1008 fn test_invalid_chunk_size() {
1009 let splitter = RecursiveCharacterTextSplitter::new(0, 0);
1010 let result = splitter.split_text("test");
1011 assert!(result.is_err());
1012 }
1013
1014 #[test]
1015 fn test_invalid_overlap() {
1016 let splitter = RecursiveCharacterTextSplitter::new(100, 100);
1017 let result = splitter.split_text("test");
1018 assert!(result.is_err());
1019 }
1020
1021 #[test]
1023 fn test_markdown_splitter_basic() {
1024 let splitter = MarkdownTextSplitter::new(200, 20);
1025 let text = "# Header 1\n\nSome content here.\n\n## Header 2\n\nMore content.";
1026 let chunks = splitter.split_text(text).unwrap();
1027 assert!(chunks.len() >= 1);
1028 }
1029
1030 #[test]
1031 fn test_markdown_splitter_preserve_headers() {
1032 let splitter = MarkdownTextSplitter::new(200, 20).with_preserve_headers(true);
1033 let text = "# Main\n\nContent 1\n\n## Section\n\nContent 2";
1034 let chunks = splitter.split_text(text).unwrap();
1035
1036 assert!(chunks.len() >= 1);
1038 }
1039
1040 #[test]
1041 fn test_markdown_header_parsing() {
1042 let splitter = MarkdownTextSplitter::new(100, 10);
1043
1044 assert_eq!(splitter.parse_header_level("# H1"), Some(1));
1046 assert_eq!(splitter.parse_header_level("## H2"), Some(2));
1047 assert_eq!(splitter.parse_header_level("### H3"), Some(3));
1048 assert_eq!(splitter.parse_header_level("Not a header"), None);
1049 assert_eq!(splitter.parse_header_level("####### Too many"), None);
1050 }
1051
1052 #[test]
1053 fn test_markdown_simple_by_default() {
1054 let splitter = MarkdownTextSplitter::new(500, 50);
1056 assert!(!splitter.preserve_headers);
1057 }
1058
1059 #[test]
1061 fn test_code_splitter_basic() {
1062 let splitter = CodeTextSplitter::new(200, 20);
1063 let code = "fn main() {\n println!(\"Hello\");\n}\n\nfn test() {\n // test\n}";
1064 let chunks = splitter.split_text(code).unwrap();
1065 assert!(chunks.len() >= 1);
1066 }
1067
1068 #[test]
1069 fn test_code_splitter_with_language() {
1070 let splitter = CodeTextSplitter::new(300, 30).with_language("rust");
1071 let code =
1072 "fn main() {\n println!(\"Hello\");\n}\n\nfn test() {\n println!(\"Test\");\n}";
1073 let chunks = splitter.split_text(code).unwrap();
1074 assert!(chunks.len() >= 1);
1075 }
1076
1077 #[test]
1078 fn test_code_block_detection() {
1079 let splitter = CodeTextSplitter::new(100, 10).with_language("rust");
1080 assert!(splitter.is_code_block_start("fn main() {"));
1081 assert!(splitter.is_code_block_start("pub fn test() {"));
1082 assert!(splitter.is_code_block_start("struct Foo {"));
1083 assert!(!splitter.is_code_block_start(" let x = 5;"));
1084 }
1085
1086 #[test]
1087 fn test_code_splitter_simple_by_default() {
1088 let splitter = CodeTextSplitter::new(500, 50);
1090 assert!(splitter.language.is_none());
1091 }
1092
1093 struct MockEmbedder;
1095
1096 impl Embedder for MockEmbedder {
1097 fn embed(&self, text: &str) -> Result<Vec<f32>> {
1098 let len = text.len() as f32;
1101 Ok(vec![len / 100.0, len / 50.0, len / 25.0])
1102 }
1103 }
1104
1105 #[test]
1106 fn test_semantic_splitter_basic() {
1107 let embedder = Box::new(MockEmbedder);
1108 let splitter = SemanticTextSplitter::new(embedder, 200, 20);
1109 let text =
1110 "First sentence. Second sentence here. Third one is different. Fourth continues.";
1111 let chunks = splitter.split_text(text).unwrap();
1112 assert!(chunks.len() >= 1);
1113 }
1114
1115 #[test]
1116 fn test_semantic_splitter_with_threshold() {
1117 let embedder = Box::new(MockEmbedder);
1118 let splitter = SemanticTextSplitter::new(embedder, 300, 30).with_similarity_threshold(0.8);
1119 let text = "Sentence one. Sentence two. Sentence three.";
1120 let chunks = splitter.split_text(text).unwrap();
1121 assert!(chunks.len() >= 1);
1122 }
1123
1124 #[test]
1125 fn test_semantic_splitter_cosine_similarity() {
1126 let embedder = Box::new(MockEmbedder);
1127 let splitter = SemanticTextSplitter::new(embedder, 100, 10);
1128
1129 let v1 = vec![1.0, 0.0, 0.0];
1130 let v2 = vec![1.0, 0.0, 0.0];
1131 let v3 = vec![0.0, 1.0, 0.0];
1132
1133 let sim1 = splitter.cosine_similarity(&v1, &v2);
1135 assert!((sim1 - 1.0).abs() < 0.01);
1136
1137 let sim2 = splitter.cosine_similarity(&v1, &v3);
1139 assert!(sim2.abs() < 0.01);
1140 }
1141
1142 #[test]
1143 fn test_embedder_trait_composable() {
1144 struct CustomEmbedder;
1146 impl Embedder for CustomEmbedder {
1147 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1148 Ok(vec![1.0, 2.0, 3.0])
1149 }
1150 }
1151
1152 let embedder = Box::new(CustomEmbedder);
1153 let splitter = SemanticTextSplitter::new(embedder, 500, 50);
1154
1155 let text = "Test text.";
1156 let result = splitter.split_text(text);
1157 assert!(result.is_ok());
1158 }
1159}