Skip to main content

zeph_core/agent/speculative/
cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! In-flight speculative handle cache.
5//!
6//! Keyed by `(ToolName, blake3::Hash)` where the hash covers the tool's argument map.
7//! Bounded by `max_in_flight`; oldest handle (by `started_at`) is evicted and cancelled
8//! when the bound is exceeded.
9
10#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio_util::sync::CancellationToken;
18use zeph_common::ToolName;
19use zeph_common::task_supervisor::{BlockingError, BlockingHandle};
20use zeph_tools::{ExecutionContext, ToolError, ToolOutput};
21
22/// Unique key for a speculative handle: tool name + BLAKE3 hash of normalized args + context.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25    pub tool_id: ToolName,
26    pub args_hash: blake3::Hash,
27    /// Hash of the [`ExecutionContext`] fields. Two calls with different contexts must not share
28    /// a speculative result — the resolved env/cwd would differ.
29    pub context_hash: blake3::Hash,
30}
31
32/// An in-flight speculative execution handle.
33///
34/// Created when the engine dispatches a speculative tool call. Committed when the LLM
35/// confirms the same call on `ToolUseStop`; cancelled on mismatch or TTL expiry.
36pub struct SpeculativeHandle {
37    pub key: HandleKey,
38    pub join: BlockingHandle<Result<Option<ToolOutput>, ToolError>>,
39    pub cancel: CancellationToken,
40    /// Absolute wall-clock deadline; handle is cancelled by the sweeper when exceeded.
41    pub ttl_deadline: tokio::time::Instant,
42    pub started_at: Instant,
43}
44
45impl std::fmt::Debug for SpeculativeHandle {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("SpeculativeHandle")
48            .field("key", &self.key)
49            .field("ttl_deadline", &self.ttl_deadline)
50            .field("started_at", &self.started_at)
51            .finish_non_exhaustive()
52    }
53}
54
55impl SpeculativeHandle {
56    /// Cancel the in-flight task.
57    pub fn cancel(self) {
58        self.cancel.cancel();
59        self.join.abort();
60    }
61
62    /// Await the result; blocks until the task finishes or is cancelled.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`ToolError::Execution`] if the task was cancelled or panicked.
67    pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
68        match self.join.join().await {
69            Ok(r) => r,
70            Err(BlockingError::Panicked) => Err(ToolError::Execution(std::io::Error::other(
71                "speculative task panicked",
72            ))),
73            Err(BlockingError::SupervisorDropped) => Err(ToolError::Execution(
74                std::io::Error::other("speculative task cancelled"),
75            )),
76            Err(_) => Err(ToolError::Execution(std::io::Error::other(
77                "speculative task failed",
78            ))),
79        }
80    }
81}
82
83pub struct CacheInner {
84    pub handles: HashMap<HandleKey, SpeculativeHandle>,
85}
86
87/// Cache for in-flight speculative handles, bounded by `max_in_flight`.
88///
89/// Thread-safe; all operations hold a short `parking_lot::Mutex` lock.
90/// The inner `Arc<Mutex<CacheInner>>` is shared with the background TTL sweeper so
91/// both operate on the same handle set (C2: no separate empty instance in the sweeper).
92pub struct SpeculativeCache {
93    pub(crate) inner: Arc<Mutex<CacheInner>>,
94    max: usize,
95}
96
97impl SpeculativeCache {
98    /// Create a new cache with the given capacity.
99    #[must_use]
100    pub fn new(max_in_flight: usize) -> Self {
101        Self {
102            inner: Arc::new(Mutex::new(CacheInner {
103                handles: HashMap::new(),
104            })),
105            max: max_in_flight.clamp(1, 16),
106        }
107    }
108
109    /// Return a cloned `Arc` to the inner mutex so it can be shared with a sweeper task.
110    ///
111    /// The sweeper calls [`SpeculativeCache::sweep_expired_inner`] on the shared `Arc`
112    /// instead of constructing a second `SpeculativeCache` that would have separate storage.
113    #[must_use]
114    pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
115        Arc::clone(&self.inner)
116    }
117
118    /// Cancel and remove all handles whose TTL deadline has passed, operating on a raw `Arc`.
119    ///
120    /// Intended for use by the sweeper task, which holds only the `Arc` (not a full
121    /// `SpeculativeCache` wrapper).
122    pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
123        let now = tokio::time::Instant::now();
124        let mut g = inner.lock();
125        let expired: Vec<HandleKey> = g
126            .handles
127            .iter()
128            .filter(|(_, h)| h.ttl_deadline <= now)
129            .map(|(k, _)| k.clone())
130            .collect();
131        for key in expired {
132            if let Some(h) = g.handles.remove(&key) {
133                h.cancel();
134            }
135        }
136    }
137
138    /// Insert a new handle. If at capacity, evicts and cancels the oldest.
139    ///
140    /// If a handle with the same key already exists it is replaced and explicitly cancelled
141    /// so the underlying tokio task does not keep running (C4: no silent drop).
142    pub fn insert(&self, handle: SpeculativeHandle) {
143        let mut g = self.inner.lock();
144        if g.handles.len() >= self.max {
145            let oldest_key = g
146                .handles
147                .values()
148                .min_by_key(|h| h.started_at)
149                .map(|h| h.key.clone());
150            if let Some(key) = oldest_key
151                && let Some(evicted) = g.handles.remove(&key)
152            {
153                evicted.cancel();
154            }
155        }
156        if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
157            displaced.cancel();
158        }
159    }
160
161    /// Find and remove a handle matching `tool_id` + `args_hash` + `context_hash`.
162    #[must_use]
163    pub fn take_match(
164        &self,
165        tool_id: &ToolName,
166        args_hash: &blake3::Hash,
167        context_hash: &blake3::Hash,
168    ) -> Option<SpeculativeHandle> {
169        let key = HandleKey {
170            tool_id: tool_id.clone(),
171            args_hash: *args_hash,
172            context_hash: *context_hash,
173        };
174        self.inner.lock().handles.remove(&key)
175    }
176
177    /// Remove and cancel the first handle whose `tool_id` matches, if any.
178    ///
179    /// Used when the args hash is not known (e.g., on tool-id mismatch at dispatch time).
180    pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
181        let mut g = self.inner.lock();
182        let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
183        if let Some(key) = key
184            && let Some(h) = g.handles.remove(&key)
185        {
186            h.cancel();
187        }
188    }
189
190    /// Cancel and remove all handles whose TTL deadline has passed.
191    ///
192    /// Called by the sweeper task every 5 s.
193    pub fn sweep_expired(&self) {
194        Self::sweep_expired_inner(&self.inner);
195    }
196
197    /// Cancel and remove all remaining handles (called at turn boundary).
198    pub fn cancel_all(&self) {
199        let mut g = self.inner.lock();
200        for (_, h) in g.handles.drain() {
201            h.cancel();
202        }
203    }
204
205    /// Number of in-flight handles.
206    #[must_use]
207    pub fn len(&self) -> usize {
208        self.inner.lock().handles.len()
209    }
210
211    /// True when the cache is empty.
212    #[must_use]
213    pub fn is_empty(&self) -> bool {
214        self.len() == 0
215    }
216}
217
218/// Compute a BLAKE3 hash over a normalized JSON args map.
219///
220/// Keys are sorted lexicographically before hashing to ensure arg-order independence.
221#[must_use]
222pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
223    let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
224    keys.sort_unstable();
225    let mut hasher = blake3::Hasher::new();
226    for k in keys {
227        hasher.update(k.as_bytes());
228        hasher.update(b"\x00");
229        let v = args[k].to_string();
230        hasher.update(v.as_bytes());
231        hasher.update(b"\x00");
232    }
233    hasher.finalize()
234}
235
236/// Compute a BLAKE3 hash over the fields of an [`ExecutionContext`] that affect execution.
237///
238/// Two `ToolCall`s with the same args but different contexts must produce different keys so
239/// the speculative cache does not serve a result resolved under the wrong env/cwd.
240#[must_use]
241pub fn hash_context(ctx: Option<&ExecutionContext>) -> blake3::Hash {
242    let mut hasher = blake3::Hasher::new();
243    if let Some(ctx) = ctx {
244        if let Some(name) = ctx.name() {
245            hasher.update(b"name\x00");
246            hasher.update(name.as_bytes());
247            hasher.update(b"\x00");
248        }
249        if let Some(cwd) = ctx.cwd() {
250            hasher.update(b"cwd\x00");
251            hasher.update(cwd.as_os_str().as_encoded_bytes());
252            hasher.update(b"\x00");
253        }
254        // env_overrides is a BTreeMap so iteration order is already deterministic.
255        for (k, v) in ctx.env_overrides() {
256            hasher.update(b"env\x00");
257            hasher.update(k.as_bytes());
258            hasher.update(b"\x00");
259            hasher.update(v.as_bytes());
260            hasher.update(b"\x00");
261        }
262        hasher.update(if ctx.is_trusted() {
263            b"trusted"
264        } else {
265            b"untrusted"
266        });
267    }
268    // No context → all-zeros input → distinct hash from any populated context.
269    hasher.finalize()
270}
271
272/// Produce a normalized args template: top-level keys with their JSON type as placeholder value.
273///
274/// Used by `PatternStore` to store a template that is stable across observations with varying
275/// argument values. Example: `{"command":"<string>","timeout":"<number>"}`.
276#[must_use]
277pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
278    let template: serde_json::Map<String, serde_json::Value> = args
279        .iter()
280        .map(|(k, v)| {
281            let placeholder = match v {
282                serde_json::Value::String(_) => serde_json::json!("<string>"),
283                serde_json::Value::Number(_) => serde_json::json!("<number>"),
284                serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
285                serde_json::Value::Array(_) => serde_json::json!("<array>"),
286                serde_json::Value::Object(_) => serde_json::json!("<object>"),
287                serde_json::Value::Null => serde_json::json!(null),
288            };
289            (k.clone(), placeholder)
290        })
291        .collect();
292    serde_json::Value::Object(template).to_string()
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn hash_args_order_independent() {
301        let mut a = serde_json::Map::new();
302        a.insert("z".into(), serde_json::json!(1));
303        a.insert("a".into(), serde_json::json!(2));
304
305        let mut b = serde_json::Map::new();
306        b.insert("a".into(), serde_json::json!(2));
307        b.insert("z".into(), serde_json::json!(1));
308
309        assert_eq!(hash_args(&a), hash_args(&b));
310    }
311
312    #[test]
313    fn hash_args_different_values() {
314        let mut a = serde_json::Map::new();
315        a.insert("x".into(), serde_json::json!(1));
316        let mut b = serde_json::Map::new();
317        b.insert("x".into(), serde_json::json!(2));
318        assert_ne!(hash_args(&a), hash_args(&b));
319    }
320
321    #[test]
322    fn args_template_replaces_values_with_type_placeholders() {
323        let mut m = serde_json::Map::new();
324        m.insert("cmd".into(), serde_json::json!("ls -la"));
325        m.insert("timeout".into(), serde_json::json!(30));
326        m.insert("flag".into(), serde_json::json!(true));
327        let t = args_template(&m);
328        assert!(t.contains("<string>"));
329        assert!(t.contains("<number>"));
330        assert!(t.contains("<bool>"));
331    }
332}