stabby_abi/
checked_import.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   Pierre Avital, <pierre.avital@me.com>
13//
14
15use core::{
16    ops::Deref,
17    sync::atomic::{AtomicU8, Ordering},
18};
19/// Used in `#[stabby::import(canaries)]`
20#[crate::stabby]
21pub struct CanariedImport<F> {
22    result: F,
23    checked: AtomicU8,
24    canary: extern "C" fn(),
25}
26// SAFETY: `CanariedImport`s always refer to `Send + Sync` things.
27unsafe impl<F> Send for CanariedImport<F> {}
28// SAFETY: `CanariedImport`s always refer to `Send + Sync` things.
29unsafe impl<F> Sync for CanariedImport<F> {}
30impl<F> CanariedImport<F> {
31    /// Used in `#[stabby::import(canaries)]`
32    pub const fn new(source: F, canary_caller: extern "C" fn()) -> Self {
33        Self {
34            result: source,
35            checked: AtomicU8::new(0),
36            canary: canary_caller,
37        }
38    }
39}
40impl<F> Deref for CanariedImport<F> {
41    type Target = F;
42    fn deref(&self) -> &Self::Target {
43        if self.checked.swap(1, Ordering::Relaxed) == 0 {
44            (self.canary)()
45        }
46        &self.result
47    }
48}
49
50/// Used in `#[stabby::import]`
51#[crate::stabby]
52pub struct CheckedImport<F> {
53    result: core::cell::UnsafeCell<core::mem::MaybeUninit<F>>,
54    checked: AtomicU8,
55    #[allow(improper_ctypes_definitions)]
56    checker: unsafe extern "C" fn(&crate::report::TypeReport) -> Option<F>,
57    get_report: unsafe extern "C" fn() -> &'static crate::report::TypeReport,
58    local_report: &'static crate::report::TypeReport,
59}
60// SAFETY: `CheckedImport`s always refer to functions.
61unsafe impl<F> Send for CheckedImport<F> {}
62// SAFETY: `CheckedImport`s always refer to functions.
63unsafe impl<F> Sync for CheckedImport<F> {}
64
65/// When reports mismatch between loader and loadee, both reports are exposed to allow debuging the issue.
66#[crate::stabby]
67#[derive(Debug, Clone, Copy)]
68pub struct ReportMismatch {
69    /// The report on loader side.
70    pub local: &'static crate::report::TypeReport,
71    /// The report on loadee side.
72    pub loaded: &'static crate::report::TypeReport,
73}
74impl core::fmt::Display for ReportMismatch {
75    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
76        core::fmt::Debug::fmt(&self, f)
77    }
78}
79#[cfg(feature = "std")]
80impl std::error::Error for ReportMismatch {}
81
82const UNCHECKED: u8 = 0;
83const VALIDATED: u8 = 1;
84const INVALIDATED: u8 = 2;
85const LOCKED: u8 = 3;
86impl<F> CheckedImport<F> {
87    /// Used by `#[stabby::import]` proc-macro
88    #[allow(improper_ctypes_definitions)]
89    pub const fn new(
90        checker: unsafe extern "C" fn(&crate::report::TypeReport) -> Option<F>,
91        get_report: unsafe extern "C" fn() -> &'static crate::report::TypeReport,
92        local_report: &'static crate::report::TypeReport,
93    ) -> Self {
94        Self {
95            checked: AtomicU8::new(UNCHECKED),
96            checker,
97            get_report,
98            local_report,
99            result: core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit()),
100        }
101    }
102    fn error_report(&self) -> ReportMismatch {
103        ReportMismatch {
104            local: self.local_report,
105            loaded: unsafe { (self.get_report)() },
106        }
107    }
108    /// # Errors
109    /// Returns a [`ReportMismatch`] if the local and loaded reports differ.
110    pub fn as_ref(&self) -> Result<&F, ReportMismatch> {
111        loop {
112            match self.checked.load(Ordering::Relaxed) {
113                UNCHECKED => match unsafe { (self.checker)(self.local_report) } {
114                    Some(result) => {
115                        if self
116                            .checked
117                            .compare_exchange_weak(
118                                UNCHECKED,
119                                LOCKED,
120                                Ordering::SeqCst,
121                                Ordering::Relaxed,
122                            )
123                            .is_ok()
124                        {
125                            unsafe {
126                                (*self.result.get()).write(result);
127                                self.checked.store(VALIDATED, Ordering::SeqCst);
128                                return Ok((*self.result.get()).assume_init_ref());
129                            }
130                        }
131                    }
132                    None => {
133                        self.checked.store(INVALIDATED, Ordering::Relaxed);
134                        return Err(self.error_report());
135                    }
136                },
137                VALIDATED => return Ok(unsafe { (*self.result.get()).assume_init_ref() }),
138                INVALIDATED => return Err(self.error_report()),
139                _ => {}
140            }
141            core::hint::spin_loop();
142        }
143    }
144}
145impl<F> core::ops::Deref for CheckedImport<F> {
146    type Target = F;
147    fn deref(&self) -> &Self::Target {
148        self.as_ref().unwrap()
149    }
150}
151impl<F> Drop for CheckedImport<F> {
152    fn drop(&mut self) {
153        if self.checked.load(Ordering::Relaxed) == VALIDATED {
154            unsafe { self.result.get_mut().assume_init_drop() }
155        }
156    }
157}