use std::cmp::{max, min};
use std::fmt::{Display, Formatter};
use arrow_buffer::{BooleanBuffer, MutableBuffer};
use croaring::Bitmap;
use vortex_array::array::{BoolArray, PrimitiveArray};
use vortex_array::compute::{filter, slice, take};
use vortex_array::validity::Validity;
use vortex_array::{iterate_integer_array, Array, IntoArray};
use vortex_dtype::PType;
use vortex_error::{vortex_bail, vortex_err, VortexResult};
const PREFER_TAKE_TO_FILTER_DENSITY: f64 = 1.0 / 1024.0;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct RowMask {
values: Bitmap,
begin: usize,
end: usize,
}
impl Display for RowMask {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "RowSelector [{}..{}]", self.begin, self.end)
}
}
impl RowMask {
pub fn try_new(values: Bitmap, begin: usize, end: usize) -> VortexResult<Self> {
if values
.maximum()
.map(|m| m > (end - begin) as u32)
.unwrap_or(false)
{
vortex_bail!("Values bitmap must be in 0..(end-begin) range")
}
Ok(Self { values, begin, end })
}
pub fn new_valid_between(begin: usize, end: usize) -> Self {
unsafe { RowMask::new_unchecked(Bitmap::from_range(0..(end - begin) as u32), begin, end) }
}
pub unsafe fn new_unchecked(values: Bitmap, begin: usize, end: usize) -> Self {
Self { values, begin, end }
}
pub fn from_mask_array(array: &Array, begin: usize, end: usize) -> VortexResult<Self> {
array.with_dyn(|a| {
a.as_bool_array()
.ok_or_else(|| vortex_err!("Must be a bool array"))
.map(|b| {
let mut bitmap = Bitmap::new();
for (sb, se) in b.maybe_null_slices_iter() {
bitmap.add_range(sb as u32..se as u32);
}
unsafe { RowMask::new_unchecked(bitmap, begin, end) }
})
})
}
pub fn from_index_array(array: &Array, begin: usize, end: usize) -> VortexResult<Self> {
array.with_dyn(|a| {
let err = || vortex_err!(InvalidArgument: "index array must be integers in the range [0, 2^32)");
let array = a.as_primitive_array().ok_or_else(err)?;
if !array.ptype().is_int() {
return Err(err());
}
let mut bitmap = Bitmap::new();
iterate_integer_array!(array, |$P, $iterator| {
for batch in $iterator {
for index in batch.data() {
bitmap.add(u32::try_from(*index).map_err(|_| err())?);
}
}
});
Ok(unsafe { RowMask::new_unchecked(bitmap, begin, end) })
})
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn begin(&self) -> usize {
self.begin
}
pub fn end(&self) -> usize {
self.end
}
pub fn len(&self) -> usize {
self.end - self.begin
}
pub fn slice(&self, begin: usize, end: usize) -> Self {
let range_begin = max(self.begin, begin);
let range_end = min(self.end, end);
let mask =
Bitmap::from_range((range_begin - self.begin) as u32..(range_end - self.begin) as u32);
unsafe {
RowMask::new_unchecked(
self.values
.and(&mask)
.add_offset(-((range_begin - self.begin) as i64)),
range_begin,
range_end,
)
}
}
pub fn and_inplace(&mut self, other: &RowMask) -> VortexResult<()> {
if self.begin != other.begin || self.end != other.end {
vortex_bail!(
"begin and ends must match: {}-{} {}-{}",
self.begin,
self.end,
other.begin,
other.end
);
}
self.values.and_inplace(&other.values);
Ok(())
}
pub fn filter_array(&self, array: impl AsRef<Array>) -> VortexResult<Option<Array>> {
let true_count = self.values.cardinality();
if true_count == 0 {
return Ok(None);
}
let array = array.as_ref();
let sliced = if self.len() == array.len() {
array
} else {
&slice(array, self.begin, self.end)?
};
if true_count == sliced.len() as u64 {
return Ok(Some(sliced.clone()));
}
if (true_count as f64 / sliced.len() as f64) < PREFER_TAKE_TO_FILTER_DENSITY {
let indices = self.to_indices_array()?;
take(sliced, indices).map(Some)
} else {
let mask = self.to_mask_array()?;
filter(sliced, mask).map(Some)
}
}
pub fn to_indices_array(&self) -> VortexResult<Array> {
Ok(PrimitiveArray::from_vec(self.values.to_vec(), Validity::NonNullable).into_array())
}
pub fn to_mask_array(&self) -> VortexResult<Array> {
let bitset = self
.values
.to_bitset()
.ok_or_else(|| vortex_err!("Couldn't create bitset for RowSelection"))?;
let byte_length = self.len().div_ceil(8);
let mut buffer = MutableBuffer::with_capacity(byte_length);
buffer.extend_from_slice(bitset.as_slice());
if byte_length > bitset.size_in_bytes() {
buffer.extend_zeros(byte_length - bitset.size_in_bytes());
}
BoolArray::try_new(
BooleanBuffer::new(buffer.into(), 0, self.len()),
Validity::NonNullable,
)
.map(IntoArray::into_array)
}
pub fn shift(self, offset: usize) -> VortexResult<RowMask> {
let valid_shift = self.begin >= offset;
if !valid_shift {
vortex_bail!(
"Can shift RowMask by at most {}, tried to shift by {offset}",
self.begin
)
}
Ok(unsafe { RowMask::new_unchecked(self.values, self.begin - offset, self.end - offset) })
}
}
#[cfg(test)]
mod tests {
use croaring::Bitmap;
use rstest::rstest;
use vortex_array::array::PrimitiveArray;
use vortex_array::{IntoArray, IntoArrayVariant};
use crate::file::read::mask::RowMask;
#[rstest]
#[case(RowMask::try_new((0..2).chain(9..10).collect(), 0, 10).unwrap(), (0, 1), RowMask::try_new((0..1).collect(), 0, 1).unwrap())]
#[case(RowMask::try_new((5..8).chain(9..10).collect(), 0, 10).unwrap(), (2, 5), RowMask::try_new(Bitmap::new(), 2, 5).unwrap())]
#[case(RowMask::try_new((0..4).collect(), 0, 10).unwrap(), (2, 5), RowMask::try_new((0..2).collect(), 2, 5).unwrap())]
#[case(RowMask::try_new((0..3).chain(5..6).collect(), 0, 10).unwrap(), (2, 6), RowMask::try_new((0..1).chain(3..4).collect(), 2, 6).unwrap())]
#[case(RowMask::try_new((5..10).collect(), 0, 10).unwrap(), (7, 11), RowMask::try_new((0..3).collect(), 7, 10).unwrap())]
#[case(RowMask::try_new((1..6).collect(), 3, 9).unwrap(), (0, 5), RowMask::try_new((1..2).collect(), 3, 5).unwrap())]
#[cfg_attr(miri, ignore)]
fn slice(#[case] first: RowMask, #[case] range: (usize, usize), #[case] expected: RowMask) {
assert_eq!(first.slice(range.0, range.1), expected);
}
#[test]
#[should_panic]
#[cfg_attr(miri, ignore)]
fn test_new() {
RowMask::try_new((5..10).collect(), 5, 10).unwrap();
}
#[test]
#[should_panic]
#[cfg_attr(miri, ignore)]
fn shift_invalid() {
RowMask::try_new((0..5).collect(), 5, 10)
.unwrap()
.shift(7)
.unwrap();
}
#[test]
#[cfg_attr(miri, ignore)]
fn shift() {
assert_eq!(
RowMask::try_new((0..5).collect(), 5, 10)
.unwrap()
.shift(5)
.unwrap(),
RowMask::try_new((0..5).collect(), 0, 5).unwrap()
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn filter_array() {
let mask = RowMask::try_new((5..10).collect(), 0, 10).unwrap();
let array = PrimitiveArray::from((0..20).collect::<Vec<_>>()).into_array();
let filtered = mask.filter_array(array).unwrap().unwrap();
assert_eq!(
filtered.into_primitive().unwrap().maybe_null_slice::<i32>(),
(5..10).collect::<Vec<_>>()
);
}
}