Skip to main content

vortex_sequence/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::dtype::DType;
7use vortex_array::dtype::Nullability;
8use vortex_array::scalar::Scalar;
9use vortex_array::scalar::ScalarValue;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::Sequence;
15use crate::SequenceArray;
16
17impl CastReduce for Sequence {
18    fn cast(array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        // SequenceArray represents arithmetic sequences (base + i * multiplier) which
20        // only makes sense for integer types. Floating-point sequences would accumulate
21        // rounding errors, and other types don't support arithmetic operations.
22        let DType::Primitive(target_ptype, target_nullability) = dtype else {
23            return Ok(None);
24        };
25
26        if !target_ptype.is_int() {
27            return Ok(None);
28        }
29
30        // Check if this is just a nullability change
31        if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
32            // For SequenceArray, we can just create a new one with the same parameters
33            // but different nullability
34            return Ok(Some(
35                SequenceArray::try_new(
36                    array.base(),
37                    array.multiplier(),
38                    *target_ptype,
39                    *target_nullability,
40                    array.len(),
41                )?
42                .into_array(),
43            ));
44        }
45
46        // For type changes, we need to cast the base and multiplier
47        if array.ptype() != *target_ptype {
48            // Create scalars from PValues and cast them
49            let base_scalar = Scalar::try_new(
50                DType::Primitive(array.ptype(), Nullability::NonNullable),
51                Some(ScalarValue::Primitive(array.base())),
52            )?;
53            let multiplier_scalar = Scalar::try_new(
54                DType::Primitive(array.ptype(), Nullability::NonNullable),
55                Some(ScalarValue::Primitive(array.multiplier())),
56            )?;
57
58            let new_base_scalar =
59                base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
60            let new_multiplier_scalar = multiplier_scalar
61                .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
62
63            // Extract PValues from the casted scalars
64            let new_base = new_base_scalar
65                .as_primitive()
66                .pvalue()
67                .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
68            let new_multiplier = new_multiplier_scalar
69                .as_primitive()
70                .pvalue()
71                .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
72
73            return Ok(Some(
74                SequenceArray::try_new(
75                    new_base,
76                    new_multiplier,
77                    *target_ptype,
78                    *target_nullability,
79                    array.len(),
80                )?
81                .into_array(),
82            ));
83        }
84
85        Ok(None)
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use rstest::rstest;
92    use vortex_array::IntoArray;
93    use vortex_array::ToCanonical;
94    use vortex_array::arrays::PrimitiveArray;
95    use vortex_array::assert_arrays_eq;
96    use vortex_array::builtins::ArrayBuiltins;
97    use vortex_array::compute::conformance::cast::test_cast_conformance;
98    use vortex_array::dtype::DType;
99    use vortex_array::dtype::Nullability;
100    use vortex_array::dtype::PType;
101
102    use crate::SequenceArray;
103
104    #[test]
105    fn test_cast_sequence_nullability() {
106        let sequence =
107            SequenceArray::try_new_typed(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
108
109        // Cast to nullable
110        let casted = sequence
111            .into_array()
112            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
113            .unwrap();
114        assert_eq!(
115            casted.dtype(),
116            &DType::Primitive(PType::U32, Nullability::Nullable)
117        );
118    }
119
120    #[test]
121    fn test_cast_sequence_u32_to_i64() {
122        let sequence =
123            SequenceArray::try_new_typed(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
124
125        let casted = sequence
126            .into_array()
127            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
128            .unwrap();
129        assert_eq!(
130            casted.dtype(),
131            &DType::Primitive(PType::I64, Nullability::NonNullable)
132        );
133
134        // Verify the values
135        let decoded = casted.to_primitive();
136        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
137    }
138
139    #[test]
140    fn test_cast_sequence_i16_to_i32_nullable() {
141        // Test ptype change AND nullability change in one cast
142        let sequence =
143            SequenceArray::try_new_typed(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
144
145        let casted = sequence
146            .into_array()
147            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
148            .unwrap();
149        assert_eq!(
150            casted.dtype(),
151            &DType::Primitive(PType::I32, Nullability::Nullable)
152        );
153
154        // Verify the values
155        let decoded = casted.to_primitive();
156        assert_arrays_eq!(
157            decoded,
158            PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
159        );
160    }
161
162    #[test]
163    fn test_cast_sequence_to_float_delegates_to_canonical() {
164        let sequence =
165            SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
166
167        // Cast to float should delegate to canonical (SequenceArray doesn't support float)
168        let casted = sequence
169            .into_array()
170            .cast(DType::Primitive(PType::F32, Nullability::NonNullable))
171            .unwrap();
172        // Should still succeed by decoding to canonical first
173        assert_eq!(
174            casted.dtype(),
175            &DType::Primitive(PType::F32, Nullability::NonNullable)
176        );
177
178        // Verify the values were correctly converted
179        let decoded = casted.to_primitive();
180        assert_arrays_eq!(
181            decoded,
182            PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
183        );
184    }
185
186    #[rstest]
187    #[case::i32(SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
188    #[case::u64(SequenceArray::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
189    // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even
190    // though all its values are representable therein.
191    //
192    // #[case::negative_step(SequenceArray::try_new_typed(100i32, -10i32, Nullability::NonNullable,
193    // 5).unwrap())]
194    #[case::single(SequenceArray::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
195    #[case::constant(SequenceArray::try_new_typed(
196        100i32,
197        0i32, // multiplier of 0 means constant array
198        Nullability::NonNullable,
199        5,
200    ).unwrap())]
201    fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
202        test_cast_conformance(&sequence.into_array());
203    }
204}