1#![allow(dead_code)]
22
23use std::sync::Arc;
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25
26use thiserror::Error;
27use tokio::sync::Mutex as AsyncMutex;
28use tracing::{debug, warn};
29use zeph_db::DbPool;
30
31use super::prediction::{Prediction, PredictionSource};
32use crate::agent::speculative::cache::{args_template, hash_args};
33
34const Z: f64 = 1.645;
36
37#[non_exhaustive]
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum ToolOutcome {
41 Success,
42 Failure,
43}
44
45#[non_exhaustive]
46#[derive(Debug, Error)]
48pub enum PatternError {
49 #[error("sqlx error: {0}")]
50 Db(#[from] zeph_db::sqlx::Error),
51 #[error("json error: {0}")]
52 Json(#[from] serde_json::Error),
53}
54
55struct RefreshState {
57 last_refresh: Option<std::time::Instant>,
58}
59
60pub struct PatternStore {
78 pool: DbPool,
79 half_life_days: f64,
80 refresh_debounce: Arc<AsyncMutex<std::collections::HashMap<String, RefreshState>>>,
81 min_observations: u32,
82}
83
84impl PatternStore {
85 #[must_use]
89 pub fn new(pool: DbPool, half_life_days: f64) -> Self {
90 Self {
91 pool,
92 half_life_days,
93 refresh_debounce: Arc::new(AsyncMutex::new(std::collections::HashMap::new())),
94 min_observations: 5,
95 }
96 }
97
98 #[must_use]
100 pub fn with_min_observations(mut self, n: u32) -> Self {
101 self.min_observations = n;
102 self
103 }
104
105 #[allow(clippy::too_many_arguments)] pub async fn observe(
117 &self,
118 skill_name: &str,
119 skill_hash: &str,
120 prev_tool: Option<&str>,
121 next_tool: &str,
122 args_json: &str,
123 outcome: ToolOutcome,
124 latency_ms: u64,
125 ) -> Result<(), PatternError> {
126 let now = unix_now();
127 let half_life_secs = self.half_life_days * 86_400.0;
128 let success_delta = i64::from(outcome == ToolOutcome::Success);
129 let args: serde_json::Value = serde_json::from_str(args_json)?;
130 let args_obj = args.as_object().cloned().unwrap_or_default();
131 let args_fingerprint = {
132 let h = hash_args(&args_obj);
133 h.to_hex().to_string()
134 };
135 let tmpl = args_template(&args_obj);
136 #[allow(clippy::cast_possible_wrap)]
137 let latency_i64 = latency_ms as i64;
138
139 let existing = zeph_db::query_as::<_, (f64, i64, i64, i64)>(
143 r"
144 SELECT count_decayed, last_seen_at, count_raw, avg_latency_ms
145 FROM tool_pattern_transitions
146 WHERE skill_name = ? AND skill_hash = ?
147 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
148 AND next_tool = ? AND args_fingerprint = ?
149 ",
150 )
151 .bind(skill_name)
152 .bind(skill_hash)
153 .bind(prev_tool)
154 .bind(prev_tool)
155 .bind(next_tool)
156 .bind(&args_fingerprint)
157 .fetch_optional(&self.pool)
158 .await?;
159
160 if let Some((old_decayed, last_seen_at, old_count_raw, old_avg_latency)) = existing {
161 #[allow(clippy::cast_precision_loss)]
162 let elapsed = (now - last_seen_at).max(0) as f64;
163 let new_decayed = old_decayed * 0.5f64.powf(elapsed / half_life_secs) + 1.0;
164 let new_count_raw = old_count_raw + 1;
165 #[allow(clippy::cast_precision_loss)]
166 let new_avg_latency = (old_avg_latency * old_count_raw + latency_i64) / new_count_raw;
167
168 zeph_db::query(
169 r"
170 UPDATE tool_pattern_transitions SET
171 count_decayed = ?,
172 count_raw = ?,
173 success_raw = success_raw + ?,
174 last_seen_at = ?,
175 avg_latency_ms = ?
176 WHERE skill_name = ? AND skill_hash = ?
177 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
178 AND next_tool = ? AND args_fingerprint = ?
179 ",
180 )
181 .bind(new_decayed)
182 .bind(new_count_raw)
183 .bind(success_delta)
184 .bind(now)
185 .bind(new_avg_latency)
186 .bind(skill_name)
187 .bind(skill_hash)
188 .bind(prev_tool)
189 .bind(prev_tool)
190 .bind(next_tool)
191 .bind(&args_fingerprint)
192 .execute(&self.pool)
193 .await?;
194 } else {
195 zeph_db::query(
196 r"
197 INSERT INTO tool_pattern_transitions
198 (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
199 args_template, count_raw, success_raw, count_decayed, last_seen_at, avg_latency_ms)
200 VALUES (?, ?, ?, ?, ?, ?, 1, ?, 1.0, ?, ?)
201 ",
202 )
203 .bind(skill_name)
204 .bind(skill_hash)
205 .bind(prev_tool)
206 .bind(next_tool)
207 .bind(&args_fingerprint)
208 .bind(&tmpl)
209 .bind(success_delta)
210 .bind(now)
211 .bind(latency_i64)
212 .execute(&self.pool)
213 .await?;
214 }
215
216 self.debounced_refresh(skill_name, skill_hash, prev_tool)
217 .await;
218 Ok(())
219 }
220
221 pub async fn predict(
230 &self,
231 skill_name: &str,
232 skill_hash: &str,
233 prev_tool: Option<&str>,
234 k: u8,
235 ) -> Result<Vec<Prediction>, PatternError> {
236 let rows = zeph_db::query_as::<_, (String, String, f64, f64, i64)>(
237 r"
238 SELECT next_tool, args_template, score, wilson_lower_bound, rank
239 FROM tool_pattern_predictions
240 WHERE skill_name = ? AND skill_hash = ?
241 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
242 AND wilson_lower_bound >= 0.5
243 ORDER BY rank ASC
244 LIMIT ?
245 ",
246 )
247 .bind(skill_name)
248 .bind(skill_hash)
249 .bind(prev_tool)
250 .bind(prev_tool)
251 .bind(i64::from(k))
252 .fetch_all(&self.pool)
253 .await?;
254
255 let predictions = rows
256 .into_iter()
257 .enumerate()
258 .filter_map(|(i, (next_tool, args_template, score, _wilson, _rank))| {
259 let args: serde_json::Map<String, serde_json::Value> =
260 serde_json::from_str(&args_template).ok()?;
261 Some(Prediction {
262 tool_id: zeph_common::ToolName::new(next_tool),
263 args,
264 #[allow(clippy::cast_possible_truncation)]
265 confidence: score as f32,
266 source: PredictionSource::HistoryPattern {
267 skill: skill_name.to_owned(),
268 #[allow(clippy::cast_possible_truncation)]
269 rank: i as u8,
270 },
271 })
272 })
273 .collect();
274
275 Ok(predictions)
276 }
277
278 pub async fn refresh(
287 &self,
288 skill_name: &str,
289 skill_hash: &str,
290 prev_tool: Option<&str>,
291 ) -> Result<(), PatternError> {
292 let min_obs = self.min_observations;
293 let half_life_secs = self.half_life_days * 86_400.0;
294 let now = unix_now();
295
296 let rows = zeph_db::query_as::<_, (String, String, String, i64, i64, f64, i64)>(
299 r"
300 SELECT next_tool, args_fingerprint, args_template,
301 count_raw, success_raw, count_decayed, last_seen_at
302 FROM tool_pattern_transitions
303 WHERE skill_name = ? AND skill_hash = ?
304 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
305 AND count_raw >= ?
306 ",
307 )
308 .bind(skill_name)
309 .bind(skill_hash)
310 .bind(prev_tool)
311 .bind(prev_tool)
312 .bind(i64::from(min_obs))
313 .fetch_all(&self.pool)
314 .await?;
315
316 if rows.is_empty() {
317 return Ok(());
318 }
319
320 let scored = score_rows(rows, now, half_life_secs);
322 if scored.is_empty() {
323 return Ok(());
324 }
325
326 let mut tx = zeph_db::begin(&self.pool).await?;
328
329 zeph_db::query(
330 "DELETE FROM tool_pattern_predictions \
331 WHERE skill_name = ? AND skill_hash = ? \
332 AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))",
333 )
334 .bind(skill_name)
335 .bind(skill_hash)
336 .bind(prev_tool)
337 .bind(prev_tool)
338 .execute(&mut *tx)
339 .await?;
340
341 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
342 for (rank, (next_tool, args_fp, tmpl, score, wilson)) in scored.iter().enumerate().take(10)
343 {
344 let rank_i64 = rank as i64;
345 zeph_db::query(
346 r"
347 INSERT OR REPLACE INTO tool_pattern_predictions
348 (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
349 args_template, score, wilson_lower_bound, rank)
350 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
351 ",
352 )
353 .bind(skill_name)
354 .bind(skill_hash)
355 .bind(prev_tool)
356 .bind(next_tool)
357 .bind(args_fp)
358 .bind(tmpl)
359 .bind(score)
360 .bind(wilson)
361 .bind(rank_i64)
362 .execute(&mut *tx)
363 .await?;
364 }
365
366 tx.commit().await?;
367
368 debug!(
369 skill = skill_name,
370 prev_tool = prev_tool.unwrap_or("<activation>"),
371 "PASTE: refreshed {} predictions",
372 scored.len().min(10)
373 );
374 Ok(())
375 }
376
377 pub async fn vacuum(&self) -> Result<u64, PatternError> {
383 let cutoff = unix_now() - 30 * 86_400;
384 let result = zeph_db::query("DELETE FROM tool_pattern_transitions WHERE last_seen_at < ?")
385 .bind(cutoff)
386 .execute(&self.pool)
387 .await?;
388 let rows = result.rows_affected();
389 if rows > 0 {
390 debug!("PASTE vacuum: removed {} stale rows", rows);
391 }
392 Ok(rows)
393 }
394
395 async fn debounced_refresh(&self, skill_name: &str, skill_hash: &str, prev_tool: Option<&str>) {
396 let key = format!("{skill_hash}:{}", prev_tool.unwrap_or(""));
397 let should_refresh = {
398 let mut map = self.refresh_debounce.lock().await;
399 let state = map
400 .entry(key.clone())
401 .or_insert(RefreshState { last_refresh: None });
402 match state.last_refresh {
403 None => true,
404 Some(t) => t.elapsed() >= Duration::from_mins(1),
405 }
406 };
407 if should_refresh {
408 if let Err(e) = self.refresh(skill_name, skill_hash, prev_tool).await {
409 warn!("PASTE refresh failed: {e}");
410 }
411 let mut map = self.refresh_debounce.lock().await;
412 if let Some(state) = map.get_mut(&key) {
413 state.last_refresh = Some(std::time::Instant::now());
414 }
415 }
416 }
417}
418
419fn score_rows(
424 rows: Vec<(String, String, String, i64, i64, f64, i64)>,
425 now: i64,
426 half_life_secs: f64,
427) -> Vec<(String, String, String, f64, f64)> {
428 let decayed: Vec<_> = rows
429 .into_iter()
430 .map(
431 |(tool, fp, tmpl, count_raw, success_raw, count_decayed, last_seen_at)| {
432 #[allow(clippy::cast_precision_loss)]
433 let elapsed = now.saturating_sub(last_seen_at) as f64;
434 let current_decay = count_decayed * 0.5f64.powf(elapsed / half_life_secs);
435 #[allow(clippy::cast_sign_loss)]
436 let wilson = wilson_lower_bound(success_raw as u64, count_raw as u64);
437 (tool, fp, tmpl, current_decay, wilson)
438 },
439 )
440 .collect();
441
442 let total: f64 = decayed.iter().map(|(_, _, _, d, _)| d).sum();
443 if total <= 0.0 {
444 return vec![];
445 }
446
447 let mut scored: Vec<_> = decayed
448 .into_iter()
449 .map(|(tool, fp, tmpl, d, wilson)| ((d / total) * wilson, tool, fp, tmpl, wilson))
450 .collect();
451 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
452 scored
453 .into_iter()
454 .map(|(score, tool, fp, tmpl, wilson)| (tool, fp, tmpl, score, wilson))
455 .collect()
456}
457
458#[allow(clippy::cast_possible_wrap)]
459fn unix_now() -> i64 {
460 SystemTime::now()
461 .duration_since(UNIX_EPOCH)
462 .unwrap_or_default()
463 .as_secs() as i64
464}
465
466#[allow(clippy::cast_precision_loss)]
468fn wilson_lower_bound(successes: u64, n: u64) -> f64 {
469 if n == 0 {
470 return 0.0;
471 }
472 let n = n as f64;
473 let p_hat = successes as f64 / n;
474 let z2 = Z * Z;
475 let numerator =
476 p_hat + z2 / (2.0 * n) - Z * (p_hat * (1.0 - p_hat) / n + z2 / (4.0 * n * n)).sqrt();
477 let denominator = 1.0 + z2 / n;
478 (numerator / denominator).max(0.0)
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn wilson_zero_observations() {
487 assert!((wilson_lower_bound(0, 0) - 0.0_f64).abs() < f64::EPSILON);
488 }
489
490 #[test]
491 fn wilson_all_success_small_n() {
492 let lb = wilson_lower_bound(3, 3);
494 assert!(lb > 0.0 && lb < 1.0, "got {lb}");
495 }
496
497 #[test]
498 fn wilson_zero_success() {
499 let lb = wilson_lower_bound(0, 10);
500 assert!(lb < 0.1, "got {lb}");
501 }
502
503 #[test]
504 fn fingerprint_deterministic_different_order() {
505 fn fp(json: &str) -> String {
506 let v: serde_json::Value = serde_json::from_str(json).unwrap();
507 let obj = v.as_object().cloned().unwrap_or_default();
508 hash_args(&obj).to_hex().to_string()
509 }
510 let a = r#"{"z": 1, "a": 2}"#;
511 let b = r#"{"a": 2, "z": 1}"#;
512 assert_eq!(fp(a), fp(b));
513 }
514}