zeph_core/agent/speculative/
cache.rs1#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio_util::sync::CancellationToken;
18use zeph_common::ToolName;
19use zeph_common::task_supervisor::{BlockingError, BlockingHandle};
20use zeph_tools::{ExecutionContext, ToolError, ToolOutput};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25 pub tool_id: ToolName,
26 pub args_hash: blake3::Hash,
27 pub context_hash: blake3::Hash,
30}
31
32pub struct SpeculativeHandle {
37 pub key: HandleKey,
38 pub join: BlockingHandle<Result<Option<ToolOutput>, ToolError>>,
39 pub cancel: CancellationToken,
40 pub ttl_deadline: tokio::time::Instant,
42 pub started_at: Instant,
43}
44
45impl std::fmt::Debug for SpeculativeHandle {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("SpeculativeHandle")
48 .field("key", &self.key)
49 .field("ttl_deadline", &self.ttl_deadline)
50 .field("started_at", &self.started_at)
51 .finish_non_exhaustive()
52 }
53}
54
55impl SpeculativeHandle {
56 pub fn cancel(self) {
58 self.cancel.cancel();
59 self.join.abort();
60 }
61
62 pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
68 match self.join.join().await {
69 Ok(r) => r,
70 Err(BlockingError::Panicked) => Err(ToolError::Execution(std::io::Error::other(
71 "speculative task panicked",
72 ))),
73 Err(BlockingError::SupervisorDropped) => Err(ToolError::Execution(
74 std::io::Error::other("speculative task cancelled"),
75 )),
76 }
77 }
78}
79
80pub struct CacheInner {
81 pub handles: HashMap<HandleKey, SpeculativeHandle>,
82}
83
84pub struct SpeculativeCache {
90 pub(crate) inner: Arc<Mutex<CacheInner>>,
91 max: usize,
92}
93
94impl SpeculativeCache {
95 #[must_use]
97 pub fn new(max_in_flight: usize) -> Self {
98 Self {
99 inner: Arc::new(Mutex::new(CacheInner {
100 handles: HashMap::new(),
101 })),
102 max: max_in_flight.clamp(1, 16),
103 }
104 }
105
106 #[must_use]
111 pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
112 Arc::clone(&self.inner)
113 }
114
115 pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
120 let now = tokio::time::Instant::now();
121 let mut g = inner.lock();
122 let expired: Vec<HandleKey> = g
123 .handles
124 .iter()
125 .filter(|(_, h)| h.ttl_deadline <= now)
126 .map(|(k, _)| k.clone())
127 .collect();
128 for key in expired {
129 if let Some(h) = g.handles.remove(&key) {
130 h.cancel();
131 }
132 }
133 }
134
135 pub fn insert(&self, handle: SpeculativeHandle) {
140 let mut g = self.inner.lock();
141 if g.handles.len() >= self.max {
142 let oldest_key = g
143 .handles
144 .values()
145 .min_by_key(|h| h.started_at)
146 .map(|h| h.key.clone());
147 if let Some(key) = oldest_key
148 && let Some(evicted) = g.handles.remove(&key)
149 {
150 evicted.cancel();
151 }
152 }
153 if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
154 displaced.cancel();
155 }
156 }
157
158 #[must_use]
160 pub fn take_match(
161 &self,
162 tool_id: &ToolName,
163 args_hash: &blake3::Hash,
164 context_hash: &blake3::Hash,
165 ) -> Option<SpeculativeHandle> {
166 let key = HandleKey {
167 tool_id: tool_id.clone(),
168 args_hash: *args_hash,
169 context_hash: *context_hash,
170 };
171 self.inner.lock().handles.remove(&key)
172 }
173
174 pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
178 let mut g = self.inner.lock();
179 let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
180 if let Some(key) = key
181 && let Some(h) = g.handles.remove(&key)
182 {
183 h.cancel();
184 }
185 }
186
187 pub fn sweep_expired(&self) {
191 Self::sweep_expired_inner(&self.inner);
192 }
193
194 pub fn cancel_all(&self) {
196 let mut g = self.inner.lock();
197 for (_, h) in g.handles.drain() {
198 h.cancel();
199 }
200 }
201
202 #[must_use]
204 pub fn len(&self) -> usize {
205 self.inner.lock().handles.len()
206 }
207
208 #[must_use]
210 pub fn is_empty(&self) -> bool {
211 self.len() == 0
212 }
213}
214
215#[must_use]
219pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
220 let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
221 keys.sort_unstable();
222 let mut hasher = blake3::Hasher::new();
223 for k in keys {
224 hasher.update(k.as_bytes());
225 hasher.update(b"\x00");
226 let v = args[k].to_string();
227 hasher.update(v.as_bytes());
228 hasher.update(b"\x00");
229 }
230 hasher.finalize()
231}
232
233#[must_use]
238pub fn hash_context(ctx: Option<&ExecutionContext>) -> blake3::Hash {
239 let mut hasher = blake3::Hasher::new();
240 if let Some(ctx) = ctx {
241 if let Some(name) = ctx.name() {
242 hasher.update(b"name\x00");
243 hasher.update(name.as_bytes());
244 hasher.update(b"\x00");
245 }
246 if let Some(cwd) = ctx.cwd() {
247 hasher.update(b"cwd\x00");
248 hasher.update(cwd.as_os_str().as_encoded_bytes());
249 hasher.update(b"\x00");
250 }
251 for (k, v) in ctx.env_overrides() {
253 hasher.update(b"env\x00");
254 hasher.update(k.as_bytes());
255 hasher.update(b"\x00");
256 hasher.update(v.as_bytes());
257 hasher.update(b"\x00");
258 }
259 hasher.update(if ctx.is_trusted() {
260 b"trusted"
261 } else {
262 b"untrusted"
263 });
264 }
265 hasher.finalize()
267}
268
269#[must_use]
274pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
275 let template: serde_json::Map<String, serde_json::Value> = args
276 .iter()
277 .map(|(k, v)| {
278 let placeholder = match v {
279 serde_json::Value::String(_) => serde_json::json!("<string>"),
280 serde_json::Value::Number(_) => serde_json::json!("<number>"),
281 serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
282 serde_json::Value::Array(_) => serde_json::json!("<array>"),
283 serde_json::Value::Object(_) => serde_json::json!("<object>"),
284 serde_json::Value::Null => serde_json::json!(null),
285 };
286 (k.clone(), placeholder)
287 })
288 .collect();
289 serde_json::Value::Object(template).to_string()
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn hash_args_order_independent() {
298 let mut a = serde_json::Map::new();
299 a.insert("z".into(), serde_json::json!(1));
300 a.insert("a".into(), serde_json::json!(2));
301
302 let mut b = serde_json::Map::new();
303 b.insert("a".into(), serde_json::json!(2));
304 b.insert("z".into(), serde_json::json!(1));
305
306 assert_eq!(hash_args(&a), hash_args(&b));
307 }
308
309 #[test]
310 fn hash_args_different_values() {
311 let mut a = serde_json::Map::new();
312 a.insert("x".into(), serde_json::json!(1));
313 let mut b = serde_json::Map::new();
314 b.insert("x".into(), serde_json::json!(2));
315 assert_ne!(hash_args(&a), hash_args(&b));
316 }
317
318 #[test]
319 fn args_template_replaces_values_with_type_placeholders() {
320 let mut m = serde_json::Map::new();
321 m.insert("cmd".into(), serde_json::json!("ls -la"));
322 m.insert("timeout".into(), serde_json::json!(30));
323 m.insert("flag".into(), serde_json::json!(true));
324 let t = args_template(&m);
325 assert!(t.contains("<string>"));
326 assert!(t.contains("<number>"));
327 assert!(t.contains("<bool>"));
328 }
329}