vortex_array/arrays/list/compute/
is_constant.rs

1use vortex_error::VortexResult;
2use vortex_scalar::NumericOperator;
3
4use crate::arrays::{ListArray, ListVTable};
5use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, numeric};
6use crate::register_kernel;
7
8const SMALL_ARRAY_THRESHOLD: usize = 64;
9
10impl IsConstantKernel for ListVTable {
11    fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
12        // At this point, we're guaranteed:
13        // - Array has at least 2 elements
14        // - All elements are valid (no nulls)
15
16        let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
17
18        let first_list_len = array.offset_at(1) - array.offset_at(0);
19        for i in 1..manual_check_until {
20            let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
21            if current_list_len != first_list_len {
22                return Ok(Some(false));
23            }
24        }
25
26        if opts.is_negligible_cost() {
27            return Ok(None);
28        }
29
30        if array.len() > SMALL_ARRAY_THRESHOLD {
31            // check the rest of the element lengths
32            let start_offsets = array.offsets.slice(SMALL_ARRAY_THRESHOLD, array.len())?;
33            let end_offsets = array
34                .offsets
35                .slice(SMALL_ARRAY_THRESHOLD + 1, array.len() + 1)?;
36            let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
37
38            if !list_lengths.is_constant() {
39                return Ok(Some(false));
40            }
41        }
42
43        // If all lists have the same length, compare the actual list contents
44        let first_scalar = array.scalar_at(0)?;
45        for i in 1..array.len() {
46            let current_scalar = array.scalar_at(i)?;
47            if current_scalar != first_scalar {
48                return Ok(Some(false));
49            }
50        }
51
52        Ok(Some(true))
53    }
54}
55
56register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
57
58#[cfg(test)]
59mod tests {
60
61    use rstest::rstest;
62    use vortex_dtype::FieldNames;
63
64    use crate::IntoArray;
65    use crate::arrays::{ListArray, PrimitiveArray, StructArray};
66    use crate::compute::is_constant;
67    use crate::validity::Validity;
68
69    #[test]
70    fn test_is_constant_nested_list() {
71        let xs = ListArray::try_new(
72            PrimitiveArray::from_iter([0i32, 1, 0, 1]).into_array(),
73            PrimitiveArray::from_iter([0u32, 2, 4]).into_array(),
74            Validity::NonNullable,
75        )
76        .unwrap();
77
78        let struct_of_lists = StructArray::try_new(
79            FieldNames::from(["xs"]),
80            vec![xs.into_array()],
81            2,
82            Validity::NonNullable,
83        )
84        .unwrap();
85        assert!(
86            is_constant(&struct_of_lists.clone().into_array())
87                .unwrap()
88                .unwrap()
89        );
90        assert!(struct_of_lists.is_constant());
91    }
92
93    #[rstest]
94    #[case(
95        // [1,2], [1, 2], [1, 2]
96        vec![1i32, 2, 1, 2, 1, 2],
97        vec![0u32, 2, 4, 6],
98        true
99    )]
100    #[case(
101        // [1, 2], [3], [4, 5]
102        vec![1i32, 2, 3, 4, 5],
103        vec![0u32, 2, 3, 5],
104        false
105    )]
106    #[case(
107        // [1, 2], [3, 4]
108        vec![1i32, 2, 3, 4],
109        vec![0u32, 2, 4],
110        false
111    )]
112    #[case(
113        // [], [], []
114        vec![],
115        vec![0u32, 0, 0, 0],
116        true
117    )]
118    fn test_list_is_constant(
119        #[case] elements: Vec<i32>,
120        #[case] offsets: Vec<u32>,
121        #[case] expected: bool,
122    ) {
123        let list_array = ListArray::try_new(
124            PrimitiveArray::from_iter(elements).into_array(),
125            PrimitiveArray::from_iter(offsets).into_array(),
126            Validity::NonNullable,
127        )
128        .unwrap();
129
130        let result = is_constant(&list_array.into_array()).unwrap();
131        assert_eq!(result.unwrap(), expected);
132    }
133
134    #[test]
135    fn test_list_is_constant_nested_lists() {
136        let inner_elements = PrimitiveArray::from_iter([1i32, 2, 1, 2]).into_array();
137        let inner_offsets = PrimitiveArray::from_iter([0u32, 1, 2, 3, 4]).into_array();
138        let inner_lists =
139            ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
140
141        let outer_offsets = PrimitiveArray::from_iter([0u32, 2, 4]).into_array();
142        let outer_list = ListArray::try_new(
143            inner_lists.into_array(),
144            outer_offsets,
145            Validity::NonNullable,
146        )
147        .unwrap();
148
149        // Both outer lists contain [[1], [2]], so should be constant
150        assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
151    }
152
153    #[rstest]
154    #[case(
155        // 100 identical [1, 2] lists
156        [1i32, 2].repeat(100),
157        (0..101).map(|i| (i * 2) as u32).collect(),
158        true
159    )]
160    #[case(
161        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
162        {
163            let mut elements = [1i32, 2].repeat(64);
164            elements.extend_from_slice(&[3, 4]);
165            elements
166        },
167        (0..66).map(|i| (i * 2) as u32).collect(),
168        false
169    )]
170    #[case(
171        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
172        {
173            let mut elements = [1i32, 2].repeat(63);
174            elements.extend_from_slice(&[3, 4]);
175            elements.extend([1i32, 2].repeat(37));
176            elements
177        },
178        (0..101).map(|i| (i * 2) as u32).collect(),
179        false
180    )]
181    fn test_large_list_is_constant(
182        #[case] elements: Vec<i32>,
183        #[case] offsets: Vec<u32>,
184        #[case] expected: bool,
185    ) {
186        let list_array = ListArray::try_new(
187            PrimitiveArray::from_iter(elements).into_array(),
188            PrimitiveArray::from_iter(offsets).into_array(),
189            Validity::NonNullable,
190        )
191        .unwrap();
192
193        let result = is_constant(&list_array.into_array()).unwrap();
194        assert_eq!(result.unwrap(), expected);
195    }
196}