zeph_core/agent/speculative/
mod.rs1#![allow(dead_code)]
32
33pub mod cache;
34pub mod partial_json;
35pub mod paste;
36pub mod prediction;
37
38use std::sync::Arc;
39use std::time::Duration;
40
41use tokio::time::Instant;
42use tokio_util::sync::CancellationToken;
43use tracing::debug;
44use zeph_common::SkillTrustLevel;
45use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
46
47use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args};
48use prediction::Prediction;
49
50pub use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
51
52#[derive(Debug, Default, Clone)]
54pub struct SpeculativeMetrics {
55 pub committed: u32,
57 pub cancelled: u32,
59 pub evicted_oldest: u32,
61 pub skipped_confirmation: u32,
63 pub wasted_ms: u64,
65}
66
67pub struct SpeculationEngine {
85 executor: Arc<dyn ErasedToolExecutor>,
86 config: SpeculativeConfig,
87 cache: SpeculativeCache,
88 metrics: parking_lot::Mutex<SpeculativeMetrics>,
89 sweeper: parking_lot::Mutex<Option<zeph_common::task_supervisor::TaskHandle>>,
91 task_supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
94}
95
96impl SpeculationEngine {
97 #[must_use]
99 pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
100 Self::new_with_supervisor(executor, config, None)
101 }
102
103 #[must_use]
108 pub fn new_with_supervisor(
109 executor: Arc<dyn ErasedToolExecutor>,
110 config: SpeculativeConfig,
111 supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
112 ) -> Self {
113 let cache = SpeculativeCache::new(config.max_in_flight);
114
115 let shared = cache.shared_inner();
117
118 let sweeper_handle = if let Some(sup) = &supervisor {
119 Some(sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
122 name: "agent.speculative.sweeper",
123 restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
124 factory: move || {
125 let shared = Arc::clone(&shared);
126 async move {
127 let mut interval = tokio::time::interval(Duration::from_secs(5));
128 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
129 loop {
130 interval.tick().await;
131 SpeculativeCache::sweep_expired_inner(&shared);
132 }
133 }
134 },
135 }))
136 } else {
137 let jh = tokio::spawn(async move {
141 let mut interval = tokio::time::interval(Duration::from_secs(5));
142 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
143 loop {
144 interval.tick().await;
145 SpeculativeCache::sweep_expired_inner(&shared);
146 }
147 });
148 let cancel = tokio_util::sync::CancellationToken::new();
150 let tmp_sup = zeph_common::TaskSupervisor::new(cancel);
151 let h = tmp_sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
152 name: "agent.speculative.sweeper",
153 restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
154 factory: || async {},
155 });
156 let abort = jh.abort_handle();
159 std::mem::forget(jh); drop(abort);
167 Some(h)
168 };
169
170 Self {
171 executor,
172 config,
173 cache,
174 metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
175 sweeper: parking_lot::Mutex::new(sweeper_handle),
176 task_supervisor: supervisor,
177 }
178 }
179
180 #[must_use]
182 pub fn mode(&self) -> SpeculationMode {
183 self.config.mode
184 }
185
186 #[must_use]
188 pub fn is_active(&self) -> bool {
189 self.config.mode != SpeculationMode::Off
190 }
191
192 pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
200 if trust_level != SkillTrustLevel::Trusted {
201 return false;
202 }
203
204 let tool_id = &prediction.tool_id;
205 if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
206 return false;
207 }
208
209 let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
210 let args_hash = hash_args(&call.params);
211
212 if self.executor.requires_confirmation_erased(&call) {
215 let mut m = self.metrics.lock();
216 m.skipped_confirmation += 1;
217 debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
218 return false;
219 }
220
221 let exec = Arc::clone(&self.executor);
222 let call_clone = call.clone();
223 let cancel = CancellationToken::new();
224 let cancel_child = cancel.child_token();
225
226 let task_name: Arc<str> = Arc::from(format!(
227 "agent.speculative.dispatch.{}",
228 uuid::Uuid::new_v4()
229 ));
230 let join = if let Some(sup) = &self.task_supervisor {
231 sup.spawn_oneshot(Arc::clone(&task_name), move || async move {
232 tokio::select! {
233 result = exec.execute_tool_call_erased(&call_clone) => result,
234 () = cancel_child.cancelled() => {
235 Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
236 }
237 }
238 })
239 } else {
240 let tmp_cancel = tokio_util::sync::CancellationToken::new();
244 let tmp_sup = Arc::new(zeph_common::TaskSupervisor::new(tmp_cancel));
245 tmp_sup.spawn_oneshot(task_name, move || async move {
246 tokio::select! {
247 result = exec.execute_tool_call_erased(&call_clone) => result,
248 () = cancel_child.cancelled() => {
249 Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
250 }
251 }
252 })
253 };
254
255 let handle = SpeculativeHandle {
256 key: HandleKey {
257 tool_id: tool_id.clone(),
258 args_hash,
259 },
260 join,
261 cancel,
262 ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
263 started_at: std::time::Instant::now(),
264 };
265
266 debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
267 self.cache.insert(handle);
268 true
269 }
270
271 pub async fn try_commit(
276 &self,
277 call: &ToolCall,
278 ) -> Option<Result<Option<ToolOutput>, ToolError>> {
279 let args_hash = hash_args(&call.params);
280 if let Some(handle) = self.cache.take_match(&call.tool_id, &args_hash) {
281 {
282 let mut m = self.metrics.lock();
283 m.committed += 1;
284 }
285 debug!(tool_id = %call.tool_id, "speculative commit");
286 Some(handle.commit().await)
287 } else {
288 None
289 }
290 }
291
292 pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
296 debug!(tool_id = %tool_id, "speculative cancel for tool");
297 self.cache.cancel_by_tool_id(tool_id);
298 let mut m = self.metrics.lock();
299 m.cancelled += 1;
300 }
301
302 pub fn end_turn(&self) -> SpeculativeMetrics {
304 self.cache.cancel_all();
305 let m = self.metrics.lock().clone();
306 *self.metrics.lock() = SpeculativeMetrics::default();
307 m
308 }
309
310 #[must_use]
312 pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
313 self.metrics.lock().clone()
314 }
315}
316
317impl Drop for SpeculationEngine {
318 fn drop(&mut self) {
319 self.cache.cancel_all();
320 if let Some(handle) = self.sweeper.lock().take() {
321 handle.abort();
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
330
331 struct AlwaysOkExecutor;
332
333 impl ToolExecutor for AlwaysOkExecutor {
334 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
335 Ok(None)
336 }
337
338 async fn execute_tool_call(
339 &self,
340 _call: &ToolCall,
341 ) -> Result<Option<ToolOutput>, ToolError> {
342 Ok(Some(ToolOutput {
343 tool_name: zeph_common::ToolName::new("test"),
344 summary: "ok".into(),
345 blocks_executed: 1,
346 filter_stats: None,
347 diff: None,
348 streamed: false,
349 terminal_id: None,
350 locations: None,
351 raw_response: None,
352 claim_source: None,
353 }))
354 }
355
356 fn is_tool_speculatable(&self, _: &str) -> bool {
357 true
358 }
359 }
360
361 #[tokio::test]
362 async fn dispatch_and_commit_succeeds() {
363 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
364 let config = SpeculativeConfig {
365 mode: SpeculationMode::Decoding,
366 ..Default::default()
367 };
368 let engine = SpeculationEngine::new(exec, config);
369
370 let pred = Prediction {
371 tool_id: zeph_common::ToolName::new("test"),
372 args: serde_json::Map::new(),
373 confidence: 0.9,
374 source: prediction::PredictionSource::StreamPartial,
375 };
376
377 let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
378 let _ = dispatched;
379 }
380
381 #[tokio::test]
382 async fn untrusted_skill_skips_dispatch() {
383 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
384 let config = SpeculativeConfig {
385 mode: SpeculationMode::Decoding,
386 ..Default::default()
387 };
388 let engine = SpeculationEngine::new(exec, config);
389
390 let pred = Prediction {
391 tool_id: zeph_common::ToolName::new("test"),
392 args: serde_json::Map::new(),
393 confidence: 0.9,
394 source: prediction::PredictionSource::StreamPartial,
395 };
396
397 let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
398 assert!(
399 !dispatched,
400 "untrusted skill must not dispatch speculatively"
401 );
402 }
403
404 #[tokio::test]
405 async fn cancel_for_removes_handle() {
406 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
407 let config = SpeculativeConfig {
408 mode: SpeculationMode::Decoding,
409 ..Default::default()
410 };
411 let engine = SpeculationEngine::new(exec, config);
412
413 let pred = Prediction {
414 tool_id: zeph_common::ToolName::new("test"),
415 args: serde_json::Map::new(),
416 confidence: 0.9,
417 source: prediction::PredictionSource::StreamPartial,
418 };
419
420 engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
421 engine.cancel_for(&zeph_common::ToolName::new("test"));
423 assert!(
424 engine.cache.is_empty(),
425 "cancel_for must remove handle from cache"
426 );
427 }
428}