Skip to main content

zeph_core/agent/speculative/
cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! In-flight speculative handle cache.
5//!
6//! Keyed by `(ToolName, blake3::Hash)` where the hash covers the tool's argument map.
7//! Bounded by `max_in_flight`; oldest handle (by `started_at`) is evicted and cancelled
8//! when the bound is exceeded.
9
10#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio_util::sync::CancellationToken;
18use zeph_common::ToolName;
19use zeph_common::task_supervisor::{BlockingError, BlockingHandle};
20use zeph_tools::{ExecutionContext, ToolError, ToolOutput};
21
22/// Unique key for a speculative handle: tool name + BLAKE3 hash of normalized args + context.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25    pub tool_id: ToolName,
26    pub args_hash: blake3::Hash,
27    /// Hash of the [`ExecutionContext`] fields. Two calls with different contexts must not share
28    /// a speculative result — the resolved env/cwd would differ.
29    pub context_hash: blake3::Hash,
30}
31
32/// An in-flight speculative execution handle.
33///
34/// Created when the engine dispatches a speculative tool call. Committed when the LLM
35/// confirms the same call on `ToolUseStop`; cancelled on mismatch or TTL expiry.
36pub struct SpeculativeHandle {
37    pub key: HandleKey,
38    pub join: BlockingHandle<Result<Option<ToolOutput>, ToolError>>,
39    pub cancel: CancellationToken,
40    /// Absolute wall-clock deadline; handle is cancelled by the sweeper when exceeded.
41    pub ttl_deadline: tokio::time::Instant,
42    pub started_at: Instant,
43}
44
45impl std::fmt::Debug for SpeculativeHandle {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("SpeculativeHandle")
48            .field("key", &self.key)
49            .field("ttl_deadline", &self.ttl_deadline)
50            .field("started_at", &self.started_at)
51            .finish_non_exhaustive()
52    }
53}
54
55impl SpeculativeHandle {
56    /// Cancel the in-flight task.
57    pub fn cancel(self) {
58        self.cancel.cancel();
59        self.join.abort();
60    }
61
62    /// Await the result; blocks until the task finishes or is cancelled.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`ToolError::Execution`] if the task was cancelled or panicked.
67    pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
68        match self.join.join().await {
69            Ok(r) => r,
70            Err(BlockingError::Panicked) => Err(ToolError::Execution(std::io::Error::other(
71                "speculative task panicked",
72            ))),
73            Err(BlockingError::SupervisorDropped) => Err(ToolError::Execution(
74                std::io::Error::other("speculative task cancelled"),
75            )),
76        }
77    }
78}
79
80pub struct CacheInner {
81    pub handles: HashMap<HandleKey, SpeculativeHandle>,
82}
83
84/// Cache for in-flight speculative handles, bounded by `max_in_flight`.
85///
86/// Thread-safe; all operations hold a short `parking_lot::Mutex` lock.
87/// The inner `Arc<Mutex<CacheInner>>` is shared with the background TTL sweeper so
88/// both operate on the same handle set (C2: no separate empty instance in the sweeper).
89pub struct SpeculativeCache {
90    pub(crate) inner: Arc<Mutex<CacheInner>>,
91    max: usize,
92}
93
94impl SpeculativeCache {
95    /// Create a new cache with the given capacity.
96    #[must_use]
97    pub fn new(max_in_flight: usize) -> Self {
98        Self {
99            inner: Arc::new(Mutex::new(CacheInner {
100                handles: HashMap::new(),
101            })),
102            max: max_in_flight.clamp(1, 16),
103        }
104    }
105
106    /// Return a cloned `Arc` to the inner mutex so it can be shared with a sweeper task.
107    ///
108    /// The sweeper calls [`SpeculativeCache::sweep_expired_inner`] on the shared `Arc`
109    /// instead of constructing a second `SpeculativeCache` that would have separate storage.
110    #[must_use]
111    pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
112        Arc::clone(&self.inner)
113    }
114
115    /// Cancel and remove all handles whose TTL deadline has passed, operating on a raw `Arc`.
116    ///
117    /// Intended for use by the sweeper task, which holds only the `Arc` (not a full
118    /// `SpeculativeCache` wrapper).
119    pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
120        let now = tokio::time::Instant::now();
121        let mut g = inner.lock();
122        let expired: Vec<HandleKey> = g
123            .handles
124            .iter()
125            .filter(|(_, h)| h.ttl_deadline <= now)
126            .map(|(k, _)| k.clone())
127            .collect();
128        for key in expired {
129            if let Some(h) = g.handles.remove(&key) {
130                h.cancel();
131            }
132        }
133    }
134
135    /// Insert a new handle. If at capacity, evicts and cancels the oldest.
136    ///
137    /// If a handle with the same key already exists it is replaced and explicitly cancelled
138    /// so the underlying tokio task does not keep running (C4: no silent drop).
139    pub fn insert(&self, handle: SpeculativeHandle) {
140        let mut g = self.inner.lock();
141        if g.handles.len() >= self.max {
142            let oldest_key = g
143                .handles
144                .values()
145                .min_by_key(|h| h.started_at)
146                .map(|h| h.key.clone());
147            if let Some(key) = oldest_key
148                && let Some(evicted) = g.handles.remove(&key)
149            {
150                evicted.cancel();
151            }
152        }
153        if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
154            displaced.cancel();
155        }
156    }
157
158    /// Find and remove a handle matching `tool_id` + `args_hash` + `context_hash`.
159    #[must_use]
160    pub fn take_match(
161        &self,
162        tool_id: &ToolName,
163        args_hash: &blake3::Hash,
164        context_hash: &blake3::Hash,
165    ) -> Option<SpeculativeHandle> {
166        let key = HandleKey {
167            tool_id: tool_id.clone(),
168            args_hash: *args_hash,
169            context_hash: *context_hash,
170        };
171        self.inner.lock().handles.remove(&key)
172    }
173
174    /// Remove and cancel the first handle whose `tool_id` matches, if any.
175    ///
176    /// Used when the args hash is not known (e.g., on tool-id mismatch at dispatch time).
177    pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
178        let mut g = self.inner.lock();
179        let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
180        if let Some(key) = key
181            && let Some(h) = g.handles.remove(&key)
182        {
183            h.cancel();
184        }
185    }
186
187    /// Cancel and remove all handles whose TTL deadline has passed.
188    ///
189    /// Called by the sweeper task every 5 s.
190    pub fn sweep_expired(&self) {
191        Self::sweep_expired_inner(&self.inner);
192    }
193
194    /// Cancel and remove all remaining handles (called at turn boundary).
195    pub fn cancel_all(&self) {
196        let mut g = self.inner.lock();
197        for (_, h) in g.handles.drain() {
198            h.cancel();
199        }
200    }
201
202    /// Number of in-flight handles.
203    #[must_use]
204    pub fn len(&self) -> usize {
205        self.inner.lock().handles.len()
206    }
207
208    /// True when the cache is empty.
209    #[must_use]
210    pub fn is_empty(&self) -> bool {
211        self.len() == 0
212    }
213}
214
215/// Compute a BLAKE3 hash over a normalized JSON args map.
216///
217/// Keys are sorted lexicographically before hashing to ensure arg-order independence.
218#[must_use]
219pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
220    let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
221    keys.sort_unstable();
222    let mut hasher = blake3::Hasher::new();
223    for k in keys {
224        hasher.update(k.as_bytes());
225        hasher.update(b"\x00");
226        let v = args[k].to_string();
227        hasher.update(v.as_bytes());
228        hasher.update(b"\x00");
229    }
230    hasher.finalize()
231}
232
233/// Compute a BLAKE3 hash over the fields of an [`ExecutionContext`] that affect execution.
234///
235/// Two `ToolCall`s with the same args but different contexts must produce different keys so
236/// the speculative cache does not serve a result resolved under the wrong env/cwd.
237#[must_use]
238pub fn hash_context(ctx: Option<&ExecutionContext>) -> blake3::Hash {
239    let mut hasher = blake3::Hasher::new();
240    if let Some(ctx) = ctx {
241        if let Some(name) = ctx.name() {
242            hasher.update(b"name\x00");
243            hasher.update(name.as_bytes());
244            hasher.update(b"\x00");
245        }
246        if let Some(cwd) = ctx.cwd() {
247            hasher.update(b"cwd\x00");
248            hasher.update(cwd.as_os_str().as_encoded_bytes());
249            hasher.update(b"\x00");
250        }
251        // env_overrides is a BTreeMap so iteration order is already deterministic.
252        for (k, v) in ctx.env_overrides() {
253            hasher.update(b"env\x00");
254            hasher.update(k.as_bytes());
255            hasher.update(b"\x00");
256            hasher.update(v.as_bytes());
257            hasher.update(b"\x00");
258        }
259        hasher.update(if ctx.is_trusted() {
260            b"trusted"
261        } else {
262            b"untrusted"
263        });
264    }
265    // No context → all-zeros input → distinct hash from any populated context.
266    hasher.finalize()
267}
268
269/// Produce a normalized args template: top-level keys with their JSON type as placeholder value.
270///
271/// Used by `PatternStore` to store a template that is stable across observations with varying
272/// argument values. Example: `{"command":"<string>","timeout":"<number>"}`.
273#[must_use]
274pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
275    let template: serde_json::Map<String, serde_json::Value> = args
276        .iter()
277        .map(|(k, v)| {
278            let placeholder = match v {
279                serde_json::Value::String(_) => serde_json::json!("<string>"),
280                serde_json::Value::Number(_) => serde_json::json!("<number>"),
281                serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
282                serde_json::Value::Array(_) => serde_json::json!("<array>"),
283                serde_json::Value::Object(_) => serde_json::json!("<object>"),
284                serde_json::Value::Null => serde_json::json!(null),
285            };
286            (k.clone(), placeholder)
287        })
288        .collect();
289    serde_json::Value::Object(template).to_string()
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn hash_args_order_independent() {
298        let mut a = serde_json::Map::new();
299        a.insert("z".into(), serde_json::json!(1));
300        a.insert("a".into(), serde_json::json!(2));
301
302        let mut b = serde_json::Map::new();
303        b.insert("a".into(), serde_json::json!(2));
304        b.insert("z".into(), serde_json::json!(1));
305
306        assert_eq!(hash_args(&a), hash_args(&b));
307    }
308
309    #[test]
310    fn hash_args_different_values() {
311        let mut a = serde_json::Map::new();
312        a.insert("x".into(), serde_json::json!(1));
313        let mut b = serde_json::Map::new();
314        b.insert("x".into(), serde_json::json!(2));
315        assert_ne!(hash_args(&a), hash_args(&b));
316    }
317
318    #[test]
319    fn args_template_replaces_values_with_type_placeholders() {
320        let mut m = serde_json::Map::new();
321        m.insert("cmd".into(), serde_json::json!("ls -la"));
322        m.insert("timeout".into(), serde_json::json!(30));
323        m.insert("flag".into(), serde_json::json!(true));
324        let t = args_template(&m);
325        assert!(t.contains("<string>"));
326        assert!(t.contains("<number>"));
327        assert!(t.contains("<bool>"));
328    }
329}