1#![allow(dead_code)]
22
23use std::sync::Arc;
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25
26use thiserror::Error;
27use tokio::sync::Mutex as AsyncMutex;
28use tracing::{debug, warn};
29use zeph_db::DbPool;
30
31use super::prediction::{Prediction, PredictionSource};
32use crate::agent::speculative::cache::{args_template, hash_args};
33
34const Z: f64 = 1.645;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ToolOutcome {
40 Success,
41 Failure,
42}
43
44#[derive(Debug, Error)]
46pub enum PatternError {
47 #[error("database error: {0}")]
48 Db(#[from] zeph_db::sqlx::Error),
49 #[error("json error: {0}")]
50 Json(#[from] serde_json::Error),
51}
52
53struct RefreshState {
55 last_refresh: Option<std::time::Instant>,
56}
57
58pub struct PatternStore {
76 pool: DbPool,
77 half_life_days: f64,
78 refresh_debounce: Arc<AsyncMutex<std::collections::HashMap<String, RefreshState>>>,
79 min_observations: u32,
80}
81
82impl PatternStore {
83 #[must_use]
87 pub fn new(pool: DbPool, half_life_days: f64) -> Self {
88 Self {
89 pool,
90 half_life_days,
91 refresh_debounce: Arc::new(AsyncMutex::new(std::collections::HashMap::new())),
92 min_observations: 5,
93 }
94 }
95
96 #[must_use]
98 pub fn with_min_observations(mut self, n: u32) -> Self {
99 self.min_observations = n;
100 self
101 }
102
103 #[allow(clippy::too_many_arguments)]
114 pub async fn observe(
115 &self,
116 skill_name: &str,
117 skill_hash: &str,
118 prev_tool: Option<&str>,
119 next_tool: &str,
120 args_json: &str,
121 outcome: ToolOutcome,
122 latency_ms: u64,
123 ) -> Result<(), PatternError> {
124 let now = unix_now();
125 let half_life_secs = self.half_life_days * 86_400.0;
126 let success_delta = i64::from(outcome == ToolOutcome::Success);
127 let args: serde_json::Value = serde_json::from_str(args_json)?;
128 let args_obj = args.as_object().cloned().unwrap_or_default();
129 let args_fingerprint = {
130 let h = hash_args(&args_obj);
131 h.to_hex().to_string()
132 };
133 let tmpl = args_template(&args_obj);
134 #[allow(clippy::cast_possible_wrap)]
135 let latency_i64 = latency_ms as i64;
136
137 let existing = zeph_db::query_as::<_, (f64, i64, i64, i64)>(
141 r"
142 SELECT count_decayed, last_seen_at, count_raw, avg_latency_ms
143 FROM tool_pattern_transitions
144 WHERE skill_name = ? AND skill_hash = ?
145 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
146 AND next_tool = ? AND args_fingerprint = ?
147 ",
148 )
149 .bind(skill_name)
150 .bind(skill_hash)
151 .bind(prev_tool)
152 .bind(prev_tool)
153 .bind(next_tool)
154 .bind(&args_fingerprint)
155 .fetch_optional(&self.pool)
156 .await?;
157
158 if let Some((old_decayed, last_seen_at, old_count_raw, old_avg_latency)) = existing {
159 #[allow(clippy::cast_precision_loss)]
160 let elapsed = (now - last_seen_at).max(0) as f64;
161 let new_decayed = old_decayed * 0.5f64.powf(elapsed / half_life_secs) + 1.0;
162 let new_count_raw = old_count_raw + 1;
163 #[allow(clippy::cast_precision_loss)]
164 let new_avg_latency = (old_avg_latency * old_count_raw + latency_i64) / new_count_raw;
165
166 zeph_db::query(
167 r"
168 UPDATE tool_pattern_transitions SET
169 count_decayed = ?,
170 count_raw = ?,
171 success_raw = success_raw + ?,
172 last_seen_at = ?,
173 avg_latency_ms = ?
174 WHERE skill_name = ? AND skill_hash = ?
175 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
176 AND next_tool = ? AND args_fingerprint = ?
177 ",
178 )
179 .bind(new_decayed)
180 .bind(new_count_raw)
181 .bind(success_delta)
182 .bind(now)
183 .bind(new_avg_latency)
184 .bind(skill_name)
185 .bind(skill_hash)
186 .bind(prev_tool)
187 .bind(prev_tool)
188 .bind(next_tool)
189 .bind(&args_fingerprint)
190 .execute(&self.pool)
191 .await?;
192 } else {
193 zeph_db::query(
194 r"
195 INSERT INTO tool_pattern_transitions
196 (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
197 args_template, count_raw, success_raw, count_decayed, last_seen_at, avg_latency_ms)
198 VALUES (?, ?, ?, ?, ?, ?, 1, ?, 1.0, ?, ?)
199 ",
200 )
201 .bind(skill_name)
202 .bind(skill_hash)
203 .bind(prev_tool)
204 .bind(next_tool)
205 .bind(&args_fingerprint)
206 .bind(&tmpl)
207 .bind(success_delta)
208 .bind(now)
209 .bind(latency_i64)
210 .execute(&self.pool)
211 .await?;
212 }
213
214 self.debounced_refresh(skill_name, skill_hash, prev_tool)
215 .await;
216 Ok(())
217 }
218
219 pub async fn predict(
228 &self,
229 skill_name: &str,
230 skill_hash: &str,
231 prev_tool: Option<&str>,
232 k: u8,
233 ) -> Result<Vec<Prediction>, PatternError> {
234 let rows = zeph_db::query_as::<_, (String, String, f64, f64, i64)>(
235 r"
236 SELECT next_tool, args_template, score, wilson_lower_bound, rank
237 FROM tool_pattern_predictions
238 WHERE skill_name = ? AND skill_hash = ?
239 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
240 AND wilson_lower_bound >= 0.5
241 ORDER BY rank ASC
242 LIMIT ?
243 ",
244 )
245 .bind(skill_name)
246 .bind(skill_hash)
247 .bind(prev_tool)
248 .bind(prev_tool)
249 .bind(i64::from(k))
250 .fetch_all(&self.pool)
251 .await?;
252
253 let predictions = rows
254 .into_iter()
255 .enumerate()
256 .filter_map(|(i, (next_tool, args_template, score, _wilson, _rank))| {
257 let args: serde_json::Map<String, serde_json::Value> =
258 serde_json::from_str(&args_template).ok()?;
259 Some(Prediction {
260 tool_id: zeph_common::ToolName::new(next_tool),
261 args,
262 #[allow(clippy::cast_possible_truncation)]
263 confidence: score as f32,
264 source: PredictionSource::HistoryPattern {
265 skill: skill_name.to_owned(),
266 #[allow(clippy::cast_possible_truncation)]
267 rank: i as u8,
268 },
269 })
270 })
271 .collect();
272
273 Ok(predictions)
274 }
275
276 pub async fn refresh(
285 &self,
286 skill_name: &str,
287 skill_hash: &str,
288 prev_tool: Option<&str>,
289 ) -> Result<(), PatternError> {
290 let min_obs = self.min_observations;
291 let half_life_secs = self.half_life_days * 86_400.0;
292 let now = unix_now();
293
294 let rows = zeph_db::query_as::<_, (String, String, String, i64, i64, f64, i64)>(
297 r"
298 SELECT next_tool, args_fingerprint, args_template,
299 count_raw, success_raw, count_decayed, last_seen_at
300 FROM tool_pattern_transitions
301 WHERE skill_name = ? AND skill_hash = ?
302 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
303 AND count_raw >= ?
304 ",
305 )
306 .bind(skill_name)
307 .bind(skill_hash)
308 .bind(prev_tool)
309 .bind(prev_tool)
310 .bind(i64::from(min_obs))
311 .fetch_all(&self.pool)
312 .await?;
313
314 if rows.is_empty() {
315 return Ok(());
316 }
317
318 let scored = score_rows(rows, now, half_life_secs);
320 if scored.is_empty() {
321 return Ok(());
322 }
323
324 let mut tx = zeph_db::begin(&self.pool).await?;
326
327 zeph_db::query(
328 "DELETE FROM tool_pattern_predictions \
329 WHERE skill_name = ? AND skill_hash = ? \
330 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))",
331 )
332 .bind(skill_name)
333 .bind(skill_hash)
334 .bind(prev_tool)
335 .bind(prev_tool)
336 .execute(&mut *tx)
337 .await?;
338
339 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
340 for (rank, (next_tool, args_fp, tmpl, score, wilson)) in scored.iter().enumerate().take(10)
341 {
342 let rank_i64 = rank as i64;
343 zeph_db::query(
344 r"
345 INSERT OR REPLACE INTO tool_pattern_predictions
346 (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
347 args_template, score, wilson_lower_bound, rank)
348 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
349 ",
350 )
351 .bind(skill_name)
352 .bind(skill_hash)
353 .bind(prev_tool)
354 .bind(next_tool)
355 .bind(args_fp)
356 .bind(tmpl)
357 .bind(score)
358 .bind(wilson)
359 .bind(rank_i64)
360 .execute(&mut *tx)
361 .await?;
362 }
363
364 tx.commit().await?;
365
366 debug!(
367 skill = skill_name,
368 prev_tool = prev_tool.unwrap_or("<activation>"),
369 "PASTE: refreshed {} predictions",
370 scored.len().min(10)
371 );
372 Ok(())
373 }
374
375 pub async fn vacuum(&self) -> Result<u64, PatternError> {
381 let cutoff = unix_now() - 30 * 86_400;
382 let result = zeph_db::query("DELETE FROM tool_pattern_transitions WHERE last_seen_at < ?")
383 .bind(cutoff)
384 .execute(&self.pool)
385 .await?;
386 let rows = result.rows_affected();
387 if rows > 0 {
388 debug!("PASTE vacuum: removed {} stale rows", rows);
389 }
390 Ok(rows)
391 }
392
393 async fn debounced_refresh(&self, skill_name: &str, skill_hash: &str, prev_tool: Option<&str>) {
394 let key = format!("{skill_hash}:{}", prev_tool.unwrap_or(""));
395 let should_refresh = {
396 let mut map = self.refresh_debounce.lock().await;
397 let state = map
398 .entry(key.clone())
399 .or_insert(RefreshState { last_refresh: None });
400 match state.last_refresh {
401 None => true,
402 Some(t) => t.elapsed() >= Duration::from_mins(1),
403 }
404 };
405 if should_refresh {
406 if let Err(e) = self.refresh(skill_name, skill_hash, prev_tool).await {
407 warn!("PASTE refresh failed: {e}");
408 }
409 let mut map = self.refresh_debounce.lock().await;
410 if let Some(state) = map.get_mut(&key) {
411 state.last_refresh = Some(std::time::Instant::now());
412 }
413 }
414 }
415}
416
417fn score_rows(
422 rows: Vec<(String, String, String, i64, i64, f64, i64)>,
423 now: i64,
424 half_life_secs: f64,
425) -> Vec<(String, String, String, f64, f64)> {
426 let decayed: Vec<_> = rows
427 .into_iter()
428 .map(
429 |(tool, fp, tmpl, count_raw, success_raw, count_decayed, last_seen_at)| {
430 #[allow(clippy::cast_precision_loss)]
431 let elapsed = now.saturating_sub(last_seen_at) as f64;
432 let current_decay = count_decayed * 0.5f64.powf(elapsed / half_life_secs);
433 #[allow(clippy::cast_sign_loss)]
434 let wilson = wilson_lower_bound(success_raw as u64, count_raw as u64);
435 (tool, fp, tmpl, current_decay, wilson)
436 },
437 )
438 .collect();
439
440 let total: f64 = decayed.iter().map(|(_, _, _, d, _)| d).sum();
441 if total <= 0.0 {
442 return vec![];
443 }
444
445 let mut scored: Vec<_> = decayed
446 .into_iter()
447 .map(|(tool, fp, tmpl, d, wilson)| ((d / total) * wilson, tool, fp, tmpl, wilson))
448 .collect();
449 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
450 scored
451 .into_iter()
452 .map(|(score, tool, fp, tmpl, wilson)| (tool, fp, tmpl, score, wilson))
453 .collect()
454}
455
456#[allow(clippy::cast_possible_wrap)]
457fn unix_now() -> i64 {
458 SystemTime::now()
459 .duration_since(UNIX_EPOCH)
460 .unwrap_or_default()
461 .as_secs() as i64
462}
463
464#[allow(clippy::cast_precision_loss)]
466fn wilson_lower_bound(successes: u64, n: u64) -> f64 {
467 if n == 0 {
468 return 0.0;
469 }
470 let n = n as f64;
471 let p_hat = successes as f64 / n;
472 let z2 = Z * Z;
473 let numerator =
474 p_hat + z2 / (2.0 * n) - Z * (p_hat * (1.0 - p_hat) / n + z2 / (4.0 * n * n)).sqrt();
475 let denominator = 1.0 + z2 / n;
476 (numerator / denominator).max(0.0)
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn wilson_zero_observations() {
485 assert!((wilson_lower_bound(0, 0) - 0.0_f64).abs() < f64::EPSILON);
486 }
487
488 #[test]
489 fn wilson_all_success_small_n() {
490 let lb = wilson_lower_bound(3, 3);
492 assert!(lb > 0.0 && lb < 1.0, "got {lb}");
493 }
494
495 #[test]
496 fn wilson_zero_success() {
497 let lb = wilson_lower_bound(0, 10);
498 assert!(lb < 0.1, "got {lb}");
499 }
500
501 #[test]
502 fn fingerprint_deterministic_different_order() {
503 fn fp(json: &str) -> String {
504 let v: serde_json::Value = serde_json::from_str(json).unwrap();
505 let obj = v.as_object().cloned().unwrap_or_default();
506 hash_args(&obj).to_hex().to_string()
507 }
508 let a = r#"{"z": 1, "a": 2}"#;
509 let b = r#"{"a": 2, "z": 1}"#;
510 assert_eq!(fp(a), fp(b));
511 }
512}