1use std::cell::RefCell;
2use std::ops::RangeInclusive;
3use std::rc::Rc;
4
5use proptest::test_runner::{Config as ProptestConfig, TestCaseError, TestError, TestRunner};
6
7use crate::generator::{
8 clear_reject_context, state_machine_sequence_strategy, StateMachineSequence,
9};
10use crate::{
11 CorpusReport, CoverageReport, RunConfig, RunFailure, RunOutcome, RunResult, RunWarning,
12 SessionDriver, TraceEntry,
13};
14
15use super::coverage::CoverageTracker;
16use super::linting::{evaluate_run_phase, lint_phases};
17use super::pre_run::run_pre_run_hook;
18use super::prepare::prepare_run;
19use super::result::{failure_result, finalize_state_machine_result, FailureContext};
20use super::state_machine::{execute_state_machine_sequence, StateMachineExecution};
21
22#[derive(Clone, Debug)]
36pub struct RunnerOptions {
37 cases: u32,
39 sequence_len: RangeInclusive<usize>,
41}
42
43impl RunnerOptions {
44 pub fn new(cases: u32, sequence_len: RangeInclusive<usize>) -> Result<Self, String> {
46 if cases < 1 {
47 return Err("cases must be at least 1".to_string());
48 }
49 let min_len = *sequence_len.start();
50 if min_len < 1 {
51 return Err("min-sequence-len must be at least 1".to_string());
52 }
53 let max_len = *sequence_len.end();
54 if min_len > max_len {
55 return Err("min-sequence-len must be <= max-sequence-len".to_string());
56 }
57 Ok(Self {
58 cases,
59 sequence_len,
60 })
61 }
62
63 pub fn cases(&self) -> u32 {
65 self.cases
66 }
67
68 pub fn sequence_len(&self) -> RangeInclusive<usize> {
70 self.sequence_len.clone()
71 }
72
73 pub(crate) fn min_sequence_len(&self) -> usize {
74 *self.sequence_len.start()
75 }
76}
77
78impl Default for RunnerOptions {
79 fn default() -> Self {
80 Self {
81 cases: 32,
82 sequence_len: 1..=20,
83 }
84 }
85}
86
87pub async fn run_with_session(
95 session: &SessionDriver,
96 config: &RunConfig,
97 options: RunnerOptions,
98) -> RunResult {
99 let lint_phases = lint_phases(&config.lints);
100 let prepared = match prepare_run(session, config, &lint_phases.list).await {
101 Ok(prepared) => prepared,
102 Err(result) => return result,
103 };
104 let prelude_trace = Rc::new(prepared.prelude_trace);
105 let tools = prepared.tools;
106 let warnings = prepared.warnings;
107 let validators = prepared.validators;
108
109 let assertions = config.assertions.clone();
110 let warnings = Rc::new(RefCell::new(warnings));
111 let aggregate_tools = tools.clone();
112 let aggregate_tracker: Rc<RefCell<CoverageTracker<'_>>> =
113 Rc::new(RefCell::new(CoverageTracker::new(
114 &aggregate_tools,
115 &config.state_machine,
116 config.uncallable_limit(),
117 )));
118 let last_trace: Rc<RefCell<Vec<TraceEntry>>> = Rc::new(RefCell::new(Vec::new()));
119 last_trace.replace(prelude_trace.as_ref().clone());
120 let last_coverage: Rc<RefCell<Option<CoverageReport>>> = Rc::new(RefCell::new(None));
121 let last_corpus: Rc<RefCell<Option<CorpusReport>>> = Rc::new(RefCell::new(None));
122 let last_failure = Rc::new(RefCell::new(FailureContext {
123 failure: RunFailure::new(String::new()),
124 trace: Vec::new(),
125 coverage: None,
126 corpus: None,
127 positive_error: false,
128 }));
129 let validators = Rc::new(validators);
130 clear_reject_context();
131 let handle = tokio::runtime::Handle::current();
132 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
133 return failure_result(
134 RunFailure::new("run_with_session requires a multi-thread Tokio runtime".to_string()),
135 Vec::new(),
136 None,
137 warnings.borrow().clone(),
138 None,
139 None,
140 );
141 }
142
143 let mut runner = TestRunner::new(ProptestConfig {
144 cases: options.cases(),
145 failure_persistence: None,
146 ..ProptestConfig::default()
147 });
148
149 let strategy = match state_machine_sequence_strategy(
150 &tools,
151 config.predicate.as_ref(),
152 &config.state_machine,
153 options.sequence_len(),
154 ) {
155 Ok(strategy) => strategy,
156 Err(error) => {
157 return failure_result(
158 RunFailure::new(error.to_string()),
159 prelude_trace.as_ref().clone(),
160 None,
161 warnings.borrow().clone(),
162 None,
163 None,
164 )
165 }
166 };
167
168 let run_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
169 runner.run(&strategy, {
170 let assertions = assertions.clone();
171 let last_trace = last_trace.clone();
172 let last_coverage = last_coverage.clone();
173 let last_corpus = last_corpus.clone();
174 let last_failure = last_failure.clone();
175 let validators = validators.clone();
176 let aggregate_tracker = aggregate_tracker.clone();
177 let response_lints = lint_phases.response.clone();
178 let warnings = warnings.clone();
179 let trace_sink = config.trace_sink.clone();
180 let case_counter = Rc::new(RefCell::new(0u64));
181 move |sequence| {
182 let execution = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
183 tokio::task::block_in_place(|| {
184 let last_coverage = last_coverage.clone();
185 let last_corpus = last_corpus.clone();
186 let case_counter = case_counter.clone();
187 handle.block_on(async {
188 if let Err(failure) = run_pre_run_hook(config).await {
189 return Err(FailureContext {
190 failure,
191 trace: Vec::new(),
192 coverage: None,
193 corpus: None,
194 positive_error: true,
195 });
196 }
197 let mut tracker = CoverageTracker::new(
198 &tools,
199 &config.state_machine,
200 config.uncallable_limit(),
201 );
202 let min_len = Some(options.min_sequence_len());
205 let case_index = {
206 let mut counter = case_counter.borrow_mut();
207 let index = *counter;
208 *counter += 1;
209 index
210 };
211 let execution = StateMachineExecution {
212 session,
213 tools: &tools,
214 validators: &validators,
215 assertions: &assertions,
216 predicate: config.predicate.as_ref(),
217 min_len,
218 in_band_error_forbidden: config.in_band_error_forbidden,
219 full_trace: config.full_trace,
220 warnings: warnings.clone(),
221 response_lints: response_lints.clone(),
222 case_index,
223 trace_sink: trace_sink.clone(),
224 };
225 let result =
226 execute_state_machine_sequence(&sequence, &execution, &mut tracker)
227 .await;
228 let (report, corpus_report) = {
229 let mut aggregate = aggregate_tracker.borrow_mut();
230 aggregate.merge_from(&tracker);
231 let mut report = aggregate.report();
232 apply_uncallable_traces(&mut report, config.show_uncallable);
233 let corpus_report = if config.state_machine.dump_corpus {
234 Some(aggregate.corpus_report())
235 } else {
236 None
237 };
238 (report, corpus_report)
239 };
240 last_coverage.replace(Some(report.clone()));
241 last_corpus.replace(corpus_report.clone());
242 match result {
243 Ok(trace) => Ok(trace),
244 Err(mut failure) => {
245 if failure.positive_error {
246 failure.coverage = None;
247 } else {
248 failure.coverage = Some(report);
249 }
250 failure.corpus = corpus_report;
251 Err(failure)
252 }
253 }
254 })
255 })
256 }));
257 let execution: Result<Vec<TraceEntry>, FailureContext> = match execution {
258 Ok(execution) => execution,
259 Err(payload) => {
260 let reason = panic_message(payload);
261 let failure = FailureContext {
262 failure: RunFailure {
263 reason: format!("run panicked: {reason}"),
264 code: Some("run_panicked".to_string()),
265 details: None,
266 },
267 trace: prelude_trace.as_ref().clone(),
268 coverage: None,
269 corpus: None,
270 positive_error: true,
271 };
272 last_failure.replace(failure.clone());
273 last_trace.replace(failure.trace.clone());
274 return Err(TestCaseError::fail(failure.failure.reason.clone()));
275 }
276 };
277 match execution {
278 Ok(trace) => {
279 let mut full_trace = prelude_trace.as_ref().clone();
280 full_trace.extend(trace);
281 last_trace.replace(full_trace);
282 Ok(())
283 }
284 Err(mut failure) => {
285 let mut full_trace = prelude_trace.as_ref().clone();
286 full_trace.extend(failure.trace);
287 failure.trace = full_trace;
288 last_failure.replace(failure.clone());
289 Err(TestCaseError::fail(failure.failure.reason.clone()))
290 }
291 }
292 }
293 })
294 }));
295 let run_result = finalize_run_result(
296 run_result,
297 &last_trace,
298 &last_failure,
299 &last_coverage,
300 &last_corpus,
301 warnings.borrow().clone(),
302 );
303 let mut run_result = run_result;
304 let outcome = run_result.outcome.clone();
305 let coverage = last_coverage.borrow();
306 let corpus = last_corpus.borrow();
307 let context = crate::RunLintContext {
308 coverage: coverage.as_ref(),
309 corpus: corpus.as_ref(),
310 coverage_allowlist: config.state_machine.coverage_allowlist.as_deref(),
311 coverage_blocklist: config.state_machine.coverage_blocklist.as_deref(),
312 outcome: &outcome,
313 };
314 if let Some(failure) = evaluate_run_phase(&lint_phases.run, &context, &mut run_result.warnings)
315 {
316 if matches!(outcome, RunOutcome::Success) {
317 return failure_result(
318 failure,
319 Vec::new(),
320 None,
321 run_result.warnings.clone(),
322 run_result.coverage.clone(),
323 run_result.corpus.clone(),
324 );
325 }
326 }
327 run_result
328}
329
330fn apply_uncallable_traces(report: &mut CoverageReport, show_uncallable: bool) {
331 if !show_uncallable {
332 report.uncallable_traces.clear();
333 }
334}
335
336fn finalize_run_result(
337 run_result: std::thread::Result<Result<(), TestError<StateMachineSequence>>>,
338 last_trace: &Rc<RefCell<Vec<TraceEntry>>>,
339 last_failure: &Rc<RefCell<FailureContext>>,
340 last_coverage: &Rc<RefCell<Option<CoverageReport>>>,
341 last_corpus: &Rc<RefCell<Option<CorpusReport>>>,
342 warnings: Vec<RunWarning>,
343) -> RunResult {
344 match run_result {
345 Ok(run_result) => finalize_state_machine_result(
346 run_result,
347 last_trace,
348 last_failure,
349 last_coverage,
350 last_corpus,
351 &warnings,
352 ),
353 Err(payload) => run_result_from_panic(
354 payload,
355 last_trace.borrow().clone(),
356 warnings,
357 last_coverage.borrow().clone(),
358 last_corpus.borrow().clone(),
359 ),
360 }
361}
362
363fn run_result_from_panic(
364 payload: Box<dyn std::any::Any + Send>,
365 trace: Vec<TraceEntry>,
366 warnings: Vec<RunWarning>,
367 coverage: Option<CoverageReport>,
368 corpus: Option<CorpusReport>,
369) -> RunResult {
370 let reason = panic_message(payload);
371 failure_result(
372 RunFailure {
373 reason: format!("run panicked: {reason}"),
374 code: Some("run_panicked".to_string()),
375 details: None,
376 },
377 trace,
378 None,
379 warnings,
380 coverage,
381 corpus,
382 )
383}
384
385fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
386 if let Some(message) = payload.as_ref().downcast_ref::<&str>() {
387 (*message).to_string()
388 } else if let Some(message) = payload.as_ref().downcast_ref::<String>() {
389 message.clone()
390 } else {
391 "unknown panic".to_string()
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::{
398 apply_uncallable_traces, finalize_run_result, panic_message, run_result_from_panic,
399 FailureContext,
400 };
401 use crate::generator::StateMachineSequence;
402 use crate::{
403 CallToolResult, CoverageReport, RunOutcome, RunWarning, ToolInvocation, UncallableToolCall,
404 };
405 use proptest::test_runner::TestError;
406 use std::cell::RefCell;
407 use std::collections::BTreeMap;
408 use std::rc::Rc;
409
410 fn outcome_is_failure(outcome: &RunOutcome) -> bool {
411 matches!(outcome, RunOutcome::Failure(_))
412 }
413
414 #[test]
415 fn panic_message_handles_str() {
416 let payload: Box<dyn std::any::Any + Send> = Box::new("boom");
417 assert_eq!(panic_message(payload), "boom");
418 }
419
420 #[test]
421 fn panic_message_handles_string() {
422 let payload: Box<dyn std::any::Any + Send> = Box::new("boom".to_string());
423 assert_eq!(panic_message(payload), "boom");
424 }
425
426 #[test]
427 fn panic_message_handles_unknown() {
428 let payload: Box<dyn std::any::Any + Send> = Box::new(42_u64);
429 assert_eq!(panic_message(payload), "unknown panic");
430 }
431
432 #[test]
433 fn run_result_from_panic_builds_failure() {
434 let payload: Box<dyn std::any::Any + Send> = Box::new("boom");
435 let result =
436 run_result_from_panic(payload, Vec::new(), Vec::<RunWarning>::new(), None, None);
437 let is_expected = matches!(
438 result.outcome,
439 RunOutcome::Failure(ref failure)
440 if failure.reason.contains("run panicked: boom")
441 && failure.code.as_deref() == Some("run_panicked")
442 );
443 assert!(is_expected);
444 }
445
446 #[test]
447 fn finalize_run_result_handles_panic() {
448 let run_result: std::thread::Result<Result<(), TestError<StateMachineSequence>>> =
449 Err(Box::new("boom"));
450 let trace = Rc::new(RefCell::new(Vec::new()));
451 let failure = Rc::new(RefCell::new(FailureContext {
452 failure: crate::RunFailure::new(String::new()),
453 trace: Vec::new(),
454 coverage: None,
455 corpus: None,
456 positive_error: false,
457 }));
458 let coverage = Rc::new(RefCell::new(None));
459 let corpus = Rc::new(RefCell::new(None));
460 let result =
461 finalize_run_result(run_result, &trace, &failure, &coverage, &corpus, Vec::new());
462 assert!(outcome_is_failure(&result.outcome));
463 }
464
465 #[test]
466 fn finalize_run_result_handles_success() {
467 let run_result: std::thread::Result<Result<(), TestError<StateMachineSequence>>> =
468 Ok(Ok(()));
469 let trace = Rc::new(RefCell::new(Vec::new()));
470 let failure = Rc::new(RefCell::new(FailureContext {
471 failure: crate::RunFailure::new(String::new()),
472 trace: Vec::new(),
473 coverage: None,
474 corpus: None,
475 positive_error: false,
476 }));
477 let coverage = Rc::new(RefCell::new(None));
478 let corpus = Rc::new(RefCell::new(None));
479 let result =
480 finalize_run_result(run_result, &trace, &failure, &coverage, &corpus, Vec::new());
481 assert!(!outcome_is_failure(&result.outcome));
482 }
483
484 fn sample_report() -> CoverageReport {
485 let mut uncallable_traces = BTreeMap::new();
486 uncallable_traces.insert(
487 "echo".to_string(),
488 vec![UncallableToolCall {
489 input: ToolInvocation {
490 name: "echo".to_string().into(),
491 arguments: None,
492 },
493 output: Some(CallToolResult::success(vec![])),
494 error: None,
495 timestamp: "2026-01-01T00:00:00Z".to_string(),
496 }],
497 );
498 CoverageReport {
499 counts: BTreeMap::new(),
500 failures: BTreeMap::new(),
501 warnings: Vec::new(),
502 uncallable_traces,
503 }
504 }
505
506 #[test]
507 fn apply_uncallable_traces_clears_when_disabled() {
508 let mut report = sample_report();
509 apply_uncallable_traces(&mut report, false);
510 assert!(report.uncallable_traces.is_empty());
511 }
512
513 #[test]
514 fn apply_uncallable_traces_retains_when_enabled() {
515 let mut report = sample_report();
516 apply_uncallable_traces(&mut report, true);
517 assert_eq!(report.uncallable_traces.len(), 1);
518 }
519}