Skip to main content

tango_bench/
lib.rs

1#[cfg(feature = "async")]
2pub use asynchronous::async_benchmark_fn;
3use core::ptr;
4use dylib::ffi::TANGO_API_VERSION;
5use metrics::WallClock;
6use num_traits::ToPrimitive;
7use rand::{rngs::SmallRng, Rng, SeedableRng};
8use std::{
9    cmp::Ordering,
10    hint::black_box,
11    io,
12    marker::PhantomData,
13    mem,
14    ops::{Deref, RangeInclusive},
15    str::Utf8Error,
16    time::Duration,
17};
18use thiserror::Error;
19
20pub mod cli;
21pub mod dylib;
22#[cfg(target_os = "linux")]
23pub mod linux;
24pub mod platform;
25#[cfg(target_os = "windows")]
26pub mod windows;
27
28#[derive(Debug, Error)]
29pub enum Error {
30    #[error("No measurements given")]
31    NoMeasurements,
32
33    #[error("Invalid string pointer from FFI")]
34    InvalidFFIString(Utf8Error),
35
36    #[error("Spi::self() was already called")]
37    SpiSelfWasMoved,
38
39    #[error("Benchmark is missing")]
40    BenchmarkNotFound,
41
42    #[error("Unable to load benchmark")]
43    UnableToLoadBenchmark(#[source] libloading::Error),
44
45    #[cfg(target_os = "windows")]
46    #[error("Unable to patch IAT")]
47    UnableToPatchIat(#[source] windows::Error),
48
49    #[error("Unable to load library symbol: {0}")]
50    UnableToLoadSymbol(String, #[source] libloading::Error),
51
52    #[error("Unknown sampler type. Available options are: flat and linear")]
53    UnknownSamplerType,
54
55    #[error("Invalid test name given")]
56    InvalidTestName,
57
58    #[error("IO Error")]
59    IOError(#[from] io::Error),
60
61    #[error("FFI Error: {0}")]
62    FFIError(String),
63
64    #[error("Unknown FFI Error")]
65    UnknownFFIError,
66
67    #[error(
68        "Non matching tango version. Expected: {}, got: {0}",
69        TANGO_API_VERSION
70    )]
71    IncorrectVersion(u32),
72}
73
74/// Registers benchmark in the system
75///
76/// Macros accepts a list of functions that produce any [`IntoBenchmarks`] type. All of the benchmarks
77/// created by those functions are registered in the harness.
78///
79/// ## Example
80/// ```rust
81/// use std::time::Instant;
82/// use tango_bench::{benchmark_fn, IntoBenchmarks, tango_benchmarks};
83///
84/// fn time_benchmarks() -> impl IntoBenchmarks {
85///     [benchmark_fn("current_time", |b| b.iter(|| Instant::now()))]
86/// }
87///
88/// tango_benchmarks!(time_benchmarks());
89/// ```
90#[macro_export]
91macro_rules! tango_benchmarks {
92    ($($func_expr:expr),+) => {
93        /// Type checking tango_init() function
94        const TANGO_INIT: $crate::dylib::ffi::InitFn = tango_init;
95
96        /// Exported function for initializing the benchmark harness
97        #[no_mangle]
98        unsafe extern "C" fn tango_init() {
99            let mut benchmarks = vec![];
100            $(benchmarks.extend($crate::IntoBenchmarks::into_benchmarks($func_expr));)*
101            $crate::dylib::__tango_init(benchmarks)
102        }
103
104    };
105}
106
107/// Main entrypoint for benchmarks
108///
109/// This macro generate `main()` function for the benchmark harness. Can be used in a form with providing
110/// measurement settings:
111/// ```rust
112/// use tango_bench::{tango_main, tango_benchmarks, MeasurementSettings};
113///
114/// // Register benchmarks
115/// tango_benchmarks!([]);
116///
117/// tango_main!(MeasurementSettings {
118///     samples_per_haystack: 1000,
119///     min_iterations_per_sample: 10,
120///     max_iterations_per_sample: 10_000,
121///     ..Default::default()
122/// });
123/// ```
124#[macro_export]
125macro_rules! tango_main {
126    ($settings:expr) => {
127        fn main() -> $crate::cli::Result<std::process::ExitCode> {
128            // Initialize Tango for SelfVTable usage
129            unsafe { tango_init() };
130            $crate::cli::run($settings)
131        }
132    };
133    () => {
134        tango_main! {$crate::MeasurementSettings::default()}
135    };
136}
137
138pub struct BenchmarkParams {
139    pub seed: u64,
140}
141
142pub struct Bencher<M> {
143    params: BenchmarkParams,
144    metric: PhantomData<M>,
145}
146
147impl<M> Deref for Bencher<M> {
148    type Target = BenchmarkParams;
149
150    fn deref(&self) -> &Self::Target {
151        &self.params
152    }
153}
154
155impl<M: Metric + 'static> Bencher<M> {
156    pub fn metric<T: Metric>(self) -> Bencher<T> {
157        Bencher {
158            params: self.params,
159            metric: PhantomData::<T>,
160        }
161    }
162
163    pub fn iter<O: 'static, F: FnMut() -> O + 'static>(self, func: F) -> Box<dyn ErasedSampler> {
164        Box::new(Sampler::<_, M>::new(func))
165    }
166}
167
168struct Sampler<F, M> {
169    func: F,
170    metric: PhantomData<M>,
171}
172
173impl<F, M> Sampler<F, M> {
174    fn new(func: F) -> Self {
175        Self {
176            func,
177            metric: PhantomData::<M>,
178        }
179    }
180}
181
182pub trait ErasedSampler {
183    /// Measures the performance if the function
184    ///
185    /// Returns the cumulative execution time (all iterations) with nanoseconds precision,
186    /// but not necessarily accuracy. Usually this time is get by `clock_gettime()` call or some other
187    /// platform-specific call.
188    ///
189    /// This method should use the same arguments for measuring the test function unless [`prepare_state()`]
190    /// method is called. Only then new set of input arguments should be generated. It is NOT allowed
191    /// to call this method without first calling [`prepare_state()`].
192    ///
193    /// [`prepare_state()`]: Self::prepare_state()
194    fn measure(&mut self, iterations: usize) -> u64;
195
196    /// Estimates the number of iterations achievable within given time.
197    ///
198    /// Time span is given in milliseconds (`time_ms`). Estimate can be an approximation and it is important
199    /// for implementation to be fast (in the order of 10 ms).
200    /// If possible the same input arguments should be used when building the estimate.
201    /// If the single call of a function is longer than provided timespan the implementation should return 0.
202    fn estimate_iterations(&mut self, time_ms: u32) -> usize {
203        let mut iters = 1;
204        let time_ns = Duration::from_millis(time_ms as u64).as_nanos() as u64;
205
206        for _ in 0..5 {
207            // Never believe short measurements because they are very unreliable. Pretending that
208            // measurement at least took 1us guarantees that we won't end up with an unreasonably large number
209            // of iterations
210            let time = self.measure(iters).max(1_000);
211            let time_per_iteration = (time / iters as u64).max(1);
212            let new_iters = (time_ns / time_per_iteration) as usize;
213
214            // Do early stop if new estimate has the same order of magnitude. It is good enough.
215            if new_iters < 2 * iters {
216                return new_iters;
217            }
218
219            iters = new_iters;
220        }
221
222        iters
223    }
224}
225
226impl<O, F: FnMut() -> O, M: Metric> ErasedSampler for Sampler<F, M> {
227    fn measure(&mut self, iterations: usize) -> u64 {
228        M::measure_fn(|| {
229            for _ in 0..iterations {
230                black_box((self.func)());
231            }
232        })
233    }
234}
235
236pub struct Benchmark {
237    name: String,
238    sampler_factory: Box<dyn SamplerFactory>,
239}
240
241pub fn benchmark_fn<F: FnMut(Bencher<WallClock>) -> Box<dyn ErasedSampler> + 'static>(
242    name: impl Into<String>,
243    sampler_factory: F,
244) -> Benchmark {
245    let name = name.into();
246    assert!(!name.is_empty());
247    Benchmark {
248        name,
249        sampler_factory: Box::new(SyncSampleFactory(sampler_factory)),
250    }
251}
252
253pub trait SamplerFactory {
254    fn create_sampler(&mut self, params: BenchmarkParams) -> Box<dyn ErasedSampler>;
255}
256
257struct SyncSampleFactory<F>(F);
258
259impl<F: FnMut(Bencher<WallClock>) -> Box<dyn ErasedSampler>> SamplerFactory
260    for SyncSampleFactory<F>
261{
262    fn create_sampler(&mut self, params: BenchmarkParams) -> Box<dyn ErasedSampler> {
263        (self.0)(Bencher {
264            params,
265            metric: PhantomData::<WallClock>,
266        })
267    }
268}
269
270impl Benchmark {
271    /// Generates next haystack for the measurement
272    ///
273    /// Calling this method should update internal haystack used for measurement.
274    /// Returns `true` if update happens, `false` if implementation doesn't support haystack generation.
275    /// Haystack/Needle distinction is described in [`Generator`] trait.
276    pub fn prepare_state(&mut self, seed: u64) -> Box<dyn ErasedSampler> {
277        self.sampler_factory
278            .create_sampler(BenchmarkParams { seed })
279    }
280
281    /// Name of the benchmark
282    pub fn name(&self) -> &str {
283        self.name.as_str()
284    }
285}
286
287/// Converts the implementing type into a vector of [`Benchmark`].
288pub trait IntoBenchmarks {
289    fn into_benchmarks(self) -> Vec<Benchmark>;
290}
291
292impl<const N: usize> IntoBenchmarks for [Benchmark; N] {
293    fn into_benchmarks(self) -> Vec<Benchmark> {
294        self.into_iter().collect()
295    }
296}
297
298impl IntoBenchmarks for Vec<Benchmark> {
299    fn into_benchmarks(self) -> Vec<Benchmark> {
300        self
301    }
302}
303
304/// Describes basic settings for the benchmarking process
305///
306/// This structure is passed to [`cli::run()`].
307///
308/// Should be created only with overriding needed properties, like so:
309/// ```rust
310/// use tango_bench::MeasurementSettings;
311///
312/// let settings = MeasurementSettings {
313///     min_iterations_per_sample: 1000,
314///     ..Default::default()
315/// };
316/// ```
317#[derive(Clone, Copy, Debug)]
318pub struct MeasurementSettings {
319    pub filter_outliers: bool,
320
321    /// The number of samples per one generated haystack
322    pub samples_per_haystack: usize,
323
324    /// Minimum number of iterations in a sample for each of 2 tested functions
325    pub min_iterations_per_sample: usize,
326
327    /// The number of iterations in a sample for each of 2 tested functions
328    pub max_iterations_per_sample: usize,
329
330    pub sampler_type: SampleLengthKind,
331
332    /// If true scheduler performs warmup iterations before measuring function
333    pub warmup_enabled: bool,
334
335    /// Size of a CPU cache firewall in KBytes
336    ///
337    /// If set, the scheduler will perform a dummy data read between samples generation to spoil the CPU cache
338    ///
339    /// Cache firewall is a way to reduce the impact of the CPU cache on the benchmarking process. It tries
340    /// to minimize discrepancies in performance between two algorithms due to the CPU cache state.
341    pub cache_firewall: Option<usize>,
342
343    /// If true, scheduler will perform a yield of control back to the OS before taking each sample
344    ///
345    /// Yielding control to the OS is a way to reduce the impact of OS scheduler on the benchmarking process.
346    pub yield_before_sample: bool,
347
348    /// If set, use alloca to allocate a random offset for the stack each sample.
349    /// This to reduce memory alignment effects on the benchmarking process.
350    ///
351    /// May cause UB if the allocation is larger then the thread stack size.
352    pub randomize_stack: Option<usize>,
353}
354
355#[derive(Clone, Copy, Debug)]
356pub enum SampleLengthKind {
357    Flat,
358    Linear,
359    Random,
360}
361
362/// Performs a dummy reads from memory to spoil given amount of CPU cache
363///
364/// Uses cache aligned data arrays to perform minimum amount of reads possible to spoil the cache
365struct CacheFirewall {
366    cache_lines: Vec<CacheLine>,
367}
368
369impl CacheFirewall {
370    fn new(bytes: usize) -> Self {
371        let n = bytes / mem::size_of::<CacheLine>();
372        let cache_lines = vec![CacheLine::default(); n];
373        Self { cache_lines }
374    }
375
376    fn issue_read(&self) {
377        for line in &self.cache_lines {
378            // Because CacheLine is aligned on 64 bytes it is enough to read single element from the array
379            // to spoil the whole cache line
380            unsafe { ptr::read_volatile(&line.0[0]) };
381        }
382    }
383}
384
385#[repr(C)]
386#[repr(align(64))]
387#[derive(Default, Clone, Copy)]
388struct CacheLine([u16; 32]);
389
390pub const DEFAULT_SETTINGS: MeasurementSettings = MeasurementSettings {
391    filter_outliers: false,
392    samples_per_haystack: 1,
393    min_iterations_per_sample: 1,
394    max_iterations_per_sample: 5000,
395    sampler_type: SampleLengthKind::Random,
396    cache_firewall: None,
397    yield_before_sample: false,
398    warmup_enabled: true,
399    randomize_stack: None,
400};
401
402impl Default for MeasurementSettings {
403    fn default() -> Self {
404        DEFAULT_SETTINGS
405    }
406}
407
408/// Responsible for determining the number of iterations to run for each sample
409///
410/// Different sampler strategies can influence the results heavily. For example, if function is dependent heavily
411/// on a memory subsystem, then it should be tested with different number of iterations to be representative
412/// for different memory access patterns and cache states.
413trait SampleLength {
414    /// Returns the number of iterations to run for the next sample
415    ///
416    /// Accepts the number of iteration being run starting from 0 and cumulative time spent by both functions
417    fn next_sample_iterations(&mut self, iteration_no: usize, estimate: usize) -> usize;
418}
419
420/// Runs the same number of iterations for each sample
421///
422/// Estimates the number of iterations based on the number of iterations achieved in 10 ms and uses
423/// this number as a base for the number of iterations for each sample.
424struct FlatSampleLength {
425    min: usize,
426    max: usize,
427}
428
429impl FlatSampleLength {
430    fn new(settings: &MeasurementSettings) -> Self {
431        FlatSampleLength {
432            min: settings.min_iterations_per_sample.max(1),
433            max: settings.max_iterations_per_sample,
434        }
435    }
436}
437
438impl SampleLength for FlatSampleLength {
439    fn next_sample_iterations(&mut self, _iteration_no: usize, estimate: usize) -> usize {
440        estimate.clamp(self.min, self.max)
441    }
442}
443
444struct LinearSampleLength {
445    min: usize,
446    max: usize,
447}
448
449impl LinearSampleLength {
450    fn new(settings: &MeasurementSettings) -> Self {
451        Self {
452            min: settings.min_iterations_per_sample.max(1),
453            max: settings.max_iterations_per_sample,
454        }
455    }
456}
457
458impl SampleLength for LinearSampleLength {
459    fn next_sample_iterations(&mut self, iteration_no: usize, estimate: usize) -> usize {
460        let estimate = estimate.clamp(self.min, self.max);
461        (iteration_no % estimate) + 1
462    }
463}
464
465/// Sampler that randomly determines the number of iterations to run for each sample
466///
467/// This sampler uses a random number generator to decide the number of iterations for each sample.
468struct RandomSampleLength {
469    rng: SmallRng,
470    min: usize,
471    max: usize,
472}
473
474impl RandomSampleLength {
475    pub fn new(settings: &MeasurementSettings, seed: u64) -> Self {
476        Self {
477            rng: SmallRng::seed_from_u64(seed),
478            min: settings.min_iterations_per_sample.max(1),
479            max: settings.max_iterations_per_sample,
480        }
481    }
482}
483
484impl SampleLength for RandomSampleLength {
485    fn next_sample_iterations(&mut self, _iteration_no: usize, estimate: usize) -> usize {
486        let estimate = estimate.clamp(self.min, self.max);
487        self.rng.gen_range(1..=estimate)
488    }
489}
490
491/// Calculates the result of the benchmarking run
492///
493/// Return None if no measurements were made
494pub(crate) fn calculate_run_result<N: Into<String>>(
495    name: N,
496    baseline: &[u64],
497    candidate: &[u64],
498    iterations_per_sample: &[usize],
499    filter_outliers: bool,
500) -> Option<RunResult> {
501    assert!(baseline.len() == candidate.len());
502    assert!(baseline.len() == iterations_per_sample.len());
503
504    let mut iterations_per_sample = iterations_per_sample.to_vec();
505
506    let mut diff = candidate
507        .iter()
508        .zip(baseline.iter())
509        // Calculating difference between candidate and baseline
510        .map(|(&c, &b)| c as f64 - b as f64)
511        .zip(iterations_per_sample.iter())
512        // Normalizing difference to iterations count
513        .map(|(diff, &iters)| diff / iters as f64)
514        .collect::<Vec<_>>();
515
516    // need to save number of original samples to calculate number of outliers correctly
517    let n = diff.len();
518
519    // Normalizing measurements to iterations count
520    let mut baseline = baseline
521        .iter()
522        .zip(iterations_per_sample.iter())
523        .map(|(&v, &iters)| (v as f64) / (iters as f64))
524        .collect::<Vec<_>>();
525    let mut candidate = candidate
526        .iter()
527        .zip(iterations_per_sample.iter())
528        .map(|(&v, &iters)| (v as f64) / (iters as f64))
529        .collect::<Vec<_>>();
530
531    // Calculating measurements range. All measurements outside this interval considered outliers
532    let range = if filter_outliers {
533        iqr_variance_thresholds(diff.to_vec())
534    } else {
535        None
536    };
537
538    // Cleaning measurements from outliers if needed
539    if let Some(range) = range {
540        // We filtering outliers to build statistical Summary and the order of elements in arrays
541        // doesn't matter, therefore swap_remove() is used. But we need to make sure that all arrays
542        // has the same length
543        assert_eq!(diff.len(), baseline.len());
544        assert_eq!(diff.len(), candidate.len());
545
546        let mut i = 0;
547        while i < diff.len() {
548            if range.contains(&diff[i]) {
549                i += 1;
550            } else {
551                diff.swap_remove(i);
552                iterations_per_sample.swap_remove(i);
553                baseline.swap_remove(i);
554                candidate.swap_remove(i);
555            }
556        }
557    };
558
559    let diff_summary = Summary::from(&diff)?;
560    let baseline_summary = Summary::from(&baseline)?;
561    let candidate_summary = Summary::from(&candidate)?;
562
563    let diff_estimate = DiffEstimate::build(&baseline_summary, &diff_summary);
564
565    Some(RunResult {
566        baseline: baseline_summary,
567        candidate: candidate_summary,
568        diff: diff_summary,
569        name: name.into(),
570        diff_estimate,
571        outliers: n - diff_summary.n,
572    })
573}
574
575/// Contains the estimation of how much faster or slower is candidate function compared to baseline
576pub(crate) struct DiffEstimate {
577    // Percentage of difference between candidate and baseline
578    //
579    // Negative value means that candidate is faster than baseline, positive - slower.
580    pct: f64,
581
582    // Is the difference statistically significant
583    significant: bool,
584}
585
586impl DiffEstimate {
587    /// Builds [`DiffEstimate`] from flat sampling
588    ///
589    /// Flat sampling is a sampling where each measurement is normalized by the number of iterations.
590    /// This is needed to make measurements comparable between each other. Linear sampling is more
591    /// robust to outliers, but it is requiring more iterations.
592    ///
593    /// It is assumed that baseline and candidate are already normalized by iterations count.
594    fn build(baseline: &Summary<f64>, diff: &Summary<f64>) -> Self {
595        let std_dev = diff.variance.sqrt();
596        let std_err = std_dev / (diff.n as f64).sqrt();
597        let z_score = diff.mean / std_err;
598
599        // significant result is far away from 0 and have more than 0.5% base/candidate difference
600        // z_score = 2.6 corresponds to 99% significance level
601        let significant = z_score.abs() >= 2.6 && (diff.mean / baseline.mean).abs() > 0.005;
602        let pct = diff.mean / baseline.mean * 100.0;
603
604        Self { pct, significant }
605    }
606}
607
608/// Describes the results of a single benchmark run
609pub(crate) struct RunResult {
610    /// name of a test
611    name: String,
612
613    /// statistical summary of baseline function measurements
614    baseline: Summary<f64>,
615
616    /// statistical summary of candidate function measurements
617    candidate: Summary<f64>,
618
619    /// individual measurements of a benchmark (candidate - baseline)
620    diff: Summary<f64>,
621
622    diff_estimate: DiffEstimate,
623
624    /// Numbers of detected and filtered outliers
625    outliers: usize,
626}
627
628/// Statistical summary for a given iterator of numbers.
629///
630/// Calculates all the information using single pass over the data. Mean and variance are calculated using
631/// streaming algorithm described in _Art of Computer Programming, Vol 2, page 232_.
632#[derive(Clone, Copy)]
633pub struct Summary<T> {
634    pub n: usize,
635    pub min: T,
636    pub max: T,
637    pub mean: f64,
638    pub variance: f64,
639}
640
641impl<T: PartialOrd> Summary<T> {
642    pub fn from<'a, C>(values: C) -> Option<Self>
643    where
644        C: IntoIterator<Item = &'a T>,
645        T: ToPrimitive + Copy + Default + 'a,
646    {
647        Self::running(values.into_iter().copied()).last()
648    }
649
650    pub fn running<I>(iter: I) -> impl Iterator<Item = Summary<T>>
651    where
652        T: ToPrimitive + Copy + Default,
653        I: Iterator<Item = T>,
654    {
655        RunningSummary {
656            iter,
657            n: 0,
658            min: T::default(),
659            max: T::default(),
660            mean: 0.,
661            s: 0.,
662        }
663    }
664}
665
666struct RunningSummary<T, I> {
667    iter: I,
668    n: usize,
669    min: T,
670    max: T,
671    mean: f64,
672    s: f64,
673}
674
675impl<T, I> Iterator for RunningSummary<T, I>
676where
677    T: Copy + PartialOrd,
678    I: Iterator<Item = T>,
679    T: ToPrimitive,
680{
681    type Item = Summary<T>;
682
683    fn next(&mut self) -> Option<Self::Item> {
684        let value = self.iter.next()?;
685        let fvalue = value.to_f64().expect("f64 overflow detected");
686
687        if self.n == 0 {
688            self.min = value;
689            self.max = value;
690        }
691
692        if let Some(Ordering::Less) = value.partial_cmp(&self.min) {
693            self.min = value;
694        }
695        if let Some(Ordering::Greater) = value.partial_cmp(&self.max) {
696            self.max = value;
697        }
698
699        self.n += 1;
700        let mean_p = self.mean;
701        self.mean += (fvalue - self.mean) / self.n as f64;
702        self.s += (fvalue - mean_p) * (fvalue - self.mean);
703        let variance = if self.n > 1 {
704            self.s / (self.n - 1) as f64
705        } else {
706            0.
707        };
708
709        Some(Summary {
710            n: self.n,
711            min: self.min,
712            max: self.max,
713            mean: self.mean,
714            variance,
715        })
716    }
717}
718
719/// Outlier detection algorithm based on interquartile range
720///
721/// Observations that are 1.5 IQR away from the corresponding quartile are consideted as outliers
722/// as described in original Tukey's paper.
723pub fn iqr_variance_thresholds(mut input: Vec<f64>) -> Option<RangeInclusive<f64>> {
724    const MINIMUM_IQR: f64 = 1.;
725
726    input.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
727    let (q1, q3) = (input.len() / 4, input.len() * 3 / 4 - 1);
728    if q1 >= q3 || q3 >= input.len() {
729        return None;
730    }
731    // In case q1 and q3 are equal, we need to make sure that IQR is not 0
732    // In the future it would be nice to measure system timer precision empirically.
733    let iqr = (input[q3] - input[q1]).max(MINIMUM_IQR);
734
735    let low_threshold = input[q1] - iqr * 1.5;
736    let high_threshold = input[q3] + iqr * 1.5;
737
738    // Calculating the indices of the thresholds in an dataset
739    let low_threshold_idx =
740        match input[0..q1].binary_search_by(|probe| probe.total_cmp(&low_threshold)) {
741            Ok(idx) => idx,
742            Err(idx) => idx,
743        };
744
745    let high_threshold_idx =
746        match input[q3..].binary_search_by(|probe| probe.total_cmp(&high_threshold)) {
747            Ok(idx) => idx,
748            Err(idx) => idx,
749        };
750
751    if low_threshold_idx == 0 || high_threshold_idx >= input.len() {
752        return None;
753    }
754
755    // Calculating the equal number of observations which should be removed from each "side" of observations
756    let outliers_cnt = low_threshold_idx.min(input.len() - high_threshold_idx);
757
758    Some(input[outliers_cnt]..=(input[input.len() - outliers_cnt - 1]))
759}
760
761/// This trait allows to define strategy for measuring metric of interest about the code
762pub trait Metric {
763    /// Measures current metric on a given code
764    fn measure_fn(f: impl FnMut()) -> u64;
765}
766
767pub mod metrics {
768    use crate::Metric;
769
770    pub struct WallClock;
771
772    impl Metric for WallClock {
773        /// Implementation of wall clock timer that uses standard OS time source
774        #[cfg(not(all(feature = "hw-timer", target_arch = "x86_64")))]
775        fn measure_fn(mut f: impl FnMut()) -> u64 {
776            use std::time::Instant;
777            let start = Instant::now();
778            f();
779            start.elapsed().as_nanos() as u64
780        }
781
782        /// Implementation of wall clock timer that uses rdtscp on x86
783        #[cfg(all(feature = "hw-timer", target_arch = "x86_64"))]
784        fn measure_fn(mut f: impl FnMut()) -> u64 {
785            use std::arch::x86_64::{__rdtscp, _mm_mfence};
786            let start = unsafe {
787                _mm_mfence();
788                __rdtscp(&mut 0)
789            };
790            f();
791            unsafe {
792                let end = __rdtscp(&mut 0);
793                _mm_mfence();
794                end - start
795            }
796        }
797    }
798}
799
800#[cfg(feature = "async")]
801pub mod asynchronous {
802    use crate::metrics::WallClock;
803
804    use super::{Benchmark, BenchmarkParams, ErasedSampler, Sampler, SamplerFactory};
805    use std::{future::Future, ops::Deref};
806
807    pub fn async_benchmark_fn<R, F>(
808        name: impl Into<String>,
809        runtime: R,
810        sampler_factory: F,
811    ) -> Benchmark
812    where
813        R: AsyncRuntime + 'static,
814        F: FnMut(AsyncBencher<R>) -> Box<dyn ErasedSampler> + 'static,
815    {
816        let name = name.into();
817        assert!(!name.is_empty());
818        Benchmark {
819            name,
820            sampler_factory: Box::new(AsyncSampleFactory(sampler_factory, runtime)),
821        }
822    }
823
824    pub struct AsyncSampleFactory<F, R>(pub F, pub R);
825
826    impl<R: AsyncRuntime, F: FnMut(AsyncBencher<R>) -> Box<dyn ErasedSampler>> SamplerFactory
827        for AsyncSampleFactory<F, R>
828    {
829        fn create_sampler(&mut self, params: BenchmarkParams) -> Box<dyn ErasedSampler> {
830            (self.0)(AsyncBencher {
831                params,
832                runtime: self.1,
833            })
834        }
835    }
836
837    pub struct AsyncBencher<R> {
838        params: BenchmarkParams,
839        runtime: R,
840    }
841
842    impl<R: AsyncRuntime + 'static> AsyncBencher<R> {
843        pub fn iter<O, Fut, F>(self, func: F) -> Box<dyn ErasedSampler>
844        where
845            O: 'static,
846            Fut: Future<Output = O>,
847            F: FnMut() -> Fut + Copy + 'static,
848        {
849            Box::new(Sampler::<_, WallClock>::new(move || {
850                self.runtime.block_on(func)
851            }))
852        }
853    }
854
855    impl<R> Deref for AsyncBencher<R> {
856        type Target = BenchmarkParams;
857
858        fn deref(&self) -> &Self::Target {
859            &self.params
860        }
861    }
862
863    pub trait AsyncRuntime: Copy {
864        fn block_on<O, Fut: Future<Output = O>, F: FnMut() -> Fut>(&self, f: F) -> O;
865    }
866
867    #[cfg(feature = "async-tokio")]
868    pub mod tokio {
869        use super::*;
870        use ::tokio::runtime::Builder;
871
872        #[derive(Copy, Clone, Default)]
873        pub struct TokioRuntime;
874
875        impl AsyncRuntime for TokioRuntime {
876            fn block_on<O, Fut: Future<Output = O>, F: FnMut() -> Fut>(&self, mut f: F) -> O {
877                let mut builder = Builder::new_current_thread();
878                #[cfg(feature = "async-tokio-all-drivers")]
879                builder.enable_all();
880                let runtime = builder.build().unwrap();
881                runtime.block_on(f())
882            }
883        }
884    }
885}
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890    use rand::{rngs::SmallRng, Rng, RngCore, SeedableRng};
891    use std::{
892        iter::Sum,
893        ops::{Add, Div},
894        thread,
895        time::Duration,
896    };
897
898    #[test]
899    fn check_iqr_variance_thresholds() {
900        let mut rng = SmallRng::from_entropy();
901
902        // Generate 20 random values in range [-50, 50]
903        // and add 10 outliers in each of two ranges [-1000, -200] and [200, 1000]
904        // This way IQR is no more than 100 and thresholds should be within [-50, 50] range
905        let mut values = vec![];
906        values.extend((0..20).map(|_| rng.gen_range(-50.0..=50.)));
907        values.extend((0..10).map(|_| rng.gen_range(-1000.0..=-200.0)));
908        values.extend((0..10).map(|_| rng.gen_range(200.0..=1000.0)));
909
910        let thresholds = iqr_variance_thresholds(values).unwrap();
911
912        assert!(
913            -50. <= *thresholds.start() && *thresholds.end() <= 50.,
914            "Invalid range: {:?}",
915            thresholds
916        );
917    }
918
919    /// This tests checks that the algorithm is stable in case of zero difference between 25 and 75 percentiles
920    #[test]
921    fn check_outliers_zero_iqr() {
922        let mut rng = SmallRng::from_entropy();
923
924        let mut values = vec![];
925        values.extend([0.; 20]);
926        values.extend((0..10).map(|_| rng.gen_range(-1000.0..=-200.0)));
927        values.extend((0..10).map(|_| rng.gen_range(200.0..=1000.0)));
928
929        let thresholds = iqr_variance_thresholds(values).unwrap();
930
931        assert!(
932            0. <= *thresholds.start() && *thresholds.end() <= 0.,
933            "Invalid range: {:?}",
934            thresholds
935        );
936    }
937
938    #[test]
939    fn check_summary_statistics() {
940        for i in 2u32..100 {
941            let range = 1..=i;
942            let values = range.collect::<Vec<_>>();
943            let stat = Summary::from(&values).unwrap();
944
945            let sum = (i * (i + 1)) as f64 / 2.;
946            let expected_mean = sum / i as f64;
947            let expected_variance = naive_variance(values.as_slice());
948
949            assert_eq!(stat.min, 1);
950            assert_eq!(stat.n, i as usize);
951            assert_eq!(stat.max, i);
952            assert!(
953                (stat.mean - expected_mean).abs() < 1e-5,
954                "Expected close to: {}, given: {}",
955                expected_mean,
956                stat.mean
957            );
958            assert!(
959                (stat.variance - expected_variance).abs() < 1e-5,
960                "Expected close to: {}, given: {}",
961                expected_variance,
962                stat.variance
963            );
964        }
965    }
966
967    #[test]
968    fn check_summary_statistics_types() {
969        Summary::from(<&[i64]>::default());
970        Summary::from(<&[u32]>::default());
971        Summary::from(&Vec::<i64>::default());
972    }
973
974    #[test]
975    fn check_naive_variance() {
976        assert_eq!(naive_variance(&[1, 2, 3]), 1.0);
977        assert_eq!(naive_variance(&[1, 2, 3, 4, 5]), 2.5);
978    }
979
980    #[test]
981    fn check_running_variance() {
982        let input = [1i64, 2, 3, 4, 5, 6, 7];
983        let variances = Summary::running(input.into_iter())
984            .map(|s| s.variance)
985            .collect::<Vec<_>>();
986        let expected = &[0., 0.5, 1., 1.6666, 2.5, 3.5, 4.6666];
987
988        assert_eq!(variances.len(), expected.len());
989
990        for (value, expected_value) in variances.iter().zip(expected) {
991            assert!(
992                (value - expected_value).abs() < 1e-3,
993                "Expected close to: {}, given: {}",
994                expected_value,
995                value
996            );
997        }
998    }
999
1000    #[test]
1001    fn check_running_variance_stress_test() {
1002        let rng = RngIterator(SmallRng::seed_from_u64(0)).map(|i| i as i64);
1003        let mut variances = Summary::running(rng).map(|s| s.variance);
1004
1005        assert!(variances.nth(1_000_000).unwrap() > 0.)
1006    }
1007
1008    /// Basic check of measurement code
1009    ///
1010    /// This test is quite brittle. There is no guarantee the OS scheduler will wake up the thread
1011    /// soon enough to meet measurement target. We try to mitigate this possibility using several strategies:
1012    /// 1. repeating test several times and taking median as target measurement.
1013    /// 2. using more liberal checking condition (allowing 1 order of magnitude error in measurement)
1014    #[test]
1015    fn check_measure_time() {
1016        let expected_delay = 100;
1017        let mut target = benchmark_fn("foo", move |b| {
1018            b.metric::<WallClock>()
1019                .iter(move || thread::sleep(Duration::from_millis(expected_delay)))
1020        });
1021        target.prepare_state(0);
1022
1023        let median = median_execution_time(&mut target, 10).as_millis() as u64;
1024        assert!(median < expected_delay * 10, "Median {median} is too large");
1025    }
1026
1027    struct RngIterator<T>(T);
1028
1029    impl<T: RngCore> Iterator for RngIterator<T> {
1030        type Item = u32;
1031
1032        fn next(&mut self) -> Option<Self::Item> {
1033            Some(self.0.next_u32())
1034        }
1035    }
1036
1037    fn naive_variance<T>(values: &[T]) -> f64
1038    where
1039        T: Sum + Copy,
1040        f64: From<T>,
1041    {
1042        let n = values.len() as f64;
1043        let mean = f64::from(values.iter().copied().sum::<T>()) / n;
1044        let mut sum_of_squares = 0.;
1045        for value in values.iter().copied() {
1046            sum_of_squares += (f64::from(value) - mean).powi(2);
1047        }
1048        sum_of_squares / (n - 1.)
1049    }
1050
1051    fn median_execution_time(target: &mut Benchmark, iterations: u32) -> Duration {
1052        assert!(iterations >= 1);
1053        let mut state = target.prepare_state(0);
1054        let measures: Vec<_> = (0..iterations).map(|_| state.measure(1)).collect();
1055        let time = median(measures).max(1);
1056        Duration::from_nanos(time)
1057    }
1058
1059    fn median<T: Copy + Ord + Add<Output = T> + Div<Output = T>>(mut measures: Vec<T>) -> T {
1060        assert!(!measures.is_empty(), "Vec is empty");
1061        measures.sort_unstable();
1062        measures[measures.len() / 2]
1063    }
1064}