use std::cmp::{max, min};
use std::fmt::{Display, Formatter};
use std::ops::RangeBounds;
use vortex_array::array::BooleanBuffer;
use vortex_array::compute::{filter, slice, try_cast};
use vortex_array::{Array, IntoArrayVariant};
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, vortex_err, VortexExpect, VortexResult};
use vortex_mask::Mask;
#[derive(Debug, Clone)]
pub struct RowMask {
mask: Mask,
begin: u64,
end: u64,
}
#[cfg(test)]
impl PartialEq for RowMask {
fn eq(&self, other: &Self) -> bool {
self.begin == other.begin && self.end == other.end && self.mask == other.mask
}
}
impl Display for RowMask {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "RowSelector [{}..{}]", self.begin, self.end)
}
}
impl RowMask {
pub fn new(mask: Mask, begin: u64) -> Self {
let end = begin + (mask.len() as u64);
Self { mask, begin, end }
}
pub fn new_valid_between(begin: u64, end: u64) -> Self {
let length =
usize::try_from(end - begin).vortex_expect("Range length does not fit into a usize");
RowMask::new(Mask::from(BooleanBuffer::new_set(length)), begin)
}
pub fn new_invalid_between(begin: u64, end: u64) -> Self {
let length =
usize::try_from(end - begin).vortex_expect("Range length does not fit into a usize");
RowMask::new(Mask::from(BooleanBuffer::new_unset(length)), begin)
}
pub fn from_array(array: &Array, begin: u64, end: u64) -> VortexResult<Self> {
if array.dtype().is_int() {
Self::from_index_array(array, begin, end)
} else if array.dtype().is_boolean() {
Self::from_mask_array(array, begin)
} else {
vortex_bail!(
"RowMask can only be created from integer or boolean arrays, got {} instead.",
array.dtype()
);
}
}
fn from_mask_array(array: &Array, begin: u64) -> VortexResult<Self> {
Ok(Self::new(array.logical_validity()?, begin))
}
#[allow(clippy::cast_possible_truncation)]
fn from_index_array(array: &Array, begin: u64, end: u64) -> VortexResult<Self> {
let length = usize::try_from(end - begin)
.map_err(|_| vortex_err!("Range length does not fit into a usize"))?;
let indices =
try_cast(array, &DType::Primitive(PType::U64, NonNullable))?.into_primitive()?;
let mask = Mask::from_indices(
length,
indices
.as_slice::<u64>()
.iter()
.map(|i| *i as usize)
.collect(),
);
Ok(RowMask::new(mask, begin))
}
pub fn is_disjoint(&self, range: impl RangeBounds<u64>) -> bool {
use std::ops::Bound;
let start = match range.start_bound() {
Bound::Included(&n) => n,
Bound::Excluded(&n) => n + 1,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(&n) => n + 1,
Bound::Excluded(&n) => n,
Bound::Unbounded => u64::MAX,
};
self.end <= start || end <= self.begin
}
#[inline]
pub fn begin(&self) -> u64 {
self.begin
}
#[inline]
pub fn end(&self) -> u64 {
self.end
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.mask.len()
}
pub fn filter_mask(&self) -> &Mask {
&self.mask
}
pub fn slice(&self, begin: u64, end: u64) -> VortexResult<Self> {
let range_begin = max(self.begin, begin);
let range_end = min(self.end, end);
Ok(RowMask::new(
if range_begin == self.begin && range_end == self.end {
self.mask.clone()
} else {
self.mask.slice(
usize::try_from(range_begin - self.begin)
.vortex_expect("we know this must fit into usize"),
usize::try_from(range_end - range_begin)
.vortex_expect("we know this must fit into usize"),
)
},
range_begin,
))
}
pub fn filter_array(&self, array: impl AsRef<Array>) -> VortexResult<Option<Array>> {
let true_count = self.mask.true_count();
if true_count == 0 {
return Ok(None);
}
let array = array.as_ref();
let sliced = if self.len() == array.len() {
array
} else {
&slice(
array,
usize::try_from(self.begin).vortex_expect("TODO(ngates): fix this bad cast"),
usize::try_from(self.end).vortex_expect("TODO(ngates): fix this bad cast"),
)?
};
if true_count == sliced.len() {
return Ok(Some(sliced.clone()));
}
filter(sliced, &self.mask).map(Some)
}
pub fn shift(self, offset: u64) -> VortexResult<RowMask> {
let valid_shift = self.begin >= offset;
if !valid_shift {
vortex_bail!(
"Can shift RowMask by at most {}, tried to shift by {offset}",
self.begin
)
}
Ok(RowMask::new(self.mask, self.begin - offset))
}
pub fn true_count(&self) -> usize {
self.mask.true_count()
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::array::PrimitiveArray;
use vortex_array::validity::Validity;
use vortex_array::{IntoArray, IntoArrayVariant};
use vortex_buffer::{buffer, Buffer};
use vortex_error::VortexUnwrap;
use vortex_mask::Mask;
use super::*;
#[rstest]
#[case(
RowMask::new(Mask::from_iter([true, true, true, false, false, false, false, false, true, true]), 0), (0, 1),
RowMask::new(Mask::from_iter([true]), 0))]
#[case(
RowMask::new(Mask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0), (2, 5),
RowMask::new(Mask::from_iter([false, false, false]), 2)
)]
#[case(
RowMask::new(Mask::from_iter([true, true, true, true, false, false, false, false, false, false]), 0), (2, 5),
RowMask::new(Mask::from_iter([true, true, false]), 2)
)]
#[case(
RowMask::new(Mask::from_iter([true, true, true, false, false, true, true, false, false, false]), 0), (2, 6),
RowMask::new(Mask::from_iter([true, false, false, true]), 2))]
#[case(
RowMask::new(Mask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0), (7, 11),
RowMask::new(Mask::from_iter([true, true, true]), 7))]
#[case(
RowMask::new(Mask::from_iter([false, true, true, true, true, true]), 3), (0, 5),
RowMask::new(Mask::from_iter([false, true]), 3))]
#[cfg_attr(miri, ignore)]
fn slice(#[case] first: RowMask, #[case] range: (u64, u64), #[case] expected: RowMask) {
assert_eq!(first.slice(range.0, range.1).vortex_unwrap(), expected);
}
#[test]
#[should_panic]
#[cfg_attr(miri, ignore)]
fn shift_invalid() {
RowMask::new(Mask::from_iter([true, true, true, true, true]), 5)
.shift(7)
.unwrap();
}
#[test]
#[cfg_attr(miri, ignore)]
fn shift() {
assert_eq!(
RowMask::new(Mask::from_iter([true, true, true, true, true]), 5)
.shift(5)
.unwrap(),
RowMask::new(Mask::from_iter([true, true, true, true, true]), 0)
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn filter_array() {
let mask = RowMask::new(
Mask::from_iter([
false, false, false, false, false, true, true, true, true, true,
]),
0,
);
let array = Buffer::from_iter(0..20).into_array();
let filtered = mask.filter_array(array).unwrap().unwrap();
assert_eq!(
filtered.into_primitive().unwrap().as_slice::<i32>(),
(5..10).collect::<Vec<_>>()
);
}
#[test]
#[should_panic]
fn test_row_mask_type_validation() {
let array = PrimitiveArray::new(buffer![1.0, 2.0], Validity::AllInvalid).into_array();
RowMask::from_array(&array, 0, 2).unwrap();
}
}