Skip to main content

zeph_orchestration/
verify_predicate.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Per-subtask verification predicates (predicate gate).
5//!
6//! Each task in a DAG may carry a `VerifyPredicate` that must be satisfied by
7//! the task's output before downstream tasks may consume it. Evaluation is
8//! LLM-based via `PredicateEvaluator`.
9//!
10//! # Design
11//!
12//! - [`VerifyPredicate`] is an enum stored in `TaskNode.verify_predicate`. Only the
13//!   `Natural(String)` variant is constructible in v1 — `Expression` returns an error
14//!   if the planner ever emits one.
15//! - [`PredicateOutcome`] is persisted on `TaskNode` via `GraphPersistence::save` (wired in
16//!   `zeph-core` scheduler loop and `handle_plan_confirm`). After a crash, rehydrating the
17//!   graph via `/plan resume <id>` restores `predicate_outcome` so the gate is not re-evaluated
18//!   for already-completed tasks.
19//! - [`PredicateEvaluator`] wraps any [`LlmProvider`] and produces [`PredicateOutcome`]
20//!   values. The evaluation prompt is intentionally minimal and model-agnostic.
21
22use std::time::Duration;
23
24use serde::{Deserialize, Serialize};
25use zeph_llm::provider::{LlmProvider, Message, Role};
26use zeph_sanitizer::{ContentSanitizer, ContentSource, ContentSourceKind};
27
28use super::error::OrchestrationError;
29
30/// A verification criterion attached to a task node.
31///
32/// The planner populates this from the `verify_criteria` field in its JSON output.
33/// Only `Natural` is constructible in v1. If the planner emits `Expression`, the
34/// scheduler returns `OrchestrationError::PredicateNotSupported` rather than
35/// silently ignoring the criterion.
36///
37/// # Examples
38///
39/// ```rust
40/// use zeph_orchestration::VerifyPredicate;
41///
42/// let pred = VerifyPredicate::Natural("output must contain a valid JSON object".to_string());
43/// assert!(matches!(pred, VerifyPredicate::Natural(_)));
44/// ```
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
47pub enum VerifyPredicate {
48    /// Free-form natural-language criterion evaluated by the LLM judge.
49    Natural(String),
50    /// Symbolic expression (reserved, not supported in v1).
51    Expression(String),
52}
53
54impl VerifyPredicate {
55    /// Returns `Ok(&criterion)` for `Natural` predicates; `Err(PredicateNotSupported)`
56    /// for unsupported variants.
57    ///
58    /// # Errors
59    ///
60    /// Returns [`OrchestrationError::PredicateNotSupported`] when the variant is not
61    /// evaluatable in the current version.
62    pub fn as_natural(&self) -> Result<&str, OrchestrationError> {
63        match self {
64            VerifyPredicate::Natural(s) => Ok(s.as_str()),
65            VerifyPredicate::Expression(s) => Err(OrchestrationError::PredicateNotSupported(
66                format!("Expression predicate '{s}' is not supported in v1; use Natural"),
67            )),
68        }
69    }
70}
71
72/// Result of evaluating a [`VerifyPredicate`] against a task's output.
73///
74/// Stored on `TaskNode::predicate_outcome` (in-memory only; restart re-evaluates
75/// any pending predicates). A `None` value signals "not yet evaluated"; consumers
76/// should re-emit `SchedulerAction::VerifyPredicate` on the next tick.
77///
78/// # Examples
79///
80/// ```rust
81/// use zeph_orchestration::PredicateOutcome;
82///
83/// let outcome = PredicateOutcome { passed: true, confidence: 0.9, reason: "output is valid JSON".to_string() };
84/// assert!(outcome.passed);
85/// assert!(outcome.confidence > 0.8);
86/// ```
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct PredicateOutcome {
89    /// Whether the predicate was satisfied.
90    pub passed: bool,
91    /// Confidence score in [0.0, 1.0]. Values < 0.5 with `passed = true` log a warn.
92    pub confidence: f32,
93    /// Human-readable explanation from the LLM judge.
94    pub reason: String,
95}
96
97/// LLM-backed predicate evaluator.
98///
99/// Evaluates a [`VerifyPredicate`] against task output by calling the configured
100/// LLM provider with a judge prompt. Fail-open: evaluation errors produce a
101/// permissive `passed = true` outcome with `confidence = 0.0` and log a warning
102/// rather than aborting the scheduler.
103///
104/// Task output is sanitized via [`ContentSanitizer`] before being embedded in the
105/// judge prompt, mirroring the same defence used by `PlanVerifier`.
106///
107/// # Examples
108///
109/// ```rust,no_run
110/// use zeph_orchestration::{PredicateEvaluator, VerifyPredicate};
111/// use zeph_sanitizer::{ContentSanitizer, ContentIsolationConfig};
112///
113/// # async fn example<P: zeph_llm::provider::LlmProvider>(provider: P) {
114/// let sanitizer = ContentSanitizer::new(&ContentIsolationConfig::default());
115/// let evaluator = PredicateEvaluator::new(provider, sanitizer, 30);
116/// let outcome = evaluator
117///     .evaluate(
118///         &VerifyPredicate::Natural("output must include a summary".to_string()),
119///         "Here is the summary: ...",
120///         None,
121///     )
122///     .await;
123/// assert!(outcome.confidence >= 0.0);
124/// # }
125/// ```
126pub struct PredicateEvaluator<P: LlmProvider> {
127    provider: P,
128    sanitizer: ContentSanitizer,
129    timeout: Duration,
130}
131
132impl<P: LlmProvider> PredicateEvaluator<P> {
133    /// Create a new evaluator backed by `provider`.
134    ///
135    /// `sanitizer` is applied to task output before it is embedded in the judge prompt.
136    /// `timeout_secs` bounds the LLM call; on timeout the evaluator returns a fail-open
137    /// outcome (`passed = true`, `confidence = 0.0`) and logs a warning.
138    pub fn new(provider: P, sanitizer: ContentSanitizer, timeout_secs: u64) -> Self {
139        Self {
140            provider,
141            sanitizer,
142            timeout: Duration::from_secs(timeout_secs),
143        }
144    }
145
146    /// Evaluate `predicate` against `output`.
147    ///
148    /// `prior_failure_reason` is injected into the prompt on re-runs so the model
149    /// knows why the previous attempt failed. Pass `None` on the first evaluation.
150    ///
151    /// On LLM or parse error, returns a permissive outcome (`passed = true,
152    /// confidence = 0.0`) and logs a warning — fail-open per the orchestration
153    /// error policy.
154    pub async fn evaluate(
155        &self,
156        predicate: &VerifyPredicate,
157        output: &str,
158        prior_failure_reason: Option<&str>,
159    ) -> PredicateOutcome {
160        let criterion = match predicate.as_natural() {
161            Ok(s) => s,
162            Err(e) => {
163                tracing::warn!(error = %e, "unsupported predicate variant, skipping evaluation (fail-open)");
164                return PredicateOutcome {
165                    passed: true,
166                    confidence: 0.0,
167                    reason: format!("predicate not evaluated: {e}"),
168                };
169            }
170        };
171
172        let prior_note = prior_failure_reason
173            .map(|r| {
174                // Truncate and wrap in XML tags to prevent injection from compromised judge output.
175                let truncated: String = r.chars().take(256).collect();
176                format!(
177                    "\n\n<prior_failure_reason>{truncated}</prior_failure_reason>\n\
178                     Note: a previous evaluation failed with this reason. Take it into account."
179                )
180            })
181            .unwrap_or_default();
182
183        let system = format!(
184            "You are a strict output verifier. Evaluate whether the task output satisfies \
185             the given criterion. Respond with a JSON object: \
186             {{\"passed\": true/false, \"confidence\": 0.0-1.0, \"reason\": \"...\"}}\n\
187             Criterion: {criterion}{prior_note}"
188        );
189
190        // Sanitize task output before embedding it in the judge prompt (prompt-injection defence).
191        let source = ContentSource::new(ContentSourceKind::ToolResult)
192            .with_identifier("predicate-evaluator-input");
193        let sanitized = self.sanitizer.sanitize(output, source);
194        let user = format!("Task output:\n\n{}", sanitized.body);
195
196        let messages = vec![
197            Message::from_legacy(Role::System, system),
198            Message::from_legacy(Role::User, user),
199        ];
200
201        match tokio::time::timeout(
202            self.timeout,
203            self.provider.chat_typed::<EvalResponse>(&messages),
204        )
205        .await
206        {
207            Ok(Ok(resp)) => {
208                let outcome = PredicateOutcome {
209                    passed: resp.passed,
210                    confidence: resp.confidence.clamp(0.0, 1.0),
211                    reason: resp.reason,
212                };
213                if outcome.passed && outcome.confidence < 0.5 {
214                    tracing::warn!(
215                        confidence = outcome.confidence,
216                        reason = %outcome.reason,
217                        "weak predicate pass (confidence < 0.5)"
218                    );
219                }
220                outcome
221            }
222            Ok(Err(e)) => {
223                tracing::warn!(
224                    error = %e,
225                    "predicate evaluation LLM call failed, returning fail-open outcome"
226                );
227                PredicateOutcome {
228                    passed: true,
229                    confidence: 0.0,
230                    reason: format!("evaluation failed: {e}"),
231                }
232            }
233            Err(_elapsed) => {
234                tracing::warn!(
235                    timeout_secs = self.timeout.as_secs(),
236                    "predicate evaluation timed out, returning fail-open outcome"
237                );
238                PredicateOutcome {
239                    passed: true,
240                    confidence: 0.0,
241                    reason: "evaluation timed out".to_string(),
242                }
243            }
244        }
245    }
246}
247
248/// Internal response shape for predicate evaluation.
249#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
250struct EvalResponse {
251    passed: bool,
252    confidence: f32,
253    reason: String,
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn natural_predicate_as_natural() {
262        let pred = VerifyPredicate::Natural("must contain JSON".to_string());
263        assert_eq!(pred.as_natural().unwrap(), "must contain JSON");
264    }
265
266    #[test]
267    fn expression_predicate_returns_error() {
268        let pred = VerifyPredicate::Expression("len(output) > 0".to_string());
269        assert!(pred.as_natural().is_err());
270    }
271
272    #[test]
273    fn predicate_outcome_serde_roundtrip() {
274        let o = PredicateOutcome {
275            passed: true,
276            confidence: 0.85,
277            reason: "looks good".to_string(),
278        };
279        let json = serde_json::to_string(&o).expect("serialize");
280        let restored: PredicateOutcome = serde_json::from_str(&json).expect("deserialize");
281        assert_eq!(restored.passed, o.passed);
282        assert!((restored.confidence - o.confidence).abs() < f32::EPSILON);
283        assert_eq!(restored.reason, o.reason);
284    }
285
286    #[test]
287    fn verify_predicate_serde_roundtrip_natural() {
288        let pred = VerifyPredicate::Natural("criterion".to_string());
289        let json = serde_json::to_string(&pred).expect("serialize");
290        let restored: VerifyPredicate = serde_json::from_str(&json).expect("deserialize");
291        assert_eq!(pred, restored);
292    }
293
294    #[test]
295    fn task_node_missing_predicate_fields_deserialize_as_none() {
296        // Simulate old JSON blob without predicate fields — #[serde(default)] must handle it.
297        let json = r#"{
298            "id": 0,
299            "title": "t",
300            "description": "d",
301            "agent_hint": null,
302            "status": "pending",
303            "depends_on": [],
304            "result": null,
305            "assigned_agent": null,
306            "retry_count": 0,
307            "failure_strategy": null,
308            "max_retries": null
309        }"#;
310        // Parse as serde_json::Value first (TaskNode is in graph.rs; test the concept here
311        // by checking that our types have correct default handling).
312        let val: serde_json::Value = serde_json::from_str(json).expect("parse");
313        assert!(val.get("verify_predicate").is_none());
314        assert!(val.get("predicate_outcome").is_none());
315        // Actual TaskNode deserialization is tested in graph.rs tests.
316    }
317}