vortex_layout/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5use std::sync::Arc;
6
7use futures::future::{BoxFuture, Shared};
8use futures::{FutureExt, TryFutureExt};
9use vortex_error::{SharedVortexResult, VortexError, VortexResult, vortex_panic};
10use vortex_mask::Mask;
11
12#[derive(Clone)]
13pub struct MaskFuture {
14    inner: Shared<BoxFuture<'static, SharedVortexResult<Mask>>>,
15    len: usize,
16}
17
18impl MaskFuture {
19    pub fn new<F>(len: usize, fut: F) -> Self
20    where
21        F: Future<Output = VortexResult<Mask>> + Send + 'static,
22    {
23        Self {
24            inner: fut
25                .inspect(move |r| {
26                    if let Ok(mask) = r
27                        && mask.len() != len {
28                            vortex_panic!("MaskFuture created with future that returned mask of incorrect length (expected {}, got {})", len, mask.len());
29                        }
30                })
31                .map_err(Arc::new)
32                .boxed()
33                .shared(),
34            len,
35        }
36    }
37
38    pub fn len(&self) -> usize {
39        self.len
40    }
41
42    pub fn is_empty(&self) -> bool {
43        self.len == 0
44    }
45
46    pub fn ready(mask: Mask) -> Self {
47        Self::new(mask.len(), async move { Ok(mask) })
48    }
49
50    pub fn new_true(row_count: usize) -> Self {
51        Self::ready(Mask::new_true(row_count))
52    }
53
54    pub fn slice(&self, range: Range<usize>) -> Self {
55        let inner = self.inner.clone();
56        Self::new(range.len(), async move { Ok(inner.await?.slice(range)) })
57    }
58}
59
60impl Future for MaskFuture {
61    type Output = VortexResult<Mask>;
62
63    fn poll(
64        mut self: std::pin::Pin<&mut Self>,
65        cx: &mut std::task::Context<'_>,
66    ) -> std::task::Poll<Self::Output> {
67        self.inner.poll_unpin(cx).map_err(VortexError::from)
68    }
69}