1use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16use dashmap::DashMap;
17use parking_lot::RwLock;
18use zeph_common::ToolName;
19
20use super::{CompressionError, CompressionRuleStore, OutputCompressor, safe_compile};
21
22struct CompiledRule {
24 id: String,
25 glob: Option<globset::GlobMatcher>,
27 pattern: regex::Regex,
28 replacement_template: String,
29}
30
31pub struct RuleBasedCompressor {
45 rules: RwLock<Vec<CompiledRule>>,
46 hits: DashMap<String, AtomicU64>,
47 store: Arc<CompressionRuleStore>,
48 max_output_lines: usize,
49 regex_timeout_ms: u64,
50}
51
52impl std::fmt::Debug for RuleBasedCompressor {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("RuleBasedCompressor")
55 .field("rules_count", &self.rules.read().len())
56 .field("max_output_lines", &self.max_output_lines)
57 .field("regex_timeout_ms", &self.regex_timeout_ms)
58 .finish_non_exhaustive()
59 }
60}
61
62impl RuleBasedCompressor {
63 pub async fn load(
74 store: Arc<CompressionRuleStore>,
75 max_output_lines: usize,
76 regex_timeout_ms: u64,
77 ) -> Result<Self, CompressionError> {
78 let raw_rules = store.list_active().await?;
79 let mut compiled = Vec::with_capacity(raw_rules.len());
80 let hits = DashMap::new();
81
82 for rule in raw_rules {
83 let glob = if let Some(ref g) = rule.tool_glob {
84 match globset::Glob::new(g) {
85 Ok(glob) => Some(glob.compile_matcher()),
86 Err(e) => {
87 tracing::warn!(rule_id = %rule.id, pattern = %g, error = %e, "rule: invalid glob, skipping");
88 continue;
89 }
90 }
91 } else {
92 None
93 };
94
95 match super::safe_compile(&rule.pattern, regex_timeout_ms).await {
96 Ok(re) => {
97 hits.insert(rule.id.clone(), AtomicU64::new(0));
98 compiled.push(CompiledRule {
99 id: rule.id,
100 glob,
101 pattern: re,
102 replacement_template: rule.replacement_template,
103 });
104 }
105 Err(e) => {
106 tracing::warn!(rule_id = %rule.id, error = %e, "rule: compile failed, skipping");
107 }
108 }
109 }
110
111 compiled.sort_unstable_by(|a, b| a.id.cmp(&b.id));
112
113 Ok(Self {
114 rules: RwLock::new(compiled),
115 hits,
116 store,
117 max_output_lines,
118 regex_timeout_ms,
119 })
120 }
121
122 pub async fn reload(&self) -> Result<(), CompressionError> {
129 let raw_rules = self.store.list_active().await?;
130 let mut compiled = Vec::with_capacity(raw_rules.len());
131
132 let active_ids: std::collections::HashSet<&str> =
134 raw_rules.iter().map(|r| r.id.as_str()).collect();
135 let stale_ids: Vec<String> = self
136 .hits
137 .iter()
138 .filter(|e| !active_ids.contains(e.key().as_str()))
139 .map(|e| e.key().clone())
140 .collect();
141 for id in stale_ids {
142 self.hits.remove(&id);
143 }
144
145 for rule in raw_rules {
146 let glob = if let Some(ref g) = rule.tool_glob {
147 match globset::Glob::new(g) {
148 Ok(glob) => Some(glob.compile_matcher()),
149 Err(e) => {
150 tracing::warn!(rule_id = %rule.id, error = %e, "reload: invalid glob");
151 continue;
152 }
153 }
154 } else {
155 None
156 };
157
158 match safe_compile(&rule.pattern, self.regex_timeout_ms).await {
159 Ok(re) => {
160 self.hits
161 .entry(rule.id.clone())
162 .or_insert_with(|| AtomicU64::new(0));
163 compiled.push(CompiledRule {
164 id: rule.id,
165 glob,
166 pattern: re,
167 replacement_template: rule.replacement_template,
168 });
169 }
170 Err(e) => {
171 tracing::warn!(rule_id = %rule.id, error = %e, "reload: compile failed");
172 }
173 }
174 }
175
176 compiled.sort_unstable_by(|a, b| a.id.cmp(&b.id));
177 *self.rules.write() = compiled;
178 Ok(())
179 }
180
181 pub async fn flush_hit_counts(&self) -> Result<(), CompressionError> {
190 let batch: Vec<(String, u64)> = self
191 .hits
192 .iter()
193 .filter_map(|e| {
194 let delta = e.value().swap(0, Ordering::Relaxed);
195 if delta > 0 {
196 Some((e.key().clone(), delta))
197 } else {
198 None
199 }
200 })
201 .collect();
202
203 if batch.is_empty() {
204 return Ok(());
205 }
206
207 self.store.increment_hits(&batch).await?;
208 Ok(())
209 }
210}
211
212impl OutputCompressor for RuleBasedCompressor {
213 fn compress<'a>(
214 &'a self,
215 tool_name: &'a ToolName,
216 output: &'a str,
217 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, CompressionError>> + Send + 'a>> {
218 Box::pin(async move {
219 drop(
221 tracing::info_span!("tools.compression.compress", tool = %tool_name.as_str())
222 .entered(),
223 );
224 let rules = self.rules.read();
225 for rule in rules.iter() {
226 if rule
227 .glob
228 .as_ref()
229 .is_some_and(|g| !g.is_match(tool_name.as_str()))
230 {
231 continue;
232 }
233 if rule.pattern.is_match(output) {
234 let compressed = rule
235 .pattern
236 .replace_all(output, rule.replacement_template.as_str())
237 .into_owned();
238
239 if compressed.len() < output.len() {
240 if let Some(entry) = self.hits.get(&rule.id) {
241 entry.fetch_add(1, Ordering::Relaxed);
242 }
243 tracing::debug!(
244 rule_id = %rule.id,
245 tool = %tool_name.as_str(),
246 original_bytes = output.len(),
247 compressed_bytes = compressed.len(),
248 "compression applied"
249 );
250 return Ok(Some(compressed));
251 }
252 }
253 }
254 Ok(None)
255 })
256 }
257
258 fn name(&self) -> &'static str {
259 "rule_based"
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use std::sync::Arc;
266
267 use zeph_common::ToolName;
268
269 use super::*;
270 use crate::compression::{CompressionRuleStore, OutputCompressor, store::CompressionRule};
271
272 async fn make_store_with_rules(rules: &[(&str, &str)]) -> Arc<CompressionRuleStore> {
273 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
274 sqlx::query(
275 "CREATE TABLE compression_rules (\
276 id TEXT PRIMARY KEY, tool_glob TEXT, pattern TEXT NOT NULL, \
277 replacement_template TEXT NOT NULL, hit_count INTEGER NOT NULL DEFAULT 0, \
278 source TEXT NOT NULL DEFAULT 'operator', created_at TEXT NOT NULL, \
279 UNIQUE(tool_glob, pattern))",
280 )
281 .execute(&pool)
282 .await
283 .unwrap();
284
285 let store = Arc::new(CompressionRuleStore::new(Arc::new(pool)));
286 for (i, (pattern, replacement)) in rules.iter().enumerate() {
287 store
288 .upsert(&CompressionRule {
289 id: format!("rule-{i}"),
290 tool_glob: None,
291 pattern: (*pattern).to_owned(),
292 replacement_template: (*replacement).to_owned(),
293 hit_count: 0,
294 source: "operator".to_owned(),
295 created_at: "2026-01-01T00:00:00Z".to_owned(),
296 })
297 .await
298 .unwrap();
299 }
300 store
301 }
302
303 #[tokio::test]
304 async fn compress_returns_none_when_no_rule_matches() {
305 let store = make_store_with_rules(&[(r"\d+", "N")]).await;
306 let compressor = RuleBasedCompressor::load(store, 2, 500).await.unwrap();
307 let tool = ToolName::new("shell");
308 let input = "line\n".repeat(10);
310 let result = compressor.compress(&tool, &input).await.unwrap();
311 assert!(result.is_none(), "no rule should match non-digit input");
312 }
313
314 #[tokio::test]
315 async fn compress_applies_matching_rule() {
316 let store = make_store_with_rules(&[(r"\d+", "N")]).await;
318 let compressor = RuleBasedCompressor::load(store, 2, 500).await.unwrap();
319 let tool = ToolName::new("shell");
320 let input: String = "12345\n".repeat(10);
322 let result = compressor.compress(&tool, &input).await.unwrap();
323 assert!(result.is_some(), "rule should have matched");
324 let compressed = result.unwrap();
325 assert!(compressed.len() < input.len(), "compressed must be shorter");
326 assert!(compressed.contains('N'));
327 }
328
329 #[tokio::test]
330 async fn compress_returns_none_when_not_shorter() {
331 let long_replacement = "VERY_LONG_REPLACEMENT_THAT_IS_DEFINITELY_LONGER_THAN_ORIGINAL";
333 let store = make_store_with_rules(&[(r"\d", long_replacement)]).await;
334 let compressor = RuleBasedCompressor::load(store, 2, 500).await.unwrap();
335 let tool = ToolName::new("shell");
336 let input = "1\n".repeat(10);
337 let result = compressor.compress(&tool, &input).await.unwrap();
338 assert!(
339 result.is_none(),
340 "compression that doesn't reduce size should return None"
341 );
342 }
343}