Skip to main content

zeph_core/agent/speculative/
paste.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! PASTE — Pattern-Aware Speculative Tool Execution (issue #2409).
5//!
6//! Tracks per-skill tool invocation sequences in `SQLite` and surfaces top-K predicted
7//! next tool calls at skill activation time. Predictions are scored using exponentially
8//! decayed frequency × Wilson 95% one-sided lower bound on success rate.
9//!
10//! ## Scoring formula
11//!
12//! ```text
13//! count_decayed = Σ_i  0.5 ^ ((now - t_i) / half_life_seconds)
14//! p_hat         = success_raw / count_raw
15//! wilson_low    = (p_hat + z²/(2n) - z·sqrt(p_hat(1-p_hat)/n + z²/(4n²))) / (1 + z²/n)
16//!                 where n = count_raw, z = 1.645 (95% one-sided)
17//! freq_norm     = count_decayed / total_decayed  (over sibling (skill_hash, prev_tool) rows)
18//! score         = freq_norm * wilson_low
19//! ```
20
21#![allow(dead_code)]
22
23use std::sync::Arc;
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25
26use thiserror::Error;
27use tokio::sync::Mutex as AsyncMutex;
28use tracing::{debug, warn};
29use zeph_db::DbPool;
30
31use super::prediction::{Prediction, PredictionSource};
32use crate::agent::speculative::cache::{args_template, hash_args};
33
34/// Wilson 95% one-sided z-score.
35const Z: f64 = 1.645;
36
37/// Outcome of a tool call, used when observing a transition.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ToolOutcome {
40    Success,
41    Failure,
42}
43
44/// Error type for `PatternStore` operations.
45#[derive(Debug, Error)]
46pub enum PatternError {
47    #[error("database error: {0}")]
48    Db(#[from] zeph_db::sqlx::Error),
49    #[error("json error: {0}")]
50    Json(#[from] serde_json::Error),
51}
52
53/// Per-`(skill_hash, prev_tool)` refresh debounce state.
54struct RefreshState {
55    last_refresh: Option<std::time::Instant>,
56}
57
58/// SQLite-backed tool invocation pattern store for PASTE.
59///
60/// Thread-safe via `Arc`. Call [`observe`](Self::observe) after each tool completion
61/// and [`predict`](Self::predict) at skill activation to get speculative candidates.
62///
63/// # Examples
64///
65/// ```rust,no_run
66/// # use zeph_core::agent::speculative::paste::{PatternStore, ToolOutcome};
67/// # async fn example(pool: zeph_db::DbPool) -> Result<(), Box<dyn std::error::Error>> {
68/// let store = PatternStore::new(pool, 14.0);
69/// store.observe("my-skill", "abc123", None, "bash",
70///     r#"{"command":"ls"}"#, ToolOutcome::Success, 42).await?;
71/// let predictions = store.predict("my-skill", "abc123", None, 3).await?;
72/// # Ok(())
73/// # }
74/// ```
75pub struct PatternStore {
76    pool: DbPool,
77    half_life_days: f64,
78    refresh_debounce: Arc<AsyncMutex<std::collections::HashMap<String, RefreshState>>>,
79    min_observations: u32,
80}
81
82impl PatternStore {
83    /// Create a new pattern store.
84    ///
85    /// `half_life_days` controls the exponential decay; 14 days is the default.
86    #[must_use]
87    pub fn new(pool: DbPool, half_life_days: f64) -> Self {
88        Self {
89            pool,
90            half_life_days,
91            refresh_debounce: Arc::new(AsyncMutex::new(std::collections::HashMap::new())),
92            min_observations: 5,
93        }
94    }
95
96    /// Set the minimum number of raw observations required before predicting.
97    #[must_use]
98    pub fn with_min_observations(mut self, n: u32) -> Self {
99        self.min_observations = n;
100        self
101    }
102
103    /// Record a completed tool invocation.
104    ///
105    /// Updates `count_decayed` using the exponential decay formula (computed in Rust,
106    /// not via `pow()` in `SQLite` — `pow` requires `SQLITE_ENABLE_MATH_FUNCTIONS`
107    /// which is not available in bundled `libsqlite3-sys`) and appends to
108    /// `success_raw` / `count_raw`. Also triggers a debounced [`refresh`](Self::refresh).
109    ///
110    /// # Errors
111    ///
112    /// Returns [`PatternError::Db`] on `SQLite` failure.
113    #[allow(clippy::too_many_arguments)]
114    pub async fn observe(
115        &self,
116        skill_name: &str,
117        skill_hash: &str,
118        prev_tool: Option<&str>,
119        next_tool: &str,
120        args_json: &str,
121        outcome: ToolOutcome,
122        latency_ms: u64,
123    ) -> Result<(), PatternError> {
124        let now = unix_now();
125        let half_life_secs = self.half_life_days * 86_400.0;
126        let success_delta = i64::from(outcome == ToolOutcome::Success);
127        let args: serde_json::Value = serde_json::from_str(args_json)?;
128        let args_obj = args.as_object().cloned().unwrap_or_default();
129        let args_fingerprint = {
130            let h = hash_args(&args_obj);
131            h.to_hex().to_string()
132        };
133        let tmpl = args_template(&args_obj);
134        #[allow(clippy::cast_possible_wrap)]
135        let latency_i64 = latency_ms as i64;
136
137        // Fetch the existing row's count_decayed + last_seen_at so we can compute
138        // the updated decay value in Rust (C6: avoids SQLite pow() which requires
139        // SQLITE_ENABLE_MATH_FUNCTIONS not present in bundled libsqlite3-sys).
140        let existing = zeph_db::query_as::<_, (f64, i64, i64, i64)>(
141            r"
142            SELECT count_decayed, last_seen_at, count_raw, avg_latency_ms
143            FROM tool_pattern_transitions
144            WHERE skill_name = ? AND skill_hash = ?
145              AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
146              AND next_tool = ? AND args_fingerprint = ?
147            ",
148        )
149        .bind(skill_name)
150        .bind(skill_hash)
151        .bind(prev_tool)
152        .bind(prev_tool)
153        .bind(next_tool)
154        .bind(&args_fingerprint)
155        .fetch_optional(&self.pool)
156        .await?;
157
158        if let Some((old_decayed, last_seen_at, old_count_raw, old_avg_latency)) = existing {
159            #[allow(clippy::cast_precision_loss)]
160            let elapsed = (now - last_seen_at).max(0) as f64;
161            let new_decayed = old_decayed * 0.5f64.powf(elapsed / half_life_secs) + 1.0;
162            let new_count_raw = old_count_raw + 1;
163            #[allow(clippy::cast_precision_loss)]
164            let new_avg_latency = (old_avg_latency * old_count_raw + latency_i64) / new_count_raw;
165
166            zeph_db::query(
167                r"
168                UPDATE tool_pattern_transitions SET
169                    count_decayed  = ?,
170                    count_raw      = ?,
171                    success_raw    = success_raw + ?,
172                    last_seen_at   = ?,
173                    avg_latency_ms = ?
174                WHERE skill_name = ? AND skill_hash = ?
175                  AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
176                  AND next_tool = ? AND args_fingerprint = ?
177                ",
178            )
179            .bind(new_decayed)
180            .bind(new_count_raw)
181            .bind(success_delta)
182            .bind(now)
183            .bind(new_avg_latency)
184            .bind(skill_name)
185            .bind(skill_hash)
186            .bind(prev_tool)
187            .bind(prev_tool)
188            .bind(next_tool)
189            .bind(&args_fingerprint)
190            .execute(&self.pool)
191            .await?;
192        } else {
193            zeph_db::query(
194                r"
195                INSERT INTO tool_pattern_transitions
196                    (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
197                     args_template, count_raw, success_raw, count_decayed, last_seen_at, avg_latency_ms)
198                VALUES (?, ?, ?, ?, ?, ?, 1, ?, 1.0, ?, ?)
199                ",
200            )
201            .bind(skill_name)
202            .bind(skill_hash)
203            .bind(prev_tool)
204            .bind(next_tool)
205            .bind(&args_fingerprint)
206            .bind(&tmpl)
207            .bind(success_delta)
208            .bind(now)
209            .bind(latency_i64)
210            .execute(&self.pool)
211            .await?;
212        }
213
214        self.debounced_refresh(skill_name, skill_hash, prev_tool)
215            .await;
216        Ok(())
217    }
218
219    /// Return the top-`k` predicted next tool calls for `(skill, prev_tool)`.
220    ///
221    /// Only returns predictions with `wilson_lower_bound >= 0.5` and
222    /// `count_raw >= min_observations`.
223    ///
224    /// # Errors
225    ///
226    /// Returns [`PatternError::Db`] on `SQLite` failure.
227    pub async fn predict(
228        &self,
229        skill_name: &str,
230        skill_hash: &str,
231        prev_tool: Option<&str>,
232        k: u8,
233    ) -> Result<Vec<Prediction>, PatternError> {
234        let rows = zeph_db::query_as::<_, (String, String, f64, f64, i64)>(
235            r"
236            SELECT next_tool, args_template, score, wilson_lower_bound, rank
237            FROM tool_pattern_predictions
238            WHERE skill_name = ? AND skill_hash = ?
239              AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
240              AND wilson_lower_bound >= 0.5
241            ORDER BY rank ASC
242            LIMIT ?
243            ",
244        )
245        .bind(skill_name)
246        .bind(skill_hash)
247        .bind(prev_tool)
248        .bind(prev_tool)
249        .bind(i64::from(k))
250        .fetch_all(&self.pool)
251        .await?;
252
253        let predictions = rows
254            .into_iter()
255            .enumerate()
256            .filter_map(|(i, (next_tool, args_template, score, _wilson, _rank))| {
257                let args: serde_json::Map<String, serde_json::Value> =
258                    serde_json::from_str(&args_template).ok()?;
259                Some(Prediction {
260                    tool_id: zeph_common::ToolName::new(next_tool),
261                    args,
262                    #[allow(clippy::cast_possible_truncation)]
263                    confidence: score as f32,
264                    source: PredictionSource::HistoryPattern {
265                        skill: skill_name.to_owned(),
266                        #[allow(clippy::cast_possible_truncation)]
267                        rank: i as u8,
268                    },
269                })
270            })
271            .collect();
272
273        Ok(predictions)
274    }
275
276    /// Recompute and materialize predictions for `(skill, skill_hash, prev_tool)`.
277    ///
278    /// Debounced to at most once per 60 s per `(skill_hash, prev_tool)`.
279    /// Runs DELETE + N INSERTs inside a single transaction (H3: atomic refresh).
280    ///
281    /// # Errors
282    ///
283    /// Returns [`PatternError::Db`] on `SQLite` failure.
284    pub async fn refresh(
285        &self,
286        skill_name: &str,
287        skill_hash: &str,
288        prev_tool: Option<&str>,
289    ) -> Result<(), PatternError> {
290        let min_obs = self.min_observations;
291        let half_life_secs = self.half_life_days * 86_400.0;
292        let now = unix_now();
293
294        // Fetch all sibling transitions for Wilson + normalization.
295        // Also fetch args_template so predictions carry the real type-placeholder template (H1).
296        let rows = zeph_db::query_as::<_, (String, String, String, i64, i64, f64, i64)>(
297            r"
298            SELECT next_tool, args_fingerprint, args_template,
299                   count_raw, success_raw, count_decayed, last_seen_at
300            FROM tool_pattern_transitions
301            WHERE skill_name = ? AND skill_hash = ?
302              AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
303              AND count_raw >= ?
304            ",
305        )
306        .bind(skill_name)
307        .bind(skill_hash)
308        .bind(prev_tool)
309        .bind(prev_tool)
310        .bind(i64::from(min_obs))
311        .fetch_all(&self.pool)
312        .await?;
313
314        if rows.is_empty() {
315            return Ok(());
316        }
317
318        // Recompute decayed counts and Wilson scores in Rust, then normalize and rank.
319        let scored = score_rows(rows, now, half_life_secs);
320        if scored.is_empty() {
321            return Ok(());
322        }
323
324        // Wrap DELETE + N INSERTs in a transaction to prevent partial state on crash (H3).
325        let mut tx = zeph_db::begin(&self.pool).await?;
326
327        zeph_db::query(
328            "DELETE FROM tool_pattern_predictions \
329             WHERE skill_name = ? AND skill_hash = ? \
330             AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))",
331        )
332        .bind(skill_name)
333        .bind(skill_hash)
334        .bind(prev_tool)
335        .bind(prev_tool)
336        .execute(&mut *tx)
337        .await?;
338
339        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
340        for (rank, (next_tool, args_fp, tmpl, score, wilson)) in scored.iter().enumerate().take(10)
341        {
342            let rank_i64 = rank as i64;
343            zeph_db::query(
344                r"
345                INSERT OR REPLACE INTO tool_pattern_predictions
346                    (skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
347                     args_template, score, wilson_lower_bound, rank)
348                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
349                ",
350            )
351            .bind(skill_name)
352            .bind(skill_hash)
353            .bind(prev_tool)
354            .bind(next_tool)
355            .bind(args_fp)
356            .bind(tmpl)
357            .bind(score)
358            .bind(wilson)
359            .bind(rank_i64)
360            .execute(&mut *tx)
361            .await?;
362        }
363
364        tx.commit().await?;
365
366        debug!(
367            skill = skill_name,
368            prev_tool = prev_tool.unwrap_or("<activation>"),
369            "PASTE: refreshed {} predictions",
370            scored.len().min(10)
371        );
372        Ok(())
373    }
374
375    /// Purge `tool_pattern_transitions` rows with stale `skill_hash` older than 30 days.
376    ///
377    /// # Errors
378    ///
379    /// Returns [`PatternError::Db`] on `SQLite` failure.
380    pub async fn vacuum(&self) -> Result<u64, PatternError> {
381        let cutoff = unix_now() - 30 * 86_400;
382        let result = zeph_db::query("DELETE FROM tool_pattern_transitions WHERE last_seen_at < ?")
383            .bind(cutoff)
384            .execute(&self.pool)
385            .await?;
386        let rows = result.rows_affected();
387        if rows > 0 {
388            debug!("PASTE vacuum: removed {} stale rows", rows);
389        }
390        Ok(rows)
391    }
392
393    async fn debounced_refresh(&self, skill_name: &str, skill_hash: &str, prev_tool: Option<&str>) {
394        let key = format!("{skill_hash}:{}", prev_tool.unwrap_or(""));
395        let should_refresh = {
396            let mut map = self.refresh_debounce.lock().await;
397            let state = map
398                .entry(key.clone())
399                .or_insert(RefreshState { last_refresh: None });
400            match state.last_refresh {
401                None => true,
402                Some(t) => t.elapsed() >= Duration::from_mins(1),
403            }
404        };
405        if should_refresh {
406            if let Err(e) = self.refresh(skill_name, skill_hash, prev_tool).await {
407                warn!("PASTE refresh failed: {e}");
408            }
409            let mut map = self.refresh_debounce.lock().await;
410            if let Some(state) = map.get_mut(&key) {
411                state.last_refresh = Some(std::time::Instant::now());
412            }
413        }
414    }
415}
416
417/// Compute decay-adjusted Wilson scores for a batch of transition rows and return them
418/// sorted descending by score (top-K ready).
419///
420/// `rows` tuples: `(next_tool, args_fp, args_template, count_raw, success_raw, count_decayed, last_seen_at)`
421fn score_rows(
422    rows: Vec<(String, String, String, i64, i64, f64, i64)>,
423    now: i64,
424    half_life_secs: f64,
425) -> Vec<(String, String, String, f64, f64)> {
426    let decayed: Vec<_> = rows
427        .into_iter()
428        .map(
429            |(tool, fp, tmpl, count_raw, success_raw, count_decayed, last_seen_at)| {
430                #[allow(clippy::cast_precision_loss)]
431                let elapsed = now.saturating_sub(last_seen_at) as f64;
432                let current_decay = count_decayed * 0.5f64.powf(elapsed / half_life_secs);
433                #[allow(clippy::cast_sign_loss)]
434                let wilson = wilson_lower_bound(success_raw as u64, count_raw as u64);
435                (tool, fp, tmpl, current_decay, wilson)
436            },
437        )
438        .collect();
439
440    let total: f64 = decayed.iter().map(|(_, _, _, d, _)| d).sum();
441    if total <= 0.0 {
442        return vec![];
443    }
444
445    let mut scored: Vec<_> = decayed
446        .into_iter()
447        .map(|(tool, fp, tmpl, d, wilson)| ((d / total) * wilson, tool, fp, tmpl, wilson))
448        .collect();
449    scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
450    scored
451        .into_iter()
452        .map(|(score, tool, fp, tmpl, wilson)| (tool, fp, tmpl, score, wilson))
453        .collect()
454}
455
456#[allow(clippy::cast_possible_wrap)]
457fn unix_now() -> i64 {
458    SystemTime::now()
459        .duration_since(UNIX_EPOCH)
460        .unwrap_or_default()
461        .as_secs() as i64
462}
463
464/// Wilson 95% one-sided lower confidence bound.
465#[allow(clippy::cast_precision_loss)]
466fn wilson_lower_bound(successes: u64, n: u64) -> f64 {
467    if n == 0 {
468        return 0.0;
469    }
470    let n = n as f64;
471    let p_hat = successes as f64 / n;
472    let z2 = Z * Z;
473    let numerator =
474        p_hat + z2 / (2.0 * n) - Z * (p_hat * (1.0 - p_hat) / n + z2 / (4.0 * n * n)).sqrt();
475    let denominator = 1.0 + z2 / n;
476    (numerator / denominator).max(0.0)
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn wilson_zero_observations() {
485        assert!((wilson_lower_bound(0, 0) - 0.0_f64).abs() < f64::EPSILON);
486    }
487
488    #[test]
489    fn wilson_all_success_small_n() {
490        // 3/3 successes → lower bound well below 1.0 (small-sample conservatism)
491        let lb = wilson_lower_bound(3, 3);
492        assert!(lb > 0.0 && lb < 1.0, "got {lb}");
493    }
494
495    #[test]
496    fn wilson_zero_success() {
497        let lb = wilson_lower_bound(0, 10);
498        assert!(lb < 0.1, "got {lb}");
499    }
500
501    #[test]
502    fn fingerprint_deterministic_different_order() {
503        fn fp(json: &str) -> String {
504            let v: serde_json::Value = serde_json::from_str(json).unwrap();
505            let obj = v.as_object().cloned().unwrap_or_default();
506            hash_args(&obj).to_hex().to_string()
507        }
508        let a = r#"{"z": 1, "a": 2}"#;
509        let b = r#"{"a": 2, "z": 1}"#;
510        assert_eq!(fp(a), fp(b));
511    }
512}