zeph_core/agent/speculative/
mod.rs1#![allow(dead_code)]
32
33pub mod cache;
34pub mod partial_json;
35pub mod paste;
36pub mod prediction;
37
38use std::sync::Arc;
39use std::time::Duration;
40
41use tokio::time::Instant;
42use tokio_util::sync::CancellationToken;
43use tracing::debug;
44use zeph_common::SkillTrustLevel;
45use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
46
47use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args};
48use prediction::Prediction;
49
50pub use zeph_tools::config::{SpeculationMode, SpeculativeConfig};
51
52#[derive(Debug, Default, Clone)]
54pub struct SpeculativeMetrics {
55 pub committed: u32,
57 pub cancelled: u32,
59 pub evicted_oldest: u32,
61 pub skipped_confirmation: u32,
63 pub wasted_ms: u64,
65}
66
67pub struct SpeculationEngine {
85 executor: Arc<dyn ErasedToolExecutor>,
86 config: SpeculativeConfig,
87 cache: SpeculativeCache,
88 metrics: parking_lot::Mutex<SpeculativeMetrics>,
89 sweeper: parking_lot::Mutex<Option<tokio::task::JoinHandle<()>>>,
91}
92
93impl SpeculationEngine {
94 #[must_use]
96 pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
97 let cache = SpeculativeCache::new(config.max_in_flight);
98
99 let shared = cache.shared_inner();
101
102 let sweeper = tokio::spawn(async move {
103 let mut interval = tokio::time::interval(Duration::from_secs(5));
104 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
105 loop {
106 interval.tick().await;
107 SpeculativeCache::sweep_expired_inner(&shared);
108 }
109 });
110
111 Self {
112 executor,
113 config,
114 cache,
115 metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
116 sweeper: parking_lot::Mutex::new(Some(sweeper)),
117 }
118 }
119
120 #[must_use]
122 pub fn mode(&self) -> SpeculationMode {
123 self.config.mode
124 }
125
126 #[must_use]
128 pub fn is_active(&self) -> bool {
129 self.config.mode != SpeculationMode::Off
130 }
131
132 pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
140 if trust_level != SkillTrustLevel::Trusted {
141 return false;
142 }
143
144 let tool_id = &prediction.tool_id;
145 if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
146 return false;
147 }
148
149 let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
150 let args_hash = hash_args(&call.params);
151
152 if self.executor.requires_confirmation_erased(&call) {
155 let mut m = self.metrics.lock();
156 m.skipped_confirmation += 1;
157 debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
158 return false;
159 }
160
161 let exec = Arc::clone(&self.executor);
162 let call_clone = call.clone();
163 let cancel = CancellationToken::new();
164 let cancel_child = cancel.child_token();
165
166 let join = tokio::spawn(async move {
167 tokio::select! {
168 result = exec.execute_tool_call_erased(&call_clone) => result,
169 () = cancel_child.cancelled() => {
170 Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
171 }
172 }
173 });
174
175 let handle = SpeculativeHandle {
176 key: HandleKey {
177 tool_id: tool_id.clone(),
178 args_hash,
179 },
180 join,
181 cancel,
182 ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
183 started_at: std::time::Instant::now(),
184 };
185
186 debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
187 self.cache.insert(handle);
188 true
189 }
190
191 pub async fn try_commit(
196 &self,
197 call: &ToolCall,
198 ) -> Option<Result<Option<ToolOutput>, ToolError>> {
199 let args_hash = hash_args(&call.params);
200 if let Some(handle) = self.cache.take_match(&call.tool_id, &args_hash) {
201 {
202 let mut m = self.metrics.lock();
203 m.committed += 1;
204 }
205 debug!(tool_id = %call.tool_id, "speculative commit");
206 Some(handle.commit().await)
207 } else {
208 None
209 }
210 }
211
212 pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
216 debug!(tool_id = %tool_id, "speculative cancel for tool");
217 self.cache.cancel_by_tool_id(tool_id);
218 let mut m = self.metrics.lock();
219 m.cancelled += 1;
220 }
221
222 pub fn end_turn(&self) -> SpeculativeMetrics {
224 self.cache.cancel_all();
225 let m = self.metrics.lock().clone();
226 *self.metrics.lock() = SpeculativeMetrics::default();
227 m
228 }
229
230 #[must_use]
232 pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
233 self.metrics.lock().clone()
234 }
235}
236
237impl Drop for SpeculationEngine {
238 fn drop(&mut self) {
239 self.cache.cancel_all();
240 if let Some(sweeper) = self.sweeper.lock().take() {
241 sweeper.abort();
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
250
251 struct AlwaysOkExecutor;
252
253 impl ToolExecutor for AlwaysOkExecutor {
254 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
255 Ok(None)
256 }
257
258 async fn execute_tool_call(
259 &self,
260 _call: &ToolCall,
261 ) -> Result<Option<ToolOutput>, ToolError> {
262 Ok(Some(ToolOutput {
263 tool_name: zeph_common::ToolName::new("test"),
264 summary: "ok".into(),
265 blocks_executed: 1,
266 filter_stats: None,
267 diff: None,
268 streamed: false,
269 terminal_id: None,
270 locations: None,
271 raw_response: None,
272 claim_source: None,
273 }))
274 }
275
276 fn is_tool_speculatable(&self, _: &str) -> bool {
277 true
278 }
279 }
280
281 #[tokio::test]
282 async fn dispatch_and_commit_succeeds() {
283 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
284 let config = SpeculativeConfig {
285 mode: SpeculationMode::Decoding,
286 ..Default::default()
287 };
288 let engine = SpeculationEngine::new(exec, config);
289
290 let pred = Prediction {
291 tool_id: zeph_common::ToolName::new("test"),
292 args: serde_json::Map::new(),
293 confidence: 0.9,
294 source: prediction::PredictionSource::StreamPartial,
295 };
296
297 let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
298 let _ = dispatched;
299 }
300
301 #[tokio::test]
302 async fn untrusted_skill_skips_dispatch() {
303 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
304 let config = SpeculativeConfig {
305 mode: SpeculationMode::Decoding,
306 ..Default::default()
307 };
308 let engine = SpeculationEngine::new(exec, config);
309
310 let pred = Prediction {
311 tool_id: zeph_common::ToolName::new("test"),
312 args: serde_json::Map::new(),
313 confidence: 0.9,
314 source: prediction::PredictionSource::StreamPartial,
315 };
316
317 let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
318 assert!(
319 !dispatched,
320 "untrusted skill must not dispatch speculatively"
321 );
322 }
323
324 #[tokio::test]
325 async fn cancel_for_removes_handle() {
326 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
327 let config = SpeculativeConfig {
328 mode: SpeculationMode::Decoding,
329 ..Default::default()
330 };
331 let engine = SpeculationEngine::new(exec, config);
332
333 let pred = Prediction {
334 tool_id: zeph_common::ToolName::new("test"),
335 args: serde_json::Map::new(),
336 confidence: 0.9,
337 source: prediction::PredictionSource::StreamPartial,
338 };
339
340 engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
341 engine.cancel_for(&zeph_common::ToolName::new("test"));
343 assert!(
344 engine.cache.is_empty(),
345 "cancel_for must remove handle from cache"
346 );
347 }
348}