vortex_array/array/primitive/compute/
compare.rsuse arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer};
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::{vortex_err, VortexExpect, VortexResult};
use vortex_scalar::PrimitiveScalar;
use crate::array::primitive::PrimitiveArray;
use crate::array::{BoolArray, ConstantArray};
use crate::compute::{MaybeCompareFn, Operator};
use crate::variants::PrimitiveArrayTrait;
use crate::{Array, ArrayDType, IntoArray};
impl MaybeCompareFn for PrimitiveArray {
    fn maybe_compare(&self, other: &Array, operator: Operator) -> Option<VortexResult<Array>> {
        if let Ok(const_array) = ConstantArray::try_from(other) {
            return Some(primitive_const_compare(self, const_array, operator));
        }
        if let Ok(primitive) = PrimitiveArray::try_from(other) {
            let match_mask = match_each_native_ptype!(self.ptype(), |$T| {
                apply_predicate(self.maybe_null_slice::<$T>(), primitive.maybe_null_slice::<$T>(), operator.to_fn::<$T>())
            });
            let validity = self
                .validity()
                .and(primitive.validity())
                .map(|v| v.into_nullable());
            return Some(
                validity
                    .and_then(|v| BoolArray::try_new(match_mask, v))
                    .map(|a| a.into_array()),
            );
        }
        None
    }
}
fn primitive_const_compare(
    this: &PrimitiveArray,
    other: ConstantArray,
    operator: Operator,
) -> VortexResult<Array> {
    let primitive_scalar = PrimitiveScalar::try_new(other.dtype(), other.scalar_value())
        .vortex_expect("Expected a primitive scalar");
    let buffer = match_each_native_ptype!(this.ptype(), |$T| {
        let typed_value = primitive_scalar.typed_value::<$T>()
            .ok_or_else(|| vortex_err!("Type mismatch between array and constant"))?;
        primitive_value_compare::<$T>(this, typed_value, operator)
    });
    Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array())
}
fn primitive_value_compare<T: NativePType>(
    this: &PrimitiveArray,
    value: T,
    op: Operator,
) -> BooleanBuffer {
    let op_fn = op.to_fn::<T>();
    let slice = this.maybe_null_slice::<T>();
    BooleanBuffer::collect_bool(this.len(), |idx| {
        op_fn(unsafe { *slice.get_unchecked(idx) }, value)
    })
}
fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
    lhs: &[T],
    rhs: &[T],
    f: F,
) -> BooleanBuffer {
    const BLOCK_SIZE: usize = u64::BITS as usize;
    let len = lhs.len();
    let reminder = len % BLOCK_SIZE;
    let block_count = len / BLOCK_SIZE;
    let mut buffer = MutableBuffer::new(ceil(len, BLOCK_SIZE) * 8);
    for block in 0..block_count {
        let mut packed_block = 0_u64;
        for bit_idx in 0..BLOCK_SIZE {
            let idx = bit_idx + block * BLOCK_SIZE;
            let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe {
                *rhs.get_unchecked(idx)
            });
            packed_block |= (r as u64) << bit_idx;
        }
        unsafe {
            buffer.push_unchecked(packed_block);
        }
    }
    if reminder != 0 {
        let mut packed_block = 0_u64;
        for bit_idx in 0..reminder {
            let idx = bit_idx + block_count * BLOCK_SIZE;
            let r = f(lhs[idx], rhs[idx]);
            packed_block |= (r as u64) << bit_idx;
        }
        unsafe {
            buffer.push_unchecked(packed_block);
        }
    }
    BooleanBuffer::new(buffer.into(), 0, len)
}
#[cfg(test)]
#[allow(clippy::panic_in_result_fn)]
mod test {
    use itertools::Itertools;
    use super::*;
    use crate::compute::compare;
    use crate::IntoArrayVariant;
    fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
        let filtered = indices_bits
            .boolean_buffer()
            .iter()
            .enumerate()
            .filter_map(|(idx, v)| {
                let valid_and_true = indices_bits.validity().is_valid(idx) & v;
                valid_and_true.then_some(idx as u64)
            })
            .collect_vec();
        filtered
    }
    #[test]
    fn test_basic_comparisons() -> VortexResult<()> {
        let arr = PrimitiveArray::from_nullable_vec(vec![
            Some(1i32),
            Some(2),
            Some(3),
            Some(4),
            None,
            Some(5),
            Some(6),
            Some(7),
            Some(8),
            None,
            Some(9),
            None,
        ])
        .into_array();
        let matches = compare(&arr, &arr, Operator::Eq)?.into_bool()?;
        assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);
        let matches = compare(&arr, &arr, Operator::NotEq)?.into_bool()?;
        let empty: [u64; 0] = [];
        assert_eq!(to_int_indices(matches), empty);
        let other = PrimitiveArray::from_nullable_vec(vec![
            Some(1i32),
            Some(2),
            Some(3),
            Some(4),
            None,
            Some(6),
            Some(7),
            Some(8),
            Some(9),
            None,
            Some(10),
            None,
        ])
        .into_array();
        let matches = compare(&arr, &other, Operator::Lte)?.into_bool()?;
        assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);
        let matches = compare(&arr, &other, Operator::Lt)?.into_bool()?;
        assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
        let matches = compare(&other, &arr, Operator::Gte)?.into_bool()?;
        assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);
        let matches = compare(&other, &arr, Operator::Gt)?.into_bool()?;
        assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
        Ok(())
    }
}