1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::array::sparse::SparseArray;
use crate::compute::unary::{scalar_at, ScalarAtFn};
use crate::compute::{
    search_sorted, ArrayCompute, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn,
};
use crate::ArrayDType;

mod slice;
mod take;

impl ArrayCompute for SparseArray {
    fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
        Some(self)
    }

    fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
        Some(self)
    }

    fn slice(&self) -> Option<&dyn SliceFn> {
        Some(self)
    }

    fn take(&self) -> Option<&dyn TakeFn> {
        Some(self)
    }
}

impl ScalarAtFn for SparseArray {
    fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
        match self.find_index(index)? {
            None => self.fill_value().clone().cast(self.dtype()),
            Some(idx) => scalar_at(&self.values(), idx)?.cast(self.dtype()),
        }
    }
}

impl SearchSortedFn for SparseArray {
    fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
        search_sorted(&self.values(), value.clone(), side).and_then(|sr| match sr {
            SearchResult::Found(i) => {
                let index: usize = scalar_at(&self.indices(), i)?.as_ref().try_into().unwrap();
                Ok(SearchResult::Found(index - self.indices_offset()))
            }
            SearchResult::NotFound(i) => {
                let index: usize = scalar_at(&self.indices(), if i == 0 { 0 } else { i - 1 })?
                    .as_ref()
                    .try_into()
                    .unwrap();
                Ok(SearchResult::NotFound(
                    if i == 0 { index } else { index + 1 } - self.indices_offset(),
                ))
            }
        })
    }
}

#[cfg(test)]
mod test {
    use vortex_dtype::{DType, Nullability, PType};
    use vortex_scalar::Scalar;

    use crate::array::primitive::PrimitiveArray;
    use crate::array::sparse::SparseArray;
    use crate::compute::{search_sorted, slice, SearchResult, SearchSortedSide};
    use crate::validity::Validity;
    use crate::{Array, IntoArray};

    fn array() -> Array {
        SparseArray::try_new(
            PrimitiveArray::from(vec![2u64, 9, 15]).into_array(),
            PrimitiveArray::from_vec(vec![33, 44, 55], Validity::AllValid).into_array(),
            20,
            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
        )
        .unwrap()
        .into_array()
    }

    #[test]
    pub fn search_larger_than() {
        let res = search_sorted(&array(), 66, SearchSortedSide::Left).unwrap();
        assert_eq!(res, SearchResult::NotFound(16));
    }

    #[test]
    pub fn search_less_than() {
        let res = search_sorted(&array(), 22, SearchSortedSide::Left).unwrap();
        assert_eq!(res, SearchResult::NotFound(2));
    }

    #[test]
    pub fn search_found() {
        let res = search_sorted(&array(), 44, SearchSortedSide::Left).unwrap();
        assert_eq!(res, SearchResult::Found(9));
    }

    #[test]
    pub fn search_sliced() {
        let array = slice(&array(), 7, 20).unwrap();
        assert_eq!(
            search_sorted(&array, 22, SearchSortedSide::Left).unwrap(),
            SearchResult::NotFound(2)
        );
    }
}