vortex_array/
mask_future.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::ops::Range;
6use std::sync::Arc;
7
8use futures::FutureExt;
9use futures::TryFutureExt;
10use futures::future::BoxFuture;
11use futures::future::Shared;
12use vortex_error::SharedVortexResult;
13use vortex_error::VortexError;
14use vortex_error::VortexResult;
15use vortex_error::vortex_panic;
16use vortex_mask::Mask;
17
18/// A future that resolves to a mask.
19#[derive(Clone)]
20pub struct MaskFuture {
21    inner: Shared<BoxFuture<'static, SharedVortexResult<Mask>>>,
22    len: usize,
23}
24
25impl MaskFuture {
26    /// Create a new MaskFuture from a future that returns a mask.
27    pub fn new<F>(len: usize, fut: F) -> Self
28    where
29        F: Future<Output = VortexResult<Mask>> + Send + 'static,
30    {
31        Self {
32            inner: fut
33                .inspect(move |r| {
34                    if let Ok(mask) = r
35                        && mask.len() != len {
36                            vortex_panic!("MaskFuture created with future that returned mask of incorrect length (expected {}, got {})", len, mask.len());
37                        }
38                })
39                .map_err(Arc::new)
40                .boxed()
41                .shared(),
42            len,
43        }
44    }
45
46    /// Returns the length of the mask.
47    pub fn len(&self) -> usize {
48        self.len
49    }
50
51    /// Returns true if the mask is empty.
52    pub fn is_empty(&self) -> bool {
53        self.len == 0
54    }
55
56    /// Create a MaskFuture from a ready mask.
57    pub fn ready(mask: Mask) -> Self {
58        Self::new(mask.len(), async move { Ok(mask) })
59    }
60
61    /// Create a MaskFuture that resolves to a mask with all values set to true.
62    pub fn new_true(row_count: usize) -> Self {
63        Self::ready(Mask::new_true(row_count))
64    }
65
66    /// Create a MaskFuture that resolves to a slice of the original mask.
67    pub fn slice(&self, range: Range<usize>) -> Self {
68        let inner = self.inner.clone();
69        Self::new(range.len(), async move { Ok(inner.await?.slice(range)) })
70    }
71
72    pub fn inspect(
73        self,
74        f: impl FnOnce(&SharedVortexResult<Mask>) + 'static + Send + Sync,
75    ) -> Self {
76        let len = self.len;
77
78        Self {
79            inner: self.inner.inspect(f).boxed().shared(),
80            len,
81        }
82    }
83}
84
85impl Future for MaskFuture {
86    type Output = VortexResult<Mask>;
87
88    fn poll(
89        mut self: std::pin::Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> std::task::Poll<Self::Output> {
92        self.inner.poll_unpin(cx).map_err(VortexError::from)
93    }
94}