vortex_layout/segments/
shared.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use futures::FutureExt;
7use futures::TryFutureExt;
8use futures::future::BoxFuture;
9use futures::future::WeakShared;
10use vortex_buffer::BufferHandle;
11use vortex_error::SharedVortexResult;
12use vortex_error::VortexError;
13use vortex_error::VortexExpect;
14use vortex_utils::aliases::dash_map::DashMap;
15use vortex_utils::aliases::dash_map::Entry;
16
17use crate::segments::SegmentFuture;
18use crate::segments::SegmentId;
19use crate::segments::SegmentSource;
20
21/// A [`SegmentSource`] that allows multiple requesters to await the same underlying segment
22/// request.
23pub struct SharedSegmentSource<S> {
24    inner: S,
25    in_flight: DashMap<SegmentId, WeakShared<SharedSegmentFuture>>,
26}
27
28type SharedSegmentFuture = BoxFuture<'static, SharedVortexResult<BufferHandle>>;
29
30impl<S: SegmentSource> SharedSegmentSource<S> {
31    /// Create a new `SharedSegmentSource` wrapping the provided inner source.
32    pub fn new(inner: S) -> Self {
33        Self {
34            inner,
35            in_flight: DashMap::default(),
36        }
37    }
38}
39
40impl<S: SegmentSource> SegmentSource for SharedSegmentSource<S> {
41    fn request(&self, id: SegmentId) -> SegmentFuture {
42        loop {
43            match self.in_flight.entry(id) {
44                Entry::Occupied(e) => {
45                    if let Some(shared_future) = e.get().upgrade() {
46                        return shared_future.map_err(VortexError::from).boxed();
47                    } else {
48                        // The future has been dropped, remove the entry and try again.
49                        e.remove();
50                    }
51                }
52                Entry::Vacant(e) => {
53                    let future = self.inner.request(id).map_err(Arc::new).boxed().shared();
54                    e.insert(
55                        future
56                            .downgrade()
57                            .vortex_expect("just created, cannot be polled to completion"),
58                    );
59                    return future.map_err(VortexError::from).boxed();
60                }
61            }
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use std::sync::atomic::AtomicUsize;
69    use std::sync::atomic::Ordering;
70
71    use vortex_buffer::ByteBuffer;
72
73    use super::*;
74    use crate::segments::SegmentSink;
75    use crate::segments::TestSegments;
76    use crate::sequence::SequenceId;
77
78    // Custom source that tracks how many times a segment is requested
79    #[derive(Default, Clone)]
80    struct CountingSegmentSource {
81        segments: TestSegments,
82        request_count: Arc<AtomicUsize>,
83    }
84
85    impl SegmentSource for CountingSegmentSource {
86        fn request(&self, id: SegmentId) -> SegmentFuture {
87            self.request_count.fetch_add(1, Ordering::SeqCst);
88            self.segments.request(id)
89        }
90    }
91
92    #[tokio::test]
93    async fn test_shared_source_deduplicates_concurrent_requests() {
94        let source = CountingSegmentSource::default();
95
96        // Add a segment to the test source
97        let data = ByteBuffer::from(vec![1, 2, 3, 4]);
98        let seq_id = SequenceId::root().downgrade();
99        source
100            .segments
101            .write(seq_id, vec![data.clone()])
102            .await
103            .unwrap();
104
105        let shared_source = SharedSegmentSource::new(source.clone());
106
107        // Request the same segment twice concurrently
108        let id = SegmentId::from(0);
109        let future1 = shared_source.request(id);
110        let future2 = shared_source.request(id);
111
112        // Both futures should resolve to the same data
113        let (result1, result2) = futures::join!(future1, future2);
114        assert_eq!(*result1.unwrap().bytes(), data);
115        assert_eq!(*result2.unwrap().bytes(), data);
116
117        // The inner source should have been called only once
118        assert_eq!(source.request_count.load(Ordering::Relaxed), 1);
119    }
120
121    #[tokio::test]
122    async fn test_shared_source_handles_dropped_futures() {
123        let source = CountingSegmentSource::default();
124
125        // Add a segment
126        let data = ByteBuffer::from(vec![5, 6, 7, 8]);
127        let seq_id = SequenceId::root().downgrade();
128        source
129            .segments
130            .write(seq_id, vec![data.clone()])
131            .await
132            .unwrap();
133
134        let shared_source = SharedSegmentSource::new(source.clone());
135        let id = SegmentId::from(0);
136
137        // Create and immediately drop a future
138        {
139            let _future = shared_source.request(id);
140            // Future is dropped here
141        }
142
143        // A new request should still work correctly
144        let result = shared_source.request(id).await;
145        assert_eq!(*result.unwrap().bytes(), data);
146
147        // Should have made 2 requests since the first was dropped before completion
148        assert_eq!(source.request_count.load(Ordering::Relaxed), 2);
149    }
150}