vortex_sequence/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{VortexResult, vortex_err};
8use vortex_scalar::{Scalar, ScalarValue};
9
10use crate::{SequenceArray, SequenceVTable};
11
12impl CastKernel for SequenceVTable {
13    fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14        // SequenceArray represents arithmetic sequences (base + i * multiplier) which
15        // only makes sense for integer types. Floating-point sequences would accumulate
16        // rounding errors, and other types don't support arithmetic operations.
17        let DType::Primitive(target_ptype, target_nullability) = dtype else {
18            return Ok(None);
19        };
20
21        if !target_ptype.is_int() {
22            return Ok(None);
23        }
24
25        // Check if this is just a nullability change
26        if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
27            // For SequenceArray, we can just create a new one with the same parameters
28            // but different nullability
29            return Ok(Some(
30                SequenceArray::new(
31                    array.base(),
32                    array.multiplier(),
33                    *target_ptype,
34                    *target_nullability,
35                    array.len(),
36                )?
37                .into_array(),
38            ));
39        }
40
41        // For type changes, we need to cast the base and multiplier
42        if array.ptype() != *target_ptype {
43            // Create scalars from PValues and cast them
44            let base_scalar = Scalar::new(
45                DType::Primitive(array.ptype(), Nullability::NonNullable),
46                ScalarValue::from(array.base()),
47            );
48            let multiplier_scalar = Scalar::new(
49                DType::Primitive(array.ptype(), Nullability::NonNullable),
50                ScalarValue::from(array.multiplier()),
51            );
52
53            let new_base_scalar =
54                base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
55            let new_multiplier_scalar = multiplier_scalar
56                .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
57
58            // Extract PValues from the casted scalars
59            let new_base = new_base_scalar
60                .as_primitive()
61                .pvalue()
62                .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
63            let new_multiplier = new_multiplier_scalar
64                .as_primitive()
65                .pvalue()
66                .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
67
68            return Ok(Some(
69                SequenceArray::new(
70                    new_base,
71                    new_multiplier,
72                    *target_ptype,
73                    *target_nullability,
74                    array.len(),
75                )?
76                .into_array(),
77            ));
78        }
79
80        Ok(None)
81    }
82}
83
84register_kernel!(CastKernelAdapter(SequenceVTable).lift());
85
86#[cfg(test)]
87mod tests {
88    use rstest::rstest;
89    use vortex_array::compute::cast;
90    use vortex_array::compute::conformance::cast::test_cast_conformance;
91    use vortex_dtype::{DType, Nullability, PType};
92
93    use crate::SequenceArray;
94
95    #[test]
96    fn test_cast_sequence_nullability() {
97        let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
98
99        // Cast to nullable
100        let casted = cast(
101            sequence.as_ref(),
102            &DType::Primitive(PType::U32, Nullability::Nullable),
103        )
104        .unwrap();
105        assert_eq!(
106            casted.dtype(),
107            &DType::Primitive(PType::U32, Nullability::Nullable)
108        );
109    }
110
111    #[test]
112    fn test_cast_sequence_u32_to_i64() {
113        let sequence =
114            SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
115
116        let casted = cast(
117            sequence.as_ref(),
118            &DType::Primitive(PType::I64, Nullability::NonNullable),
119        )
120        .unwrap();
121        assert_eq!(
122            casted.dtype(),
123            &DType::Primitive(PType::I64, Nullability::NonNullable)
124        );
125
126        // Verify the values
127        let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
128        assert_eq!(decoded.as_slice::<i64>(), &[100i64, 110, 120, 130]);
129    }
130
131    #[test]
132    fn test_cast_sequence_i16_to_i32_nullable() {
133        // Test ptype change AND nullability change in one cast
134        let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
135
136        let casted = cast(
137            sequence.as_ref(),
138            &DType::Primitive(PType::I32, Nullability::Nullable),
139        )
140        .unwrap();
141        assert_eq!(
142            casted.dtype(),
143            &DType::Primitive(PType::I32, Nullability::Nullable)
144        );
145
146        // Verify the values
147        let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
148        assert_eq!(decoded.as_slice::<i32>(), &[5i32, 8, 11]);
149    }
150
151    #[test]
152    fn test_cast_sequence_to_float_delegates_to_canonical() {
153        let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
154
155        // Cast to float should delegate to canonical (SequenceArray doesn't support float)
156        let casted = cast(
157            sequence.as_ref(),
158            &DType::Primitive(PType::F32, Nullability::NonNullable),
159        )
160        .unwrap();
161        // Should still succeed by decoding to canonical first
162        assert_eq!(
163            casted.dtype(),
164            &DType::Primitive(PType::F32, Nullability::NonNullable)
165        );
166
167        // Verify the values were correctly converted
168        let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
169        let float_values = decoded.as_slice::<f32>();
170        assert_eq!(float_values, &[0.0f32, 1.0, 2.0, 3.0, 4.0]);
171    }
172
173    #[rstest]
174    #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
175    #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
176    #[case::negative_step(SequenceArray::typed_new(100i32, -10i32, Nullability::NonNullable, 5).unwrap())]
177    #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
178    #[case::constant(SequenceArray::typed_new(
179        100i32,
180        0i32, // multiplier of 0 means constant array
181        Nullability::NonNullable,
182        5,
183    ).unwrap())]
184    fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
185        test_cast_conformance(sequence.as_ref());
186    }
187}