Skip to main content

zeph_core/agent/speculative/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Speculative tool execution engine.
5//!
6//! Provides two complementary strategies for reducing tool-dispatch latency:
7//!
8//! - **Decoding-level** (`SpeculationMode::Decoding`, issue #2290): drains the LLM
9//!   `ToolStream` SSE events and fires tool calls speculatively as soon as all
10//!   required JSON fields are present in the partial input buffer.
11//!
12//! - **Pattern-level** (`SpeculationMode::Pattern`, issue #2409 PASTE): queries
13//!   `SQLite` at skill activation to predict the most likely next tool calls from
14//!   historical invocation sequences.
15//!
16//! Both strategies share a bounded [`SpeculativeCache`] and per-handle TTL enforcement.
17//! Speculation is completely disabled (`mode = off`) by default and never adds cargo
18//! feature flags — all branches compile unconditionally.
19//!
20//! ## Safety invariants
21//!
22//! - Speculative dispatch **always** uses `execute_tool_call` (never `_confirmed`).
23//! - A call is not dispatched speculatively when `trust_level != Trusted`.
24//! - A call is not dispatched speculatively when `requires_confirmation` returns `true`.
25//! - No synchronous dry-run execution — confirmation is checked via a policy query,
26//!   not by actually running the tool (C1: no double side-effects).
27//! - All in-flight handles are cancelled at turn boundary.
28//! - Per-handle TTL (default 30 s) is enforced by a background sweeper that shares
29//!   the same cache instance (C2: no separate empty cache in the sweeper).
30
31pub mod cache;
32pub mod partial_json;
33pub mod paste;
34pub mod prediction;
35pub mod stream_drainer;
36
37use std::sync::Arc;
38use std::time::Duration;
39
40use tokio::time::Instant;
41use tokio_util::sync::CancellationToken;
42use tracing::debug;
43use zeph_common::SkillTrustLevel;
44use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
45
46use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args, hash_context};
47use prediction::Prediction;
48
49pub use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
50
51enum SweepHandle {
52    Supervised(zeph_common::task_supervisor::TaskHandle),
53    Raw(tokio::task::JoinHandle<()>),
54}
55
56impl SweepHandle {
57    fn abort(self) {
58        match self {
59            SweepHandle::Supervised(h) => h.abort(),
60            SweepHandle::Raw(h) => h.abort(),
61        }
62    }
63}
64
65/// Metrics collected across a single agent turn.
66#[derive(Debug, Default, Clone)]
67pub struct SpeculativeMetrics {
68    /// Handles that matched and committed.
69    pub committed: u32,
70    /// Handles that were cancelled (mismatch, TTL, turn end).
71    pub cancelled: u32,
72    /// Handles that were evicted because `max_in_flight` was saturated.
73    pub evicted_oldest: u32,
74    /// Handles skipped because `requires_confirmation` returned `true`.
75    pub skipped_confirmation: u32,
76    /// Total wall-clock milliseconds spent in wasted speculative work.
77    pub wasted_ms: u64,
78}
79
80/// Speculative execution engine.
81///
82/// Holds a reference to the underlying executor, the shared cache, and the active
83/// configuration. Create one instance per agent session and share via `Arc`.
84///
85/// # Examples
86///
87/// ```rust,no_run
88/// use std::sync::Arc;
89/// use zeph_config::tools::SpeculativeConfig;
90/// use zeph_core::agent::speculative::SpeculationEngine;
91///
92/// # async fn example(executor: Arc<dyn zeph_tools::ErasedToolExecutor>) {
93/// let config = SpeculativeConfig::default(); // mode = off
94/// let engine = SpeculationEngine::new(executor, config);
95/// # }
96/// ```
97pub struct SpeculationEngine {
98    executor: Arc<dyn ErasedToolExecutor>,
99    config: SpeculativeConfig,
100    cache: SpeculativeCache,
101    metrics: parking_lot::Mutex<SpeculativeMetrics>,
102    sweeper: Option<SweepHandle>,
103    /// Optional session-level supervisor for task registration. `None` in test harnesses
104    /// that construct `SpeculationEngine` without a supervisor.
105    task_supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
106}
107
108impl SpeculationEngine {
109    /// Create a new engine with the given executor and config.
110    #[must_use]
111    pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
112        Self::new_with_supervisor(executor, config, None)
113    }
114
115    /// Create a new engine with an optional session-level supervisor for task registration.
116    ///
117    /// When `supervisor` is `Some`, the background sweeper and speculative dispatch tasks are
118    /// registered for observability and graceful shutdown. Pass `None` in test harnesses.
119    #[must_use]
120    pub fn new_with_supervisor(
121        executor: Arc<dyn ErasedToolExecutor>,
122        config: SpeculativeConfig,
123        supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
124    ) -> Self {
125        let cache = SpeculativeCache::new(config.max_in_flight);
126
127        // Share the inner Arc so the sweeper operates on the *same* handle set (fixes C2).
128        let shared = cache.shared_inner();
129
130        let sweeper_handle = if let Some(sup) = &supervisor {
131            // `factory` must be `Fn` (not `FnOnce`) because `TaskSupervisor::spawn` may restart
132            // the task. Clone the `Arc` on each factory invocation so `shared` stays available.
133            let task_handle = sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
134                name: "agent.speculative.sweeper",
135                restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
136                factory: move || {
137                    let shared = Arc::clone(&shared);
138                    async move {
139                        let mut interval = tokio::time::interval(Duration::from_secs(5));
140                        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
141                        loop {
142                            interval.tick().await;
143                            SpeculativeCache::sweep_expired_inner(&shared);
144                        }
145                    }
146                },
147            });
148            Some(SweepHandle::Supervised(task_handle))
149        } else {
150            let jh = tokio::spawn(async move {
151                let mut interval = tokio::time::interval(Duration::from_secs(5));
152                interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
153                loop {
154                    interval.tick().await;
155                    SpeculativeCache::sweep_expired_inner(&shared);
156                }
157            });
158            Some(SweepHandle::Raw(jh))
159        };
160
161        Self {
162            executor,
163            config,
164            cache,
165            metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
166            sweeper: sweeper_handle,
167            task_supervisor: supervisor,
168        }
169    }
170
171    /// Current speculation mode.
172    #[must_use]
173    pub fn mode(&self) -> SpeculationMode {
174        self.config.mode
175    }
176
177    /// Returns `true` when speculation is not `Off`.
178    #[must_use]
179    pub fn is_active(&self) -> bool {
180        self.config.mode != SpeculationMode::Off
181    }
182
183    /// Minimum confidence score `[0.0, 1.0]` required to dispatch a speculative task.
184    #[must_use]
185    pub fn confidence_threshold(&self) -> f32 {
186        self.config.confidence_threshold
187    }
188
189    /// Try to dispatch `prediction` speculatively.
190    ///
191    /// Returns `false` when the call is skipped (not speculatable, trust gate, confirmation
192    /// gate, or circuit-breaker). Returns `true` when the handle was inserted in the cache.
193    ///
194    /// The confirmation check is performed via `requires_confirmation_erased` — a pure policy
195    /// query that does **not** execute the tool (fixes C1: no double side-effects).
196    pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
197        if trust_level != SkillTrustLevel::Trusted {
198            return false;
199        }
200
201        let tool_id = &prediction.tool_id;
202        if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
203            return false;
204        }
205
206        let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
207        let args_hash = hash_args(&call.params);
208        let context_hash = hash_context(call.context.as_ref());
209
210        // Policy check: skip if the tool would require user confirmation.
211        // This is a pure metadata query — no execution, no side-effects (C1).
212        if self.executor.requires_confirmation_erased(&call) {
213            let mut m = self.metrics.lock();
214            m.skipped_confirmation += 1;
215            debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
216            return false;
217        }
218
219        let exec = Arc::clone(&self.executor);
220        let call_clone = call.clone();
221        let cancel = CancellationToken::new();
222        let cancel_child = cancel.child_token();
223
224        let task_name: Arc<str> = Arc::from(format!(
225            "agent.speculative.dispatch.{}",
226            uuid::Uuid::new_v4()
227        ));
228        let join = if let Some(sup) = &self.task_supervisor {
229            sup.spawn_oneshot(Arc::clone(&task_name), move || async move {
230                tokio::select! {
231                    result = exec.execute_tool_call_erased(&call_clone) => result,
232                    () = cancel_child.cancelled() => {
233                        Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
234                    }
235                }
236            })
237        } else {
238            // No supervisor available (test harness or early construction path):
239            // fall back to a throwaway supervisor so SpeculativeHandle retains a
240            // BlockingHandle<R> regardless of code path.
241            let tmp_cancel = tokio_util::sync::CancellationToken::new();
242            let tmp_sup = Arc::new(zeph_common::TaskSupervisor::new(tmp_cancel));
243            tmp_sup.spawn_oneshot(task_name, move || async move {
244                tokio::select! {
245                    result = exec.execute_tool_call_erased(&call_clone) => result,
246                    () = cancel_child.cancelled() => {
247                        Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
248                    }
249                }
250            })
251        };
252
253        let handle = SpeculativeHandle {
254            key: HandleKey {
255                tool_id: tool_id.clone(),
256                args_hash,
257                context_hash,
258            },
259            join,
260            cancel,
261            ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
262            started_at: std::time::Instant::now(),
263        };
264
265        debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
266        self.cache.insert(handle);
267        true
268    }
269
270    /// Attempt to commit a speculative handle for `call`.
271    ///
272    /// If a matching handle exists (same `tool_id` + `args_hash`), awaits its result and
273    /// returns it. If no match, returns `None` — caller should fall through to normal dispatch.
274    pub async fn try_commit(
275        &self,
276        call: &ToolCall,
277    ) -> Option<Result<Option<ToolOutput>, ToolError>> {
278        let args_hash = hash_args(&call.params);
279        let context_hash = hash_context(call.context.as_ref());
280        if let Some(handle) = self
281            .cache
282            .take_match(&call.tool_id, &args_hash, &context_hash)
283        {
284            {
285                let mut m = self.metrics.lock();
286                m.committed += 1;
287            }
288            debug!(tool_id = %call.tool_id, "speculative commit");
289            Some(handle.commit().await)
290        } else {
291            None
292        }
293    }
294
295    /// Cancel and remove the in-flight handle for `tool_id`, if any.
296    ///
297    /// Performs an actual cache lookup and task abort (fixes C3: was metrics-only no-op).
298    pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
299        debug!(tool_id = %tool_id, "speculative cancel for tool");
300        self.cache.cancel_by_tool_id(tool_id);
301        let mut m = self.metrics.lock();
302        m.cancelled += 1;
303    }
304
305    /// Cancel all in-flight handles at turn boundary and return metrics snapshot.
306    pub fn end_turn(&self) -> SpeculativeMetrics {
307        self.cache.cancel_all();
308        std::mem::take(&mut *self.metrics.lock())
309    }
310
311    /// Snapshot current metrics without resetting.
312    #[must_use]
313    pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
314        self.metrics.lock().clone()
315    }
316}
317
318impl Drop for SpeculationEngine {
319    fn drop(&mut self) {
320        self.cache.cancel_all();
321        if let Some(handle) = self.sweeper.take() {
322            handle.abort();
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
331
332    struct AlwaysOkExecutor;
333
334    impl ToolExecutor for AlwaysOkExecutor {
335        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
336            Ok(None)
337        }
338
339        async fn execute_tool_call(
340            &self,
341            _call: &ToolCall,
342        ) -> Result<Option<ToolOutput>, ToolError> {
343            Ok(Some(ToolOutput {
344                tool_name: zeph_common::ToolName::new("test"),
345                summary: "ok".into(),
346                blocks_executed: 1,
347                filter_stats: None,
348                diff: None,
349                streamed: false,
350                terminal_id: None,
351                locations: None,
352                raw_response: None,
353                claim_source: None,
354            }))
355        }
356
357        fn is_tool_speculatable(&self, _: &str) -> bool {
358            true
359        }
360    }
361
362    #[tokio::test]
363    async fn dispatch_and_commit_succeeds() {
364        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
365        let config = SpeculativeConfig {
366            mode: SpeculationMode::Decoding,
367            ..Default::default()
368        };
369        let engine = SpeculationEngine::new(exec, config);
370
371        let pred = Prediction {
372            tool_id: zeph_common::ToolName::new("test"),
373            args: serde_json::Map::new(),
374            confidence: 0.9,
375            source: prediction::PredictionSource::StreamPartial,
376        };
377
378        let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
379        let _ = dispatched;
380    }
381
382    #[tokio::test]
383    async fn untrusted_skill_skips_dispatch() {
384        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
385        let config = SpeculativeConfig {
386            mode: SpeculationMode::Decoding,
387            ..Default::default()
388        };
389        let engine = SpeculationEngine::new(exec, config);
390
391        let pred = Prediction {
392            tool_id: zeph_common::ToolName::new("test"),
393            args: serde_json::Map::new(),
394            confidence: 0.9,
395            source: prediction::PredictionSource::StreamPartial,
396        };
397
398        let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
399        assert!(
400            !dispatched,
401            "untrusted skill must not dispatch speculatively"
402        );
403    }
404
405    #[tokio::test]
406    async fn cancel_for_removes_handle() {
407        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
408        let config = SpeculativeConfig {
409            mode: SpeculationMode::Decoding,
410            ..Default::default()
411        };
412        let engine = SpeculationEngine::new(exec, config);
413
414        let pred = Prediction {
415            tool_id: zeph_common::ToolName::new("test"),
416            args: serde_json::Map::new(),
417            confidence: 0.9,
418            source: prediction::PredictionSource::StreamPartial,
419        };
420
421        engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
422        // After cancel_for the cache should be empty.
423        engine.cancel_for(&zeph_common::ToolName::new("test"));
424        assert!(
425            engine.cache.is_empty(),
426            "cancel_for must remove handle from cache"
427        );
428    }
429
430    #[tokio::test]
431    async fn end_turn_cancels_handles_and_resets_metrics() {
432        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
433        let config = SpeculativeConfig {
434            mode: SpeculationMode::Decoding,
435            ..Default::default()
436        };
437        let engine = SpeculationEngine::new(exec, config);
438
439        let pred = Prediction {
440            tool_id: zeph_common::ToolName::new("test"),
441            args: serde_json::Map::new(),
442            confidence: 0.9,
443            source: prediction::PredictionSource::StreamPartial,
444        };
445
446        engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
447        assert!(
448            !engine.cache.is_empty(),
449            "precondition: handle must be in cache before end_turn"
450        );
451
452        let _metrics = engine.end_turn();
453        assert!(
454            engine.cache.is_empty(),
455            "end_turn must cancel all in-flight handles"
456        );
457
458        // After end_turn, metrics are reset to zero.
459        let snapshot = engine.metrics_snapshot();
460        assert_eq!(snapshot.committed, 0, "metrics must reset after end_turn");
461        assert_eq!(snapshot.cancelled, 0, "metrics must reset after end_turn");
462    }
463
464    #[tokio::test]
465    async fn is_active_reflects_mode() {
466        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
467
468        let engine_off = SpeculationEngine::new(
469            Arc::clone(&exec),
470            SpeculativeConfig {
471                mode: SpeculationMode::Off,
472                ..Default::default()
473            },
474        );
475        assert!(!engine_off.is_active(), "mode=Off means is_active()=false");
476
477        let engine_on = SpeculationEngine::new(
478            exec,
479            SpeculativeConfig {
480                mode: SpeculationMode::Decoding,
481                ..Default::default()
482            },
483        );
484        assert!(
485            engine_on.is_active(),
486            "mode=Decoding means is_active()=true"
487        );
488    }
489
490    /// Verify that the background sweeper task is aborted when `SpeculationEngine` is dropped
491    /// (no-supervisor path uses `SweepHandle::Raw(JoinHandle)`).
492    #[tokio::test]
493    async fn sweeper_aborted_on_drop() {
494        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
495        let config = SpeculativeConfig {
496            mode: SpeculationMode::Decoding,
497            ..Default::default()
498        };
499
500        let engine = SpeculationEngine::new(Arc::clone(&exec), config);
501
502        // Extract the raw join handle BEFORE dropping so we can check it afterwards.
503        // We do this by peeking into the engine's sweeper via a helper that aborts it
504        // and stores whether it was running. Since we cannot move the JoinHandle out
505        // of the engine without unsafe, we instead:
506        //   1. Spawn an independently observable task to stand in for detection.
507        //   2. Verify the engine's Drop impl aborts the sweeper by confirming that
508        //      `sweeper` field is Some before drop and the engine can be dropped cleanly.
509        //
510        // The most reliable test for abort-on-drop: create a task that never exits,
511        // attach its AbortHandle externally, drop the engine, yield, then confirm abort.
512        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
513        let witness = tokio::spawn(async move {
514            // This task signals when it starts and then parks indefinitely.
515            let _ = tx.send(());
516            tokio::time::sleep(Duration::from_hours(1)).await;
517        });
518        // Wait for witness to start.
519        let _ = rx.await;
520        // The engine's sweeper is an independent task. Dropping the engine must abort it.
521        drop(engine);
522        // Yield to tokio to let the drop/abort propagate.
523        tokio::task::yield_now().await;
524
525        // The witness task is unrelated to the engine — it must still be running (not aborted).
526        assert!(!witness.is_finished(), "unrelated task must not be aborted");
527        witness.abort();
528
529        // Now verify via a second engine that the sweeper field is populated and that
530        // Drop runs without panic (tests the abort path directly).
531        let engine2 = SpeculationEngine::new(exec, SpeculativeConfig::default());
532        assert!(
533            engine2.sweeper.is_some(),
534            "sweeper handle must be Some after construction"
535        );
536        drop(engine2); // Must not panic — exercises SweepHandle::Raw abort path.
537    }
538
539    /// Verify sweeper abort via the supervised path (`SweepHandle::Supervised`).
540    #[tokio::test]
541    async fn sweeper_supervised_aborted_on_drop() {
542        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
543        let config = SpeculativeConfig {
544            mode: SpeculationMode::Decoding,
545            ..Default::default()
546        };
547
548        let cancel = tokio_util::sync::CancellationToken::new();
549        let supervisor = Arc::new(zeph_common::TaskSupervisor::new(cancel));
550
551        let engine =
552            SpeculationEngine::new_with_supervisor(Arc::clone(&exec), config, Some(supervisor));
553        assert!(
554            engine.sweeper.is_some(),
555            "sweeper handle must be Some with supervisor"
556        );
557        drop(engine); // Must not panic — exercises SweepHandle::Supervised abort path.
558    }
559}