Skip to main content

vortex_array/arrays/primitive/array/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::vortex_panic;
6
7use super::PrimitiveData;
8use crate::dtype::NativePType;
9
10impl PrimitiveData {
11    /// Return a slice of the array's buffer.
12    ///
13    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
14    ///
15    /// # Panic
16    ///
17    /// This operation will panic if the array is not backed by host memory.
18    pub fn as_slice<T: NativePType>(&self) -> &[T] {
19        if T::PTYPE != self.ptype() {
20            vortex_panic!(
21                "Attempted to get slice of type {} from array of type {}",
22                T::PTYPE,
23                self.ptype()
24            )
25        }
26
27        let byte_buffer = self
28            .buffer
29            .as_host_opt()
30            .vortex_expect("as_slice must be called on host buffer");
31        let raw_slice = byte_buffer.as_ptr();
32
33        // SAFETY: alignment of Buffer is checked on construction
34        unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use std::sync::LazyLock;
41
42    use rstest::rstest;
43    use vortex_buffer::Buffer;
44    use vortex_buffer::buffer;
45    use vortex_session::VortexSession;
46
47    use crate::VortexSessionExecute;
48    use crate::arrays::PrimitiveArray;
49    use crate::arrays::primitive::PrimitiveArrayExt;
50    use crate::dtype::DType;
51    use crate::dtype::Nullability;
52    use crate::dtype::PType;
53    use crate::session::ArraySession;
54    use crate::validity::Validity;
55
56    static SESSION: LazyLock<VortexSession> =
57        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
58
59    #[test]
60    fn test_downcast_all_invalid() {
61        let array = PrimitiveArray::new(
62            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
63            Validity::AllInvalid,
64        );
65
66        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
67        assert_eq!(
68            result.dtype(),
69            &DType::Primitive(PType::U8, Nullability::Nullable)
70        );
71        assert!(matches!(result.validity(), Ok(Validity::AllInvalid)));
72    }
73
74    #[rstest]
75    #[case(vec![0_i64, 127], PType::U8)]
76    #[case(vec![-128_i64, 127], PType::I8)]
77    #[case(vec![-129_i64, 127], PType::I16)]
78    #[case(vec![-128_i64, 128], PType::I16)]
79    #[case(vec![-32768_i64, 32767], PType::I16)]
80    #[case(vec![-32769_i64, 32767], PType::I32)]
81    #[case(vec![-32768_i64, 32768], PType::I32)]
82    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
83    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
84        let array = PrimitiveArray::from_iter(values);
85        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
86        assert_eq!(result.ptype(), expected_ptype);
87    }
88
89    #[rstest]
90    #[case(vec![0_u64, 255], PType::U8)]
91    #[case(vec![0_u64, 256], PType::U16)]
92    #[case(vec![0_u64, 65535], PType::U16)]
93    #[case(vec![0_u64, 65536], PType::U32)]
94    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
95    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
96        let array = PrimitiveArray::from_iter(values);
97        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
98        assert_eq!(result.ptype(), expected_ptype);
99    }
100
101    #[test]
102    fn test_downcast_keeps_original_if_too_large() {
103        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
104        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
105        assert_eq!(result.ptype(), PType::U64);
106    }
107
108    #[test]
109    fn test_downcast_preserves_nullability() {
110        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
111        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
112        assert_eq!(
113            result.dtype(),
114            &DType::Primitive(PType::U8, Nullability::Nullable)
115        );
116        // Check that validity is preserved (the array should still have nullable values)
117        assert!(matches!(result.validity(), Ok(Validity::Array(_))));
118    }
119
120    #[test]
121    fn test_downcast_preserves_values() {
122        let values = vec![-100_i16, 0, 100];
123        let array = PrimitiveArray::from_iter(values);
124        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
125
126        assert_eq!(result.ptype(), PType::I8);
127        // Check that the values were properly downscaled
128        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
129        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
130    }
131
132    #[test]
133    fn test_downcast_with_mixed_signs_chooses_signed() {
134        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
135        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
136        assert_eq!(result.ptype(), PType::I16);
137    }
138
139    #[test]
140    fn test_downcast_floats() {
141        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
142        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
143        // Floats should remain unchanged since they can't be downscaled to integers
144        assert_eq!(result.ptype(), PType::F32);
145    }
146
147    #[test]
148    fn test_downcast_empty_array() {
149        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
150        let result = array.narrow(&mut SESSION.create_execution_ctx()).unwrap();
151        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
152        let result2 = array2.narrow(&mut SESSION.create_execution_ctx()).unwrap();
153        // Empty arrays should not have their validity changed
154        assert!(matches!(result.validity(), Ok(Validity::AllInvalid)));
155        assert!(matches!(result2.validity(), Ok(Validity::NonNullable)));
156    }
157}