Skip to main content

vortex_array/scalar_fn/fns/
byte_length.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::array::ArrayView;
17use crate::array::VTable;
18use crate::arrays::ConstantArray;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::scalar_fn::ExactScalarFn;
22use crate::arrays::scalar_fn::ScalarFnArrayView;
23use crate::arrays::varbinview::VarBinViewArrayExt;
24use crate::dtype::DType;
25use crate::dtype::Nullability;
26use crate::dtype::PType;
27use crate::kernel::ExecuteParentKernel;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::EmptyOptions;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35
36pub trait ByteLengthKernel: VTable {
37    fn byte_length(
38        array: ArrayView<'_, Self>,
39        ctx: &mut ExecutionCtx,
40    ) -> VortexResult<Option<ArrayRef>>;
41}
42
43#[derive(Default, Debug)]
44pub struct ByteLengthExecuteAdaptor<V>(pub V);
45
46impl<V: ByteLengthKernel> ExecuteParentKernel<V> for ByteLengthExecuteAdaptor<V> {
47    type Parent = ExactScalarFn<ByteLength>;
48
49    fn execute_parent(
50        &self,
51        array: ArrayView<'_, V>,
52        _parent: ScalarFnArrayView<'_, ByteLength>,
53        child_idx: usize,
54        ctx: &mut ExecutionCtx,
55    ) -> VortexResult<Option<ArrayRef>> {
56        vortex_ensure!(child_idx == 0);
57        V::byte_length(array, ctx)
58    }
59}
60
61/// Byte length of each element in a Utf8 or Binary array.
62#[derive(Clone)]
63pub struct ByteLength;
64
65impl ScalarFnVTable for ByteLength {
66    type Options = EmptyOptions;
67
68    fn id(&self) -> ScalarFnId {
69        static ID: CachedId = CachedId::new("vortex.byte_length");
70        *ID
71    }
72
73    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
74        Ok(Some(vec![]))
75    }
76
77    fn deserialize(
78        &self,
79        _metadata: &[u8],
80        _session: &VortexSession,
81    ) -> VortexResult<Self::Options> {
82        Ok(EmptyOptions)
83    }
84
85    fn arity(&self, _options: &Self::Options) -> Arity {
86        Arity::Exact(1)
87    }
88
89    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
90        match child_idx {
91            0 => ChildName::from("input"),
92            _ => unreachable!("Invalid child index {child_idx} for byte_length()"),
93        }
94    }
95
96    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
97        match &arg_dtypes[0] {
98            DType::Utf8(nullable) | DType::Binary(nullable) => {
99                Ok(DType::Primitive(PType::U64, *nullable))
100            }
101            other => vortex_bail!("byte_length() requires Utf8 or Binary, got {other}"),
102        }
103    }
104
105    fn execute(
106        &self,
107        _options: &Self::Options,
108        args: &dyn ExecutionArgs,
109        ctx: &mut ExecutionCtx,
110    ) -> VortexResult<ArrayRef> {
111        let input = args.get(0)?;
112        let nullability = input.dtype().nullability();
113
114        if let Some(scalar) = input.as_constant() {
115            let len_scalar = scalar_byte_length(&scalar, nullability)?;
116            return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array());
117        }
118
119        match input.dtype() {
120            DType::Utf8(_) | DType::Binary(_) => byte_length(&input, nullability, ctx),
121            other => vortex_bail!("byte_length() requires Utf8 or Binary, got {other}"),
122        }
123    }
124
125    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
126        false
127    }
128
129    fn is_fallible(&self, _options: &Self::Options) -> bool {
130        false
131    }
132}
133
134fn scalar_byte_length(scalar: &Scalar, nullability: Nullability) -> VortexResult<Scalar> {
135    if scalar.is_null() {
136        let dtype = DType::Primitive(PType::U64, Nullability::Nullable);
137        return Ok(Scalar::null(dtype));
138    }
139    let len = match scalar.dtype() {
140        DType::Utf8(_) => scalar
141            .as_utf8()
142            .value()
143            .vortex_expect("null utf-8 scalar")
144            .len(),
145        DType::Binary(_) => scalar
146            .as_binary()
147            .value()
148            .vortex_expect("null binary scalar")
149            .len(),
150        other => vortex_bail!("byte_length() requires Utf8 or Binary, got {other}"),
151    };
152    let len: u64 = len.as_();
153    Ok(Scalar::primitive(len, nullability))
154}
155
156pub(crate) fn byte_length(
157    array: &ArrayRef,
158    nullability: Nullability,
159    ctx: &mut ExecutionCtx,
160) -> VortexResult<ArrayRef> {
161    let array = array.clone().execute::<VarBinViewArray>(ctx)?;
162    let validity = array.varbinview_validity();
163    let lengths: Buffer<u64> = array.views().iter().map(|v| v.len() as u64).collect();
164    Ok(PrimitiveArray::new(lengths, validity.union_nullability(nullability)).into_array())
165}
166
167#[cfg(test)]
168mod tests {
169    use rstest::rstest;
170    use vortex_error::VortexResult;
171
172    use crate::ArrayRef;
173    use crate::IntoArray;
174    use crate::LEGACY_SESSION;
175    use crate::VortexSessionExecute;
176    use crate::arrays::ConstantArray;
177    use crate::arrays::PrimitiveArray;
178    use crate::arrays::VarBinArray;
179    use crate::arrays::VarBinViewArray;
180    use crate::assert_arrays_eq;
181    use crate::dtype::DType;
182    use crate::dtype::Nullability;
183    use crate::expr::byte_length;
184    use crate::expr::root;
185    use crate::scalar::Scalar;
186
187    #[rstest]
188    #[case(VarBinArray::from_strs(vec!["hello", "world", ""]).into_array(), vec![5u64, 5, 0])]
189    #[case(VarBinArray::from_bytes(vec![b"ab".as_ref(), b"cde"]).into_array(), vec![2u64, 3])]
190    #[case(VarBinArray::from_strs(vec!["Пуховички"]).into_array(), vec![18u64])]
191    #[case(VarBinArray::from_bytes(vec!["Пуховички".as_ref()]).into_array(), vec![18u64])]
192    fn test_bytes_byte_length(
193        #[case] array: ArrayRef,
194        #[case] expected_lens: Vec<u64>,
195    ) -> VortexResult<()> {
196        let result = array.apply(&byte_length(root()))?;
197        let expected = PrimitiveArray::from_iter(expected_lens);
198        assert_arrays_eq!(result, expected);
199        Ok(())
200    }
201
202    #[test]
203    fn test_varbinview_byte_length() -> VortexResult<()> {
204        let array = VarBinViewArray::from_iter_str(["short", "a longer string here"]).into_array();
205        let result = array.apply(&byte_length(root()))?;
206        let expected = PrimitiveArray::from_iter(vec![5u64, 20]);
207        assert_arrays_eq!(result, expected);
208        Ok(())
209    }
210
211    #[test]
212    fn test_nullable_string_byte_length() -> VortexResult<()> {
213        let array = VarBinArray::from_nullable_strs(vec![Some("hello"), None, Some("Пуховички")])
214            .into_array();
215        let result = array.apply(&byte_length(root()))?;
216
217        let mut ctx = LEGACY_SESSION.create_execution_ctx();
218        assert!(result.is_valid(0, &mut ctx)?);
219        assert!(!result.is_valid(1, &mut ctx)?);
220        assert!(result.is_valid(2, &mut ctx)?);
221        assert_eq!(
222            result.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?,
223            Scalar::primitive(5u64, Nullability::Nullable),
224        );
225        assert_eq!(
226            result.execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())?,
227            Scalar::primitive(18u64, Nullability::Nullable),
228        );
229        Ok(())
230    }
231
232    #[test]
233    fn test_null_scalar_byte_length() -> VortexResult<()> {
234        let null_scalar = Scalar::null(DType::Utf8(Nullability::Nullable));
235        let array = ConstantArray::new(null_scalar, 2).into_array();
236        let result = array.apply(&byte_length(root()))?;
237        let mut ctx = LEGACY_SESSION.create_execution_ctx();
238        assert!(!result.is_valid(0, &mut ctx)?);
239        assert!(!result.is_valid(1, &mut ctx)?);
240        Ok(())
241    }
242
243    #[test]
244    fn test_display() {
245        let expr = byte_length(root());
246        assert_eq!(expr.to_string(), "vortex.byte_length($)");
247    }
248}