vortex_array/array/chunked/compute/
take.rsuse itertools::Itertools;
use vortex_dtype::PType;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use crate::array::chunked::ChunkedArray;
use crate::compute::unary::{scalar_at, subtract_scalar, try_cast};
use crate::compute::{search_sorted, slice, take, SearchSortedSide, TakeFn};
use crate::stats::ArrayStatistics;
use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant, ToArrayData};
impl TakeFn for ChunkedArray {
    fn take(&self, indices: &ArrayData) -> VortexResult<ArrayData> {
        if indices
            .statistics()
            .compute_is_strict_sorted()
            .unwrap_or(false)
        {
            if self.len() == indices.len() {
                return Ok(self.to_array());
            }
            return take_strict_sorted(self, indices);
        }
        let indices = try_cast(indices, PType::U64.into())?.into_primitive()?;
        let mut chunks = Vec::new();
        let mut indices_in_chunk = Vec::new();
        let mut prev_chunk_idx = self
            .find_chunk_idx(indices.maybe_null_slice::<u64>()[0] as usize)
            .0;
        for idx in indices.maybe_null_slice::<u64>() {
            let (chunk_idx, idx_in_chunk) = self.find_chunk_idx(*idx as usize);
            if chunk_idx != prev_chunk_idx {
                let indices_in_chunk_array = indices_in_chunk.clone().into_array();
                chunks.push(take(&self.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
                indices_in_chunk = Vec::new();
            }
            indices_in_chunk.push(idx_in_chunk as u64);
            prev_chunk_idx = chunk_idx;
        }
        if !indices_in_chunk.is_empty() {
            let indices_in_chunk_array = indices_in_chunk.into_array();
            chunks.push(take(&self.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?);
        }
        Ok(Self::try_new(chunks, self.dtype().clone())?.into_array())
    }
}
fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResult<ArrayData> {
    let mut indices_by_chunk = vec![None; chunked.nchunks()];
    let mut pos = 0;
    while pos < indices.len() {
        let idx = usize::try_from(&scalar_at(indices, pos)?)?;
        let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx);
        let chunk_begin = usize::try_from(&scalar_at(chunked.chunk_offsets(), chunk_idx)?)?;
        let chunk_end = usize::try_from(&scalar_at(chunked.chunk_offsets(), chunk_idx + 1)?)?;
        let chunk_end_pos = search_sorted(indices, chunk_end, SearchSortedSide::Left)?.to_index();
        let chunk_indices = slice(indices, pos, chunk_end_pos)?;
        let chunk_indices =
            if chunk_begin < PType::try_from(chunk_indices.dtype())?.max_value_as_u64() as usize {
                subtract_scalar(
                    &chunk_indices,
                    &Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
                )?
            } else {
                let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())?;
                subtract_scalar(&u64_chunk_indices, &chunk_begin.into())?
            };
        indices_by_chunk[chunk_idx] = Some(chunk_indices);
        pos = chunk_end_pos;
    }
    let chunks = indices_by_chunk
        .into_iter()
        .enumerate()
        .filter_map(|(chunk_idx, indices)| indices.map(|i| (chunk_idx, i)))
        .map(|(chunk_idx, chunk_indices)| take(&chunked.chunk(chunk_idx)?, &chunk_indices))
        .try_collect()?;
    Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array())
}
#[cfg(test)]
mod test {
    use crate::array::chunked::ChunkedArray;
    use crate::compute::take;
    use crate::{ArrayDType, IntoArrayData, IntoArrayVariant};
    #[test]
    fn test_take() {
        let a = vec![1i32, 2, 3].into_array();
        let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
            .unwrap();
        assert_eq!(arr.nchunks(), 3);
        assert_eq!(arr.len(), 9);
        let indices = vec![0u64, 0, 6, 4].into_array();
        let result = &ChunkedArray::try_from(take(arr.as_ref(), &indices).unwrap())
            .unwrap()
            .into_array()
            .into_primitive()
            .unwrap();
        assert_eq!(result.maybe_null_slice::<i32>(), &[1, 1, 1, 2]);
    }
}