zeph_core/agent/speculative/
cache.rs1#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio_util::sync::CancellationToken;
18use zeph_common::ToolName;
19use zeph_common::task_supervisor::{BlockingError, BlockingHandle};
20use zeph_tools::{ExecutionContext, ToolError, ToolOutput};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25 pub tool_id: ToolName,
26 pub args_hash: blake3::Hash,
27 pub context_hash: blake3::Hash,
30}
31
32pub struct SpeculativeHandle {
37 pub key: HandleKey,
38 pub join: BlockingHandle<Result<Option<ToolOutput>, ToolError>>,
39 pub cancel: CancellationToken,
40 pub ttl_deadline: tokio::time::Instant,
42 pub started_at: Instant,
43}
44
45impl std::fmt::Debug for SpeculativeHandle {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("SpeculativeHandle")
48 .field("key", &self.key)
49 .field("ttl_deadline", &self.ttl_deadline)
50 .field("started_at", &self.started_at)
51 .finish_non_exhaustive()
52 }
53}
54
55impl SpeculativeHandle {
56 pub fn cancel(self) {
58 self.cancel.cancel();
59 self.join.abort();
60 }
61
62 pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
68 match self.join.join().await {
69 Ok(r) => r,
70 Err(BlockingError::Panicked) => Err(ToolError::Execution(std::io::Error::other(
71 "speculative task panicked",
72 ))),
73 Err(BlockingError::SupervisorDropped) => Err(ToolError::Execution(
74 std::io::Error::other("speculative task cancelled"),
75 )),
76 Err(_) => Err(ToolError::Execution(std::io::Error::other(
77 "speculative task failed",
78 ))),
79 }
80 }
81}
82
83pub struct CacheInner {
84 pub handles: HashMap<HandleKey, SpeculativeHandle>,
85}
86
87pub struct SpeculativeCache {
93 pub(crate) inner: Arc<Mutex<CacheInner>>,
94 max: usize,
95}
96
97impl SpeculativeCache {
98 #[must_use]
100 pub fn new(max_in_flight: usize) -> Self {
101 Self {
102 inner: Arc::new(Mutex::new(CacheInner {
103 handles: HashMap::new(),
104 })),
105 max: max_in_flight.clamp(1, 16),
106 }
107 }
108
109 #[must_use]
114 pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
115 Arc::clone(&self.inner)
116 }
117
118 pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
123 let now = tokio::time::Instant::now();
124 let mut g = inner.lock();
125 let expired: Vec<HandleKey> = g
126 .handles
127 .iter()
128 .filter(|(_, h)| h.ttl_deadline <= now)
129 .map(|(k, _)| k.clone())
130 .collect();
131 for key in expired {
132 if let Some(h) = g.handles.remove(&key) {
133 h.cancel();
134 }
135 }
136 }
137
138 pub fn insert(&self, handle: SpeculativeHandle) {
143 let mut g = self.inner.lock();
144 if g.handles.len() >= self.max {
145 let oldest_key = g
146 .handles
147 .values()
148 .min_by_key(|h| h.started_at)
149 .map(|h| h.key.clone());
150 if let Some(key) = oldest_key
151 && let Some(evicted) = g.handles.remove(&key)
152 {
153 evicted.cancel();
154 }
155 }
156 if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
157 displaced.cancel();
158 }
159 }
160
161 #[must_use]
163 pub fn take_match(
164 &self,
165 tool_id: &ToolName,
166 args_hash: &blake3::Hash,
167 context_hash: &blake3::Hash,
168 ) -> Option<SpeculativeHandle> {
169 let key = HandleKey {
170 tool_id: tool_id.clone(),
171 args_hash: *args_hash,
172 context_hash: *context_hash,
173 };
174 self.inner.lock().handles.remove(&key)
175 }
176
177 pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
181 let mut g = self.inner.lock();
182 let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
183 if let Some(key) = key
184 && let Some(h) = g.handles.remove(&key)
185 {
186 h.cancel();
187 }
188 }
189
190 pub fn sweep_expired(&self) {
194 Self::sweep_expired_inner(&self.inner);
195 }
196
197 pub fn cancel_all(&self) {
199 let mut g = self.inner.lock();
200 for (_, h) in g.handles.drain() {
201 h.cancel();
202 }
203 }
204
205 #[must_use]
207 pub fn len(&self) -> usize {
208 self.inner.lock().handles.len()
209 }
210
211 #[must_use]
213 pub fn is_empty(&self) -> bool {
214 self.len() == 0
215 }
216}
217
218#[must_use]
222pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
223 let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
224 keys.sort_unstable();
225 let mut hasher = blake3::Hasher::new();
226 for k in keys {
227 hasher.update(k.as_bytes());
228 hasher.update(b"\x00");
229 let v = args[k].to_string();
230 hasher.update(v.as_bytes());
231 hasher.update(b"\x00");
232 }
233 hasher.finalize()
234}
235
236#[must_use]
241pub fn hash_context(ctx: Option<&ExecutionContext>) -> blake3::Hash {
242 let mut hasher = blake3::Hasher::new();
243 if let Some(ctx) = ctx {
244 if let Some(name) = ctx.name() {
245 hasher.update(b"name\x00");
246 hasher.update(name.as_bytes());
247 hasher.update(b"\x00");
248 }
249 if let Some(cwd) = ctx.cwd() {
250 hasher.update(b"cwd\x00");
251 hasher.update(cwd.as_os_str().as_encoded_bytes());
252 hasher.update(b"\x00");
253 }
254 for (k, v) in ctx.env_overrides() {
256 hasher.update(b"env\x00");
257 hasher.update(k.as_bytes());
258 hasher.update(b"\x00");
259 hasher.update(v.as_bytes());
260 hasher.update(b"\x00");
261 }
262 hasher.update(if ctx.is_trusted() {
263 b"trusted"
264 } else {
265 b"untrusted"
266 });
267 }
268 hasher.finalize()
270}
271
272#[must_use]
277pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
278 let template: serde_json::Map<String, serde_json::Value> = args
279 .iter()
280 .map(|(k, v)| {
281 let placeholder = match v {
282 serde_json::Value::String(_) => serde_json::json!("<string>"),
283 serde_json::Value::Number(_) => serde_json::json!("<number>"),
284 serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
285 serde_json::Value::Array(_) => serde_json::json!("<array>"),
286 serde_json::Value::Object(_) => serde_json::json!("<object>"),
287 serde_json::Value::Null => serde_json::json!(null),
288 };
289 (k.clone(), placeholder)
290 })
291 .collect();
292 serde_json::Value::Object(template).to_string()
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn hash_args_order_independent() {
301 let mut a = serde_json::Map::new();
302 a.insert("z".into(), serde_json::json!(1));
303 a.insert("a".into(), serde_json::json!(2));
304
305 let mut b = serde_json::Map::new();
306 b.insert("a".into(), serde_json::json!(2));
307 b.insert("z".into(), serde_json::json!(1));
308
309 assert_eq!(hash_args(&a), hash_args(&b));
310 }
311
312 #[test]
313 fn hash_args_different_values() {
314 let mut a = serde_json::Map::new();
315 a.insert("x".into(), serde_json::json!(1));
316 let mut b = serde_json::Map::new();
317 b.insert("x".into(), serde_json::json!(2));
318 assert_ne!(hash_args(&a), hash_args(&b));
319 }
320
321 #[test]
322 fn args_template_replaces_values_with_type_placeholders() {
323 let mut m = serde_json::Map::new();
324 m.insert("cmd".into(), serde_json::json!("ls -la"));
325 m.insert("timeout".into(), serde_json::json!(30));
326 m.insert("flag".into(), serde_json::json!(true));
327 let t = args_template(&m);
328 assert!(t.contains("<string>"));
329 assert!(t.contains("<number>"));
330 assert!(t.contains("<bool>"));
331 }
332}