Skip to main content

tango_bench/
dylib.rs

1//! Loading and resolving symbols from .dylib/.so libraries
2
3use self::ffi::SELF_VTABLE;
4use crate::{Benchmark, ErasedSampler, Error};
5use ffi::VTable;
6use libloading::{Library, Symbol};
7use std::{
8    cell::UnsafeCell,
9    ffi::{c_char, c_ulonglong},
10    path::Path,
11    ptr::addr_of,
12    slice, str,
13    sync::mpsc::{channel, Receiver, Sender},
14    thread::{self, JoinHandle},
15};
16
17pub type FunctionIdx = usize;
18
19#[derive(Debug, Clone)]
20pub struct NamedFunction {
21    pub name: String,
22
23    /// Function index in FFI API
24    pub idx: FunctionIdx,
25}
26
27pub(crate) struct Spi {
28    tests: Vec<NamedFunction>,
29    selected_function: Option<FunctionIdx>,
30    mode: SpiMode,
31}
32
33#[derive(PartialEq, Eq, Clone, Copy)]
34pub enum SpiModeKind {
35    // Benchmarks are executed synchronously when calling SPI
36    //
37    // Dispatcher switches between baseline and candidate after each sample
38    Synchronous,
39
40    // Benchmarks are executed in different threads
41    //
42    // Dispatcher creates a separate thread for baseline and candidate, but synchronize them after each benchmark
43    Asynchronous,
44}
45
46enum SpiMode {
47    Synchronous {
48        vt: VTable,
49        last_measurement: u64,
50    },
51    Asynchronous {
52        worker: Option<JoinHandle<()>>,
53        tx: Sender<SpiRequest>,
54        rx: Receiver<SpiReply>,
55    },
56}
57
58impl Spi {
59    pub(crate) fn for_library(path: impl AsRef<Path>, mode: SpiModeKind) -> Result<Spi, Error> {
60        let path = path.as_ref();
61        if path.exists() {
62            #[cfg(target_os = "windows")]
63            let lib = {
64                use libloading::os::windows::Library as WinLibrary;
65                use windows::Win32::Foundation::HMODULE;
66
67                let lib = unsafe { WinLibrary::new(path) }.map_err(Error::UnableToLoadBenchmark)?;
68
69                // Get the raw module handle and patch the IAT
70                // This is needed because Windows doesn't properly resolve imports when loading
71                // an EXE file as a library
72                let raw_handle = lib.into_raw();
73                let handle = HMODULE(raw_handle as _);
74                unsafe {
75                    crate::windows::patch_iat(handle).map_err(Error::UnableToPatchIat)?;
76                }
77                // Reconstruct the library from the raw handle
78                Library::from(unsafe { WinLibrary::from_raw(raw_handle) })
79            };
80
81            #[cfg(not(target_os = "windows"))]
82            let lib = unsafe { Library::new(path) }.map_err(Error::UnableToLoadBenchmark)?;
83
84            Ok(spi_handle_for_vtable(ffi::VTable::new(lib)?, mode))
85        } else {
86            Err(Error::BenchmarkNotFound)
87        }
88    }
89
90    pub(crate) fn for_self(mode: SpiModeKind) -> Option<Spi> {
91        SELF_VTABLE
92            .lock()
93            .unwrap()
94            .take()
95            .map(|vt| spi_handle_for_vtable(vt, mode))
96    }
97
98    pub(crate) fn tests(&self) -> &[NamedFunction] {
99        &self.tests
100    }
101
102    pub(crate) fn lookup(&self, name: &str) -> Option<&NamedFunction> {
103        self.tests.iter().find(|f| f.name == name)
104    }
105
106    pub(crate) fn run(&mut self, iterations: usize) -> Result<u64, Error> {
107        match &self.mode {
108            SpiMode::Synchronous { vt, .. } => vt.run(iterations as c_ulonglong),
109            SpiMode::Asynchronous { worker: _, tx, rx } => {
110                tx.send(SpiRequest::Run { iterations }).unwrap();
111                match rx.recv().unwrap() {
112                    SpiReply::Run(time) => time,
113                    r => panic!("Unexpected response: {:?}", r),
114                }
115            }
116        }
117    }
118
119    pub(crate) fn measure(&mut self, iterations: usize) -> Result<(), Error> {
120        match &mut self.mode {
121            SpiMode::Synchronous {
122                vt,
123                last_measurement,
124            } => {
125                *last_measurement = vt.run(iterations as c_ulonglong)?;
126            }
127            SpiMode::Asynchronous { tx, .. } => {
128                tx.send(SpiRequest::Measure { iterations }).unwrap();
129            }
130        }
131        Ok(())
132    }
133
134    pub(crate) fn read_sample(&mut self) -> Result<u64, Error> {
135        match &self.mode {
136            SpiMode::Synchronous {
137                last_measurement, ..
138            } => Ok(*last_measurement),
139            SpiMode::Asynchronous { rx, .. } => match rx.recv().unwrap() {
140                SpiReply::Measure(time) => time,
141                r => panic!("Unexpected response: {:?}", r),
142            },
143        }
144    }
145
146    pub(crate) fn estimate_iterations(&mut self, time_ms: u32) -> Result<usize, Error> {
147        match &self.mode {
148            SpiMode::Synchronous { vt, .. } => vt.estimate_iterations(time_ms),
149            SpiMode::Asynchronous { tx, rx, .. } => {
150                tx.send(SpiRequest::EstimateIterations { time_ms }).unwrap();
151                match rx.recv().unwrap() {
152                    SpiReply::EstimateIterations(iters) => iters,
153                    r => panic!("Unexpected response: {:?}", r),
154                }
155            }
156        }
157    }
158
159    pub(crate) fn prepare_state(&mut self, seed: u64) -> Result<(), Error> {
160        match &self.mode {
161            SpiMode::Synchronous { vt, .. } => vt.prepare_state(seed),
162            SpiMode::Asynchronous { tx, rx, .. } => {
163                tx.send(SpiRequest::PrepareState { seed }).unwrap();
164                match rx.recv().unwrap() {
165                    SpiReply::PrepareState(result) => result,
166                    r => panic!("Unexpected response: {:?}", r),
167                }
168            }
169        }
170    }
171
172    pub(crate) fn select(&mut self, idx: usize) {
173        match &self.mode {
174            SpiMode::Synchronous { vt, .. } => vt.select(idx as c_ulonglong),
175            SpiMode::Asynchronous { tx, rx, .. } => {
176                tx.send(SpiRequest::Select { idx }).unwrap();
177                match rx.recv().unwrap() {
178                    SpiReply::Select => self.selected_function = Some(idx),
179                    r => panic!("Unexpected response: {:?}", r),
180                }
181            }
182        }
183    }
184}
185
186impl Drop for Spi {
187    fn drop(&mut self) {
188        if let SpiMode::Asynchronous { worker, tx, .. } = &mut self.mode {
189            if let Some(worker) = worker.take() {
190                tx.send(SpiRequest::Shutdown).unwrap();
191                worker.join().unwrap();
192            }
193        }
194    }
195}
196
197fn spi_worker(vt: &VTable, rx: Receiver<SpiRequest>, tx: Sender<SpiReply>) {
198    use SpiReply as Rp;
199    use SpiRequest as Rq;
200
201    while let Ok(req) = rx.recv() {
202        let reply = match req {
203            Rq::EstimateIterations { time_ms } => {
204                Rp::EstimateIterations(vt.estimate_iterations(time_ms))
205            }
206            Rq::PrepareState { seed } => Rp::PrepareState(vt.prepare_state(seed)),
207            Rq::Select { idx } => {
208                vt.select(idx as c_ulonglong);
209                Rp::Select
210            }
211            Rq::Run { iterations } => Rp::Run(vt.run(iterations as c_ulonglong)),
212            Rq::Measure { iterations } => Rp::Measure(vt.run(iterations as c_ulonglong)),
213            Rq::Shutdown => break,
214        };
215        tx.send(reply).unwrap();
216    }
217}
218
219fn spi_handle_for_vtable(vt: VTable, mode: SpiModeKind) -> Spi {
220    vt.init();
221    let tests = enumerate_tests(&vt).unwrap();
222
223    match mode {
224        SpiModeKind::Asynchronous => {
225            let (request_tx, request_rx) = channel();
226            let (reply_tx, reply_rx) = channel();
227            let worker = thread::spawn(move || {
228                spi_worker(&vt, request_rx, reply_tx);
229            });
230
231            Spi {
232                tests,
233                selected_function: None,
234                mode: SpiMode::Asynchronous {
235                    worker: Some(worker),
236                    tx: request_tx,
237                    rx: reply_rx,
238                },
239            }
240        }
241        SpiModeKind::Synchronous => Spi {
242            tests,
243            selected_function: None,
244            mode: SpiMode::Synchronous {
245                vt,
246                last_measurement: 0,
247            },
248        },
249    }
250}
251
252fn enumerate_tests(vt: &VTable) -> Result<Vec<NamedFunction>, Error> {
253    let mut tests = vec![];
254    for idx in 0..vt.count() {
255        vt.select(idx);
256
257        let mut length = 0;
258        let name_ptr: *const c_char = c"".as_ptr();
259        vt.get_test_name(addr_of!(name_ptr) as _, &mut length);
260        if length > 0 {
261            let slice = unsafe { slice::from_raw_parts(name_ptr as *const u8, length as usize) };
262            let name = str::from_utf8(slice)
263                .map_err(Error::InvalidFFIString)?
264                .to_string();
265            let idx = idx as usize;
266            tests.push(NamedFunction { name, idx });
267        }
268    }
269    Ok(tests)
270}
271
272enum SpiRequest {
273    EstimateIterations { time_ms: u32 },
274    PrepareState { seed: u64 },
275    Select { idx: usize },
276    Run { iterations: usize },
277    Measure { iterations: usize },
278    Shutdown,
279}
280
281#[derive(Debug)]
282enum SpiReply {
283    EstimateIterations(Result<usize, Error>),
284    PrepareState(Result<(), Error>),
285    Select,
286    Run(Result<u64, Error>),
287    Measure(Result<u64, Error>),
288}
289
290/// State which holds the information about list of benchmarks and which one is selected.
291/// Used in FFI API (`tango_*` functions).
292struct State {
293    benchmarks: Vec<Benchmark>,
294    selected_function: Option<(usize, Option<Box<dyn ErasedSampler>>)>,
295    last_error: Option<String>,
296}
297
298impl State {
299    fn selected(&self) -> &Benchmark {
300        &self.benchmarks[self.ensure_selected()]
301    }
302
303    fn ensure_selected(&self) -> usize {
304        self.selected_function
305            .as_ref()
306            .map(|(idx, _)| *idx)
307            .expect("No function was selected. Call tango_select() first")
308    }
309
310    fn selected_state_mut(&mut self) -> Option<&mut Box<dyn ErasedSampler>> {
311        self.selected_function
312            .as_mut()
313            .and_then(|(_, state)| state.as_mut())
314    }
315}
316
317/// Global state of the benchmarking library
318static STATE: StateWrapper = StateWrapper(UnsafeCell::new(None));
319
320struct StateWrapper(UnsafeCell<Option<State>>);
321unsafe impl Sync for StateWrapper {}
322
323impl StateWrapper {
324    unsafe fn as_ref(&self) -> Option<&State> {
325        (*self.0.get()).as_ref()
326    }
327
328    #[allow(clippy::mut_from_ref)]
329    unsafe fn as_mut(&self) -> Option<&mut State> {
330        (*self.0.get()).as_mut()
331    }
332}
333
334/// `tango_init()` implementation
335///
336/// This function is not exported from the library, but is used by the `tango_init()` functions
337/// generated by the `tango_benchmark!()` macro.
338pub fn __tango_init(benchmarks: Vec<Benchmark>) {
339    if unsafe { STATE.as_ref().is_none() } {
340        let state = Some(State {
341            benchmarks,
342            selected_function: None,
343            last_error: None,
344        });
345        unsafe { *STATE.0.get() = state }
346    }
347}
348
349/// Defines all the foundation types and exported symbols for the FFI communication API between two
350/// executables.
351///
352/// Tango execution model implies simultaneous execution of the code from two binaries. To achieve that
353/// Tango benchmark is compiled in a way that executable is also a shared library (.dll, .so, .dylib). This
354/// way two executables can coexist in the single process at the same time.
355pub mod ffi {
356    use super::*;
357    use std::{
358        ffi::{c_int, c_uint, c_ulonglong},
359        os::raw::c_char,
360        panic::{catch_unwind, UnwindSafe},
361        ptr::null,
362        sync::Mutex,
363    };
364
365    /// Signature types of all FFI API functions
366    pub type InitFn = unsafe extern "C" fn();
367    type CountFn = unsafe extern "C" fn() -> c_ulonglong;
368    type GetTestNameFn = unsafe extern "C" fn(*mut *const c_char, *mut c_ulonglong);
369    type SelectFn = unsafe extern "C" fn(c_ulonglong);
370    type RunFn = unsafe extern "C" fn(c_ulonglong, *mut c_ulonglong) -> c_int;
371    type EstimateIterationsFn = unsafe extern "C" fn(c_uint) -> c_ulonglong;
372    type PrepareStateFn = unsafe extern "C" fn(c_ulonglong) -> c_int;
373    type GetLastErrorFn = unsafe extern "C" fn(*mut *const c_char, *mut c_ulonglong) -> c_int;
374    type ApiVersionFn = unsafe extern "C" fn() -> c_uint;
375    type FreeFn = unsafe extern "C" fn();
376
377    pub(super) static SELF_VTABLE: Mutex<Option<VTable>> = Mutex::new(Some(VTable::for_self()));
378    pub const TANGO_API_VERSION: u32 = 3;
379
380    #[no_mangle]
381    unsafe extern "C" fn tango_count() -> c_ulonglong {
382        STATE
383            .as_ref()
384            .map(|s| s.benchmarks.len() as c_ulonglong)
385            .unwrap_or(0)
386    }
387
388    #[no_mangle]
389    unsafe extern "C" fn tango_api_version() -> c_uint {
390        TANGO_API_VERSION
391    }
392
393    #[no_mangle]
394    unsafe extern "C" fn tango_select(idx: c_ulonglong) {
395        if let Some(s) = STATE.as_mut() {
396            let idx = idx as usize;
397            assert!(idx < s.benchmarks.len());
398
399            s.selected_function = Some(match s.selected_function.take() {
400                // Preserving state if the same function is selected
401                Some((selected, state)) if selected == idx => (selected, state),
402                _ => (idx, None),
403            });
404        }
405    }
406
407    #[no_mangle]
408    unsafe extern "C" fn tango_get_test_name(name: *mut *const c_char, length: *mut c_ulonglong) {
409        if let Some(s) = STATE.as_ref() {
410            let n = s.selected().name();
411            *name = n.as_ptr() as _;
412            *length = n.len() as _;
413        } else {
414            *name = null();
415            *length = 0;
416        }
417    }
418
419    /// Returns C-string to a description of last error (if any)
420    ///
421    /// Returns: 0 if last error was returned, -1 otherwise
422    #[no_mangle]
423    unsafe extern "C" fn tango_get_last_error(
424        name: *mut *const c_char,
425        length: *mut c_ulonglong,
426    ) -> c_int {
427        if let Some(err) = STATE.as_ref().and_then(|s| s.last_error.as_ref()) {
428            *name = err.as_ptr() as _;
429            *length = err.len() as _;
430            0
431        } else {
432            *name = null();
433            *length = 0;
434            -1
435        }
436    }
437
438    #[no_mangle]
439    unsafe extern "C" fn tango_run(iterations: c_ulonglong, time: *mut c_ulonglong) -> c_int {
440        let measurement = catch(|| {
441            STATE.as_mut().map(|s| {
442                s.selected_state_mut()
443                    .expect("no tango_prepare_state() was called")
444                    .measure(iterations as usize)
445            })
446        })
447        .flatten();
448        *time = measurement.unwrap_or(0);
449        if measurement.is_some() {
450            0
451        } else {
452            -1
453        }
454    }
455
456    /// Returns an estimation of number of iterations needed to spent given amount of time
457    ///
458    /// Returns: the number of iterations (minimum of 1) or 0 if error happens during building the estimate.
459    #[no_mangle]
460    unsafe extern "C" fn tango_estimate_iterations(time_ms: c_uint) -> c_ulonglong {
461        catch(|| {
462            if let Some(s) = STATE.as_mut() {
463                s.selected_state_mut()
464                    .expect("no tango_prepare_state() was called")
465                    .as_mut()
466                    .estimate_iterations(time_ms)
467                    .max(1) as c_ulonglong
468            } else {
469                0
470            }
471        })
472        .unwrap_or(0)
473    }
474
475    /// Prepares benchmark internal state
476    ///
477    /// Should be called once benchmark was selected ([`tango_select`]) to initialize all needed state.
478    ///
479    /// Returns: 0 if success, otherwise preparing state was failed
480    #[no_mangle]
481    unsafe extern "C" fn tango_prepare_state(seed: c_ulonglong) -> c_int {
482        catch(|| {
483            if let Some(s) = STATE.as_mut() {
484                let Some((idx, state)) = &mut s.selected_function else {
485                    panic!("No tango_select() was called")
486                };
487                *state = Some(s.benchmarks[*idx].prepare_state(seed));
488            }
489            0
490        })
491        .unwrap_or(-1)
492    }
493
494    #[no_mangle]
495    unsafe extern "C" fn tango_free() {
496        unsafe { *STATE.0.get() = None }
497    }
498
499    /// Since unwinding cannot cross FFI boundaries, we catch all panics here
500    /// and print their message for debugging, while returning None to indicate failure.
501    fn catch<T>(f: impl FnOnce() -> T + UnwindSafe) -> Option<T> {
502        match catch_unwind(f) {
503            Ok(r) => Some(r),
504            Err(e) => {
505                // Here we're assuming state is already initialized, because f was running some operations on it
506                let state = unsafe { STATE.as_mut().unwrap() };
507                if let Some(msg) = e.downcast_ref::<&str>() {
508                    state.last_error = Some(msg.to_string());
509                } else {
510                    state.last_error = e.downcast_ref::<String>().cloned();
511                }
512                None
513            }
514        }
515    }
516
517    pub(super) struct VTable {
518        /// SAFETY: using plain function pointers instead of [`Symbol`] here to generalize over the case
519        /// when we have to have `VTable` for functions defined in our own address space
520        /// (so called [self VTable](Self::for_self()))
521        ///
522        /// This is is sound because:
523        ///  (1) this struct is private and field can not be accessed outside
524        ///  (2) rust has drop order guarantee (fields are dropped in declaration order)
525        init_fn: InitFn,
526        count_fn: CountFn,
527        select_fn: SelectFn,
528        get_test_name_fn: GetTestNameFn,
529        get_last_error_fn: GetLastErrorFn,
530        run_fn: RunFn,
531        estimate_iterations_fn: EstimateIterationsFn,
532        prepare_state_fn: PrepareStateFn,
533        free_fn: FreeFn,
534
535        /// SAFETY: This field should be last because it should be dropped last
536        _library: Option<Box<Library>>,
537    }
538
539    impl VTable {
540        pub(super) fn new(lib: Library) -> Result<Self, Error> {
541            let api_version_fn = *lookup_symbol::<ApiVersionFn>(&lib, "tango_api_version")?;
542            let api_version = unsafe { (api_version_fn)() };
543            if api_version != TANGO_API_VERSION {
544                return Err(Error::IncorrectVersion(api_version));
545            }
546            Ok(Self {
547                init_fn: *lookup_symbol(&lib, "tango_init")?,
548                count_fn: *lookup_symbol(&lib, "tango_count")?,
549                select_fn: *lookup_symbol(&lib, "tango_select")?,
550                get_test_name_fn: *lookup_symbol(&lib, "tango_get_test_name")?,
551                run_fn: *lookup_symbol(&lib, "tango_run")?,
552                estimate_iterations_fn: *lookup_symbol(&lib, "tango_estimate_iterations")?,
553                prepare_state_fn: *lookup_symbol(&lib, "tango_prepare_state")?,
554                get_last_error_fn: *lookup_symbol(&lib, "tango_get_last_error")?,
555                free_fn: *lookup_symbol(&lib, "tango_free")?,
556                // SAFETY: symbols are valid as long as _library member is alive
557                _library: Some(Box::new(lib)),
558            })
559        }
560
561        const fn for_self() -> Self {
562            unsafe extern "C" fn no_init() {
563                // In executable mode `tango_init` is already called by the main function
564            }
565            Self {
566                init_fn: no_init,
567                count_fn: ffi::tango_count,
568                select_fn: ffi::tango_select,
569                get_test_name_fn: ffi::tango_get_test_name,
570                run_fn: ffi::tango_run,
571                estimate_iterations_fn: ffi::tango_estimate_iterations,
572                prepare_state_fn: ffi::tango_prepare_state,
573                get_last_error_fn: ffi::tango_get_last_error,
574                free_fn: ffi::tango_free,
575                _library: None,
576            }
577        }
578
579        pub(super) fn init(&self) {
580            unsafe { (self.init_fn)() }
581        }
582
583        pub(super) fn count(&self) -> c_ulonglong {
584            unsafe { (self.count_fn)() }
585        }
586
587        pub(super) fn select(&self, func_idx: c_ulonglong) {
588            unsafe { (self.select_fn)(func_idx) }
589        }
590
591        pub(super) fn get_test_name(&self, ptr: *mut *const c_char, len: *mut c_ulonglong) {
592            unsafe { (self.get_test_name_fn)(ptr, len) }
593        }
594
595        pub(super) fn run(&self, iterations: c_ulonglong) -> Result<u64, Error> {
596            let mut measurement = 0u64;
597            match unsafe { (self.run_fn)(iterations, &mut measurement) } {
598                0 => Ok(measurement),
599                _ => Err(self.last_error()?),
600            }
601        }
602
603        pub(super) fn estimate_iterations(&self, time_ms: c_uint) -> Result<usize, Error> {
604            match unsafe { (self.estimate_iterations_fn)(time_ms) } {
605                0 => Err(self.last_error()?),
606                iters => Ok(iters as usize),
607            }
608        }
609
610        pub(super) fn prepare_state(&self, seed: c_ulonglong) -> Result<(), Error> {
611            match unsafe { (self.prepare_state_fn)(seed) } {
612                0 => Ok(()),
613                _ => Err(self.last_error()?),
614            }
615        }
616
617        fn last_error(&self) -> Result<Error, Error> {
618            let mut length = 0;
619            let mut name = null();
620            if unsafe { (self.get_last_error_fn)(&mut name, &mut length) } != 0 {
621                Err(Error::UnknownFFIError)
622            } else {
623                let name = unsafe { slice::from_raw_parts(name as *const u8, length as usize) };
624                str::from_utf8(name)
625                    .map_err(Error::InvalidFFIString)
626                    .map(str::to_string)
627                    .map(Error::FFIError)
628            }
629        }
630    }
631
632    impl Drop for VTable {
633        fn drop(&mut self) {
634            unsafe { (self.free_fn)() }
635        }
636    }
637
638    fn lookup_symbol<'l, T>(library: &'l Library, name: &str) -> Result<Symbol<'l, T>, Error> {
639        unsafe {
640            library
641                .get(name.as_bytes())
642                .map_err(|e| Error::UnableToLoadSymbol(name.to_string(), e))
643        }
644    }
645}