vortex_layout/
segments.rs

1use std::collections::BTreeMap;
2use std::fmt::Display;
3use std::ops::{Bound, Deref, Range, RangeBounds};
4use std::sync::{Arc, RwLock};
5use std::task::Poll;
6
7use async_trait::async_trait;
8use futures::channel::mpsc;
9use futures::{SinkExt, Stream, StreamExt};
10use range_union_find::RangeUnionFind;
11use vortex_buffer::{Buffer, ByteBuffer};
12use vortex_error::{VortexExpect, VortexResult, vortex_err};
13use vortex_metrics::VortexMetrics;
14
15use crate::range_intersection;
16
17/// The identifier for a single segment.
18// TODO(ngates): should this be a `[u8]` instead? Allowing for arbitrary segment identifiers?
19#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct SegmentId(u32);
21
22impl From<u32> for SegmentId {
23    fn from(value: u32) -> Self {
24        Self(value)
25    }
26}
27
28impl Deref for SegmentId {
29    type Target = u32;
30
31    fn deref(&self) -> &Self::Target {
32        &self.0
33    }
34}
35
36impl Display for SegmentId {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "SegmentId({})", self.0)
39    }
40}
41
42#[async_trait]
43pub trait AsyncSegmentReader: 'static + Send + Sync {
44    /// Attempt to get the data associated with a given segment ID.
45    async fn get(&self, id: SegmentId) -> VortexResult<ByteBuffer>;
46}
47
48pub trait SegmentWriter {
49    /// Write the given data into a segment and return its identifier.
50    /// The provided buffers are concatenated together to form the segment.
51    ///
52    // TODO(ngates): in order to support aligned Direct I/O, it is preferable for all segments to
53    //  be aligned to the logical block size (typically 512, but could be 4096). For this reason,
54    //  if we know we're going to read an entire FlatLayout together, then we should probably
55    //  serialize it into a single segment that is 512 byte aligned? Or else, we should guarantee
56    //  to align the the first segment to 512, and then assume that coalescing captures the rest.
57    fn put(&mut self, buffer: &[ByteBuffer]) -> SegmentId;
58}
59
60#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
61pub enum RequiredSegmentKind {
62    Pruning = 1,
63    Filter = 2,
64    #[default]
65    Projection = 3,
66}
67
68#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
69pub struct SegmentPriority {
70    row_end: u64, // sort by row_end first
71    kind: RequiredSegmentKind,
72    row_start: u64,
73}
74
75impl SegmentPriority {
76    fn new(row_start: u64, row_end: u64, kind: RequiredSegmentKind) -> Self {
77        SegmentPriority {
78            row_end,
79            kind,
80            row_start,
81        }
82    }
83}
84
85const TOP_PRIORITY: SegmentPriority = SegmentPriority {
86    row_end: 0,
87    kind: RequiredSegmentKind::Pruning,
88    row_start: 0,
89};
90
91type SegmentStore = BTreeMap<SegmentPriority, Vec<SegmentId>>;
92
93#[derive(Default)]
94pub struct SegmentCollector {
95    store: Arc<RwLock<SegmentStore>>,
96    pub kind: RequiredSegmentKind,
97    metrics: VortexMetrics,
98}
99
100impl SegmentCollector {
101    pub fn new(metrics: VortexMetrics) -> Self {
102        Self {
103            metrics,
104            ..Default::default()
105        }
106    }
107
108    pub fn with_priority_hint(&self, kind: RequiredSegmentKind) -> Self {
109        Self {
110            store: self.store.clone(),
111            // highest priority wins
112            kind: kind.min(self.kind),
113            metrics: self.metrics.clone(),
114        }
115    }
116
117    pub fn push(&mut self, row_start: u64, row_end: u64, segment: SegmentId) {
118        let (start, end) = match self.kind {
119            // row offset inside the stats table is not our concern
120            RequiredSegmentKind::Pruning => (0, 0),
121            _ => (row_start, row_end),
122        };
123        self.increment_metrics();
124        let priority = SegmentPriority::new(start, end, self.kind);
125        self.store
126            .write()
127            .vortex_expect("poisoned lock")
128            .entry(priority)
129            .or_default()
130            .push(segment);
131    }
132
133    pub fn finish(self) -> (RowRangePruner, SegmentStream) {
134        let (cancellations_tx, cancellations_rx) = mpsc::unbounded();
135        (
136            RowRangePruner {
137                store: self.store.clone(),
138                cancellations_tx,
139                excluded_ranges: Default::default(),
140                metrics: self.metrics.clone(),
141            },
142            SegmentStream {
143                store: self.store,
144                cancellations_rx,
145                current_key: TOP_PRIORITY,
146                current_idx: 0,
147            },
148        )
149    }
150
151    fn increment_metrics(&self) {
152        self.metrics
153            .counter("vortex.scan.segments.count.total")
154            .inc();
155        self.metrics
156            .counter(format!("vortex.scan.segments.count.{:?}", self.kind))
157            .inc();
158    }
159}
160
161#[derive(Debug, Clone)]
162pub struct RowRangePruner {
163    store: Arc<RwLock<SegmentStore>>,
164    cancellations_tx: mpsc::UnboundedSender<SegmentId>,
165    excluded_ranges: Arc<RwLock<RangeUnionFind<u64>>>,
166    metrics: VortexMetrics,
167}
168
169impl RowRangePruner {
170    // Remove all segments fully encompassed by the given row range. Removals
171    // of each matching segment is notified to the cancellation channel.
172    pub async fn remove(&mut self, to_exclude: Range<u64>) -> VortexResult<()> {
173        let to_exclude = {
174            let mut excluded_ranges = self.excluded_ranges.write().vortex_expect("poisoned lock");
175            excluded_ranges
176                .insert_range(&to_exclude)
177                .map_err(|e| vortex_err!("invalid range: {e}"))?;
178            excluded_ranges
179                .find_range_with_element(&to_exclude.start)
180                .map_err(|_| vortex_err!("can not find range just inserted"))?
181        };
182        let first_row = match to_exclude.start_bound() {
183            Bound::Included(idx) => *idx,
184            Bound::Excluded(idx) => *idx + 1,
185            Bound::Unbounded => 0,
186        };
187
188        let last_row = match to_exclude.end_bound() {
189            Bound::Included(idx) => *idx + 1,
190            Bound::Excluded(idx) => *idx,
191            Bound::Unbounded => u64::MAX,
192        };
193
194        let cancelled_segments: Vec<_> = {
195            let mut store = self.store.write()?;
196            let to_remove: Vec<_> = store
197                .keys()
198                .filter(|key| key.kind != RequiredSegmentKind::Pruning)
199                .skip_while(|key| key.row_end < first_row)
200                .take_while(|key| key.row_end <= last_row)
201                .filter(|key| first_row <= key.row_start)
202                .copied()
203                .collect();
204            to_remove
205                .iter()
206                .flat_map(|key| store.remove(key).unwrap_or_default())
207                .collect()
208        };
209        self.metrics
210            .counter("vortex.scan.segments.cancel_sent")
211            .add(cancelled_segments.len() as i64);
212        for id in cancelled_segments {
213            self.cancellations_tx
214                .send(id)
215                .await
216                .map_err(|_| vortex_err!("channel closed"))?;
217        }
218        Ok(())
219    }
220
221    /// Bulk remove row_indices. It is intended to be used for
222    /// pruning row indices known to be excluded before the scan.
223    /// It does not notify the cancellation channel.
224    pub fn retain_matching(&mut self, row_indices: Buffer<u64>) {
225        if row_indices.is_empty() {
226            return;
227        }
228        self.store
229            .write()
230            .vortex_expect("poisoned lock")
231            .retain(|key, _| {
232                if key.kind == RequiredSegmentKind::Pruning {
233                    return true; // keep segments required for pruning
234                }
235                let keep =
236                    range_intersection(&(key.row_start..key.row_end), &row_indices).is_some();
237                if !keep {
238                    self.metrics
239                        .counter("vortex.scan.segment.pruned_by_row_indices")
240                        .inc();
241                }
242                keep
243            });
244    }
245}
246
247pub struct SegmentStream {
248    store: Arc<RwLock<SegmentStore>>,
249    cancellations_rx: mpsc::UnboundedReceiver<SegmentId>,
250    current_key: SegmentPriority,
251    current_idx: usize,
252}
253
254pub enum SegmentEvent {
255    Cancel(SegmentId),
256    Request(SegmentId),
257}
258
259impl Stream for SegmentStream {
260    type Item = SegmentEvent;
261
262    fn poll_next(
263        mut self: std::pin::Pin<&mut Self>,
264        cx: &mut std::task::Context<'_>,
265    ) -> Poll<Option<Self::Item>> {
266        // cancellations take priority over the next segment in store
267        let channel_closed = match self.cancellations_rx.poll_next_unpin(cx) {
268            Poll::Ready(Some(segment)) => return Poll::Ready(Some(SegmentEvent::Cancel(segment))),
269            Poll::Ready(None) => true,
270            Poll::Pending => false,
271        };
272
273        {
274            let store_clone = self.store.clone();
275            let store_guard = store_clone.read().vortex_expect("poisoned lock");
276            let store_iter = store_guard.range(self.current_key..);
277            for (&key, segments) in store_iter {
278                match key == self.current_key {
279                    true if self.current_idx >= segments.len() => continue,
280                    false => {
281                        self.current_idx = 0;
282                        self.current_key = key;
283                    }
284                    _ => {}
285                };
286                let segment_to_yield = segments[self.current_idx];
287                self.current_idx += 1;
288                return Poll::Ready(Some(SegmentEvent::Request(segment_to_yield)));
289            }
290        }
291        // store is exhausted if we are here
292        if channel_closed {
293            return Poll::Ready(None);
294        }
295        match self.cancellations_rx.poll_next_unpin(cx) {
296            Poll::Ready(Some(segment)) => Poll::Ready(Some(SegmentEvent::Cancel(segment))),
297            Poll::Ready(None) => Poll::Ready(None), // channel closed, end stream
298            Poll::Pending => Poll::Pending,
299        }
300    }
301}
302
303#[cfg(test)]
304pub mod test {
305    use futures::executor::block_on;
306    use vortex_array::aliases::hash_map::HashMap;
307    use vortex_array::aliases::hash_set::HashSet;
308    use vortex_buffer::ByteBufferMut;
309    use vortex_error::{VortexExpect, vortex_err};
310
311    use super::*;
312
313    #[derive(Default)]
314    pub struct TestSegments {
315        segments: Vec<ByteBuffer>,
316    }
317
318    impl SegmentWriter for TestSegments {
319        fn put(&mut self, data: &[ByteBuffer]) -> SegmentId {
320            let id = u32::try_from(self.segments.len())
321                .vortex_expect("Cannot store more than u32::MAX segments");
322
323            // Combine all the buffers since we're only a test implementation
324            let mut buffer = ByteBufferMut::empty();
325            for segment in data {
326                buffer.extend_from_slice(segment.as_ref());
327            }
328            self.segments.push(buffer.freeze());
329
330            id.into()
331        }
332    }
333
334    #[async_trait]
335    impl AsyncSegmentReader for TestSegments {
336        async fn get(&self, id: SegmentId) -> VortexResult<ByteBuffer> {
337            self.segments
338                .get(*id as usize)
339                .cloned()
340                .ok_or_else(|| vortex_err!("Segment not found"))
341        }
342    }
343
344    fn setup_store() -> Arc<RwLock<SegmentStore>> {
345        let mut store = BTreeMap::new();
346
347        // Add segments that span different ranges
348        store.insert(
349            SegmentPriority::new(0, 100, RequiredSegmentKind::Projection),
350            vec![SegmentId(1)],
351        );
352        store.insert(
353            SegmentPriority::new(50, 150, RequiredSegmentKind::Projection),
354            vec![SegmentId(2)],
355        );
356        store.insert(
357            SegmentPriority::new(150, 250, RequiredSegmentKind::Filter),
358            vec![SegmentId(3)],
359        );
360        store.insert(
361            SegmentPriority::new(200, 300, RequiredSegmentKind::Projection),
362            vec![SegmentId(4)],
363        );
364        store.insert(
365            SegmentPriority::new(0, 0, RequiredSegmentKind::Pruning),
366            vec![SegmentId(5)],
367        );
368
369        Arc::new(RwLock::new(store))
370    }
371
372    #[test]
373    fn test_remove_fully_encompassed_segments() {
374        block_on(async {
375            // Setup
376            let store = setup_store();
377            let (tx, mut rx) = mpsc::unbounded();
378            let mut pruner = RowRangePruner {
379                store: store.clone(),
380                cancellations_tx: tx,
381                excluded_ranges: Default::default(),
382                metrics: Default::default(),
383            };
384
385            // Test removing segments in range 0..200
386            let result = pruner.remove(0..200).await;
387            assert!(result.is_ok(), "Removal operation should succeed");
388
389            // Check that the correct segments were removed from the store
390            let store_lock = store.read().vortex_expect("poisoned lock");
391            assert!(!store_lock.contains_key(&SegmentPriority::new(
392                0,
393                100,
394                RequiredSegmentKind::Projection
395            )));
396            assert!(!store_lock.contains_key(&SegmentPriority::new(
397                50,
398                150,
399                RequiredSegmentKind::Projection
400            )));
401            assert!(store_lock.contains_key(&SegmentPriority::new(
402                150,
403                250,
404                RequiredSegmentKind::Filter
405            ))); // Not fully encompassed
406            assert!(store_lock.contains_key(&SegmentPriority::new(
407                200,
408                300,
409                RequiredSegmentKind::Projection
410            )));
411            assert!(store_lock.contains_key(&SegmentPriority::new(
412                0,
413                0,
414                RequiredSegmentKind::Pruning
415            )));
416
417            // Check that the correct cancellation messages were sent
418            let mut received_cancellations = HashSet::new();
419            while let Ok(Some(id)) = rx.try_next() {
420                received_cancellations.insert(id);
421            }
422
423            assert!(received_cancellations.contains(&SegmentId(1)));
424            assert!(received_cancellations.contains(&SegmentId(2)));
425            assert!(!received_cancellations.contains(&SegmentId(3))); // Not fully encompassed
426            assert!(!received_cancellations.contains(&SegmentId(4)));
427            assert!(!received_cancellations.contains(&SegmentId(5)));
428        })
429    }
430
431    #[test]
432    fn test_no_double_cancellation() {
433        block_on(async {
434            // Setup
435            let store = setup_store();
436            let (tx, mut rx) = mpsc::unbounded();
437            let mut pruner = RowRangePruner {
438                store: store.clone(),
439                cancellations_tx: tx,
440                excluded_ranges: Default::default(),
441                metrics: Default::default(),
442            };
443
444            // First removal (0..100)
445            let result = pruner.remove(0..100).await;
446            assert!(result.is_ok(), "First removal operation should succeed");
447
448            // Second removal with overlapping range (50..150)
449            let result = pruner.remove(50..150).await;
450            assert!(result.is_ok(), "Second removal operation should succeed");
451
452            // Third removal with broader range (0..200)
453            let result = pruner.remove(0..200).await;
454            assert!(result.is_ok(), "Third removal operation should succeed");
455
456            // Check all cancellation messages
457            let mut received_cancellations = Vec::new();
458            while let Ok(Some(id)) = rx.try_next() {
459                received_cancellations.push(id);
460            }
461
462            // Count occurrences of each segment ID
463            let mut id_counts = HashMap::new();
464            for id in received_cancellations {
465                *id_counts.entry(id).or_insert(0) += 1;
466            }
467
468            // Verify no segment was cancelled more than once
469            for (id, count) in id_counts {
470                assert_eq!(count, 1, "Segment {:?} was cancelled {} times", id, count);
471            }
472        })
473    }
474
475    #[test]
476    fn test_range_merging() {
477        block_on(async {
478            // Setup
479            let store = setup_store();
480            let (tx, _rx) = mpsc::unbounded();
481            let mut pruner = RowRangePruner {
482                store: store.clone(),
483                cancellations_tx: tx,
484                excluded_ranges: Default::default(),
485                metrics: Default::default(),
486            };
487
488            // First removal (0..75)
489            let result = pruner.remove(0..75).await;
490            assert!(result.is_ok());
491
492            // Second removal with adjacent range (75..150)
493            let result = pruner.remove(75..150).await;
494            assert!(result.is_ok());
495
496            // Third removal with overlapping range (125..200)
497            let result = pruner.remove(125..200).await;
498            assert!(result.is_ok());
499
500            // Check the store to confirm proper range merging behavior
501            let store_lock = store.read().vortex_expect("poisoned lock");
502            assert!(!store_lock.contains_key(&SegmentPriority::new(
503                0,
504                100,
505                RequiredSegmentKind::Projection
506            )));
507            assert!(!store_lock.contains_key(&SegmentPriority::new(
508                50,
509                150,
510                RequiredSegmentKind::Projection
511            )));
512            assert!(store_lock.contains_key(&SegmentPriority::new(
513                150,
514                250,
515                RequiredSegmentKind::Filter
516            ))); // Not fully encompassed
517        })
518    }
519
520    #[test]
521    fn test_retain_matching_with_pruning_segments() {
522        block_on(async {
523            // Setup
524            let store = setup_store();
525            let (tx, _rx) = mpsc::unbounded();
526            let mut pruner = RowRangePruner {
527                store: store.clone(),
528                cancellations_tx: tx,
529                excluded_ranges: Default::default(),
530                metrics: Default::default(),
531            };
532
533            // Create a buffer with specific row indices
534            let row_indices = Buffer::from_iter(vec![75, 125, 175, 225, 325, 375]);
535
536            // Call retain_matching
537            pruner.retain_matching(row_indices);
538
539            // Check that the correct segments were retained
540            let store_lock = store.read().vortex_expect("poisoned lock");
541
542            // Segments that intersect with the row indices should be kept
543            assert!(store_lock.contains_key(&SegmentPriority::new(
544                0,
545                100,
546                RequiredSegmentKind::Projection
547            ))); // Contains 75
548            assert!(store_lock.contains_key(&SegmentPriority::new(
549                50,
550                150,
551                RequiredSegmentKind::Projection
552            ))); // Contains 75, 125
553            assert!(store_lock.contains_key(&SegmentPriority::new(
554                150,
555                250,
556                RequiredSegmentKind::Filter
557            ))); // Contains 175, 225
558            assert!(store_lock.contains_key(&SegmentPriority::new(
559                200,
560                300,
561                RequiredSegmentKind::Projection
562            ))); // Contains 225
563
564            // PRUNING segments should always be kept
565            assert!(store_lock.contains_key(&SegmentPriority::new(
566                0,
567                0,
568                RequiredSegmentKind::Pruning
569            )));
570        })
571    }
572
573    #[test]
574    fn test_cancellation_channel_closed() {
575        block_on(async {
576            let store = setup_store();
577            let (tx, rx) = mpsc::unbounded();
578            let mut pruner = RowRangePruner {
579                store: store.clone(),
580                cancellations_tx: tx,
581                excluded_ranges: Default::default(),
582                metrics: Default::default(),
583            };
584
585            // Drop the receiver to close the channel
586            drop(rx);
587
588            // Attempt to remove segments
589            let result = pruner.remove(0..100).await;
590
591            // Should fail with channel closed error
592            assert!(
593                result.is_err(),
594                "Removal should fail when channel is closed"
595            );
596        })
597    }
598
599    #[test]
600    fn test_segments_of_different_kinds() {
601        block_on(async {
602            // Setup
603            let store = setup_store();
604            let (tx, mut rx) = mpsc::unbounded();
605            let mut pruner = RowRangePruner {
606                store: store.clone(),
607                cancellations_tx: tx,
608                excluded_ranges: Default::default(),
609                metrics: Default::default(),
610            };
611
612            // Test removing segments that cover the entire range
613            let result = pruner.remove(0..400).await;
614            assert!(result.is_ok(), "Removal operation should succeed");
615
616            // Check that segments of all kinds that are fully encompassed were removed
617            let store_lock = store.read().vortex_expect("poisoned lock");
618            assert_eq!(store_lock.len(), 1);
619
620            // Verify the cancellations
621            let mut received_cancellations = HashSet::new();
622            while let Ok(Some(id)) = rx.try_next() {
623                received_cancellations.insert(id);
624            }
625
626            assert!(received_cancellations.contains(&SegmentId(1))); // PROJECTION
627            assert!(received_cancellations.contains(&SegmentId(2))); // PROJECTION
628            assert!(received_cancellations.contains(&SegmentId(3))); // FILTER
629            assert!(received_cancellations.contains(&SegmentId(4))); // PROJECTION
630        })
631    }
632}