Skip to main content

zeph_tools/compression/
rule_based.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Regex-based tool output compressor with self-evolution support.
5//!
6//! Rules are loaded from `SQLite` and stored as compiled `regex::Regex` values in a
7//! `parking_lot::RwLock<Vec<CompiledRule>>`. Hit counts are tracked separately in a
8//! `dashmap::DashMap<String, AtomicU64>` so that a rules-vec swap (on reload) cannot
9//! lose unflushed counters.
10
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16use dashmap::DashMap;
17use parking_lot::RwLock;
18use zeph_common::ToolName;
19
20use super::{CompressionError, CompressionRuleStore, OutputCompressor, safe_compile};
21
22/// A compiled compression rule ready for matching.
23struct CompiledRule {
24    id: String,
25    /// When `Some`, this rule only applies to tools whose name matches the glob.
26    glob: Option<globset::GlobMatcher>,
27    pattern: regex::Regex,
28    replacement_template: String,
29}
30
31/// Regex-based compressor that applies operator- and LLM-evolved rules to tool output.
32///
33/// Rules are sorted deterministically by `id` to ensure stable application order.
34/// Hit counts are stored in `hits` keyed by `rule.id`; the `rules` vec can be swapped
35/// on reload without losing any unflushed counts.
36///
37/// # Invariants
38///
39/// - Rules are applied in `id`-ascending order (deterministic).
40/// - `compress` returns the first successful match (earliest rule wins).
41/// - A rule is skipped when `glob` is set and does not match `tool_name`.
42/// - `regex::Regex::replace_all` guarantees linear time (no catastrophic backtracking).
43///   No `catch_unwind` is needed around `replace_all`.
44pub struct RuleBasedCompressor {
45    rules: RwLock<Vec<CompiledRule>>,
46    hits: DashMap<String, AtomicU64>,
47    store: Arc<CompressionRuleStore>,
48    max_output_lines: usize,
49    regex_timeout_ms: u64,
50}
51
52impl std::fmt::Debug for RuleBasedCompressor {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("RuleBasedCompressor")
55            .field("rules_count", &self.rules.read().len())
56            .field("max_output_lines", &self.max_output_lines)
57            .field("regex_timeout_ms", &self.regex_timeout_ms)
58            .finish_non_exhaustive()
59    }
60}
61
62impl RuleBasedCompressor {
63    /// Load all active rules from the store and compile them.
64    ///
65    /// Rules that fail compilation are skipped and logged as warnings.
66    ///
67    /// `regex_timeout_ms` controls the DoS-safe regex compilation timeout passed to
68    /// [`super::safe_compile`]. Sourced from `[tools.compression] regex_compile_timeout_ms`.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`CompressionError::Db`] if the store query fails.
73    pub async fn load(
74        store: Arc<CompressionRuleStore>,
75        max_output_lines: usize,
76        regex_timeout_ms: u64,
77    ) -> Result<Self, CompressionError> {
78        let raw_rules = store.list_active().await?;
79        let mut compiled = Vec::with_capacity(raw_rules.len());
80        let hits = DashMap::new();
81
82        for rule in raw_rules {
83            let glob = if let Some(ref g) = rule.tool_glob {
84                match globset::Glob::new(g) {
85                    Ok(glob) => Some(glob.compile_matcher()),
86                    Err(e) => {
87                        tracing::warn!(rule_id = %rule.id, pattern = %g, error = %e, "rule: invalid glob, skipping");
88                        continue;
89                    }
90                }
91            } else {
92                None
93            };
94
95            match super::safe_compile(&rule.pattern, regex_timeout_ms).await {
96                Ok(re) => {
97                    hits.insert(rule.id.clone(), AtomicU64::new(0));
98                    compiled.push(CompiledRule {
99                        id: rule.id,
100                        glob,
101                        pattern: re,
102                        replacement_template: rule.replacement_template,
103                    });
104                }
105                Err(e) => {
106                    tracing::warn!(rule_id = %rule.id, error = %e, "rule: compile failed, skipping");
107                }
108            }
109        }
110
111        compiled.sort_unstable_by(|a, b| a.id.cmp(&b.id));
112
113        Ok(Self {
114            rules: RwLock::new(compiled),
115            hits,
116            store,
117            max_output_lines,
118            regex_timeout_ms,
119        })
120    }
121
122    /// Reload rules from the store, preserving hit counts for still-present rules
123    /// and flushing counts for rules that no longer exist.
124    ///
125    /// # Errors
126    ///
127    /// Returns [`CompressionError::Db`] if the store query fails.
128    pub async fn reload(&self) -> Result<(), CompressionError> {
129        let raw_rules = self.store.list_active().await?;
130        let mut compiled = Vec::with_capacity(raw_rules.len());
131
132        // Flush and remove hits for rules that are no longer in the store.
133        let active_ids: std::collections::HashSet<&str> =
134            raw_rules.iter().map(|r| r.id.as_str()).collect();
135        let stale_ids: Vec<String> = self
136            .hits
137            .iter()
138            .filter(|e| !active_ids.contains(e.key().as_str()))
139            .map(|e| e.key().clone())
140            .collect();
141        for id in stale_ids {
142            self.hits.remove(&id);
143        }
144
145        for rule in raw_rules {
146            let glob = if let Some(ref g) = rule.tool_glob {
147                match globset::Glob::new(g) {
148                    Ok(glob) => Some(glob.compile_matcher()),
149                    Err(e) => {
150                        tracing::warn!(rule_id = %rule.id, error = %e, "reload: invalid glob");
151                        continue;
152                    }
153                }
154            } else {
155                None
156            };
157
158            match safe_compile(&rule.pattern, self.regex_timeout_ms).await {
159                Ok(re) => {
160                    self.hits
161                        .entry(rule.id.clone())
162                        .or_insert_with(|| AtomicU64::new(0));
163                    compiled.push(CompiledRule {
164                        id: rule.id,
165                        glob,
166                        pattern: re,
167                        replacement_template: rule.replacement_template,
168                    });
169                }
170                Err(e) => {
171                    tracing::warn!(rule_id = %rule.id, error = %e, "reload: compile failed");
172                }
173            }
174        }
175
176        compiled.sort_unstable_by(|a, b| a.id.cmp(&b.id));
177        *self.rules.write() = compiled;
178        Ok(())
179    }
180
181    /// Drain pending hit counts into a batch and write them to the store.
182    ///
183    /// Called during the `maybe_autodream` maintenance pass. Resets all counters
184    /// to zero after flushing.
185    ///
186    /// # Errors
187    ///
188    /// Returns a database error if the batch write fails.
189    pub async fn flush_hit_counts(&self) -> Result<(), CompressionError> {
190        let batch: Vec<(String, u64)> = self
191            .hits
192            .iter()
193            .filter_map(|e| {
194                let delta = e.value().swap(0, Ordering::Relaxed);
195                if delta > 0 {
196                    Some((e.key().clone(), delta))
197                } else {
198                    None
199                }
200            })
201            .collect();
202
203        if batch.is_empty() {
204            return Ok(());
205        }
206
207        self.store.increment_hits(&batch).await?;
208        Ok(())
209    }
210}
211
212impl OutputCompressor for RuleBasedCompressor {
213    fn compress<'a>(
214        &'a self,
215        tool_name: &'a ToolName,
216        output: &'a str,
217    ) -> Pin<Box<dyn Future<Output = Result<Option<String>, CompressionError>> + Send + 'a>> {
218        Box::pin(async move {
219            // Drop span guard before first await; EnteredSpan is not Send.
220            drop(
221                tracing::info_span!("tools.compression.compress", tool = %tool_name.as_str())
222                    .entered(),
223            );
224            let rules = self.rules.read();
225            for rule in rules.iter() {
226                if rule
227                    .glob
228                    .as_ref()
229                    .is_some_and(|g| !g.is_match(tool_name.as_str()))
230                {
231                    continue;
232                }
233                if rule.pattern.is_match(output) {
234                    let compressed = rule
235                        .pattern
236                        .replace_all(output, rule.replacement_template.as_str())
237                        .into_owned();
238
239                    if compressed.len() < output.len() {
240                        if let Some(entry) = self.hits.get(&rule.id) {
241                            entry.fetch_add(1, Ordering::Relaxed);
242                        }
243                        tracing::debug!(
244                            rule_id = %rule.id,
245                            tool = %tool_name.as_str(),
246                            original_bytes = output.len(),
247                            compressed_bytes = compressed.len(),
248                            "compression applied"
249                        );
250                        return Ok(Some(compressed));
251                    }
252                }
253            }
254            Ok(None)
255        })
256    }
257
258    fn name(&self) -> &'static str {
259        "rule_based"
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use std::sync::Arc;
266
267    use zeph_common::ToolName;
268
269    use super::*;
270    use crate::compression::{CompressionRuleStore, OutputCompressor, store::CompressionRule};
271
272    async fn make_store_with_rules(rules: &[(&str, &str)]) -> Arc<CompressionRuleStore> {
273        let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
274        sqlx::query(
275            "CREATE TABLE compression_rules (\
276             id TEXT PRIMARY KEY, tool_glob TEXT, pattern TEXT NOT NULL, \
277             replacement_template TEXT NOT NULL, hit_count INTEGER NOT NULL DEFAULT 0, \
278             source TEXT NOT NULL DEFAULT 'operator', created_at TEXT NOT NULL, \
279             UNIQUE(tool_glob, pattern))",
280        )
281        .execute(&pool)
282        .await
283        .unwrap();
284
285        let store = Arc::new(CompressionRuleStore::new(Arc::new(pool)));
286        for (i, (pattern, replacement)) in rules.iter().enumerate() {
287            store
288                .upsert(&CompressionRule {
289                    id: format!("rule-{i}"),
290                    tool_glob: None,
291                    pattern: (*pattern).to_owned(),
292                    replacement_template: (*replacement).to_owned(),
293                    hit_count: 0,
294                    source: "operator".to_owned(),
295                    created_at: "2026-01-01T00:00:00Z".to_owned(),
296                })
297                .await
298                .unwrap();
299        }
300        store
301    }
302
303    #[tokio::test]
304    async fn compress_returns_none_when_no_rule_matches() {
305        let store = make_store_with_rules(&[(r"\d+", "N")]).await;
306        let compressor = RuleBasedCompressor::load(store, 2, 500).await.unwrap();
307        let tool = ToolName::new("shell");
308        // Input has no digits — pattern won't match.
309        let input = "line\n".repeat(10);
310        let result = compressor.compress(&tool, &input).await.unwrap();
311        assert!(result.is_none(), "no rule should match non-digit input");
312    }
313
314    #[tokio::test]
315    async fn compress_applies_matching_rule() {
316        // Replace every digit sequence with "N".
317        let store = make_store_with_rules(&[(r"\d+", "N")]).await;
318        let compressor = RuleBasedCompressor::load(store, 2, 500).await.unwrap();
319        let tool = ToolName::new("shell");
320        // 10 lines, each "12345\n" → replaced with "N\n".
321        let input: String = "12345\n".repeat(10);
322        let result = compressor.compress(&tool, &input).await.unwrap();
323        assert!(result.is_some(), "rule should have matched");
324        let compressed = result.unwrap();
325        assert!(compressed.len() < input.len(), "compressed must be shorter");
326        assert!(compressed.contains('N'));
327    }
328
329    #[tokio::test]
330    async fn compress_returns_none_when_not_shorter() {
331        // Replacement is longer than original — should not be returned.
332        let long_replacement = "VERY_LONG_REPLACEMENT_THAT_IS_DEFINITELY_LONGER_THAN_ORIGINAL";
333        let store = make_store_with_rules(&[(r"\d", long_replacement)]).await;
334        let compressor = RuleBasedCompressor::load(store, 2, 500).await.unwrap();
335        let tool = ToolName::new("shell");
336        let input = "1\n".repeat(10);
337        let result = compressor.compress(&tool, &input).await.unwrap();
338        assert!(
339            result.is_none(),
340            "compression that doesn't reduce size should return None"
341        );
342    }
343}