stats_alloc_helper/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    alloc::GlobalAlloc,
5    sync::atomic::{AtomicUsize, Ordering},
6    thread::sleep,
7    time::Duration,
8};
9
10#[cfg(feature = "async_tokio")]
11use std::future::Future;
12
13use stats_alloc::{Stats, StatsAlloc};
14
15#[cfg(feature = "async_tokio")]
16use tokio::{runtime, task::spawn_blocking};
17
18const STATE_UNLOCKED: usize = 0;
19const STATE_IN_USE: usize = 1;
20
21const SLEEP: Duration = Duration::from_micros(50);
22
23pub struct LockedAllocator<T>
24where
25    T: GlobalAlloc,
26{
27    locked: AtomicUsize,
28    inner: StatsAlloc<T>,
29}
30
31impl<T> LockedAllocator<T>
32where
33    T: GlobalAlloc,
34{
35    pub const fn new(inner: StatsAlloc<T>) -> Self {
36        let locked = AtomicUsize::new(0);
37        Self { locked, inner }
38    }
39
40    /// An allocation free way to get the current thread id.
41    fn current_thread_id() -> usize {
42        unsafe { libc::pthread_self() as usize }
43    }
44
45    /// An allocation free serialization code that runs prior to any allocator operation.
46    /// Returns whether the current thread locked the allocator.
47    fn before_op(&self) -> bool {
48        let current_thread_id = Self::current_thread_id();
49
50        loop {
51            match self.locked.compare_exchange(
52                STATE_UNLOCKED,
53                STATE_IN_USE,
54                Ordering::SeqCst,
55                Ordering::SeqCst,
56            ) {
57                Ok(_) => break,
58                Err(existing) => {
59                    if existing == current_thread_id {
60                        return true;
61                    }
62                }
63            }
64
65            sleep(SLEEP);
66        }
67
68        false
69    }
70
71    /// An allocation free serialization code that runs after to any allocator operation.
72    fn after_op(&self) {
73        let current_thread_id = Self::current_thread_id();
74
75        loop {
76            match self.locked.compare_exchange(
77                STATE_IN_USE,
78                STATE_UNLOCKED,
79                Ordering::SeqCst,
80                Ordering::SeqCst,
81            ) {
82                Ok(_) => break,
83                Err(existing) => {
84                    if existing == current_thread_id {
85                        break;
86                    }
87                }
88            }
89
90            sleep(SLEEP);
91        }
92    }
93
94    /// A serialization wrapper to use for all allocator operations.
95    fn serialized<F, O>(&self, op: F) -> O
96    where
97        F: FnOnce(bool) -> O,
98    {
99        let locked = self.before_op();
100        let result = op(locked);
101        self.after_op();
102
103        result
104    }
105
106    /// Lock the allocator to only allow operations from the current thread.
107    fn lock(&self) {
108        let current_thread_id = Self::current_thread_id();
109
110        loop {
111            let r = self.locked.compare_exchange(
112                STATE_UNLOCKED,
113                current_thread_id,
114                Ordering::SeqCst,
115                Ordering::SeqCst,
116            );
117
118            if r.is_ok() {
119                break;
120            }
121
122            sleep(SLEEP);
123        }
124    }
125
126    /// Unlocks the allocator to allow operations from any thread.
127    fn unlock(&self) {
128        let expected = Self::current_thread_id();
129
130        assert_eq!(
131            expected,
132            self.locked
133                .compare_exchange(expected, STATE_UNLOCKED, Ordering::SeqCst, Ordering::SeqCst)
134                .unwrap()
135        );
136    }
137
138    /// Returns [Stats] from the wrapped [StatsAlloc].
139    fn stats(&self) -> Stats {
140        self.inner.stats()
141    }
142}
143
144unsafe impl<T> GlobalAlloc for LockedAllocator<T>
145where
146    T: GlobalAlloc,
147{
148    unsafe fn alloc(&self, layout: std::alloc::Layout) -> *mut u8 {
149        self.serialized(|is_locked| {
150            if is_locked {
151                probe::probe!(LockedAllocator, alloc_locked);
152            }
153
154            self.inner.alloc(layout)
155        })
156    }
157
158    unsafe fn dealloc(&self, ptr: *mut u8, layout: std::alloc::Layout) {
159        self.serialized(|is_locked| {
160            if is_locked {
161                probe::probe!(LockedAllocator, dealloc_locked);
162            }
163
164            self.inner.dealloc(ptr, layout)
165        })
166    }
167
168    unsafe fn realloc(&self, ptr: *mut u8, layout: std::alloc::Layout, new_size: usize) -> *mut u8 {
169        self.serialized(|is_locked| {
170            if is_locked {
171                probe::probe!(LockedAllocator, realloc_locked);
172            }
173
174            self.inner.realloc(ptr, layout, new_size)
175        })
176    }
177}
178
179/// Measure memory and return [Stats] object for the runtime of the passed closure.
180pub fn memory_measured<A, F>(alloc: &LockedAllocator<A>, f: F) -> Stats
181where
182    A: GlobalAlloc,
183    F: FnOnce(),
184{
185    alloc.lock();
186
187    let before = alloc.stats();
188
189    f();
190
191    let after = alloc.stats();
192
193    alloc.unlock();
194
195    after - before
196}
197
198/// Measure memory and return [Stats] object for the runtime of the passed future.
199#[cfg(feature = "async_tokio")]
200pub async fn memory_measured_future<A, F>(alloc: &'static LockedAllocator<A>, f: F) -> Stats
201where
202    A: GlobalAlloc + Send + Sync,
203    F: Future<Output = ()> + Send + 'static,
204{
205    // Tokio runtime cannot be created from a thread that is a part of a runtime already.
206    spawn_blocking(|| {
207        let runtime = runtime::Builder::new_current_thread()
208            .enable_all()
209            .build()
210            .unwrap();
211
212        runtime.block_on(async {
213            alloc.lock();
214
215            let before = alloc.stats();
216
217            f.await;
218
219            let after = alloc.stats();
220
221            alloc.unlock();
222
223            after - before
224        })
225    })
226    .await
227    .unwrap()
228}
229
230#[cfg(test)]
231mod tests {
232    use std::{
233        alloc::System,
234        sync::{
235            atomic::{AtomicBool, Ordering},
236            Arc,
237        },
238        thread::{sleep, spawn},
239        time::Duration,
240    };
241
242    use super::*;
243
244    #[global_allocator]
245    static GLOBAL: LockedAllocator<System> = LockedAllocator::new(StatsAlloc::system());
246
247    #[test]
248    fn it_works() {
249        let mut length = 0;
250
251        let stats = memory_measured(&GLOBAL, || {
252            let s = "whoa".to_owned().replace("whoa", "wow").to_owned();
253
254            length = s.len();
255        });
256
257        assert_eq!(length, 3);
258
259        assert_eq!(
260            stats,
261            Stats {
262                allocations: 3,
263                deallocations: 3,
264                reallocations: 0,
265                bytes_allocated: 15,
266                bytes_deallocated: 15,
267                bytes_reallocated: 0
268            }
269        );
270
271        let stats = memory_measured(&GLOBAL, || {
272            let mut v = vec![1, 2, 3, 4, 5];
273
274            v.push(6);
275
276            length = v.len();
277        });
278
279        assert_eq!(length, 6);
280
281        assert_eq!(
282            stats,
283            Stats {
284                allocations: 1,
285                deallocations: 1,
286                reallocations: 1,
287                bytes_allocated: 40,
288                bytes_deallocated: 40,
289                bytes_reallocated: 20
290            }
291        );
292    }
293
294    #[test]
295    fn test_parallel() {
296        let stop = Arc::new(AtomicBool::new(false));
297
298        {
299            let stop = stop.clone();
300            spawn(move || {
301                let mut vec = vec![];
302                while !stop.load(Ordering::Relaxed) {
303                    vec.push(1);
304                    sleep(Duration::from_micros(1));
305                }
306            });
307        }
308
309        let mut length = 0;
310        let step = Duration::from_millis(150);
311
312        let stats = memory_measured(&GLOBAL, || {
313            let s = "whoa".to_owned().replace("whoa", "wow").to_owned();
314
315            sleep(step);
316
317            length = s.len();
318        });
319
320        stop.store(true, Ordering::Relaxed);
321
322        assert_eq!(length, 3);
323
324        assert_eq!(
325            stats,
326            Stats {
327                allocations: 3,
328                deallocations: 3,
329                reallocations: 0,
330                bytes_allocated: 15,
331                bytes_deallocated: 15,
332                bytes_reallocated: 0
333            }
334        );
335    }
336
337    #[tokio::test]
338    #[cfg(feature = "async_tokio")]
339    async fn test_tokio() {
340        let stats = memory_measured_future(&GLOBAL, async {
341            let _ = vec![1, 2, 3, 4];
342        })
343        .await;
344
345        assert_eq!(
346            stats,
347            Stats {
348                allocations: 1,
349                deallocations: 1,
350                reallocations: 0,
351                bytes_allocated: 16,
352                bytes_deallocated: 16,
353                bytes_reallocated: 0
354            }
355        );
356    }
357}