vortex_layout/segments/
shared.rs1use std::sync::Arc;
5
6use futures::future::{BoxFuture, WeakShared};
7use futures::{FutureExt, TryFutureExt};
8use vortex_buffer::ByteBuffer;
9use vortex_error::{SharedVortexResult, VortexError, VortexExpect};
10use vortex_utils::aliases::dash_map::{DashMap, Entry};
11
12use crate::segments::{SegmentFuture, SegmentId, SegmentSource};
13
14pub struct SharedSegmentSource<S> {
17 inner: S,
18 in_flight: DashMap<SegmentId, WeakShared<SharedSegmentFuture>>,
19}
20
21type SharedSegmentFuture = BoxFuture<'static, SharedVortexResult<ByteBuffer>>;
22
23impl<S: SegmentSource> SharedSegmentSource<S> {
24 pub fn new(inner: S) -> Self {
26 Self {
27 inner,
28 in_flight: DashMap::default(),
29 }
30 }
31}
32
33impl<S: SegmentSource> SegmentSource for SharedSegmentSource<S> {
34 fn request(&self, id: SegmentId) -> SegmentFuture {
35 loop {
36 match self.in_flight.entry(id) {
37 Entry::Occupied(e) => {
38 if let Some(shared_future) = e.get().upgrade() {
39 return shared_future.map_err(VortexError::from).boxed();
40 } else {
41 e.remove();
43 }
44 }
45 Entry::Vacant(e) => {
46 let future = self.inner.request(id).map_err(Arc::new).boxed().shared();
47 e.insert(
48 future
49 .downgrade()
50 .vortex_expect("just created, cannot be polled to completion"),
51 );
52 return future.map_err(VortexError::from).boxed();
53 }
54 }
55 }
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use std::sync::atomic::{AtomicUsize, Ordering};
62
63 use vortex_buffer::ByteBuffer;
64
65 use super::*;
66 use crate::segments::{SegmentSink, TestSegments};
67 use crate::sequence::SequenceId;
68
69 #[derive(Default, Clone)]
71 struct CountingSegmentSource {
72 segments: TestSegments,
73 request_count: Arc<AtomicUsize>,
74 }
75
76 impl SegmentSource for CountingSegmentSource {
77 fn request(&self, id: SegmentId) -> SegmentFuture {
78 self.request_count.fetch_add(1, Ordering::SeqCst);
79 self.segments.request(id)
80 }
81 }
82
83 #[tokio::test]
84 async fn test_shared_source_deduplicates_concurrent_requests() {
85 let source = CountingSegmentSource::default();
86
87 let data = ByteBuffer::from(vec![1, 2, 3, 4]);
89 let seq_id = SequenceId::root().downgrade();
90 source
91 .segments
92 .write(seq_id, vec![data.clone()])
93 .await
94 .unwrap();
95
96 let shared_source = SharedSegmentSource::new(source.clone());
97
98 let id = SegmentId::from(0);
100 let future1 = shared_source.request(id);
101 let future2 = shared_source.request(id);
102
103 let (result1, result2) = futures::join!(future1, future2);
105 assert_eq!(result1.unwrap(), data);
106 assert_eq!(result2.unwrap(), data);
107
108 assert_eq!(source.request_count.load(Ordering::Relaxed), 1);
110 }
111
112 #[tokio::test]
113 async fn test_shared_source_handles_dropped_futures() {
114 let source = CountingSegmentSource::default();
115
116 let data = ByteBuffer::from(vec![5, 6, 7, 8]);
118 let seq_id = SequenceId::root().downgrade();
119 source
120 .segments
121 .write(seq_id, vec![data.clone()])
122 .await
123 .unwrap();
124
125 let shared_source = SharedSegmentSource::new(source.clone());
126 let id = SegmentId::from(0);
127
128 {
130 let _future = shared_source.request(id);
131 }
133
134 let result = shared_source.request(id).await;
136 assert_eq!(result.unwrap(), data);
137
138 assert_eq!(source.request_count.load(Ordering::Relaxed), 2);
140 }
141}