wraith/km/
sync.rs

1//! Kernel-mode synchronization primitives
2
3use core::cell::UnsafeCell;
4use core::ffi::c_void;
5use core::ops::{Deref, DerefMut};
6
7/// kernel spinlock (KSPIN_LOCK)
8#[repr(transparent)]
9pub struct SpinLockRaw(usize);
10
11impl SpinLockRaw {
12    /// create uninitialized spinlock
13    pub const fn new() -> Self {
14        Self(0)
15    }
16}
17
18/// RAII spinlock guard
19pub struct SpinLockGuard<'a, T> {
20    lock: &'a SpinLock<T>,
21    old_irql: u8,
22}
23
24impl<'a, T> Drop for SpinLockGuard<'a, T> {
25    fn drop(&mut self) {
26        // SAFETY: we hold the lock and have valid old_irql
27        unsafe {
28            KeReleaseSpinLock(&self.lock.raw as *const _ as *mut _, self.old_irql);
29        }
30    }
31}
32
33impl<'a, T> Deref for SpinLockGuard<'a, T> {
34    type Target = T;
35
36    fn deref(&self) -> &Self::Target {
37        // SAFETY: we hold the lock
38        unsafe { &*self.lock.data.get() }
39    }
40}
41
42impl<'a, T> DerefMut for SpinLockGuard<'a, T> {
43    fn deref_mut(&mut self) -> &mut Self::Target {
44        // SAFETY: we hold the lock exclusively
45        unsafe { &mut *self.lock.data.get() }
46    }
47}
48
49/// spinlock protecting data T
50pub struct SpinLock<T> {
51    raw: SpinLockRaw,
52    data: UnsafeCell<T>,
53}
54
55// SAFETY: SpinLock provides exclusive access through lock()
56unsafe impl<T: Send> Send for SpinLock<T> {}
57unsafe impl<T: Send> Sync for SpinLock<T> {}
58
59impl<T> SpinLock<T> {
60    /// create new spinlock with data
61    pub const fn new(data: T) -> Self {
62        Self {
63            raw: SpinLockRaw::new(),
64            data: UnsafeCell::new(data),
65        }
66    }
67
68    /// initialize the spinlock (must be called before first use)
69    pub fn init(&mut self) {
70        // SAFETY: valid spinlock pointer
71        unsafe {
72            KeInitializeSpinLock(&mut self.raw as *mut _ as *mut _);
73        }
74    }
75
76    /// acquire spinlock (raises IRQL to DISPATCH_LEVEL)
77    pub fn lock(&self) -> SpinLockGuard<'_, T> {
78        let mut old_irql: u8 = 0;
79        // SAFETY: valid spinlock pointer
80        unsafe {
81            KeAcquireSpinLock(&self.raw as *const _ as *mut _, &mut old_irql);
82        }
83        SpinLockGuard {
84            lock: self,
85            old_irql,
86        }
87    }
88
89    /// try to acquire spinlock without blocking
90    pub fn try_lock(&self) -> Option<SpinLockGuard<'_, T>> {
91        let mut old_irql: u8 = 0;
92        // SAFETY: valid spinlock pointer
93        let acquired = unsafe {
94            KeTryToAcquireSpinLockAtDpcLevel(&self.raw as *const _ as *mut _)
95        };
96
97        if acquired != 0 {
98            Some(SpinLockGuard {
99                lock: self,
100                old_irql,
101            })
102        } else {
103            None
104        }
105    }
106
107    /// get mutable reference without locking (unsafe)
108    ///
109    /// # Safety
110    /// caller must ensure exclusive access
111    pub unsafe fn get_unchecked(&self) -> &mut T {
112        unsafe { &mut *self.data.get() }
113    }
114}
115
116/// fast mutex (FAST_MUTEX)
117#[repr(C)]
118pub struct FastMutexRaw {
119    count: i32,
120    owner: *mut c_void,
121    contention: u32,
122    event: [u8; 24], // KEVENT
123    old_irql: u32,
124}
125
126/// RAII fast mutex guard
127pub struct FastMutexGuard<'a, T> {
128    mutex: &'a FastMutex<T>,
129}
130
131impl<'a, T> Drop for FastMutexGuard<'a, T> {
132    fn drop(&mut self) {
133        // SAFETY: we hold the mutex
134        unsafe {
135            ExReleaseFastMutex(&self.mutex.raw as *const _ as *mut _);
136        }
137    }
138}
139
140impl<'a, T> Deref for FastMutexGuard<'a, T> {
141    type Target = T;
142
143    fn deref(&self) -> &Self::Target {
144        // SAFETY: we hold the mutex
145        unsafe { &*self.mutex.data.get() }
146    }
147}
148
149impl<'a, T> DerefMut for FastMutexGuard<'a, T> {
150    fn deref_mut(&mut self) -> &mut Self::Target {
151        // SAFETY: we hold the mutex exclusively
152        unsafe { &mut *self.mutex.data.get() }
153    }
154}
155
156/// fast mutex protecting data T (APC level, can't be used at DISPATCH_LEVEL)
157pub struct FastMutex<T> {
158    raw: FastMutexRaw,
159    data: UnsafeCell<T>,
160}
161
162// SAFETY: FastMutex provides exclusive access through lock()
163unsafe impl<T: Send> Send for FastMutex<T> {}
164unsafe impl<T: Send> Sync for FastMutex<T> {}
165
166impl<T> FastMutex<T> {
167    /// create new fast mutex (uninitialized)
168    pub fn new(data: T) -> Self {
169        Self {
170            raw: FastMutexRaw {
171                count: 1,
172                owner: core::ptr::null_mut(),
173                contention: 0,
174                event: [0; 24],
175                old_irql: 0,
176            },
177            data: UnsafeCell::new(data),
178        }
179    }
180
181    /// initialize the mutex (must be called before first use)
182    pub fn init(&mut self) {
183        // SAFETY: valid mutex pointer
184        unsafe {
185            ExInitializeFastMutex(&mut self.raw as *mut _ as *mut _);
186        }
187    }
188
189    /// acquire mutex (raises IRQL to APC_LEVEL)
190    pub fn lock(&self) -> FastMutexGuard<'_, T> {
191        // SAFETY: valid mutex pointer
192        unsafe {
193            ExAcquireFastMutex(&self.raw as *const _ as *mut _);
194        }
195        FastMutexGuard { mutex: self }
196    }
197
198    /// try to acquire mutex without blocking
199    pub fn try_lock(&self) -> Option<FastMutexGuard<'_, T>> {
200        // SAFETY: valid mutex pointer
201        let acquired = unsafe {
202            ExTryToAcquireFastMutex(&self.raw as *const _ as *mut _)
203        };
204
205        if acquired != 0 {
206            Some(FastMutexGuard { mutex: self })
207        } else {
208            None
209        }
210    }
211}
212
213/// RAII wrapper for generic locked data
214pub struct Guarded<T, L> {
215    data: T,
216    lock: L,
217}
218
219impl<T, L> Guarded<T, L> {
220    /// create new guarded data
221    pub fn new(data: T, lock: L) -> Self {
222        Self { data, lock }
223    }
224}
225
226/// push lock (EX_PUSH_LOCK) - lightweight reader/writer lock
227#[repr(transparent)]
228pub struct PushLockRaw(usize);
229
230impl PushLockRaw {
231    pub const fn new() -> Self {
232        Self(0)
233    }
234}
235
236/// push lock wrapper
237pub struct PushLock<T> {
238    raw: PushLockRaw,
239    data: UnsafeCell<T>,
240}
241
242// SAFETY: PushLock provides synchronized access
243unsafe impl<T: Send> Send for PushLock<T> {}
244unsafe impl<T: Send + Sync> Sync for PushLock<T> {}
245
246impl<T> PushLock<T> {
247    /// create new push lock
248    pub const fn new(data: T) -> Self {
249        Self {
250            raw: PushLockRaw::new(),
251            data: UnsafeCell::new(data),
252        }
253    }
254
255    /// initialize push lock
256    pub fn init(&mut self) {
257        // SAFETY: valid pointer
258        unsafe {
259            ExInitializePushLock(&mut self.raw as *mut _ as *mut _);
260        }
261    }
262
263    /// acquire exclusive (write) lock
264    pub fn lock_exclusive(&self) -> PushLockExclusiveGuard<'_, T> {
265        // SAFETY: valid pointer
266        unsafe {
267            ExAcquirePushLockExclusive(&self.raw as *const _ as *mut _);
268        }
269        PushLockExclusiveGuard { lock: self }
270    }
271
272    /// acquire shared (read) lock
273    pub fn lock_shared(&self) -> PushLockSharedGuard<'_, T> {
274        // SAFETY: valid pointer
275        unsafe {
276            ExAcquirePushLockShared(&self.raw as *const _ as *mut _);
277        }
278        PushLockSharedGuard { lock: self }
279    }
280}
281
282/// exclusive guard for push lock
283pub struct PushLockExclusiveGuard<'a, T> {
284    lock: &'a PushLock<T>,
285}
286
287impl<'a, T> Drop for PushLockExclusiveGuard<'a, T> {
288    fn drop(&mut self) {
289        // SAFETY: we hold the lock
290        unsafe {
291            ExReleasePushLockExclusive(&self.lock.raw as *const _ as *mut _);
292        }
293    }
294}
295
296impl<'a, T> Deref for PushLockExclusiveGuard<'a, T> {
297    type Target = T;
298    fn deref(&self) -> &Self::Target {
299        unsafe { &*self.lock.data.get() }
300    }
301}
302
303impl<'a, T> DerefMut for PushLockExclusiveGuard<'a, T> {
304    fn deref_mut(&mut self) -> &mut Self::Target {
305        unsafe { &mut *self.lock.data.get() }
306    }
307}
308
309/// shared guard for push lock
310pub struct PushLockSharedGuard<'a, T> {
311    lock: &'a PushLock<T>,
312}
313
314impl<'a, T> Drop for PushLockSharedGuard<'a, T> {
315    fn drop(&mut self) {
316        // SAFETY: we hold the lock
317        unsafe {
318            ExReleasePushLockShared(&self.lock.raw as *const _ as *mut _);
319        }
320    }
321}
322
323impl<'a, T> Deref for PushLockSharedGuard<'a, T> {
324    type Target = T;
325    fn deref(&self) -> &Self::Target {
326        unsafe { &*self.lock.data.get() }
327    }
328}
329
330// kernel synchronization functions
331extern "system" {
332    fn KeInitializeSpinLock(SpinLock: *mut c_void);
333    fn KeAcquireSpinLock(SpinLock: *mut c_void, OldIrql: *mut u8);
334    fn KeReleaseSpinLock(SpinLock: *mut c_void, NewIrql: u8);
335    fn KeTryToAcquireSpinLockAtDpcLevel(SpinLock: *mut c_void) -> u32;
336
337    fn ExInitializeFastMutex(FastMutex: *mut c_void);
338    fn ExAcquireFastMutex(FastMutex: *mut c_void);
339    fn ExReleaseFastMutex(FastMutex: *mut c_void);
340    fn ExTryToAcquireFastMutex(FastMutex: *mut c_void) -> u8;
341
342    fn ExInitializePushLock(PushLock: *mut c_void);
343    fn ExAcquirePushLockExclusive(PushLock: *mut c_void);
344    fn ExReleasePushLockExclusive(PushLock: *mut c_void);
345    fn ExAcquirePushLockShared(PushLock: *mut c_void);
346    fn ExReleasePushLockShared(PushLock: *mut c_void);
347}