1use std::{
2 collections::{BTreeSet, HashMap},
3 time::Instant,
4 fs::{OpenOptions, symlink_metadata},
5 path::{
6 Path,
7 Component::Normal,
8 },
9 io::{BufReader, BufRead , Result, ErrorKind, Error},
10 hint::black_box,
11};
12
13pub type Ns = u128;
15
16#[derive(Default)]
18struct Run {
19 times: HashMap<&'static str, Ns>,
21 total_ns: Ns,
23}
24
25pub struct Bench {
27 cur_start: Instant,
29 cur_times: HashMap<&'static str, Ns>,
31 runs: Vec<Run>,
33}
34
35impl Bench {
36 pub fn new() -> Self {
38 Self {
39 cur_start: Instant::now(),
40 cur_times: HashMap::new(),
41 runs: Vec::new(),
42 }
43 }
44
45 pub fn measure<R, F>(&mut self, label: &'static str, mut f: F) -> R
48 where
49 F: FnMut() -> R,
50 {
51 let warmup_runs = 200;
52 let concerned_runs = 1000;
53
54 for _ in 0..warmup_runs {
55 black_box(f());
56 }
57
58 let mut total_ns: u128 = 0;
59 let mut function_output: Option<R> = None;
60
61 for _ in 0..concerned_runs {
62 let t0 = Instant::now();
63 let out = black_box(f());
64 let ns = t0.elapsed().as_nanos();
65 total_ns = total_ns.saturating_add(ns); function_output = Some(out);
67 }
68
69 let average = total_ns / concerned_runs as u128;
70 self.cur_times.insert(label, average);
71 function_output.expect("no measured runs completed (closure panicked before first store?)")
72
73 }
74
75 pub fn measure_with_custom_runs_and_warmup<R, F>(
76 &mut self,
77 label: &'static str,
78 mut f: F,
79 runs: usize,
80 warmup: usize,
81 ) -> R
82 where
83 F: FnMut() -> R,
84 {
85 assert!(runs > warmup, "runs must be > warmup");
86 let conerened_runs = runs - warmup;
87
88 assert!(conerened_runs > 0, "must have at least one timed run");
89
90 for _ in 0..warmup {
92 black_box(f());
93 }
94
95 let mut total_ns: u128 = 0;
97 let mut last_out: Option<R> = None;
98
99 for _ in 0..conerened_runs {
100 let t0 = Instant::now();
101 let out = std::hint::black_box(f());
102 let ns = t0.elapsed().as_nanos();
103 total_ns = total_ns.saturating_add(ns); last_out = Some(out);
105 }
106
107 let average = total_ns / conerened_runs as u128;
108 self.cur_times.insert(label, average);
109 last_out.expect("no measured runs completed (error running the given function?)")
110 }
111
112
113
114 pub fn next_run(&mut self) {
117 let total_ns = self.cur_start.elapsed().as_nanos();
118 self.runs.push(Run {
119 times: std::mem::take(&mut self.cur_times),
120 total_ns,
121 });
122 self.cur_start = Instant::now();
123 }
124
125 pub fn save_to_csv<P: AsRef<Path>>(&mut self, path: P) -> csv::Result<()> { let path = path.as_ref();
131 ensure_cwd_csv(path).map_err(csv::Error::from)?; self.next_run();
135
136 let mut labels: BTreeSet<&'static str> = BTreeSet::new();
138 for run in &self.runs {
139 labels.extend(run.times.keys());
140 }
141
142 let (mut wtr, start_idx) = if path.exists() {
143 let f = OpenOptions::new().read(true).open(path)?;
145 let mut rdr = BufReader::new(&f);
146 let mut last = 0;
147 let mut line = String::new();
148 while rdr.read_line(&mut line)? != 0 {
149 if let Some(first) = line.split(',').next() {
150 last = first.trim().parse::<usize>().unwrap_or(last);
151 }
152 line.clear();
153 }
154 let f = OpenOptions::new().append(true).open(path)?;
156 let w = csv::WriterBuilder::new().has_headers(false).from_writer(f);
157 (w, last)
158 } else {
159 let mut w = csv::Writer::from_path(path)?;
161 let mut header: Vec<String> = Vec::with_capacity(labels.len() + 2);
162 header.push("run".into());
163 for l in &labels {
164 header.push(format!("{l}_ns"));
165 }
166 header.push("total_ns".into());
167 w.write_record(&header)?;
168 (w, 0)
169 };
170
171 for (idx, run) in self.runs.iter().enumerate() {
173 let mut row: Vec<String> = Vec::with_capacity(labels.len() + 2);
174 row.push((start_idx + idx + 1).to_string()); for l in &labels {
176 row.push(
177 run.times
178 .get(l)
179 .map_or(String::new(), |v| v.to_string()),
180 );
181 }
182 row.push(run.total_ns.to_string());
183 wtr.write_record(&row)?;
184 }
185 wtr.flush()?;
186 Ok(())
187 }
188}
189
190#[macro_export]
196macro_rules! bench {
197 ($bench:expr, $label:expr, $body:block) => {{
198 $bench.measure($label, || $body)
199 }};
200}
201
202fn ensure_cwd_csv(path: &Path) -> Result<()> {
203 if path.is_absolute() {
204 return Err(Error::new(ErrorKind::InvalidInput, "absolute paths are not allowed"));
205 }
206 let mut comps = path.components();
208 let ok = match (comps.next(), comps.next(), comps.next()) {
209 (Some(Normal(_)), None, None) => true,
210 (Some(_), Some(Normal(_)), None) => true,
211 _ => false,
212 };
213 if !ok {
214 return Err(Error::new(
215 ErrorKind::InvalidInput,
216 "only filenames in the current directory are allowed",
217 ));
218 }
219 if path.is_dir() {
220 return Err(Error::new(ErrorKind::InvalidInput, "path points to a directory"));
221 }
222 if let Some(ext) = path.extension() {
223 if ext != "csv" {
224 return Err(Error::new(ErrorKind::InvalidInput, "file extension must be .csv"));
225 }
226 } else {
227 return Err(Error::new(ErrorKind::InvalidInput, "file must have .csv extension"));
228 }
229 if let Ok(meta) = symlink_metadata(path) {
231 if meta.file_type().is_symlink() {
232 return Err(Error::new(ErrorKind::InvalidInput, "symlinks are not allowed"));
233 }
234 }
235 Ok(())
236}