vortex_sequence/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{VortexResult, vortex_err};
8use vortex_scalar::{Scalar, ScalarValue};
9
10use crate::{SequenceArray, SequenceVTable};
11
12impl CastKernel for SequenceVTable {
13    fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14        // SequenceArray represents arithmetic sequences (base + i * multiplier) which
15        // only makes sense for integer types. Floating-point sequences would accumulate
16        // rounding errors, and other types don't support arithmetic operations.
17        let DType::Primitive(target_ptype, target_nullability) = dtype else {
18            return Ok(None);
19        };
20
21        if !target_ptype.is_int() {
22            return Ok(None);
23        }
24
25        // Check if this is just a nullability change
26        if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
27            // For SequenceArray, we can just create a new one with the same parameters
28            // but different nullability
29            return Ok(Some(
30                SequenceArray::new(
31                    array.base(),
32                    array.multiplier(),
33                    *target_ptype,
34                    *target_nullability,
35                    array.len(),
36                )?
37                .into_array(),
38            ));
39        }
40
41        // For type changes, we need to cast the base and multiplier
42        if array.ptype() != *target_ptype {
43            // Create scalars from PValues and cast them
44            let base_scalar = Scalar::new(
45                DType::Primitive(array.ptype(), Nullability::NonNullable),
46                ScalarValue::from(array.base()),
47            );
48            let multiplier_scalar = Scalar::new(
49                DType::Primitive(array.ptype(), Nullability::NonNullable),
50                ScalarValue::from(array.multiplier()),
51            );
52
53            let new_base_scalar =
54                base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
55            let new_multiplier_scalar = multiplier_scalar
56                .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
57
58            // Extract PValues from the casted scalars
59            let new_base = new_base_scalar
60                .as_primitive()
61                .pvalue()
62                .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
63            let new_multiplier = new_multiplier_scalar
64                .as_primitive()
65                .pvalue()
66                .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
67
68            return Ok(Some(
69                SequenceArray::new(
70                    new_base,
71                    new_multiplier,
72                    *target_ptype,
73                    *target_nullability,
74                    array.len(),
75                )?
76                .into_array(),
77            ));
78        }
79
80        Ok(None)
81    }
82}
83
84register_kernel!(CastKernelAdapter(SequenceVTable).lift());
85
86#[cfg(test)]
87mod tests {
88    use rstest::rstest;
89    use vortex_array::arrays::PrimitiveArray;
90    use vortex_array::compute::cast;
91    use vortex_array::compute::conformance::cast::test_cast_conformance;
92    use vortex_array::{ToCanonical, assert_arrays_eq};
93    use vortex_dtype::{DType, Nullability, PType};
94
95    use crate::SequenceArray;
96
97    #[test]
98    fn test_cast_sequence_nullability() {
99        let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
100
101        // Cast to nullable
102        let casted = cast(
103            sequence.as_ref(),
104            &DType::Primitive(PType::U32, Nullability::Nullable),
105        )
106        .unwrap();
107        assert_eq!(
108            casted.dtype(),
109            &DType::Primitive(PType::U32, Nullability::Nullable)
110        );
111    }
112
113    #[test]
114    fn test_cast_sequence_u32_to_i64() {
115        let sequence =
116            SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
117
118        let casted = cast(
119            sequence.as_ref(),
120            &DType::Primitive(PType::I64, Nullability::NonNullable),
121        )
122        .unwrap();
123        assert_eq!(
124            casted.dtype(),
125            &DType::Primitive(PType::I64, Nullability::NonNullable)
126        );
127
128        // Verify the values
129        let decoded = casted.to_primitive();
130        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
131    }
132
133    #[test]
134    fn test_cast_sequence_i16_to_i32_nullable() {
135        // Test ptype change AND nullability change in one cast
136        let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
137
138        let casted = cast(
139            sequence.as_ref(),
140            &DType::Primitive(PType::I32, Nullability::Nullable),
141        )
142        .unwrap();
143        assert_eq!(
144            casted.dtype(),
145            &DType::Primitive(PType::I32, Nullability::Nullable)
146        );
147
148        // Verify the values
149        let decoded = casted.to_primitive();
150        assert_arrays_eq!(
151            decoded,
152            PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
153        );
154    }
155
156    #[test]
157    fn test_cast_sequence_to_float_delegates_to_canonical() {
158        let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
159
160        // Cast to float should delegate to canonical (SequenceArray doesn't support float)
161        let casted = cast(
162            sequence.as_ref(),
163            &DType::Primitive(PType::F32, Nullability::NonNullable),
164        )
165        .unwrap();
166        // Should still succeed by decoding to canonical first
167        assert_eq!(
168            casted.dtype(),
169            &DType::Primitive(PType::F32, Nullability::NonNullable)
170        );
171
172        // Verify the values were correctly converted
173        let decoded = casted.to_primitive();
174        assert_arrays_eq!(
175            decoded,
176            PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
177        );
178    }
179
180    #[rstest]
181    #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
182    #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
183    // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even
184    // though all its values are representable therein.
185    //
186    // #[case::negative_step(SequenceArray::typed_new(100i32, -10i32, Nullability::NonNullable,
187    // 5).unwrap())]
188    #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
189    #[case::constant(SequenceArray::typed_new(
190        100i32,
191        0i32, // multiplier of 0 means constant array
192        Nullability::NonNullable,
193        5,
194    ).unwrap())]
195    fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
196        test_cast_conformance(sequence.as_ref());
197    }
198}