tooltest_core/runner/
execution.rs

1use std::cell::RefCell;
2use std::ops::RangeInclusive;
3use std::rc::Rc;
4
5use proptest::test_runner::{Config as ProptestConfig, TestCaseError, TestRunner};
6
7use crate::generator::{clear_reject_context, state_machine_sequence_strategy};
8use crate::{
9    CorpusReport, CoverageReport, RunConfig, RunFailure, RunOutcome, RunResult, SessionDriver,
10    TraceEntry,
11};
12
13use super::assertions::attach_failure_reason;
14use super::coverage::{coverage_failure, CoverageTracker};
15use super::pre_run::run_pre_run_hook;
16use super::prepare::prepare_run;
17use super::result::{failure_result, finalize_state_machine_result, FailureContext};
18use super::state_machine::execute_state_machine_sequence;
19
20/// Configuration for proptest-driven run behavior.
21#[derive(Clone, Debug)]
22pub struct RunnerOptions {
23    /// Number of proptest cases to execute.
24    pub cases: u32,
25    /// Range of invocation counts per generated sequence.
26    pub sequence_len: RangeInclusive<usize>,
27}
28
29impl Default for RunnerOptions {
30    fn default() -> Self {
31        Self {
32            cases: 32,
33            sequence_len: 1..=20,
34        }
35    }
36}
37
38/// Execute a tooltest run using a pre-initialized session.
39///
40/// Runs apply default assertions that fail on error responses and validate
41/// structured output against declared output schemas, plus any user-supplied
42/// assertion rules.
43///
44/// Requires a multi-thread Tokio runtime; current-thread runtimes are rejected.
45pub async fn run_with_session(
46    session: &SessionDriver,
47    config: &RunConfig,
48    options: RunnerOptions,
49) -> RunResult {
50    let prepared = match prepare_run(session, config).await {
51        Ok(prepared) => prepared,
52        Err(result) => return result,
53    };
54    let prelude_trace = Rc::new(prepared.prelude_trace);
55    let tools = prepared.tools;
56    let warnings = prepared.warnings;
57    let validators = prepared.validators;
58
59    let assertions = config.assertions.clone();
60    let warnings = Rc::new(warnings);
61    let aggregate_tools = tools.clone();
62    let aggregate_tracker: Rc<RefCell<CoverageTracker<'_>>> = Rc::new(RefCell::new(
63        CoverageTracker::new(&aggregate_tools, &config.state_machine),
64    ));
65    let last_trace: Rc<RefCell<Vec<TraceEntry>>> = Rc::new(RefCell::new(Vec::new()));
66    last_trace.replace(prelude_trace.as_ref().clone());
67    let last_coverage: Rc<RefCell<Option<CoverageReport>>> = Rc::new(RefCell::new(None));
68    let last_corpus: Rc<RefCell<Option<CorpusReport>>> = Rc::new(RefCell::new(None));
69    let last_failure = Rc::new(RefCell::new(FailureContext {
70        failure: RunFailure::new(String::new()),
71        trace: Vec::new(),
72        coverage: None,
73        corpus: None,
74    }));
75    let validators = Rc::new(validators);
76    clear_reject_context();
77    let handle = tokio::runtime::Handle::current();
78    if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
79        return failure_result(
80            RunFailure::new("run_with_session requires a multi-thread Tokio runtime".to_string()),
81            Vec::new(),
82            None,
83            warnings.as_ref().clone(),
84            None,
85            None,
86        );
87    }
88
89    let mut runner = TestRunner::new(ProptestConfig {
90        cases: options.cases,
91        failure_persistence: None,
92        ..ProptestConfig::default()
93    });
94
95    let strategy = match state_machine_sequence_strategy(
96        &tools,
97        config.predicate.as_ref(),
98        &config.state_machine,
99        options.sequence_len.clone(),
100    ) {
101        Ok(strategy) => strategy,
102        Err(error) => {
103            return failure_result(
104                RunFailure::new(error.to_string()),
105                prelude_trace.as_ref().clone(),
106                None,
107                warnings.as_ref().clone(),
108                None,
109                None,
110            )
111        }
112    };
113
114    if options.cases == 0 {
115        if let Err(failure) = run_pre_run_hook(config).await {
116            return failure_result(
117                failure,
118                prelude_trace.as_ref().clone(),
119                None,
120                warnings.as_ref().clone(),
121                None,
122                None,
123            );
124        }
125    }
126
127    let run_result = runner.run(&strategy, {
128        let assertions = assertions.clone();
129        let last_trace = last_trace.clone();
130        let last_coverage = last_coverage.clone();
131        let last_corpus = last_corpus.clone();
132        let last_failure = last_failure.clone();
133        let validators = validators.clone();
134        let aggregate_tracker = aggregate_tracker.clone();
135        move |sequence| {
136            let execution: Result<Vec<TraceEntry>, FailureContext> =
137                tokio::task::block_in_place(|| {
138                    let last_coverage = last_coverage.clone();
139                    let last_corpus = last_corpus.clone();
140                    handle.block_on(async {
141                        if let Err(failure) = run_pre_run_hook(config).await {
142                            return Err(FailureContext {
143                                failure,
144                                trace: Vec::new(),
145                                coverage: None,
146                                corpus: None,
147                            });
148                        }
149                        let mut tracker = CoverageTracker::new(&tools, &config.state_machine);
150                        let min_len = if config.state_machine.coverage_rules.is_empty() {
151                            Some(*options.sequence_len.start())
152                        } else {
153                            None
154                        };
155                        let result = execute_state_machine_sequence(
156                            session,
157                            &tools,
158                            &validators,
159                            &assertions,
160                            &sequence,
161                            &mut tracker,
162                            config.predicate.as_ref(),
163                            min_len,
164                        )
165                        .await;
166                        let (report, corpus_report) = {
167                            let mut aggregate = aggregate_tracker.borrow_mut();
168                            aggregate.merge_from(&tracker);
169                            let report = aggregate.report();
170                            let corpus_report = if config.state_machine.dump_corpus {
171                                Some(aggregate.corpus_report())
172                            } else {
173                                None
174                            };
175                            (report, corpus_report)
176                        };
177                        last_coverage.replace(Some(report.clone()));
178                        last_corpus.replace(corpus_report.clone());
179                        match result {
180                            Ok(trace) => Ok(trace),
181                            Err(mut failure) => {
182                                failure.coverage = Some(report);
183                                failure.corpus = corpus_report;
184                                Err(failure)
185                            }
186                        }
187                    })
188                });
189            match execution {
190                Ok(trace) => {
191                    let mut full_trace = prelude_trace.as_ref().clone();
192                    full_trace.extend(trace);
193                    last_trace.replace(full_trace);
194                    Ok(())
195                }
196                Err(mut failure) => {
197                    let mut full_trace = prelude_trace.as_ref().clone();
198                    full_trace.extend(failure.trace);
199                    failure.trace = full_trace;
200                    last_failure.replace(failure.clone());
201                    Err(TestCaseError::fail(failure.failure.reason.clone()))
202                }
203            }
204        }
205    });
206    let run_result = finalize_state_machine_result(
207        run_result,
208        &last_trace,
209        &last_failure,
210        &last_coverage,
211        &last_corpus,
212        warnings.as_ref(),
213    );
214    if matches!(run_result.outcome, RunOutcome::Success) {
215        if let Err(failure) = aggregate_tracker
216            .borrow()
217            .validate(&config.state_machine.coverage_rules)
218        {
219            let mut trace = last_trace.borrow().clone();
220            attach_failure_reason(&mut trace, "coverage validation failed".to_string());
221            let report = aggregate_tracker.borrow().report();
222            let corpus_report = if config.state_machine.dump_corpus {
223                Some(aggregate_tracker.borrow().corpus_report())
224            } else {
225                None
226            };
227            return failure_result(
228                coverage_failure(failure),
229                trace,
230                None,
231                warnings.as_ref().clone(),
232                Some(report),
233                corpus_report,
234            );
235        }
236    }
237    run_result
238}