Skip to main content

sqlite_graphrag/
chunking.rs

1// src/chunking.rs
2// Token-based chunking for E5 model (512 token limit)
3
4use crate::constants::{CHUNK_OVERLAP_TOKENS, CHUNK_SIZE_TOKENS, EMBEDDING_DIM};
5
6// Heurística conservadora para reduzir o risco de subestimar o número real de tokens
7// em Markdown, código e texto multilíngue. Valor anterior 4 chars/token permitia
8// chunks grandes demais para alguns documentos reais.
9const CHARS_PER_TOKEN: usize = 2;
10pub const CHUNK_SIZE_CHARS: usize = CHUNK_SIZE_TOKENS * CHARS_PER_TOKEN;
11pub const CHUNK_OVERLAP_CHARS: usize = CHUNK_OVERLAP_TOKENS * CHARS_PER_TOKEN;
12
13#[derive(Debug, Clone)]
14pub struct Chunk {
15    pub text: String,
16    pub start_offset: usize,
17    pub end_offset: usize,
18    pub token_count_approx: usize,
19}
20
21pub fn needs_chunking(body: &str) -> bool {
22    body.len() > CHUNK_SIZE_CHARS
23}
24
25pub fn split_into_chunks(body: &str) -> Vec<Chunk> {
26    if !needs_chunking(body) {
27        return vec![Chunk {
28            token_count_approx: body.chars().count() / CHARS_PER_TOKEN,
29            text: body.to_string(),
30            start_offset: 0,
31            end_offset: body.len(),
32        }];
33    }
34
35    let mut chunks = Vec::new();
36    let mut start = 0usize;
37
38    while start < body.len() {
39        start = next_char_boundary(body, start);
40        let desired_end = previous_char_boundary(body, (start + CHUNK_SIZE_CHARS).min(body.len()));
41        let end = if desired_end < body.len() {
42            find_split_boundary(body, start, desired_end)
43        } else {
44            desired_end
45        };
46
47        let end = if end <= start {
48            let fallback = previous_char_boundary(body, (start + CHUNK_SIZE_CHARS).min(body.len()));
49            if fallback > start {
50                fallback
51            } else {
52                body.len()
53            }
54        } else {
55            end
56        };
57
58        let text = body[start..end].to_string();
59        let token_count_approx = text.chars().count() / CHARS_PER_TOKEN;
60        chunks.push(Chunk {
61            text,
62            start_offset: start,
63            end_offset: end,
64            token_count_approx,
65        });
66
67        if end >= body.len() {
68            break;
69        }
70
71        let next_start = next_char_boundary(body, end.saturating_sub(CHUNK_OVERLAP_CHARS));
72        start = if next_start >= end { end } else { next_start };
73    }
74
75    chunks
76}
77
78fn find_split_boundary(body: &str, start: usize, desired_end: usize) -> usize {
79    let slice = &body[start..desired_end];
80    if let Some(pos) = slice.rfind("\n\n") {
81        return start + pos + 2;
82    }
83    if let Some(pos) = slice.rfind(". ") {
84        return start + pos + 2;
85    }
86    if let Some(pos) = slice.rfind(' ') {
87        return start + pos + 1;
88    }
89    desired_end
90}
91
92fn previous_char_boundary(body: &str, mut idx: usize) -> usize {
93    idx = idx.min(body.len());
94    while idx > 0 && !body.is_char_boundary(idx) {
95        idx -= 1;
96    }
97    idx
98}
99
100fn next_char_boundary(body: &str, mut idx: usize) -> usize {
101    idx = idx.min(body.len());
102    while idx < body.len() && !body.is_char_boundary(idx) {
103        idx += 1;
104    }
105    idx
106}
107
108pub fn aggregate_embeddings(chunk_embeddings: &[Vec<f32>]) -> Vec<f32> {
109    if chunk_embeddings.is_empty() {
110        return vec![0.0f32; EMBEDDING_DIM];
111    }
112    if chunk_embeddings.len() == 1 {
113        return chunk_embeddings[0].clone();
114    }
115
116    let dim = chunk_embeddings[0].len();
117    let mut mean = vec![0.0f32; dim];
118    for emb in chunk_embeddings {
119        for (i, v) in emb.iter().enumerate() {
120            mean[i] += v;
121        }
122    }
123    let n = chunk_embeddings.len() as f32;
124    for v in &mut mean {
125        *v /= n;
126    }
127
128    let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
129    if norm > 1e-9 {
130        for v in &mut mean {
131            *v /= norm;
132        }
133    }
134    mean
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_short_body_no_chunking() {
143        let body = "short text";
144        assert!(!needs_chunking(body));
145        let chunks = split_into_chunks(body);
146        assert_eq!(chunks.len(), 1);
147        assert_eq!(chunks[0].text, body);
148    }
149
150    #[test]
151    fn test_long_body_produces_multiple_chunks() {
152        let body = "word ".repeat(1000);
153        assert!(needs_chunking(&body));
154        let chunks = split_into_chunks(&body);
155        assert!(chunks.len() > 1);
156    }
157
158    #[test]
159    fn test_multibyte_body_preserves_progress_and_boundaries() {
160        let body = "ação útil ".repeat(1000);
161        let chunks = split_into_chunks(&body);
162        assert!(chunks.len() > 1);
163        for chunk in &chunks {
164            assert!(!chunk.text.is_empty());
165            assert!(body.is_char_boundary(chunk.start_offset));
166            assert!(body.is_char_boundary(chunk.end_offset));
167            assert!(chunk.end_offset > chunk.start_offset);
168        }
169        for pair in chunks.windows(2) {
170            assert!(pair[1].start_offset >= pair[0].start_offset);
171            assert!(pair[1].end_offset > pair[0].start_offset);
172        }
173    }
174
175    #[test]
176    fn test_aggregate_embeddings_normalizes() {
177        let embs = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
178        let agg = aggregate_embeddings(&embs);
179        let norm: f32 = agg.iter().map(|x| x * x).sum::<f32>().sqrt();
180        assert!((norm - 1.0).abs() < 1e-5);
181    }
182}