1use std::collections::BTreeMap;
2use std::fmt::Display;
3use std::ops::{Bound, Deref, Range, RangeBounds};
4use std::sync::{Arc, RwLock};
5use std::task::Poll;
6
7use async_trait::async_trait;
8use futures::channel::mpsc;
9use futures::{SinkExt, Stream, StreamExt};
10use range_union_find::RangeUnionFind;
11use vortex_buffer::{Buffer, ByteBuffer};
12use vortex_error::{VortexExpect, VortexResult, vortex_err};
13use vortex_metrics::VortexMetrics;
14
15use crate::range_intersection;
16
17#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct SegmentId(u32);
21
22impl From<u32> for SegmentId {
23 fn from(value: u32) -> Self {
24 Self(value)
25 }
26}
27
28impl Deref for SegmentId {
29 type Target = u32;
30
31 fn deref(&self) -> &Self::Target {
32 &self.0
33 }
34}
35
36impl Display for SegmentId {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "SegmentId({})", self.0)
39 }
40}
41
42#[async_trait]
43pub trait AsyncSegmentReader: 'static + Send + Sync {
44 async fn get(&self, id: SegmentId) -> VortexResult<ByteBuffer>;
46}
47
48pub trait SegmentWriter {
49 fn put(&mut self, buffer: &[ByteBuffer]) -> SegmentId;
58}
59
60#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
61pub enum RequiredSegmentKind {
62 Pruning = 1,
63 Filter = 2,
64 #[default]
65 Projection = 3,
66}
67
68#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
69pub struct SegmentPriority {
70 row_end: u64, kind: RequiredSegmentKind,
72 row_start: u64,
73}
74
75impl SegmentPriority {
76 fn new(row_start: u64, row_end: u64, kind: RequiredSegmentKind) -> Self {
77 SegmentPriority {
78 row_end,
79 kind,
80 row_start,
81 }
82 }
83}
84
85const TOP_PRIORITY: SegmentPriority = SegmentPriority {
86 row_end: 0,
87 kind: RequiredSegmentKind::Pruning,
88 row_start: 0,
89};
90
91type SegmentStore = BTreeMap<SegmentPriority, Vec<SegmentId>>;
92
93#[derive(Default)]
94pub struct SegmentCollector {
95 store: Arc<RwLock<SegmentStore>>,
96 pub kind: RequiredSegmentKind,
97 metrics: VortexMetrics,
98}
99
100impl SegmentCollector {
101 pub fn new(metrics: VortexMetrics) -> Self {
102 Self {
103 metrics,
104 ..Default::default()
105 }
106 }
107
108 pub fn with_priority_hint(&self, kind: RequiredSegmentKind) -> Self {
109 Self {
110 store: self.store.clone(),
111 kind: kind.min(self.kind),
113 metrics: self.metrics.clone(),
114 }
115 }
116
117 pub fn push(&mut self, row_start: u64, row_end: u64, segment: SegmentId) {
118 let (start, end) = match self.kind {
119 RequiredSegmentKind::Pruning => (0, 0),
121 _ => (row_start, row_end),
122 };
123 self.increment_metrics();
124 let priority = SegmentPriority::new(start, end, self.kind);
125 self.store
126 .write()
127 .vortex_expect("poisoned lock")
128 .entry(priority)
129 .or_default()
130 .push(segment);
131 }
132
133 pub fn finish(self) -> (RowRangePruner, SegmentStream) {
134 let (cancellations_tx, cancellations_rx) = mpsc::unbounded();
135 (
136 RowRangePruner {
137 store: self.store.clone(),
138 cancellations_tx,
139 excluded_ranges: Default::default(),
140 metrics: self.metrics.clone(),
141 },
142 SegmentStream {
143 store: self.store,
144 cancellations_rx,
145 current_key: TOP_PRIORITY,
146 current_idx: 0,
147 },
148 )
149 }
150
151 fn increment_metrics(&self) {
152 self.metrics
153 .counter("vortex.scan.segments.count.total")
154 .inc();
155 self.metrics
156 .counter(format!("vortex.scan.segments.count.{:?}", self.kind))
157 .inc();
158 }
159}
160
161#[derive(Debug, Clone)]
162pub struct RowRangePruner {
163 store: Arc<RwLock<SegmentStore>>,
164 cancellations_tx: mpsc::UnboundedSender<SegmentId>,
165 excluded_ranges: Arc<RwLock<RangeUnionFind<u64>>>,
166 metrics: VortexMetrics,
167}
168
169impl RowRangePruner {
170 pub async fn remove(&mut self, to_exclude: Range<u64>) -> VortexResult<()> {
173 let to_exclude = {
174 let mut excluded_ranges = self.excluded_ranges.write().vortex_expect("poisoned lock");
175 excluded_ranges
176 .insert_range(&to_exclude)
177 .map_err(|e| vortex_err!("invalid range: {e}"))?;
178 excluded_ranges
179 .find_range_with_element(&to_exclude.start)
180 .map_err(|_| vortex_err!("can not find range just inserted"))?
181 };
182 let first_row = match to_exclude.start_bound() {
183 Bound::Included(idx) => *idx,
184 Bound::Excluded(idx) => *idx + 1,
185 Bound::Unbounded => 0,
186 };
187
188 let last_row = match to_exclude.end_bound() {
189 Bound::Included(idx) => *idx + 1,
190 Bound::Excluded(idx) => *idx,
191 Bound::Unbounded => u64::MAX,
192 };
193
194 let cancelled_segments: Vec<_> = {
195 let mut store = self.store.write()?;
196 let to_remove: Vec<_> = store
197 .keys()
198 .filter(|key| key.kind != RequiredSegmentKind::Pruning)
199 .skip_while(|key| key.row_end < first_row)
200 .take_while(|key| key.row_end <= last_row)
201 .filter(|key| first_row <= key.row_start)
202 .copied()
203 .collect();
204 to_remove
205 .iter()
206 .flat_map(|key| store.remove(key).unwrap_or_default())
207 .collect()
208 };
209 self.metrics
210 .counter("vortex.scan.segments.cancel_sent")
211 .add(cancelled_segments.len() as i64);
212 for id in cancelled_segments {
213 self.cancellations_tx
214 .send(id)
215 .await
216 .map_err(|_| vortex_err!("channel closed"))?;
217 }
218 Ok(())
219 }
220
221 pub fn retain_matching(&mut self, row_indices: Buffer<u64>) {
225 if row_indices.is_empty() {
226 return;
227 }
228 self.store
229 .write()
230 .vortex_expect("poisoned lock")
231 .retain(|key, _| {
232 if key.kind == RequiredSegmentKind::Pruning {
233 return true; }
235 let keep =
236 range_intersection(&(key.row_start..key.row_end), &row_indices).is_some();
237 if !keep {
238 self.metrics
239 .counter("vortex.scan.segment.pruned_by_row_indices")
240 .inc();
241 }
242 keep
243 });
244 }
245}
246
247pub struct SegmentStream {
248 store: Arc<RwLock<SegmentStore>>,
249 cancellations_rx: mpsc::UnboundedReceiver<SegmentId>,
250 current_key: SegmentPriority,
251 current_idx: usize,
252}
253
254pub enum SegmentEvent {
255 Cancel(SegmentId),
256 Request(SegmentId),
257}
258
259impl Stream for SegmentStream {
260 type Item = SegmentEvent;
261
262 fn poll_next(
263 mut self: std::pin::Pin<&mut Self>,
264 cx: &mut std::task::Context<'_>,
265 ) -> Poll<Option<Self::Item>> {
266 let channel_closed = match self.cancellations_rx.poll_next_unpin(cx) {
268 Poll::Ready(Some(segment)) => return Poll::Ready(Some(SegmentEvent::Cancel(segment))),
269 Poll::Ready(None) => true,
270 Poll::Pending => false,
271 };
272
273 {
274 let store_clone = self.store.clone();
275 let store_guard = store_clone.read().vortex_expect("poisoned lock");
276 let store_iter = store_guard.range(self.current_key..);
277 for (&key, segments) in store_iter {
278 match key == self.current_key {
279 true if self.current_idx >= segments.len() => continue,
280 false => {
281 self.current_idx = 0;
282 self.current_key = key;
283 }
284 _ => {}
285 };
286 let segment_to_yield = segments[self.current_idx];
287 self.current_idx += 1;
288 return Poll::Ready(Some(SegmentEvent::Request(segment_to_yield)));
289 }
290 }
291 if channel_closed {
293 return Poll::Ready(None);
294 }
295 match self.cancellations_rx.poll_next_unpin(cx) {
296 Poll::Ready(Some(segment)) => Poll::Ready(Some(SegmentEvent::Cancel(segment))),
297 Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending,
299 }
300 }
301}
302
303#[cfg(test)]
304pub mod test {
305 use futures::executor::block_on;
306 use vortex_array::aliases::hash_map::HashMap;
307 use vortex_array::aliases::hash_set::HashSet;
308 use vortex_buffer::ByteBufferMut;
309 use vortex_error::{VortexExpect, vortex_err};
310
311 use super::*;
312
313 #[derive(Default)]
314 pub struct TestSegments {
315 segments: Vec<ByteBuffer>,
316 }
317
318 impl SegmentWriter for TestSegments {
319 fn put(&mut self, data: &[ByteBuffer]) -> SegmentId {
320 let id = u32::try_from(self.segments.len())
321 .vortex_expect("Cannot store more than u32::MAX segments");
322
323 let mut buffer = ByteBufferMut::empty();
325 for segment in data {
326 buffer.extend_from_slice(segment.as_ref());
327 }
328 self.segments.push(buffer.freeze());
329
330 id.into()
331 }
332 }
333
334 #[async_trait]
335 impl AsyncSegmentReader for TestSegments {
336 async fn get(&self, id: SegmentId) -> VortexResult<ByteBuffer> {
337 self.segments
338 .get(*id as usize)
339 .cloned()
340 .ok_or_else(|| vortex_err!("Segment not found"))
341 }
342 }
343
344 fn setup_store() -> Arc<RwLock<SegmentStore>> {
345 let mut store = BTreeMap::new();
346
347 store.insert(
349 SegmentPriority::new(0, 100, RequiredSegmentKind::Projection),
350 vec![SegmentId(1)],
351 );
352 store.insert(
353 SegmentPriority::new(50, 150, RequiredSegmentKind::Projection),
354 vec![SegmentId(2)],
355 );
356 store.insert(
357 SegmentPriority::new(150, 250, RequiredSegmentKind::Filter),
358 vec![SegmentId(3)],
359 );
360 store.insert(
361 SegmentPriority::new(200, 300, RequiredSegmentKind::Projection),
362 vec![SegmentId(4)],
363 );
364 store.insert(
365 SegmentPriority::new(0, 0, RequiredSegmentKind::Pruning),
366 vec![SegmentId(5)],
367 );
368
369 Arc::new(RwLock::new(store))
370 }
371
372 #[test]
373 fn test_remove_fully_encompassed_segments() {
374 block_on(async {
375 let store = setup_store();
377 let (tx, mut rx) = mpsc::unbounded();
378 let mut pruner = RowRangePruner {
379 store: store.clone(),
380 cancellations_tx: tx,
381 excluded_ranges: Default::default(),
382 metrics: Default::default(),
383 };
384
385 let result = pruner.remove(0..200).await;
387 assert!(result.is_ok(), "Removal operation should succeed");
388
389 let store_lock = store.read().vortex_expect("poisoned lock");
391 assert!(!store_lock.contains_key(&SegmentPriority::new(
392 0,
393 100,
394 RequiredSegmentKind::Projection
395 )));
396 assert!(!store_lock.contains_key(&SegmentPriority::new(
397 50,
398 150,
399 RequiredSegmentKind::Projection
400 )));
401 assert!(store_lock.contains_key(&SegmentPriority::new(
402 150,
403 250,
404 RequiredSegmentKind::Filter
405 ))); assert!(store_lock.contains_key(&SegmentPriority::new(
407 200,
408 300,
409 RequiredSegmentKind::Projection
410 )));
411 assert!(store_lock.contains_key(&SegmentPriority::new(
412 0,
413 0,
414 RequiredSegmentKind::Pruning
415 )));
416
417 let mut received_cancellations = HashSet::new();
419 while let Ok(Some(id)) = rx.try_next() {
420 received_cancellations.insert(id);
421 }
422
423 assert!(received_cancellations.contains(&SegmentId(1)));
424 assert!(received_cancellations.contains(&SegmentId(2)));
425 assert!(!received_cancellations.contains(&SegmentId(3))); assert!(!received_cancellations.contains(&SegmentId(4)));
427 assert!(!received_cancellations.contains(&SegmentId(5)));
428 })
429 }
430
431 #[test]
432 fn test_no_double_cancellation() {
433 block_on(async {
434 let store = setup_store();
436 let (tx, mut rx) = mpsc::unbounded();
437 let mut pruner = RowRangePruner {
438 store: store.clone(),
439 cancellations_tx: tx,
440 excluded_ranges: Default::default(),
441 metrics: Default::default(),
442 };
443
444 let result = pruner.remove(0..100).await;
446 assert!(result.is_ok(), "First removal operation should succeed");
447
448 let result = pruner.remove(50..150).await;
450 assert!(result.is_ok(), "Second removal operation should succeed");
451
452 let result = pruner.remove(0..200).await;
454 assert!(result.is_ok(), "Third removal operation should succeed");
455
456 let mut received_cancellations = Vec::new();
458 while let Ok(Some(id)) = rx.try_next() {
459 received_cancellations.push(id);
460 }
461
462 let mut id_counts = HashMap::new();
464 for id in received_cancellations {
465 *id_counts.entry(id).or_insert(0) += 1;
466 }
467
468 for (id, count) in id_counts {
470 assert_eq!(count, 1, "Segment {:?} was cancelled {} times", id, count);
471 }
472 })
473 }
474
475 #[test]
476 fn test_range_merging() {
477 block_on(async {
478 let store = setup_store();
480 let (tx, _rx) = mpsc::unbounded();
481 let mut pruner = RowRangePruner {
482 store: store.clone(),
483 cancellations_tx: tx,
484 excluded_ranges: Default::default(),
485 metrics: Default::default(),
486 };
487
488 let result = pruner.remove(0..75).await;
490 assert!(result.is_ok());
491
492 let result = pruner.remove(75..150).await;
494 assert!(result.is_ok());
495
496 let result = pruner.remove(125..200).await;
498 assert!(result.is_ok());
499
500 let store_lock = store.read().vortex_expect("poisoned lock");
502 assert!(!store_lock.contains_key(&SegmentPriority::new(
503 0,
504 100,
505 RequiredSegmentKind::Projection
506 )));
507 assert!(!store_lock.contains_key(&SegmentPriority::new(
508 50,
509 150,
510 RequiredSegmentKind::Projection
511 )));
512 assert!(store_lock.contains_key(&SegmentPriority::new(
513 150,
514 250,
515 RequiredSegmentKind::Filter
516 ))); })
518 }
519
520 #[test]
521 fn test_retain_matching_with_pruning_segments() {
522 block_on(async {
523 let store = setup_store();
525 let (tx, _rx) = mpsc::unbounded();
526 let mut pruner = RowRangePruner {
527 store: store.clone(),
528 cancellations_tx: tx,
529 excluded_ranges: Default::default(),
530 metrics: Default::default(),
531 };
532
533 let row_indices = Buffer::from_iter(vec![75, 125, 175, 225, 325, 375]);
535
536 pruner.retain_matching(row_indices);
538
539 let store_lock = store.read().vortex_expect("poisoned lock");
541
542 assert!(store_lock.contains_key(&SegmentPriority::new(
544 0,
545 100,
546 RequiredSegmentKind::Projection
547 ))); assert!(store_lock.contains_key(&SegmentPriority::new(
549 50,
550 150,
551 RequiredSegmentKind::Projection
552 ))); assert!(store_lock.contains_key(&SegmentPriority::new(
554 150,
555 250,
556 RequiredSegmentKind::Filter
557 ))); assert!(store_lock.contains_key(&SegmentPriority::new(
559 200,
560 300,
561 RequiredSegmentKind::Projection
562 ))); assert!(store_lock.contains_key(&SegmentPriority::new(
566 0,
567 0,
568 RequiredSegmentKind::Pruning
569 )));
570 })
571 }
572
573 #[test]
574 fn test_cancellation_channel_closed() {
575 block_on(async {
576 let store = setup_store();
577 let (tx, rx) = mpsc::unbounded();
578 let mut pruner = RowRangePruner {
579 store: store.clone(),
580 cancellations_tx: tx,
581 excluded_ranges: Default::default(),
582 metrics: Default::default(),
583 };
584
585 drop(rx);
587
588 let result = pruner.remove(0..100).await;
590
591 assert!(
593 result.is_err(),
594 "Removal should fail when channel is closed"
595 );
596 })
597 }
598
599 #[test]
600 fn test_segments_of_different_kinds() {
601 block_on(async {
602 let store = setup_store();
604 let (tx, mut rx) = mpsc::unbounded();
605 let mut pruner = RowRangePruner {
606 store: store.clone(),
607 cancellations_tx: tx,
608 excluded_ranges: Default::default(),
609 metrics: Default::default(),
610 };
611
612 let result = pruner.remove(0..400).await;
614 assert!(result.is_ok(), "Removal operation should succeed");
615
616 let store_lock = store.read().vortex_expect("poisoned lock");
618 assert_eq!(store_lock.len(), 1);
619
620 let mut received_cancellations = HashSet::new();
622 while let Ok(Some(id)) = rx.try_next() {
623 received_cancellations.insert(id);
624 }
625
626 assert!(received_cancellations.contains(&SegmentId(1))); assert!(received_cancellations.contains(&SegmentId(2))); assert!(received_cancellations.contains(&SegmentId(3))); assert!(received_cancellations.contains(&SegmentId(4))); })
631 }
632}