1#![doc = include_str!("../README.md")]
2
3use std::{
4 alloc::GlobalAlloc,
5 sync::atomic::{AtomicUsize, Ordering},
6 thread::sleep,
7 time::Duration,
8};
9
10#[cfg(feature = "async_tokio")]
11use std::future::Future;
12
13use stats_alloc::{Stats, StatsAlloc};
14
15#[cfg(feature = "async_tokio")]
16use tokio::{runtime, task::spawn_blocking};
17
18const STATE_UNLOCKED: usize = 0;
19const STATE_IN_USE: usize = 1;
20
21const SLEEP: Duration = Duration::from_micros(50);
22
23pub struct LockedAllocator<T>
24where
25 T: GlobalAlloc,
26{
27 locked: AtomicUsize,
28 inner: StatsAlloc<T>,
29}
30
31impl<T> LockedAllocator<T>
32where
33 T: GlobalAlloc,
34{
35 pub const fn new(inner: StatsAlloc<T>) -> Self {
36 let locked = AtomicUsize::new(0);
37 Self { locked, inner }
38 }
39
40 fn current_thread_id() -> usize {
42 unsafe { libc::pthread_self() as usize }
43 }
44
45 fn before_op(&self) -> bool {
48 let current_thread_id = Self::current_thread_id();
49
50 loop {
51 match self.locked.compare_exchange(
52 STATE_UNLOCKED,
53 STATE_IN_USE,
54 Ordering::SeqCst,
55 Ordering::SeqCst,
56 ) {
57 Ok(_) => break,
58 Err(existing) => {
59 if existing == current_thread_id {
60 return true;
61 }
62 }
63 }
64
65 sleep(SLEEP);
66 }
67
68 false
69 }
70
71 fn after_op(&self) {
73 let current_thread_id = Self::current_thread_id();
74
75 loop {
76 match self.locked.compare_exchange(
77 STATE_IN_USE,
78 STATE_UNLOCKED,
79 Ordering::SeqCst,
80 Ordering::SeqCst,
81 ) {
82 Ok(_) => break,
83 Err(existing) => {
84 if existing == current_thread_id {
85 break;
86 }
87 }
88 }
89
90 sleep(SLEEP);
91 }
92 }
93
94 fn serialized<F, O>(&self, op: F) -> O
96 where
97 F: FnOnce(bool) -> O,
98 {
99 let locked = self.before_op();
100 let result = op(locked);
101 self.after_op();
102
103 result
104 }
105
106 fn lock(&self) {
108 let current_thread_id = Self::current_thread_id();
109
110 loop {
111 let r = self.locked.compare_exchange(
112 STATE_UNLOCKED,
113 current_thread_id,
114 Ordering::SeqCst,
115 Ordering::SeqCst,
116 );
117
118 if r.is_ok() {
119 break;
120 }
121
122 sleep(SLEEP);
123 }
124 }
125
126 fn unlock(&self) {
128 let expected = Self::current_thread_id();
129
130 assert_eq!(
131 expected,
132 self.locked
133 .compare_exchange(expected, STATE_UNLOCKED, Ordering::SeqCst, Ordering::SeqCst)
134 .unwrap()
135 );
136 }
137
138 fn stats(&self) -> Stats {
140 self.inner.stats()
141 }
142}
143
144unsafe impl<T> GlobalAlloc for LockedAllocator<T>
145where
146 T: GlobalAlloc,
147{
148 unsafe fn alloc(&self, layout: std::alloc::Layout) -> *mut u8 {
149 self.serialized(|is_locked| {
150 if is_locked {
151 probe::probe!(LockedAllocator, alloc_locked);
152 }
153
154 self.inner.alloc(layout)
155 })
156 }
157
158 unsafe fn dealloc(&self, ptr: *mut u8, layout: std::alloc::Layout) {
159 self.serialized(|is_locked| {
160 if is_locked {
161 probe::probe!(LockedAllocator, dealloc_locked);
162 }
163
164 self.inner.dealloc(ptr, layout)
165 })
166 }
167
168 unsafe fn realloc(&self, ptr: *mut u8, layout: std::alloc::Layout, new_size: usize) -> *mut u8 {
169 self.serialized(|is_locked| {
170 if is_locked {
171 probe::probe!(LockedAllocator, realloc_locked);
172 }
173
174 self.inner.realloc(ptr, layout, new_size)
175 })
176 }
177}
178
179pub fn memory_measured<A, F>(alloc: &LockedAllocator<A>, f: F) -> Stats
181where
182 A: GlobalAlloc,
183 F: FnOnce(),
184{
185 alloc.lock();
186
187 let before = alloc.stats();
188
189 f();
190
191 let after = alloc.stats();
192
193 alloc.unlock();
194
195 after - before
196}
197
198#[cfg(feature = "async_tokio")]
200pub async fn memory_measured_future<A, F>(alloc: &'static LockedAllocator<A>, f: F) -> Stats
201where
202 A: GlobalAlloc + Send + Sync,
203 F: Future<Output = ()> + Send + 'static,
204{
205 spawn_blocking(|| {
207 let runtime = runtime::Builder::new_current_thread()
208 .enable_all()
209 .build()
210 .unwrap();
211
212 runtime.block_on(async {
213 alloc.lock();
214
215 let before = alloc.stats();
216
217 f.await;
218
219 let after = alloc.stats();
220
221 alloc.unlock();
222
223 after - before
224 })
225 })
226 .await
227 .unwrap()
228}
229
230#[cfg(test)]
231mod tests {
232 use std::{
233 alloc::System,
234 sync::{
235 atomic::{AtomicBool, Ordering},
236 Arc,
237 },
238 thread::{sleep, spawn},
239 time::Duration,
240 };
241
242 use super::*;
243
244 #[global_allocator]
245 static GLOBAL: LockedAllocator<System> = LockedAllocator::new(StatsAlloc::system());
246
247 #[test]
248 fn it_works() {
249 let mut length = 0;
250
251 let stats = memory_measured(&GLOBAL, || {
252 let s = "whoa".to_owned().replace("whoa", "wow").to_owned();
253
254 length = s.len();
255 });
256
257 assert_eq!(length, 3);
258
259 assert_eq!(
260 stats,
261 Stats {
262 allocations: 3,
263 deallocations: 3,
264 reallocations: 0,
265 bytes_allocated: 15,
266 bytes_deallocated: 15,
267 bytes_reallocated: 0
268 }
269 );
270
271 let stats = memory_measured(&GLOBAL, || {
272 let mut v = vec![1, 2, 3, 4, 5];
273
274 v.push(6);
275
276 length = v.len();
277 });
278
279 assert_eq!(length, 6);
280
281 assert_eq!(
282 stats,
283 Stats {
284 allocations: 1,
285 deallocations: 1,
286 reallocations: 1,
287 bytes_allocated: 40,
288 bytes_deallocated: 40,
289 bytes_reallocated: 20
290 }
291 );
292 }
293
294 #[test]
295 fn test_parallel() {
296 let stop = Arc::new(AtomicBool::new(false));
297
298 {
299 let stop = stop.clone();
300 spawn(move || {
301 let mut vec = vec![];
302 while !stop.load(Ordering::Relaxed) {
303 vec.push(1);
304 sleep(Duration::from_micros(1));
305 }
306 });
307 }
308
309 let mut length = 0;
310 let step = Duration::from_millis(150);
311
312 let stats = memory_measured(&GLOBAL, || {
313 let s = "whoa".to_owned().replace("whoa", "wow").to_owned();
314
315 sleep(step);
316
317 length = s.len();
318 });
319
320 stop.store(true, Ordering::Relaxed);
321
322 assert_eq!(length, 3);
323
324 assert_eq!(
325 stats,
326 Stats {
327 allocations: 3,
328 deallocations: 3,
329 reallocations: 0,
330 bytes_allocated: 15,
331 bytes_deallocated: 15,
332 bytes_reallocated: 0
333 }
334 );
335 }
336
337 #[tokio::test]
338 #[cfg(feature = "async_tokio")]
339 async fn test_tokio() {
340 let stats = memory_measured_future(&GLOBAL, async {
341 let _ = vec![1, 2, 3, 4];
342 })
343 .await;
344
345 assert_eq!(
346 stats,
347 Stats {
348 allocations: 1,
349 deallocations: 1,
350 reallocations: 0,
351 bytes_allocated: 16,
352 bytes_deallocated: 16,
353 bytes_reallocated: 0
354 }
355 );
356 }
357}