Skip to main content

tycho_core/block_strider/provider/
box_provider.rs

1use std::sync::atomic::{AtomicPtr, Ordering};
2
3use anyhow::Result;
4use futures_util::FutureExt;
5use futures_util::future::BoxFuture;
6use tycho_block_util::block::BlockIdRelation;
7use tycho_types::models::BlockId;
8
9use crate::block_strider::provider::{BlockProvider, OptionalBlockStuff};
10
11pub struct BoxBlockProvider {
12    data: AtomicPtr<()>,
13    vtable: &'static Vtable,
14}
15
16impl BoxBlockProvider {
17    pub fn new<P>(provider: P) -> Self
18    where
19        P: BlockProvider,
20    {
21        let ptr = Box::into_raw(Box::new(provider));
22
23        Self {
24            data: AtomicPtr::new(ptr.cast()),
25            vtable: const { Vtable::new::<P>() },
26        }
27    }
28}
29
30impl BlockProvider for BoxBlockProvider {
31    type GetNextBlockFut<'a> = GetBlockFut<'a>;
32    type GetBlockFut<'a> = GetBlockFut<'a>;
33    type CleanupFut<'a> = ClenaupFut<'a>;
34
35    fn get_next_block<'a>(&'a self, prev_block_id: &'a BlockId) -> Self::GetNextBlockFut<'a> {
36        unsafe { (self.vtable.get_next_block)(&self.data, prev_block_id) }
37    }
38
39    fn get_block<'a>(&'a self, block_id_relation: &'a BlockIdRelation) -> Self::GetBlockFut<'a> {
40        unsafe { (self.vtable.get_block)(&self.data, block_id_relation) }
41    }
42
43    fn cleanup_until(&self, mc_seqno: u32) -> Self::CleanupFut<'_> {
44        unsafe { (self.vtable.cleanup_until)(&self.data, mc_seqno) }
45    }
46}
47
48impl Drop for BoxBlockProvider {
49    fn drop(&mut self) {
50        unsafe { (self.vtable.drop)(&mut self.data) }
51    }
52}
53
54// Vtable must enforce this behavior
55unsafe impl Send for BoxBlockProvider {}
56unsafe impl Sync for BoxBlockProvider {}
57
58struct Vtable {
59    get_next_block: GetNextBlockFn,
60    get_block: GetBlockFn,
61    cleanup_until: CleanupFn,
62    drop: DropFn,
63}
64
65impl Vtable {
66    const fn new<P: BlockProvider>() -> &'static Self {
67        &Self {
68            get_next_block: |ptr, prev_block_id| {
69                let provider = unsafe { &*ptr.load(Ordering::Relaxed).cast::<P>() };
70                provider.get_next_block(prev_block_id).boxed()
71            },
72            get_block: |ptr, block_id_relation| {
73                let provider = unsafe { &*ptr.load(Ordering::Relaxed).cast::<P>() };
74                provider.get_block(block_id_relation).boxed()
75            },
76            cleanup_until: |ptr, mc_seqno| {
77                let provider = unsafe { &*ptr.load(Ordering::Relaxed).cast::<P>() };
78                provider.cleanup_until(mc_seqno).boxed()
79            },
80            drop: |ptr| {
81                drop(unsafe { Box::<P>::from_raw(ptr.get_mut().cast::<P>()) });
82            },
83        }
84    }
85}
86
87type GetNextBlockFn = for<'a> unsafe fn(&AtomicPtr<()>, &'a BlockId) -> GetBlockFut<'a>;
88type GetBlockFn = for<'a> unsafe fn(&AtomicPtr<()>, &'a BlockIdRelation) -> GetBlockFut<'a>;
89type CleanupFn = for<'a> unsafe fn(&AtomicPtr<()>, u32) -> ClenaupFut<'_>;
90type DropFn = unsafe fn(&mut AtomicPtr<()>);
91
92type GetBlockFut<'a> = BoxFuture<'a, OptionalBlockStuff>;
93type ClenaupFut<'a> = BoxFuture<'a, Result<()>>;
94
95#[cfg(test)]
96mod tests {
97    use std::sync::Arc;
98    use std::sync::atomic::AtomicUsize;
99
100    use anyhow::Result;
101    use tycho_block_util::block::BlockIdExt;
102
103    use super::*;
104
105    #[tokio::test]
106    async fn boxed_provider_works() -> Result<()> {
107        struct ProviderState {
108            get_next_called: AtomicUsize,
109            get_called: AtomicUsize,
110            cleanup_called: AtomicUsize,
111            dropped: AtomicUsize,
112        }
113
114        struct TestProvider {
115            state: Arc<ProviderState>,
116        }
117
118        impl Drop for TestProvider {
119            fn drop(&mut self) {
120                self.state.dropped.fetch_add(1, Ordering::Relaxed);
121            }
122        }
123
124        impl BlockProvider for TestProvider {
125            type GetNextBlockFut<'a> = futures_util::future::Ready<OptionalBlockStuff>;
126            type GetBlockFut<'a> = futures_util::future::Ready<OptionalBlockStuff>;
127            type CleanupFut<'a> = futures_util::future::Ready<Result<()>>;
128
129            fn get_next_block<'a>(&'a self, _: &'a BlockId) -> Self::GetNextBlockFut<'a> {
130                self.state.get_next_called.fetch_add(1, Ordering::Relaxed);
131                futures_util::future::ready(None)
132            }
133
134            fn get_block<'a>(&'a self, _: &'a BlockIdRelation) -> Self::GetBlockFut<'a> {
135                self.state.get_called.fetch_add(1, Ordering::Relaxed);
136                futures_util::future::ready(None)
137            }
138
139            fn cleanup_until(&self, _: u32) -> Self::CleanupFut<'_> {
140                self.state.cleanup_called.fetch_add(1, Ordering::Relaxed);
141                futures_util::future::ready(Ok(()))
142            }
143        }
144
145        let state = Arc::new(ProviderState {
146            get_next_called: AtomicUsize::new(0),
147            get_called: AtomicUsize::new(0),
148            cleanup_called: AtomicUsize::new(0),
149            dropped: AtomicUsize::new(0),
150        });
151        let boxed = BoxBlockProvider::new(TestProvider {
152            state: state.clone(),
153        });
154
155        assert_eq!(state.get_next_called.load(Ordering::Acquire), 0);
156        assert_eq!(state.get_called.load(Ordering::Acquire), 0);
157        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 0);
158        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
159
160        let mc_block_id = BlockId::default();
161        assert!(boxed.get_next_block(&mc_block_id).await.is_none());
162        assert_eq!(state.get_next_called.load(Ordering::Acquire), 1);
163        assert_eq!(state.get_called.load(Ordering::Acquire), 0);
164        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 0);
165        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
166
167        assert!(boxed.get_next_block(&mc_block_id).await.is_none());
168        assert_eq!(state.get_next_called.load(Ordering::Acquire), 2);
169        assert_eq!(state.get_called.load(Ordering::Acquire), 0);
170        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 0);
171        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
172
173        let relation = mc_block_id.relative_to_self();
174        assert!(boxed.get_block(&relation).await.is_none());
175        assert_eq!(state.get_next_called.load(Ordering::Acquire), 2);
176        assert_eq!(state.get_called.load(Ordering::Acquire), 1);
177        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 0);
178        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
179
180        assert!(boxed.get_block(&relation).await.is_none());
181        assert_eq!(state.get_next_called.load(Ordering::Acquire), 2);
182        assert_eq!(state.get_called.load(Ordering::Acquire), 2);
183        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 0);
184        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
185
186        boxed.cleanup_until(123).await.unwrap();
187        assert_eq!(state.get_next_called.load(Ordering::Acquire), 2);
188        assert_eq!(state.get_called.load(Ordering::Acquire), 2);
189        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 1);
190        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
191
192        boxed.cleanup_until(321).await.unwrap();
193        assert_eq!(state.get_next_called.load(Ordering::Acquire), 2);
194        assert_eq!(state.get_called.load(Ordering::Acquire), 2);
195        assert_eq!(state.cleanup_called.load(Ordering::Acquire), 2);
196        assert_eq!(state.dropped.load(Ordering::Acquire), 0);
197
198        assert_eq!(Arc::strong_count(&state), 2);
199        drop(boxed);
200
201        assert_eq!(state.get_next_called.load(Ordering::Acquire), 2);
202        assert_eq!(state.get_called.load(Ordering::Acquire), 2);
203        assert_eq!(state.dropped.load(Ordering::Acquire), 1);
204
205        assert_eq!(Arc::strong_count(&state), 1);
206
207        Ok(())
208    }
209}