vortex_array/arrays/constant/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6pub use compute::ConstantOperator;
7use vortex_buffer::ByteBufferMut;
8use vortex_dtype::DType;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::vtable::{
14    ArrayVTable, NotSupported, OperationsVTable, VTable, ValidityVTable, VisitorVTable,
15};
16use crate::{
17    ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, EncodingId, EncodingRef, IntoArray, vtable,
18};
19
20mod canonical;
21mod compute;
22mod encode;
23mod serde;
24
25vtable!(Constant);
26
27#[derive(Clone, Debug)]
28pub struct ConstantArray {
29    scalar: Scalar,
30    len: usize,
31    stats_set: ArrayStats,
32}
33
34#[derive(Clone, Debug)]
35pub struct ConstantEncoding;
36
37impl VTable for ConstantVTable {
38    type Array = ConstantArray;
39    type Encoding = ConstantEncoding;
40
41    type ArrayVTable = Self;
42    type CanonicalVTable = Self;
43    type OperationsVTable = Self;
44    type ValidityVTable = Self;
45    type VisitorVTable = Self;
46    // TODO(ngates): implement a compute kernel for elementwise operations
47    type ComputeVTable = NotSupported;
48    type EncodeVTable = Self;
49    type PipelineVTable = Self;
50    type SerdeVTable = Self;
51
52    fn id(_encoding: &Self::Encoding) -> EncodingId {
53        EncodingId::new_ref("vortex.constant")
54    }
55
56    fn encoding(_array: &Self::Array) -> EncodingRef {
57        EncodingRef::new_ref(ConstantEncoding.as_ref())
58    }
59}
60
61impl ConstantArray {
62    pub fn new<S>(scalar: S, len: usize) -> Self
63    where
64        S: Into<Scalar>,
65    {
66        let scalar = scalar.into();
67        Self {
68            scalar,
69            len,
70            stats_set: Default::default(),
71        }
72    }
73
74    /// Returns the [`Scalar`] value of this constant array.
75    pub fn scalar(&self) -> &Scalar {
76        &self.scalar
77    }
78}
79
80impl ArrayVTable<ConstantVTable> for ConstantVTable {
81    fn len(array: &ConstantArray) -> usize {
82        array.len
83    }
84
85    fn dtype(array: &ConstantArray) -> &DType {
86        array.scalar.dtype()
87    }
88
89    fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
90        array.stats_set.to_ref(array.as_ref())
91    }
92}
93
94impl OperationsVTable<ConstantVTable> for ConstantVTable {
95    fn slice(array: &ConstantArray, range: Range<usize>) -> ArrayRef {
96        ConstantArray::new(array.scalar.clone(), range.len()).into_array()
97    }
98
99    fn scalar_at(array: &ConstantArray, _index: usize) -> Scalar {
100        array.scalar.clone()
101    }
102}
103
104impl ValidityVTable<ConstantVTable> for ConstantVTable {
105    fn is_valid(array: &ConstantArray, _index: usize) -> bool {
106        !array.scalar().is_null()
107    }
108
109    fn all_valid(array: &ConstantArray) -> bool {
110        !array.scalar().is_null()
111    }
112
113    fn all_invalid(array: &ConstantArray) -> bool {
114        array.scalar().is_null()
115    }
116
117    fn validity_mask(array: &ConstantArray) -> Mask {
118        match array.scalar().is_null() {
119            true => Mask::AllFalse(array.len()),
120            false => Mask::AllTrue(array.len()),
121        }
122    }
123}
124
125impl VisitorVTable<ConstantVTable> for ConstantVTable {
126    fn visit_buffers(array: &ConstantArray, visitor: &mut dyn ArrayBufferVisitor) {
127        let buffer = array
128            .scalar
129            .value()
130            .to_protobytes::<ByteBufferMut>()
131            .freeze();
132        visitor.visit_buffer(&buffer);
133    }
134
135    fn visit_children(_array: &ConstantArray, _visitor: &mut dyn ArrayChildVisitor) {}
136}