zeph_core/agent/speculative/
cache.rs1#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15
16use parking_lot::Mutex;
17use tokio_util::sync::CancellationToken;
18use zeph_common::ToolName;
19use zeph_common::task_supervisor::{BlockingError, BlockingHandle};
20use zeph_tools::{ToolError, ToolOutput};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct HandleKey {
25 pub tool_id: ToolName,
26 pub args_hash: blake3::Hash,
27}
28
29pub struct SpeculativeHandle {
34 pub key: HandleKey,
35 pub join: BlockingHandle<Result<Option<ToolOutput>, ToolError>>,
36 pub cancel: CancellationToken,
37 pub ttl_deadline: tokio::time::Instant,
39 pub started_at: Instant,
40}
41
42impl std::fmt::Debug for SpeculativeHandle {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("SpeculativeHandle")
45 .field("key", &self.key)
46 .field("ttl_deadline", &self.ttl_deadline)
47 .field("started_at", &self.started_at)
48 .finish_non_exhaustive()
49 }
50}
51
52impl SpeculativeHandle {
53 pub fn cancel(self) {
55 self.cancel.cancel();
56 self.join.abort();
57 }
58
59 pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
65 match self.join.join().await {
66 Ok(r) => r,
67 Err(BlockingError::Panicked) => Err(ToolError::Execution(std::io::Error::other(
68 "speculative task panicked",
69 ))),
70 Err(BlockingError::SupervisorDropped) => Err(ToolError::Execution(
71 std::io::Error::other("speculative task cancelled"),
72 )),
73 }
74 }
75}
76
77pub struct CacheInner {
78 pub handles: HashMap<HandleKey, SpeculativeHandle>,
79}
80
81pub struct SpeculativeCache {
87 pub(crate) inner: Arc<Mutex<CacheInner>>,
88 max: usize,
89}
90
91impl SpeculativeCache {
92 #[must_use]
94 pub fn new(max_in_flight: usize) -> Self {
95 Self {
96 inner: Arc::new(Mutex::new(CacheInner {
97 handles: HashMap::new(),
98 })),
99 max: max_in_flight.clamp(1, 16),
100 }
101 }
102
103 #[must_use]
108 pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
109 Arc::clone(&self.inner)
110 }
111
112 pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
117 let now = tokio::time::Instant::now();
118 let mut g = inner.lock();
119 let expired: Vec<HandleKey> = g
120 .handles
121 .iter()
122 .filter(|(_, h)| h.ttl_deadline <= now)
123 .map(|(k, _)| k.clone())
124 .collect();
125 for key in expired {
126 if let Some(h) = g.handles.remove(&key) {
127 h.cancel();
128 }
129 }
130 }
131
132 pub fn insert(&self, handle: SpeculativeHandle) {
137 let mut g = self.inner.lock();
138 if g.handles.len() >= self.max {
139 let oldest_key = g
140 .handles
141 .values()
142 .min_by_key(|h| h.started_at)
143 .map(|h| h.key.clone());
144 if let Some(key) = oldest_key
145 && let Some(evicted) = g.handles.remove(&key)
146 {
147 evicted.cancel();
148 }
149 }
150 if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
151 displaced.cancel();
152 }
153 }
154
155 #[must_use]
157 pub fn take_match(
158 &self,
159 tool_id: &ToolName,
160 args_hash: &blake3::Hash,
161 ) -> Option<SpeculativeHandle> {
162 let key = HandleKey {
163 tool_id: tool_id.clone(),
164 args_hash: *args_hash,
165 };
166 self.inner.lock().handles.remove(&key)
167 }
168
169 pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
173 let mut g = self.inner.lock();
174 let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
175 if let Some(key) = key
176 && let Some(h) = g.handles.remove(&key)
177 {
178 h.cancel();
179 }
180 }
181
182 pub fn sweep_expired(&self) {
186 Self::sweep_expired_inner(&self.inner);
187 }
188
189 pub fn cancel_all(&self) {
191 let mut g = self.inner.lock();
192 for (_, h) in g.handles.drain() {
193 h.cancel();
194 }
195 }
196
197 #[must_use]
199 pub fn len(&self) -> usize {
200 self.inner.lock().handles.len()
201 }
202
203 #[must_use]
205 pub fn is_empty(&self) -> bool {
206 self.len() == 0
207 }
208}
209
210#[must_use]
214pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
215 let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
216 keys.sort_unstable();
217 let mut hasher = blake3::Hasher::new();
218 for k in keys {
219 hasher.update(k.as_bytes());
220 hasher.update(b"\x00");
221 let v = args[k].to_string();
222 hasher.update(v.as_bytes());
223 hasher.update(b"\x00");
224 }
225 hasher.finalize()
226}
227
228#[must_use]
233pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
234 let template: serde_json::Map<String, serde_json::Value> = args
235 .iter()
236 .map(|(k, v)| {
237 let placeholder = match v {
238 serde_json::Value::String(_) => serde_json::json!("<string>"),
239 serde_json::Value::Number(_) => serde_json::json!("<number>"),
240 serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
241 serde_json::Value::Array(_) => serde_json::json!("<array>"),
242 serde_json::Value::Object(_) => serde_json::json!("<object>"),
243 serde_json::Value::Null => serde_json::json!(null),
244 };
245 (k.clone(), placeholder)
246 })
247 .collect();
248 serde_json::Value::Object(template).to_string()
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn hash_args_order_independent() {
257 let mut a = serde_json::Map::new();
258 a.insert("z".into(), serde_json::json!(1));
259 a.insert("a".into(), serde_json::json!(2));
260
261 let mut b = serde_json::Map::new();
262 b.insert("a".into(), serde_json::json!(2));
263 b.insert("z".into(), serde_json::json!(1));
264
265 assert_eq!(hash_args(&a), hash_args(&b));
266 }
267
268 #[test]
269 fn hash_args_different_values() {
270 let mut a = serde_json::Map::new();
271 a.insert("x".into(), serde_json::json!(1));
272 let mut b = serde_json::Map::new();
273 b.insert("x".into(), serde_json::json!(2));
274 assert_ne!(hash_args(&a), hash_args(&b));
275 }
276
277 #[test]
278 fn args_template_replaces_values_with_type_placeholders() {
279 let mut m = serde_json::Map::new();
280 m.insert("cmd".into(), serde_json::json!("ls -la"));
281 m.insert("timeout".into(), serde_json::json!(30));
282 m.insert("flag".into(), serde_json::json!(true));
283 let t = args_template(&m);
284 assert!(t.contains("<string>"));
285 assert!(t.contains("<number>"));
286 assert!(t.contains("<bool>"));
287 }
288}