Skip to main content

tooltest_core/runner/
execution.rs

1use std::cell::RefCell;
2use std::ops::RangeInclusive;
3use std::rc::Rc;
4
5use proptest::test_runner::{Config as ProptestConfig, TestCaseError, TestError, TestRunner};
6
7use crate::generator::{
8    clear_reject_context, state_machine_sequence_strategy, StateMachineSequence,
9};
10use crate::{
11    CorpusReport, CoverageReport, RunConfig, RunFailure, RunOutcome, RunResult, RunWarning,
12    SessionDriver, TraceEntry,
13};
14
15use super::coverage::CoverageTracker;
16use super::linting::{evaluate_run_phase, lint_phases};
17use super::pre_run::run_pre_run_hook;
18use super::prepare::prepare_run;
19use super::result::{failure_result, finalize_state_machine_result, FailureContext};
20use super::state_machine::{execute_state_machine_sequence, StateMachineExecution};
21
22/// Configuration for proptest-driven run behavior.
23///
24/// Downstream crates cannot construct this type via a struct literal; use
25/// [`RunnerOptions::new`] to ensure invariants are validated.
26///
27/// ```rust,compile_fail
28/// use tooltest_core::RunnerOptions;
29///
30/// let _ = RunnerOptions {
31///     cases: 0,
32///     sequence_len: 0..=0,
33/// };
34/// ```
35#[derive(Clone, Debug)]
36pub struct RunnerOptions {
37    /// Number of proptest cases to execute.
38    cases: u32,
39    /// Range of invocation counts per generated sequence.
40    sequence_len: RangeInclusive<usize>,
41}
42
43impl RunnerOptions {
44    /// Creates runner options, validating that `cases >= 1` and the sequence length range is valid.
45    pub fn new(cases: u32, sequence_len: RangeInclusive<usize>) -> Result<Self, String> {
46        if cases < 1 {
47            return Err("cases must be at least 1".to_string());
48        }
49        let min_len = *sequence_len.start();
50        if min_len < 1 {
51            return Err("min-sequence-len must be at least 1".to_string());
52        }
53        let max_len = *sequence_len.end();
54        if min_len > max_len {
55            return Err("min-sequence-len must be <= max-sequence-len".to_string());
56        }
57        Ok(Self {
58            cases,
59            sequence_len,
60        })
61    }
62
63    /// Returns the configured number of proptest cases.
64    pub fn cases(&self) -> u32 {
65        self.cases
66    }
67
68    /// Returns the configured sequence length range.
69    pub fn sequence_len(&self) -> RangeInclusive<usize> {
70        self.sequence_len.clone()
71    }
72
73    pub(crate) fn min_sequence_len(&self) -> usize {
74        *self.sequence_len.start()
75    }
76}
77
78impl Default for RunnerOptions {
79    fn default() -> Self {
80        Self {
81            cases: 32,
82            sequence_len: 1..=20,
83        }
84    }
85}
86
87/// Execute a tooltest run using a pre-initialized session.
88///
89/// Runs apply default assertions that fail on MCP protocol errors, schema-invalid
90/// responses, and (when configured) tool result error responses, plus any
91/// user-supplied assertion rules.
92///
93/// Requires a multi-thread Tokio runtime; current-thread runtimes are rejected.
94pub async fn run_with_session(
95    session: &SessionDriver,
96    config: &RunConfig,
97    options: RunnerOptions,
98) -> RunResult {
99    let lint_phases = lint_phases(&config.lints);
100    let prepared = match prepare_run(session, config, &lint_phases.list).await {
101        Ok(prepared) => prepared,
102        Err(result) => return result,
103    };
104    let prelude_trace = Rc::new(prepared.prelude_trace);
105    let tools = prepared.tools;
106    let warnings = prepared.warnings;
107    let validators = prepared.validators;
108
109    let assertions = config.assertions.clone();
110    let warnings = Rc::new(RefCell::new(warnings));
111    let aggregate_tools = tools.clone();
112    let aggregate_tracker: Rc<RefCell<CoverageTracker<'_>>> =
113        Rc::new(RefCell::new(CoverageTracker::new(
114            &aggregate_tools,
115            &config.state_machine,
116            config.uncallable_limit(),
117        )));
118    let last_trace: Rc<RefCell<Vec<TraceEntry>>> = Rc::new(RefCell::new(Vec::new()));
119    last_trace.replace(prelude_trace.as_ref().clone());
120    let last_coverage: Rc<RefCell<Option<CoverageReport>>> = Rc::new(RefCell::new(None));
121    let last_corpus: Rc<RefCell<Option<CorpusReport>>> = Rc::new(RefCell::new(None));
122    let last_failure = Rc::new(RefCell::new(FailureContext {
123        failure: RunFailure::new(String::new()),
124        trace: Vec::new(),
125        coverage: None,
126        corpus: None,
127        positive_error: false,
128    }));
129    let validators = Rc::new(validators);
130    clear_reject_context();
131    let handle = tokio::runtime::Handle::current();
132    if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
133        return failure_result(
134            RunFailure::new("run_with_session requires a multi-thread Tokio runtime".to_string()),
135            Vec::new(),
136            None,
137            warnings.borrow().clone(),
138            None,
139            None,
140        );
141    }
142
143    let mut runner = TestRunner::new(ProptestConfig {
144        cases: options.cases(),
145        failure_persistence: None,
146        ..ProptestConfig::default()
147    });
148
149    let strategy = match state_machine_sequence_strategy(
150        &tools,
151        config.predicate.as_ref(),
152        &config.state_machine,
153        options.sequence_len(),
154    ) {
155        Ok(strategy) => strategy,
156        Err(error) => {
157            return failure_result(
158                RunFailure::new(error.to_string()),
159                prelude_trace.as_ref().clone(),
160                None,
161                warnings.borrow().clone(),
162                None,
163                None,
164            )
165        }
166    };
167
168    let run_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
169        runner.run(&strategy, {
170            let assertions = assertions.clone();
171            let last_trace = last_trace.clone();
172            let last_coverage = last_coverage.clone();
173            let last_corpus = last_corpus.clone();
174            let last_failure = last_failure.clone();
175            let validators = validators.clone();
176            let aggregate_tracker = aggregate_tracker.clone();
177            let response_lints = lint_phases.response.clone();
178            let warnings = warnings.clone();
179            let trace_sink = config.trace_sink.clone();
180            let case_counter = Rc::new(RefCell::new(0u64));
181            move |sequence| {
182                let execution = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
183                    tokio::task::block_in_place(|| {
184                        let last_coverage = last_coverage.clone();
185                        let last_corpus = last_corpus.clone();
186                        let case_counter = case_counter.clone();
187                        handle.block_on(async {
188                            if let Err(failure) = run_pre_run_hook(config).await {
189                                return Err(FailureContext {
190                                    failure,
191                                    trace: Vec::new(),
192                                    coverage: None,
193                                    corpus: None,
194                                    positive_error: true,
195                                });
196                            }
197                            let mut tracker = CoverageTracker::new(
198                                &tools,
199                                &config.state_machine,
200                                config.uncallable_limit(),
201                            );
202                            // Enforce min-sequence-len even when coverage lint is enabled, so a
203                            // run cannot "succeed" with 0 tool calls when all tools are uncallable.
204                            let min_len = Some(options.min_sequence_len());
205                            let case_index = {
206                                let mut counter = case_counter.borrow_mut();
207                                let index = *counter;
208                                *counter += 1;
209                                index
210                            };
211                            let execution = StateMachineExecution {
212                                session,
213                                tools: &tools,
214                                validators: &validators,
215                                assertions: &assertions,
216                                predicate: config.predicate.as_ref(),
217                                min_len,
218                                in_band_error_forbidden: config.in_band_error_forbidden,
219                                full_trace: config.full_trace,
220                                warnings: warnings.clone(),
221                                response_lints: response_lints.clone(),
222                                case_index,
223                                trace_sink: trace_sink.clone(),
224                            };
225                            let result =
226                                execute_state_machine_sequence(&sequence, &execution, &mut tracker)
227                                    .await;
228                            let (report, corpus_report) = {
229                                let mut aggregate = aggregate_tracker.borrow_mut();
230                                aggregate.merge_from(&tracker);
231                                let mut report = aggregate.report();
232                                apply_uncallable_traces(&mut report, config.show_uncallable);
233                                let corpus_report = if config.state_machine.dump_corpus {
234                                    Some(aggregate.corpus_report())
235                                } else {
236                                    None
237                                };
238                                (report, corpus_report)
239                            };
240                            last_coverage.replace(Some(report.clone()));
241                            last_corpus.replace(corpus_report.clone());
242                            match result {
243                                Ok(trace) => Ok(trace),
244                                Err(mut failure) => {
245                                    if failure.positive_error {
246                                        failure.coverage = None;
247                                    } else {
248                                        failure.coverage = Some(report);
249                                    }
250                                    failure.corpus = corpus_report;
251                                    Err(failure)
252                                }
253                            }
254                        })
255                    })
256                }));
257                let execution: Result<Vec<TraceEntry>, FailureContext> = match execution {
258                    Ok(execution) => execution,
259                    Err(payload) => {
260                        let reason = panic_message(payload);
261                        let failure = FailureContext {
262                            failure: RunFailure {
263                                reason: format!("run panicked: {reason}"),
264                                code: Some("run_panicked".to_string()),
265                                details: None,
266                            },
267                            trace: prelude_trace.as_ref().clone(),
268                            coverage: None,
269                            corpus: None,
270                            positive_error: true,
271                        };
272                        last_failure.replace(failure.clone());
273                        last_trace.replace(failure.trace.clone());
274                        return Err(TestCaseError::fail(failure.failure.reason.clone()));
275                    }
276                };
277                match execution {
278                    Ok(trace) => {
279                        let mut full_trace = prelude_trace.as_ref().clone();
280                        full_trace.extend(trace);
281                        last_trace.replace(full_trace);
282                        Ok(())
283                    }
284                    Err(mut failure) => {
285                        let mut full_trace = prelude_trace.as_ref().clone();
286                        full_trace.extend(failure.trace);
287                        failure.trace = full_trace;
288                        last_failure.replace(failure.clone());
289                        Err(TestCaseError::fail(failure.failure.reason.clone()))
290                    }
291                }
292            }
293        })
294    }));
295    let run_result = finalize_run_result(
296        run_result,
297        &last_trace,
298        &last_failure,
299        &last_coverage,
300        &last_corpus,
301        warnings.borrow().clone(),
302    );
303    let mut run_result = run_result;
304    let outcome = run_result.outcome.clone();
305    let coverage = last_coverage.borrow();
306    let corpus = last_corpus.borrow();
307    let context = crate::RunLintContext {
308        coverage: coverage.as_ref(),
309        corpus: corpus.as_ref(),
310        coverage_allowlist: config.state_machine.coverage_allowlist.as_deref(),
311        coverage_blocklist: config.state_machine.coverage_blocklist.as_deref(),
312        outcome: &outcome,
313    };
314    if let Some(failure) = evaluate_run_phase(&lint_phases.run, &context, &mut run_result.warnings)
315    {
316        if matches!(outcome, RunOutcome::Success) {
317            return failure_result(
318                failure,
319                Vec::new(),
320                None,
321                run_result.warnings.clone(),
322                run_result.coverage.clone(),
323                run_result.corpus.clone(),
324            );
325        }
326    }
327    run_result
328}
329
330fn apply_uncallable_traces(report: &mut CoverageReport, show_uncallable: bool) {
331    if !show_uncallable {
332        report.uncallable_traces.clear();
333    }
334}
335
336fn finalize_run_result(
337    run_result: std::thread::Result<Result<(), TestError<StateMachineSequence>>>,
338    last_trace: &Rc<RefCell<Vec<TraceEntry>>>,
339    last_failure: &Rc<RefCell<FailureContext>>,
340    last_coverage: &Rc<RefCell<Option<CoverageReport>>>,
341    last_corpus: &Rc<RefCell<Option<CorpusReport>>>,
342    warnings: Vec<RunWarning>,
343) -> RunResult {
344    match run_result {
345        Ok(run_result) => finalize_state_machine_result(
346            run_result,
347            last_trace,
348            last_failure,
349            last_coverage,
350            last_corpus,
351            &warnings,
352        ),
353        Err(payload) => run_result_from_panic(
354            payload,
355            last_trace.borrow().clone(),
356            warnings,
357            last_coverage.borrow().clone(),
358            last_corpus.borrow().clone(),
359        ),
360    }
361}
362
363fn run_result_from_panic(
364    payload: Box<dyn std::any::Any + Send>,
365    trace: Vec<TraceEntry>,
366    warnings: Vec<RunWarning>,
367    coverage: Option<CoverageReport>,
368    corpus: Option<CorpusReport>,
369) -> RunResult {
370    let reason = panic_message(payload);
371    failure_result(
372        RunFailure {
373            reason: format!("run panicked: {reason}"),
374            code: Some("run_panicked".to_string()),
375            details: None,
376        },
377        trace,
378        None,
379        warnings,
380        coverage,
381        corpus,
382    )
383}
384
385fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
386    if let Some(message) = payload.as_ref().downcast_ref::<&str>() {
387        (*message).to_string()
388    } else if let Some(message) = payload.as_ref().downcast_ref::<String>() {
389        message.clone()
390    } else {
391        "unknown panic".to_string()
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::{
398        apply_uncallable_traces, finalize_run_result, panic_message, run_result_from_panic,
399        FailureContext,
400    };
401    use crate::generator::StateMachineSequence;
402    use crate::{
403        CallToolResult, CoverageReport, RunOutcome, RunWarning, ToolInvocation, UncallableToolCall,
404    };
405    use proptest::test_runner::TestError;
406    use std::cell::RefCell;
407    use std::collections::BTreeMap;
408    use std::rc::Rc;
409
410    fn outcome_is_failure(outcome: &RunOutcome) -> bool {
411        matches!(outcome, RunOutcome::Failure(_))
412    }
413
414    #[test]
415    fn panic_message_handles_str() {
416        let payload: Box<dyn std::any::Any + Send> = Box::new("boom");
417        assert_eq!(panic_message(payload), "boom");
418    }
419
420    #[test]
421    fn panic_message_handles_string() {
422        let payload: Box<dyn std::any::Any + Send> = Box::new("boom".to_string());
423        assert_eq!(panic_message(payload), "boom");
424    }
425
426    #[test]
427    fn panic_message_handles_unknown() {
428        let payload: Box<dyn std::any::Any + Send> = Box::new(42_u64);
429        assert_eq!(panic_message(payload), "unknown panic");
430    }
431
432    #[test]
433    fn run_result_from_panic_builds_failure() {
434        let payload: Box<dyn std::any::Any + Send> = Box::new("boom");
435        let result =
436            run_result_from_panic(payload, Vec::new(), Vec::<RunWarning>::new(), None, None);
437        let is_expected = matches!(
438            result.outcome,
439            RunOutcome::Failure(ref failure)
440                if failure.reason.contains("run panicked: boom")
441                    && failure.code.as_deref() == Some("run_panicked")
442        );
443        assert!(is_expected);
444    }
445
446    #[test]
447    fn finalize_run_result_handles_panic() {
448        let run_result: std::thread::Result<Result<(), TestError<StateMachineSequence>>> =
449            Err(Box::new("boom"));
450        let trace = Rc::new(RefCell::new(Vec::new()));
451        let failure = Rc::new(RefCell::new(FailureContext {
452            failure: crate::RunFailure::new(String::new()),
453            trace: Vec::new(),
454            coverage: None,
455            corpus: None,
456            positive_error: false,
457        }));
458        let coverage = Rc::new(RefCell::new(None));
459        let corpus = Rc::new(RefCell::new(None));
460        let result =
461            finalize_run_result(run_result, &trace, &failure, &coverage, &corpus, Vec::new());
462        assert!(outcome_is_failure(&result.outcome));
463    }
464
465    #[test]
466    fn finalize_run_result_handles_success() {
467        let run_result: std::thread::Result<Result<(), TestError<StateMachineSequence>>> =
468            Ok(Ok(()));
469        let trace = Rc::new(RefCell::new(Vec::new()));
470        let failure = Rc::new(RefCell::new(FailureContext {
471            failure: crate::RunFailure::new(String::new()),
472            trace: Vec::new(),
473            coverage: None,
474            corpus: None,
475            positive_error: false,
476        }));
477        let coverage = Rc::new(RefCell::new(None));
478        let corpus = Rc::new(RefCell::new(None));
479        let result =
480            finalize_run_result(run_result, &trace, &failure, &coverage, &corpus, Vec::new());
481        assert!(!outcome_is_failure(&result.outcome));
482    }
483
484    fn sample_report() -> CoverageReport {
485        let mut uncallable_traces = BTreeMap::new();
486        uncallable_traces.insert(
487            "echo".to_string(),
488            vec![UncallableToolCall {
489                input: ToolInvocation {
490                    name: "echo".to_string().into(),
491                    arguments: None,
492                },
493                output: Some(CallToolResult::success(vec![])),
494                error: None,
495                timestamp: "2026-01-01T00:00:00Z".to_string(),
496            }],
497        );
498        CoverageReport {
499            counts: BTreeMap::new(),
500            failures: BTreeMap::new(),
501            warnings: Vec::new(),
502            uncallable_traces,
503        }
504    }
505
506    #[test]
507    fn apply_uncallable_traces_clears_when_disabled() {
508        let mut report = sample_report();
509        apply_uncallable_traces(&mut report, false);
510        assert!(report.uncallable_traces.is_empty());
511    }
512
513    #[test]
514    fn apply_uncallable_traces_retains_when_enabled() {
515        let mut report = sample_report();
516        apply_uncallable_traces(&mut report, true);
517        assert_eq!(report.uncallable_traces.len(), 1);
518    }
519}