Skip to main content

three_dcf_core/
bench.rs

1use std::cell::RefCell;
2use std::fs::{self, File};
3use std::io::Write;
4use std::path::{Path, PathBuf};
5use std::time::Instant;
6
7use serde::Serialize;
8use sysinfo::{Pid, System};
9use walkdir::WalkDir;
10
11use crate::document::Document;
12use crate::encoder::Encoder;
13use crate::error::Result;
14use crate::metrics::{cer, numeric_stats, wer};
15use crate::stats::{Stats, TokenizerKind};
16
17#[derive(Debug, Clone, Copy, Serialize)]
18#[serde(rename_all = "snake_case")]
19pub enum BenchMode {
20    Encode,
21    Decode,
22    Full,
23}
24
25#[derive(Debug, Clone, Copy)]
26enum BenchStage {
27    Encode,
28    Decode,
29}
30
31impl BenchStage {
32    fn as_mode(self) -> BenchMode {
33        match self {
34            BenchStage::Encode => BenchMode::Encode,
35            BenchStage::Decode => BenchMode::Decode,
36        }
37    }
38}
39
40#[derive(Debug, Clone)]
41pub struct BenchConfig {
42    pub mode: BenchMode,
43    pub root: PathBuf,
44    pub gold_root: Option<PathBuf>,
45    pub output: Option<PathBuf>,
46    pub preset: String,
47    pub tokenizer: TokenizerKind,
48    pub budgets: Vec<Option<usize>>,
49}
50
51#[derive(Debug, Clone, Serialize)]
52pub struct BenchResult {
53    pub row_type: &'static str,
54    pub run_id: String,
55    pub mode: BenchMode,
56    pub doc: String,
57    pub preset: String,
58    pub encode_ms: u128,
59    pub decode_ms: u128,
60    pub cer: Option<f64>,
61    pub wer: Option<f64>,
62    pub numguard_f1: Option<f64>,
63    pub units_ok: Option<f64>,
64    pub tokens_raw: usize,
65    pub tokens_3dcf: usize,
66    pub savings_ratio: f64,
67    pub avg_cells_kept_per_page: f64,
68    pub pages: usize,
69    pub budget: Option<usize>,
70    pub numguard_mismatches: usize,
71    pub encode_pages_per_s: f64,
72    pub decode_pages_per_s: f64,
73    pub mem_peak_mb: f64,
74}
75
76#[derive(Debug, Clone, Serialize)]
77pub struct BenchPageRow {
78    pub row_type: &'static str,
79    pub run_id: String,
80    pub doc: String,
81    pub preset: String,
82    pub page_idx: u32,
83    pub cer_page: f64,
84    pub precision_page: f64,
85    pub tokens_gold_page: usize,
86    pub tokens_3dcf_page: usize,
87    pub compression_ratio: f64,
88    pub budget: Option<usize>,
89}
90
91#[derive(Debug, Clone, Serialize)]
92pub struct CorpusMetrics {
93    pub results: Vec<BenchResult>,
94    pub mean_savings: f64,
95    pub median_savings: f64,
96    pub encode_p50_ms: f64,
97    pub encode_p95_ms: f64,
98    pub decode_p50_ms: f64,
99    pub decode_p95_ms: f64,
100    pub mean_encode_pages_per_s: f64,
101    pub mean_decode_pages_per_s: f64,
102    pub max_mem_mb: f64,
103}
104
105pub struct BenchRunner {
106    config: BenchConfig,
107    tokenizer: tiktoken_rs::CoreBPE,
108    sys: RefCell<System>,
109    pid: Pid,
110    mem_peak_mb: RefCell<f64>,
111}
112
113impl BenchRunner {
114    pub fn new(config: BenchConfig) -> Result<Self> {
115        let tokenizer = config.tokenizer.build()?;
116        let pid =
117            sysinfo::get_current_pid().map_err(|e| crate::error::DcfError::Other(e.to_string()))?;
118        let mut sys = System::new();
119        sys.refresh_process(pid);
120        Ok(Self {
121            config,
122            tokenizer,
123            sys: RefCell::new(sys),
124            pid,
125            mem_peak_mb: RefCell::new(0.0),
126        })
127    }
128
129    pub fn run(&self) -> Result<CorpusMetrics> {
130        let budgets = if self.config.budgets.is_empty() {
131            vec![None]
132        } else {
133            self.config.budgets.clone()
134        };
135
136        let mut doc_rows = Vec::new();
137        *self.mem_peak_mb.borrow_mut() = 0.0;
138        match self.config.mode {
139            BenchMode::Encode => {
140                for budget in budgets {
141                    doc_rows.extend(self.run_encode_cycle(budget)?);
142                }
143            }
144            BenchMode::Decode => {
145                doc_rows.extend(self.run_decode_cycle()?);
146            }
147            BenchMode::Full => {
148                for budget in budgets {
149                    doc_rows.extend(self.run_encode_cycle(budget)?);
150                }
151                doc_rows.extend(self.run_decode_cycle()?);
152            }
153        }
154
155        let mean = if doc_rows.is_empty() {
156            0.0
157        } else {
158            doc_rows.iter().map(|r| r.savings_ratio).sum::<f64>() / doc_rows.len() as f64
159        };
160        let mut ordered = doc_rows.iter().map(|r| r.savings_ratio).collect::<Vec<_>>();
161        ordered.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
162        let median = if ordered.is_empty() {
163            0.0
164        } else {
165            let mid = ordered.len() / 2;
166            ordered[mid]
167        };
168        let encode_ms_vals = doc_rows
169            .iter()
170            .map(|r| r.encode_ms as f64)
171            .collect::<Vec<_>>();
172        let decode_ms_vals = doc_rows
173            .iter()
174            .map(|r| r.decode_ms as f64)
175            .collect::<Vec<_>>();
176        let encode_p50 = percentile(&encode_ms_vals, 0.5);
177        let encode_p95 = percentile(&encode_ms_vals, 0.95);
178        let decode_p50 = percentile(&decode_ms_vals, 0.5);
179        let decode_p95 = percentile(&decode_ms_vals, 0.95);
180        let mean_encode_pages = if doc_rows.is_empty() {
181            0.0
182        } else {
183            doc_rows.iter().map(|r| r.encode_pages_per_s).sum::<f64>() / doc_rows.len() as f64
184        };
185        let mean_decode_pages = if doc_rows.is_empty() {
186            0.0
187        } else {
188            doc_rows.iter().map(|r| r.decode_pages_per_s).sum::<f64>() / doc_rows.len() as f64
189        };
190        Ok(CorpusMetrics {
191            results: doc_rows,
192            mean_savings: mean,
193            median_savings: median,
194            encode_p50_ms: encode_p50,
195            encode_p95_ms: encode_p95,
196            decode_p50_ms: decode_p50,
197            decode_p95_ms: decode_p95,
198            mean_encode_pages_per_s: mean_encode_pages,
199            mean_decode_pages_per_s: mean_decode_pages,
200            max_mem_mb: *self.mem_peak_mb.borrow(),
201        })
202    }
203
204    fn run_encode_cycle(&self, budget: Option<usize>) -> Result<Vec<BenchResult>> {
205        let encoder = self.build_encoder(budget)?;
206        let mut rows = Vec::new();
207        for entry in WalkDir::new(&self.config.root)
208            .into_iter()
209            .filter_map(|e| e.ok())
210            .filter(|e| e.file_type().is_file())
211        {
212            let path = entry.path();
213            if !is_supported_source(path) {
214                continue;
215            }
216            let (row, pages) = self.process_encode_doc(&encoder, path, budget)?;
217            self.append_row(&row)?;
218            for page in pages {
219                self.append_page_row(&page)?;
220            }
221            rows.push(row);
222        }
223        Ok(rows)
224    }
225
226    fn run_decode_cycle(&self) -> Result<Vec<BenchResult>> {
227        let mut rows = Vec::new();
228        for entry in WalkDir::new(&self.config.root)
229            .into_iter()
230            .filter_map(|e| e.ok())
231            .filter(|e| e.file_type().is_file())
232        {
233            let path = entry.path();
234            if path.extension().and_then(|e| e.to_str()) != Some("3dcf") {
235                continue;
236            }
237            let (row, pages) = self.process_decode_doc(path)?;
238            self.append_row(&row)?;
239            for page in pages {
240                self.append_page_row(&page)?;
241            }
242            rows.push(row);
243        }
244        Ok(rows)
245    }
246
247    fn process_encode_doc(
248        &self,
249        encoder: &Encoder,
250        path: &Path,
251        budget: Option<usize>,
252    ) -> Result<(BenchResult, Vec<BenchPageRow>)> {
253        let encode_start = Instant::now();
254        let (doc, _) = encoder.encode_path(path)?;
255        let encode_ms = encode_start.elapsed().as_millis();
256        self.measure_doc(path, &doc, encode_ms, BenchStage::Encode, budget)
257    }
258
259    fn process_decode_doc(&self, path: &Path) -> Result<(BenchResult, Vec<BenchPageRow>)> {
260        let load_start = Instant::now();
261        let doc = Document::load_bin(path)?;
262        let _load_ms = load_start.elapsed().as_millis();
263        let (row, pages) = self.measure_doc(path, &doc, 0, BenchStage::Decode, None)?;
264        Ok((row, pages))
265    }
266
267    fn measure_doc(
268        &self,
269        path: &Path,
270        doc: &Document,
271        encode_ms: u128,
272        stage: BenchStage,
273        budget: Option<usize>,
274    ) -> Result<(BenchResult, Vec<BenchPageRow>)> {
275        let decode_start = Instant::now();
276        let decoded = doc.decode_to_text();
277        let decode_ms = decode_start.elapsed().as_millis();
278        let stats = Stats::measure_with_bpe(doc, &self.tokenizer)?;
279        let rel = self.relative_path(path);
280        let run_id = self.run_id(stage, budget);
281        let gold = self.load_gold(&rel, doc.total_pages())?;
282
283        let (cer_doc, wer_doc, num_stats) = if let Some(gold_doc) = &gold {
284            let gold_text = gold_doc
285                .doc
286                .clone()
287                .unwrap_or_else(|| gold_doc.joined_pages());
288            (
289                Some(cer(&decoded, &gold_text)),
290                Some(wer(&decoded, &gold_text)),
291                Some(numeric_stats(&decoded, &gold_text)),
292            )
293        } else {
294            (None, None, None)
295        };
296
297        let avg_cells = if doc.total_pages() == 0 {
298            0.0
299        } else {
300            doc.total_cells() as f64 / doc.total_pages() as f64
301        };
302
303        let numguard_alerts = doc.numguard_mismatches();
304        let mem_mb = self.observe_memory_mb();
305
306        let pages_f = doc.total_pages().max(1) as f64;
307
308        let row = BenchResult {
309            row_type: "doc",
310            run_id: run_id.clone(),
311            mode: stage.as_mode(),
312            doc: rel.clone(),
313            preset: self.config.preset.clone(),
314            encode_ms,
315            decode_ms,
316            cer: cer_doc,
317            wer: wer_doc,
318            numguard_f1: num_stats.map(|n| n.f1),
319            units_ok: num_stats.map(|n| n.units_ok),
320            tokens_raw: stats.tokens_raw,
321            tokens_3dcf: stats.tokens_3dcf,
322            savings_ratio: stats.savings_ratio as f64,
323            avg_cells_kept_per_page: avg_cells,
324            pages: doc.total_pages(),
325            budget,
326            numguard_mismatches: numguard_alerts.len(),
327            encode_pages_per_s: if encode_ms == 0 {
328                0.0
329            } else {
330                pages_f / (encode_ms as f64 / 1000.0)
331            },
332            decode_pages_per_s: if decode_ms == 0 {
333                0.0
334            } else {
335                pages_f / (decode_ms as f64 / 1000.0)
336            },
337            mem_peak_mb: mem_mb,
338        };
339
340        let page_rows = match gold {
341            Some(gold_doc) => self.page_metrics(&run_id, &rel, doc, &gold_doc, budget)?,
342            None => Vec::new(),
343        };
344
345        Ok((row, page_rows))
346    }
347
348    fn page_metrics(
349        &self,
350        run_id: &str,
351        rel: &str,
352        doc: &Document,
353        gold: &GoldDoc,
354        budget: Option<usize>,
355    ) -> Result<Vec<BenchPageRow>> {
356        let mut rows = Vec::new();
357        for (idx, gold_page) in gold.pages.iter().enumerate() {
358            let gold_text = match gold_page {
359                Some(text) => text,
360                None => continue,
361            };
362            let pred = doc.decode_page_to_text(idx as u32);
363            let cer_page = cer(&pred, gold_text);
364            let precision = (1.0 - cer_page).clamp(0.0, 1.0);
365            let tokens_gold = self.tokenizer.encode_with_special_tokens(gold_text).len();
366            let tokens_pred = self
367                .tokenizer
368                .encode_with_special_tokens(pred.as_str())
369                .len();
370            let compression = if tokens_pred == 0 {
371                0.0
372            } else {
373                tokens_gold as f64 / tokens_pred as f64
374            };
375            rows.push(BenchPageRow {
376                row_type: "page",
377                run_id: run_id.to_string(),
378                doc: rel.to_string(),
379                preset: self.config.preset.clone(),
380                page_idx: idx as u32,
381                cer_page,
382                precision_page: precision,
383                tokens_gold_page: tokens_gold,
384                tokens_3dcf_page: tokens_pred,
385                compression_ratio: compression,
386                budget,
387            });
388        }
389        Ok(rows)
390    }
391
392    fn load_gold(&self, rel: &str, page_count: usize) -> Result<Option<GoldDoc>> {
393        let root = match &self.config.gold_root {
394            Some(path) => path,
395            None => return Ok(None),
396        };
397        let rel_path = Path::new(rel);
398        let mut doc_path = root.join(rel_path);
399        doc_path.set_extension("txt");
400        let doc_text = fs::read_to_string(&doc_path).ok();
401        let mut base = doc_path.clone();
402        base.set_extension("");
403        let mut pages = Vec::with_capacity(page_count);
404        for idx in 0..page_count {
405            let page_path = base.join(format!("page_{idx:04}.txt"));
406            pages.push(fs::read_to_string(&page_path).ok());
407        }
408        if doc_text.is_none() && pages.iter().all(|p| p.is_none()) {
409            return Ok(None);
410        }
411        Ok(Some(GoldDoc {
412            doc: doc_text,
413            pages,
414        }))
415    }
416
417    fn append_row(&self, row: &BenchResult) -> Result<()> {
418        if let Some(out) = &self.config.output {
419            append_json_line(out, row)?;
420        }
421        Ok(())
422    }
423
424    fn append_page_row(&self, row: &BenchPageRow) -> Result<()> {
425        if let Some(out) = &self.config.output {
426            append_json_line(out, row)?;
427        }
428        Ok(())
429    }
430
431    fn relative_path(&self, path: &Path) -> String {
432        path.strip_prefix(&self.config.root)
433            .unwrap_or(path)
434            .to_string_lossy()
435            .to_string()
436    }
437
438    fn run_id(&self, stage: BenchStage, budget: Option<usize>) -> String {
439        match stage {
440            BenchStage::Encode => match budget {
441                Some(b) => format!("{}-{}", self.config.preset, b),
442                None => format!("{}-auto", self.config.preset),
443            },
444            BenchStage::Decode => format!("{}-decode", self.config.preset),
445        }
446    }
447
448    fn build_encoder(&self, budget: Option<usize>) -> Result<Encoder> {
449        let mut builder = Encoder::builder(&self.config.preset)?;
450        if let Some(b) = budget {
451            builder = builder.budget(Some(b));
452        }
453        Ok(builder.build())
454    }
455
456    fn observe_memory_mb(&self) -> f64 {
457        let mem = self.current_memory_mb();
458        let mut peak = self.mem_peak_mb.borrow_mut();
459        if mem > *peak {
460            *peak = mem;
461        }
462        mem
463    }
464
465    fn current_memory_mb(&self) -> f64 {
466        let mut sys = self.sys.borrow_mut();
467        sys.refresh_process(self.pid);
468        if let Some(proc) = sys.process(self.pid) {
469            proc.memory() as f64 / 1024.0
470        } else {
471            0.0
472        }
473    }
474}
475struct GoldDoc {
476    doc: Option<String>,
477    pages: Vec<Option<String>>,
478}
479
480impl GoldDoc {
481    fn joined_pages(&self) -> String {
482        self.pages
483            .iter()
484            .filter_map(|p| p.as_ref())
485            .cloned()
486            .collect::<Vec<_>>()
487            .join("\n")
488    }
489}
490
491fn is_supported_source(path: &Path) -> bool {
492    match path
493        .extension()
494        .and_then(|ext| ext.to_str())
495        .map(|s| s.to_lowercase())
496        .as_deref()
497    {
498        Some("pdf") | Some("txt") | Some("text") | Some("md") | Some("markdown") | Some("html")
499        | Some("htm") | Some("json") | Some("tex") | Some("bib") => true,
500        None => true,
501        _ => false,
502    }
503}
504
505fn append_json_line<T: Serialize>(path: &Path, value: &T) -> Result<()> {
506    let mut file = File::options().append(true).create(true).open(path)?;
507    serde_json::to_writer(&mut file, value)?;
508    file.write_all(b"\n")?;
509    Ok(())
510}
511
512fn percentile(values: &[f64], quantile: f64) -> f64 {
513    if values.is_empty() {
514        return 0.0;
515    }
516    let mut sorted = values.to_vec();
517    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
518    let idx = ((sorted.len() - 1) as f64 * quantile.clamp(0.0, 1.0)).round() as usize;
519    sorted[idx]
520}