zeph_core/agent/speculative/
cache.rs1#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio::task::JoinHandle;
18use tokio_util::sync::CancellationToken;
19use zeph_common::ToolName;
20use zeph_tools::{ToolError, ToolOutput};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25 pub tool_id: ToolName,
26 pub args_hash: blake3::Hash,
27}
28
29pub struct SpeculativeHandle {
34 pub key: HandleKey,
35 pub join: JoinHandle<Result<Option<ToolOutput>, ToolError>>,
36 pub cancel: CancellationToken,
37 pub ttl_deadline: tokio::time::Instant,
39 pub started_at: Instant,
40}
41
42impl std::fmt::Debug for SpeculativeHandle {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("SpeculativeHandle")
45 .field("key", &self.key)
46 .field("ttl_deadline", &self.ttl_deadline)
47 .field("started_at", &self.started_at)
48 .finish_non_exhaustive()
49 }
50}
51
52impl SpeculativeHandle {
53 pub fn cancel(self) {
55 self.cancel.cancel();
56 self.join.abort();
57 }
58
59 pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
65 match self.join.await {
66 Ok(r) => r,
67 Err(e) if e.is_cancelled() => Err(ToolError::Execution(std::io::Error::other(
68 "speculative task cancelled",
69 ))),
70 Err(e) => Err(ToolError::Execution(std::io::Error::other(e.to_string()))),
71 }
72 }
73}
74
75pub struct CacheInner {
76 pub handles: HashMap<HandleKey, SpeculativeHandle>,
77}
78
79pub struct SpeculativeCache {
85 pub(crate) inner: Arc<Mutex<CacheInner>>,
86 max: usize,
87}
88
89impl SpeculativeCache {
90 #[must_use]
92 pub fn new(max_in_flight: usize) -> Self {
93 Self {
94 inner: Arc::new(Mutex::new(CacheInner {
95 handles: HashMap::new(),
96 })),
97 max: max_in_flight.clamp(1, 16),
98 }
99 }
100
101 #[must_use]
106 pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
107 Arc::clone(&self.inner)
108 }
109
110 pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
115 let now = tokio::time::Instant::now();
116 let mut g = inner.lock();
117 let expired: Vec<HandleKey> = g
118 .handles
119 .iter()
120 .filter(|(_, h)| h.ttl_deadline <= now)
121 .map(|(k, _)| k.clone())
122 .collect();
123 for key in expired {
124 if let Some(h) = g.handles.remove(&key) {
125 h.cancel();
126 }
127 }
128 }
129
130 pub fn insert(&self, handle: SpeculativeHandle) {
135 let mut g = self.inner.lock();
136 if g.handles.len() >= self.max {
137 let oldest_key = g
138 .handles
139 .values()
140 .min_by_key(|h| h.started_at)
141 .map(|h| h.key.clone());
142 if let Some(key) = oldest_key
143 && let Some(evicted) = g.handles.remove(&key)
144 {
145 evicted.cancel();
146 }
147 }
148 if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
149 displaced.cancel();
150 }
151 }
152
153 #[must_use]
155 pub fn take_match(
156 &self,
157 tool_id: &ToolName,
158 args_hash: &blake3::Hash,
159 ) -> Option<SpeculativeHandle> {
160 let key = HandleKey {
161 tool_id: tool_id.clone(),
162 args_hash: *args_hash,
163 };
164 self.inner.lock().handles.remove(&key)
165 }
166
167 pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
171 let mut g = self.inner.lock();
172 let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
173 if let Some(key) = key
174 && let Some(h) = g.handles.remove(&key)
175 {
176 h.cancel();
177 }
178 }
179
180 pub fn sweep_expired(&self) {
184 Self::sweep_expired_inner(&self.inner);
185 }
186
187 pub fn cancel_all(&self) {
189 let mut g = self.inner.lock();
190 for (_, h) in g.handles.drain() {
191 h.cancel();
192 }
193 }
194
195 #[must_use]
197 pub fn len(&self) -> usize {
198 self.inner.lock().handles.len()
199 }
200
201 #[must_use]
203 pub fn is_empty(&self) -> bool {
204 self.len() == 0
205 }
206}
207
208#[must_use]
212pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
213 let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
214 keys.sort_unstable();
215 let mut hasher = blake3::Hasher::new();
216 for k in keys {
217 hasher.update(k.as_bytes());
218 hasher.update(b"\x00");
219 let v = args[k].to_string();
220 hasher.update(v.as_bytes());
221 hasher.update(b"\x00");
222 }
223 hasher.finalize()
224}
225
226#[must_use]
231pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
232 let template: serde_json::Map<String, serde_json::Value> = args
233 .iter()
234 .map(|(k, v)| {
235 let placeholder = match v {
236 serde_json::Value::String(_) => serde_json::json!("<string>"),
237 serde_json::Value::Number(_) => serde_json::json!("<number>"),
238 serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
239 serde_json::Value::Array(_) => serde_json::json!("<array>"),
240 serde_json::Value::Object(_) => serde_json::json!("<object>"),
241 serde_json::Value::Null => serde_json::json!(null),
242 };
243 (k.clone(), placeholder)
244 })
245 .collect();
246 serde_json::Value::Object(template).to_string()
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn hash_args_order_independent() {
255 let mut a = serde_json::Map::new();
256 a.insert("z".into(), serde_json::json!(1));
257 a.insert("a".into(), serde_json::json!(2));
258
259 let mut b = serde_json::Map::new();
260 b.insert("a".into(), serde_json::json!(2));
261 b.insert("z".into(), serde_json::json!(1));
262
263 assert_eq!(hash_args(&a), hash_args(&b));
264 }
265
266 #[test]
267 fn hash_args_different_values() {
268 let mut a = serde_json::Map::new();
269 a.insert("x".into(), serde_json::json!(1));
270 let mut b = serde_json::Map::new();
271 b.insert("x".into(), serde_json::json!(2));
272 assert_ne!(hash_args(&a), hash_args(&b));
273 }
274
275 #[test]
276 fn args_template_replaces_values_with_type_placeholders() {
277 let mut m = serde_json::Map::new();
278 m.insert("cmd".into(), serde_json::json!("ls -la"));
279 m.insert("timeout".into(), serde_json::json!(30));
280 m.insert("flag".into(), serde_json::json!(true));
281 let t = args_template(&m);
282 assert!(t.contains("<string>"));
283 assert!(t.contains("<number>"));
284 assert!(t.contains("<bool>"));
285 }
286}