zeph_core/agent/speculative/
mod.rs1pub mod cache;
32pub mod partial_json;
33pub mod paste;
34pub mod prediction;
35pub mod stream_drainer;
36
37use std::sync::Arc;
38use std::time::Duration;
39
40use tokio::time::Instant;
41use tokio_util::sync::CancellationToken;
42use tracing::debug;
43use zeph_common::SkillTrustLevel;
44use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
45
46use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args, hash_context};
47use prediction::Prediction;
48
49pub use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
50
51enum SweepHandle {
52 Supervised(zeph_common::task_supervisor::TaskHandle),
53 Raw(tokio::task::JoinHandle<()>),
54}
55
56impl SweepHandle {
57 fn abort(self) {
58 match self {
59 SweepHandle::Supervised(h) => h.abort(),
60 SweepHandle::Raw(h) => h.abort(),
61 }
62 }
63}
64
65#[derive(Debug, Default, Clone)]
67pub struct SpeculativeMetrics {
68 pub committed: u32,
70 pub cancelled: u32,
72 pub evicted_oldest: u32,
74 pub skipped_confirmation: u32,
76 pub wasted_ms: u64,
78}
79
80pub struct SpeculationEngine {
98 executor: Arc<dyn ErasedToolExecutor>,
99 config: SpeculativeConfig,
100 cache: SpeculativeCache,
101 metrics: parking_lot::Mutex<SpeculativeMetrics>,
102 sweeper: Option<SweepHandle>,
103 task_supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
106}
107
108impl SpeculationEngine {
109 #[must_use]
111 pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
112 Self::new_with_supervisor(executor, config, None)
113 }
114
115 #[must_use]
120 pub fn new_with_supervisor(
121 executor: Arc<dyn ErasedToolExecutor>,
122 config: SpeculativeConfig,
123 supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
124 ) -> Self {
125 let cache = SpeculativeCache::new(config.max_in_flight);
126
127 let shared = cache.shared_inner();
129
130 let sweeper_handle = if let Some(sup) = &supervisor {
131 let task_handle = sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
134 name: "agent.speculative.sweeper",
135 restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
136 factory: move || {
137 let shared = Arc::clone(&shared);
138 async move {
139 let mut interval = tokio::time::interval(Duration::from_secs(5));
140 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
141 loop {
142 interval.tick().await;
143 SpeculativeCache::sweep_expired_inner(&shared);
144 }
145 }
146 },
147 });
148 Some(SweepHandle::Supervised(task_handle))
149 } else {
150 let jh = tokio::spawn(async move {
151 let mut interval = tokio::time::interval(Duration::from_secs(5));
152 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
153 loop {
154 interval.tick().await;
155 SpeculativeCache::sweep_expired_inner(&shared);
156 }
157 });
158 Some(SweepHandle::Raw(jh))
159 };
160
161 Self {
162 executor,
163 config,
164 cache,
165 metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
166 sweeper: sweeper_handle,
167 task_supervisor: supervisor,
168 }
169 }
170
171 #[must_use]
173 pub fn mode(&self) -> SpeculationMode {
174 self.config.mode
175 }
176
177 #[must_use]
179 pub fn is_active(&self) -> bool {
180 self.config.mode != SpeculationMode::Off
181 }
182
183 #[must_use]
185 pub fn confidence_threshold(&self) -> f32 {
186 self.config.confidence_threshold
187 }
188
189 pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
197 if trust_level != SkillTrustLevel::Trusted {
198 return false;
199 }
200
201 let tool_id = &prediction.tool_id;
202 if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
203 return false;
204 }
205
206 let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
207 let args_hash = hash_args(&call.params);
208 let context_hash = hash_context(call.context.as_ref());
209
210 if self.executor.requires_confirmation_erased(&call) {
213 let mut m = self.metrics.lock();
214 m.skipped_confirmation += 1;
215 debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
216 return false;
217 }
218
219 let exec = Arc::clone(&self.executor);
220 let call_clone = call.clone();
221 let cancel = CancellationToken::new();
222 let cancel_child = cancel.child_token();
223
224 let task_name: Arc<str> = Arc::from(format!(
225 "agent.speculative.dispatch.{}",
226 uuid::Uuid::new_v4()
227 ));
228 let join = if let Some(sup) = &self.task_supervisor {
229 sup.spawn_oneshot(Arc::clone(&task_name), move || async move {
230 tokio::select! {
231 result = exec.execute_tool_call_erased(&call_clone) => result,
232 () = cancel_child.cancelled() => {
233 Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
234 }
235 }
236 })
237 } else {
238 let tmp_cancel = tokio_util::sync::CancellationToken::new();
242 let tmp_sup = Arc::new(zeph_common::TaskSupervisor::new(tmp_cancel));
243 tmp_sup.spawn_oneshot(task_name, move || async move {
244 tokio::select! {
245 result = exec.execute_tool_call_erased(&call_clone) => result,
246 () = cancel_child.cancelled() => {
247 Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
248 }
249 }
250 })
251 };
252
253 let handle = SpeculativeHandle {
254 key: HandleKey {
255 tool_id: tool_id.clone(),
256 args_hash,
257 context_hash,
258 },
259 join,
260 cancel,
261 ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
262 started_at: std::time::Instant::now(),
263 };
264
265 debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
266 self.cache.insert(handle);
267 true
268 }
269
270 pub async fn try_commit(
275 &self,
276 call: &ToolCall,
277 ) -> Option<Result<Option<ToolOutput>, ToolError>> {
278 let args_hash = hash_args(&call.params);
279 let context_hash = hash_context(call.context.as_ref());
280 if let Some(handle) = self
281 .cache
282 .take_match(&call.tool_id, &args_hash, &context_hash)
283 {
284 {
285 let mut m = self.metrics.lock();
286 m.committed += 1;
287 }
288 debug!(tool_id = %call.tool_id, "speculative commit");
289 Some(handle.commit().await)
290 } else {
291 None
292 }
293 }
294
295 pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
299 debug!(tool_id = %tool_id, "speculative cancel for tool");
300 self.cache.cancel_by_tool_id(tool_id);
301 let mut m = self.metrics.lock();
302 m.cancelled += 1;
303 }
304
305 pub fn end_turn(&self) -> SpeculativeMetrics {
307 self.cache.cancel_all();
308 std::mem::take(&mut *self.metrics.lock())
309 }
310
311 #[must_use]
313 pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
314 self.metrics.lock().clone()
315 }
316}
317
318impl Drop for SpeculationEngine {
319 fn drop(&mut self) {
320 self.cache.cancel_all();
321 if let Some(handle) = self.sweeper.take() {
322 handle.abort();
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
331
332 struct AlwaysOkExecutor;
333
334 impl ToolExecutor for AlwaysOkExecutor {
335 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
336 Ok(None)
337 }
338
339 async fn execute_tool_call(
340 &self,
341 _call: &ToolCall,
342 ) -> Result<Option<ToolOutput>, ToolError> {
343 Ok(Some(ToolOutput {
344 tool_name: zeph_common::ToolName::new("test"),
345 summary: "ok".into(),
346 blocks_executed: 1,
347 filter_stats: None,
348 diff: None,
349 streamed: false,
350 terminal_id: None,
351 locations: None,
352 raw_response: None,
353 claim_source: None,
354 }))
355 }
356
357 fn is_tool_speculatable(&self, _: &str) -> bool {
358 true
359 }
360 }
361
362 #[tokio::test]
363 async fn dispatch_and_commit_succeeds() {
364 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
365 let config = SpeculativeConfig {
366 mode: SpeculationMode::Decoding,
367 ..Default::default()
368 };
369 let engine = SpeculationEngine::new(exec, config);
370
371 let pred = Prediction {
372 tool_id: zeph_common::ToolName::new("test"),
373 args: serde_json::Map::new(),
374 confidence: 0.9,
375 source: prediction::PredictionSource::StreamPartial,
376 };
377
378 let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
379 let _ = dispatched;
380 }
381
382 #[tokio::test]
383 async fn untrusted_skill_skips_dispatch() {
384 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
385 let config = SpeculativeConfig {
386 mode: SpeculationMode::Decoding,
387 ..Default::default()
388 };
389 let engine = SpeculationEngine::new(exec, config);
390
391 let pred = Prediction {
392 tool_id: zeph_common::ToolName::new("test"),
393 args: serde_json::Map::new(),
394 confidence: 0.9,
395 source: prediction::PredictionSource::StreamPartial,
396 };
397
398 let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
399 assert!(
400 !dispatched,
401 "untrusted skill must not dispatch speculatively"
402 );
403 }
404
405 #[tokio::test]
406 async fn cancel_for_removes_handle() {
407 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
408 let config = SpeculativeConfig {
409 mode: SpeculationMode::Decoding,
410 ..Default::default()
411 };
412 let engine = SpeculationEngine::new(exec, config);
413
414 let pred = Prediction {
415 tool_id: zeph_common::ToolName::new("test"),
416 args: serde_json::Map::new(),
417 confidence: 0.9,
418 source: prediction::PredictionSource::StreamPartial,
419 };
420
421 engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
422 engine.cancel_for(&zeph_common::ToolName::new("test"));
424 assert!(
425 engine.cache.is_empty(),
426 "cancel_for must remove handle from cache"
427 );
428 }
429
430 #[tokio::test]
431 async fn end_turn_cancels_handles_and_resets_metrics() {
432 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
433 let config = SpeculativeConfig {
434 mode: SpeculationMode::Decoding,
435 ..Default::default()
436 };
437 let engine = SpeculationEngine::new(exec, config);
438
439 let pred = Prediction {
440 tool_id: zeph_common::ToolName::new("test"),
441 args: serde_json::Map::new(),
442 confidence: 0.9,
443 source: prediction::PredictionSource::StreamPartial,
444 };
445
446 engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
447 assert!(
448 !engine.cache.is_empty(),
449 "precondition: handle must be in cache before end_turn"
450 );
451
452 let _metrics = engine.end_turn();
453 assert!(
454 engine.cache.is_empty(),
455 "end_turn must cancel all in-flight handles"
456 );
457
458 let snapshot = engine.metrics_snapshot();
460 assert_eq!(snapshot.committed, 0, "metrics must reset after end_turn");
461 assert_eq!(snapshot.cancelled, 0, "metrics must reset after end_turn");
462 }
463
464 #[tokio::test]
465 async fn is_active_reflects_mode() {
466 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
467
468 let engine_off = SpeculationEngine::new(
469 Arc::clone(&exec),
470 SpeculativeConfig {
471 mode: SpeculationMode::Off,
472 ..Default::default()
473 },
474 );
475 assert!(!engine_off.is_active(), "mode=Off means is_active()=false");
476
477 let engine_on = SpeculationEngine::new(
478 exec,
479 SpeculativeConfig {
480 mode: SpeculationMode::Decoding,
481 ..Default::default()
482 },
483 );
484 assert!(
485 engine_on.is_active(),
486 "mode=Decoding means is_active()=true"
487 );
488 }
489
490 #[tokio::test]
493 async fn sweeper_aborted_on_drop() {
494 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
495 let config = SpeculativeConfig {
496 mode: SpeculationMode::Decoding,
497 ..Default::default()
498 };
499
500 let engine = SpeculationEngine::new(Arc::clone(&exec), config);
501
502 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
513 let witness = tokio::spawn(async move {
514 let _ = tx.send(());
516 tokio::time::sleep(Duration::from_hours(1)).await;
517 });
518 let _ = rx.await;
520 drop(engine);
522 tokio::task::yield_now().await;
524
525 assert!(!witness.is_finished(), "unrelated task must not be aborted");
527 witness.abort();
528
529 let engine2 = SpeculationEngine::new(exec, SpeculativeConfig::default());
532 assert!(
533 engine2.sweeper.is_some(),
534 "sweeper handle must be Some after construction"
535 );
536 drop(engine2); }
538
539 #[tokio::test]
541 async fn sweeper_supervised_aborted_on_drop() {
542 let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
543 let config = SpeculativeConfig {
544 mode: SpeculationMode::Decoding,
545 ..Default::default()
546 };
547
548 let cancel = tokio_util::sync::CancellationToken::new();
549 let supervisor = Arc::new(zeph_common::TaskSupervisor::new(cancel));
550
551 let engine =
552 SpeculationEngine::new_with_supervisor(Arc::clone(&exec), config, Some(supervisor));
553 assert!(
554 engine.sweeper.is_some(),
555 "sweeper handle must be Some with supervisor"
556 );
557 drop(engine); }
559}