vortex_sequence/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{VortexResult, vortex_err};
8use vortex_scalar::{Scalar, ScalarValue};
9
10use crate::{SequenceArray, SequenceVTable};
11
12impl CastKernel for SequenceVTable {
13    fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14        // SequenceArray represents arithmetic sequences (base + i * multiplier) which
15        // only makes sense for integer types. Floating-point sequences would accumulate
16        // rounding errors, and other types don't support arithmetic operations.
17        let DType::Primitive(target_ptype, target_nullability) = dtype else {
18            return Ok(None);
19        };
20
21        if !target_ptype.is_int() {
22            return Ok(None);
23        }
24
25        // Check if this is just a nullability change
26        if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
27            // For SequenceArray, we can just create a new one with the same parameters
28            // but different nullability
29            return Ok(Some(
30                SequenceArray::new(
31                    array.base(),
32                    array.multiplier(),
33                    *target_ptype,
34                    *target_nullability,
35                    array.len(),
36                )?
37                .into_array(),
38            ));
39        }
40
41        // For type changes, we need to cast the base and multiplier
42        if array.ptype() != *target_ptype {
43            // Create scalars from PValues and cast them
44            let base_scalar = Scalar::new(
45                DType::Primitive(array.ptype(), Nullability::NonNullable),
46                ScalarValue::from(array.base()),
47            );
48            let multiplier_scalar = Scalar::new(
49                DType::Primitive(array.ptype(), Nullability::NonNullable),
50                ScalarValue::from(array.multiplier()),
51            );
52
53            let new_base_scalar =
54                base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
55            let new_multiplier_scalar = multiplier_scalar
56                .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
57
58            // Extract PValues from the casted scalars
59            let new_base = new_base_scalar
60                .as_primitive()
61                .pvalue()
62                .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
63            let new_multiplier = new_multiplier_scalar
64                .as_primitive()
65                .pvalue()
66                .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
67
68            return Ok(Some(
69                SequenceArray::new(
70                    new_base,
71                    new_multiplier,
72                    *target_ptype,
73                    *target_nullability,
74                    array.len(),
75                )?
76                .into_array(),
77            ));
78        }
79
80        Ok(None)
81    }
82}
83
84register_kernel!(CastKernelAdapter(SequenceVTable).lift());
85
86#[cfg(test)]
87mod tests {
88    use rstest::rstest;
89    use vortex_array::ToCanonical;
90    use vortex_array::compute::cast;
91    use vortex_array::compute::conformance::cast::test_cast_conformance;
92    use vortex_dtype::{DType, Nullability, PType};
93
94    use crate::SequenceArray;
95
96    #[test]
97    fn test_cast_sequence_nullability() {
98        let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
99
100        // Cast to nullable
101        let casted = cast(
102            sequence.as_ref(),
103            &DType::Primitive(PType::U32, Nullability::Nullable),
104        )
105        .unwrap();
106        assert_eq!(
107            casted.dtype(),
108            &DType::Primitive(PType::U32, Nullability::Nullable)
109        );
110    }
111
112    #[test]
113    fn test_cast_sequence_u32_to_i64() {
114        let sequence =
115            SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
116
117        let casted = cast(
118            sequence.as_ref(),
119            &DType::Primitive(PType::I64, Nullability::NonNullable),
120        )
121        .unwrap();
122        assert_eq!(
123            casted.dtype(),
124            &DType::Primitive(PType::I64, Nullability::NonNullable)
125        );
126
127        // Verify the values
128        let decoded = casted.to_primitive();
129        assert_eq!(decoded.as_slice::<i64>(), &[100i64, 110, 120, 130]);
130    }
131
132    #[test]
133    fn test_cast_sequence_i16_to_i32_nullable() {
134        // Test ptype change AND nullability change in one cast
135        let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
136
137        let casted = cast(
138            sequence.as_ref(),
139            &DType::Primitive(PType::I32, Nullability::Nullable),
140        )
141        .unwrap();
142        assert_eq!(
143            casted.dtype(),
144            &DType::Primitive(PType::I32, Nullability::Nullable)
145        );
146
147        // Verify the values
148        let decoded = casted.to_primitive();
149        assert_eq!(decoded.as_slice::<i32>(), &[5i32, 8, 11]);
150    }
151
152    #[test]
153    fn test_cast_sequence_to_float_delegates_to_canonical() {
154        let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
155
156        // Cast to float should delegate to canonical (SequenceArray doesn't support float)
157        let casted = cast(
158            sequence.as_ref(),
159            &DType::Primitive(PType::F32, Nullability::NonNullable),
160        )
161        .unwrap();
162        // Should still succeed by decoding to canonical first
163        assert_eq!(
164            casted.dtype(),
165            &DType::Primitive(PType::F32, Nullability::NonNullable)
166        );
167
168        // Verify the values were correctly converted
169        let decoded = casted.to_primitive();
170        let float_values = decoded.as_slice::<f32>();
171        assert_eq!(float_values, &[0.0f32, 1.0, 2.0, 3.0, 4.0]);
172    }
173
174    #[rstest]
175    #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
176    #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
177    #[case::negative_step(SequenceArray::typed_new(100i32, -10i32, Nullability::NonNullable, 5).unwrap())]
178    #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
179    #[case::constant(SequenceArray::typed_new(
180        100i32,
181        0i32, // multiplier of 0 means constant array
182        Nullability::NonNullable,
183        5,
184    ).unwrap())]
185    fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
186        test_cast_conformance(sequence.as_ref());
187    }
188}