vortex_array/arrays/extension/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_dtype::{DType, ExtDType, ExtID};
7use vortex_error::VortexResult;
8use vortex_scalar::Scalar;
9
10use crate::stats::{ArrayStats, StatsSetRef};
11use crate::vtable::{
12    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityChild,
13    ValidityVTableFromChild, VisitorVTable,
14};
15use crate::{
16    Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef,
17    IntoArray, vtable,
18};
19
20mod compute;
21mod serde;
22
23vtable!(Extension);
24
25impl VTable for ExtensionVTable {
26    type Array = ExtensionArray;
27    type Encoding = ExtensionEncoding;
28
29    type ArrayVTable = Self;
30    type CanonicalVTable = Self;
31    type OperationsVTable = Self;
32    type ValidityVTable = ValidityVTableFromChild;
33    type VisitorVTable = Self;
34    type ComputeVTable = NotSupported;
35    type EncodeVTable = NotSupported;
36    type SerdeVTable = Self;
37
38    fn id(_encoding: &Self::Encoding) -> EncodingId {
39        EncodingId::new_ref("vortex.ext")
40    }
41
42    fn encoding(_array: &Self::Array) -> EncodingRef {
43        EncodingRef::new_ref(ExtensionEncoding.as_ref())
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct ExtensionEncoding;
49
50#[derive(Clone, Debug)]
51pub struct ExtensionArray {
52    dtype: DType,
53    storage: ArrayRef,
54    stats_set: ArrayStats,
55}
56
57impl ExtensionArray {
58    pub fn new(ext_dtype: Arc<ExtDType>, storage: ArrayRef) -> Self {
59        assert_eq!(
60            ext_dtype.storage_dtype(),
61            storage.dtype(),
62            "ExtensionArray: storage_dtype must match storage array DType",
63        );
64        Self {
65            dtype: DType::Extension(ext_dtype),
66            storage,
67            stats_set: ArrayStats::default(),
68        }
69    }
70
71    pub fn ext_dtype(&self) -> &Arc<ExtDType> {
72        let DType::Extension(ext) = &self.dtype else {
73            unreachable!("ExtensionArray: dtype must be an ExtDType")
74        };
75        ext
76    }
77
78    pub fn storage(&self) -> &ArrayRef {
79        &self.storage
80    }
81
82    #[allow(dead_code)]
83    #[inline]
84    pub fn id(&self) -> &ExtID {
85        self.ext_dtype().id()
86    }
87}
88
89impl ArrayVTable<ExtensionVTable> for ExtensionVTable {
90    fn len(array: &ExtensionArray) -> usize {
91        array.storage.len()
92    }
93
94    fn dtype(array: &ExtensionArray) -> &DType {
95        &array.dtype
96    }
97
98    fn stats(array: &ExtensionArray) -> StatsSetRef<'_> {
99        array.stats_set.to_ref(array.as_ref())
100    }
101}
102
103impl ValidityChild<ExtensionVTable> for ExtensionVTable {
104    fn validity_child(array: &ExtensionArray) -> &dyn Array {
105        array.storage.as_ref()
106    }
107}
108
109impl CanonicalVTable<ExtensionVTable> for ExtensionVTable {
110    fn canonicalize(array: &ExtensionArray) -> VortexResult<Canonical> {
111        Ok(Canonical::Extension(array.clone()))
112    }
113}
114
115impl OperationsVTable<ExtensionVTable> for ExtensionVTable {
116    fn slice(array: &ExtensionArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
117        Ok(ExtensionArray::new(
118            array.ext_dtype().clone(),
119            array.storage().slice(start, stop)?,
120        )
121        .into_array())
122    }
123
124    fn scalar_at(array: &ExtensionArray, index: usize) -> VortexResult<Scalar> {
125        Ok(Scalar::extension(
126            array.ext_dtype().clone(),
127            array.storage().scalar_at(index)?,
128        ))
129    }
130}
131
132impl VisitorVTable<ExtensionVTable> for ExtensionVTable {
133    fn visit_buffers(_array: &ExtensionArray, _visitor: &mut dyn ArrayBufferVisitor) {}
134
135    fn visit_children(array: &ExtensionArray, visitor: &mut dyn ArrayChildVisitor) {
136        visitor.visit_child("storage", array.storage.as_ref());
137    }
138}