1use std::sync::Arc;
7
8use zeph_db::DbPool;
9
10#[derive(Debug, Clone, sqlx::FromRow)]
12pub struct CompressionRule {
13 pub id: String,
15 pub tool_glob: Option<String>,
17 pub pattern: String,
19 pub replacement_template: String,
21 pub hit_count: i64,
23 pub source: String,
25 pub created_at: String,
27}
28
29#[derive(Clone)]
35pub struct CompressionRuleStore {
36 pool: Arc<DbPool>,
37}
38
39impl CompressionRuleStore {
40 #[must_use]
42 pub fn new(pool: Arc<DbPool>) -> Self {
43 Self { pool }
44 }
45
46 pub async fn list_active(&self) -> Result<Vec<CompressionRule>, zeph_db::SqlxError> {
52 sqlx::query_as(zeph_db::sql!(
53 "SELECT id, tool_glob, pattern, replacement_template, hit_count, source, created_at \
54 FROM compression_rules ORDER BY hit_count ASC"
55 ))
56 .fetch_all(self.pool.as_ref())
57 .await
58 }
59
60 pub async fn upsert(&self, rule: &CompressionRule) -> Result<(), zeph_db::SqlxError> {
66 sqlx::query(zeph_db::sql!(
67 "INSERT INTO compression_rules \
68 (id, tool_glob, pattern, replacement_template, hit_count, source, created_at) \
69 VALUES (?, ?, ?, ?, ?, ?, ?) \
70 ON CONFLICT(tool_glob, pattern) DO UPDATE SET \
71 replacement_template = excluded.replacement_template, \
72 source = excluded.source"
73 ))
74 .bind(&rule.id)
75 .bind(&rule.tool_glob)
76 .bind(&rule.pattern)
77 .bind(&rule.replacement_template)
78 .bind(rule.hit_count)
79 .bind(&rule.source)
80 .bind(&rule.created_at)
81 .execute(self.pool.as_ref())
82 .await?;
83 Ok(())
84 }
85
86 pub async fn increment_hits(&self, batch: &[(String, u64)]) -> Result<(), zeph_db::SqlxError> {
96 for (id, delta) in batch {
97 sqlx::query(zeph_db::sql!(
98 "UPDATE compression_rules SET hit_count = hit_count + ? WHERE id = ?"
99 ))
100 .bind((*delta).cast_signed())
101 .bind(id.as_str())
102 .execute(self.pool.as_ref())
103 .await?;
104 }
105 Ok(())
106 }
107
108 pub async fn delete(&self, id: &str) -> Result<(), zeph_db::SqlxError> {
114 sqlx::query(zeph_db::sql!("DELETE FROM compression_rules WHERE id = ?"))
115 .bind(id)
116 .execute(self.pool.as_ref())
117 .await?;
118 Ok(())
119 }
120
121 pub async fn prune_lowest_hits(&self, max_rules: u32) -> Result<u64, zeph_db::SqlxError> {
129 let count: i64 =
130 sqlx::query_scalar(zeph_db::sql!("SELECT COUNT(*) FROM compression_rules"))
131 .fetch_one(self.pool.as_ref())
132 .await?;
133
134 if count <= i64::from(max_rules) {
135 return Ok(0);
136 }
137
138 let to_delete = count - i64::from(max_rules);
139 let result = sqlx::query(zeph_db::sql!(
140 "DELETE FROM compression_rules WHERE id IN \
141 (SELECT id FROM compression_rules ORDER BY hit_count ASC LIMIT ?)"
142 ))
143 .bind(to_delete)
144 .execute(self.pool.as_ref())
145 .await?;
146
147 Ok(result.rows_affected())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::sync::Arc;
154
155 use super::{CompressionRule, CompressionRuleStore};
156
157 async fn make_store() -> (CompressionRuleStore, sqlx::SqlitePool) {
158 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
159 sqlx::query(
160 "CREATE TABLE compression_rules (\
161 id TEXT PRIMARY KEY, tool_glob TEXT, pattern TEXT NOT NULL, \
162 replacement_template TEXT NOT NULL, hit_count INTEGER NOT NULL DEFAULT 0, \
163 source TEXT NOT NULL DEFAULT 'operator', created_at TEXT NOT NULL, \
164 UNIQUE(tool_glob, pattern))",
165 )
166 .execute(&pool)
167 .await
168 .unwrap();
169 let store = CompressionRuleStore::new(Arc::new(pool.clone()));
170 (store, pool)
171 }
172
173 fn rule(
174 id: &str,
175 tool_glob: Option<&str>,
176 pattern: &str,
177 replacement: &str,
178 hits: i64,
179 source: &str,
180 ) -> CompressionRule {
181 CompressionRule {
182 id: id.to_owned(),
183 tool_glob: tool_glob.map(ToOwned::to_owned),
184 pattern: pattern.to_owned(),
185 replacement_template: replacement.to_owned(),
186 hit_count: hits,
187 source: source.to_owned(),
188 created_at: "2026-01-01T00:00:00Z".to_owned(),
189 }
190 }
191
192 #[tokio::test]
195 async fn list_active_empty() {
196 let (store, _pool) = make_store().await;
197 let rules = store.list_active().await.unwrap();
198 assert!(rules.is_empty());
199 }
200
201 #[tokio::test]
202 async fn list_active_returns_ordered_by_hits_asc() {
203 let (store, _pool) = make_store().await;
206 store
207 .upsert(&rule("a", None, "pa", "ra", 10, "operator"))
208 .await
209 .unwrap();
210 store
211 .upsert(&rule("b", None, "pb", "rb", 0, "operator"))
212 .await
213 .unwrap();
214 store
215 .upsert(&rule("c", None, "pc", "rc", 5, "operator"))
216 .await
217 .unwrap();
218
219 let rules = store.list_active().await.unwrap();
220 assert_eq!(rules.len(), 3);
221 assert_eq!(rules[0].hit_count, 0);
222 assert_eq!(rules[1].hit_count, 5);
223 assert_eq!(rules[2].hit_count, 10);
224 }
225
226 #[tokio::test]
229 async fn upsert_inserts_new_rule() {
230 let (store, _pool) = make_store().await;
231 store
232 .upsert(&rule("r1", Some("shell"), "pat", "tmpl", 0, "operator"))
233 .await
234 .unwrap();
235
236 let rules = store.list_active().await.unwrap();
237 assert_eq!(rules.len(), 1);
238 let r = &rules[0];
239 assert_eq!(r.id, "r1");
240 assert_eq!(r.tool_glob.as_deref(), Some("shell"));
241 assert_eq!(r.pattern, "pat");
242 assert_eq!(r.replacement_template, "tmpl");
243 assert_eq!(r.source, "operator");
244 }
245
246 #[tokio::test]
247 async fn upsert_conflict_updates_template_and_source() {
248 let (store, _pool) = make_store().await;
250 store
251 .upsert(&rule("r1", Some("shell"), "pat", "old-tmpl", 5, "operator"))
252 .await
253 .unwrap();
254 store
255 .upsert(&rule(
256 "r2",
257 Some("shell"),
258 "pat",
259 "new-tmpl",
260 0,
261 "llm-evolved",
262 ))
263 .await
264 .unwrap();
265
266 let rules = store.list_active().await.unwrap();
267 assert_eq!(rules.len(), 1);
268 assert_eq!(rules[0].id, "r1");
270 assert_eq!(rules[0].replacement_template, "new-tmpl");
271 assert_eq!(rules[0].source, "llm-evolved");
272 assert_eq!(rules[0].hit_count, 5);
274 }
275
276 #[tokio::test]
277 async fn upsert_null_tool_glob_distinct() {
278 let (store, _pool) = make_store().await;
282 store
283 .upsert(&rule("r1", None, "same-pat", "ra", 0, "operator"))
284 .await
285 .unwrap();
286 store
287 .upsert(&rule("r2", None, "same-pat", "rb", 0, "operator"))
288 .await
289 .unwrap();
290
291 let rules = store.list_active().await.unwrap();
292 assert_eq!(rules.len(), 2);
293 }
294
295 #[tokio::test]
296 async fn upsert_preserves_hit_count_on_conflict() {
297 let (store, _pool) = make_store().await;
301 store
302 .upsert(&rule("r1", Some("shell"), "pat", "tmpl", 5, "operator"))
303 .await
304 .unwrap();
305 store
307 .upsert(&rule("r2", Some("shell"), "pat", "tmpl2", 0, "operator"))
308 .await
309 .unwrap();
310
311 let rules = store.list_active().await.unwrap();
312 assert_eq!(rules.len(), 1);
313 assert_eq!(
314 rules[0].hit_count, 5,
315 "hit_count must not be reset by ON CONFLICT"
316 );
317 }
318
319 #[tokio::test]
322 async fn increment_hits_single() {
323 let (store, _pool) = make_store().await;
324 store
325 .upsert(&rule("r1", None, "pat", "tmpl", 0, "operator"))
326 .await
327 .unwrap();
328
329 store.increment_hits(&[("r1".to_owned(), 3)]).await.unwrap();
330
331 let rules = store.list_active().await.unwrap();
332 assert_eq!(rules[0].hit_count, 3);
333 }
334
335 #[tokio::test]
336 async fn increment_hits_batch() {
337 let (store, _pool) = make_store().await;
338 store
339 .upsert(&rule("r1", None, "p1", "t1", 0, "operator"))
340 .await
341 .unwrap();
342 store
343 .upsert(&rule("r2", None, "p2", "t2", 10, "operator"))
344 .await
345 .unwrap();
346 store
347 .upsert(&rule("r3", None, "p3", "t3", 0, "operator"))
348 .await
349 .unwrap();
350
351 store
352 .increment_hits(&[
353 ("r1".to_owned(), 2),
354 ("r2".to_owned(), 5),
355 ("r3".to_owned(), 1),
356 ])
357 .await
358 .unwrap();
359
360 let rules = store.list_active().await.unwrap();
361 let by_id = |id: &str| rules.iter().find(|r| r.id == id).unwrap().hit_count;
362 assert_eq!(by_id("r1"), 2);
363 assert_eq!(by_id("r2"), 15);
364 assert_eq!(by_id("r3"), 1);
365 }
366
367 #[tokio::test]
368 async fn increment_hits_nonexistent_id() {
369 let (store, _pool) = make_store().await;
370 store
372 .increment_hits(&[("ghost".to_owned(), 1)])
373 .await
374 .unwrap();
375 }
376
377 #[tokio::test]
378 async fn increment_hits_empty_batch() {
379 let (store, _pool) = make_store().await;
380 store
381 .upsert(&rule("r1", None, "pat", "tmpl", 7, "operator"))
382 .await
383 .unwrap();
384
385 store.increment_hits(&[]).await.unwrap();
386
387 let rules = store.list_active().await.unwrap();
388 assert_eq!(
389 rules[0].hit_count, 7,
390 "empty batch must not modify existing rules"
391 );
392 }
393
394 #[tokio::test]
397 async fn delete_removes_rule() {
398 let (store, _pool) = make_store().await;
399 store
400 .upsert(&rule("r1", None, "pat", "tmpl", 0, "operator"))
401 .await
402 .unwrap();
403
404 store.delete("r1").await.unwrap();
405
406 let rules = store.list_active().await.unwrap();
407 assert!(rules.is_empty());
408 }
409
410 #[tokio::test]
411 async fn delete_nonexistent_is_noop() {
412 let (store, _pool) = make_store().await;
413 store.delete("ghost").await.unwrap();
415 }
416
417 #[tokio::test]
420 async fn prune_fast_path_no_deletion() {
421 let (store, _pool) = make_store().await;
422 store
423 .upsert(&rule("r1", None, "p1", "t1", 1, "operator"))
424 .await
425 .unwrap();
426 store
427 .upsert(&rule("r2", None, "p2", "t2", 2, "operator"))
428 .await
429 .unwrap();
430
431 let deleted = store.prune_lowest_hits(5).await.unwrap();
432 assert_eq!(deleted, 0);
433 assert_eq!(store.list_active().await.unwrap().len(), 2);
434 }
435
436 #[tokio::test]
437 async fn prune_deletes_lowest_hit_rules() {
438 let (store, _pool) = make_store().await;
439 for (i, hits) in [1i64, 2, 3, 4, 5].iter().enumerate() {
440 store
441 .upsert(&rule(
442 &format!("r{i}"),
443 None,
444 &format!("p{i}"),
445 "t",
446 *hits,
447 "operator",
448 ))
449 .await
450 .unwrap();
451 }
452
453 let deleted = store.prune_lowest_hits(3).await.unwrap();
454 assert_eq!(deleted, 2);
455
456 let remaining = store.list_active().await.unwrap();
457 assert_eq!(remaining.len(), 3);
458 assert!(remaining.iter().all(|r| r.hit_count >= 3));
459 }
460
461 #[tokio::test]
462 async fn prune_exact_boundary() {
463 let (store, _pool) = make_store().await;
464 store
465 .upsert(&rule("r1", None, "p1", "t1", 1, "operator"))
466 .await
467 .unwrap();
468 store
469 .upsert(&rule("r2", None, "p2", "t2", 2, "operator"))
470 .await
471 .unwrap();
472 store
473 .upsert(&rule("r3", None, "p3", "t3", 3, "operator"))
474 .await
475 .unwrap();
476
477 let deleted = store.prune_lowest_hits(3).await.unwrap();
479 assert_eq!(deleted, 0);
480 assert_eq!(store.list_active().await.unwrap().len(), 3);
481 }
482
483 #[tokio::test]
484 async fn prune_max_rules_zero_deletes_all() {
485 let (store, _pool) = make_store().await;
486 store
487 .upsert(&rule("r1", None, "p1", "t1", 1, "operator"))
488 .await
489 .unwrap();
490 store
491 .upsert(&rule("r2", None, "p2", "t2", 2, "operator"))
492 .await
493 .unwrap();
494 store
495 .upsert(&rule("r3", None, "p3", "t3", 3, "operator"))
496 .await
497 .unwrap();
498
499 let deleted = store.prune_lowest_hits(0).await.unwrap();
500 assert_eq!(deleted, 3);
501 assert!(store.list_active().await.unwrap().is_empty());
502 }
503}