Skip to main content

vortex_array/arrays/list/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::DynArray;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::IntoArray;
19use crate::Precision;
20use crate::ProstMetadata;
21use crate::arrays::ListArray;
22use crate::arrays::list::compute::PARENT_KERNELS;
23use crate::arrays::list::compute::rules::PARENT_RULES;
24use crate::arrays::listview::list_view_from_list;
25use crate::buffer::BufferHandle;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::dtype::PType;
29use crate::hash::ArrayEq;
30use crate::hash::ArrayHash;
31use crate::metadata::DeserializeMetadata;
32use crate::metadata::SerializeMetadata;
33use crate::serde::ArrayChildren;
34use crate::stats::StatsSetRef;
35use crate::validity::Validity;
36use crate::vtable;
37use crate::vtable::Array;
38use crate::vtable::ArrayId;
39use crate::vtable::VTable;
40use crate::vtable::ValidityVTableFromValidityHelper;
41use crate::vtable::validity_nchildren;
42use crate::vtable::validity_to_child;
43mod operations;
44mod validity;
45vtable!(List);
46
47#[derive(Clone, prost::Message)]
48pub struct ListMetadata {
49    #[prost(uint64, tag = "1")]
50    elements_len: u64,
51    #[prost(enumeration = "PType", tag = "2")]
52    offset_ptype: i32,
53}
54
55impl VTable for List {
56    type Array = ListArray;
57
58    type Metadata = ProstMetadata<ListMetadata>;
59    type OperationsVTable = Self;
60    type ValidityVTable = ValidityVTableFromValidityHelper;
61    fn vtable(_array: &Self::Array) -> &Self {
62        &List
63    }
64
65    fn id(&self) -> ArrayId {
66        Self::ID
67    }
68
69    fn len(array: &ListArray) -> usize {
70        array.offsets.len().saturating_sub(1)
71    }
72
73    fn dtype(array: &ListArray) -> &DType {
74        &array.dtype
75    }
76
77    fn stats(array: &ListArray) -> StatsSetRef<'_> {
78        array.stats_set.to_ref(array.as_ref())
79    }
80
81    fn array_hash<H: std::hash::Hasher>(array: &ListArray, state: &mut H, precision: Precision) {
82        array.dtype.hash(state);
83        array.elements.array_hash(state, precision);
84        array.offsets.array_hash(state, precision);
85        array.validity.array_hash(state, precision);
86    }
87
88    fn array_eq(array: &ListArray, other: &ListArray, precision: Precision) -> bool {
89        array.dtype == other.dtype
90            && array.elements.array_eq(&other.elements, precision)
91            && array.offsets.array_eq(&other.offsets, precision)
92            && array.validity.array_eq(&other.validity, precision)
93    }
94
95    fn nbuffers(_array: &ListArray) -> usize {
96        0
97    }
98
99    fn buffer(_array: &ListArray, idx: usize) -> BufferHandle {
100        vortex_panic!("ListArray buffer index {idx} out of bounds")
101    }
102
103    fn buffer_name(_array: &ListArray, idx: usize) -> Option<String> {
104        vortex_panic!("ListArray buffer_name index {idx} out of bounds")
105    }
106
107    fn nchildren(array: &ListArray) -> usize {
108        2 + validity_nchildren(&array.validity)
109    }
110
111    fn child(array: &ListArray, idx: usize) -> ArrayRef {
112        match idx {
113            0 => array.elements().clone(),
114            1 => array.offsets().clone(),
115            2 => validity_to_child(&array.validity, array.len())
116                .vortex_expect("ListArray validity child out of bounds"),
117            _ => vortex_panic!("ListArray child index {idx} out of bounds"),
118        }
119    }
120
121    fn child_name(_array: &ListArray, idx: usize) -> String {
122        match idx {
123            0 => "elements".to_string(),
124            1 => "offsets".to_string(),
125            2 => "validity".to_string(),
126            _ => vortex_panic!("ListArray child_name index {idx} out of bounds"),
127        }
128    }
129
130    fn reduce_parent(
131        array: &Array<Self>,
132        parent: &ArrayRef,
133        child_idx: usize,
134    ) -> VortexResult<Option<ArrayRef>> {
135        PARENT_RULES.evaluate(array, parent, child_idx)
136    }
137
138    fn metadata(array: &ListArray) -> VortexResult<Self::Metadata> {
139        Ok(ProstMetadata(ListMetadata {
140            elements_len: array.elements().len() as u64,
141            offset_ptype: PType::try_from(array.offsets().dtype())? as i32,
142        }))
143    }
144
145    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
146        Ok(Some(SerializeMetadata::serialize(metadata)))
147    }
148
149    fn deserialize(
150        bytes: &[u8],
151        _dtype: &DType,
152        _len: usize,
153        _buffers: &[BufferHandle],
154        _session: &VortexSession,
155    ) -> VortexResult<Self::Metadata> {
156        Ok(ProstMetadata(
157            <ProstMetadata<ListMetadata> as DeserializeMetadata>::deserialize(bytes)?,
158        ))
159    }
160
161    fn build(
162        dtype: &DType,
163        len: usize,
164        metadata: &Self::Metadata,
165        _buffers: &[BufferHandle],
166        children: &dyn ArrayChildren,
167    ) -> VortexResult<ListArray> {
168        let validity = if children.len() == 2 {
169            Validity::from(dtype.nullability())
170        } else if children.len() == 3 {
171            let validity = children.get(2, &Validity::DTYPE, len)?;
172            Validity::Array(validity)
173        } else {
174            vortex_bail!("Expected 2 or 3 children, got {}", children.len());
175        };
176
177        let DType::List(element_dtype, _) = &dtype else {
178            vortex_bail!("Expected List dtype, got {:?}", dtype);
179        };
180        let elements = children.get(
181            0,
182            element_dtype.as_ref(),
183            usize::try_from(metadata.0.elements_len)?,
184        )?;
185
186        let offsets = children.get(
187            1,
188            &DType::Primitive(metadata.0.offset_ptype(), Nullability::NonNullable),
189            len + 1,
190        )?;
191
192        ListArray::try_new(elements, offsets, validity)
193    }
194
195    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
196        vortex_ensure!(
197            children.len() == 2 || children.len() == 3,
198            "ListArray expects 2 or 3 children, got {}",
199            children.len()
200        );
201
202        let mut iter = children.into_iter();
203        let elements = iter
204            .next()
205            .vortex_expect("children length already validated");
206        let offsets = iter
207            .next()
208            .vortex_expect("children length already validated");
209        let validity = if let Some(validity_array) = iter.next() {
210            Validity::Array(validity_array)
211        } else {
212            Validity::from(array.dtype.nullability())
213        };
214
215        let new_array = ListArray::try_new(elements, offsets, validity)?;
216        *array = new_array;
217        Ok(())
218    }
219
220    fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
221        Ok(ExecutionResult::done(
222            list_view_from_list(ListArray::clone(&array), ctx)?.into_array(),
223        ))
224    }
225
226    fn execute_parent(
227        array: &Array<Self>,
228        parent: &ArrayRef,
229        child_idx: usize,
230        ctx: &mut ExecutionCtx,
231    ) -> VortexResult<Option<ArrayRef>> {
232        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
233    }
234}
235
236#[derive(Clone, Debug)]
237pub struct List;
238
239impl List {
240    pub const ID: ArrayId = ArrayId::new_ref("vortex.list");
241}