vecstore/
text_splitter.rs

1//! Text Splitting for RAG Applications
2//!
3//! Provides text chunking strategies for breaking down large documents into
4//! embedding-sized chunks. Essential for RAG (Retrieval-Augmented Generation)
5//! systems that need to store and search long documents.
6//!
7//! # Strategies
8//!
9//! - **RecursiveCharacterTextSplitter**: Splits on paragraph/sentence/word boundaries
10//! - **TokenTextSplitter**: Splits based on token count (for LLMs with token limits)
11//! - **MarkdownTextSplitter**: Markdown-aware splitting that respects header hierarchy
12//! - **CodeTextSplitter**: Code-aware splitting that respects function/class boundaries
13//! - **SemanticTextSplitter**: Embedding-based splitting that groups semantically similar content
14//!
15//! # Example
16//!
17//! ```no_run
18//! use vecstore::text_splitter::{RecursiveCharacterTextSplitter, TextSplitter};
19//!
20//! let splitter = RecursiveCharacterTextSplitter::new(500, 50);
21//! let chunks = splitter.split_text("Long document text...")?;
22//!
23//! for (i, chunk) in chunks.iter().enumerate() {
24//!     println!("Chunk {}: {} chars", i, chunk.len());
25//! }
26//! # Ok::<(), anyhow::Error>(())
27//! ```
28
29use crate::error::{Result, VecStoreError};
30
31/// Trait for text splitting strategies
32pub trait TextSplitter {
33    /// Split text into chunks
34    fn split_text(&self, text: &str) -> Result<Vec<String>>;
35
36    /// Split text into chunks with metadata (position, length, etc.)
37    fn split_with_metadata(&self, text: &str) -> Result<Vec<TextChunk>> {
38        let chunks = self.split_text(text)?;
39        Ok(chunks
40            .into_iter()
41            .enumerate()
42            .map(|(i, content)| TextChunk {
43                index: i,
44                content,
45                char_start: 0, // Simplified - could track actual positions
46                char_end: 0,
47            })
48            .collect())
49    }
50}
51
52/// A text chunk with metadata
53#[derive(Debug, Clone, PartialEq)]
54pub struct TextChunk {
55    /// Chunk index in the original document
56    pub index: usize,
57    /// Chunk content
58    pub content: String,
59    /// Character start position in original text
60    pub char_start: usize,
61    /// Character end position in original text
62    pub char_end: usize,
63}
64
65/// Recursive character-based text splitter
66///
67/// Tries to split on natural boundaries in this order:
68/// 1. Double newlines (paragraphs)
69/// 2. Single newlines (lines)
70/// 3. Sentences (periods, question marks, exclamation points)
71/// 4. Words (spaces)
72/// 5. Characters (last resort)
73///
74/// # Example
75///
76/// ```no_run
77/// use vecstore::text_splitter::{RecursiveCharacterTextSplitter, TextSplitter};
78///
79/// let splitter = RecursiveCharacterTextSplitter::new(1000, 100);
80/// let text = "First paragraph.\n\nSecond paragraph with more content...";
81/// let chunks = splitter.split_text(text)?;
82/// # Ok::<(), anyhow::Error>(())
83/// ```
84pub struct RecursiveCharacterTextSplitter {
85    /// Maximum chunk size in characters
86    chunk_size: usize,
87    /// Overlap between chunks in characters
88    chunk_overlap: usize,
89    /// Separators to try, in order of preference
90    separators: Vec<String>,
91}
92
93impl RecursiveCharacterTextSplitter {
94    /// Create a new recursive splitter
95    ///
96    /// # Arguments
97    /// * `chunk_size` - Maximum characters per chunk
98    /// * `chunk_overlap` - Characters to overlap between chunks (for context continuity)
99    ///
100    /// # Example
101    ///
102    /// ```no_run
103    /// use vecstore::text_splitter::RecursiveCharacterTextSplitter;
104    ///
105    /// // 500 char chunks with 50 char overlap
106    /// let splitter = RecursiveCharacterTextSplitter::new(500, 50);
107    /// ```
108    pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
109        Self {
110            chunk_size,
111            chunk_overlap,
112            separators: vec![
113                "\n\n".to_string(), // Paragraphs
114                "\n".to_string(),   // Lines
115                ". ".to_string(),   // Sentences
116                "! ".to_string(),
117                "? ".to_string(),
118                " ".to_string(), // Words
119                "".to_string(),  // Characters
120            ],
121        }
122    }
123
124    /// Create with custom separators
125    pub fn with_separators(mut self, separators: Vec<String>) -> Self {
126        self.separators = separators;
127        self
128    }
129
130    fn split_recursive(&self, text: &str, separators: &[String]) -> Vec<String> {
131        if text.len() <= self.chunk_size {
132            return vec![text.to_string()];
133        }
134
135        if separators.is_empty() {
136            // Fallback: character-level split
137            return self.split_by_chars(text);
138        }
139
140        let sep = &separators[0];
141        let remaining_seps = &separators[1..];
142
143        if sep.is_empty() {
144            // Empty separator means character-level split
145            return self.split_by_chars(text);
146        }
147
148        // Split by current separator
149        let parts: Vec<&str> = text.split(sep).collect();
150
151        let mut chunks = Vec::new();
152        let mut current_chunk = String::new();
153
154        for (i, part) in parts.iter().enumerate() {
155            let part_with_sep = if i < parts.len() - 1 {
156                format!("{}{}", part, sep)
157            } else {
158                part.to_string()
159            };
160
161            // If this part alone is too big, recursively split it
162            if part_with_sep.len() > self.chunk_size {
163                if !current_chunk.is_empty() {
164                    chunks.push(current_chunk.clone());
165                    current_chunk.clear();
166                }
167                let sub_chunks = self.split_recursive(&part_with_sep, remaining_seps);
168                chunks.extend(sub_chunks);
169                continue;
170            }
171
172            // Try to add to current chunk
173            if current_chunk.len() + part_with_sep.len() <= self.chunk_size {
174                current_chunk.push_str(&part_with_sep);
175            } else {
176                // Current chunk is full, start a new one
177                if !current_chunk.is_empty() {
178                    chunks.push(current_chunk.clone());
179                }
180                current_chunk = part_with_sep;
181            }
182        }
183
184        if !current_chunk.is_empty() {
185            chunks.push(current_chunk);
186        }
187
188        // Add overlap
189        self.add_overlap(chunks)
190    }
191
192    fn split_by_chars(&self, text: &str) -> Vec<String> {
193        let chars: Vec<char> = text.chars().collect();
194        let mut chunks = Vec::new();
195
196        let mut i = 0;
197        while i < chars.len() {
198            let end = (i + self.chunk_size).min(chars.len());
199            let chunk: String = chars[i..end].iter().collect();
200            chunks.push(chunk);
201
202            if end >= chars.len() {
203                break;
204            }
205
206            // Move forward, accounting for overlap
207            i += self.chunk_size - self.chunk_overlap;
208        }
209
210        chunks
211    }
212
213    fn add_overlap(&self, chunks: Vec<String>) -> Vec<String> {
214        if self.chunk_overlap == 0 || chunks.len() <= 1 {
215            return chunks;
216        }
217
218        let mut result = Vec::new();
219
220        for (i, chunk) in chunks.iter().enumerate() {
221            if i == 0 {
222                result.push(chunk.clone());
223                continue;
224            }
225
226            // Get overlap from previous chunk
227            let prev_chunk = &chunks[i - 1];
228            let overlap_chars: Vec<char> = prev_chunk.chars().collect();
229            let overlap_start = overlap_chars.len().saturating_sub(self.chunk_overlap);
230            let overlap: String = overlap_chars[overlap_start..].iter().collect();
231
232            let new_chunk = format!("{}{}", overlap, chunk);
233            result.push(new_chunk);
234        }
235
236        result
237    }
238}
239
240impl TextSplitter for RecursiveCharacterTextSplitter {
241    fn split_text(&self, text: &str) -> Result<Vec<String>> {
242        if text.is_empty() {
243            return Ok(vec![]);
244        }
245
246        if self.chunk_size == 0 {
247            return Err(VecStoreError::invalid_parameter(
248                "chunk_size",
249                "must be greater than 0",
250            ));
251        }
252
253        if self.chunk_overlap >= self.chunk_size {
254            return Err(VecStoreError::invalid_parameter(
255                "chunk_overlap",
256                "must be less than chunk_size",
257            ));
258        }
259
260        Ok(self.split_recursive(text, &self.separators))
261    }
262}
263
264/// Token-based text splitter
265///
266/// Splits text based on approximate token count rather than character count.
267/// Useful for LLM applications with token limits.
268///
269/// Uses a simple heuristic: ~4 characters per token (approximation for English)
270///
271/// # Example
272///
273/// ```no_run
274/// use vecstore::text_splitter::{TokenTextSplitter, TextSplitter};
275///
276/// // Split into ~512 token chunks
277/// let splitter = TokenTextSplitter::new(512, 50);
278/// let chunks = splitter.split_text("Long document...")?;
279/// # Ok::<(), anyhow::Error>(())
280/// ```
281pub struct TokenTextSplitter {
282    /// Maximum tokens per chunk
283    max_tokens: usize,
284    /// Overlap in tokens
285    token_overlap: usize,
286    /// Characters per token (approximation)
287    chars_per_token: usize,
288}
289
290impl TokenTextSplitter {
291    /// Create a new token-based splitter
292    ///
293    /// # Arguments
294    /// * `max_tokens` - Maximum tokens per chunk
295    /// * `token_overlap` - Tokens to overlap between chunks
296    ///
297    /// # Example
298    ///
299    /// ```no_run
300    /// use vecstore::text_splitter::TokenTextSplitter;
301    ///
302    /// // 512 token chunks with 50 token overlap
303    /// let splitter = TokenTextSplitter::new(512, 50);
304    /// ```
305    pub fn new(max_tokens: usize, token_overlap: usize) -> Self {
306        Self {
307            max_tokens,
308            token_overlap,
309            chars_per_token: 4, // Approximation for English
310        }
311    }
312
313    /// Set characters per token (default: 4)
314    pub fn with_chars_per_token(mut self, chars_per_token: usize) -> Self {
315        self.chars_per_token = chars_per_token;
316        self
317    }
318}
319
320impl TextSplitter for TokenTextSplitter {
321    fn split_text(&self, text: &str) -> Result<Vec<String>> {
322        if text.is_empty() {
323            return Ok(vec![]);
324        }
325
326        // Convert token limits to character limits
327        let chunk_size = self.max_tokens * self.chars_per_token;
328        let chunk_overlap = self.token_overlap * self.chars_per_token;
329
330        // Use recursive splitter with character-based limits
331        let char_splitter = RecursiveCharacterTextSplitter::new(chunk_size, chunk_overlap);
332        char_splitter.split_text(text)
333    }
334}
335
336/// Markdown-aware text splitter
337///
338/// Splits markdown documents while respecting header hierarchy.
339/// **HYBRID**: Simple by default, powerful when needed.
340///
341/// # Simple Usage (Default)
342///
343/// ```no_run
344/// use vecstore::text_splitter::{MarkdownTextSplitter, TextSplitter};
345///
346/// // Just works - splits on markdown boundaries
347/// let splitter = MarkdownTextSplitter::new(500, 50);
348/// let chunks = splitter.split_text("# Title\n\nContent...")?;
349/// # Ok::<(), anyhow::Error>(())
350/// ```
351///
352/// # Advanced Usage (Optional)
353///
354/// ```no_run
355/// use vecstore::text_splitter::{MarkdownTextSplitter, TextSplitter};
356///
357/// // Preserve header context in each chunk
358/// let splitter = MarkdownTextSplitter::new(500, 50)
359///     .with_preserve_headers(true);
360/// let chunks = splitter.split_text("# Title\n## Section\nContent...")?;
361/// # Ok::<(), anyhow::Error>(())
362/// ```
363pub struct MarkdownTextSplitter {
364    /// Maximum chunk size in characters
365    chunk_size: usize,
366    /// Overlap between chunks
367    chunk_overlap: usize,
368    /// Whether to preserve header hierarchy in chunks
369    preserve_headers: bool,
370}
371
372impl MarkdownTextSplitter {
373    /// Create a new markdown splitter (simple, just works)
374    pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
375        Self {
376            chunk_size,
377            chunk_overlap,
378            preserve_headers: false, // Simple by default
379        }
380    }
381
382    /// Preserve header context in chunks (advanced, opt-in)
383    pub fn with_preserve_headers(mut self, preserve: bool) -> Self {
384        self.preserve_headers = preserve;
385        self
386    }
387
388    /// Parse markdown sections with header hierarchy
389    fn parse_sections(&self, text: &str) -> Vec<MarkdownSection> {
390        let mut sections = Vec::new();
391        let mut current_section = MarkdownSection {
392            level: 0,
393            header: String::new(),
394            content: String::new(),
395            header_chain: Vec::new(),
396        };
397
398        let mut header_stack: Vec<(usize, String)> = Vec::new();
399
400        for line in text.lines() {
401            if let Some(level) = self.parse_header_level(line) {
402                // Save previous section
403                if !current_section.content.is_empty() || !current_section.header.is_empty() {
404                    sections.push(current_section.clone());
405                }
406
407                // Parse header text
408                let header_text = line.trim_start_matches('#').trim().to_string();
409
410                // Update header stack (track hierarchy)
411                header_stack.retain(|(l, _)| *l < level);
412                header_stack.push((level, header_text.clone()));
413
414                // Start new section
415                current_section = MarkdownSection {
416                    level,
417                    header: header_text,
418                    content: String::new(),
419                    header_chain: header_stack.iter().map(|(_, h)| h.clone()).collect(),
420                };
421            } else {
422                // Add content line
423                if !current_section.content.is_empty() {
424                    current_section.content.push('\n');
425                }
426                current_section.content.push_str(line);
427            }
428        }
429
430        // Save final section
431        if !current_section.content.is_empty() || !current_section.header.is_empty() {
432            sections.push(current_section);
433        }
434
435        sections
436    }
437
438    /// Parse header level from line (e.g., "### Header" -> 3)
439    fn parse_header_level(&self, line: &str) -> Option<usize> {
440        let trimmed = line.trim_start();
441        if !trimmed.starts_with('#') {
442            return None;
443        }
444
445        let level = trimmed.chars().take_while(|&c| c == '#').count();
446        if level > 0 && level <= 6 {
447            // Valid markdown header (H1-H6)
448            Some(level)
449        } else {
450            None
451        }
452    }
453}
454
455/// Markdown section with header hierarchy
456#[derive(Debug, Clone)]
457struct MarkdownSection {
458    level: usize,
459    header: String,
460    content: String,
461    header_chain: Vec<String>, // Full hierarchy: ["H1", "H2", "H3"]
462}
463
464impl TextSplitter for MarkdownTextSplitter {
465    fn split_text(&self, text: &str) -> Result<Vec<String>> {
466        if text.is_empty() {
467            return Ok(vec![]);
468        }
469
470        if self.chunk_size == 0 {
471            return Err(VecStoreError::invalid_parameter(
472                "chunk_size",
473                "must be greater than 0",
474            ));
475        }
476
477        // Parse into markdown sections
478        let sections = self.parse_sections(text);
479
480        let mut chunks = Vec::new();
481        let mut current_chunk = String::new();
482        let mut current_header_context = String::new();
483
484        for section in sections {
485            // Build header context if preserving headers
486            if self.preserve_headers && !section.header_chain.is_empty() {
487                current_header_context = section
488                    .header_chain
489                    .iter()
490                    .enumerate()
491                    .map(|(i, h)| format!("{} {}", "#".repeat(i + 1), h))
492                    .collect::<Vec<_>>()
493                    .join("\n");
494                current_header_context.push_str("\n\n");
495            }
496
497            let section_text = if section.header.is_empty() {
498                section.content.clone()
499            } else {
500                format!(
501                    "{} {}\n\n{}",
502                    "#".repeat(section.level),
503                    section.header,
504                    section.content
505                )
506            };
507
508            // If section fits in current chunk, add it
509            let chunk_with_section = if self.preserve_headers {
510                format!(
511                    "{}{}{}",
512                    current_chunk, current_header_context, section_text
513                )
514            } else {
515                format!("{}{}", current_chunk, section_text)
516            };
517
518            if chunk_with_section.len() <= self.chunk_size {
519                current_chunk = chunk_with_section;
520            } else {
521                // Current chunk is full, save it and start new one
522                if !current_chunk.is_empty() {
523                    chunks.push(current_chunk.trim().to_string());
524                }
525
526                // If section itself is too large, split it with RecursiveCharacterTextSplitter
527                if section_text.len() > self.chunk_size {
528                    let splitter = RecursiveCharacterTextSplitter::new(
529                        self.chunk_size.saturating_sub(current_header_context.len()),
530                        self.chunk_overlap,
531                    );
532                    let sub_chunks = splitter.split_text(&section_text)?;
533
534                    for sub_chunk in sub_chunks {
535                        if self.preserve_headers && !current_header_context.is_empty() {
536                            chunks.push(format!("{}{}", current_header_context, sub_chunk));
537                        } else {
538                            chunks.push(sub_chunk);
539                        }
540                    }
541                    current_chunk = String::new();
542                } else {
543                    current_chunk = if self.preserve_headers {
544                        format!("{}{}", current_header_context, section_text)
545                    } else {
546                        section_text
547                    };
548                }
549            }
550        }
551
552        // Save final chunk
553        if !current_chunk.is_empty() {
554            chunks.push(current_chunk.trim().to_string());
555        }
556
557        Ok(chunks)
558    }
559}
560
561/// Code-aware text splitter
562///
563/// Splits source code while respecting function and class boundaries.
564/// **HYBRID**: Simple by default, language-aware when needed.
565///
566/// # Simple Usage (Default)
567///
568/// ```no_run
569/// use vecstore::text_splitter::{CodeTextSplitter, TextSplitter};
570///
571/// // Just works - splits on smart boundaries
572/// let splitter = CodeTextSplitter::new(800, 50);
573/// let chunks = splitter.split_text("fn main() { ... }")?;
574/// # Ok::<(), anyhow::Error>(())
575/// ```
576///
577/// # Advanced Usage (Optional)
578///
579/// ```no_run
580/// use vecstore::text_splitter::{CodeTextSplitter, TextSplitter};
581///
582/// // Language-specific splitting
583/// let splitter = CodeTextSplitter::new(800, 50)
584///     .with_language("rust");
585/// let chunks = splitter.split_text("fn main() { ... }")?;
586/// # Ok::<(), anyhow::Error>(())
587/// ```
588pub struct CodeTextSplitter {
589    /// Maximum chunk size in characters
590    chunk_size: usize,
591    /// Overlap between chunks
592    chunk_overlap: usize,
593    /// Optional language hint ("rust", "python", "javascript", etc.)
594    language: Option<String>,
595}
596
597impl CodeTextSplitter {
598    /// Create a new code splitter (simple, language-agnostic)
599    pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
600        Self {
601            chunk_size,
602            chunk_overlap,
603            language: None, // Simple by default - works for all languages
604        }
605    }
606
607    /// Set language for smarter splitting (advanced, opt-in)
608    pub fn with_language(mut self, language: impl Into<String>) -> Self {
609        self.language = Some(language.into());
610        self
611    }
612
613    /// Detect if a line starts a code block (function, class, etc.)
614    fn is_code_block_start(&self, line: &str) -> bool {
615        let trimmed = line.trim_start();
616
617        match self.language.as_deref() {
618            Some("rust") => {
619                trimmed.starts_with("fn ")
620                    || trimmed.starts_with("pub fn ")
621                    || trimmed.starts_with("struct ")
622                    || trimmed.starts_with("pub struct ")
623                    || trimmed.starts_with("enum ")
624                    || trimmed.starts_with("pub enum ")
625                    || trimmed.starts_with("impl ")
626                    || trimmed.starts_with("trait ")
627            }
628            Some("python") => {
629                trimmed.starts_with("def ")
630                    || trimmed.starts_with("class ")
631                    || trimmed.starts_with("async def ")
632            }
633            Some("javascript") | Some("typescript") => {
634                trimmed.starts_with("function ")
635                    || trimmed.starts_with("class ")
636                    || trimmed.starts_with("const ")
637                    || trimmed.starts_with("let ")
638                    || trimmed.starts_with("async function ")
639                    || trimmed.starts_with("export ")
640            }
641            Some("java") | Some("c") | Some("cpp") => {
642                // Simple heuristic: look for function-like patterns
643                (trimmed.contains('(')
644                    && trimmed.contains(')')
645                    && (trimmed.contains("public")
646                        || trimmed.contains("private")
647                        || trimmed.contains("void")
648                        || trimmed.contains("int")))
649                    || trimmed.starts_with("class ")
650            }
651            Some("go") => {
652                trimmed.starts_with("func ")
653                    || trimmed.starts_with("type ")
654                    || trimmed.starts_with("struct ")
655            }
656            _ => {
657                // Language-agnostic heuristics
658                trimmed.starts_with("fn ")
659                    || trimmed.starts_with("function ")
660                    || trimmed.starts_with("def ")
661                    || trimmed.starts_with("class ")
662            }
663        }
664    }
665
666    /// Get code-specific separators
667    fn get_separators(&self) -> Vec<String> {
668        vec![
669            "\n\n".to_string(),  // Double newline (blank line between functions/blocks)
670            "\n}\n".to_string(), // Closing brace (end of block)
671            "\n\n".to_string(),  // Paragraphs
672            "\n".to_string(),    // Lines
673            "; ".to_string(),    // Statements
674            " ".to_string(),     // Words
675            "".to_string(),      // Characters
676        ]
677    }
678}
679
680impl TextSplitter for CodeTextSplitter {
681    fn split_text(&self, text: &str) -> Result<Vec<String>> {
682        if text.is_empty() {
683            return Ok(vec![]);
684        }
685
686        if self.chunk_size == 0 {
687            return Err(VecStoreError::invalid_parameter(
688                "chunk_size",
689                "must be greater than 0",
690            ));
691        }
692
693        // Use recursive splitter with code-aware separators
694        let separators = self.get_separators();
695        let splitter = RecursiveCharacterTextSplitter::new(self.chunk_size, self.chunk_overlap)
696            .with_separators(separators);
697
698        // If we have language hints, try to split on code block boundaries first
699        if self.language.is_some() {
700            let mut chunks = Vec::new();
701            let mut current_chunk = String::new();
702            let mut current_block = String::new();
703
704            for line in text.lines() {
705                let line_with_newline = format!("{}\n", line);
706
707                // Check if this starts a new code block
708                if self.is_code_block_start(line) && !current_block.is_empty() {
709                    // Save previous block
710                    if current_chunk.len() + current_block.len() <= self.chunk_size {
711                        current_chunk.push_str(&current_block);
712                        current_block.clear();
713                    } else {
714                        if !current_chunk.is_empty() {
715                            chunks.push(current_chunk.clone());
716                        }
717                        current_chunk = current_block.clone();
718                        current_block.clear();
719                    }
720                }
721
722                current_block.push_str(&line_with_newline);
723
724                // If block is getting too large, flush it
725                if current_block.len() > self.chunk_size {
726                    if !current_chunk.is_empty() {
727                        chunks.push(current_chunk.clone());
728                        current_chunk.clear();
729                    }
730
731                    // Split oversized block with standard splitter
732                    let sub_chunks = splitter.split_text(&current_block)?;
733                    chunks.extend(sub_chunks);
734                    current_block.clear();
735                }
736            }
737
738            // Save remaining content
739            if !current_block.is_empty() {
740                current_chunk.push_str(&current_block);
741            }
742            if !current_chunk.is_empty() {
743                chunks.push(current_chunk);
744            }
745
746            return Ok(chunks);
747        }
748
749        // Fallback: use standard recursive splitter with code separators
750        splitter.split_text(text)
751    }
752}
753
754/// Simple trait for embedding text (used by SemanticTextSplitter)
755///
756/// **HYBRID**: Any embedder works - users provide their own implementation.
757/// No forced dependencies on specific embedding libraries.
758///
759/// # Example Implementation
760///
761/// ```no_run
762/// use vecstore::text_splitter::Embedder;
763/// use anyhow::Result;
764///
765/// struct MyEmbedder;
766///
767/// impl Embedder for MyEmbedder {
768///     fn embed(&self, text: &str) -> Result<Vec<f32>> {
769///         // Your embedding logic here
770///         Ok(vec![0.0; 384]) // Example: 384-dim vector
771///     }
772/// }
773/// ```
774pub trait Embedder {
775    /// Embed a text into a vector
776    fn embed(&self, text: &str) -> Result<Vec<f32>>;
777}
778
779/// Semantic text splitter
780///
781/// Splits text based on semantic similarity using embeddings.
782/// Groups semantically similar content together.
783/// **HYBRID**: Requires embedder (advanced), but composable with any embedding model.
784///
785/// # Usage
786///
787/// ```no_run
788/// use vecstore::text_splitter::{SemanticTextSplitter, TextSplitter, Embedder};
789/// use anyhow::Result;
790///
791/// // Provide your own embedder (no forced dependency)
792/// struct MyEmbedder;
793/// impl Embedder for MyEmbedder {
794///     fn embed(&self, text: &str) -> Result<Vec<f32>> {
795///         Ok(vec![0.0; 384])
796///     }
797/// }
798///
799/// let embedder = Box::new(MyEmbedder);
800/// let splitter = SemanticTextSplitter::new(embedder, 500, 50);
801/// let chunks = splitter.split_text("Long document...")?;
802/// # Ok::<(), anyhow::Error>(())
803/// ```
804pub struct SemanticTextSplitter {
805    /// Embedder for computing semantic similarity
806    embedder: Box<dyn Embedder>,
807    /// Maximum chunk size in characters
808    max_chunk_size: usize,
809    /// Minimum chunk size in characters
810    min_chunk_size: usize,
811    /// Similarity threshold (0.0-1.0) for grouping sentences
812    similarity_threshold: f32,
813}
814
815impl SemanticTextSplitter {
816    /// Create a new semantic splitter
817    ///
818    /// # Arguments
819    /// * `embedder` - Any embedder implementing the Embedder trait (HYBRID: bring your own)
820    /// * `max_chunk_size` - Maximum characters per chunk
821    /// * `min_chunk_size` - Minimum characters per chunk (avoid tiny chunks)
822    pub fn new(embedder: Box<dyn Embedder>, max_chunk_size: usize, min_chunk_size: usize) -> Self {
823        Self {
824            embedder,
825            max_chunk_size,
826            min_chunk_size,
827            similarity_threshold: 0.7, // Default: group similar content
828        }
829    }
830
831    /// Set similarity threshold (advanced, opt-in)
832    ///
833    /// Higher = more similar content required for grouping
834    /// Lower = more aggressive grouping
835    pub fn with_similarity_threshold(mut self, threshold: f32) -> Self {
836        self.similarity_threshold = threshold.clamp(0.0, 1.0);
837        self
838    }
839
840    /// Compute cosine similarity between two vectors
841    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
842        if a.len() != b.len() {
843            return 0.0;
844        }
845
846        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
847        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
848        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
849
850        if norm_a == 0.0 || norm_b == 0.0 {
851            return 0.0;
852        }
853
854        dot_product / (norm_a * norm_b)
855    }
856
857    /// Split text into sentences (simple heuristic)
858    fn split_sentences(&self, text: &str) -> Vec<String> {
859        // Simple sentence splitting on common boundaries
860        text.split(&['.', '!', '?'][..])
861            .filter(|s| !s.trim().is_empty())
862            .map(|s| s.trim().to_string())
863            .collect()
864    }
865}
866
867impl TextSplitter for SemanticTextSplitter {
868    fn split_text(&self, text: &str) -> Result<Vec<String>> {
869        if text.is_empty() {
870            return Ok(vec![]);
871        }
872
873        if self.max_chunk_size == 0 {
874            return Err(VecStoreError::invalid_parameter(
875                "max_chunk_size",
876                "must be greater than 0",
877            ));
878        }
879
880        // Split into sentences
881        let sentences = self.split_sentences(text);
882
883        if sentences.is_empty() {
884            return Ok(vec![]);
885        }
886
887        // Compute embeddings for all sentences
888        let mut sentence_embeddings = Vec::new();
889        for sentence in &sentences {
890            let embedding = self.embedder.embed(sentence)?;
891            sentence_embeddings.push(embedding);
892        }
893
894        // Group sentences into chunks based on semantic similarity
895        let mut chunks = Vec::new();
896        let mut current_chunk = String::new();
897        let mut current_embedding: Option<Vec<f32>> = None;
898
899        for (i, sentence) in sentences.iter().enumerate() {
900            let sentence_with_space = if current_chunk.is_empty() {
901                sentence.clone()
902            } else {
903                format!(" {}", sentence)
904            };
905
906            // Check if adding this sentence would exceed max size
907            if current_chunk.len() + sentence_with_space.len() > self.max_chunk_size {
908                // Save current chunk if it meets minimum size
909                if current_chunk.len() >= self.min_chunk_size {
910                    chunks.push(current_chunk.clone());
911                    current_chunk.clear();
912                    current_embedding = None;
913                }
914            }
915
916            // Compute similarity with current chunk
917            let should_add = if let Some(ref chunk_emb) = current_embedding {
918                let similarity = self.cosine_similarity(chunk_emb, &sentence_embeddings[i]);
919                similarity >= self.similarity_threshold
920            } else {
921                true // First sentence always added
922            };
923
924            if should_add || current_chunk.is_empty() {
925                // Add sentence to current chunk
926                current_chunk.push_str(&sentence_with_space);
927
928                // Update chunk embedding (average of all sentence embeddings)
929                if let Some(ref mut chunk_emb) = current_embedding {
930                    // Simple averaging (could be weighted)
931                    for (j, val) in sentence_embeddings[i].iter().enumerate() {
932                        chunk_emb[j] = (chunk_emb[j] + val) / 2.0;
933                    }
934                } else {
935                    current_embedding = Some(sentence_embeddings[i].clone());
936                }
937            } else {
938                // Similarity too low - start new chunk
939                if current_chunk.len() >= self.min_chunk_size {
940                    chunks.push(current_chunk.clone());
941                }
942                current_chunk = sentence.clone();
943                current_embedding = Some(sentence_embeddings[i].clone());
944            }
945        }
946
947        // Save final chunk
948        if !current_chunk.is_empty() && current_chunk.len() >= self.min_chunk_size {
949            chunks.push(current_chunk);
950        }
951
952        // Fallback: if no chunks created, use character splitter
953        if chunks.is_empty() {
954            let fallback =
955                RecursiveCharacterTextSplitter::new(self.max_chunk_size, self.min_chunk_size / 2);
956            return fallback.split_text(text);
957        }
958
959        Ok(chunks)
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966
967    #[test]
968    fn test_recursive_splitter_basic() {
969        let splitter = RecursiveCharacterTextSplitter::new(20, 0);
970        let text = "Short text.";
971        let chunks = splitter.split_text(text).unwrap();
972        assert_eq!(chunks.len(), 1);
973        assert_eq!(chunks[0], text);
974    }
975
976    #[test]
977    fn test_recursive_splitter_paragraphs() {
978        let splitter = RecursiveCharacterTextSplitter::new(50, 0);
979        let text = "First paragraph.\n\nSecond paragraph.";
980        let chunks = splitter.split_text(text).unwrap();
981        assert!(chunks.len() >= 1);
982    }
983
984    #[test]
985    fn test_recursive_splitter_overlap() {
986        let splitter = RecursiveCharacterTextSplitter::new(20, 5);
987        let text = "This is a longer text that should be split into multiple chunks.";
988        let chunks = splitter.split_text(text).unwrap();
989        assert!(chunks.len() > 1);
990    }
991
992    #[test]
993    fn test_token_splitter() {
994        let splitter = TokenTextSplitter::new(10, 2); // 10 tokens ~ 40 chars
995        let text = "This is a test. This text should be split based on token count.";
996        let chunks = splitter.split_text(text).unwrap();
997        assert!(chunks.len() > 0);
998    }
999
1000    #[test]
1001    fn test_empty_text() {
1002        let splitter = RecursiveCharacterTextSplitter::new(100, 10);
1003        let chunks = splitter.split_text("").unwrap();
1004        assert_eq!(chunks.len(), 0);
1005    }
1006
1007    #[test]
1008    fn test_invalid_chunk_size() {
1009        let splitter = RecursiveCharacterTextSplitter::new(0, 0);
1010        let result = splitter.split_text("test");
1011        assert!(result.is_err());
1012    }
1013
1014    #[test]
1015    fn test_invalid_overlap() {
1016        let splitter = RecursiveCharacterTextSplitter::new(100, 100);
1017        let result = splitter.split_text("test");
1018        assert!(result.is_err());
1019    }
1020
1021    // Markdown splitter tests
1022    #[test]
1023    fn test_markdown_splitter_basic() {
1024        let splitter = MarkdownTextSplitter::new(200, 20);
1025        let text = "# Header 1\n\nSome content here.\n\n## Header 2\n\nMore content.";
1026        let chunks = splitter.split_text(text).unwrap();
1027        assert!(chunks.len() >= 1);
1028    }
1029
1030    #[test]
1031    fn test_markdown_splitter_preserve_headers() {
1032        let splitter = MarkdownTextSplitter::new(200, 20).with_preserve_headers(true);
1033        let text = "# Main\n\nContent 1\n\n## Section\n\nContent 2";
1034        let chunks = splitter.split_text(text).unwrap();
1035
1036        // When preserving headers, chunks should contain header context
1037        assert!(chunks.len() >= 1);
1038    }
1039
1040    #[test]
1041    fn test_markdown_header_parsing() {
1042        let splitter = MarkdownTextSplitter::new(100, 10);
1043
1044        // Test various header levels
1045        assert_eq!(splitter.parse_header_level("# H1"), Some(1));
1046        assert_eq!(splitter.parse_header_level("## H2"), Some(2));
1047        assert_eq!(splitter.parse_header_level("### H3"), Some(3));
1048        assert_eq!(splitter.parse_header_level("Not a header"), None);
1049        assert_eq!(splitter.parse_header_level("####### Too many"), None);
1050    }
1051
1052    #[test]
1053    fn test_markdown_simple_by_default() {
1054        // Default behavior: simple splitting without header preservation
1055        let splitter = MarkdownTextSplitter::new(500, 50);
1056        assert!(!splitter.preserve_headers);
1057    }
1058
1059    // Code splitter tests
1060    #[test]
1061    fn test_code_splitter_basic() {
1062        let splitter = CodeTextSplitter::new(200, 20);
1063        let code = "fn main() {\n    println!(\"Hello\");\n}\n\nfn test() {\n    // test\n}";
1064        let chunks = splitter.split_text(code).unwrap();
1065        assert!(chunks.len() >= 1);
1066    }
1067
1068    #[test]
1069    fn test_code_splitter_with_language() {
1070        let splitter = CodeTextSplitter::new(300, 30).with_language("rust");
1071        let code =
1072            "fn main() {\n    println!(\"Hello\");\n}\n\nfn test() {\n    println!(\"Test\");\n}";
1073        let chunks = splitter.split_text(code).unwrap();
1074        assert!(chunks.len() >= 1);
1075    }
1076
1077    #[test]
1078    fn test_code_block_detection() {
1079        let splitter = CodeTextSplitter::new(100, 10).with_language("rust");
1080        assert!(splitter.is_code_block_start("fn main() {"));
1081        assert!(splitter.is_code_block_start("pub fn test() {"));
1082        assert!(splitter.is_code_block_start("struct Foo {"));
1083        assert!(!splitter.is_code_block_start("    let x = 5;"));
1084    }
1085
1086    #[test]
1087    fn test_code_splitter_simple_by_default() {
1088        // Default behavior: language-agnostic
1089        let splitter = CodeTextSplitter::new(500, 50);
1090        assert!(splitter.language.is_none());
1091    }
1092
1093    // Semantic splitter tests (using mock embedder)
1094    struct MockEmbedder;
1095
1096    impl Embedder for MockEmbedder {
1097        fn embed(&self, text: &str) -> Result<Vec<f32>> {
1098            // Simple mock: use text length as "embedding"
1099            // In real use, this would call an actual embedding model
1100            let len = text.len() as f32;
1101            Ok(vec![len / 100.0, len / 50.0, len / 25.0])
1102        }
1103    }
1104
1105    #[test]
1106    fn test_semantic_splitter_basic() {
1107        let embedder = Box::new(MockEmbedder);
1108        let splitter = SemanticTextSplitter::new(embedder, 200, 20);
1109        let text =
1110            "First sentence. Second sentence here. Third one is different. Fourth continues.";
1111        let chunks = splitter.split_text(text).unwrap();
1112        assert!(chunks.len() >= 1);
1113    }
1114
1115    #[test]
1116    fn test_semantic_splitter_with_threshold() {
1117        let embedder = Box::new(MockEmbedder);
1118        let splitter = SemanticTextSplitter::new(embedder, 300, 30).with_similarity_threshold(0.8);
1119        let text = "Sentence one. Sentence two. Sentence three.";
1120        let chunks = splitter.split_text(text).unwrap();
1121        assert!(chunks.len() >= 1);
1122    }
1123
1124    #[test]
1125    fn test_semantic_splitter_cosine_similarity() {
1126        let embedder = Box::new(MockEmbedder);
1127        let splitter = SemanticTextSplitter::new(embedder, 100, 10);
1128
1129        let v1 = vec![1.0, 0.0, 0.0];
1130        let v2 = vec![1.0, 0.0, 0.0];
1131        let v3 = vec![0.0, 1.0, 0.0];
1132
1133        // Identical vectors should have similarity 1.0
1134        let sim1 = splitter.cosine_similarity(&v1, &v2);
1135        assert!((sim1 - 1.0).abs() < 0.01);
1136
1137        // Orthogonal vectors should have similarity 0.0
1138        let sim2 = splitter.cosine_similarity(&v1, &v3);
1139        assert!(sim2.abs() < 0.01);
1140    }
1141
1142    #[test]
1143    fn test_embedder_trait_composable() {
1144        // Test that Embedder trait is composable (HYBRID principle)
1145        struct CustomEmbedder;
1146        impl Embedder for CustomEmbedder {
1147            fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1148                Ok(vec![1.0, 2.0, 3.0])
1149            }
1150        }
1151
1152        let embedder = Box::new(CustomEmbedder);
1153        let splitter = SemanticTextSplitter::new(embedder, 500, 50);
1154
1155        let text = "Test text.";
1156        let result = splitter.split_text(text);
1157        assert!(result.is_ok());
1158    }
1159}