use itertools::Itertools;
use vortex_dtype::PType;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use crate::array::chunked::ChunkedArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::unary::{scalar_at, subtract_scalar, try_cast};
use crate::compute::{search_sorted, slice, take, SearchSortedSide, TakeFn};
use crate::stats::ArrayStatistics;
use crate::{Array, ArrayDType, IntoArray, ToArray};
impl TakeFn for ChunkedArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
if indices
.statistics()
.compute_is_strict_sorted()
.unwrap_or(false)
{
if self.len() == indices.len() {
return Ok(self.to_array());
}
return take_strict_sorted(self, indices);
}
let indices = PrimitiveArray::try_from(try_cast(indices, PType::U64.into())?)?;
let mut chunks = Vec::new();
let mut indices_in_chunk = Vec::new();
let mut prev_chunk_idx = self
.find_chunk_idx(indices.maybe_null_slice::<u64>()[0] as usize)
.0;
for idx in indices.maybe_null_slice::<u64>() {
let (chunk_idx, idx_in_chunk) = self.find_chunk_idx(*idx as usize);
if chunk_idx != prev_chunk_idx {
let indices_in_chunk_array = indices_in_chunk.clone().into_array();
chunks.push(take(
&self.chunk(prev_chunk_idx).unwrap(),
&indices_in_chunk_array,
)?);
indices_in_chunk = Vec::new();
}
indices_in_chunk.push(idx_in_chunk as u64);
prev_chunk_idx = chunk_idx;
}
if !indices_in_chunk.is_empty() {
let indices_in_chunk_array = indices_in_chunk.into_array();
chunks.push(take(
&self.chunk(prev_chunk_idx).unwrap(),
&indices_in_chunk_array,
)?);
}
Ok(Self::try_new(chunks, self.dtype().clone())?.into_array())
}
}
fn take_strict_sorted(chunked: &ChunkedArray, indices: &Array) -> VortexResult<Array> {
let mut indices_by_chunk = vec![None; chunked.nchunks()];
let mut pos = 0;
while pos < indices.len() {
let idx = usize::try_from(&scalar_at(indices, pos)?).unwrap();
let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx);
let chunk_begin = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx)?).unwrap();
let chunk_end = usize::try_from(&scalar_at(&chunked.chunk_ends(), chunk_idx + 1)?).unwrap();
let chunk_end_pos = search_sorted(indices, chunk_end, SearchSortedSide::Left)
.unwrap()
.to_index();
let chunk_indices = slice(indices, pos, chunk_end_pos)?;
let chunk_indices = if chunk_begin < PType::try_from(chunk_indices.dtype())?.max_value() {
subtract_scalar(
&chunk_indices,
&Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
)?
} else {
let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())
.expect("safe to upcast since all indices are positive");
subtract_scalar(&u64_chunk_indices, &chunk_begin.into())?
};
indices_by_chunk[chunk_idx] = Some(chunk_indices);
pos = chunk_end_pos;
}
let chunks = indices_by_chunk
.iter()
.enumerate()
.filter_map(|(chunk_idx, indices)| indices.as_ref().map(|i| (chunk_idx, i)))
.map(|(chunk_idx, chunk_indices)| {
take(
&chunked.chunk(chunk_idx).expect("chunk not found"),
chunk_indices,
)
})
.try_collect()?;
Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array())
}
#[cfg(test)]
mod test {
use crate::array::chunked::ChunkedArray;
use crate::compute::take;
use crate::{ArrayDType, AsArray, IntoArray, IntoArrayVariant};
#[test]
fn test_take() {
let a = vec![1i32, 2, 3].into_array();
let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
.unwrap();
assert_eq!(arr.nchunks(), 3);
assert_eq!(arr.len(), 9);
let indices = vec![0u64, 0, 6, 4].into_array();
let result = &ChunkedArray::try_from(take(arr.as_array_ref(), &indices).unwrap())
.unwrap()
.into_array()
.into_primitive()
.unwrap();
assert_eq!(result.maybe_null_slice::<i32>(), &[1, 1, 1, 2]);
}
}