vortex_layout/segments/
shared.rs1use std::sync::Arc;
5
6use futures::FutureExt;
7use futures::TryFutureExt;
8use futures::future::BoxFuture;
9use futures::future::WeakShared;
10use vortex_buffer::BufferHandle;
11use vortex_error::SharedVortexResult;
12use vortex_error::VortexError;
13use vortex_error::VortexExpect;
14use vortex_utils::aliases::dash_map::DashMap;
15use vortex_utils::aliases::dash_map::Entry;
16
17use crate::segments::SegmentFuture;
18use crate::segments::SegmentId;
19use crate::segments::SegmentSource;
20
21pub struct SharedSegmentSource<S> {
24 inner: S,
25 in_flight: DashMap<SegmentId, WeakShared<SharedSegmentFuture>>,
26}
27
28type SharedSegmentFuture = BoxFuture<'static, SharedVortexResult<BufferHandle>>;
29
30impl<S: SegmentSource> SharedSegmentSource<S> {
31 pub fn new(inner: S) -> Self {
33 Self {
34 inner,
35 in_flight: DashMap::default(),
36 }
37 }
38}
39
40impl<S: SegmentSource> SegmentSource for SharedSegmentSource<S> {
41 fn request(&self, id: SegmentId) -> SegmentFuture {
42 loop {
43 match self.in_flight.entry(id) {
44 Entry::Occupied(e) => {
45 if let Some(shared_future) = e.get().upgrade() {
46 return shared_future.map_err(VortexError::from).boxed();
47 } else {
48 e.remove();
50 }
51 }
52 Entry::Vacant(e) => {
53 let future = self.inner.request(id).map_err(Arc::new).boxed().shared();
54 e.insert(
55 future
56 .downgrade()
57 .vortex_expect("just created, cannot be polled to completion"),
58 );
59 return future.map_err(VortexError::from).boxed();
60 }
61 }
62 }
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use std::sync::atomic::AtomicUsize;
69 use std::sync::atomic::Ordering;
70
71 use vortex_buffer::ByteBuffer;
72
73 use super::*;
74 use crate::segments::SegmentSink;
75 use crate::segments::TestSegments;
76 use crate::sequence::SequenceId;
77
78 #[derive(Default, Clone)]
80 struct CountingSegmentSource {
81 segments: TestSegments,
82 request_count: Arc<AtomicUsize>,
83 }
84
85 impl SegmentSource for CountingSegmentSource {
86 fn request(&self, id: SegmentId) -> SegmentFuture {
87 self.request_count.fetch_add(1, Ordering::SeqCst);
88 self.segments.request(id)
89 }
90 }
91
92 #[tokio::test]
93 async fn test_shared_source_deduplicates_concurrent_requests() {
94 let source = CountingSegmentSource::default();
95
96 let data = ByteBuffer::from(vec![1, 2, 3, 4]);
98 let seq_id = SequenceId::root().downgrade();
99 source
100 .segments
101 .write(seq_id, vec![data.clone()])
102 .await
103 .unwrap();
104
105 let shared_source = SharedSegmentSource::new(source.clone());
106
107 let id = SegmentId::from(0);
109 let future1 = shared_source.request(id);
110 let future2 = shared_source.request(id);
111
112 let (result1, result2) = futures::join!(future1, future2);
114 assert_eq!(*result1.unwrap().bytes(), data);
115 assert_eq!(*result2.unwrap().bytes(), data);
116
117 assert_eq!(source.request_count.load(Ordering::Relaxed), 1);
119 }
120
121 #[tokio::test]
122 async fn test_shared_source_handles_dropped_futures() {
123 let source = CountingSegmentSource::default();
124
125 let data = ByteBuffer::from(vec![5, 6, 7, 8]);
127 let seq_id = SequenceId::root().downgrade();
128 source
129 .segments
130 .write(seq_id, vec![data.clone()])
131 .await
132 .unwrap();
133
134 let shared_source = SharedSegmentSource::new(source.clone());
135 let id = SegmentId::from(0);
136
137 {
139 let _future = shared_source.request(id);
140 }
142
143 let result = shared_source.request(id).await;
145 assert_eq!(*result.unwrap().bytes(), data);
146
147 assert_eq!(source.request_count.load(Ordering::Relaxed), 2);
149 }
150}