Skip to main content

vortex_array/scalar_fn/fns/
ext_storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_session::VortexSession;
7use vortex_session::registry::CachedId;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::arrays::ConstantArray;
13use crate::arrays::ExtensionArray;
14use crate::arrays::extension::ExtensionArrayExt;
15use crate::dtype::DType;
16use crate::expr::Expression;
17use crate::scalar_fn::Arity;
18use crate::scalar_fn::ChildName;
19use crate::scalar_fn::EmptyOptions;
20use crate::scalar_fn::ExecutionArgs;
21use crate::scalar_fn::ScalarFnId;
22use crate::scalar_fn::ScalarFnVTable;
23
24/// Extract the storage values from an extension array.
25#[derive(Clone)]
26pub struct ExtStorage;
27
28impl ScalarFnVTable for ExtStorage {
29    type Options = EmptyOptions;
30
31    fn id(&self) -> ScalarFnId {
32        static ID: CachedId = CachedId::new("vortex.ext.storage");
33        *ID
34    }
35
36    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
37        Ok(Some(vec![]))
38    }
39
40    fn deserialize(
41        &self,
42        _metadata: &[u8],
43        _session: &VortexSession,
44    ) -> VortexResult<Self::Options> {
45        Ok(EmptyOptions)
46    }
47
48    fn arity(&self, _options: &Self::Options) -> Arity {
49        Arity::Exact(1)
50    }
51
52    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
53        match child_idx {
54            0 => ChildName::from("input"),
55            _ => unreachable!("Invalid child index {child_idx} for ext_storage()"),
56        }
57    }
58
59    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
60        let DType::Extension(ext_dtype) = &arg_dtypes[0] else {
61            vortex_bail!("ext_storage() requires Extension, got {}", arg_dtypes[0]);
62        };
63
64        Ok(ext_dtype.storage_dtype().clone())
65    }
66
67    fn execute(
68        &self,
69        _options: &Self::Options,
70        args: &dyn ExecutionArgs,
71        ctx: &mut ExecutionCtx,
72    ) -> VortexResult<ArrayRef> {
73        let input = args.get(0)?;
74
75        if !matches!(input.dtype(), DType::Extension(_)) {
76            vortex_bail!("ext_storage() requires Extension, got {}", input.dtype());
77        }
78
79        if let Some(scalar) = input.as_constant() {
80            let storage_scalar = scalar.as_extension().to_storage_scalar();
81            return Ok(ConstantArray::new(storage_scalar, args.row_count()).into_array());
82        }
83
84        let input = input.execute::<ExtensionArray>(ctx)?;
85        Ok(input.storage_array().clone())
86    }
87
88    fn validity(
89        &self,
90        _options: &Self::Options,
91        expression: &Expression,
92    ) -> VortexResult<Option<Expression>> {
93        Ok(Some(expression.child(0).validity()?))
94    }
95
96    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
97        false
98    }
99
100    fn is_fallible(&self, _options: &Self::Options) -> bool {
101        false
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use vortex_buffer::buffer;
108    use vortex_error::VortexResult;
109
110    use crate::IntoArray;
111    use crate::VortexSessionExecute;
112    use crate::arrays::ConstantArray;
113    use crate::arrays::ExtensionArray;
114    use crate::arrays::PrimitiveArray;
115    use crate::assert_arrays_eq;
116    use crate::dtype::DType;
117    use crate::dtype::Nullability;
118    use crate::dtype::PType;
119    use crate::dtype::extension::ExtDTypeRef;
120    use crate::expr::ext_storage;
121    use crate::expr::root;
122    use crate::extension::datetime::TimeUnit;
123    use crate::extension::datetime::Timestamp;
124    use crate::scalar::Scalar;
125
126    fn ext_dtype(nullability: Nullability) -> ExtDTypeRef {
127        Timestamp::new(TimeUnit::Nanoseconds, nullability).erased()
128    }
129
130    #[test]
131    fn extracts_extension_storage_array() -> VortexResult<()> {
132        let storage = buffer![2i64, 4, 6].into_array();
133        let array =
134            ExtensionArray::new(ext_dtype(Nullability::NonNullable), storage.clone()).into_array();
135
136        let result = array.apply(&ext_storage(root()))?;
137
138        assert_eq!(
139            result.dtype(),
140            &DType::Primitive(PType::I64, Nullability::NonNullable)
141        );
142        assert_arrays_eq!(
143            result,
144            storage,
145            &mut crate::array_session().create_execution_ctx()
146        );
147        Ok(())
148    }
149
150    #[test]
151    fn extracts_nullable_extension_storage_array() -> VortexResult<()> {
152        let storage = PrimitiveArray::from_option_iter([Some(2i64), None, Some(6)]).into_array();
153        let array =
154            ExtensionArray::new(ext_dtype(Nullability::Nullable), storage.clone()).into_array();
155
156        let result = array.apply(&ext_storage(root()))?;
157
158        assert_eq!(
159            result.dtype(),
160            &DType::Primitive(PType::I64, Nullability::Nullable)
161        );
162        assert_arrays_eq!(
163            result,
164            storage,
165            &mut crate::array_session().create_execution_ctx()
166        );
167        Ok(())
168    }
169
170    #[test]
171    fn extracts_constant_extension_storage_scalar() -> VortexResult<()> {
172        let storage_scalar = Scalar::primitive(4i64, Nullability::NonNullable);
173        let scalar =
174            Scalar::extension_ref(ext_dtype(Nullability::NonNullable), storage_scalar.clone());
175        let array = ConstantArray::new(scalar, 3).into_array();
176
177        let result = array.apply(&ext_storage(root()))?;
178
179        assert_eq!(
180            result.dtype(),
181            &DType::Primitive(PType::I64, Nullability::NonNullable)
182        );
183        assert_arrays_eq!(
184            result,
185            ConstantArray::new(storage_scalar, 3),
186            &mut crate::array_session().create_execution_ctx()
187        );
188        Ok(())
189    }
190
191    #[test]
192    fn rejects_non_extension_input() {
193        let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
194        let err = ext_storage(root()).return_dtype(&dtype).unwrap_err();
195        assert!(err.to_string().contains("requires Extension"));
196    }
197
198    #[test]
199    fn test_display() {
200        assert_eq!(ext_storage(root()).to_string(), "vortex.ext.storage($)");
201    }
202}