Skip to main content

sqz_engine/
mdl_selector.rs

1/// Minimum Description Length (MDL) stage selection.
2///
3/// MDL principle (Rissanen 1978): the best compression is the one that
4/// minimizes description_length(model) + description_length(data|model).
5///
6/// Instead of running all 16 stages on every input, MDL selects the optimal
7/// subset of stages for each content type. Stages that add overhead (headers,
8/// legends) without sufficient compression are skipped.
9
10use crate::error::Result;
11
12/// A stage candidate with its estimated cost and benefit.
13#[derive(Debug, Clone)]
14pub struct StageCandidate {
15    /// Stage name.
16    pub name: String,
17    /// Estimated overhead in tokens (headers, legends, markers).
18    pub overhead_tokens: u32,
19    /// Estimated savings in tokens.
20    pub savings_tokens: u32,
21    /// Whether this stage is applicable to the content type.
22    pub applicable: bool,
23}
24
25impl StageCandidate {
26    /// Net benefit: savings minus overhead. Negative means the stage hurts.
27    pub fn net_benefit(&self) -> i32 {
28        self.savings_tokens as i32 - self.overhead_tokens as i32
29    }
30}
31
32/// Result of MDL stage selection.
33#[derive(Debug, Clone)]
34pub struct MdlSelection {
35    /// Stages to enable (in order).
36    pub enabled_stages: Vec<String>,
37    /// Stages skipped because they would add overhead.
38    pub skipped_stages: Vec<String>,
39    /// Total estimated net savings.
40    pub estimated_net_savings: i32,
41}
42
43/// Content characteristics used for stage selection.
44#[derive(Debug, Clone)]
45pub struct ContentProfile {
46    /// Is the content JSON?
47    pub is_json: bool,
48    /// Content length in bytes.
49    pub length: usize,
50    /// Estimated token count.
51    pub tokens: u32,
52    /// Does the content have repeated lines?
53    pub has_repetition: bool,
54    /// Does the content look like a diff?
55    pub is_diff: bool,
56    /// Does the content look like prose?
57    pub is_prose: bool,
58    /// Does the content look like log output?
59    pub is_log: bool,
60    /// Number of null fields (JSON only).
61    pub null_count: usize,
62    /// Number of array elements (JSON only).
63    pub array_element_count: usize,
64}
65
66/// Select the optimal subset of compression stages for the given content.
67///
68/// Uses the MDL principle: only enable stages where the expected savings
69/// exceed the overhead. This avoids the problem where stages like dict
70/// compression or RLE add headers that make small payloads larger.
71pub fn select_stages(profile: &ContentProfile) -> MdlSelection {
72    let mut candidates = build_candidates(profile);
73
74    // Sort by net benefit descending
75    candidates.sort_by(|a, b| b.net_benefit().cmp(&a.net_benefit()));
76
77    let mut enabled = Vec::new();
78    let mut skipped = Vec::new();
79    let mut total_net = 0i32;
80
81    for candidate in &candidates {
82        if !candidate.applicable {
83            skipped.push(candidate.name.clone());
84            continue;
85        }
86
87        if candidate.net_benefit() > 0 {
88            enabled.push(candidate.name.clone());
89            total_net += candidate.net_benefit();
90        } else {
91            skipped.push(candidate.name.clone());
92        }
93    }
94
95    // Always include ansi_strip (zero overhead, always beneficial)
96    if !enabled.contains(&"ansi_strip".to_string()) {
97        enabled.insert(0, "ansi_strip".to_string());
98    }
99
100    MdlSelection {
101        enabled_stages: enabled,
102        skipped_stages: skipped,
103        estimated_net_savings: total_net,
104    }
105}
106
107/// Build stage candidates with estimated costs and benefits.
108fn build_candidates(p: &ContentProfile) -> Vec<StageCandidate> {
109    vec![
110        StageCandidate {
111            name: "strip_nulls".to_string(),
112            overhead_tokens: 0,
113            savings_tokens: if p.is_json { (p.null_count as u32) * 3 } else { 0 },
114            applicable: p.is_json && p.null_count > 0,
115        },
116        StageCandidate {
117            name: "condense".to_string(),
118            overhead_tokens: 0,
119            savings_tokens: if p.has_repetition { p.tokens / 4 } else { 0 },
120            applicable: !p.is_json && p.has_repetition,
121        },
122        StageCandidate {
123            name: "git_diff_fold".to_string(),
124            overhead_tokens: 2, // "[N unchanged lines]" markers
125            savings_tokens: if p.is_diff { p.tokens / 3 } else { 0 },
126            applicable: p.is_diff,
127        },
128        StageCandidate {
129            name: "collapse_arrays".to_string(),
130            overhead_tokens: 5, // summary string or table header
131            savings_tokens: if p.is_json && p.array_element_count > 10 {
132                (p.array_element_count as u32 - 5) * 4
133            } else {
134                0
135            },
136            applicable: p.is_json && p.array_element_count > 10,
137        },
138        StageCandidate {
139            name: "flatten".to_string(),
140            overhead_tokens: 0,
141            savings_tokens: if p.is_json { p.tokens / 20 } else { 0 },
142            applicable: p.is_json && p.tokens > 50,
143        },
144        StageCandidate {
145            name: "truncate_strings".to_string(),
146            overhead_tokens: 1, // "..." per truncation
147            savings_tokens: if p.is_json { p.tokens / 10 } else { 0 },
148            applicable: p.is_json && p.tokens > 100,
149        },
150        StageCandidate {
151            name: "rle".to_string(),
152            overhead_tokens: 3, // "[×N]" markers
153            savings_tokens: if p.has_repetition && !p.is_json { p.tokens / 5 } else { 0 },
154            applicable: !p.is_json && p.has_repetition && p.length > 200,
155        },
156        StageCandidate {
157            name: "sliding_window_dedup".to_string(),
158            overhead_tokens: 2, // "[→LN]" markers
159            savings_tokens: if p.has_repetition && !p.is_json { p.tokens / 8 } else { 0 },
160            applicable: !p.is_json && p.has_repetition && p.length > 300,
161        },
162        StageCandidate {
163            name: "entropy_truncate".to_string(),
164            overhead_tokens: 3, // "[N segments omitted]"
165            savings_tokens: if p.is_prose && p.tokens > 100 { p.tokens / 6 } else { 0 },
166            applicable: !p.is_json && p.is_prose && p.length > 500,
167        },
168        StageCandidate {
169            name: "token_prune".to_string(),
170            overhead_tokens: 0,
171            savings_tokens: if p.is_prose { p.tokens / 10 } else { 0 },
172            applicable: !p.is_json && p.is_prose && p.length > 100,
173        },
174        StageCandidate {
175            name: "dict_compress".to_string(),
176            overhead_tokens: 15, // §dict§ header
177            savings_tokens: if p.is_json && p.tokens > 50 { p.tokens / 8 } else { 0 },
178            applicable: p.is_json && p.tokens > 120, // only worth it for larger JSON
179        },
180        StageCandidate {
181            name: "toon_encode".to_string(),
182            overhead_tokens: 2, // "TOON:" prefix
183            savings_tokens: if p.is_json { p.tokens / 5 } else { 0 },
184            applicable: p.is_json,
185        },
186        StageCandidate {
187            name: "textrank".to_string(),
188            overhead_tokens: 0,
189            savings_tokens: if p.is_prose && p.tokens > 200 { p.tokens / 3 } else { 0 },
190            applicable: p.is_prose && p.tokens > 200,
191        },
192    ]
193}
194
195/// Profile content to determine its characteristics.
196pub fn profile_content(text: &str) -> ContentProfile {
197    let is_json = text.trim().starts_with('{') || text.trim().starts_with('[');
198    let lines: Vec<&str> = text.lines().collect();
199    let tokens = (text.len() as u32 + 3) / 4;
200
201    // Check for repetition
202    let mut has_repetition = false;
203    if lines.len() > 2 {
204        let mut prev = "";
205        let mut run = 0;
206        for line in &lines {
207            if *line == prev {
208                run += 1;
209                if run >= 2 {
210                    has_repetition = true;
211                    break;
212                }
213            } else {
214                run = 0;
215            }
216            prev = line;
217        }
218    }
219
220    let is_diff = text.contains("\n+") && text.contains("\n-") && text.contains("@@");
221    let is_log = lines.iter().take(10).any(|l| {
222        l.contains("[INFO]") || l.contains("[ERROR]") || l.contains("[WARN]")
223    });
224
225    // Prose heuristic
226    let mut prose_lines = 0;
227    let mut code_lines = 0;
228    for line in lines.iter().take(20) {
229        let t = line.trim();
230        if t.is_empty() { continue; }
231        if t.ends_with('{') || t.ends_with(';') || t.contains("::") || t.contains("->") {
232            code_lines += 1;
233        } else {
234            prose_lines += 1;
235        }
236    }
237    let is_prose = prose_lines > code_lines && !is_json && !is_diff && !is_log;
238
239    // JSON-specific metrics
240    let null_count = if is_json { text.matches(":null").count() + text.matches(": null").count() } else { 0 };
241    let array_element_count = if is_json {
242        text.matches('{').count().saturating_sub(1) // rough estimate
243    } else {
244        0
245    };
246
247    ContentProfile {
248        is_json,
249        length: text.len(),
250        tokens,
251        has_repetition,
252        is_diff,
253        is_prose,
254        is_log,
255        null_count,
256        array_element_count,
257    }
258}
259
260// ── Tests ─────────────────────────────────────────────────────────────────
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_json_profile() {
268        let p = profile_content(r#"{"a":1,"b":null,"c":null}"#);
269        assert!(p.is_json);
270        assert_eq!(p.null_count, 2);
271        assert!(!p.is_prose);
272    }
273
274    #[test]
275    fn test_prose_profile() {
276        let p = profile_content("This is a normal sentence about something interesting and important.");
277        assert!(p.is_prose);
278        assert!(!p.is_json);
279    }
280
281    #[test]
282    fn test_diff_profile() {
283        let p = profile_content("@@ -1,5 +1,5 @@\n-old\n+new\n context\n");
284        assert!(p.is_diff);
285    }
286
287    #[test]
288    fn test_log_profile() {
289        let p = profile_content("[INFO] Started\n[ERROR] Failed\n[WARN] Slow\n");
290        assert!(p.is_log);
291    }
292
293    #[test]
294    fn test_select_stages_json() {
295        let p = ContentProfile {
296            is_json: true,
297            length: 500,
298            tokens: 125,
299            has_repetition: false,
300            is_diff: false,
301            is_prose: false,
302            is_log: false,
303            null_count: 5,
304            array_element_count: 0,
305        };
306        let sel = select_stages(&p);
307        assert!(sel.enabled_stages.contains(&"strip_nulls".to_string()));
308        assert!(sel.enabled_stages.contains(&"toon_encode".to_string()));
309        assert!(!sel.enabled_stages.contains(&"condense".to_string()));
310    }
311
312    #[test]
313    fn test_select_stages_prose() {
314        let p = ContentProfile {
315            is_json: false,
316            length: 1000,
317            tokens: 250,
318            has_repetition: false,
319            is_diff: false,
320            is_prose: true,
321            is_log: false,
322            null_count: 0,
323            array_element_count: 0,
324        };
325        let sel = select_stages(&p);
326        assert!(sel.enabled_stages.contains(&"token_prune".to_string()));
327        assert!(sel.enabled_stages.contains(&"textrank".to_string()));
328        assert!(!sel.enabled_stages.contains(&"strip_nulls".to_string()));
329    }
330
331    #[test]
332    fn test_select_stages_always_includes_ansi_strip() {
333        let p = ContentProfile {
334            is_json: false,
335            length: 10,
336            tokens: 3,
337            has_repetition: false,
338            is_diff: false,
339            is_prose: false,
340            is_log: false,
341            null_count: 0,
342            array_element_count: 0,
343        };
344        let sel = select_stages(&p);
345        assert!(sel.enabled_stages.contains(&"ansi_strip".to_string()));
346    }
347
348    #[test]
349    fn test_select_stages_skips_overhead_for_small_json() {
350        let p = ContentProfile {
351            is_json: true,
352            length: 50,
353            tokens: 12,
354            has_repetition: false,
355            is_diff: false,
356            is_prose: false,
357            is_log: false,
358            null_count: 0,
359            array_element_count: 0,
360        };
361        let sel = select_stages(&p);
362        // dict_compress has 15 token overhead — not worth it for 12 tokens
363        assert!(!sel.enabled_stages.contains(&"dict_compress".to_string()));
364    }
365
366    #[test]
367    fn test_net_benefit_calculation() {
368        let c = StageCandidate {
369            name: "test".to_string(),
370            overhead_tokens: 5,
371            savings_tokens: 20,
372            applicable: true,
373        };
374        assert_eq!(c.net_benefit(), 15);
375
376        let c2 = StageCandidate {
377            name: "test".to_string(),
378            overhead_tokens: 20,
379            savings_tokens: 5,
380            applicable: true,
381        };
382        assert_eq!(c2.net_benefit(), -15);
383    }
384
385    #[test]
386    fn test_repetitive_log_profile() {
387        let p = profile_content("[INFO] ok\n[INFO] ok\n[INFO] ok\n[ERROR] fail\n");
388        assert!(p.is_log);
389        assert!(p.has_repetition);
390    }
391}