Skip to main content

three_dcf_core/
chunk.rs

1use hex;
2use once_cell::sync::Lazy;
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use tiktoken_rs::CoreBPE;
6
7use crate::document::{CellRecord, CellType, Document};
8
9static TOKENIZER: Lazy<CoreBPE> = Lazy::new(|| tiktoken_rs::cl100k_base().expect("tokenizer"));
10const CHUNK_VERSION: u32 = 1;
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "snake_case")]
14pub enum ChunkMode {
15    Cells,
16    Tokens,
17    Headings,
18    TableRows,
19}
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22pub struct ChunkConfig {
23    pub mode: ChunkMode,
24    pub cells_per_chunk: usize,
25    pub overlap_cells: usize,
26    pub max_tokens: usize,
27    pub overlap_tokens: usize,
28}
29
30impl Default for ChunkConfig {
31    fn default() -> Self {
32        Self {
33            mode: ChunkMode::Cells,
34            cells_per_chunk: 200,
35            overlap_cells: 20,
36            max_tokens: 512,
37            overlap_tokens: 64,
38        }
39    }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ChunkRecord {
44    pub chunk_id: String,
45    pub doc: String,
46    pub chunk_index: usize,
47    pub z_start: u32,
48    pub z_end: u32,
49    pub cell_start: usize,
50    pub cell_end: usize,
51    pub text: String,
52    #[serde(default)]
53    pub token_count: usize,
54    #[serde(default = "default_cell_type")]
55    pub dominant_type: CellType,
56    #[serde(default)]
57    pub importance_mean: f32,
58}
59
60pub struct Chunker {
61    config: ChunkConfig,
62}
63
64impl Chunker {
65    pub fn new(config: ChunkConfig) -> Self {
66        Self { config }
67    }
68
69    pub fn chunk_document(&self, document: &Document, doc_id: &str) -> Vec<ChunkRecord> {
70        let ordered = document.ordered_cells();
71        if ordered.is_empty() {
72            return Vec::new();
73        }
74        match self.config.mode {
75            ChunkMode::Cells => self.chunk_by_cells(document, doc_id, &ordered),
76            ChunkMode::Tokens => self.chunk_by_tokens(document, doc_id, &ordered),
77            ChunkMode::Headings => self.chunk_by_headings(document, doc_id, &ordered),
78            ChunkMode::TableRows => self.chunk_table_blocks(document, doc_id, &ordered),
79        }
80    }
81
82    fn chunk_by_cells(
83        &self,
84        document: &Document,
85        doc_id: &str,
86        ordered: &[CellRecord],
87    ) -> Vec<ChunkRecord> {
88        let chunk_size = self.config.cells_per_chunk.max(1);
89        let overlap = self.config.overlap_cells.min(chunk_size.saturating_sub(1));
90        let mut start = 0usize;
91        let mut chunk_index = 0usize;
92        let mut chunks = Vec::new();
93        while start < ordered.len() {
94            let end = (start + chunk_size).min(ordered.len());
95            if let Some(record) =
96                self.build_chunk(document, doc_id, chunk_index, start, end, ordered)
97            {
98                chunks.push(record);
99                chunk_index += 1;
100            }
101            if end == ordered.len() {
102                break;
103            }
104            start = if overlap == 0 {
105                end
106            } else {
107                end.saturating_sub(overlap)
108            };
109        }
110        chunks
111    }
112
113    fn chunk_by_tokens(
114        &self,
115        document: &Document,
116        doc_id: &str,
117        ordered: &[CellRecord],
118    ) -> Vec<ChunkRecord> {
119        let max_tokens = self.config.max_tokens.max(1);
120        let overlap_tokens = self.config.overlap_tokens.min(max_tokens.saturating_sub(1));
121        let tokens_per_cell = token_counts(document, ordered);
122        let mut start = 0usize;
123        let mut chunk_index = 0usize;
124        let mut chunks = Vec::new();
125        while start < ordered.len() {
126            let mut end = start;
127            let mut used_tokens = 0usize;
128            while end < ordered.len() {
129                let cell_tokens = tokens_per_cell[end].max(1);
130                if end > start && used_tokens + cell_tokens > max_tokens {
131                    break;
132                }
133                used_tokens += cell_tokens;
134                end += 1;
135            }
136            if end == start {
137                end += 1;
138            }
139            if let Some(record) =
140                self.build_chunk(document, doc_id, chunk_index, start, end, ordered)
141            {
142                chunks.push(record);
143                chunk_index += 1;
144            }
145            if end == ordered.len() {
146                break;
147            }
148            if overlap_tokens == 0 {
149                start = end;
150            } else {
151                let mut back_tokens = 0usize;
152                let mut new_start = end;
153                while new_start > start {
154                    new_start -= 1;
155                    back_tokens += tokens_per_cell[new_start].max(1);
156                    if back_tokens >= overlap_tokens {
157                        break;
158                    }
159                }
160                start = new_start;
161            }
162        }
163        chunks
164    }
165
166    fn chunk_by_headings(
167        &self,
168        document: &Document,
169        doc_id: &str,
170        ordered: &[CellRecord],
171    ) -> Vec<ChunkRecord> {
172        let mut chunks = Vec::new();
173        let tokens_per_cell = token_counts(document, ordered);
174        let mut chunk_index = 0usize;
175        let mut idx = 0usize;
176        while idx < ordered.len() {
177            if ordered[idx].cell_type != CellType::Header {
178                idx += 1;
179                continue;
180            }
181            let start = idx;
182            let mut end = idx;
183            let mut tokens = 0usize;
184            while end < ordered.len() {
185                if end > start && ordered[end].cell_type == CellType::Header {
186                    break;
187                }
188                tokens += tokens_per_cell[end];
189                if self.config.max_tokens > 0 && tokens >= self.config.max_tokens {
190                    end += 1;
191                    break;
192                }
193                end += 1;
194            }
195            if let Some(record) =
196                self.build_chunk(document, doc_id, chunk_index, start, end, ordered)
197            {
198                chunks.push(record);
199                chunk_index += 1;
200            }
201            idx = end;
202        }
203        chunks
204    }
205
206    fn chunk_table_blocks(
207        &self,
208        document: &Document,
209        doc_id: &str,
210        ordered: &[CellRecord],
211    ) -> Vec<ChunkRecord> {
212        let mut chunks = Vec::new();
213        let mut idx = 0usize;
214        let mut chunk_index = 0usize;
215        while idx < ordered.len() {
216            if ordered[idx].cell_type != CellType::Table {
217                idx += 1;
218                continue;
219            }
220            let mut block_end = idx;
221            while block_end < ordered.len() && ordered[block_end].cell_type == CellType::Table {
222                block_end += 1;
223            }
224            let mut start = idx;
225            while start < block_end {
226                let end = (start + self.config.cells_per_chunk.max(1)).min(block_end);
227                if let Some(record) =
228                    self.build_chunk(document, doc_id, chunk_index, start, end, ordered)
229                {
230                    chunks.push(record);
231                    chunk_index += 1;
232                }
233                start = end;
234            }
235            idx = block_end;
236        }
237        chunks
238    }
239
240    fn build_chunk(
241        &self,
242        document: &Document,
243        doc_id: &str,
244        chunk_index: usize,
245        start: usize,
246        end: usize,
247        ordered: &[CellRecord],
248    ) -> Option<ChunkRecord> {
249        if start >= end || start >= ordered.len() {
250            return None;
251        }
252        let slice = &ordered[start..end];
253        let mut parts = Vec::with_capacity(slice.len());
254        let mut token_total = 0usize;
255        let mut importance_sum = 0usize;
256        let mut type_hist = [0usize; 5];
257        for cell in slice {
258            if let Some(payload) = document.payload_for(&cell.code_id) {
259                if !payload.trim().is_empty() {
260                    parts.push(payload.to_string());
261                }
262                token_total += count_tokens(payload);
263            }
264            importance_sum += cell.importance as usize;
265            increment_histogram(&mut type_hist, cell.cell_type);
266        }
267        let text = parts.join("\n");
268        if text.trim().is_empty() {
269            return None;
270        }
271        let z_start = slice.first().map(|c| c.z).unwrap_or(0);
272        let z_end = slice.last().map(|c| c.z).unwrap_or(z_start);
273        let chunk_id = stable_chunk_id(
274            doc_id,
275            chunk_index,
276            start,
277            end.saturating_sub(1),
278            self.config.mode,
279            CHUNK_VERSION,
280        );
281        let dominant_type = dominant_cell_type(&type_hist);
282        let importance_mean = if slice.is_empty() {
283            0.0
284        } else {
285            importance_sum as f32 / (slice.len() as f32 * 255.0)
286        };
287        Some(ChunkRecord {
288            chunk_id,
289            doc: doc_id.to_string(),
290            chunk_index,
291            z_start,
292            z_end,
293            cell_start: start,
294            cell_end: end.saturating_sub(1),
295            text,
296            token_count: token_total,
297            dominant_type,
298            importance_mean,
299        })
300    }
301}
302
303fn increment_histogram(hist: &mut [usize; 5], cell_type: CellType) {
304    match cell_type {
305        CellType::Text => hist[0] += 1,
306        CellType::Table => hist[1] += 1,
307        CellType::Figure => hist[2] += 1,
308        CellType::Footer => hist[3] += 1,
309        CellType::Header => hist[4] += 1,
310    }
311}
312
313fn dominant_cell_type(hist: &[usize; 5]) -> CellType {
314    let mut max_idx = 0usize;
315    let mut max_val = 0usize;
316    for (idx, val) in hist.iter().enumerate() {
317        if *val > max_val {
318            max_val = *val;
319            max_idx = idx;
320        }
321    }
322    match max_idx {
323        0 => CellType::Text,
324        1 => CellType::Table,
325        2 => CellType::Figure,
326        3 => CellType::Footer,
327        _ => CellType::Header,
328    }
329}
330
331fn default_cell_type() -> CellType {
332    CellType::Text
333}
334
335fn token_counts(document: &Document, cells: &[CellRecord]) -> Vec<usize> {
336    cells
337        .iter()
338        .map(|cell| {
339            document
340                .payload_for(&cell.code_id)
341                .map(count_tokens)
342                .unwrap_or(0)
343        })
344        .collect()
345}
346
347fn count_tokens(text: &str) -> usize {
348    TOKENIZER.encode_with_special_tokens(text).len()
349}
350
351fn stable_chunk_id(
352    doc_id: &str,
353    chunk_index: usize,
354    cell_start: usize,
355    cell_end: usize,
356    mode: ChunkMode,
357    version: u32,
358) -> String {
359    let mut hasher = Sha256::new();
360    hasher.update(doc_id.as_bytes());
361    hasher.update(&version.to_be_bytes());
362    hasher.update(&(mode_discriminant(mode)).to_be_bytes());
363    hasher.update(&chunk_index.to_be_bytes());
364    hasher.update(&cell_start.to_be_bytes());
365    hasher.update(&cell_end.to_be_bytes());
366    hex::encode(hasher.finalize())
367}
368
369fn mode_discriminant(mode: ChunkMode) -> u32 {
370    match mode {
371        ChunkMode::Cells => 0,
372        ChunkMode::Tokens => 1,
373        ChunkMode::Headings => 2,
374        ChunkMode::TableRows => 3,
375    }
376}