Skip to main content

zeph_core/agent/speculative/
cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! In-flight speculative handle cache.
5//!
6//! Keyed by `(ToolName, blake3::Hash)` where the hash covers the tool's argument map.
7//! Bounded by `max_in_flight`; oldest handle (by `started_at`) is evicted and cancelled
8//! when the bound is exceeded.
9
10#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio::task::JoinHandle;
18use tokio_util::sync::CancellationToken;
19use zeph_common::ToolName;
20use zeph_tools::{ToolError, ToolOutput};
21
22/// Unique key for a speculative handle: tool name + BLAKE3 hash of normalized args.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25    pub tool_id: ToolName,
26    pub args_hash: blake3::Hash,
27}
28
29/// An in-flight speculative execution handle.
30///
31/// Created when the engine dispatches a speculative tool call. Committed when the LLM
32/// confirms the same call on `ToolUseStop`; cancelled on mismatch or TTL expiry.
33pub struct SpeculativeHandle {
34    pub key: HandleKey,
35    pub join: JoinHandle<Result<Option<ToolOutput>, ToolError>>,
36    pub cancel: CancellationToken,
37    /// Absolute wall-clock deadline; handle is cancelled by the sweeper when exceeded.
38    pub ttl_deadline: tokio::time::Instant,
39    pub started_at: Instant,
40}
41
42impl std::fmt::Debug for SpeculativeHandle {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("SpeculativeHandle")
45            .field("key", &self.key)
46            .field("ttl_deadline", &self.ttl_deadline)
47            .field("started_at", &self.started_at)
48            .finish_non_exhaustive()
49    }
50}
51
52impl SpeculativeHandle {
53    /// Cancel the in-flight task.
54    pub fn cancel(self) {
55        self.cancel.cancel();
56        self.join.abort();
57    }
58
59    /// Await the result; blocks until the task finishes or is cancelled.
60    ///
61    /// # Errors
62    ///
63    /// Returns [`ToolError::Execution`] if the task was cancelled or panicked.
64    pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
65        match self.join.await {
66            Ok(r) => r,
67            Err(e) if e.is_cancelled() => Err(ToolError::Execution(std::io::Error::other(
68                "speculative task cancelled",
69            ))),
70            Err(e) => Err(ToolError::Execution(std::io::Error::other(e.to_string()))),
71        }
72    }
73}
74
75pub struct CacheInner {
76    pub handles: HashMap<HandleKey, SpeculativeHandle>,
77}
78
79/// Cache for in-flight speculative handles, bounded by `max_in_flight`.
80///
81/// Thread-safe; all operations hold a short `parking_lot::Mutex` lock.
82/// The inner `Arc<Mutex<CacheInner>>` is shared with the background TTL sweeper so
83/// both operate on the same handle set (C2: no separate empty instance in the sweeper).
84pub struct SpeculativeCache {
85    pub(crate) inner: Arc<Mutex<CacheInner>>,
86    max: usize,
87}
88
89impl SpeculativeCache {
90    /// Create a new cache with the given capacity.
91    #[must_use]
92    pub fn new(max_in_flight: usize) -> Self {
93        Self {
94            inner: Arc::new(Mutex::new(CacheInner {
95                handles: HashMap::new(),
96            })),
97            max: max_in_flight.clamp(1, 16),
98        }
99    }
100
101    /// Return a cloned `Arc` to the inner mutex so it can be shared with a sweeper task.
102    ///
103    /// The sweeper calls [`SpeculativeCache::sweep_expired_inner`] on the shared `Arc`
104    /// instead of constructing a second `SpeculativeCache` that would have separate storage.
105    #[must_use]
106    pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
107        Arc::clone(&self.inner)
108    }
109
110    /// Cancel and remove all handles whose TTL deadline has passed, operating on a raw `Arc`.
111    ///
112    /// Intended for use by the sweeper task, which holds only the `Arc` (not a full
113    /// `SpeculativeCache` wrapper).
114    pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
115        let now = tokio::time::Instant::now();
116        let mut g = inner.lock();
117        let expired: Vec<HandleKey> = g
118            .handles
119            .iter()
120            .filter(|(_, h)| h.ttl_deadline <= now)
121            .map(|(k, _)| k.clone())
122            .collect();
123        for key in expired {
124            if let Some(h) = g.handles.remove(&key) {
125                h.cancel();
126            }
127        }
128    }
129
130    /// Insert a new handle. If at capacity, evicts and cancels the oldest.
131    ///
132    /// If a handle with the same key already exists it is replaced and explicitly cancelled
133    /// so the underlying tokio task does not keep running (C4: no silent drop).
134    pub fn insert(&self, handle: SpeculativeHandle) {
135        let mut g = self.inner.lock();
136        if g.handles.len() >= self.max {
137            let oldest_key = g
138                .handles
139                .values()
140                .min_by_key(|h| h.started_at)
141                .map(|h| h.key.clone());
142            if let Some(key) = oldest_key
143                && let Some(evicted) = g.handles.remove(&key)
144            {
145                evicted.cancel();
146            }
147        }
148        if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
149            displaced.cancel();
150        }
151    }
152
153    /// Find and remove a handle matching `tool_id` + `args_hash`.
154    #[must_use]
155    pub fn take_match(
156        &self,
157        tool_id: &ToolName,
158        args_hash: &blake3::Hash,
159    ) -> Option<SpeculativeHandle> {
160        let key = HandleKey {
161            tool_id: tool_id.clone(),
162            args_hash: *args_hash,
163        };
164        self.inner.lock().handles.remove(&key)
165    }
166
167    /// Remove and cancel the first handle whose `tool_id` matches, if any.
168    ///
169    /// Used when the args hash is not known (e.g., on tool-id mismatch at dispatch time).
170    pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
171        let mut g = self.inner.lock();
172        let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
173        if let Some(key) = key
174            && let Some(h) = g.handles.remove(&key)
175        {
176            h.cancel();
177        }
178    }
179
180    /// Cancel and remove all handles whose TTL deadline has passed.
181    ///
182    /// Called by the sweeper task every 5 s.
183    pub fn sweep_expired(&self) {
184        Self::sweep_expired_inner(&self.inner);
185    }
186
187    /// Cancel and remove all remaining handles (called at turn boundary).
188    pub fn cancel_all(&self) {
189        let mut g = self.inner.lock();
190        for (_, h) in g.handles.drain() {
191            h.cancel();
192        }
193    }
194
195    /// Number of in-flight handles.
196    #[must_use]
197    pub fn len(&self) -> usize {
198        self.inner.lock().handles.len()
199    }
200
201    /// True when the cache is empty.
202    #[must_use]
203    pub fn is_empty(&self) -> bool {
204        self.len() == 0
205    }
206}
207
208/// Compute a BLAKE3 hash over a normalized JSON args map.
209///
210/// Keys are sorted lexicographically before hashing to ensure arg-order independence.
211#[must_use]
212pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
213    let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
214    keys.sort_unstable();
215    let mut hasher = blake3::Hasher::new();
216    for k in keys {
217        hasher.update(k.as_bytes());
218        hasher.update(b"\x00");
219        let v = args[k].to_string();
220        hasher.update(v.as_bytes());
221        hasher.update(b"\x00");
222    }
223    hasher.finalize()
224}
225
226/// Produce a normalized args template: top-level keys with their JSON type as placeholder value.
227///
228/// Used by `PatternStore` to store a template that is stable across observations with varying
229/// argument values. Example: `{"command":"<string>","timeout":"<number>"}`.
230#[must_use]
231pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
232    let template: serde_json::Map<String, serde_json::Value> = args
233        .iter()
234        .map(|(k, v)| {
235            let placeholder = match v {
236                serde_json::Value::String(_) => serde_json::json!("<string>"),
237                serde_json::Value::Number(_) => serde_json::json!("<number>"),
238                serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
239                serde_json::Value::Array(_) => serde_json::json!("<array>"),
240                serde_json::Value::Object(_) => serde_json::json!("<object>"),
241                serde_json::Value::Null => serde_json::json!(null),
242            };
243            (k.clone(), placeholder)
244        })
245        .collect();
246    serde_json::Value::Object(template).to_string()
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn hash_args_order_independent() {
255        let mut a = serde_json::Map::new();
256        a.insert("z".into(), serde_json::json!(1));
257        a.insert("a".into(), serde_json::json!(2));
258
259        let mut b = serde_json::Map::new();
260        b.insert("a".into(), serde_json::json!(2));
261        b.insert("z".into(), serde_json::json!(1));
262
263        assert_eq!(hash_args(&a), hash_args(&b));
264    }
265
266    #[test]
267    fn hash_args_different_values() {
268        let mut a = serde_json::Map::new();
269        a.insert("x".into(), serde_json::json!(1));
270        let mut b = serde_json::Map::new();
271        b.insert("x".into(), serde_json::json!(2));
272        assert_ne!(hash_args(&a), hash_args(&b));
273    }
274
275    #[test]
276    fn args_template_replaces_values_with_type_placeholders() {
277        let mut m = serde_json::Map::new();
278        m.insert("cmd".into(), serde_json::json!("ls -la"));
279        m.insert("timeout".into(), serde_json::json!(30));
280        m.insert("flag".into(), serde_json::json!(true));
281        let t = args_template(&m);
282        assert!(t.contains("<string>"));
283        assert!(t.contains("<number>"));
284        assert!(t.contains("<bool>"));
285    }
286}