1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
//
// Copyright (c) 2023 ZettaScale Technology
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License 2.0 which is available at
// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
//
// Contributors:
//   Pierre Avital, <pierre.avital@me.com>
//

use core::{
    ops::Deref,
    sync::atomic::{AtomicU8, Ordering},
};
/// Used in `#[stabby::import(canaries)]`
#[crate::stabby]
pub struct CanariedImport<F> {
    result: F,
    checked: AtomicU8,
    canary: extern "C" fn(),
}
unsafe impl<F> Send for CanariedImport<F> {}
unsafe impl<F> Sync for CanariedImport<F> {}
impl<F> CanariedImport<F> {
    /// Used in `#[stabby::import(canaries)]`
    pub const fn new(source: F, canary_caller: extern "C" fn()) -> Self {
        Self {
            result: source,
            checked: AtomicU8::new(0),
            canary: canary_caller,
        }
    }
}
impl<F> Deref for CanariedImport<F> {
    type Target = F;
    fn deref(&self) -> &Self::Target {
        if self.checked.swap(1, Ordering::Relaxed) == 0 {
            (self.canary)()
        }
        &self.result
    }
}

/// Used in `#[stabby::import]`
#[crate::stabby]
pub struct CheckedImport<F> {
    result: core::cell::UnsafeCell<core::mem::MaybeUninit<F>>,
    checked: AtomicU8,
    #[allow(improper_ctypes_definitions)]
    checker: unsafe extern "C" fn(&crate::report::TypeReport) -> Option<F>,
    get_report: unsafe extern "C" fn() -> &'static crate::report::TypeReport,
    local_report: &'static crate::report::TypeReport,
}
unsafe impl<F> Send for CheckedImport<F> {}
unsafe impl<F> Sync for CheckedImport<F> {}

/// When reports mismatch between loader and loadee, both reports are exposed to allow debuging the issue.
#[crate::stabby]
#[derive(Debug, Clone, Copy)]
pub struct ReportMismatch {
    /// The report on loader side.
    pub local: &'static crate::report::TypeReport,
    /// The report on loadee side.
    pub loaded: &'static crate::report::TypeReport,
}
impl core::fmt::Display for ReportMismatch {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        core::fmt::Debug::fmt(&self, f)
    }
}
#[cfg(feature = "std")]
impl std::error::Error for ReportMismatch {}

const UNCHECKED: u8 = 0;
const VALIDATED: u8 = 1;
const INVALIDATED: u8 = 2;
const LOCKED: u8 = 3;
impl<F> CheckedImport<F> {
    /// Used by `#[stabby::import]` proc-macro
    #[allow(improper_ctypes_definitions)]
    pub const fn new(
        checker: unsafe extern "C" fn(&crate::report::TypeReport) -> Option<F>,
        get_report: unsafe extern "C" fn() -> &'static crate::report::TypeReport,
        local_report: &'static crate::report::TypeReport,
    ) -> Self {
        Self {
            checked: AtomicU8::new(UNCHECKED),
            checker,
            get_report,
            local_report,
            result: core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit()),
        }
    }
    fn error_report(&self) -> ReportMismatch {
        ReportMismatch {
            local: self.local_report,
            loaded: unsafe { (self.get_report)() },
        }
    }
    /// # Errors
    /// Returns a [`ReportMismatch`] if the local and loaded reports differ.
    pub fn as_ref(&self) -> Result<&F, ReportMismatch> {
        loop {
            match self.checked.load(Ordering::Relaxed) {
                UNCHECKED => match unsafe { (self.checker)(self.local_report) } {
                    Some(result) => {
                        if self
                            .checked
                            .compare_exchange_weak(
                                UNCHECKED,
                                LOCKED,
                                Ordering::SeqCst,
                                Ordering::Relaxed,
                            )
                            .is_ok()
                        {
                            unsafe {
                                (*self.result.get()).write(result);
                                self.checked.store(VALIDATED, Ordering::SeqCst);
                                return Ok((*self.result.get()).assume_init_ref());
                            }
                        }
                    }
                    None => {
                        self.checked.store(INVALIDATED, Ordering::Relaxed);
                        return Err(self.error_report());
                    }
                },
                VALIDATED => return Ok(unsafe { (*self.result.get()).assume_init_ref() }),
                INVALIDATED => return Err(self.error_report()),
                _ => {}
            }
            core::hint::spin_loop();
        }
    }
}
impl<F> core::ops::Deref for CheckedImport<F> {
    type Target = F;
    fn deref(&self) -> &Self::Target {
        self.as_ref().unwrap()
    }
}
impl<F> Drop for CheckedImport<F> {
    fn drop(&mut self) {
        if self.checked.load(Ordering::Relaxed) == VALIDATED {
            unsafe { self.result.get_mut().assume_init_drop() }
        }
    }
}