scheme_rs/gc/
collection.rs

1//! An implementation of the algorithm described in the paper Concurrent
2//! Cycle Collection in Reference Counted Systems by David F. Bacon and
3//! V.T. Rajan.
4
5use std::{
6    alloc::Layout,
7    ptr::NonNull,
8    sync::OnceLock,
9    time::{Duration, Instant},
10};
11use tokio::{
12    sync::{
13        mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
14        Semaphore, SemaphorePermit,
15    },
16    task::JoinHandle,
17};
18
19use super::{Color, GcInner, OpaqueGc, OpaqueGcPtr, Trace};
20
21#[derive(Copy, Clone, Debug)]
22pub struct Mutation {
23    kind: MutationKind,
24    gc: NonNull<OpaqueGc>,
25}
26
27impl Mutation {
28    fn new(kind: MutationKind, gc: NonNull<OpaqueGc>) -> Self {
29        Self { kind, gc }
30    }
31}
32
33unsafe impl Send for Mutation {}
34unsafe impl Sync for Mutation {}
35
36#[derive(Copy, Clone, Debug)]
37pub enum MutationKind {
38    Inc,
39    Dec,
40}
41
42/// Instead of mutations being atomic (via an atomic variable), they're buffered into
43/// "epochs", and handled by precisely one thread.
44struct MutationBuffer {
45    mutation_buffer_tx: UnboundedSender<Mutation>,
46    mutation_buffer_rx: UnboundedReceiver<Mutation>,
47}
48
49impl Default for MutationBuffer {
50    fn default() -> Self {
51        let (mutation_buffer_tx, mutation_buffer_rx) = unbounded_channel();
52        Self {
53            mutation_buffer_tx,
54            mutation_buffer_rx,
55        }
56    }
57}
58
59static mut MUTATION_BUFFER: OnceLock<MutationBuffer> = OnceLock::new();
60
61pub(super) fn inc_rc<T: Trace>(gc: NonNull<GcInner<T>>) {
62    // SAFETY: send takes an immutable reference and is atomic
63    unsafe {
64        (&raw const MUTATION_BUFFER)
65            .as_ref()
66            .unwrap()
67            .get()
68            .unwrap()
69            .mutation_buffer_tx
70            .send(Mutation::new(MutationKind::Inc, gc as NonNull<OpaqueGc>))
71            .unwrap();
72    }
73}
74
75pub(super) fn dec_rc<T: Trace>(gc: NonNull<GcInner<T>>) {
76    // SAFETY: send takes an immutable reference and is atomic
77    unsafe {
78        (&raw const MUTATION_BUFFER)
79            .as_ref()
80            .unwrap()
81            .get()
82            .unwrap()
83            .mutation_buffer_tx
84            .send(Mutation::new(MutationKind::Dec, gc as NonNull<OpaqueGc>))
85            .unwrap();
86    }
87}
88
89static COLLECTOR_TASK: OnceLock<JoinHandle<()>> = OnceLock::new();
90
91pub fn init_gc() {
92    // SAFETY: We DO NOT mutate MUTATION_BUFFER, we mutate the _interior once lock_.
93    let _ = unsafe {
94        (&raw const MUTATION_BUFFER)
95            .as_ref()
96            .unwrap()
97            .get_or_init(MutationBuffer::default)
98    };
99    let _ = COLLECTOR_TASK.get_or_init(|| {
100        tokio::task::spawn(async {
101            let mut last_epoch = Instant::now();
102            loop {
103                epoch(&mut last_epoch).await
104            }
105        })
106    });
107}
108
109#[cfg(test)]
110pub fn init_gc_test() {
111    let _ = unsafe {
112        (&raw const MUTATION_BUFFER)
113            .as_ref()
114            .unwrap()
115            .get_or_init(MutationBuffer::default)
116    };
117}
118
119async fn epoch(last_epoch: &mut Instant) {
120    process_mutation_buffer().await;
121    let duration_since_last_epoch = Instant::now() - *last_epoch;
122    if duration_since_last_epoch > Duration::from_millis(100) {
123        tokio::task::spawn_blocking(process_cycles).await.unwrap();
124        *last_epoch = Instant::now();
125    }
126}
127
128/// SAFETY: this function is _not reentrant_, may only be called by once per epoch,
129/// and must _complete_ before the next epoch.
130pub async fn process_mutation_buffer() {
131    const MAX_MUTATIONS_PER_EPOCH: usize = 10_000; // No idea what a good value is here.
132
133    let mut mutation_buffer: Vec<_> = Vec::with_capacity(MAX_MUTATIONS_PER_EPOCH);
134    // SAFETY: This function has _exclusive access_ to the receive buffer.
135    unsafe {
136        (&raw mut MUTATION_BUFFER)
137            .as_mut()
138            .unwrap()
139            .get_mut()
140            .unwrap()
141            .mutation_buffer_rx
142            .recv_many(&mut mutation_buffer, MAX_MUTATIONS_PER_EPOCH)
143            .await;
144    }
145
146    // SAFETY: This function has _exclusive access_ to mutate the header of
147    // _every_ garbage collected object. It does so _only now_.
148
149    for mutation in mutation_buffer.into_iter() {
150        match mutation.kind {
151            MutationKind::Inc => increment(mutation.gc),
152            MutationKind::Dec => decrement(mutation.gc),
153        }
154    }
155}
156
157// SAFETY: These values can only be accessed by one thread at once.
158static mut ROOTS: Vec<OpaqueGcPtr> = Vec::new();
159static mut CYCLE_BUFFER: Vec<Vec<OpaqueGcPtr>> = Vec::new();
160static mut CURRENT_CYCLE: Vec<OpaqueGcPtr> = Vec::new();
161
162fn increment(s: OpaqueGcPtr) {
163    *rc(s) += 1;
164    scan_black(s);
165}
166
167fn decrement(s: OpaqueGcPtr) {
168    *rc(s) -= 1;
169    if *rc(s) == 0 {
170        release(s);
171    } else {
172        possible_root(s);
173    }
174}
175
176fn release(s: OpaqueGcPtr) {
177    for_each_child(s, decrement);
178    *color(s) = Color::Black;
179    if !*buffered(s) {
180        free(s);
181    }
182}
183
184fn possible_root(s: OpaqueGcPtr) {
185    scan_black(s);
186    *color(s) = Color::Purple;
187    if !*buffered(s) {
188        *buffered(s) = true;
189        unsafe { (&raw mut ROOTS).as_mut().unwrap().push(s) };
190    }
191}
192
193fn process_cycles() {
194    free_cycles();
195    collect_cycles();
196    sigma_preparation();
197}
198
199fn collect_cycles() {
200    mark_roots();
201    scan_roots();
202    collect_roots();
203}
204
205// SAFETY: No function called by mark_roots may access ROOTS
206fn mark_roots() {
207    let mut new_roots = Vec::new();
208    for s in unsafe { (&raw const ROOTS).as_ref().unwrap().iter() } {
209        if *color(*s) == Color::Purple && *rc(*s) > 0 {
210            mark_gray(*s);
211            new_roots.push(*s);
212        } else {
213            *buffered(*s) = false;
214            if *rc(*s) == 0 {
215                free(*s);
216            }
217        }
218    }
219    unsafe { ROOTS = new_roots }
220}
221
222fn scan_roots() {
223    for s in unsafe { (&raw const ROOTS).as_ref().unwrap().iter() } {
224        scan(*s)
225    }
226}
227
228fn collect_roots() {
229    for s in unsafe { std::mem::take((&raw mut ROOTS).as_mut().unwrap()) } {
230        if *color(s) == Color::White {
231            collect_white(s);
232            unsafe {
233                let current_cycle = std::mem::take((&raw mut CURRENT_CYCLE).as_mut().unwrap());
234                (&raw mut CYCLE_BUFFER)
235                    .as_mut()
236                    .unwrap()
237                    .push(current_cycle);
238            }
239        } else {
240            *buffered(s) = false;
241        }
242    }
243}
244
245fn mark_gray(s: OpaqueGcPtr) {
246    if *color(s) != Color::Gray {
247        *color(s) = Color::Gray;
248        *crc(s) = *rc(s) as isize;
249        for_each_child(s, |t| {
250            mark_gray(t);
251            if *crc(t) > 0 {
252                *crc(t) -= 1;
253            }
254        });
255    }
256}
257
258fn scan(s: OpaqueGcPtr) {
259    if *color(s) == Color::Gray {
260        if *crc(s) == 0 {
261            *color(s) = Color::White;
262            for_each_child(s, scan);
263        } else {
264            scan_black(s);
265        }
266    }
267}
268
269fn scan_black(s: OpaqueGcPtr) {
270    if *color(s) != Color::Black {
271        *color(s) = Color::Black;
272        for_each_child(s, scan_black);
273    }
274}
275
276fn collect_white(s: OpaqueGcPtr) {
277    if *color(s) == Color::White {
278        *color(s) = Color::Orange;
279        *buffered(s) = true;
280        unsafe {
281            (&raw mut CURRENT_CYCLE).as_mut().unwrap().push(s);
282        }
283        for_each_child(s, collect_white);
284    }
285}
286
287fn sigma_preparation() {
288    for c in unsafe { (&raw const CYCLE_BUFFER).as_ref().unwrap() } {
289        for n in c {
290            *color(*n) = Color::Red;
291            *crc(*n) = *rc(*n) as isize;
292        }
293        for n in c {
294            for_each_child(*n, |m| {
295                if *color(m) == Color::Red && *crc(m) > 0 {
296                    *crc(m) -= 1;
297                }
298            });
299        }
300        for n in c {
301            *color(*n) = Color::Orange;
302        }
303    }
304}
305
306fn free_cycles() {
307    for c in unsafe {
308        std::mem::take((&raw mut CYCLE_BUFFER).as_mut().unwrap())
309            .into_iter()
310            .rev()
311    } {
312        if delta_test(&c) && sigma_test(&c) {
313            free_cycle(&c);
314        } else {
315            refurbish(&c);
316        }
317    }
318}
319
320fn delta_test(c: &[OpaqueGcPtr]) -> bool {
321    for n in c {
322        if *color(*n) != Color::Orange {
323            return false;
324        }
325    }
326    true
327}
328
329fn sigma_test(c: &[OpaqueGcPtr]) -> bool {
330    let mut sum = 0;
331    for n in c {
332        sum += *crc(*n);
333    }
334    sum == 0
335    /*
336    // NOTE: This is the only function so far that I have not implemented
337    // _exactly_ as the text reads. I do not understand why I would have to
338    // continue iterating if I see a CRC > 0, as CRCs cannot be negative.
339    for n in c {
340        if *crc(*n) > 0 {
341            return false;
342        }
343    }
344    true
345    */
346}
347
348fn refurbish(c: &[OpaqueGcPtr]) {
349    for (i, n) in c.iter().enumerate() {
350        match (i, *color(*n)) {
351            (0, Color::Orange) | (_, Color::Purple) => {
352                *color(*n) = Color::Purple;
353                unsafe {
354                    (&raw mut ROOTS).as_mut().unwrap().push(*n);
355                }
356            }
357            _ => {
358                *color(*n) = Color::Black;
359                *buffered(*n) = false;
360            }
361        }
362    }
363}
364
365fn free_cycle(c: &[OpaqueGcPtr]) {
366    for n in c {
367        *color(*n) = Color::Red;
368    }
369    for n in c {
370        for_each_child(*n, cyclic_decrement);
371    }
372    for n in c {
373        free(*n);
374    }
375}
376
377fn cyclic_decrement(m: OpaqueGcPtr) {
378    if *color(m) != Color::Red {
379        if *color(m) == Color::Orange {
380            *rc(m) -= 1;
381            *crc(m) -= 1;
382        } else {
383            decrement(m);
384        }
385    }
386}
387
388fn color<'a>(s: OpaqueGcPtr) -> &'a mut Color {
389    unsafe { &mut (*s.as_ref().header.get()).color }
390}
391
392fn rc<'a>(s: OpaqueGcPtr) -> &'a mut usize {
393    unsafe { &mut (*s.as_ref().header.get()).rc }
394}
395
396fn crc<'a>(s: OpaqueGcPtr) -> &'a mut isize {
397    unsafe { &mut (*s.as_ref().header.get()).crc }
398}
399
400fn buffered<'a>(s: OpaqueGcPtr) -> &'a mut bool {
401    unsafe { &mut (*s.as_ref().header.get()).buffered }
402}
403
404fn semaphore<'a>(s: OpaqueGcPtr) -> &'a Semaphore {
405    unsafe { &(*s.as_ref().header.get()).semaphore }
406}
407
408fn acquire_permit(semaphore: &'_ Semaphore) -> SemaphorePermit<'_> {
409    loop {
410        if let Ok(permit) = semaphore.try_acquire() {
411            return permit;
412        }
413    }
414}
415
416fn trace<'a>(s: OpaqueGcPtr) -> &'a mut dyn Trace {
417    unsafe { &mut *s.as_ref().data.get() }
418}
419
420fn for_each_child(s: OpaqueGcPtr, visitor: fn(OpaqueGcPtr)) {
421    let permit = acquire_permit(semaphore(s));
422    unsafe { (*s.as_ref().data.get()).visit_children(visitor) }
423    drop(permit);
424}
425
426fn free(s: OpaqueGcPtr) {
427    // Safety: No need to acquire a permit, s is guaranteed to be garbage.
428    let trace = trace(s);
429    unsafe {
430        let layout = Layout::for_value(trace);
431        trace.finalize();
432        std::alloc::dealloc(s.as_ptr() as *mut u8, layout);
433    }
434}
435
436#[cfg(test)]
437mod test {
438    use collection::{init_gc_test, process_cycles};
439
440    use crate::gc::*;
441
442    #[tokio::test]
443    async fn cycles() {
444        #[derive(Default, Trace)]
445        struct Cyclic {
446            next: Option<Gc<Cyclic>>,
447            out: Option<Arc<()>>,
448        }
449
450        let out_ptr = Arc::new(());
451
452        init_gc_test();
453
454        let a = Gc::new(Cyclic::default());
455        let b = Gc::new(Cyclic::default());
456        let c = Gc::new(Cyclic::default());
457
458        // a -> b -> c -
459        // ^----------/
460        a.write().await.next = Some(b.clone());
461        b.write().await.next = Some(c.clone());
462        b.write().await.out = Some(out_ptr.clone());
463        c.write().await.next = Some(a.clone());
464
465        assert_eq!(Arc::strong_count(&out_ptr), 2);
466
467        drop(a);
468        drop(b);
469        drop(c);
470        process_mutation_buffer().await;
471        process_cycles();
472        process_cycles();
473        assert_eq!(Arc::strong_count(&out_ptr), 1);
474    }
475}