1use std::cell::RefCell;
2use std::ops::RangeInclusive;
3use std::rc::Rc;
4
5use proptest::test_runner::{Config as ProptestConfig, TestCaseError, TestRunner};
6
7use crate::generator::{clear_reject_context, state_machine_sequence_strategy};
8use crate::{
9 CorpusReport, CoverageReport, RunConfig, RunFailure, RunOutcome, RunResult, SessionDriver,
10 TraceEntry,
11};
12
13use super::assertions::attach_failure_reason;
14use super::coverage::{coverage_failure, CoverageTracker};
15use super::pre_run::run_pre_run_hook;
16use super::prepare::prepare_run;
17use super::result::{failure_result, finalize_state_machine_result, FailureContext};
18use super::state_machine::execute_state_machine_sequence;
19
20#[derive(Clone, Debug)]
22pub struct RunnerOptions {
23 pub cases: u32,
25 pub sequence_len: RangeInclusive<usize>,
27}
28
29impl Default for RunnerOptions {
30 fn default() -> Self {
31 Self {
32 cases: 32,
33 sequence_len: 1..=20,
34 }
35 }
36}
37
38pub async fn run_with_session(
46 session: &SessionDriver,
47 config: &RunConfig,
48 options: RunnerOptions,
49) -> RunResult {
50 let prepared = match prepare_run(session, config).await {
51 Ok(prepared) => prepared,
52 Err(result) => return result,
53 };
54 let prelude_trace = Rc::new(prepared.prelude_trace);
55 let tools = prepared.tools;
56 let warnings = prepared.warnings;
57 let validators = prepared.validators;
58
59 let assertions = config.assertions.clone();
60 let warnings = Rc::new(warnings);
61 let aggregate_tools = tools.clone();
62 let aggregate_tracker: Rc<RefCell<CoverageTracker<'_>>> = Rc::new(RefCell::new(
63 CoverageTracker::new(&aggregate_tools, &config.state_machine),
64 ));
65 let last_trace: Rc<RefCell<Vec<TraceEntry>>> = Rc::new(RefCell::new(Vec::new()));
66 last_trace.replace(prelude_trace.as_ref().clone());
67 let last_coverage: Rc<RefCell<Option<CoverageReport>>> = Rc::new(RefCell::new(None));
68 let last_corpus: Rc<RefCell<Option<CorpusReport>>> = Rc::new(RefCell::new(None));
69 let last_failure = Rc::new(RefCell::new(FailureContext {
70 failure: RunFailure::new(String::new()),
71 trace: Vec::new(),
72 coverage: None,
73 corpus: None,
74 }));
75 let validators = Rc::new(validators);
76 clear_reject_context();
77 let handle = tokio::runtime::Handle::current();
78 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
79 return failure_result(
80 RunFailure::new("run_with_session requires a multi-thread Tokio runtime".to_string()),
81 Vec::new(),
82 None,
83 warnings.as_ref().clone(),
84 None,
85 None,
86 );
87 }
88
89 let mut runner = TestRunner::new(ProptestConfig {
90 cases: options.cases,
91 failure_persistence: None,
92 ..ProptestConfig::default()
93 });
94
95 let strategy = match state_machine_sequence_strategy(
96 &tools,
97 config.predicate.as_ref(),
98 &config.state_machine,
99 options.sequence_len.clone(),
100 ) {
101 Ok(strategy) => strategy,
102 Err(error) => {
103 return failure_result(
104 RunFailure::new(error.to_string()),
105 prelude_trace.as_ref().clone(),
106 None,
107 warnings.as_ref().clone(),
108 None,
109 None,
110 )
111 }
112 };
113
114 if options.cases == 0 {
115 if let Err(failure) = run_pre_run_hook(config).await {
116 return failure_result(
117 failure,
118 prelude_trace.as_ref().clone(),
119 None,
120 warnings.as_ref().clone(),
121 None,
122 None,
123 );
124 }
125 }
126
127 let run_result = runner.run(&strategy, {
128 let assertions = assertions.clone();
129 let last_trace = last_trace.clone();
130 let last_coverage = last_coverage.clone();
131 let last_corpus = last_corpus.clone();
132 let last_failure = last_failure.clone();
133 let validators = validators.clone();
134 let aggregate_tracker = aggregate_tracker.clone();
135 move |sequence| {
136 let execution: Result<Vec<TraceEntry>, FailureContext> =
137 tokio::task::block_in_place(|| {
138 let last_coverage = last_coverage.clone();
139 let last_corpus = last_corpus.clone();
140 handle.block_on(async {
141 if let Err(failure) = run_pre_run_hook(config).await {
142 return Err(FailureContext {
143 failure,
144 trace: Vec::new(),
145 coverage: None,
146 corpus: None,
147 });
148 }
149 let mut tracker = CoverageTracker::new(&tools, &config.state_machine);
150 let min_len = if config.state_machine.coverage_rules.is_empty() {
151 Some(*options.sequence_len.start())
152 } else {
153 None
154 };
155 let result = execute_state_machine_sequence(
156 session,
157 &tools,
158 &validators,
159 &assertions,
160 &sequence,
161 &mut tracker,
162 config.predicate.as_ref(),
163 min_len,
164 )
165 .await;
166 let (report, corpus_report) = {
167 let mut aggregate = aggregate_tracker.borrow_mut();
168 aggregate.merge_from(&tracker);
169 let report = aggregate.report();
170 let corpus_report = if config.state_machine.dump_corpus {
171 Some(aggregate.corpus_report())
172 } else {
173 None
174 };
175 (report, corpus_report)
176 };
177 last_coverage.replace(Some(report.clone()));
178 last_corpus.replace(corpus_report.clone());
179 match result {
180 Ok(trace) => Ok(trace),
181 Err(mut failure) => {
182 failure.coverage = Some(report);
183 failure.corpus = corpus_report;
184 Err(failure)
185 }
186 }
187 })
188 });
189 match execution {
190 Ok(trace) => {
191 let mut full_trace = prelude_trace.as_ref().clone();
192 full_trace.extend(trace);
193 last_trace.replace(full_trace);
194 Ok(())
195 }
196 Err(mut failure) => {
197 let mut full_trace = prelude_trace.as_ref().clone();
198 full_trace.extend(failure.trace);
199 failure.trace = full_trace;
200 last_failure.replace(failure.clone());
201 Err(TestCaseError::fail(failure.failure.reason.clone()))
202 }
203 }
204 }
205 });
206 let run_result = finalize_state_machine_result(
207 run_result,
208 &last_trace,
209 &last_failure,
210 &last_coverage,
211 &last_corpus,
212 warnings.as_ref(),
213 );
214 if matches!(run_result.outcome, RunOutcome::Success) {
215 if let Err(failure) = aggregate_tracker
216 .borrow()
217 .validate(&config.state_machine.coverage_rules)
218 {
219 let mut trace = last_trace.borrow().clone();
220 attach_failure_reason(&mut trace, "coverage validation failed".to_string());
221 let report = aggregate_tracker.borrow().report();
222 let corpus_report = if config.state_machine.dump_corpus {
223 Some(aggregate_tracker.borrow().corpus_report())
224 } else {
225 None
226 };
227 return failure_result(
228 coverage_failure(failure),
229 trace,
230 None,
231 warnings.as_ref().clone(),
232 Some(report),
233 corpus_report,
234 );
235 }
236 }
237 run_result
238}