Skip to main content

vortex_array/scalar_fn/fns/variant_get/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use prost::Message;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_proto::expr as pb;
13use vortex_proto::expr::variant_path_element;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16use vortex_utils::aliases::StringEscape;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::ChunkedArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::VariantArray;
24use crate::builders::builder_with_capacity_in;
25use crate::dtype::DType;
26use crate::dtype::FieldName;
27use crate::dtype::Nullability;
28use crate::expr::Expression;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35
36/// Extracts a field/index path from Variant values.
37///
38/// Missing paths, type mismatches while traversing, and failed casts produce nulls. Without a
39/// requested dtype, results are returned as nullable Variant values; with one, results are cast to
40/// that dtype with nullable nullability. Encodings may serve perfectly shredded paths directly,
41/// but must fall back to the core Variant value for paths not represented by shredded storage.
42#[derive(Clone)]
43pub struct VariantGet;
44
45impl ScalarFnVTable for VariantGet {
46    type Options = VariantGetOptions;
47
48    fn id(&self) -> ScalarFnId {
49        static ID: CachedId = CachedId::new("vortex.variant_get");
50        *ID
51    }
52
53    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
54        let path = options
55            .path()
56            .elements()
57            .iter()
58            .map(VariantPathElement::to_proto)
59            .collect();
60        let dtype = options.dtype().map(TryInto::try_into).transpose()?;
61
62        Ok(Some(pb::VariantGetOpts { path, dtype }.encode_to_vec()))
63    }
64
65    fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
66        let opts = pb::VariantGetOpts::decode(metadata)?;
67        let path = opts
68            .path
69            .into_iter()
70            .map(VariantPathElement::from_proto)
71            .collect::<VortexResult<_>>()?;
72        let dtype = opts
73            .dtype
74            .as_ref()
75            .map(|dtype| DType::from_proto(dtype, session))
76            .transpose()?;
77
78        Ok(VariantGetOptions::new(path, dtype))
79    }
80
81    fn arity(&self, _options: &Self::Options) -> Arity {
82        Arity::Exact(1)
83    }
84
85    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
86        match child_idx {
87            0 => ChildName::from("input"),
88            _ => unreachable!("Invalid child index {child_idx} for VariantGet expression"),
89        }
90    }
91
92    fn fmt_sql(
93        &self,
94        options: &Self::Options,
95        expr: &Expression,
96        f: &mut Formatter<'_>,
97    ) -> fmt::Result {
98        write!(f, "variant_get(")?;
99        expr.child(0).fmt_sql(f)?;
100        let path = options.path().to_string();
101        write!(f, ", \"{}\"", StringEscape(&path))?;
102        if let Some(dtype) = options.dtype() {
103            write!(f, ", {dtype}")?;
104        }
105        write!(f, ")")
106    }
107
108    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
109        let input_dtype = &arg_dtypes[0];
110        vortex_ensure!(
111            matches!(input_dtype, DType::Variant(_)),
112            "VariantGet input must be Variant, found {input_dtype}"
113        );
114
115        // Missing paths, traversal mismatches, and cast failures all produce nulls.
116        Ok(options
117            .dtype()
118            .map_or(DType::Variant(Nullability::Nullable), DType::as_nullable))
119    }
120
121    fn execute(
122        &self,
123        options: &Self::Options,
124        args: &dyn ExecutionArgs,
125        ctx: &mut ExecutionCtx,
126    ) -> VortexResult<ArrayRef> {
127        let input = args.get(0)?;
128        // Missing paths, traversal mismatches, and cast failures all produce nulls.
129        let dtype = options
130            .dtype()
131            .map_or(DType::Variant(Nullability::Nullable), DType::as_nullable);
132
133        if !dtype.is_variant() {
134            let mut builder = builder_with_capacity_in(ctx.allocator(), &dtype, input.len());
135            for idx in 0..input.len() {
136                let scalar = input.execute_scalar(idx, ctx)?;
137                let output = variant_get_scalar(&scalar, options, &dtype)?;
138                builder.append_scalar(&output)?;
139            }
140
141            return Ok(builder.finish_into_canonical().into_array());
142        }
143
144        // TODO(variant): replace this with a Variant builder once one exists.
145        // Chunked<Variant> canonicalizes to VariantArray, so this row-wise fallback is safe.
146        let mut chunks = Vec::with_capacity(input.len());
147
148        for idx in 0..input.len() {
149            let scalar = input.execute_scalar(idx, ctx)?;
150            let output = variant_get_scalar(&scalar, options, &dtype)?;
151            chunks.push(ConstantArray::new(output, 1).into_array());
152        }
153
154        let array = ChunkedArray::try_new(chunks, dtype)?.into_array();
155        VariantArray::try_new(array, None).map(|array| array.into_array())
156    }
157}
158
159fn variant_get_scalar(
160    scalar: &Scalar,
161    options: &VariantGetOptions,
162    output_dtype: &DType,
163) -> VortexResult<Scalar> {
164    let Some(value) = variant_path_scalar(scalar, options.path().elements())? else {
165        return Ok(Scalar::null(output_dtype.clone()));
166    };
167
168    if options.dtype().is_none_or(DType::is_variant) {
169        return Scalar::variant(value).cast(output_dtype);
170    }
171
172    if value.is_null() {
173        return Ok(Scalar::null(output_dtype.clone()));
174    }
175
176    value
177        .cast(output_dtype)
178        .or_else(|_| Ok(Scalar::null(output_dtype.clone())))
179}
180
181fn variant_path_scalar(
182    scalar: &Scalar,
183    path: &[VariantPathElement],
184) -> VortexResult<Option<Scalar>> {
185    let mut current = match variant_payload(scalar.clone()) {
186        Some(value) => value,
187        None => return Ok(None),
188    };
189
190    for element in path {
191        current = match variant_payload(current) {
192            Some(value) => value,
193            None => return Ok(None),
194        };
195
196        if current.is_null() {
197            return Ok(None);
198        }
199
200        current = match element {
201            VariantPathElement::Field(name) => {
202                let Some(struct_scalar) = current.as_struct_opt() else {
203                    return Ok(None);
204                };
205                if struct_scalar.is_null() {
206                    return Ok(None);
207                }
208                let Some(field) = struct_scalar.field(name.as_ref()) else {
209                    return Ok(None);
210                };
211                field
212            }
213            VariantPathElement::Index(index) => {
214                let Ok(index) = usize::try_from(*index) else {
215                    return Ok(None);
216                };
217                let Some(list_scalar) = current.as_list_opt() else {
218                    return Ok(None);
219                };
220                let Some(element) = list_scalar.element(index) else {
221                    return Ok(None);
222                };
223                element
224            }
225        };
226    }
227
228    Ok(variant_payload(current))
229}
230
231fn variant_payload(scalar: Scalar) -> Option<Scalar> {
232    if scalar.dtype().is_variant() {
233        scalar.as_variant().value().cloned()
234    } else {
235        Some(scalar)
236    }
237}
238
239/// Options for [`VariantGet`].
240#[derive(Clone, Debug, PartialEq, Eq, Hash)]
241pub struct VariantGetOptions {
242    path: VariantPath,
243    dtype: Option<DType>,
244}
245
246impl VariantGetOptions {
247    /// Creates options for extracting `path`, returning Variant values when `dtype` is `None`.
248    pub fn new(path: VariantPath, dtype: Option<DType>) -> Self {
249        Self { path, dtype }
250    }
251
252    /// Returns the path to extract.
253    pub fn path(&self) -> &VariantPath {
254        &self.path
255    }
256
257    /// Returns the requested output dtype, if any.
258    pub fn dtype(&self) -> Option<&DType> {
259        self.dtype.as_ref()
260    }
261}
262
263impl Display for VariantGetOptions {
264    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
265        write!(f, "{}", self.path)?;
266        if let Some(dtype) = &self.dtype {
267            write!(f, " as {dtype}")?;
268        }
269        Ok(())
270    }
271}
272
273/// A strict Variant path made from object fields and list indexes.
274#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
275pub struct VariantPath(Vec<VariantPathElement>);
276
277impl VariantPath {
278    /// Creates a path from explicit elements.
279    pub fn new(elements: impl IntoIterator<Item = VariantPathElement>) -> Self {
280        Self(elements.into_iter().collect())
281    }
282
283    /// Creates the root path.
284    pub fn root() -> Self {
285        Self::default()
286    }
287
288    /// Creates a path containing one object field.
289    pub fn field(field: impl Into<FieldName>) -> Self {
290        Self(vec![VariantPathElement::field(field)])
291    }
292
293    /// Returns the path elements.
294    pub fn elements(&self) -> &[VariantPathElement] {
295        &self.0
296    }
297
298    /// Returns whether this path references the root Variant value.
299    pub fn is_root(&self) -> bool {
300        self.0.is_empty()
301    }
302}
303
304impl FromIterator<VariantPathElement> for VariantPath {
305    fn from_iter<T: IntoIterator<Item = VariantPathElement>>(iter: T) -> Self {
306        Self(iter.into_iter().collect())
307    }
308}
309
310impl From<VariantPathElement> for VariantPath {
311    fn from(value: VariantPathElement) -> Self {
312        Self(vec![value])
313    }
314}
315
316impl Display for VariantPath {
317    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
318        write!(f, "$")?;
319        for element in self.elements() {
320            match element {
321                VariantPathElement::Field(name) => write!(f, ".{name}")?,
322                VariantPathElement::Index(index) => write!(f, "[{index}]")?,
323            }
324        }
325        Ok(())
326    }
327}
328
329/// A single field or index step in a [`VariantPath`].
330#[derive(Clone, Debug, PartialEq, Eq, Hash)]
331pub enum VariantPathElement {
332    /// Select an object field by name.
333    Field(FieldName),
334    /// Select a list element by zero-based index.
335    Index(u64),
336}
337
338impl VariantPathElement {
339    /// Creates an object-field path element.
340    pub fn field(field: impl Into<FieldName>) -> Self {
341        Self::Field(field.into())
342    }
343
344    /// Creates a list-index path element.
345    pub fn index(index: u64) -> Self {
346        Self::Index(index)
347    }
348
349    /// Decodes a path element from its protobuf representation.
350    pub fn from_proto(value: pb::VariantPathElement) -> VortexResult<Self> {
351        match value
352            .element
353            .ok_or_else(|| vortex_err!("Variant path element missing value"))?
354        {
355            variant_path_element::Element::Field(field) => Ok(Self::field(field)),
356            variant_path_element::Element::Index(index) => Ok(Self::index(index)),
357        }
358    }
359
360    /// Encodes this path element into its protobuf representation.
361    pub fn to_proto(&self) -> pb::VariantPathElement {
362        match self {
363            VariantPathElement::Field(name) => pb::VariantPathElement {
364                element: Some(variant_path_element::Element::Field(
365                    name.as_ref().to_string(),
366                )),
367            },
368            VariantPathElement::Index(index) => pb::VariantPathElement {
369                element: Some(variant_path_element::Element::Index(*index)),
370            },
371        }
372    }
373}
374
375impl From<FieldName> for VariantPathElement {
376    fn from(value: FieldName) -> Self {
377        Self::field(value)
378    }
379}
380
381impl From<&str> for VariantPathElement {
382    fn from(value: &str) -> Self {
383        Self::field(value)
384    }
385}
386
387impl From<u64> for VariantPathElement {
388    fn from(value: u64) -> Self {
389        Self::index(value)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use vortex_error::VortexResult;
396    use vortex_error::vortex_bail;
397    use vortex_error::vortex_ensure;
398    use vortex_error::vortex_err;
399    use vortex_session::VortexSession;
400
401    use crate::ArrayRef;
402    use crate::Canonical;
403    use crate::IntoArray;
404    use crate::VortexSessionExecute;
405    use crate::array_session;
406    use crate::arrays::Chunked;
407    use crate::arrays::ChunkedArray;
408    use crate::arrays::ConstantArray;
409    use crate::arrays::PrimitiveArray;
410    use crate::arrays::VariantArray;
411    use crate::arrays::variant::VariantArrayExt;
412    use crate::assert_arrays_eq;
413    use crate::assert_nth_scalar_is_null;
414    use crate::dtype::DType;
415    use crate::dtype::FieldName;
416    use crate::dtype::FieldNames;
417    use crate::dtype::Nullability;
418    use crate::dtype::PType;
419    use crate::dtype::StructFields;
420    use crate::expr::Expression;
421    use crate::expr::proto::ExprSerializeProtoExt;
422    use crate::expr::root;
423    use crate::expr::variant_get;
424    use crate::scalar::Scalar;
425    use crate::scalar::ScalarValue;
426    use crate::scalar_fn::ScalarFnVTable;
427    use crate::scalar_fn::fns::variant_get::VariantGet;
428    use crate::scalar_fn::fns::variant_get::VariantGetOptions;
429    use crate::scalar_fn::fns::variant_get::VariantPath;
430    use crate::scalar_fn::fns::variant_get::VariantPathElement;
431
432    fn variant_object(fields: impl IntoIterator<Item = (&'static str, Scalar)>) -> Scalar {
433        let fields = fields.into_iter().collect::<Vec<_>>();
434        let names = FieldNames::from_iter(fields.iter().map(|(name, _)| FieldName::from(*name)));
435        let dtypes = vec![DType::Variant(Nullability::NonNullable); fields.len()];
436        let values = fields
437            .into_iter()
438            .map(|(_, value)| Scalar::variant(value).into_value())
439            .collect();
440        Scalar::try_new(
441            DType::Struct(StructFields::new(names, dtypes), Nullability::NonNullable),
442            Some(ScalarValue::Tuple(values)),
443        )
444        .unwrap()
445    }
446
447    fn variant_rows(rows: impl IntoIterator<Item = Scalar>) -> VortexResult<ArrayRef> {
448        let dtype = DType::Variant(Nullability::Nullable);
449        let chunks = rows
450            .into_iter()
451            .map(|row| ConstantArray::new(row.cast(&dtype).unwrap(), 1).into_array())
452            .collect();
453        ChunkedArray::try_new(chunks, dtype).map(|array| array.into_array())
454    }
455
456    /// Test-only syntax for keeping `variant_get` cases compact without committing
457    /// to a public string grammar yet.
458    fn parse_path(path: &str) -> VortexResult<VariantPath> {
459        if path.is_empty() || path == "$" {
460            return Ok(VariantPath::root());
461        }
462
463        let mut elements = Vec::new();
464        let mut pos = usize::from(path.as_bytes().first() == Some(&b'$'));
465        if pos == 1
466            && path
467                .as_bytes()
468                .get(pos)
469                .is_some_and(|byte| !matches!(byte, b'.' | b'['))
470        {
471            vortex_bail!("Invalid Variant path {path:?}: expected '.' or '[' after '$'");
472        }
473
474        while pos < path.len() {
475            match path.as_bytes()[pos] {
476                b'.' => {
477                    pos += 1;
478                    let (field, next_pos) = parse_field(path, pos)?;
479                    elements.push(VariantPathElement::field(field));
480                    pos = next_pos;
481                }
482                b'[' => {
483                    let (index, next_pos) = parse_index(path, pos + 1)?;
484                    elements.push(VariantPathElement::index(index));
485                    pos = next_pos;
486                }
487                _ if pos == 0 => {
488                    let (field, next_pos) = parse_field(path, pos)?;
489                    elements.push(VariantPathElement::field(field));
490                    pos = next_pos;
491                }
492                _ => {
493                    vortex_bail!("Invalid Variant path {path:?}: expected '.', '[', or end of path")
494                }
495            }
496        }
497
498        Ok(VariantPath::new(elements))
499    }
500
501    fn parse_field(path: &str, start: usize) -> VortexResult<(&str, usize)> {
502        let mut pos = start;
503        while path
504            .as_bytes()
505            .get(pos)
506            .is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
507        {
508            pos += 1;
509        }
510        vortex_ensure!(
511            pos > start,
512            "Invalid Variant path {path:?}: expected field name"
513        );
514        Ok((&path[start..pos], pos))
515    }
516
517    fn parse_index(path: &str, start: usize) -> VortexResult<(u64, usize)> {
518        let mut pos = start;
519        while path
520            .as_bytes()
521            .get(pos)
522            .is_some_and(|byte| byte.is_ascii_digit())
523        {
524            pos += 1;
525        }
526        vortex_ensure!(
527            pos > start,
528            "Invalid Variant path {path:?}: expected list index"
529        );
530        vortex_ensure!(
531            path.as_bytes().get(pos) == Some(&b']'),
532            "Invalid Variant path {path:?}: expected closing ']'"
533        );
534        let index = path[start..pos]
535            .parse()
536            .map_err(|_| vortex_err!("Invalid Variant path {path:?}: list index is too large"))?;
537        Ok((index, pos + 1))
538    }
539
540    fn execute_variant_get(
541        array: ArrayRef,
542        path: &str,
543        dtype: Option<DType>,
544    ) -> VortexResult<ArrayRef> {
545        let expr = variant_get(root(), parse_path(path)?, dtype);
546        array
547            .apply(&expr)?
548            .execute::<ArrayRef>(&mut array_session().create_execution_ctx())
549    }
550
551    #[test]
552    fn variant_get_path_parse_and_display() {
553        let path = parse_path("$.data[1].a").unwrap();
554        assert_eq!(
555            path.elements(),
556            &[
557                VariantPathElement::field("data"),
558                VariantPathElement::index(1),
559                VariantPathElement::field("a")
560            ]
561        );
562        assert_eq!(path.to_string(), "$.data[1].a");
563
564        let bare_path = parse_path("data[2]").unwrap();
565        assert_eq!(bare_path.to_string(), "$.data[2]");
566        assert!(parse_path("$.").is_err());
567        assert!(parse_path("$data").is_err());
568        assert!(parse_path("$.data[-1]").is_err());
569    }
570
571    #[test]
572    fn variant_get_return_dtype_is_nullable_variant_without_requested_dtype() {
573        let expr = variant_get(root(), VariantPath::field("data"), None);
574        let dtype = expr
575            .return_dtype(&DType::Variant(Nullability::NonNullable))
576            .unwrap();
577
578        assert_eq!(dtype, DType::Variant(Nullability::Nullable));
579    }
580
581    #[test]
582    fn variant_get_return_dtype_makes_requested_dtype_nullable() {
583        let requested = DType::Primitive(PType::I64, Nullability::NonNullable);
584        let expr = variant_get(root(), VariantPath::field("data"), Some(requested));
585        let dtype = expr
586            .return_dtype(&DType::Variant(Nullability::NonNullable))
587            .unwrap();
588
589        assert_eq!(dtype, DType::Primitive(PType::I64, Nullability::Nullable));
590    }
591
592    #[test]
593    fn variant_get_rejects_non_variant_input() {
594        let expr = variant_get(root(), VariantPath::field("data"), None);
595        let err = expr
596            .return_dtype(&DType::Utf8(Nullability::NonNullable))
597            .unwrap_err();
598
599        assert!(err.to_string().contains("VariantGet input must be Variant"));
600    }
601
602    #[test]
603    fn variant_get_formats_sql() {
604        let expr = variant_get(
605            root(),
606            parse_path("$.data[1].a").unwrap(),
607            Some(DType::Utf8(Nullability::NonNullable)),
608        );
609
610        assert_eq!(expr.to_string(), "variant_get($, \"$.data[1].a\", utf8)");
611    }
612
613    #[test]
614    fn variant_get_options_roundtrip_serialization() {
615        let options = VariantGetOptions::new(
616            VariantPath::new([
617                VariantPathElement::field("data"),
618                VariantPathElement::index(1),
619                VariantPathElement::field("a"),
620            ]),
621            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
622        );
623        let metadata = VariantGet.serialize(&options).unwrap().unwrap();
624        let actual = VariantGet
625            .deserialize(&metadata, &VortexSession::empty())
626            .unwrap();
627
628        assert_eq!(actual, options);
629    }
630
631    #[test]
632    fn variant_get_expression_roundtrip_serialization() {
633        let expr: Expression = variant_get(
634            root(),
635            parse_path("$.data[1].a").unwrap(),
636            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
637        );
638        let proto = expr.serialize_proto().unwrap();
639        let actual = Expression::from_proto(&proto, &array_session()).unwrap();
640
641        assert_eq!(actual, expr);
642    }
643
644    #[test]
645    fn variant_get_generic_fallback_extracts_field_and_list_index() -> VortexResult<()> {
646        let items = Scalar::list(
647            DType::Variant(Nullability::NonNullable),
648            vec![
649                Scalar::variant(Scalar::primitive(10i32, Nullability::NonNullable)),
650                Scalar::variant(Scalar::primitive(20i32, Nullability::NonNullable)),
651            ],
652            Nullability::NonNullable,
653        );
654        let array = variant_rows([
655            Scalar::variant(variant_object([("items", items)])),
656            Scalar::variant(variant_object([(
657                "items",
658                Scalar::list_empty(
659                    DType::Variant(Nullability::NonNullable).into(),
660                    Nullability::NonNullable,
661                ),
662            )])),
663            Scalar::variant(variant_object([(
664                "items",
665                Scalar::list(
666                    DType::Variant(Nullability::NonNullable),
667                    vec![
668                        Scalar::variant(Scalar::utf8("x", Nullability::NonNullable)),
669                        Scalar::variant(Scalar::utf8("wrong", Nullability::NonNullable)),
670                    ],
671                    Nullability::NonNullable,
672                ),
673            )])),
674        ])?;
675
676        let result = execute_variant_get(
677            array,
678            "$.items[1]",
679            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
680        )?;
681        let mut ctx = array_session().create_execution_ctx();
682
683        assert_arrays_eq!(
684            result,
685            PrimitiveArray::from_option_iter([Some(20i32), None, None]),
686            &mut ctx
687        );
688        Ok(())
689    }
690
691    #[test]
692    fn variant_get_reads_chunked_variant_input() -> VortexResult<()> {
693        let array = variant_rows([
694            Scalar::variant(variant_object([(
695                "a",
696                Scalar::primitive(10i32, Nullability::NonNullable),
697            )])),
698            Scalar::variant(variant_object([(
699                "b",
700                Scalar::primitive(20i32, Nullability::NonNullable),
701            )])),
702            Scalar::variant(variant_object([(
703                "a",
704                Scalar::primitive(30i32, Nullability::NonNullable),
705            )])),
706            Scalar::null(DType::Variant(Nullability::Nullable)),
707        ])?;
708        assert!(array.is::<Chunked>());
709
710        let result = execute_variant_get(
711            array,
712            "$.a",
713            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
714        )?;
715        let mut ctx = array_session().create_execution_ctx();
716
717        assert_arrays_eq!(
718            result,
719            PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None]),
720            &mut ctx
721        );
722        Ok(())
723    }
724
725    #[test]
726    fn variant_get_fallback_typed_output_is_contiguous() -> VortexResult<()> {
727        let array = variant_rows([
728            Scalar::variant(variant_object([(
729                "a",
730                Scalar::primitive(10i32, Nullability::NonNullable),
731            )])),
732            Scalar::variant(variant_object([(
733                "a",
734                Scalar::primitive(20i32, Nullability::NonNullable),
735            )])),
736            Scalar::variant(variant_object([(
737                "b",
738                Scalar::primitive(30i32, Nullability::NonNullable),
739            )])),
740        ])?;
741
742        let result = execute_variant_get(
743            array,
744            "$.a",
745            Some(DType::Primitive(PType::I32, Nullability::NonNullable)),
746        )?;
747
748        assert!(!result.is::<Chunked>());
749        let mut ctx = array_session().create_execution_ctx();
750        assert_arrays_eq!(
751            result,
752            PrimitiveArray::from_option_iter([Some(10i32), Some(20), None]),
753            &mut ctx
754        );
755        Ok(())
756    }
757
758    #[test]
759    fn variant_get_generic_fallback_preserves_variant_null() -> VortexResult<()> {
760        let array = variant_rows([
761            Scalar::variant(variant_object([(
762                "a",
763                Scalar::utf8("ok", Nullability::NonNullable),
764            )])),
765            Scalar::null(DType::Variant(Nullability::Nullable)),
766            Scalar::variant(variant_object([("a", Scalar::null(DType::Null))])),
767            Scalar::variant(variant_object([(
768                "b",
769                Scalar::primitive(2i32, Nullability::NonNullable),
770            )])),
771        ])?;
772
773        let result = execute_variant_get(array, "$.a", None)?;
774
775        let mut ctx = array_session().create_execution_ctx();
776        let row0 = result.execute_scalar(0, &mut ctx)?;
777        assert_eq!(
778            row0.as_variant()
779                .value()
780                .and_then(|value| value.as_utf8().value())
781                .map(|value| value.as_str()),
782            Some("ok")
783        );
784        assert_nth_scalar_is_null!(result, 1, &mut ctx);
785        assert_eq!(
786            result
787                .execute_scalar(2, &mut ctx)?
788                .as_variant()
789                .is_variant_null(),
790            Some(true)
791        );
792        assert_nth_scalar_is_null!(result, 3, &mut ctx);
793        Ok(())
794    }
795
796    #[test]
797    fn variant_get_fallback_variant_output_canonicalizes() -> VortexResult<()> {
798        let array = variant_rows([
799            Scalar::variant(variant_object([(
800                "a",
801                Scalar::primitive(10i32, Nullability::NonNullable),
802            )])),
803            Scalar::variant(variant_object([(
804                "a",
805                Scalar::primitive(20i32, Nullability::NonNullable),
806            )])),
807        ])?;
808
809        let result = execute_variant_get(array, "$.a", None)?;
810        let variant = result
811            .clone()
812            .execute::<VariantArray>(&mut array_session().create_execution_ctx())?;
813        let canonical = result.execute::<Canonical>(&mut array_session().create_execution_ctx())?;
814        let Canonical::Variant(canonical_variant) = canonical else {
815            vortex_bail!("expected Variant canonical array");
816        };
817
818        assert_eq!(variant.len(), 2);
819        assert_eq!(canonical_variant.len(), 2);
820        assert_eq!(variant.core_storage().dtype(), variant.dtype());
821        assert_eq!(variant.core_storage().len(), variant.len());
822
823        let mut ctx = array_session().create_execution_ctx();
824        for (idx, expected) in [10i32, 20].into_iter().enumerate() {
825            let scalar = variant.execute_scalar(idx, &mut ctx)?;
826            let actual = scalar
827                .as_variant()
828                .value()
829                .and_then(|value| value.as_primitive().as_::<i32>());
830            assert_eq!(actual, Some(expected), "row {idx}");
831        }
832        Ok(())
833    }
834}