Skip to main content

vortex_array/scalar_fn/fns/variant_get/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use prost::Message;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_proto::expr as pb;
13use vortex_proto::expr::variant_path_element;
14use vortex_session::VortexSession;
15use vortex_utils::aliases::StringEscape;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ChunkedArray;
21use crate::arrays::ConstantArray;
22use crate::arrays::VariantArray;
23use crate::builders::builder_with_capacity_in;
24use crate::dtype::DType;
25use crate::dtype::FieldName;
26use crate::dtype::Nullability;
27use crate::expr::Expression;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34
35/// Extracts a field/index path from Variant values.
36///
37/// Missing paths, type mismatches while traversing, and failed casts produce nulls. Without a
38/// requested dtype, results are returned as nullable Variant values; with one, results are cast to
39/// that dtype with nullable nullability. Encodings may serve perfectly shredded paths directly,
40/// but must fall back to the core Variant value for paths not represented by shredded storage.
41#[derive(Clone)]
42pub struct VariantGet;
43
44impl ScalarFnVTable for VariantGet {
45    type Options = VariantGetOptions;
46
47    fn id(&self) -> ScalarFnId {
48        ScalarFnId::new("vortex.variant_get")
49    }
50
51    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        let path = options
53            .path()
54            .elements()
55            .iter()
56            .map(|element| match element {
57                VariantPathElement::Field(name) => pb::VariantPathElement {
58                    element: Some(variant_path_element::Element::Field(
59                        name.as_ref().to_string(),
60                    )),
61                },
62                VariantPathElement::Index(index) => pb::VariantPathElement {
63                    element: Some(variant_path_element::Element::Index(*index)),
64                },
65            })
66            .collect();
67        let dtype = options.dtype().map(TryInto::try_into).transpose()?;
68
69        Ok(Some(pb::VariantGetOpts { path, dtype }.encode_to_vec()))
70    }
71
72    fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
73        let opts = pb::VariantGetOpts::decode(metadata)?;
74        let path = opts
75            .path
76            .into_iter()
77            .map(VariantPathElement::from_proto)
78            .collect::<VortexResult<_>>()?;
79        let dtype = opts
80            .dtype
81            .as_ref()
82            .map(|dtype| DType::from_proto(dtype, session))
83            .transpose()?;
84
85        Ok(VariantGetOptions::new(path, dtype))
86    }
87
88    fn arity(&self, _options: &Self::Options) -> Arity {
89        Arity::Exact(1)
90    }
91
92    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
93        match child_idx {
94            0 => ChildName::from("input"),
95            _ => unreachable!("Invalid child index {child_idx} for VariantGet expression"),
96        }
97    }
98
99    fn fmt_sql(
100        &self,
101        options: &Self::Options,
102        expr: &Expression,
103        f: &mut Formatter<'_>,
104    ) -> fmt::Result {
105        write!(f, "variant_get(")?;
106        expr.child(0).fmt_sql(f)?;
107        let path = options.path().to_string();
108        write!(f, ", \"{}\"", StringEscape(&path))?;
109        if let Some(dtype) = options.dtype() {
110            write!(f, ", {dtype}")?;
111        }
112        write!(f, ")")
113    }
114
115    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
116        let input_dtype = &arg_dtypes[0];
117        vortex_ensure!(
118            matches!(input_dtype, DType::Variant(_)),
119            "VariantGet input must be Variant, found {input_dtype}"
120        );
121
122        // Missing paths, traversal mismatches, and cast failures all produce nulls.
123        Ok(options
124            .dtype()
125            .map_or(DType::Variant(Nullability::Nullable), DType::as_nullable))
126    }
127
128    fn execute(
129        &self,
130        options: &Self::Options,
131        args: &dyn ExecutionArgs,
132        ctx: &mut ExecutionCtx,
133    ) -> VortexResult<ArrayRef> {
134        let input = args.get(0)?;
135        // Missing paths, traversal mismatches, and cast failures all produce nulls.
136        let dtype = options
137            .dtype()
138            .map_or(DType::Variant(Nullability::Nullable), DType::as_nullable);
139
140        if !dtype.is_variant() {
141            let mut builder = builder_with_capacity_in(ctx.allocator(), &dtype, input.len());
142            for idx in 0..input.len() {
143                let scalar = input.execute_scalar(idx, ctx)?;
144                let output = variant_get_scalar(&scalar, options, &dtype)?;
145                builder.append_scalar(&output)?;
146            }
147
148            return Ok(builder.finish_into_canonical().into_array());
149        }
150
151        // TODO(variant): replace this with a Variant builder once one exists.
152        // Chunked<Variant> canonicalizes to VariantArray, so this row-wise fallback is safe.
153        let mut chunks = Vec::with_capacity(input.len());
154
155        for idx in 0..input.len() {
156            let scalar = input.execute_scalar(idx, ctx)?;
157            let output = variant_get_scalar(&scalar, options, &dtype)?;
158            chunks.push(ConstantArray::new(output, 1).into_array());
159        }
160
161        let array = ChunkedArray::try_new(chunks, dtype)?.into_array();
162        VariantArray::try_new(array, None).map(|array| array.into_array())
163    }
164}
165
166fn variant_get_scalar(
167    scalar: &Scalar,
168    options: &VariantGetOptions,
169    output_dtype: &DType,
170) -> VortexResult<Scalar> {
171    let Some(value) = variant_path_scalar(scalar, options.path().elements())? else {
172        return Ok(Scalar::null(output_dtype.clone()));
173    };
174
175    if options.dtype().is_none_or(DType::is_variant) {
176        return Scalar::variant(value).cast(output_dtype);
177    }
178
179    if value.is_null() {
180        return Ok(Scalar::null(output_dtype.clone()));
181    }
182
183    value
184        .cast(output_dtype)
185        .or_else(|_| Ok(Scalar::null(output_dtype.clone())))
186}
187
188fn variant_path_scalar(
189    scalar: &Scalar,
190    path: &[VariantPathElement],
191) -> VortexResult<Option<Scalar>> {
192    let mut current = match variant_payload(scalar.clone()) {
193        Some(value) => value,
194        None => return Ok(None),
195    };
196
197    for element in path {
198        current = match variant_payload(current) {
199            Some(value) => value,
200            None => return Ok(None),
201        };
202
203        if current.is_null() {
204            return Ok(None);
205        }
206
207        current = match element {
208            VariantPathElement::Field(name) => {
209                let Some(struct_scalar) = current.as_struct_opt() else {
210                    return Ok(None);
211                };
212                if struct_scalar.is_null() {
213                    return Ok(None);
214                }
215                let Some(field) = struct_scalar.field(name.as_ref()) else {
216                    return Ok(None);
217                };
218                field
219            }
220            VariantPathElement::Index(index) => {
221                let Ok(index) = usize::try_from(*index) else {
222                    return Ok(None);
223                };
224                let Some(list_scalar) = current.as_list_opt() else {
225                    return Ok(None);
226                };
227                let Some(element) = list_scalar.element(index) else {
228                    return Ok(None);
229                };
230                element
231            }
232        };
233    }
234
235    Ok(variant_payload(current))
236}
237
238fn variant_payload(scalar: Scalar) -> Option<Scalar> {
239    if scalar.dtype().is_variant() {
240        scalar.as_variant().value().cloned()
241    } else {
242        Some(scalar)
243    }
244}
245
246/// Options for [`VariantGet`].
247#[derive(Clone, Debug, PartialEq, Eq, Hash)]
248pub struct VariantGetOptions {
249    path: VariantPath,
250    dtype: Option<DType>,
251}
252
253impl VariantGetOptions {
254    /// Creates options for extracting `path`, returning Variant values when `dtype` is `None`.
255    pub fn new(path: VariantPath, dtype: Option<DType>) -> Self {
256        Self { path, dtype }
257    }
258
259    /// Returns the path to extract.
260    pub fn path(&self) -> &VariantPath {
261        &self.path
262    }
263
264    /// Returns the requested output dtype, if any.
265    pub fn dtype(&self) -> Option<&DType> {
266        self.dtype.as_ref()
267    }
268}
269
270impl Display for VariantGetOptions {
271    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
272        write!(f, "{}", self.path)?;
273        if let Some(dtype) = &self.dtype {
274            write!(f, " as {dtype}")?;
275        }
276        Ok(())
277    }
278}
279
280/// A strict Variant path made from object fields and list indexes.
281#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
282pub struct VariantPath(Vec<VariantPathElement>);
283
284impl VariantPath {
285    /// Creates a path from explicit elements.
286    pub fn new(elements: impl IntoIterator<Item = VariantPathElement>) -> Self {
287        Self(elements.into_iter().collect())
288    }
289
290    /// Creates the root path.
291    pub fn root() -> Self {
292        Self::default()
293    }
294
295    /// Creates a path containing one object field.
296    pub fn field(field: impl Into<FieldName>) -> Self {
297        Self(vec![VariantPathElement::field(field)])
298    }
299
300    /// Returns the path elements.
301    pub fn elements(&self) -> &[VariantPathElement] {
302        &self.0
303    }
304
305    /// Returns whether this path references the root Variant value.
306    pub fn is_root(&self) -> bool {
307        self.0.is_empty()
308    }
309}
310
311impl FromIterator<VariantPathElement> for VariantPath {
312    fn from_iter<T: IntoIterator<Item = VariantPathElement>>(iter: T) -> Self {
313        Self(iter.into_iter().collect())
314    }
315}
316
317impl From<VariantPathElement> for VariantPath {
318    fn from(value: VariantPathElement) -> Self {
319        Self(vec![value])
320    }
321}
322
323impl Display for VariantPath {
324    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
325        write!(f, "$")?;
326        for element in self.elements() {
327            match element {
328                VariantPathElement::Field(name) => write!(f, ".{name}")?,
329                VariantPathElement::Index(index) => write!(f, "[{index}]")?,
330            }
331        }
332        Ok(())
333    }
334}
335
336/// A single field or index step in a [`VariantPath`].
337#[derive(Clone, Debug, PartialEq, Eq, Hash)]
338pub enum VariantPathElement {
339    /// Select an object field by name.
340    Field(FieldName),
341    /// Select a list element by zero-based index.
342    Index(u64),
343}
344
345impl VariantPathElement {
346    /// Creates an object-field path element.
347    pub fn field(field: impl Into<FieldName>) -> Self {
348        Self::Field(field.into())
349    }
350
351    /// Creates a list-index path element.
352    pub fn index(index: u64) -> Self {
353        Self::Index(index)
354    }
355
356    fn from_proto(value: pb::VariantPathElement) -> VortexResult<Self> {
357        match value
358            .element
359            .ok_or_else(|| vortex_err!("VariantGet path element missing value"))?
360        {
361            variant_path_element::Element::Field(field) => Ok(Self::field(field)),
362            variant_path_element::Element::Index(index) => Ok(Self::index(index)),
363        }
364    }
365}
366
367impl From<FieldName> for VariantPathElement {
368    fn from(value: FieldName) -> Self {
369        Self::field(value)
370    }
371}
372
373impl From<&str> for VariantPathElement {
374    fn from(value: &str) -> Self {
375        Self::field(value)
376    }
377}
378
379impl From<u64> for VariantPathElement {
380    fn from(value: u64) -> Self {
381        Self::index(value)
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use vortex_error::VortexResult;
388    use vortex_error::vortex_bail;
389    use vortex_error::vortex_ensure;
390    use vortex_error::vortex_err;
391    use vortex_session::VortexSession;
392
393    use crate::ArrayRef;
394    use crate::Canonical;
395    use crate::IntoArray;
396    use crate::LEGACY_SESSION;
397    use crate::VortexSessionExecute;
398    use crate::arrays::Chunked;
399    use crate::arrays::ChunkedArray;
400    use crate::arrays::ConstantArray;
401    use crate::arrays::PrimitiveArray;
402    use crate::arrays::VariantArray;
403    use crate::arrays::variant::VariantArrayExt;
404    use crate::assert_arrays_eq;
405    use crate::assert_nth_scalar_is_null;
406    use crate::dtype::DType;
407    use crate::dtype::FieldName;
408    use crate::dtype::FieldNames;
409    use crate::dtype::Nullability;
410    use crate::dtype::PType;
411    use crate::dtype::StructFields;
412    use crate::expr::Expression;
413    use crate::expr::proto::ExprSerializeProtoExt;
414    use crate::expr::root;
415    use crate::expr::variant_get;
416    use crate::scalar::Scalar;
417    use crate::scalar::ScalarValue;
418    use crate::scalar_fn::ScalarFnVTable;
419    use crate::scalar_fn::fns::variant_get::VariantGet;
420    use crate::scalar_fn::fns::variant_get::VariantGetOptions;
421    use crate::scalar_fn::fns::variant_get::VariantPath;
422    use crate::scalar_fn::fns::variant_get::VariantPathElement;
423
424    fn variant_object(fields: impl IntoIterator<Item = (&'static str, Scalar)>) -> Scalar {
425        let fields = fields.into_iter().collect::<Vec<_>>();
426        let names = FieldNames::from_iter(fields.iter().map(|(name, _)| FieldName::from(*name)));
427        let dtypes = vec![DType::Variant(Nullability::NonNullable); fields.len()];
428        let values = fields
429            .into_iter()
430            .map(|(_, value)| Scalar::variant(value).into_value())
431            .collect();
432        Scalar::try_new(
433            DType::Struct(StructFields::new(names, dtypes), Nullability::NonNullable),
434            Some(ScalarValue::Tuple(values)),
435        )
436        .unwrap()
437    }
438
439    fn variant_rows(rows: impl IntoIterator<Item = Scalar>) -> VortexResult<ArrayRef> {
440        let dtype = DType::Variant(Nullability::Nullable);
441        let chunks = rows
442            .into_iter()
443            .map(|row| ConstantArray::new(row.cast(&dtype).unwrap(), 1).into_array())
444            .collect();
445        ChunkedArray::try_new(chunks, dtype).map(|array| array.into_array())
446    }
447
448    /// Test-only syntax for keeping `variant_get` cases compact without committing
449    /// to a public string grammar yet.
450    fn parse_path(path: &str) -> VortexResult<VariantPath> {
451        if path.is_empty() || path == "$" {
452            return Ok(VariantPath::root());
453        }
454
455        let mut elements = Vec::new();
456        let mut pos = usize::from(path.as_bytes().first() == Some(&b'$'));
457        if pos == 1
458            && path
459                .as_bytes()
460                .get(pos)
461                .is_some_and(|byte| !matches!(byte, b'.' | b'['))
462        {
463            vortex_bail!("Invalid Variant path {path:?}: expected '.' or '[' after '$'");
464        }
465
466        while pos < path.len() {
467            match path.as_bytes()[pos] {
468                b'.' => {
469                    pos += 1;
470                    let (field, next_pos) = parse_field(path, pos)?;
471                    elements.push(VariantPathElement::field(field));
472                    pos = next_pos;
473                }
474                b'[' => {
475                    let (index, next_pos) = parse_index(path, pos + 1)?;
476                    elements.push(VariantPathElement::index(index));
477                    pos = next_pos;
478                }
479                _ if pos == 0 => {
480                    let (field, next_pos) = parse_field(path, pos)?;
481                    elements.push(VariantPathElement::field(field));
482                    pos = next_pos;
483                }
484                _ => {
485                    vortex_bail!("Invalid Variant path {path:?}: expected '.', '[', or end of path")
486                }
487            }
488        }
489
490        Ok(VariantPath::new(elements))
491    }
492
493    fn parse_field(path: &str, start: usize) -> VortexResult<(&str, usize)> {
494        let mut pos = start;
495        while path
496            .as_bytes()
497            .get(pos)
498            .is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
499        {
500            pos += 1;
501        }
502        vortex_ensure!(
503            pos > start,
504            "Invalid Variant path {path:?}: expected field name"
505        );
506        Ok((&path[start..pos], pos))
507    }
508
509    fn parse_index(path: &str, start: usize) -> VortexResult<(u64, usize)> {
510        let mut pos = start;
511        while path
512            .as_bytes()
513            .get(pos)
514            .is_some_and(|byte| byte.is_ascii_digit())
515        {
516            pos += 1;
517        }
518        vortex_ensure!(
519            pos > start,
520            "Invalid Variant path {path:?}: expected list index"
521        );
522        vortex_ensure!(
523            path.as_bytes().get(pos) == Some(&b']'),
524            "Invalid Variant path {path:?}: expected closing ']'"
525        );
526        let index = path[start..pos]
527            .parse()
528            .map_err(|_| vortex_err!("Invalid Variant path {path:?}: list index is too large"))?;
529        Ok((index, pos + 1))
530    }
531
532    fn execute_variant_get(
533        array: ArrayRef,
534        path: &str,
535        dtype: Option<DType>,
536    ) -> VortexResult<ArrayRef> {
537        let expr = variant_get(root(), parse_path(path)?, dtype);
538        array
539            .apply(&expr)?
540            .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())
541    }
542
543    #[test]
544    fn variant_get_path_parse_and_display() {
545        let path = parse_path("$.data[1].a").unwrap();
546        assert_eq!(
547            path.elements(),
548            &[
549                VariantPathElement::field("data"),
550                VariantPathElement::index(1),
551                VariantPathElement::field("a")
552            ]
553        );
554        assert_eq!(path.to_string(), "$.data[1].a");
555
556        let bare_path = parse_path("data[2]").unwrap();
557        assert_eq!(bare_path.to_string(), "$.data[2]");
558        assert!(parse_path("$.").is_err());
559        assert!(parse_path("$data").is_err());
560        assert!(parse_path("$.data[-1]").is_err());
561    }
562
563    #[test]
564    fn variant_get_return_dtype_is_nullable_variant_without_requested_dtype() {
565        let expr = variant_get(root(), VariantPath::field("data"), None);
566        let dtype = expr
567            .return_dtype(&DType::Variant(Nullability::NonNullable))
568            .unwrap();
569
570        assert_eq!(dtype, DType::Variant(Nullability::Nullable));
571    }
572
573    #[test]
574    fn variant_get_return_dtype_makes_requested_dtype_nullable() {
575        let requested = DType::Primitive(PType::I64, Nullability::NonNullable);
576        let expr = variant_get(root(), VariantPath::field("data"), Some(requested));
577        let dtype = expr
578            .return_dtype(&DType::Variant(Nullability::NonNullable))
579            .unwrap();
580
581        assert_eq!(dtype, DType::Primitive(PType::I64, Nullability::Nullable));
582    }
583
584    #[test]
585    fn variant_get_rejects_non_variant_input() {
586        let expr = variant_get(root(), VariantPath::field("data"), None);
587        let err = expr
588            .return_dtype(&DType::Utf8(Nullability::NonNullable))
589            .unwrap_err();
590
591        assert!(err.to_string().contains("VariantGet input must be Variant"));
592    }
593
594    #[test]
595    fn variant_get_formats_sql() {
596        let expr = variant_get(
597            root(),
598            parse_path("$.data[1].a").unwrap(),
599            Some(DType::Utf8(Nullability::NonNullable)),
600        );
601
602        assert_eq!(expr.to_string(), "variant_get($, \"$.data[1].a\", utf8)");
603    }
604
605    #[test]
606    fn variant_get_options_roundtrip_serialization() {
607        let options = VariantGetOptions::new(
608            VariantPath::new([
609                VariantPathElement::field("data"),
610                VariantPathElement::index(1),
611                VariantPathElement::field("a"),
612            ]),
613            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
614        );
615        let metadata = VariantGet.serialize(&options).unwrap().unwrap();
616        let actual = VariantGet
617            .deserialize(&metadata, &VortexSession::empty())
618            .unwrap();
619
620        assert_eq!(actual, options);
621    }
622
623    #[test]
624    fn variant_get_expression_roundtrip_serialization() {
625        let expr: Expression = variant_get(
626            root(),
627            parse_path("$.data[1].a").unwrap(),
628            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
629        );
630        let proto = expr.serialize_proto().unwrap();
631        let actual = Expression::from_proto(&proto, &LEGACY_SESSION).unwrap();
632
633        assert_eq!(actual, expr);
634    }
635
636    #[test]
637    fn variant_get_generic_fallback_extracts_field_and_list_index() -> VortexResult<()> {
638        let items = Scalar::list(
639            DType::Variant(Nullability::NonNullable),
640            vec![
641                Scalar::variant(Scalar::primitive(10i32, Nullability::NonNullable)),
642                Scalar::variant(Scalar::primitive(20i32, Nullability::NonNullable)),
643            ],
644            Nullability::NonNullable,
645        );
646        let array = variant_rows([
647            Scalar::variant(variant_object([("items", items)])),
648            Scalar::variant(variant_object([(
649                "items",
650                Scalar::list_empty(
651                    DType::Variant(Nullability::NonNullable).into(),
652                    Nullability::NonNullable,
653                ),
654            )])),
655            Scalar::variant(variant_object([(
656                "items",
657                Scalar::list(
658                    DType::Variant(Nullability::NonNullable),
659                    vec![
660                        Scalar::variant(Scalar::utf8("x", Nullability::NonNullable)),
661                        Scalar::variant(Scalar::utf8("wrong", Nullability::NonNullable)),
662                    ],
663                    Nullability::NonNullable,
664                ),
665            )])),
666        ])?;
667
668        let result = execute_variant_get(
669            array,
670            "$.items[1]",
671            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
672        )?;
673
674        assert_arrays_eq!(
675            result,
676            PrimitiveArray::from_option_iter([Some(20i32), None, None])
677        );
678        Ok(())
679    }
680
681    #[test]
682    fn variant_get_reads_chunked_variant_input() -> VortexResult<()> {
683        let array = variant_rows([
684            Scalar::variant(variant_object([(
685                "a",
686                Scalar::primitive(10i32, Nullability::NonNullable),
687            )])),
688            Scalar::variant(variant_object([(
689                "b",
690                Scalar::primitive(20i32, Nullability::NonNullable),
691            )])),
692            Scalar::variant(variant_object([(
693                "a",
694                Scalar::primitive(30i32, Nullability::NonNullable),
695            )])),
696            Scalar::null(DType::Variant(Nullability::Nullable)),
697        ])?;
698        assert!(array.is::<Chunked>());
699
700        let result = execute_variant_get(
701            array,
702            "$.a",
703            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
704        )?;
705
706        assert_arrays_eq!(
707            result,
708            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None])
709        );
710        Ok(())
711    }
712
713    #[test]
714    fn variant_get_fallback_typed_output_is_contiguous() -> VortexResult<()> {
715        let array = variant_rows([
716            Scalar::variant(variant_object([(
717                "a",
718                Scalar::primitive(10i32, Nullability::NonNullable),
719            )])),
720            Scalar::variant(variant_object([(
721                "a",
722                Scalar::primitive(20i32, Nullability::NonNullable),
723            )])),
724            Scalar::variant(variant_object([(
725                "b",
726                Scalar::primitive(30i32, Nullability::NonNullable),
727            )])),
728        ])?;
729
730        let result = execute_variant_get(
731            array,
732            "$.a",
733            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
734        )?;
735
736        assert!(!result.is::<Chunked>());
737        assert_arrays_eq!(
738            result,
739            PrimitiveArray::from_option_iter([Some(10i32), Some(20), None])
740        );
741        Ok(())
742    }
743
744    #[test]
745    fn variant_get_generic_fallback_preserves_variant_null() -> VortexResult<()> {
746        let array = variant_rows([
747            Scalar::variant(variant_object([(
748                "a",
749                Scalar::utf8("ok", Nullability::NonNullable),
750            )])),
751            Scalar::null(DType::Variant(Nullability::Nullable)),
752            Scalar::variant(variant_object([("a", Scalar::null(DType::Null))])),
753            Scalar::variant(variant_object([(
754                "b",
755                Scalar::primitive(2i32, Nullability::NonNullable),
756            )])),
757        ])?;
758
759        let result = execute_variant_get(array, "$.a", None)?;
760
761        let mut ctx = LEGACY_SESSION.create_execution_ctx();
762        let row0 = result.execute_scalar(0, &mut ctx)?;
763        assert_eq!(
764            row0.as_variant()
765                .value()
766                .and_then(|value| value.as_utf8().value())
767                .map(|value| value.as_str()),
768            Some("ok")
769        );
770        assert_nth_scalar_is_null!(result, 1);
771        assert_eq!(
772            result
773                .execute_scalar(2, &mut ctx)?
774                .as_variant()
775                .is_variant_null(),
776            Some(true)
777        );
778        assert_nth_scalar_is_null!(result, 3);
779        Ok(())
780    }
781
782    #[test]
783    fn variant_get_fallback_variant_output_canonicalizes() -> VortexResult<()> {
784        let array = variant_rows([
785            Scalar::variant(variant_object([(
786                "a",
787                Scalar::primitive(10i32, Nullability::NonNullable),
788            )])),
789            Scalar::variant(variant_object([(
790                "a",
791                Scalar::primitive(20i32, Nullability::NonNullable),
792            )])),
793        ])?;
794
795        let result = execute_variant_get(array, "$.a", None)?;
796        let variant = result
797            .clone()
798            .execute::<VariantArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
799        let canonical = result.execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())?;
800        let Canonical::Variant(canonical_variant) = canonical else {
801            vortex_bail!("expected Variant canonical array");
802        };
803
804        assert_eq!(variant.len(), 2);
805        assert_eq!(canonical_variant.len(), 2);
806        assert_eq!(variant.core_storage().dtype(), variant.dtype());
807        assert_eq!(variant.core_storage().len(), variant.len());
808
809        let mut ctx = LEGACY_SESSION.create_execution_ctx();
810        for (idx, expected) in [10i32, 20].into_iter().enumerate() {
811            let scalar = variant.execute_scalar(idx, &mut ctx)?;
812            let actual = scalar
813                .as_variant()
814                .value()
815                .and_then(|value| value.as_primitive().as_::<i32>());
816            assert_eq!(actual, Some(expected), "row {idx}");
817        }
818        Ok(())
819    }
820}