1use crate::execution::TestSuiteExecution;
2use crate::output::TestRunnerOutput;
3use clap::{Parser, ValueEnum};
4use std::ffi::OsString;
5use std::num::NonZero;
6use std::str::FromStr;
7use std::sync::Arc;
8use std::time::Duration;
9
10#[derive(Parser, Debug, Clone, Default)]
16#[command(
17    help_template = "USAGE: [OPTIONS] [FILTERS...]\n\n{all-args}\n",
18    disable_version_flag = true
19)]
20pub struct Arguments {
21    #[arg(long = "include-ignored")]
23    pub include_ignored: bool,
24
25    #[arg(long = "ignored")]
27    pub ignored: bool,
28
29    #[arg(long = "exclude-should-panic")]
31    pub exclude_should_panic: bool,
32
33    #[arg(long = "test", conflicts_with = "bench")]
35    pub test: bool,
36
37    #[arg(long = "bench")]
39    pub bench: bool,
40
41    #[arg(long = "list")]
43    pub list: bool,
44
45    #[arg(long = "logfile", value_name = "PATH")]
47    pub logfile: Option<String>,
48
49    #[arg(long = "nocapture")]
51    pub nocapture: bool,
52
53    #[arg(long = "test-threads")]
55    pub test_threads: Option<usize>,
56
57    #[arg(long = "skip", value_name = "FILTER")]
59    pub skip: Vec<String>,
60
61    #[arg(short = 'q', long = "quiet", conflicts_with = "format")]
64    pub quiet: bool,
65
66    #[arg(long = "exact")]
68    pub exact: bool,
69
70    #[arg(long = "color", value_enum, value_name = "auto|always|never")]
72    pub color: Option<ColorSetting>,
73
74    #[arg(long = "format", value_enum, value_name = "pretty|terse|json|junit")]
76    pub format: Option<FormatSetting>,
77
78    #[arg(long = "show-output")]
80    pub show_output: bool,
81
82    #[arg(short = 'Z')]
84    pub unstable_flags: Option<UnstableFlags>,
85
86    #[arg(long = "report-time")]
94    pub report_time: bool,
95
96    #[arg(long = "ensure-time")]
102    pub ensure_time: bool,
103
104    #[arg(long = "shuffle", conflicts_with = "shuffle_seed")]
106    pub shuffle: bool,
107
108    #[arg(long = "shuffle-seed", value_name = "SEED", conflicts_with = "shuffle")]
110    pub shuffle_seed: Option<u64>,
111
112    #[arg(long = "show-stats")]
114    pub show_stats: bool,
115
116    #[arg(value_name = "FILTER")]
120    pub filter: Option<String>,
121
122    #[arg(long = "flaky-run", value_name = "COUNT")]
125    pub flaky_run: Option<usize>,
126
127    #[arg(long = "ipc", hide = true)]
131    pub ipc: Option<String>,
132
133    #[arg(long = "spawn-workers", hide = true)]
135    pub spawn_workers: bool,
136}
137
138impl Arguments {
139    pub fn from_args() -> Self {
145        let mut result: Self = Parser::parse();
146        if result.shuffle && result.shuffle_seed.is_none() {
147            result.shuffle_seed = Some(rand::random());
149            result.shuffle = false;
150        }
151        result
152    }
153
154    pub fn to_args(&self) -> Vec<OsString> {
156        let mut result = Vec::new();
157
158        if self.include_ignored {
159            result.push(OsString::from("--include-ignored"));
160        }
161
162        if self.ignored {
163            result.push(OsString::from("--ignored"));
164        }
165
166        if self.exclude_should_panic {
167            result.push(OsString::from("--exclude-should-panic"));
168        }
169
170        if self.test {
171            result.push(OsString::from("--test"));
172        }
173
174        if self.bench {
175            result.push(OsString::from("--bench"));
176        }
177
178        if self.list {
179            result.push(OsString::from("--list"));
180        }
181
182        if let Some(logfile) = &self.logfile {
183            result.push(OsString::from("--logfile"));
184            result.push(OsString::from(logfile));
185        }
186
187        if self.nocapture {
188            result.push(OsString::from("--nocapture"));
189        }
190
191        if let Some(test_threads) = self.test_threads {
192            result.push(OsString::from("--test-threads"));
193            result.push(OsString::from(test_threads.to_string()));
194        }
195
196        for skip in &self.skip {
197            result.push(OsString::from("--skip"));
198            result.push(OsString::from(skip));
199        }
200
201        if self.quiet {
202            result.push(OsString::from("--quiet"));
203        }
204
205        if self.exact {
206            result.push(OsString::from("--exact"));
207        }
208
209        if let Some(color) = self.color {
210            result.push(OsString::from("--color"));
211            match color {
212                ColorSetting::Auto => result.push(OsString::from("auto")),
213                ColorSetting::Always => result.push(OsString::from("always")),
214                ColorSetting::Never => result.push(OsString::from("never")),
215            }
216        }
217
218        if let Some(format) = self.format {
219            result.push(OsString::from("--format"));
220            match format {
221                FormatSetting::Pretty => result.push(OsString::from("pretty")),
222                FormatSetting::Terse => result.push(OsString::from("terse")),
223                FormatSetting::Json => result.push(OsString::from("json")),
224                FormatSetting::Junit => result.push(OsString::from("junit")),
225            }
226        }
227
228        if self.show_output {
229            result.push(OsString::from("--show-output"));
230        }
231
232        if let Some(unstable_flags) = &self.unstable_flags {
233            result.push(OsString::from("-Z"));
234            match unstable_flags {
235                UnstableFlags::UnstableOptions => result.push(OsString::from("unstable-options")),
236            }
237        }
238
239        if self.report_time {
240            result.push(OsString::from("--report-time"));
241        }
242
243        if self.ensure_time {
244            result.push(OsString::from("--ensure-time"));
245        }
246
247        if self.shuffle {
248            result.push(OsString::from("--shuffle"));
249        }
250
251        if let Some(shuffle_seed) = &self.shuffle_seed {
252            result.push(OsString::from("--shuffle-seed"));
253            result.push(OsString::from(shuffle_seed.to_string()));
254        }
255
256        if self.show_stats {
257            result.push(OsString::from("--show-stats"));
258        }
259
260        if let Some(filter) = &self.filter {
261            result.push(OsString::from(filter));
262        }
263
264        if let Some(flaky_run) = &self.flaky_run {
265            result.push(OsString::from("--flaky-run"));
266            result.push(OsString::from(flaky_run.to_string()));
267        }
268
269        if let Some(ipc) = &self.ipc {
270            result.push(OsString::from("--ipc"));
271            result.push(OsString::from(ipc));
272        }
273
274        if self.spawn_workers {
275            result.push(OsString::from("--spawn-workers"));
276        }
277
278        result
279    }
280
281    pub fn unit_test_threshold(&self) -> TimeThreshold {
282        TimeThreshold::from_env_var("RUST_TEST_TIME_UNIT").unwrap_or(TimeThreshold::new(
283            Duration::from_millis(50),
284            Duration::from_millis(100),
285        ))
286    }
287
288    pub fn integration_test_threshold(&self) -> TimeThreshold {
289        TimeThreshold::from_env_var("RUST_TEST_TIME_INTEGRATION").unwrap_or(TimeThreshold::new(
290            Duration::from_millis(500),
291            Duration::from_millis(1000),
292        ))
293    }
294
295    pub(crate) fn test_threads(&self) -> NonZero<usize> {
296        if self.ipc.is_some() {
297            NonZero::new(1).unwrap()
299        } else {
300            self.test_threads
301                .and_then(NonZero::new)
302                .or_else(|| std::thread::available_parallelism().ok())
303                .unwrap_or(NonZero::new(1).unwrap())
304        }
305    }
306
307    pub(crate) fn finalize_for_execution(
309        &mut self,
310        execution: &TestSuiteExecution,
311        output: Arc<dyn TestRunnerOutput>,
312    ) {
313        let requires_capturing = execution.requires_capturing(!self.nocapture);
314
315        if !requires_capturing || self.ipc.is_some() {
316            } else {
319            self.spawn_workers = true;
321
322            if self.test_threads().get() > 1 {
323                if execution.has_dependencies() {
327                    if execution.remaining() > 1 {
328                        output.warning("Cannot run tests in parallel when tests have shared dependencies and output capturing is on. Using a single thread.");
330                    }
331                    self.test_threads = Some(1); }
333            }
334        }
335    }
336}
337
338impl<A: Into<OsString> + Clone> FromIterator<A> for Arguments {
339    fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
340        Parser::parse_from(iter)
341    }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
346pub enum ColorSetting {
347    #[default]
349    Auto,
350
351    Always,
353
354    Never,
356}
357
358#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
360pub enum UnstableFlags {
361    UnstableOptions,
363}
364
365#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
367pub enum FormatSetting {
368    #[default]
370    Pretty,
371
372    Terse,
374
375    Json,
377
378    Junit,
380}
381
382#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
386pub struct TimeThreshold {
387    pub warn: Duration,
388    pub critical: Duration,
389}
390
391impl TimeThreshold {
392    pub fn new(warn: Duration, critical: Duration) -> Self {
394        Self { warn, critical }
395    }
396
397    pub fn from_env_var(env_var_name: &str) -> Option<Self> {
407        let durations_str = std::env::var(env_var_name).ok()?;
408        let (warn_str, critical_str) = durations_str.split_once(',').unwrap_or_else(|| {
409            panic!(
410                "Duration variable {env_var_name} expected to have 2 numbers separated by comma, but got {durations_str}"
411            )
412        });
413
414        let parse_u64 = |v| {
415            u64::from_str(v).unwrap_or_else(|_| {
416                panic!(
417                    "Duration value in variable {env_var_name} is expected to be a number, but got {v}"
418                )
419            })
420        };
421
422        let warn = parse_u64(warn_str);
423        let critical = parse_u64(critical_str);
424        if warn > critical {
425            panic!("Test execution warn time should be less or equal to the critical time");
426        }
427
428        Some(Self::new(
429            Duration::from_millis(warn),
430            Duration::from_millis(critical),
431        ))
432    }
433
434    pub fn is_critical(&self, duration: &Duration) -> bool {
435        *duration >= self.critical
436    }
437
438    pub fn is_warn(&self, duration: &Duration) -> bool {
439        *duration >= self.warn
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn verify_cli() {
449        use clap::CommandFactory;
450        Arguments::command().debug_assert();
451    }
452}