use arrow_array::cast::AsArray;
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, VortexResult};
use crate::arrow::FromArrowArray;
use crate::{Array, ArrayDType, IntoCanonical};
pub trait FilterFn {
fn filter(&self, predicate: &Array) -> VortexResult<Array>;
}
pub fn filter(array: impl AsRef<Array>, predicate: impl AsRef<Array>) -> VortexResult<Array> {
let array = array.as_ref();
let predicate = predicate.as_ref();
if predicate.dtype() != &DType::Bool(Nullability::NonNullable) {
vortex_bail!(
"predicate must be non-nullable bool, has dtype {}",
predicate.dtype(),
);
}
if predicate.len() != array.len() {
vortex_bail!(
"predicate.len() is {}, does not equal array.len() of {}",
predicate.len(),
array.len()
);
}
array.with_dyn(|a| {
if let Some(filter_fn) = a.filter() {
filter_fn.filter(predicate)
} else {
let array_ref = array.clone().into_canonical()?.into_arrow()?;
let predicate_ref = predicate.clone().into_canonical()?.into_arrow()?;
let filtered =
arrow_select::filter::filter(array_ref.as_ref(), predicate_ref.as_boolean())?;
Ok(Array::from_arrow(filtered, array.dtype().is_nullable()))
}
})
}
#[cfg(test)]
mod test {
use crate::array::{BoolArray, PrimitiveArray};
use crate::compute::filter::filter;
use crate::validity::Validity;
use crate::{IntoArray, IntoCanonical};
#[test]
fn test_filter() {
let items =
PrimitiveArray::from_nullable_vec(vec![Some(0i32), None, Some(1i32), None, Some(2i32)])
.into_array();
let predicate =
BoolArray::from_vec(vec![true, false, true, false, true], Validity::NonNullable)
.into_array();
let filtered = filter(&items, &predicate).unwrap();
assert_eq!(
filtered
.into_canonical()
.unwrap()
.into_primitive()
.unwrap()
.into_maybe_null_slice::<i32>(),
vec![0i32, 1i32, 2i32]
);
}
}