vortex_layout/segments/
shared.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use futures::future::{BoxFuture, WeakShared};
7use futures::{FutureExt, TryFutureExt};
8use vortex_buffer::ByteBuffer;
9use vortex_error::{SharedVortexResult, VortexError, VortexExpect};
10use vortex_utils::aliases::dash_map::{DashMap, Entry};
11
12use crate::segments::{SegmentFuture, SegmentId, SegmentSource};
13
14/// A [`SegmentSource`] that allows multiple requesters to await the same underlying segment
15/// request.
16pub struct SharedSegmentSource<S> {
17    inner: S,
18    in_flight: DashMap<SegmentId, WeakShared<SharedSegmentFuture>>,
19}
20
21type SharedSegmentFuture = BoxFuture<'static, SharedVortexResult<ByteBuffer>>;
22
23impl<S: SegmentSource> SharedSegmentSource<S> {
24    /// Create a new `SharedSegmentSource` wrapping the provided inner source.
25    pub fn new(inner: S) -> Self {
26        Self {
27            inner,
28            in_flight: DashMap::default(),
29        }
30    }
31}
32
33impl<S: SegmentSource> SegmentSource for SharedSegmentSource<S> {
34    fn request(&self, id: SegmentId) -> SegmentFuture {
35        loop {
36            match self.in_flight.entry(id) {
37                Entry::Occupied(e) => {
38                    if let Some(shared_future) = e.get().upgrade() {
39                        return shared_future.map_err(VortexError::from).boxed();
40                    } else {
41                        // The future has been dropped, remove the entry and try again.
42                        e.remove();
43                    }
44                }
45                Entry::Vacant(e) => {
46                    let future = self.inner.request(id).map_err(Arc::new).boxed().shared();
47                    e.insert(
48                        future
49                            .downgrade()
50                            .vortex_expect("just created, cannot be polled to completion"),
51                    );
52                    return future.map_err(VortexError::from).boxed();
53                }
54            }
55        }
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use std::sync::atomic::{AtomicUsize, Ordering};
62
63    use vortex_buffer::ByteBuffer;
64
65    use super::*;
66    use crate::segments::{SegmentSink, TestSegments};
67    use crate::sequence::SequenceId;
68
69    // Custom source that tracks how many times a segment is requested
70    #[derive(Default, Clone)]
71    struct CountingSegmentSource {
72        segments: TestSegments,
73        request_count: Arc<AtomicUsize>,
74    }
75
76    impl SegmentSource for CountingSegmentSource {
77        fn request(&self, id: SegmentId) -> SegmentFuture {
78            self.request_count.fetch_add(1, Ordering::SeqCst);
79            self.segments.request(id)
80        }
81    }
82
83    #[tokio::test]
84    async fn test_shared_source_deduplicates_concurrent_requests() {
85        let source = CountingSegmentSource::default();
86
87        // Add a segment to the test source
88        let data = ByteBuffer::from(vec![1, 2, 3, 4]);
89        let seq_id = SequenceId::root().downgrade();
90        source
91            .segments
92            .write(seq_id, vec![data.clone()])
93            .await
94            .unwrap();
95
96        let shared_source = SharedSegmentSource::new(source.clone());
97
98        // Request the same segment twice concurrently
99        let id = SegmentId::from(0);
100        let future1 = shared_source.request(id);
101        let future2 = shared_source.request(id);
102
103        // Both futures should resolve to the same data
104        let (result1, result2) = futures::join!(future1, future2);
105        assert_eq!(result1.unwrap(), data);
106        assert_eq!(result2.unwrap(), data);
107
108        // The inner source should have been called only once
109        assert_eq!(source.request_count.load(Ordering::Relaxed), 1);
110    }
111
112    #[tokio::test]
113    async fn test_shared_source_handles_dropped_futures() {
114        let source = CountingSegmentSource::default();
115
116        // Add a segment
117        let data = ByteBuffer::from(vec![5, 6, 7, 8]);
118        let seq_id = SequenceId::root().downgrade();
119        source
120            .segments
121            .write(seq_id, vec![data.clone()])
122            .await
123            .unwrap();
124
125        let shared_source = SharedSegmentSource::new(source.clone());
126        let id = SegmentId::from(0);
127
128        // Create and immediately drop a future
129        {
130            let _future = shared_source.request(id);
131            // Future is dropped here
132        }
133
134        // A new request should still work correctly
135        let result = shared_source.request(id).await;
136        assert_eq!(result.unwrap(), data);
137
138        // Should have made 2 requests since the first was dropped before completion
139        assert_eq!(source.request_count.load(Ordering::Relaxed), 2);
140    }
141}