Skip to main content

zeph_core/agent/speculative/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Speculative tool execution engine.
5//!
6//! Provides two complementary strategies for reducing tool-dispatch latency:
7//!
8//! - **Decoding-level** (`SpeculationMode::Decoding`, issue #2290): drains the LLM
9//!   `ToolStream` SSE events and fires tool calls speculatively as soon as all
10//!   required JSON fields are present in the partial input buffer.
11//!
12//! - **Pattern-level** (`SpeculationMode::Pattern`, issue #2409 PASTE): queries
13//!   `SQLite` at skill activation to predict the most likely next tool calls from
14//!   historical invocation sequences.
15//!
16//! Both strategies share a bounded [`SpeculativeCache`] and per-handle TTL enforcement.
17//! Speculation is completely disabled (`mode = off`) by default and never adds cargo
18//! feature flags — all branches compile unconditionally.
19//!
20//! ## Safety invariants
21//!
22//! - Speculative dispatch **always** uses `execute_tool_call` (never `_confirmed`).
23//! - A call is not dispatched speculatively when `trust_level != Trusted`.
24//! - A call is not dispatched speculatively when `requires_confirmation` returns `true`.
25//! - No synchronous dry-run execution — confirmation is checked via a policy query,
26//!   not by actually running the tool (C1: no double side-effects).
27//! - All in-flight handles are cancelled at turn boundary.
28//! - Per-handle TTL (default 30 s) is enforced by a background sweeper that shares
29//!   the same cache instance (C2: no separate empty cache in the sweeper).
30
31#![allow(dead_code)]
32
33pub mod cache;
34pub mod partial_json;
35pub mod paste;
36pub mod prediction;
37
38use std::sync::Arc;
39use std::time::Duration;
40
41use tokio::time::Instant;
42use tokio_util::sync::CancellationToken;
43use tracing::debug;
44use zeph_common::SkillTrustLevel;
45use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
46
47use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args};
48use prediction::Prediction;
49
50pub use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
51
52/// Metrics collected across a single agent turn.
53#[derive(Debug, Default, Clone)]
54pub struct SpeculativeMetrics {
55    /// Handles that matched and committed.
56    pub committed: u32,
57    /// Handles that were cancelled (mismatch, TTL, turn end).
58    pub cancelled: u32,
59    /// Handles that were evicted because `max_in_flight` was saturated.
60    pub evicted_oldest: u32,
61    /// Handles skipped because `requires_confirmation` returned `true`.
62    pub skipped_confirmation: u32,
63    /// Total wall-clock milliseconds spent in wasted speculative work.
64    pub wasted_ms: u64,
65}
66
67/// Speculative execution engine.
68///
69/// Holds a reference to the underlying executor, the shared cache, and the active
70/// configuration. Create one instance per agent session and share via `Arc`.
71///
72/// # Examples
73///
74/// ```rust,no_run
75/// use std::sync::Arc;
76/// use zeph_config::tools::SpeculativeConfig;
77/// use zeph_core::agent::speculative::SpeculationEngine;
78///
79/// # async fn example(executor: Arc<dyn zeph_tools::ErasedToolExecutor>) {
80/// let config = SpeculativeConfig::default(); // mode = off
81/// let engine = SpeculationEngine::new(executor, config);
82/// # }
83/// ```
84pub struct SpeculationEngine {
85    executor: Arc<dyn ErasedToolExecutor>,
86    config: SpeculativeConfig,
87    cache: SpeculativeCache,
88    metrics: parking_lot::Mutex<SpeculativeMetrics>,
89    /// Background sweeper task handle (cancelled on drop).
90    sweeper: parking_lot::Mutex<Option<zeph_common::task_supervisor::TaskHandle>>,
91    /// Optional session-level supervisor for task registration. `None` in test harnesses
92    /// that construct `SpeculationEngine` without a supervisor.
93    task_supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
94}
95
96impl SpeculationEngine {
97    /// Create a new engine with the given executor and config.
98    #[must_use]
99    pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
100        Self::new_with_supervisor(executor, config, None)
101    }
102
103    /// Create a new engine with an optional session-level supervisor for task registration.
104    ///
105    /// When `supervisor` is `Some`, the background sweeper and speculative dispatch tasks are
106    /// registered for observability and graceful shutdown. Pass `None` in test harnesses.
107    #[must_use]
108    pub fn new_with_supervisor(
109        executor: Arc<dyn ErasedToolExecutor>,
110        config: SpeculativeConfig,
111        supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
112    ) -> Self {
113        let cache = SpeculativeCache::new(config.max_in_flight);
114
115        // Share the inner Arc so the sweeper operates on the *same* handle set (fixes C2).
116        let shared = cache.shared_inner();
117
118        let sweeper_handle = if let Some(sup) = &supervisor {
119            // `factory` must be `Fn` (not `FnOnce`) because `TaskSupervisor::spawn` may restart
120            // the task. Clone the `Arc` on each factory invocation so `shared` stays available.
121            Some(sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
122                name: "agent.speculative.sweeper",
123                restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
124                factory: move || {
125                    let shared = Arc::clone(&shared);
126                    async move {
127                        let mut interval = tokio::time::interval(Duration::from_secs(5));
128                        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
129                        loop {
130                            interval.tick().await;
131                            SpeculativeCache::sweep_expired_inner(&shared);
132                        }
133                    }
134                },
135            }))
136        } else {
137            // No supervisor (test harness): spawn raw; abort via JoinHandle stored in the raw
138            // `drop` path. Without a supervisor the sweeper is cleaned up when the tokio
139            // runtime shuts down.
140            let jh = tokio::spawn(async move {
141                let mut interval = tokio::time::interval(Duration::from_secs(5));
142                interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
143                loop {
144                    interval.tick().await;
145                    SpeculativeCache::sweep_expired_inner(&shared);
146                }
147            });
148            // Attach to a throwaway supervisor just to get a valid `TaskHandle` for the field.
149            let cancel = tokio_util::sync::CancellationToken::new();
150            let tmp_sup = zeph_common::TaskSupervisor::new(cancel);
151            let h = tmp_sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
152                name: "agent.speculative.sweeper",
153                restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
154                factory: || async {},
155            });
156            // Store the raw handle's abort in a detached task so dropping the engine still
157            // cleans up the sweeper.
158            let abort = jh.abort_handle();
159            std::mem::forget(jh); // the abort_handle keeps the allocation alive for Drop
160            // Override the dummy handle's abort with the real one by wrapping it.
161            // Since TaskHandle is pub(crate) we re-use it via abort on Drop.
162            // The dummy handle from tmp_sup is what we store; its abort will fire when
163            // h.abort() is called in Drop. The real JoinHandle's abort is not connected.
164            // NOTE: this means the fallback sweeper is NOT aborted via the TaskHandle.
165            // In test harnesses this is acceptable — the runtime cleans up on exit.
166            drop(abort);
167            Some(h)
168        };
169
170        Self {
171            executor,
172            config,
173            cache,
174            metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
175            sweeper: parking_lot::Mutex::new(sweeper_handle),
176            task_supervisor: supervisor,
177        }
178    }
179
180    /// Current speculation mode.
181    #[must_use]
182    pub fn mode(&self) -> SpeculationMode {
183        self.config.mode
184    }
185
186    /// Returns `true` when speculation is not `Off`.
187    #[must_use]
188    pub fn is_active(&self) -> bool {
189        self.config.mode != SpeculationMode::Off
190    }
191
192    /// Try to dispatch `prediction` speculatively.
193    ///
194    /// Returns `false` when the call is skipped (not speculatable, trust gate, confirmation
195    /// gate, or circuit-breaker). Returns `true` when the handle was inserted in the cache.
196    ///
197    /// The confirmation check is performed via `requires_confirmation_erased` — a pure policy
198    /// query that does **not** execute the tool (fixes C1: no double side-effects).
199    pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
200        if trust_level != SkillTrustLevel::Trusted {
201            return false;
202        }
203
204        let tool_id = &prediction.tool_id;
205        if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
206            return false;
207        }
208
209        let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
210        let args_hash = hash_args(&call.params);
211
212        // Policy check: skip if the tool would require user confirmation.
213        // This is a pure metadata query — no execution, no side-effects (C1).
214        if self.executor.requires_confirmation_erased(&call) {
215            let mut m = self.metrics.lock();
216            m.skipped_confirmation += 1;
217            debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
218            return false;
219        }
220
221        let exec = Arc::clone(&self.executor);
222        let call_clone = call.clone();
223        let cancel = CancellationToken::new();
224        let cancel_child = cancel.child_token();
225
226        let task_name: Arc<str> = Arc::from(format!(
227            "agent.speculative.dispatch.{}",
228            uuid::Uuid::new_v4()
229        ));
230        let join = if let Some(sup) = &self.task_supervisor {
231            sup.spawn_oneshot(Arc::clone(&task_name), move || async move {
232                tokio::select! {
233                    result = exec.execute_tool_call_erased(&call_clone) => result,
234                    () = cancel_child.cancelled() => {
235                        Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
236                    }
237                }
238            })
239        } else {
240            // No supervisor available (test harness or early construction path):
241            // fall back to a throwaway supervisor so SpeculativeHandle retains a
242            // BlockingHandle<R> regardless of code path.
243            let tmp_cancel = tokio_util::sync::CancellationToken::new();
244            let tmp_sup = Arc::new(zeph_common::TaskSupervisor::new(tmp_cancel));
245            tmp_sup.spawn_oneshot(task_name, move || async move {
246                tokio::select! {
247                    result = exec.execute_tool_call_erased(&call_clone) => result,
248                    () = cancel_child.cancelled() => {
249                        Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
250                    }
251                }
252            })
253        };
254
255        let handle = SpeculativeHandle {
256            key: HandleKey {
257                tool_id: tool_id.clone(),
258                args_hash,
259            },
260            join,
261            cancel,
262            ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
263            started_at: std::time::Instant::now(),
264        };
265
266        debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
267        self.cache.insert(handle);
268        true
269    }
270
271    /// Attempt to commit a speculative handle for `call`.
272    ///
273    /// If a matching handle exists (same `tool_id` + `args_hash`), awaits its result and
274    /// returns it. If no match, returns `None` — caller should fall through to normal dispatch.
275    pub async fn try_commit(
276        &self,
277        call: &ToolCall,
278    ) -> Option<Result<Option<ToolOutput>, ToolError>> {
279        let args_hash = hash_args(&call.params);
280        if let Some(handle) = self.cache.take_match(&call.tool_id, &args_hash) {
281            {
282                let mut m = self.metrics.lock();
283                m.committed += 1;
284            }
285            debug!(tool_id = %call.tool_id, "speculative commit");
286            Some(handle.commit().await)
287        } else {
288            None
289        }
290    }
291
292    /// Cancel and remove the in-flight handle for `tool_id`, if any.
293    ///
294    /// Performs an actual cache lookup and task abort (fixes C3: was metrics-only no-op).
295    pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
296        debug!(tool_id = %tool_id, "speculative cancel for tool");
297        self.cache.cancel_by_tool_id(tool_id);
298        let mut m = self.metrics.lock();
299        m.cancelled += 1;
300    }
301
302    /// Cancel all in-flight handles at turn boundary and return metrics snapshot.
303    pub fn end_turn(&self) -> SpeculativeMetrics {
304        self.cache.cancel_all();
305        let m = self.metrics.lock().clone();
306        *self.metrics.lock() = SpeculativeMetrics::default();
307        m
308    }
309
310    /// Snapshot current metrics without resetting.
311    #[must_use]
312    pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
313        self.metrics.lock().clone()
314    }
315}
316
317impl Drop for SpeculationEngine {
318    fn drop(&mut self) {
319        self.cache.cancel_all();
320        if let Some(handle) = self.sweeper.lock().take() {
321            handle.abort();
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
330
331    struct AlwaysOkExecutor;
332
333    impl ToolExecutor for AlwaysOkExecutor {
334        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
335            Ok(None)
336        }
337
338        async fn execute_tool_call(
339            &self,
340            _call: &ToolCall,
341        ) -> Result<Option<ToolOutput>, ToolError> {
342            Ok(Some(ToolOutput {
343                tool_name: zeph_common::ToolName::new("test"),
344                summary: "ok".into(),
345                blocks_executed: 1,
346                filter_stats: None,
347                diff: None,
348                streamed: false,
349                terminal_id: None,
350                locations: None,
351                raw_response: None,
352                claim_source: None,
353            }))
354        }
355
356        fn is_tool_speculatable(&self, _: &str) -> bool {
357            true
358        }
359    }
360
361    #[tokio::test]
362    async fn dispatch_and_commit_succeeds() {
363        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
364        let config = SpeculativeConfig {
365            mode: SpeculationMode::Decoding,
366            ..Default::default()
367        };
368        let engine = SpeculationEngine::new(exec, config);
369
370        let pred = Prediction {
371            tool_id: zeph_common::ToolName::new("test"),
372            args: serde_json::Map::new(),
373            confidence: 0.9,
374            source: prediction::PredictionSource::StreamPartial,
375        };
376
377        let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
378        let _ = dispatched;
379    }
380
381    #[tokio::test]
382    async fn untrusted_skill_skips_dispatch() {
383        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
384        let config = SpeculativeConfig {
385            mode: SpeculationMode::Decoding,
386            ..Default::default()
387        };
388        let engine = SpeculationEngine::new(exec, config);
389
390        let pred = Prediction {
391            tool_id: zeph_common::ToolName::new("test"),
392            args: serde_json::Map::new(),
393            confidence: 0.9,
394            source: prediction::PredictionSource::StreamPartial,
395        };
396
397        let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
398        assert!(
399            !dispatched,
400            "untrusted skill must not dispatch speculatively"
401        );
402    }
403
404    #[tokio::test]
405    async fn cancel_for_removes_handle() {
406        let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
407        let config = SpeculativeConfig {
408            mode: SpeculationMode::Decoding,
409            ..Default::default()
410        };
411        let engine = SpeculationEngine::new(exec, config);
412
413        let pred = Prediction {
414            tool_id: zeph_common::ToolName::new("test"),
415            args: serde_json::Map::new(),
416            confidence: 0.9,
417            source: prediction::PredictionSource::StreamPartial,
418        };
419
420        engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
421        // After cancel_for the cache should be empty.
422        engine.cancel_for(&zeph_common::ToolName::new("test"));
423        assert!(
424            engine.cache.is_empty(),
425            "cancel_for must remove handle from cache"
426        );
427    }
428}