vortex_array/array/sparse/compute/
take.rsuse std::convert::identity;
use itertools::Itertools;
use vortex_dtype::match_each_integer_ptype;
use vortex_error::VortexResult;
use crate::aliases::hash_map::HashMap;
use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::SparseArray;
use crate::compute::{take, TakeFn};
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayData, IntoArrayData, IntoArrayVariant};
impl TakeFn for SparseArray {
    fn take(&self, indices: &ArrayData) -> VortexResult<ArrayData> {
        let flat_indices = indices.clone().into_primitive()?;
        let (positions, physical_take_indices) = if indices.len() > 128 {
            take_map(self, &flat_indices)?
        } else {
            take_search_sorted(self, &flat_indices)?
        };
        let taken_values = take(self.values(), physical_take_indices)?;
        Ok(Self::try_new(
            positions.into_array(),
            taken_values,
            indices.len(),
            self.fill_value().clone(),
        )?
        .into_array())
    }
}
fn take_map(
    array: &SparseArray,
    indices: &PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
    let indices_map: HashMap<u64, u64> = array
        .resolved_indices()
        .iter()
        .enumerate()
        .map(|(i, r)| (*r as u64, i as u64))
        .collect();
    let min_index = array.min_index().unwrap_or_default() as u64;
    let max_index = array.max_index().unwrap_or_default() as u64;
    let (positions, patch_indices): (Vec<u64>, Vec<u64>) = match_each_integer_ptype!(indices.ptype(), |$P| {
        indices.maybe_null_slice::<$P>()
            .iter()
            .map(|pi| *pi as u64)
            .enumerate()
            .filter(|(_, pi)| *pi >= min_index && *pi <= max_index) .filter_map(|(i, pi)| indices_map.get(&pi).map(|phy_idx| (i as u64, phy_idx)))
            .unzip()
    });
    Ok((
        PrimitiveArray::from(positions),
        PrimitiveArray::from(patch_indices),
    ))
}
fn take_search_sorted(
    array: &SparseArray,
    indices: &PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
    let min_index = array.min_index().unwrap_or_default() as u64;
    let max_index = array.max_index().unwrap_or_default() as u64;
    let resolved = match_each_integer_ptype!(indices.ptype(), |$P| {
        indices
            .maybe_null_slice::<$P>()
            .iter()
            .enumerate()
            .filter(|(_, i)| **i as u64 >= min_index && **i as u64 <= max_index) .map(|(pos, i)| {
                array
                    .search_index(*i as usize)
                    .map(|r| r.to_found().map(|ii| (pos as u64, ii as u64)))
            })
            .filter_map_ok(identity)
            .collect::<VortexResult<Vec<_>>>()?
    });
    let (positions, patch_indices): (Vec<u64>, Vec<u64>) = resolved.into_iter().unzip();
    Ok((
        PrimitiveArray::from(positions),
        PrimitiveArray::from(patch_indices),
    ))
}
#[cfg(test)]
mod test {
    use itertools::Itertools;
    use vortex_scalar::ScalarValue;
    use crate::array::primitive::PrimitiveArray;
    use crate::array::sparse::compute::take::take_map;
    use crate::array::sparse::SparseArray;
    use crate::compute::take;
    use crate::validity::Validity;
    use crate::{ArrayData, IntoArrayData, IntoArrayVariant};
    fn sparse_array() -> ArrayData {
        SparseArray::try_new(
            PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
            PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid)
                .into_array(),
            100,
            ScalarValue::Null,
        )
        .unwrap()
        .into_array()
    }
    #[test]
    fn sparse_take() {
        let sparse = sparse_array();
        let taken =
            SparseArray::try_from(take(sparse, vec![0, 47, 47, 0, 99].into_array()).unwrap())
                .unwrap();
        assert_eq!(
            taken
                .indices()
                .into_primitive()
                .unwrap()
                .maybe_null_slice::<u64>(),
            [0, 1, 2, 3, 4]
        );
        assert_eq!(
            taken
                .values()
                .into_primitive()
                .unwrap()
                .maybe_null_slice::<f64>(),
            [1.23f64, 9.99, 9.99, 1.23, 3.5]
        );
    }
    #[test]
    fn nonexistent_take() {
        let sparse = sparse_array();
        let taken = SparseArray::try_from(take(sparse, vec![69].into_array()).unwrap()).unwrap();
        assert!(taken
            .indices()
            .into_primitive()
            .unwrap()
            .maybe_null_slice::<u64>()
            .is_empty());
        assert!(taken
            .values()
            .into_primitive()
            .unwrap()
            .maybe_null_slice::<f64>()
            .is_empty());
    }
    #[test]
    fn ordered_take() {
        let sparse = sparse_array();
        let taken =
            SparseArray::try_from(take(&sparse, vec![69, 37].into_array()).unwrap()).unwrap();
        assert_eq!(
            taken
                .indices()
                .into_primitive()
                .unwrap()
                .maybe_null_slice::<u64>(),
            [1]
        );
        assert_eq!(
            taken
                .values()
                .into_primitive()
                .unwrap()
                .maybe_null_slice::<f64>(),
            [0.47f64]
        );
        assert_eq!(taken.len(), 2);
    }
    #[test]
    fn test_take_map() {
        let sparse = SparseArray::try_from(sparse_array()).unwrap();
        let indices = PrimitiveArray::from((0u64..100).collect_vec());
        let (positions, patch_indices) = take_map(&sparse, &indices).unwrap();
        assert_eq!(
            positions.maybe_null_slice::<u64>(),
            sparse
                .indices()
                .into_primitive()
                .unwrap()
                .maybe_null_slice::<u64>()
        );
        assert_eq!(patch_indices.maybe_null_slice::<u64>(), [0u64, 1, 2, 3]);
    }
}