Skip to main content

vortex_array/scalar_fn/fns/
byte_length.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::array::ArrayView;
17use crate::array::VTable;
18use crate::arrays::ConstantArray;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::scalar_fn::ExactScalarFn;
22use crate::arrays::scalar_fn::ScalarFnArrayView;
23use crate::arrays::varbinview::VarBinViewArrayExt;
24use crate::dtype::DType;
25use crate::dtype::Nullability;
26use crate::dtype::PType;
27use crate::expr::Expression;
28use crate::kernel::ExecuteParentKernel;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::EmptyOptions;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36
37pub trait ByteLengthKernel: VTable {
38    fn byte_length(
39        array: ArrayView<'_, Self>,
40        ctx: &mut ExecutionCtx,
41    ) -> VortexResult<Option<ArrayRef>>;
42}
43
44#[derive(Default, Debug)]
45pub struct ByteLengthExecuteAdaptor<V>(pub V);
46
47impl<V: ByteLengthKernel> ExecuteParentKernel<V> for ByteLengthExecuteAdaptor<V> {
48    type Parent = ExactScalarFn<ByteLength>;
49
50    fn execute_parent(
51        &self,
52        array: ArrayView<'_, V>,
53        _parent: ScalarFnArrayView<'_, ByteLength>,
54        child_idx: usize,
55        ctx: &mut ExecutionCtx,
56    ) -> VortexResult<Option<ArrayRef>> {
57        vortex_ensure!(child_idx == 0);
58        V::byte_length(array, ctx)
59    }
60}
61
62/// Byte length of each element in a Utf8 or Binary array.
63#[derive(Clone)]
64pub struct ByteLength;
65
66impl ScalarFnVTable for ByteLength {
67    type Options = EmptyOptions;
68
69    fn id(&self) -> ScalarFnId {
70        static ID: CachedId = CachedId::new("vortex.byte_length");
71        *ID
72    }
73
74    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
75        Ok(Some(vec![]))
76    }
77
78    fn deserialize(
79        &self,
80        _metadata: &[u8],
81        _session: &VortexSession,
82    ) -> VortexResult<Self::Options> {
83        Ok(EmptyOptions)
84    }
85
86    fn arity(&self, _options: &Self::Options) -> Arity {
87        Arity::Exact(1)
88    }
89
90    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
91        match child_idx {
92            0 => ChildName::from("input"),
93            _ => unreachable!("Invalid child index {child_idx} for byte_length()"),
94        }
95    }
96
97    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
98        match &arg_dtypes[0] {
99            DType::Utf8(nullable) | DType::Binary(nullable) => {
100                Ok(DType::Primitive(PType::U64, *nullable))
101            }
102            other => vortex_bail!("byte_length() requires Utf8 or Binary, got {other}"),
103        }
104    }
105
106    fn execute(
107        &self,
108        _options: &Self::Options,
109        args: &dyn ExecutionArgs,
110        ctx: &mut ExecutionCtx,
111    ) -> VortexResult<ArrayRef> {
112        let input = args.get(0)?;
113        let nullability = input.dtype().nullability();
114
115        if let Some(scalar) = input.as_constant() {
116            let len_scalar = scalar_byte_length(&scalar, nullability)?;
117            return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array());
118        }
119
120        match input.dtype() {
121            DType::Utf8(_) | DType::Binary(_) => byte_length(&input, nullability, ctx),
122            other => vortex_bail!("byte_length() requires Utf8 or Binary, got {other}"),
123        }
124    }
125
126    fn validity(
127        &self,
128        _: &Self::Options,
129        expression: &Expression,
130    ) -> VortexResult<Option<Expression>> {
131        Ok(Some(expression.child(0).validity()?))
132    }
133
134    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
135        false
136    }
137
138    fn is_fallible(&self, _options: &Self::Options) -> bool {
139        false
140    }
141}
142
143fn scalar_byte_length(scalar: &Scalar, nullability: Nullability) -> VortexResult<Scalar> {
144    if scalar.is_null() {
145        let dtype = DType::Primitive(PType::U64, Nullability::Nullable);
146        return Ok(Scalar::null(dtype));
147    }
148    let len = match scalar.dtype() {
149        DType::Utf8(_) => scalar
150            .as_utf8()
151            .value()
152            .vortex_expect("null utf-8 scalar")
153            .len(),
154        DType::Binary(_) => scalar
155            .as_binary()
156            .value()
157            .vortex_expect("null binary scalar")
158            .len(),
159        other => vortex_bail!("byte_length() requires Utf8 or Binary, got {other}"),
160    };
161    let len: u64 = len.as_();
162    Ok(Scalar::primitive(len, nullability))
163}
164
165pub(crate) fn byte_length(
166    array: &ArrayRef,
167    nullability: Nullability,
168    ctx: &mut ExecutionCtx,
169) -> VortexResult<ArrayRef> {
170    let array = array.clone().execute::<VarBinViewArray>(ctx)?;
171    let validity = array.varbinview_validity();
172    let lengths: Buffer<u64> = array.views().iter().map(|v| v.len() as u64).collect();
173    Ok(PrimitiveArray::new(lengths, validity.union_nullability(nullability)).into_array())
174}
175
176#[cfg(test)]
177mod tests {
178    use rstest::rstest;
179    use vortex_error::VortexResult;
180
181    use crate::ArrayRef;
182    use crate::IntoArray;
183    use crate::VortexSessionExecute;
184    use crate::array_session;
185    use crate::arrays::ConstantArray;
186    use crate::arrays::PrimitiveArray;
187    use crate::arrays::VarBinArray;
188    use crate::arrays::VarBinViewArray;
189    use crate::assert_arrays_eq;
190    use crate::dtype::DType;
191    use crate::dtype::Nullability;
192    use crate::expr::byte_length;
193    use crate::expr::root;
194    use crate::scalar::Scalar;
195
196    #[rstest]
197    #[case(VarBinArray::from_strs(vec!["hello", "world", ""]).into_array(), vec![5u64, 5, 0])]
198    #[case(VarBinArray::from_bytes(vec![b"ab".as_ref(), b"cde"]).into_array(), vec![2u64, 3])]
199    #[case(VarBinArray::from_strs(vec!["Пуховички"]).into_array(), vec![18u64])]
200    #[case(VarBinArray::from_bytes(vec!["Пуховички".as_ref()]).into_array(), vec![18u64])]
201    fn test_bytes_byte_length(
202        #[case] array: ArrayRef,
203        #[case] expected_lens: Vec<u64>,
204    ) -> VortexResult<()> {
205        let mut ctx = array_session().create_execution_ctx();
206        let result = array.apply(&byte_length(root()))?;
207        let expected = PrimitiveArray::from_iter(expected_lens);
208        assert_arrays_eq!(result, expected, &mut ctx);
209        Ok(())
210    }
211
212    #[test]
213    fn test_varbinview_byte_length() -> VortexResult<()> {
214        let mut ctx = array_session().create_execution_ctx();
215        let array = VarBinViewArray::from_iter_str(["short", "a longer string here"]).into_array();
216        let result = array.apply(&byte_length(root()))?;
217        let expected = PrimitiveArray::from_iter(vec![5u64, 20]);
218        assert_arrays_eq!(result, expected, &mut ctx);
219        Ok(())
220    }
221
222    #[test]
223    fn test_nullable_string_byte_length() -> VortexResult<()> {
224        let array = VarBinArray::from_nullable_strs(vec![Some("hello"), None, Some("Пуховички")])
225            .into_array();
226        let result = array.apply(&byte_length(root()))?;
227
228        let mut ctx = array_session().create_execution_ctx();
229        assert!(result.is_valid(0, &mut ctx)?);
230        assert!(!result.is_valid(1, &mut ctx)?);
231        assert!(result.is_valid(2, &mut ctx)?);
232        assert_eq!(
233            result.execute_scalar(0, &mut array_session().create_execution_ctx())?,
234            Scalar::primitive(5u64, Nullability::Nullable),
235        );
236        assert_eq!(
237            result.execute_scalar(2, &mut array_session().create_execution_ctx())?,
238            Scalar::primitive(18u64, Nullability::Nullable),
239        );
240        Ok(())
241    }
242
243    #[test]
244    fn test_null_scalar_byte_length() -> VortexResult<()> {
245        let null_scalar = Scalar::null(DType::Utf8(Nullability::Nullable));
246        let array = ConstantArray::new(null_scalar, 2).into_array();
247        let result = array.apply(&byte_length(root()))?;
248        let mut ctx = array_session().create_execution_ctx();
249        assert!(!result.is_valid(0, &mut ctx)?);
250        assert!(!result.is_valid(1, &mut ctx)?);
251        Ok(())
252    }
253
254    #[test]
255    fn test_display() {
256        let expr = byte_length(root());
257        assert_eq!(expr.to_string(), "vortex.byte_length($)");
258    }
259}