Skip to main content

zeph_core/agent/speculative/
cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! In-flight speculative handle cache.
5//!
6//! Keyed by `(ToolName, blake3::Hash)` where the hash covers the tool's argument map.
7//! Bounded by `max_in_flight`; oldest handle (by `started_at`) is evicted and cancelled
8//! when the bound is exceeded.
9
10#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio_util::sync::CancellationToken;
18use zeph_common::ToolName;
19use zeph_common::task_supervisor::{BlockingError, BlockingHandle};
20use zeph_tools::{ToolError, ToolOutput};
21
22/// Unique key for a speculative handle: tool name + BLAKE3 hash of normalized args.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25    pub tool_id: ToolName,
26    pub args_hash: blake3::Hash,
27}
28
29/// An in-flight speculative execution handle.
30///
31/// Created when the engine dispatches a speculative tool call. Committed when the LLM
32/// confirms the same call on `ToolUseStop`; cancelled on mismatch or TTL expiry.
33pub struct SpeculativeHandle {
34    pub key: HandleKey,
35    pub join: BlockingHandle<Result<Option<ToolOutput>, ToolError>>,
36    pub cancel: CancellationToken,
37    /// Absolute wall-clock deadline; handle is cancelled by the sweeper when exceeded.
38    pub ttl_deadline: tokio::time::Instant,
39    pub started_at: Instant,
40}
41
42impl std::fmt::Debug for SpeculativeHandle {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("SpeculativeHandle")
45            .field("key", &self.key)
46            .field("ttl_deadline", &self.ttl_deadline)
47            .field("started_at", &self.started_at)
48            .finish_non_exhaustive()
49    }
50}
51
52impl SpeculativeHandle {
53    /// Cancel the in-flight task.
54    pub fn cancel(self) {
55        self.cancel.cancel();
56        self.join.abort();
57    }
58
59    /// Await the result; blocks until the task finishes or is cancelled.
60    ///
61    /// # Errors
62    ///
63    /// Returns [`ToolError::Execution`] if the task was cancelled or panicked.
64    pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
65        match self.join.join().await {
66            Ok(r) => r,
67            Err(BlockingError::Panicked) => Err(ToolError::Execution(std::io::Error::other(
68                "speculative task panicked",
69            ))),
70            Err(BlockingError::SupervisorDropped) => Err(ToolError::Execution(
71                std::io::Error::other("speculative task cancelled"),
72            )),
73        }
74    }
75}
76
77pub struct CacheInner {
78    pub handles: HashMap<HandleKey, SpeculativeHandle>,
79}
80
81/// Cache for in-flight speculative handles, bounded by `max_in_flight`.
82///
83/// Thread-safe; all operations hold a short `parking_lot::Mutex` lock.
84/// The inner `Arc<Mutex<CacheInner>>` is shared with the background TTL sweeper so
85/// both operate on the same handle set (C2: no separate empty instance in the sweeper).
86pub struct SpeculativeCache {
87    pub(crate) inner: Arc<Mutex<CacheInner>>,
88    max: usize,
89}
90
91impl SpeculativeCache {
92    /// Create a new cache with the given capacity.
93    #[must_use]
94    pub fn new(max_in_flight: usize) -> Self {
95        Self {
96            inner: Arc::new(Mutex::new(CacheInner {
97                handles: HashMap::new(),
98            })),
99            max: max_in_flight.clamp(1, 16),
100        }
101    }
102
103    /// Return a cloned `Arc` to the inner mutex so it can be shared with a sweeper task.
104    ///
105    /// The sweeper calls [`SpeculativeCache::sweep_expired_inner`] on the shared `Arc`
106    /// instead of constructing a second `SpeculativeCache` that would have separate storage.
107    #[must_use]
108    pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
109        Arc::clone(&self.inner)
110    }
111
112    /// Cancel and remove all handles whose TTL deadline has passed, operating on a raw `Arc`.
113    ///
114    /// Intended for use by the sweeper task, which holds only the `Arc` (not a full
115    /// `SpeculativeCache` wrapper).
116    pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
117        let now = tokio::time::Instant::now();
118        let mut g = inner.lock();
119        let expired: Vec<HandleKey> = g
120            .handles
121            .iter()
122            .filter(|(_, h)| h.ttl_deadline <= now)
123            .map(|(k, _)| k.clone())
124            .collect();
125        for key in expired {
126            if let Some(h) = g.handles.remove(&key) {
127                h.cancel();
128            }
129        }
130    }
131
132    /// Insert a new handle. If at capacity, evicts and cancels the oldest.
133    ///
134    /// If a handle with the same key already exists it is replaced and explicitly cancelled
135    /// so the underlying tokio task does not keep running (C4: no silent drop).
136    pub fn insert(&self, handle: SpeculativeHandle) {
137        let mut g = self.inner.lock();
138        if g.handles.len() >= self.max {
139            let oldest_key = g
140                .handles
141                .values()
142                .min_by_key(|h| h.started_at)
143                .map(|h| h.key.clone());
144            if let Some(key) = oldest_key
145                && let Some(evicted) = g.handles.remove(&key)
146            {
147                evicted.cancel();
148            }
149        }
150        if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
151            displaced.cancel();
152        }
153    }
154
155    /// Find and remove a handle matching `tool_id` + `args_hash`.
156    #[must_use]
157    pub fn take_match(
158        &self,
159        tool_id: &ToolName,
160        args_hash: &blake3::Hash,
161    ) -> Option<SpeculativeHandle> {
162        let key = HandleKey {
163            tool_id: tool_id.clone(),
164            args_hash: *args_hash,
165        };
166        self.inner.lock().handles.remove(&key)
167    }
168
169    /// Remove and cancel the first handle whose `tool_id` matches, if any.
170    ///
171    /// Used when the args hash is not known (e.g., on tool-id mismatch at dispatch time).
172    pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
173        let mut g = self.inner.lock();
174        let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
175        if let Some(key) = key
176            && let Some(h) = g.handles.remove(&key)
177        {
178            h.cancel();
179        }
180    }
181
182    /// Cancel and remove all handles whose TTL deadline has passed.
183    ///
184    /// Called by the sweeper task every 5 s.
185    pub fn sweep_expired(&self) {
186        Self::sweep_expired_inner(&self.inner);
187    }
188
189    /// Cancel and remove all remaining handles (called at turn boundary).
190    pub fn cancel_all(&self) {
191        let mut g = self.inner.lock();
192        for (_, h) in g.handles.drain() {
193            h.cancel();
194        }
195    }
196
197    /// Number of in-flight handles.
198    #[must_use]
199    pub fn len(&self) -> usize {
200        self.inner.lock().handles.len()
201    }
202
203    /// True when the cache is empty.
204    #[must_use]
205    pub fn is_empty(&self) -> bool {
206        self.len() == 0
207    }
208}
209
210/// Compute a BLAKE3 hash over a normalized JSON args map.
211///
212/// Keys are sorted lexicographically before hashing to ensure arg-order independence.
213#[must_use]
214pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
215    let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
216    keys.sort_unstable();
217    let mut hasher = blake3::Hasher::new();
218    for k in keys {
219        hasher.update(k.as_bytes());
220        hasher.update(b"\x00");
221        let v = args[k].to_string();
222        hasher.update(v.as_bytes());
223        hasher.update(b"\x00");
224    }
225    hasher.finalize()
226}
227
228/// Produce a normalized args template: top-level keys with their JSON type as placeholder value.
229///
230/// Used by `PatternStore` to store a template that is stable across observations with varying
231/// argument values. Example: `{"command":"<string>","timeout":"<number>"}`.
232#[must_use]
233pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
234    let template: serde_json::Map<String, serde_json::Value> = args
235        .iter()
236        .map(|(k, v)| {
237            let placeholder = match v {
238                serde_json::Value::String(_) => serde_json::json!("<string>"),
239                serde_json::Value::Number(_) => serde_json::json!("<number>"),
240                serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
241                serde_json::Value::Array(_) => serde_json::json!("<array>"),
242                serde_json::Value::Object(_) => serde_json::json!("<object>"),
243                serde_json::Value::Null => serde_json::json!(null),
244            };
245            (k.clone(), placeholder)
246        })
247        .collect();
248    serde_json::Value::Object(template).to_string()
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn hash_args_order_independent() {
257        let mut a = serde_json::Map::new();
258        a.insert("z".into(), serde_json::json!(1));
259        a.insert("a".into(), serde_json::json!(2));
260
261        let mut b = serde_json::Map::new();
262        b.insert("a".into(), serde_json::json!(2));
263        b.insert("z".into(), serde_json::json!(1));
264
265        assert_eq!(hash_args(&a), hash_args(&b));
266    }
267
268    #[test]
269    fn hash_args_different_values() {
270        let mut a = serde_json::Map::new();
271        a.insert("x".into(), serde_json::json!(1));
272        let mut b = serde_json::Map::new();
273        b.insert("x".into(), serde_json::json!(2));
274        assert_ne!(hash_args(&a), hash_args(&b));
275    }
276
277    #[test]
278    fn args_template_replaces_values_with_type_placeholders() {
279        let mut m = serde_json::Map::new();
280        m.insert("cmd".into(), serde_json::json!("ls -la"));
281        m.insert("timeout".into(), serde_json::json!(30));
282        m.insert("flag".into(), serde_json::json!(true));
283        let t = args_template(&m);
284        assert!(t.contains("<string>"));
285        assert!(t.contains("<number>"));
286        assert!(t.contains("<bool>"));
287    }
288}