Skip to main content

vortex_sequence/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::dtype::Nullability;
9use vortex_array::scalar::Scalar;
10use vortex_array::scalar::ScalarValue;
11use vortex_array::scalar_fn::fns::cast::CastReduce;
12use vortex_error::VortexResult;
13use vortex_error::vortex_err;
14
15use crate::Sequence;
16impl CastReduce for Sequence {
17    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        // SequenceArray represents arithmetic sequences (base + i * multiplier) which
19        // only makes sense for integer types. Floating-point sequences would accumulate
20        // rounding errors, and other types don't support arithmetic operations.
21        let DType::Primitive(target_ptype, target_nullability) = dtype else {
22            return Ok(None);
23        };
24
25        if !target_ptype.is_int() {
26            return Ok(None);
27        }
28
29        // Check if this is just a nullability change
30        if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
31            // For SequenceArray, we can just create a new one with the same parameters
32            // but different nullability
33            return Ok(Some(
34                Sequence::try_new(
35                    array.base(),
36                    array.multiplier(),
37                    *target_ptype,
38                    *target_nullability,
39                    array.len(),
40                )?
41                .into_array(),
42            ));
43        }
44
45        // For type changes, we need to cast the base and multiplier
46        if array.ptype() != *target_ptype {
47            // Create scalars from PValues and cast them
48            let base_scalar = Scalar::try_new(
49                DType::Primitive(array.ptype(), Nullability::NonNullable),
50                Some(ScalarValue::Primitive(array.base())),
51            )?;
52            let multiplier_scalar = Scalar::try_new(
53                DType::Primitive(array.ptype(), Nullability::NonNullable),
54                Some(ScalarValue::Primitive(array.multiplier())),
55            )?;
56
57            let new_base_scalar =
58                base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
59            let new_multiplier_scalar = multiplier_scalar
60                .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
61
62            // Extract PValues from the casted scalars
63            let new_base = new_base_scalar
64                .as_primitive()
65                .pvalue()
66                .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
67            let new_multiplier = new_multiplier_scalar
68                .as_primitive()
69                .pvalue()
70                .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
71
72            return Ok(Some(
73                Sequence::try_new(
74                    new_base,
75                    new_multiplier,
76                    *target_ptype,
77                    *target_nullability,
78                    array.len(),
79                )?
80                .into_array(),
81            ));
82        }
83
84        Ok(None)
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::LazyLock;
91
92    use rstest::rstest;
93    use vortex_array::IntoArray;
94    use vortex_array::VortexSessionExecute;
95    use vortex_array::arrays::PrimitiveArray;
96    use vortex_array::assert_arrays_eq;
97    use vortex_array::builtins::ArrayBuiltins;
98    use vortex_array::compute::conformance::cast::test_cast_conformance;
99    use vortex_array::dtype::DType;
100    use vortex_array::dtype::Nullability;
101    use vortex_array::dtype::PType;
102    use vortex_array::session::ArraySession;
103    use vortex_session::VortexSession;
104
105    use crate::Sequence;
106    use crate::SequenceArray;
107
108    static SESSION: LazyLock<VortexSession> =
109        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
110
111    #[test]
112    fn test_cast_sequence_nullability() {
113        let sequence = Sequence::try_new_typed(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
114
115        // Cast to nullable
116        let casted = sequence
117            .into_array()
118            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
119            .unwrap();
120        assert_eq!(
121            casted.dtype(),
122            &DType::Primitive(PType::U32, Nullability::Nullable)
123        );
124    }
125
126    #[test]
127    fn test_cast_sequence_u32_to_i64() {
128        let mut ctx = SESSION.create_execution_ctx();
129        let sequence = Sequence::try_new_typed(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
130
131        let casted = sequence
132            .into_array()
133            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
134            .unwrap();
135        assert_eq!(
136            casted.dtype(),
137            &DType::Primitive(PType::I64, Nullability::NonNullable)
138        );
139
140        // Verify the values
141        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
142        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
143    }
144
145    #[test]
146    fn test_cast_sequence_i16_to_i32_nullable() {
147        let mut ctx = SESSION.create_execution_ctx();
148        // Test ptype change AND nullability change in one cast
149        let sequence = Sequence::try_new_typed(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
150
151        let casted = sequence
152            .into_array()
153            .cast(DType::Primitive(PType::I32, Nullability::Nullable))
154            .unwrap();
155        assert_eq!(
156            casted.dtype(),
157            &DType::Primitive(PType::I32, Nullability::Nullable)
158        );
159
160        // Verify the values
161        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
162        assert_arrays_eq!(
163            decoded,
164            PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
165        );
166    }
167
168    #[test]
169    fn test_cast_sequence_to_float_delegates_to_canonical() {
170        let mut ctx = SESSION.create_execution_ctx();
171        let sequence = Sequence::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
172
173        // Cast to float should delegate to canonical (SequenceArray doesn't support float)
174        let casted = sequence
175            .into_array()
176            .cast(DType::Primitive(PType::F32, Nullability::NonNullable))
177            .unwrap();
178        // Should still succeed by decoding to canonical first
179        assert_eq!(
180            casted.dtype(),
181            &DType::Primitive(PType::F32, Nullability::NonNullable)
182        );
183
184        // Verify the values were correctly converted
185        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
186        assert_arrays_eq!(
187            decoded,
188            PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
189        );
190    }
191
192    #[rstest]
193    #[case::i32(Sequence::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
194    #[case::u64(Sequence::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
195    // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even
196    // though all its values are representable therein.
197    //
198    // #[case::negative_step(Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable,
199    // 5).unwrap())]
200    #[case::single(Sequence::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
201    #[case::constant(Sequence::try_new_typed(
202        100i32,
203        0i32, // multiplier of 0 means constant array
204        Nullability::NonNullable,
205        5,
206    ).unwrap())]
207    fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
208        test_cast_conformance(&sequence.into_array());
209    }
210}