use std::cmp::Ordering;
use std::cmp::Ordering::{Equal, Greater, Less};
use std::fmt::{Debug, Display, Formatter};
use std::hint;
use itertools::Itertools;
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_scalar::Scalar;
use crate::compute::scalar_at;
use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData};
#[derive(Debug, Copy, Clone)]
pub enum SearchSortedSide {
Left,
Right,
}
impl Display for SearchSortedSide {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SearchSortedSide::Left => write!(f, "left"),
SearchSortedSide::Right => write!(f, "right"),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SearchResult {
Found(usize),
NotFound(usize),
}
impl SearchResult {
pub fn to_found(self) -> Option<usize> {
match self {
Self::Found(i) => Some(i),
Self::NotFound(_) => None,
}
}
pub fn to_index(self) -> usize {
match self {
Self::Found(i) => i,
Self::NotFound(i) => i,
}
}
pub fn to_offsets_index(self, len: usize) -> usize {
match self {
SearchResult::Found(i) => {
if i == len {
i - 1
} else {
i
}
}
SearchResult::NotFound(i) => i.saturating_sub(1),
}
}
pub fn to_ends_index(self, len: usize) -> usize {
let idx = self.to_index();
if idx == len {
idx - 1
} else {
idx
}
}
#[inline]
pub fn map<F>(self, f: F) -> SearchResult
where
F: FnOnce(usize) -> usize,
{
match self {
SearchResult::Found(i) => SearchResult::Found(f(i)),
SearchResult::NotFound(i) => SearchResult::NotFound(f(i)),
}
}
}
impl Display for SearchResult {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SearchResult::Found(i) => write!(f, "Found({i})"),
SearchResult::NotFound(i) => write!(f, "NotFound({i})"),
}
}
}
pub trait SearchSortedFn<Array> {
fn search_sorted(
&self,
array: &Array,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult>;
fn search_sorted_many(
&self,
array: &Array,
values: &[Scalar],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
values
.iter()
.map(|value| self.search_sorted(array, value, side))
.try_collect()
}
}
pub trait SearchSortedUsizeFn<Array> {
fn search_sorted_usize(
&self,
array: &Array,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult>;
fn search_sorted_usize_many(
&self,
array: &Array,
values: &[usize],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
values
.iter()
.map(|&value| self.search_sorted_usize(array, value, side))
.try_collect()
}
}
impl<E: Encoding> SearchSortedFn<ArrayData> for E
where
E: SearchSortedFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn search_sorted(
&self,
array: &ArrayData,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
SearchSortedFn::search_sorted(encoding, array_ref, value, side)
}
fn search_sorted_many(
&self,
array: &ArrayData,
values: &[Scalar],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
SearchSortedFn::search_sorted_many(encoding, array_ref, values, side)
}
}
impl<E: Encoding> SearchSortedUsizeFn<ArrayData> for E
where
E: SearchSortedUsizeFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn search_sorted_usize(
&self,
array: &ArrayData,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
SearchSortedUsizeFn::search_sorted_usize(encoding, array_ref, value, side)
}
fn search_sorted_usize_many(
&self,
array: &ArrayData,
values: &[usize],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
SearchSortedUsizeFn::search_sorted_usize_many(encoding, array_ref, values, side)
}
}
pub fn search_sorted<T: Into<Scalar>>(
array: &ArrayData,
target: T,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let Ok(scalar) = target.into().cast(array.dtype()) else {
return Ok(SearchResult::NotFound(array.len()));
};
if scalar.is_null() {
vortex_bail!("Search sorted with null value is not supported");
}
if let Some(f) = array.encoding().search_sorted_fn() {
return f.search_sorted(array, &scalar, side);
}
if array.encoding().scalar_at_fn().is_some() {
return Ok(SearchSorted::search_sorted(array, &scalar, side));
}
vortex_bail!(
NotImplemented: "search_sorted",
array.encoding().id()
)
}
pub fn search_sorted_usize(
array: &ArrayData,
target: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
if let Some(f) = array.encoding().search_sorted_usize_fn() {
return f.search_sorted_usize(array, target, side);
}
let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
return Ok(SearchResult::NotFound(array.len()));
};
if let Some(f) = array.encoding().search_sorted_fn() {
return f.search_sorted(array, &target, side);
}
if array.encoding().scalar_at_fn().is_some() {
let Ok(target) = target.cast(array.dtype()) else {
return Ok(SearchResult::NotFound(array.len()));
};
return Ok(SearchSorted::search_sorted(array, &target, side));
}
vortex_bail!(
NotImplemented: "search_sorted_usize",
array.encoding().id()
)
}
pub fn search_sorted_many<T: Into<Scalar> + Clone>(
array: &ArrayData,
targets: &[T],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
if let Some(f) = array.encoding().search_sorted_fn() {
let mut too_big_cast_idxs = Vec::new();
let values = targets
.iter()
.cloned()
.enumerate()
.filter_map(|(i, t)| {
let Ok(c) = t.into().cast(array.dtype()) else {
too_big_cast_idxs.push(i);
return None;
};
Some(c)
})
.collect::<Vec<_>>();
let mut results = f.search_sorted_many(array, &values, side)?;
for too_big_idx in too_big_cast_idxs {
results.insert(too_big_idx, SearchResult::NotFound(array.len()));
}
return Ok(results);
}
targets
.iter()
.map(|target| search_sorted(array, target.clone(), side))
.try_collect()
}
pub fn search_sorted_usize_many(
array: &ArrayData,
targets: &[usize],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
if let Some(f) = array.encoding().search_sorted_usize_fn() {
return f.search_sorted_usize_many(array, targets, side);
}
targets
.iter()
.map(|&target| search_sorted_usize(array, target, side))
.try_collect()
}
pub trait IndexOrd<V> {
fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;
fn index_lt(&self, idx: usize, elem: &V) -> bool {
matches!(self.index_cmp(idx, elem), Some(Less))
}
fn index_le(&self, idx: usize, elem: &V) -> bool {
matches!(self.index_cmp(idx, elem), Some(Less | Equal))
}
fn index_gt(&self, idx: usize, elem: &V) -> bool {
matches!(self.index_cmp(idx, elem), Some(Greater))
}
fn index_ge(&self, idx: usize, elem: &V) -> bool {
matches!(self.index_cmp(idx, elem), Some(Greater | Equal))
}
}
#[allow(clippy::len_without_is_empty)]
pub trait Len {
fn len(&self) -> usize;
}
pub trait SearchSorted<T> {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
where
Self: IndexOrd<T>,
{
match side {
SearchSortedSide::Left => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
},
side,
),
SearchSortedSide::Right => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
},
side,
),
}
}
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult;
}
impl<S, T> SearchSorted<T> for S
where
S: IndexOrd<T> + Len + ?Sized,
{
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult {
match search_sorted_side_idx(find, 0, self.len()) {
SearchResult::Found(found) => {
let idx_search = match side {
SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
};
match idx_search {
SearchResult::NotFound(i) => SearchResult::Found(i),
_ => unreachable!(
"searching amongst equal values should never return Found result"
),
}
}
s => s,
}
}
}
fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
mut find: F,
from: usize,
to: usize,
) -> SearchResult {
let mut size = to - from;
if size == 0 {
return SearchResult::NotFound(0);
}
let mut base = from;
while size > 1 {
let half = size / 2;
let mid = base + half;
let cmp = find(mid);
base = if cmp == Greater { base } else { mid };
size -= half;
}
let cmp = find(base);
if cmp == Equal {
unsafe { hint::assert_unchecked(base < to) };
SearchResult::Found(base)
} else {
let result = base + (cmp == Less) as usize;
unsafe { hint::assert_unchecked(result <= to) };
SearchResult::NotFound(result)
}
}
impl IndexOrd<Scalar> for ArrayData {
fn index_cmp(&self, idx: usize, elem: &Scalar) -> Option<Ordering> {
let scalar_a = scalar_at(self, idx).ok()?;
scalar_a.partial_cmp(elem)
}
}
impl<T: PartialOrd> IndexOrd<T> for [T] {
fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
unsafe { self.get_unchecked(idx) }.partial_cmp(elem)
}
}
impl Len for ArrayData {
#[allow(clippy::same_name_method)]
fn len(&self) -> usize {
Self::len(self)
}
}
impl<T> Len for [T] {
fn len(&self) -> usize {
self.len()
}
}
#[cfg(test)]
mod test {
use vortex_buffer::buffer;
use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
use crate::compute::{search_sorted, search_sorted_many};
use crate::IntoArrayData;
#[test]
fn left_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 2);
assert_eq!(res, SearchResult::Found(2));
}
#[test]
fn right_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 2);
assert_eq!(res, SearchResult::Found(6));
}
#[test]
fn left_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 0);
assert_eq!(res, SearchResult::Found(0));
}
#[test]
fn right_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 0);
assert_eq!(res, SearchResult::Found(4));
}
#[test]
fn left_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 9);
assert_eq!(res, SearchResult::Found(9));
}
#[test]
fn right_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 9);
assert_eq!(res, SearchResult::Found(13));
}
#[test]
fn failed_cast() {
let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
let res = search_sorted(&arr, 256, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::NotFound(arr.len()));
}
#[test]
fn search_sorted_many_failed_cast() {
let arr = buffer![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9].into_array();
let res = search_sorted_many(&arr, &[256], SearchSortedSide::Left).unwrap();
assert_eq!(res, vec![SearchResult::NotFound(arr.len())]);
}
}