Skip to main content

zeph_core/agent/speculative/
paste.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! PASTE — Pattern-Aware Speculative Tool Execution (issue #2409).
5//!
6//! Tracks per-skill tool invocation sequences in `SQLite` and surfaces top-K predicted
7//! next tool calls at skill activation time. Predictions are scored using exponentially
8//! decayed frequency × Wilson 95% one-sided lower bound on success rate.
9//!
10//! ## Scoring formula
11//!
12//! ```text
13//! count_decayed = Σ_i  0.5 ^ ((now - t_i) / half_life_seconds)
14//! p_hat         = success_raw / count_raw
15//! wilson_low    = (p_hat + z²/(2n) - z·sqrt(p_hat(1-p_hat)/n + z²/(4n²))) / (1 + z²/n)
16//!                 where n = count_raw, z = 1.645 (95% one-sided)
17//! freq_norm     = count_decayed / total_decayed  (over sibling (skill_hash, prev_tool) rows)
18//! score         = freq_norm * wilson_low
19//! ```
20
21#![allow(dead_code)]
22
23use std::sync::Arc;
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25
26use thiserror::Error;
27use tokio::sync::Mutex as AsyncMutex;
28use tracing::{debug, warn};
29use zeph_db::DbPool;
30
31use super::prediction::{Prediction, PredictionSource};
32use crate::agent::speculative::cache::{args_template, hash_args};
33
34/// Wilson 95% one-sided z-score.
35const Z: f64 = 1.645;
36
37#[non_exhaustive]
38/// Outcome of a tool call, used when observing a transition.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum ToolOutcome {
41    Success,
42    Failure,
43}
44
45#[non_exhaustive]
46/// Error type for `PatternStore` operations.
47#[derive(Debug, Error)]
48pub enum PatternError {
49    #[error("sqlx error: {0}")]
50    Db(#[from] zeph_db::sqlx::Error),
51    #[error("json error: {0}")]
52    Json(#[from] serde_json::Error),
53}
54
55/// Per-`(skill_hash, prev_tool)` refresh debounce state.
56struct RefreshState {
57    last_refresh: Option<std::time::Instant>,
58}
59
60/// SQLite-backed tool invocation pattern store for PASTE.
61///
62/// Thread-safe via `Arc`. Call [`observe`](Self::observe) after each tool completion
63/// and [`predict`](Self::predict) at skill activation to get speculative candidates.
64///
65/// # Examples
66///
67/// ```rust,no_run
68/// # use zeph_core::agent::speculative::paste::{PatternStore, ToolOutcome};
69/// # async fn example(pool: zeph_db::DbPool) -> Result<(), Box<dyn std::error::Error>> {
70/// let store = PatternStore::new(pool, 14.0);
71/// store.observe("my-skill", "abc123", None, "bash",
72///     r#"{"command":"ls"}"#, ToolOutcome::Success, 42).await?;
73/// let predictions = store.predict("my-skill", "abc123", None, 3).await?;
74/// # Ok(())
75/// # }
76/// ```
77pub struct PatternStore {
78    pool: DbPool,
79    half_life_days: f64,
80    refresh_debounce: Arc<AsyncMutex<std::collections::HashMap<String, RefreshState>>>,
81    min_observations: u32,
82}
83
84impl PatternStore {
85    /// Create a new pattern store.
86    ///
87    /// `half_life_days` controls the exponential decay; 14 days is the default.
88    #[must_use]
89    pub fn new(pool: DbPool, half_life_days: f64) -> Self {
90        Self {
91            pool,
92            half_life_days,
93            refresh_debounce: Arc::new(AsyncMutex::new(std::collections::HashMap::new())),
94            min_observations: 5,
95        }
96    }
97
98    /// Set the minimum number of raw observations required before predicting.
99    #[must_use]
100    pub fn with_min_observations(mut self, n: u32) -> Self {
101        self.min_observations = n;
102        self
103    }
104
105    /// Record a completed tool invocation.
106    ///
107    /// Updates `count_decayed` using the exponential decay formula (computed in Rust,
108    /// not via `pow()` in `SQLite` — `pow` requires `SQLITE_ENABLE_MATH_FUNCTIONS`
109    /// which is not available in bundled `libsqlite3-sys`) and appends to
110    /// `success_raw` / `count_raw`. Also triggers a debounced [`refresh`](Self::refresh).
111    ///
112    /// # Errors
113    ///
114    /// Returns [`PatternError::Db`] on `SQLite` failure.
115    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
116    pub async fn observe(
117        &self,
118        skill_name: &str,
119        skill_hash: &str,
120        prev_tool: Option<&str>,
121        next_tool: &str,
122        args_json: &str,
123        outcome: ToolOutcome,
124        latency_ms: u64,
125    ) -> Result<(), PatternError> {
126        let now = unix_now();
127        let half_life_secs = self.half_life_days * 86_400.0;
128        let success_delta = i64::from(outcome == ToolOutcome::Success);
129        let args: serde_json::Value = serde_json::from_str(args_json)?;
130        let args_obj = args.as_object().cloned().unwrap_or_default();
131        let args_fingerprint = {
132            let h = hash_args(&args_obj);
133            h.to_hex().to_string()
134        };
135        let tmpl = args_template(&args_obj);
136        #[allow(clippy::cast_possible_wrap)]
137        let latency_i64 = latency_ms as i64;
138
139        // Fetch the existing row's count_decayed + last_seen_at so we can compute
140        // the updated decay value in Rust (C6: avoids SQLite pow() which requires
141        // SQLITE_ENABLE_MATH_FUNCTIONS not present in bundled libsqlite3-sys).
142        let existing = zeph_db::query_as::<_, (f64, i64, i64, i64)>(
143            r"
144            SELECT count_decayed, last_seen_at, count_raw, avg_latency_ms
145            FROM tool_pattern_transitions
146            WHERE skill_name = ? AND skill_hash = ?
147              AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
148              AND next_tool = ? AND args_fingerprint = ?
149            ",
150        )
151        .bind(skill_name)
152        .bind(skill_hash)
153        .bind(prev_tool)
154        .bind(prev_tool)
155        .bind(next_tool)
156        .bind(&args_fingerprint)
157        .fetch_optional(&self.pool)
158        .await?;
159
160        if let Some((old_decayed, last_seen_at, old_count_raw, old_avg_latency)) = existing {
161            #[allow(clippy::cast_precision_loss)]
162            let elapsed = (now - last_seen_at).max(0) as f64;
163            let new_decayed = old_decayed * 0.5f64.powf(elapsed / half_life_secs) + 1.0;
164            let new_count_raw = old_count_raw + 1;
165            #[allow(clippy::cast_precision_loss)]
166            let new_avg_latency = (old_avg_latency * old_count_raw + latency_i64) / new_count_raw;
167
168            zeph_db::query(
169                r"
170                UPDATE tool_pattern_transitions SET
171                    count_decayed  = ?,
172                    count_raw      = ?,
173                    success_raw    = success_raw + ?,
174                    last_seen_at   = ?,
175                    avg_latency_ms = ?
176                WHERE skill_name = ? AND skill_hash = ?
177                  AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
178                  AND next_tool = ? AND args_fingerprint = ?
179                ",
180            )
181            .bind(new_decayed)
182            .bind(new_count_raw)
183            .bind(success_delta)
184            .bind(now)
185            .bind(new_avg_latency)
186            .bind(skill_name)
187            .bind(skill_hash)
188            .bind(prev_tool)
189            .bind(prev_tool)
190            .bind(next_tool)
191            .bind(&args_fingerprint)
192            .execute(&self.pool)
193            .await?;
194        } else {
195            zeph_db::query(
196                r"
197                INSERT INTO tool_pattern_transitions
198                    (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
199                     args_template, count_raw, success_raw, count_decayed, last_seen_at, avg_latency_ms)
200                VALUES (?, ?, ?, ?, ?, ?, 1, ?, 1.0, ?, ?)
201                ",
202            )
203            .bind(skill_name)
204            .bind(skill_hash)
205            .bind(prev_tool)
206            .bind(next_tool)
207            .bind(&args_fingerprint)
208            .bind(&tmpl)
209            .bind(success_delta)
210            .bind(now)
211            .bind(latency_i64)
212            .execute(&self.pool)
213            .await?;
214        }
215
216        self.debounced_refresh(skill_name, skill_hash, prev_tool)
217            .await;
218        Ok(())
219    }
220
221    /// Return the top-`k` predicted next tool calls for `(skill, prev_tool)`.
222    ///
223    /// Only returns predictions with `wilson_lower_bound >= 0.5` and
224    /// `count_raw >= min_observations`.
225    ///
226    /// # Errors
227    ///
228    /// Returns [`PatternError::Db`] on `SQLite` failure.
229    pub async fn predict(
230        &self,
231        skill_name: &str,
232        skill_hash: &str,
233        prev_tool: Option<&str>,
234        k: u8,
235    ) -> Result<Vec<Prediction>, PatternError> {
236        let rows = zeph_db::query_as::<_, (String, String, f64, f64, i64)>(
237            r"
238            SELECT next_tool, args_template, score, wilson_lower_bound, rank
239            FROM tool_pattern_predictions
240            WHERE skill_name = ? AND skill_hash = ?
241              AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
242              AND wilson_lower_bound >= 0.5
243            ORDER BY rank ASC
244            LIMIT ?
245            ",
246        )
247        .bind(skill_name)
248        .bind(skill_hash)
249        .bind(prev_tool)
250        .bind(prev_tool)
251        .bind(i64::from(k))
252        .fetch_all(&self.pool)
253        .await?;
254
255        let predictions = rows
256            .into_iter()
257            .enumerate()
258            .filter_map(|(i, (next_tool, args_template, score, _wilson, _rank))| {
259                let args: serde_json::Map<String, serde_json::Value> =
260                    serde_json::from_str(&args_template).ok()?;
261                Some(Prediction {
262                    tool_id: zeph_common::ToolName::new(next_tool),
263                    args,
264                    #[allow(clippy::cast_possible_truncation)]
265                    confidence: score as f32,
266                    source: PredictionSource::HistoryPattern {
267                        skill: skill_name.to_owned(),
268                        #[allow(clippy::cast_possible_truncation)]
269                        rank: i as u8,
270                    },
271                })
272            })
273            .collect();
274
275        Ok(predictions)
276    }
277
278    /// Recompute and materialize predictions for `(skill, skill_hash, prev_tool)`.
279    ///
280    /// Debounced to at most once per 60 s per `(skill_hash, prev_tool)`.
281    /// Runs DELETE + N INSERTs inside a single transaction (H3: atomic refresh).
282    ///
283    /// # Errors
284    ///
285    /// Returns [`PatternError::Db`] on `SQLite` failure.
286    pub async fn refresh(
287        &self,
288        skill_name: &str,
289        skill_hash: &str,
290        prev_tool: Option<&str>,
291    ) -> Result<(), PatternError> {
292        let min_obs = self.min_observations;
293        let half_life_secs = self.half_life_days * 86_400.0;
294        let now = unix_now();
295
296        // Fetch all sibling transitions for Wilson + normalization.
297        // Also fetch args_template so predictions carry the real type-placeholder template (H1).
298        let rows = zeph_db::query_as::<_, (String, String, String, i64, i64, f64, i64)>(
299            r"
300            SELECT next_tool, args_fingerprint, args_template,
301                   count_raw, success_raw, count_decayed, last_seen_at
302            FROM tool_pattern_transitions
303            WHERE skill_name = ? AND skill_hash = ?
304              AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
305              AND count_raw >= ?
306            ",
307        )
308        .bind(skill_name)
309        .bind(skill_hash)
310        .bind(prev_tool)
311        .bind(prev_tool)
312        .bind(i64::from(min_obs))
313        .fetch_all(&self.pool)
314        .await?;
315
316        if rows.is_empty() {
317            return Ok(());
318        }
319
320        // Recompute decayed counts and Wilson scores in Rust, then normalize and rank.
321        let scored = score_rows(rows, now, half_life_secs);
322        if scored.is_empty() {
323            return Ok(());
324        }
325
326        // Wrap DELETE + N INSERTs in a transaction to prevent partial state on crash (H3).
327        let mut tx = zeph_db::begin(&self.pool).await?;
328
329        zeph_db::query(
330            "DELETE FROM tool_pattern_predictions \
331             WHERE skill_name = ? AND skill_hash = ? \
332             AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))",
333        )
334        .bind(skill_name)
335        .bind(skill_hash)
336        .bind(prev_tool)
337        .bind(prev_tool)
338        .execute(&mut *tx)
339        .await?;
340
341        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
342        for (rank, (next_tool, args_fp, tmpl, score, wilson)) in scored.iter().enumerate().take(10)
343        {
344            let rank_i64 = rank as i64;
345            zeph_db::query(
346                r"
347                INSERT OR REPLACE INTO tool_pattern_predictions
348                    (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
349                     args_template, score, wilson_lower_bound, rank)
350                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
351                ",
352            )
353            .bind(skill_name)
354            .bind(skill_hash)
355            .bind(prev_tool)
356            .bind(next_tool)
357            .bind(args_fp)
358            .bind(tmpl)
359            .bind(score)
360            .bind(wilson)
361            .bind(rank_i64)
362            .execute(&mut *tx)
363            .await?;
364        }
365
366        tx.commit().await?;
367
368        debug!(
369            skill = skill_name,
370            prev_tool = prev_tool.unwrap_or("<activation>"),
371            "PASTE: refreshed {} predictions",
372            scored.len().min(10)
373        );
374        Ok(())
375    }
376
377    /// Purge `tool_pattern_transitions` rows with stale `skill_hash` older than 30 days.
378    ///
379    /// # Errors
380    ///
381    /// Returns [`PatternError::Db`] on `SQLite` failure.
382    pub async fn vacuum(&self) -> Result<u64, PatternError> {
383        let cutoff = unix_now() - 30 * 86_400;
384        let result = zeph_db::query("DELETE FROM tool_pattern_transitions WHERE last_seen_at < ?")
385            .bind(cutoff)
386            .execute(&self.pool)
387            .await?;
388        let rows = result.rows_affected();
389        if rows > 0 {
390            debug!("PASTE vacuum: removed {} stale rows", rows);
391        }
392        Ok(rows)
393    }
394
395    async fn debounced_refresh(&self, skill_name: &str, skill_hash: &str, prev_tool: Option<&str>) {
396        let key = format!("{skill_hash}:{}", prev_tool.unwrap_or(""));
397        let should_refresh = {
398            let mut map = self.refresh_debounce.lock().await;
399            let state = map
400                .entry(key.clone())
401                .or_insert(RefreshState { last_refresh: None });
402            match state.last_refresh {
403                None => true,
404                Some(t) => t.elapsed() >= Duration::from_mins(1),
405            }
406        };
407        if should_refresh {
408            if let Err(e) = self.refresh(skill_name, skill_hash, prev_tool).await {
409                warn!("PASTE refresh failed: {e}");
410            }
411            let mut map = self.refresh_debounce.lock().await;
412            if let Some(state) = map.get_mut(&key) {
413                state.last_refresh = Some(std::time::Instant::now());
414            }
415        }
416    }
417}
418
419/// Compute decay-adjusted Wilson scores for a batch of transition rows and return them
420/// sorted descending by score (top-K ready).
421///
422/// `rows` tuples: `(next_tool, args_fp, args_template, count_raw, success_raw, count_decayed, last_seen_at)`
423fn score_rows(
424    rows: Vec<(String, String, String, i64, i64, f64, i64)>,
425    now: i64,
426    half_life_secs: f64,
427) -> Vec<(String, String, String, f64, f64)> {
428    let decayed: Vec<_> = rows
429        .into_iter()
430        .map(
431            |(tool, fp, tmpl, count_raw, success_raw, count_decayed, last_seen_at)| {
432                #[allow(clippy::cast_precision_loss)]
433                let elapsed = now.saturating_sub(last_seen_at) as f64;
434                let current_decay = count_decayed * 0.5f64.powf(elapsed / half_life_secs);
435                #[allow(clippy::cast_sign_loss)]
436                let wilson = wilson_lower_bound(success_raw as u64, count_raw as u64);
437                (tool, fp, tmpl, current_decay, wilson)
438            },
439        )
440        .collect();
441
442    let total: f64 = decayed.iter().map(|(_, _, _, d, _)| d).sum();
443    if total <= 0.0 {
444        return vec![];
445    }
446
447    let mut scored: Vec<_> = decayed
448        .into_iter()
449        .map(|(tool, fp, tmpl, d, wilson)| ((d / total) * wilson, tool, fp, tmpl, wilson))
450        .collect();
451    scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
452    scored
453        .into_iter()
454        .map(|(score, tool, fp, tmpl, wilson)| (tool, fp, tmpl, score, wilson))
455        .collect()
456}
457
458#[allow(clippy::cast_possible_wrap)]
459fn unix_now() -> i64 {
460    SystemTime::now()
461        .duration_since(UNIX_EPOCH)
462        .unwrap_or_default()
463        .as_secs() as i64
464}
465
466/// Wilson 95% one-sided lower confidence bound.
467#[allow(clippy::cast_precision_loss)]
468fn wilson_lower_bound(successes: u64, n: u64) -> f64 {
469    if n == 0 {
470        return 0.0;
471    }
472    let n = n as f64;
473    let p_hat = successes as f64 / n;
474    let z2 = Z * Z;
475    let numerator =
476        p_hat + z2 / (2.0 * n) - Z * (p_hat * (1.0 - p_hat) / n + z2 / (4.0 * n * n)).sqrt();
477    let denominator = 1.0 + z2 / n;
478    (numerator / denominator).max(0.0)
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn wilson_zero_observations() {
487        assert!((wilson_lower_bound(0, 0) - 0.0_f64).abs() < f64::EPSILON);
488    }
489
490    #[test]
491    fn wilson_all_success_small_n() {
492        // 3/3 successes → lower bound well below 1.0 (small-sample conservatism)
493        let lb = wilson_lower_bound(3, 3);
494        assert!(lb > 0.0 && lb < 1.0, "got {lb}");
495    }
496
497    #[test]
498    fn wilson_zero_success() {
499        let lb = wilson_lower_bound(0, 10);
500        assert!(lb < 0.1, "got {lb}");
501    }
502
503    #[test]
504    fn fingerprint_deterministic_different_order() {
505        fn fp(json: &str) -> String {
506            let v: serde_json::Value = serde_json::from_str(json).unwrap();
507            let obj = v.as_object().cloned().unwrap_or_default();
508            hash_args(&obj).to_hex().to_string()
509        }
510        let a = r#"{"z": 1, "a": 2}"#;
511        let b = r#"{"a": 2, "z": 1}"#;
512        assert_eq!(fp(a), fp(b));
513    }
514}