1use self::ffi::SELF_VTABLE;
4use crate::{Benchmark, ErasedSampler, Error};
5use ffi::VTable;
6use libloading::{Library, Symbol};
7use std::{
8 cell::UnsafeCell,
9 ffi::{c_char, c_ulonglong},
10 path::Path,
11 ptr::addr_of,
12 slice, str,
13 sync::mpsc::{channel, Receiver, Sender},
14 thread::{self, JoinHandle},
15};
16
17pub type FunctionIdx = usize;
18
19#[derive(Debug, Clone)]
20pub struct NamedFunction {
21 pub name: String,
22
23 pub idx: FunctionIdx,
25}
26
27pub(crate) struct Spi {
28 tests: Vec<NamedFunction>,
29 selected_function: Option<FunctionIdx>,
30 mode: SpiMode,
31}
32
33#[derive(PartialEq, Eq, Clone, Copy)]
34pub enum SpiModeKind {
35 Synchronous,
39
40 Asynchronous,
44}
45
46enum SpiMode {
47 Synchronous {
48 vt: VTable,
49 last_measurement: u64,
50 },
51 Asynchronous {
52 worker: Option<JoinHandle<()>>,
53 tx: Sender<SpiRequest>,
54 rx: Receiver<SpiReply>,
55 },
56}
57
58impl Spi {
59 pub(crate) fn for_library(path: impl AsRef<Path>, mode: SpiModeKind) -> Result<Spi, Error> {
60 let path = path.as_ref();
61 if path.exists() {
62 #[cfg(target_os = "windows")]
63 let lib = {
64 use libloading::os::windows::Library as WinLibrary;
65 use windows::Win32::Foundation::HMODULE;
66
67 let lib = unsafe { WinLibrary::new(path) }.map_err(Error::UnableToLoadBenchmark)?;
68
69 let raw_handle = lib.into_raw();
73 let handle = HMODULE(raw_handle as _);
74 unsafe {
75 crate::windows::patch_iat(handle).map_err(Error::UnableToPatchIat)?;
76 }
77 Library::from(unsafe { WinLibrary::from_raw(raw_handle) })
79 };
80
81 #[cfg(not(target_os = "windows"))]
82 let lib = unsafe { Library::new(path) }.map_err(Error::UnableToLoadBenchmark)?;
83
84 Ok(spi_handle_for_vtable(ffi::VTable::new(lib)?, mode))
85 } else {
86 Err(Error::BenchmarkNotFound)
87 }
88 }
89
90 pub(crate) fn for_self(mode: SpiModeKind) -> Option<Spi> {
91 SELF_VTABLE
92 .lock()
93 .unwrap()
94 .take()
95 .map(|vt| spi_handle_for_vtable(vt, mode))
96 }
97
98 pub(crate) fn tests(&self) -> &[NamedFunction] {
99 &self.tests
100 }
101
102 pub(crate) fn lookup(&self, name: &str) -> Option<&NamedFunction> {
103 self.tests.iter().find(|f| f.name == name)
104 }
105
106 pub(crate) fn run(&mut self, iterations: usize) -> Result<u64, Error> {
107 match &self.mode {
108 SpiMode::Synchronous { vt, .. } => vt.run(iterations as c_ulonglong),
109 SpiMode::Asynchronous { worker: _, tx, rx } => {
110 tx.send(SpiRequest::Run { iterations }).unwrap();
111 match rx.recv().unwrap() {
112 SpiReply::Run(time) => time,
113 r => panic!("Unexpected response: {:?}", r),
114 }
115 }
116 }
117 }
118
119 pub(crate) fn measure(&mut self, iterations: usize) -> Result<(), Error> {
120 match &mut self.mode {
121 SpiMode::Synchronous {
122 vt,
123 last_measurement,
124 } => {
125 *last_measurement = vt.run(iterations as c_ulonglong)?;
126 }
127 SpiMode::Asynchronous { tx, .. } => {
128 tx.send(SpiRequest::Measure { iterations }).unwrap();
129 }
130 }
131 Ok(())
132 }
133
134 pub(crate) fn read_sample(&mut self) -> Result<u64, Error> {
135 match &self.mode {
136 SpiMode::Synchronous {
137 last_measurement, ..
138 } => Ok(*last_measurement),
139 SpiMode::Asynchronous { rx, .. } => match rx.recv().unwrap() {
140 SpiReply::Measure(time) => time,
141 r => panic!("Unexpected response: {:?}", r),
142 },
143 }
144 }
145
146 pub(crate) fn estimate_iterations(&mut self, time_ms: u32) -> Result<usize, Error> {
147 match &self.mode {
148 SpiMode::Synchronous { vt, .. } => vt.estimate_iterations(time_ms),
149 SpiMode::Asynchronous { tx, rx, .. } => {
150 tx.send(SpiRequest::EstimateIterations { time_ms }).unwrap();
151 match rx.recv().unwrap() {
152 SpiReply::EstimateIterations(iters) => iters,
153 r => panic!("Unexpected response: {:?}", r),
154 }
155 }
156 }
157 }
158
159 pub(crate) fn prepare_state(&mut self, seed: u64) -> Result<(), Error> {
160 match &self.mode {
161 SpiMode::Synchronous { vt, .. } => vt.prepare_state(seed),
162 SpiMode::Asynchronous { tx, rx, .. } => {
163 tx.send(SpiRequest::PrepareState { seed }).unwrap();
164 match rx.recv().unwrap() {
165 SpiReply::PrepareState(result) => result,
166 r => panic!("Unexpected response: {:?}", r),
167 }
168 }
169 }
170 }
171
172 pub(crate) fn select(&mut self, idx: usize) {
173 match &self.mode {
174 SpiMode::Synchronous { vt, .. } => vt.select(idx as c_ulonglong),
175 SpiMode::Asynchronous { tx, rx, .. } => {
176 tx.send(SpiRequest::Select { idx }).unwrap();
177 match rx.recv().unwrap() {
178 SpiReply::Select => self.selected_function = Some(idx),
179 r => panic!("Unexpected response: {:?}", r),
180 }
181 }
182 }
183 }
184}
185
186impl Drop for Spi {
187 fn drop(&mut self) {
188 if let SpiMode::Asynchronous { worker, tx, .. } = &mut self.mode {
189 if let Some(worker) = worker.take() {
190 tx.send(SpiRequest::Shutdown).unwrap();
191 worker.join().unwrap();
192 }
193 }
194 }
195}
196
197fn spi_worker(vt: &VTable, rx: Receiver<SpiRequest>, tx: Sender<SpiReply>) {
198 use SpiReply as Rp;
199 use SpiRequest as Rq;
200
201 while let Ok(req) = rx.recv() {
202 let reply = match req {
203 Rq::EstimateIterations { time_ms } => {
204 Rp::EstimateIterations(vt.estimate_iterations(time_ms))
205 }
206 Rq::PrepareState { seed } => Rp::PrepareState(vt.prepare_state(seed)),
207 Rq::Select { idx } => {
208 vt.select(idx as c_ulonglong);
209 Rp::Select
210 }
211 Rq::Run { iterations } => Rp::Run(vt.run(iterations as c_ulonglong)),
212 Rq::Measure { iterations } => Rp::Measure(vt.run(iterations as c_ulonglong)),
213 Rq::Shutdown => break,
214 };
215 tx.send(reply).unwrap();
216 }
217}
218
219fn spi_handle_for_vtable(vt: VTable, mode: SpiModeKind) -> Spi {
220 vt.init();
221 let tests = enumerate_tests(&vt).unwrap();
222
223 match mode {
224 SpiModeKind::Asynchronous => {
225 let (request_tx, request_rx) = channel();
226 let (reply_tx, reply_rx) = channel();
227 let worker = thread::spawn(move || {
228 spi_worker(&vt, request_rx, reply_tx);
229 });
230
231 Spi {
232 tests,
233 selected_function: None,
234 mode: SpiMode::Asynchronous {
235 worker: Some(worker),
236 tx: request_tx,
237 rx: reply_rx,
238 },
239 }
240 }
241 SpiModeKind::Synchronous => Spi {
242 tests,
243 selected_function: None,
244 mode: SpiMode::Synchronous {
245 vt,
246 last_measurement: 0,
247 },
248 },
249 }
250}
251
252fn enumerate_tests(vt: &VTable) -> Result<Vec<NamedFunction>, Error> {
253 let mut tests = vec![];
254 for idx in 0..vt.count() {
255 vt.select(idx);
256
257 let mut length = 0;
258 let name_ptr: *const c_char = c"".as_ptr();
259 vt.get_test_name(addr_of!(name_ptr) as _, &mut length);
260 if length > 0 {
261 let slice = unsafe { slice::from_raw_parts(name_ptr as *const u8, length as usize) };
262 let name = str::from_utf8(slice)
263 .map_err(Error::InvalidFFIString)?
264 .to_string();
265 let idx = idx as usize;
266 tests.push(NamedFunction { name, idx });
267 }
268 }
269 Ok(tests)
270}
271
272enum SpiRequest {
273 EstimateIterations { time_ms: u32 },
274 PrepareState { seed: u64 },
275 Select { idx: usize },
276 Run { iterations: usize },
277 Measure { iterations: usize },
278 Shutdown,
279}
280
281#[derive(Debug)]
282enum SpiReply {
283 EstimateIterations(Result<usize, Error>),
284 PrepareState(Result<(), Error>),
285 Select,
286 Run(Result<u64, Error>),
287 Measure(Result<u64, Error>),
288}
289
290struct State {
293 benchmarks: Vec<Benchmark>,
294 selected_function: Option<(usize, Option<Box<dyn ErasedSampler>>)>,
295 last_error: Option<String>,
296}
297
298impl State {
299 fn selected(&self) -> &Benchmark {
300 &self.benchmarks[self.ensure_selected()]
301 }
302
303 fn ensure_selected(&self) -> usize {
304 self.selected_function
305 .as_ref()
306 .map(|(idx, _)| *idx)
307 .expect("No function was selected. Call tango_select() first")
308 }
309
310 fn selected_state_mut(&mut self) -> Option<&mut Box<dyn ErasedSampler>> {
311 self.selected_function
312 .as_mut()
313 .and_then(|(_, state)| state.as_mut())
314 }
315}
316
317static STATE: StateWrapper = StateWrapper(UnsafeCell::new(None));
319
320struct StateWrapper(UnsafeCell<Option<State>>);
321unsafe impl Sync for StateWrapper {}
322
323impl StateWrapper {
324 unsafe fn as_ref(&self) -> Option<&State> {
325 (*self.0.get()).as_ref()
326 }
327
328 #[allow(clippy::mut_from_ref)]
329 unsafe fn as_mut(&self) -> Option<&mut State> {
330 (*self.0.get()).as_mut()
331 }
332}
333
334pub fn __tango_init(benchmarks: Vec<Benchmark>) {
339 if unsafe { STATE.as_ref().is_none() } {
340 let state = Some(State {
341 benchmarks,
342 selected_function: None,
343 last_error: None,
344 });
345 unsafe { *STATE.0.get() = state }
346 }
347}
348
349pub mod ffi {
356 use super::*;
357 use std::{
358 ffi::{c_int, c_uint, c_ulonglong},
359 os::raw::c_char,
360 panic::{catch_unwind, UnwindSafe},
361 ptr::null,
362 sync::Mutex,
363 };
364
365 pub type InitFn = unsafe extern "C" fn();
367 type CountFn = unsafe extern "C" fn() -> c_ulonglong;
368 type GetTestNameFn = unsafe extern "C" fn(*mut *const c_char, *mut c_ulonglong);
369 type SelectFn = unsafe extern "C" fn(c_ulonglong);
370 type RunFn = unsafe extern "C" fn(c_ulonglong, *mut c_ulonglong) -> c_int;
371 type EstimateIterationsFn = unsafe extern "C" fn(c_uint) -> c_ulonglong;
372 type PrepareStateFn = unsafe extern "C" fn(c_ulonglong) -> c_int;
373 type GetLastErrorFn = unsafe extern "C" fn(*mut *const c_char, *mut c_ulonglong) -> c_int;
374 type ApiVersionFn = unsafe extern "C" fn() -> c_uint;
375 type FreeFn = unsafe extern "C" fn();
376
377 pub(super) static SELF_VTABLE: Mutex<Option<VTable>> = Mutex::new(Some(VTable::for_self()));
378 pub const TANGO_API_VERSION: u32 = 3;
379
380 #[no_mangle]
381 unsafe extern "C" fn tango_count() -> c_ulonglong {
382 STATE
383 .as_ref()
384 .map(|s| s.benchmarks.len() as c_ulonglong)
385 .unwrap_or(0)
386 }
387
388 #[no_mangle]
389 unsafe extern "C" fn tango_api_version() -> c_uint {
390 TANGO_API_VERSION
391 }
392
393 #[no_mangle]
394 unsafe extern "C" fn tango_select(idx: c_ulonglong) {
395 if let Some(s) = STATE.as_mut() {
396 let idx = idx as usize;
397 assert!(idx < s.benchmarks.len());
398
399 s.selected_function = Some(match s.selected_function.take() {
400 Some((selected, state)) if selected == idx => (selected, state),
402 _ => (idx, None),
403 });
404 }
405 }
406
407 #[no_mangle]
408 unsafe extern "C" fn tango_get_test_name(name: *mut *const c_char, length: *mut c_ulonglong) {
409 if let Some(s) = STATE.as_ref() {
410 let n = s.selected().name();
411 *name = n.as_ptr() as _;
412 *length = n.len() as _;
413 } else {
414 *name = null();
415 *length = 0;
416 }
417 }
418
419 #[no_mangle]
423 unsafe extern "C" fn tango_get_last_error(
424 name: *mut *const c_char,
425 length: *mut c_ulonglong,
426 ) -> c_int {
427 if let Some(err) = STATE.as_ref().and_then(|s| s.last_error.as_ref()) {
428 *name = err.as_ptr() as _;
429 *length = err.len() as _;
430 0
431 } else {
432 *name = null();
433 *length = 0;
434 -1
435 }
436 }
437
438 #[no_mangle]
439 unsafe extern "C" fn tango_run(iterations: c_ulonglong, time: *mut c_ulonglong) -> c_int {
440 let measurement = catch(|| {
441 STATE.as_mut().map(|s| {
442 s.selected_state_mut()
443 .expect("no tango_prepare_state() was called")
444 .measure(iterations as usize)
445 })
446 })
447 .flatten();
448 *time = measurement.unwrap_or(0);
449 if measurement.is_some() {
450 0
451 } else {
452 -1
453 }
454 }
455
456 #[no_mangle]
460 unsafe extern "C" fn tango_estimate_iterations(time_ms: c_uint) -> c_ulonglong {
461 catch(|| {
462 if let Some(s) = STATE.as_mut() {
463 s.selected_state_mut()
464 .expect("no tango_prepare_state() was called")
465 .as_mut()
466 .estimate_iterations(time_ms)
467 .max(1) as c_ulonglong
468 } else {
469 0
470 }
471 })
472 .unwrap_or(0)
473 }
474
475 #[no_mangle]
481 unsafe extern "C" fn tango_prepare_state(seed: c_ulonglong) -> c_int {
482 catch(|| {
483 if let Some(s) = STATE.as_mut() {
484 let Some((idx, state)) = &mut s.selected_function else {
485 panic!("No tango_select() was called")
486 };
487 *state = Some(s.benchmarks[*idx].prepare_state(seed));
488 }
489 0
490 })
491 .unwrap_or(-1)
492 }
493
494 #[no_mangle]
495 unsafe extern "C" fn tango_free() {
496 unsafe { *STATE.0.get() = None }
497 }
498
499 fn catch<T>(f: impl FnOnce() -> T + UnwindSafe) -> Option<T> {
502 match catch_unwind(f) {
503 Ok(r) => Some(r),
504 Err(e) => {
505 let state = unsafe { STATE.as_mut().unwrap() };
507 if let Some(msg) = e.downcast_ref::<&str>() {
508 state.last_error = Some(msg.to_string());
509 } else {
510 state.last_error = e.downcast_ref::<String>().cloned();
511 }
512 None
513 }
514 }
515 }
516
517 pub(super) struct VTable {
518 init_fn: InitFn,
526 count_fn: CountFn,
527 select_fn: SelectFn,
528 get_test_name_fn: GetTestNameFn,
529 get_last_error_fn: GetLastErrorFn,
530 run_fn: RunFn,
531 estimate_iterations_fn: EstimateIterationsFn,
532 prepare_state_fn: PrepareStateFn,
533 free_fn: FreeFn,
534
535 _library: Option<Box<Library>>,
537 }
538
539 impl VTable {
540 pub(super) fn new(lib: Library) -> Result<Self, Error> {
541 let api_version_fn = *lookup_symbol::<ApiVersionFn>(&lib, "tango_api_version")?;
542 let api_version = unsafe { (api_version_fn)() };
543 if api_version != TANGO_API_VERSION {
544 return Err(Error::IncorrectVersion(api_version));
545 }
546 Ok(Self {
547 init_fn: *lookup_symbol(&lib, "tango_init")?,
548 count_fn: *lookup_symbol(&lib, "tango_count")?,
549 select_fn: *lookup_symbol(&lib, "tango_select")?,
550 get_test_name_fn: *lookup_symbol(&lib, "tango_get_test_name")?,
551 run_fn: *lookup_symbol(&lib, "tango_run")?,
552 estimate_iterations_fn: *lookup_symbol(&lib, "tango_estimate_iterations")?,
553 prepare_state_fn: *lookup_symbol(&lib, "tango_prepare_state")?,
554 get_last_error_fn: *lookup_symbol(&lib, "tango_get_last_error")?,
555 free_fn: *lookup_symbol(&lib, "tango_free")?,
556 _library: Some(Box::new(lib)),
558 })
559 }
560
561 const fn for_self() -> Self {
562 unsafe extern "C" fn no_init() {
563 }
565 Self {
566 init_fn: no_init,
567 count_fn: ffi::tango_count,
568 select_fn: ffi::tango_select,
569 get_test_name_fn: ffi::tango_get_test_name,
570 run_fn: ffi::tango_run,
571 estimate_iterations_fn: ffi::tango_estimate_iterations,
572 prepare_state_fn: ffi::tango_prepare_state,
573 get_last_error_fn: ffi::tango_get_last_error,
574 free_fn: ffi::tango_free,
575 _library: None,
576 }
577 }
578
579 pub(super) fn init(&self) {
580 unsafe { (self.init_fn)() }
581 }
582
583 pub(super) fn count(&self) -> c_ulonglong {
584 unsafe { (self.count_fn)() }
585 }
586
587 pub(super) fn select(&self, func_idx: c_ulonglong) {
588 unsafe { (self.select_fn)(func_idx) }
589 }
590
591 pub(super) fn get_test_name(&self, ptr: *mut *const c_char, len: *mut c_ulonglong) {
592 unsafe { (self.get_test_name_fn)(ptr, len) }
593 }
594
595 pub(super) fn run(&self, iterations: c_ulonglong) -> Result<u64, Error> {
596 let mut measurement = 0u64;
597 match unsafe { (self.run_fn)(iterations, &mut measurement) } {
598 0 => Ok(measurement),
599 _ => Err(self.last_error()?),
600 }
601 }
602
603 pub(super) fn estimate_iterations(&self, time_ms: c_uint) -> Result<usize, Error> {
604 match unsafe { (self.estimate_iterations_fn)(time_ms) } {
605 0 => Err(self.last_error()?),
606 iters => Ok(iters as usize),
607 }
608 }
609
610 pub(super) fn prepare_state(&self, seed: c_ulonglong) -> Result<(), Error> {
611 match unsafe { (self.prepare_state_fn)(seed) } {
612 0 => Ok(()),
613 _ => Err(self.last_error()?),
614 }
615 }
616
617 fn last_error(&self) -> Result<Error, Error> {
618 let mut length = 0;
619 let mut name = null();
620 if unsafe { (self.get_last_error_fn)(&mut name, &mut length) } != 0 {
621 Err(Error::UnknownFFIError)
622 } else {
623 let name = unsafe { slice::from_raw_parts(name as *const u8, length as usize) };
624 str::from_utf8(name)
625 .map_err(Error::InvalidFFIString)
626 .map(str::to_string)
627 .map(Error::FFIError)
628 }
629 }
630 }
631
632 impl Drop for VTable {
633 fn drop(&mut self) {
634 unsafe { (self.free_fn)() }
635 }
636 }
637
638 fn lookup_symbol<'l, T>(library: &'l Library, name: &str) -> Result<Symbol<'l, T>, Error> {
639 unsafe {
640 library
641 .get(name.as_bytes())
642 .map_err(|e| Error::UnableToLoadSymbol(name.to_string(), e))
643 }
644 }
645}