Skip to main content

vortex_array/scalar_fn/fns/variant_get/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use prost::Message;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_proto::expr as pb;
13use vortex_proto::expr::variant_path_element;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16use vortex_utils::aliases::StringEscape;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::ChunkedArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::VariantArray;
24use crate::builders::builder_with_capacity_in;
25use crate::dtype::DType;
26use crate::dtype::FieldName;
27use crate::dtype::Nullability;
28use crate::expr::Expression;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35
36/// Extracts a field/index path from Variant values.
37///
38/// Missing paths, type mismatches while traversing, and failed casts produce nulls. Without a
39/// requested dtype, results are returned as nullable Variant values; with one, results are cast to
40/// that dtype with nullable nullability. Encodings may serve perfectly shredded paths directly,
41/// but must fall back to the core Variant value for paths not represented by shredded storage.
42#[derive(Clone)]
43pub struct VariantGet;
44
45impl ScalarFnVTable for VariantGet {
46    type Options = VariantGetOptions;
47
48    fn id(&self) -> ScalarFnId {
49        static ID: CachedId = CachedId::new("vortex.variant_get");
50        *ID
51    }
52
53    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
54        let path = options
55            .path()
56            .elements()
57            .iter()
58            .map(|element| match element {
59                VariantPathElement::Field(name) => pb::VariantPathElement {
60                    element: Some(variant_path_element::Element::Field(
61                        name.as_ref().to_string(),
62                    )),
63                },
64                VariantPathElement::Index(index) => pb::VariantPathElement {
65                    element: Some(variant_path_element::Element::Index(*index)),
66                },
67            })
68            .collect();
69        let dtype = options.dtype().map(TryInto::try_into).transpose()?;
70
71        Ok(Some(pb::VariantGetOpts { path, dtype }.encode_to_vec()))
72    }
73
74    fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
75        let opts = pb::VariantGetOpts::decode(metadata)?;
76        let path = opts
77            .path
78            .into_iter()
79            .map(VariantPathElement::from_proto)
80            .collect::<VortexResult<_>>()?;
81        let dtype = opts
82            .dtype
83            .as_ref()
84            .map(|dtype| DType::from_proto(dtype, session))
85            .transpose()?;
86
87        Ok(VariantGetOptions::new(path, dtype))
88    }
89
90    fn arity(&self, _options: &Self::Options) -> Arity {
91        Arity::Exact(1)
92    }
93
94    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
95        match child_idx {
96            0 => ChildName::from("input"),
97            _ => unreachable!("Invalid child index {child_idx} for VariantGet expression"),
98        }
99    }
100
101    fn fmt_sql(
102        &self,
103        options: &Self::Options,
104        expr: &Expression,
105        f: &mut Formatter<'_>,
106    ) -> fmt::Result {
107        write!(f, "variant_get(")?;
108        expr.child(0).fmt_sql(f)?;
109        let path = options.path().to_string();
110        write!(f, ", \"{}\"", StringEscape(&path))?;
111        if let Some(dtype) = options.dtype() {
112            write!(f, ", {dtype}")?;
113        }
114        write!(f, ")")
115    }
116
117    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
118        let input_dtype = &arg_dtypes[0];
119        vortex_ensure!(
120            matches!(input_dtype, DType::Variant(_)),
121            "VariantGet input must be Variant, found {input_dtype}"
122        );
123
124        // Missing paths, traversal mismatches, and cast failures all produce nulls.
125        Ok(options
126            .dtype()
127            .map_or(DType::Variant(Nullability::Nullable), DType::as_nullable))
128    }
129
130    fn execute(
131        &self,
132        options: &Self::Options,
133        args: &dyn ExecutionArgs,
134        ctx: &mut ExecutionCtx,
135    ) -> VortexResult<ArrayRef> {
136        let input = args.get(0)?;
137        // Missing paths, traversal mismatches, and cast failures all produce nulls.
138        let dtype = options
139            .dtype()
140            .map_or(DType::Variant(Nullability::Nullable), DType::as_nullable);
141
142        if !dtype.is_variant() {
143            let mut builder = builder_with_capacity_in(ctx.allocator(), &dtype, input.len());
144            for idx in 0..input.len() {
145                let scalar = input.execute_scalar(idx, ctx)?;
146                let output = variant_get_scalar(&scalar, options, &dtype)?;
147                builder.append_scalar(&output)?;
148            }
149
150            return Ok(builder.finish_into_canonical().into_array());
151        }
152
153        // TODO(variant): replace this with a Variant builder once one exists.
154        // Chunked<Variant> canonicalizes to VariantArray, so this row-wise fallback is safe.
155        let mut chunks = Vec::with_capacity(input.len());
156
157        for idx in 0..input.len() {
158            let scalar = input.execute_scalar(idx, ctx)?;
159            let output = variant_get_scalar(&scalar, options, &dtype)?;
160            chunks.push(ConstantArray::new(output, 1).into_array());
161        }
162
163        let array = ChunkedArray::try_new(chunks, dtype)?.into_array();
164        VariantArray::try_new(array, None).map(|array| array.into_array())
165    }
166}
167
168fn variant_get_scalar(
169    scalar: &Scalar,
170    options: &VariantGetOptions,
171    output_dtype: &DType,
172) -> VortexResult<Scalar> {
173    let Some(value) = variant_path_scalar(scalar, options.path().elements())? else {
174        return Ok(Scalar::null(output_dtype.clone()));
175    };
176
177    if options.dtype().is_none_or(DType::is_variant) {
178        return Scalar::variant(value).cast(output_dtype);
179    }
180
181    if value.is_null() {
182        return Ok(Scalar::null(output_dtype.clone()));
183    }
184
185    value
186        .cast(output_dtype)
187        .or_else(|_| Ok(Scalar::null(output_dtype.clone())))
188}
189
190fn variant_path_scalar(
191    scalar: &Scalar,
192    path: &[VariantPathElement],
193) -> VortexResult<Option<Scalar>> {
194    let mut current = match variant_payload(scalar.clone()) {
195        Some(value) => value,
196        None => return Ok(None),
197    };
198
199    for element in path {
200        current = match variant_payload(current) {
201            Some(value) => value,
202            None => return Ok(None),
203        };
204
205        if current.is_null() {
206            return Ok(None);
207        }
208
209        current = match element {
210            VariantPathElement::Field(name) => {
211                let Some(struct_scalar) = current.as_struct_opt() else {
212                    return Ok(None);
213                };
214                if struct_scalar.is_null() {
215                    return Ok(None);
216                }
217                let Some(field) = struct_scalar.field(name.as_ref()) else {
218                    return Ok(None);
219                };
220                field
221            }
222            VariantPathElement::Index(index) => {
223                let Ok(index) = usize::try_from(*index) else {
224                    return Ok(None);
225                };
226                let Some(list_scalar) = current.as_list_opt() else {
227                    return Ok(None);
228                };
229                let Some(element) = list_scalar.element(index) else {
230                    return Ok(None);
231                };
232                element
233            }
234        };
235    }
236
237    Ok(variant_payload(current))
238}
239
240fn variant_payload(scalar: Scalar) -> Option<Scalar> {
241    if scalar.dtype().is_variant() {
242        scalar.as_variant().value().cloned()
243    } else {
244        Some(scalar)
245    }
246}
247
248/// Options for [`VariantGet`].
249#[derive(Clone, Debug, PartialEq, Eq, Hash)]
250pub struct VariantGetOptions {
251    path: VariantPath,
252    dtype: Option<DType>,
253}
254
255impl VariantGetOptions {
256    /// Creates options for extracting `path`, returning Variant values when `dtype` is `None`.
257    pub fn new(path: VariantPath, dtype: Option<DType>) -> Self {
258        Self { path, dtype }
259    }
260
261    /// Returns the path to extract.
262    pub fn path(&self) -> &VariantPath {
263        &self.path
264    }
265
266    /// Returns the requested output dtype, if any.
267    pub fn dtype(&self) -> Option<&DType> {
268        self.dtype.as_ref()
269    }
270}
271
272impl Display for VariantGetOptions {
273    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
274        write!(f, "{}", self.path)?;
275        if let Some(dtype) = &self.dtype {
276            write!(f, " as {dtype}")?;
277        }
278        Ok(())
279    }
280}
281
282/// A strict Variant path made from object fields and list indexes.
283#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
284pub struct VariantPath(Vec<VariantPathElement>);
285
286impl VariantPath {
287    /// Creates a path from explicit elements.
288    pub fn new(elements: impl IntoIterator<Item = VariantPathElement>) -> Self {
289        Self(elements.into_iter().collect())
290    }
291
292    /// Creates the root path.
293    pub fn root() -> Self {
294        Self::default()
295    }
296
297    /// Creates a path containing one object field.
298    pub fn field(field: impl Into<FieldName>) -> Self {
299        Self(vec![VariantPathElement::field(field)])
300    }
301
302    /// Returns the path elements.
303    pub fn elements(&self) -> &[VariantPathElement] {
304        &self.0
305    }
306
307    /// Returns whether this path references the root Variant value.
308    pub fn is_root(&self) -> bool {
309        self.0.is_empty()
310    }
311}
312
313impl FromIterator<VariantPathElement> for VariantPath {
314    fn from_iter<T: IntoIterator<Item = VariantPathElement>>(iter: T) -> Self {
315        Self(iter.into_iter().collect())
316    }
317}
318
319impl From<VariantPathElement> for VariantPath {
320    fn from(value: VariantPathElement) -> Self {
321        Self(vec![value])
322    }
323}
324
325impl Display for VariantPath {
326    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
327        write!(f, "$")?;
328        for element in self.elements() {
329            match element {
330                VariantPathElement::Field(name) => write!(f, ".{name}")?,
331                VariantPathElement::Index(index) => write!(f, "[{index}]")?,
332            }
333        }
334        Ok(())
335    }
336}
337
338/// A single field or index step in a [`VariantPath`].
339#[derive(Clone, Debug, PartialEq, Eq, Hash)]
340pub enum VariantPathElement {
341    /// Select an object field by name.
342    Field(FieldName),
343    /// Select a list element by zero-based index.
344    Index(u64),
345}
346
347impl VariantPathElement {
348    /// Creates an object-field path element.
349    pub fn field(field: impl Into<FieldName>) -> Self {
350        Self::Field(field.into())
351    }
352
353    /// Creates a list-index path element.
354    pub fn index(index: u64) -> Self {
355        Self::Index(index)
356    }
357
358    fn from_proto(value: pb::VariantPathElement) -> VortexResult<Self> {
359        match value
360            .element
361            .ok_or_else(|| vortex_err!("VariantGet path element missing value"))?
362        {
363            variant_path_element::Element::Field(field) => Ok(Self::field(field)),
364            variant_path_element::Element::Index(index) => Ok(Self::index(index)),
365        }
366    }
367}
368
369impl From<FieldName> for VariantPathElement {
370    fn from(value: FieldName) -> Self {
371        Self::field(value)
372    }
373}
374
375impl From<&str> for VariantPathElement {
376    fn from(value: &str) -> Self {
377        Self::field(value)
378    }
379}
380
381impl From<u64> for VariantPathElement {
382    fn from(value: u64) -> Self {
383        Self::index(value)
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use vortex_error::VortexResult;
390    use vortex_error::vortex_bail;
391    use vortex_error::vortex_ensure;
392    use vortex_error::vortex_err;
393    use vortex_session::VortexSession;
394
395    use crate::ArrayRef;
396    use crate::Canonical;
397    use crate::IntoArray;
398    use crate::LEGACY_SESSION;
399    use crate::VortexSessionExecute;
400    use crate::arrays::Chunked;
401    use crate::arrays::ChunkedArray;
402    use crate::arrays::ConstantArray;
403    use crate::arrays::PrimitiveArray;
404    use crate::arrays::VariantArray;
405    use crate::arrays::variant::VariantArrayExt;
406    use crate::assert_arrays_eq;
407    use crate::assert_nth_scalar_is_null;
408    use crate::dtype::DType;
409    use crate::dtype::FieldName;
410    use crate::dtype::FieldNames;
411    use crate::dtype::Nullability;
412    use crate::dtype::PType;
413    use crate::dtype::StructFields;
414    use crate::expr::Expression;
415    use crate::expr::proto::ExprSerializeProtoExt;
416    use crate::expr::root;
417    use crate::expr::variant_get;
418    use crate::scalar::Scalar;
419    use crate::scalar::ScalarValue;
420    use crate::scalar_fn::ScalarFnVTable;
421    use crate::scalar_fn::fns::variant_get::VariantGet;
422    use crate::scalar_fn::fns::variant_get::VariantGetOptions;
423    use crate::scalar_fn::fns::variant_get::VariantPath;
424    use crate::scalar_fn::fns::variant_get::VariantPathElement;
425
426    fn variant_object(fields: impl IntoIterator<Item = (&'static str, Scalar)>) -> Scalar {
427        let fields = fields.into_iter().collect::<Vec<_>>();
428        let names = FieldNames::from_iter(fields.iter().map(|(name, _)| FieldName::from(*name)));
429        let dtypes = vec![DType::Variant(Nullability::NonNullable); fields.len()];
430        let values = fields
431            .into_iter()
432            .map(|(_, value)| Scalar::variant(value).into_value())
433            .collect();
434        Scalar::try_new(
435            DType::Struct(StructFields::new(names, dtypes), Nullability::NonNullable),
436            Some(ScalarValue::Tuple(values)),
437        )
438        .unwrap()
439    }
440
441    fn variant_rows(rows: impl IntoIterator<Item = Scalar>) -> VortexResult<ArrayRef> {
442        let dtype = DType::Variant(Nullability::Nullable);
443        let chunks = rows
444            .into_iter()
445            .map(|row| ConstantArray::new(row.cast(&dtype).unwrap(), 1).into_array())
446            .collect();
447        ChunkedArray::try_new(chunks, dtype).map(|array| array.into_array())
448    }
449
450    /// Test-only syntax for keeping `variant_get` cases compact without committing
451    /// to a public string grammar yet.
452    fn parse_path(path: &str) -> VortexResult<VariantPath> {
453        if path.is_empty() || path == "$" {
454            return Ok(VariantPath::root());
455        }
456
457        let mut elements = Vec::new();
458        let mut pos = usize::from(path.as_bytes().first() == Some(&b'$'));
459        if pos == 1
460            && path
461                .as_bytes()
462                .get(pos)
463                .is_some_and(|byte| !matches!(byte, b'.' | b'['))
464        {
465            vortex_bail!("Invalid Variant path {path:?}: expected '.' or '[' after '$'");
466        }
467
468        while pos < path.len() {
469            match path.as_bytes()[pos] {
470                b'.' => {
471                    pos += 1;
472                    let (field, next_pos) = parse_field(path, pos)?;
473                    elements.push(VariantPathElement::field(field));
474                    pos = next_pos;
475                }
476                b'[' => {
477                    let (index, next_pos) = parse_index(path, pos + 1)?;
478                    elements.push(VariantPathElement::index(index));
479                    pos = next_pos;
480                }
481                _ if pos == 0 => {
482                    let (field, next_pos) = parse_field(path, pos)?;
483                    elements.push(VariantPathElement::field(field));
484                    pos = next_pos;
485                }
486                _ => {
487                    vortex_bail!("Invalid Variant path {path:?}: expected '.', '[', or end of path")
488                }
489            }
490        }
491
492        Ok(VariantPath::new(elements))
493    }
494
495    fn parse_field(path: &str, start: usize) -> VortexResult<(&str, usize)> {
496        let mut pos = start;
497        while path
498            .as_bytes()
499            .get(pos)
500            .is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
501        {
502            pos += 1;
503        }
504        vortex_ensure!(
505            pos > start,
506            "Invalid Variant path {path:?}: expected field name"
507        );
508        Ok((&path[start..pos], pos))
509    }
510
511    fn parse_index(path: &str, start: usize) -> VortexResult<(u64, usize)> {
512        let mut pos = start;
513        while path
514            .as_bytes()
515            .get(pos)
516            .is_some_and(|byte| byte.is_ascii_digit())
517        {
518            pos += 1;
519        }
520        vortex_ensure!(
521            pos > start,
522            "Invalid Variant path {path:?}: expected list index"
523        );
524        vortex_ensure!(
525            path.as_bytes().get(pos) == Some(&b']'),
526            "Invalid Variant path {path:?}: expected closing ']'"
527        );
528        let index = path[start..pos]
529            .parse()
530            .map_err(|_| vortex_err!("Invalid Variant path {path:?}: list index is too large"))?;
531        Ok((index, pos + 1))
532    }
533
534    fn execute_variant_get(
535        array: ArrayRef,
536        path: &str,
537        dtype: Option<DType>,
538    ) -> VortexResult<ArrayRef> {
539        let expr = variant_get(root(), parse_path(path)?, dtype);
540        array
541            .apply(&expr)?
542            .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())
543    }
544
545    #[test]
546    fn variant_get_path_parse_and_display() {
547        let path = parse_path("$.data[1].a").unwrap();
548        assert_eq!(
549            path.elements(),
550            &[
551                VariantPathElement::field("data"),
552                VariantPathElement::index(1),
553                VariantPathElement::field("a")
554            ]
555        );
556        assert_eq!(path.to_string(), "$.data[1].a");
557
558        let bare_path = parse_path("data[2]").unwrap();
559        assert_eq!(bare_path.to_string(), "$.data[2]");
560        assert!(parse_path("$.").is_err());
561        assert!(parse_path("$data").is_err());
562        assert!(parse_path("$.data[-1]").is_err());
563    }
564
565    #[test]
566    fn variant_get_return_dtype_is_nullable_variant_without_requested_dtype() {
567        let expr = variant_get(root(), VariantPath::field("data"), None);
568        let dtype = expr
569            .return_dtype(&DType::Variant(Nullability::NonNullable))
570            .unwrap();
571
572        assert_eq!(dtype, DType::Variant(Nullability::Nullable));
573    }
574
575    #[test]
576    fn variant_get_return_dtype_makes_requested_dtype_nullable() {
577        let requested = DType::Primitive(PType::I64, Nullability::NonNullable);
578        let expr = variant_get(root(), VariantPath::field("data"), Some(requested));
579        let dtype = expr
580            .return_dtype(&DType::Variant(Nullability::NonNullable))
581            .unwrap();
582
583        assert_eq!(dtype, DType::Primitive(PType::I64, Nullability::Nullable));
584    }
585
586    #[test]
587    fn variant_get_rejects_non_variant_input() {
588        let expr = variant_get(root(), VariantPath::field("data"), None);
589        let err = expr
590            .return_dtype(&DType::Utf8(Nullability::NonNullable))
591            .unwrap_err();
592
593        assert!(err.to_string().contains("VariantGet input must be Variant"));
594    }
595
596    #[test]
597    fn variant_get_formats_sql() {
598        let expr = variant_get(
599            root(),
600            parse_path("$.data[1].a").unwrap(),
601            Some(DType::Utf8(Nullability::NonNullable)),
602        );
603
604        assert_eq!(expr.to_string(), "variant_get($, \"$.data[1].a\", utf8)");
605    }
606
607    #[test]
608    fn variant_get_options_roundtrip_serialization() {
609        let options = VariantGetOptions::new(
610            VariantPath::new([
611                VariantPathElement::field("data"),
612                VariantPathElement::index(1),
613                VariantPathElement::field("a"),
614            ]),
615            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
616        );
617        let metadata = VariantGet.serialize(&options).unwrap().unwrap();
618        let actual = VariantGet
619            .deserialize(&metadata, &VortexSession::empty())
620            .unwrap();
621
622        assert_eq!(actual, options);
623    }
624
625    #[test]
626    fn variant_get_expression_roundtrip_serialization() {
627        let expr: Expression = variant_get(
628            root(),
629            parse_path("$.data[1].a").unwrap(),
630            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
631        );
632        let proto = expr.serialize_proto().unwrap();
633        let actual = Expression::from_proto(&proto, &LEGACY_SESSION).unwrap();
634
635        assert_eq!(actual, expr);
636    }
637
638    #[test]
639    fn variant_get_generic_fallback_extracts_field_and_list_index() -> VortexResult<()> {
640        let items = Scalar::list(
641            DType::Variant(Nullability::NonNullable),
642            vec![
643                Scalar::variant(Scalar::primitive(10i32, Nullability::NonNullable)),
644                Scalar::variant(Scalar::primitive(20i32, Nullability::NonNullable)),
645            ],
646            Nullability::NonNullable,
647        );
648        let array = variant_rows([
649            Scalar::variant(variant_object([("items", items)])),
650            Scalar::variant(variant_object([(
651                "items",
652                Scalar::list_empty(
653                    DType::Variant(Nullability::NonNullable).into(),
654                    Nullability::NonNullable,
655                ),
656            )])),
657            Scalar::variant(variant_object([(
658                "items",
659                Scalar::list(
660                    DType::Variant(Nullability::NonNullable),
661                    vec![
662                        Scalar::variant(Scalar::utf8("x", Nullability::NonNullable)),
663                        Scalar::variant(Scalar::utf8("wrong", Nullability::NonNullable)),
664                    ],
665                    Nullability::NonNullable,
666                ),
667            )])),
668        ])?;
669
670        let result = execute_variant_get(
671            array,
672            "$.items[1]",
673            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
674        )?;
675
676        assert_arrays_eq!(
677            result,
678            PrimitiveArray::from_option_iter([Some(20i32), None, None])
679        );
680        Ok(())
681    }
682
683    #[test]
684    fn variant_get_reads_chunked_variant_input() -> VortexResult<()> {
685        let array = variant_rows([
686            Scalar::variant(variant_object([(
687                "a",
688                Scalar::primitive(10i32, Nullability::NonNullable),
689            )])),
690            Scalar::variant(variant_object([(
691                "b",
692                Scalar::primitive(20i32, Nullability::NonNullable),
693            )])),
694            Scalar::variant(variant_object([(
695                "a",
696                Scalar::primitive(30i32, Nullability::NonNullable),
697            )])),
698            Scalar::null(DType::Variant(Nullability::Nullable)),
699        ])?;
700        assert!(array.is::<Chunked>());
701
702        let result = execute_variant_get(
703            array,
704            "$.a",
705            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
706        )?;
707
708        assert_arrays_eq!(
709            result,
710            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None])
711        );
712        Ok(())
713    }
714
715    #[test]
716    fn variant_get_fallback_typed_output_is_contiguous() -> VortexResult<()> {
717        let array = variant_rows([
718            Scalar::variant(variant_object([(
719                "a",
720                Scalar::primitive(10i32, Nullability::NonNullable),
721            )])),
722            Scalar::variant(variant_object([(
723                "a",
724                Scalar::primitive(20i32, Nullability::NonNullable),
725            )])),
726            Scalar::variant(variant_object([(
727                "b",
728                Scalar::primitive(30i32, Nullability::NonNullable),
729            )])),
730        ])?;
731
732        let result = execute_variant_get(
733            array,
734            "$.a",
735            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
736        )?;
737
738        assert!(!result.is::<Chunked>());
739        assert_arrays_eq!(
740            result,
741            PrimitiveArray::from_option_iter([Some(10i32), Some(20), None])
742        );
743        Ok(())
744    }
745
746    #[test]
747    fn variant_get_generic_fallback_preserves_variant_null() -> VortexResult<()> {
748        let array = variant_rows([
749            Scalar::variant(variant_object([(
750                "a",
751                Scalar::utf8("ok", Nullability::NonNullable),
752            )])),
753            Scalar::null(DType::Variant(Nullability::Nullable)),
754            Scalar::variant(variant_object([("a", Scalar::null(DType::Null))])),
755            Scalar::variant(variant_object([(
756                "b",
757                Scalar::primitive(2i32, Nullability::NonNullable),
758            )])),
759        ])?;
760
761        let result = execute_variant_get(array, "$.a", None)?;
762
763        let mut ctx = LEGACY_SESSION.create_execution_ctx();
764        let row0 = result.execute_scalar(0, &mut ctx)?;
765        assert_eq!(
766            row0.as_variant()
767                .value()
768                .and_then(|value| value.as_utf8().value())
769                .map(|value| value.as_str()),
770            Some("ok")
771        );
772        assert_nth_scalar_is_null!(result, 1);
773        assert_eq!(
774            result
775                .execute_scalar(2, &mut ctx)?
776                .as_variant()
777                .is_variant_null(),
778            Some(true)
779        );
780        assert_nth_scalar_is_null!(result, 3);
781        Ok(())
782    }
783
784    #[test]
785    fn variant_get_fallback_variant_output_canonicalizes() -> VortexResult<()> {
786        let array = variant_rows([
787            Scalar::variant(variant_object([(
788                "a",
789                Scalar::primitive(10i32, Nullability::NonNullable),
790            )])),
791            Scalar::variant(variant_object([(
792                "a",
793                Scalar::primitive(20i32, Nullability::NonNullable),
794            )])),
795        ])?;
796
797        let result = execute_variant_get(array, "$.a", None)?;
798        let variant = result
799            .clone()
800            .execute::<VariantArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
801        let canonical = result.execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())?;
802        let Canonical::Variant(canonical_variant) = canonical else {
803            vortex_bail!("expected Variant canonical array");
804        };
805
806        assert_eq!(variant.len(), 2);
807        assert_eq!(canonical_variant.len(), 2);
808        assert_eq!(variant.core_storage().dtype(), variant.dtype());
809        assert_eq!(variant.core_storage().len(), variant.len());
810
811        let mut ctx = LEGACY_SESSION.create_execution_ctx();
812        for (idx, expected) in [10i32, 20].into_iter().enumerate() {
813            let scalar = variant.execute_scalar(idx, &mut ctx)?;
814            let actual = scalar
815                .as_variant()
816                .value()
817                .and_then(|value| value.as_primitive().as_::<i32>());
818            assert_eq!(actual, Some(expected), "row {idx}");
819        }
820        Ok(())
821    }
822}