vortex_array/arrays/constant/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::ByteBufferMut;
5use vortex_dtype::DType;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8use vortex_scalar::Scalar;
9
10use crate::stats::{ArrayStats, StatsSet, StatsSetRef};
11use crate::vtable::{
12    ArrayVTable, NotSupported, OperationsVTable, VTable, ValidityVTable, VisitorVTable,
13};
14use crate::{
15    ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, EncodingId, EncodingRef, IntoArray, vtable,
16};
17
18mod canonical;
19mod compute;
20mod encode;
21mod serde;
22
23vtable!(Constant);
24
25#[derive(Clone, Debug)]
26pub struct ConstantArray {
27    scalar: Scalar,
28    len: usize,
29    stats_set: ArrayStats,
30}
31
32#[derive(Clone, Debug)]
33pub struct ConstantEncoding;
34
35impl VTable for ConstantVTable {
36    type Array = ConstantArray;
37    type Encoding = ConstantEncoding;
38
39    type ArrayVTable = Self;
40    type CanonicalVTable = Self;
41    type OperationsVTable = Self;
42    type ValidityVTable = Self;
43    type VisitorVTable = Self;
44    // TODO(ngates): implement a compute kernel for elementwise operations
45    type ComputeVTable = NotSupported;
46    type EncodeVTable = Self;
47    type SerdeVTable = Self;
48
49    fn id(_encoding: &Self::Encoding) -> EncodingId {
50        EncodingId::new_ref("vortex.constant")
51    }
52
53    fn encoding(_array: &Self::Array) -> EncodingRef {
54        EncodingRef::new_ref(ConstantEncoding.as_ref())
55    }
56}
57
58impl ConstantArray {
59    pub fn new<S>(scalar: S, len: usize) -> Self
60    where
61        S: Into<Scalar>,
62    {
63        let scalar = scalar.into();
64        let stats = StatsSet::constant(scalar.clone(), len);
65        Self {
66            scalar,
67            len,
68            stats_set: ArrayStats::from(stats),
69        }
70    }
71
72    /// Returns the [`Scalar`] value of this constant array.
73    pub fn scalar(&self) -> &Scalar {
74        &self.scalar
75    }
76}
77
78impl ArrayVTable<ConstantVTable> for ConstantVTable {
79    fn len(array: &ConstantArray) -> usize {
80        array.len
81    }
82
83    fn dtype(array: &ConstantArray) -> &DType {
84        array.scalar.dtype()
85    }
86
87    fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
88        array.stats_set.to_ref(array.as_ref())
89    }
90}
91
92impl OperationsVTable<ConstantVTable> for ConstantVTable {
93    fn slice(array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
94        Ok(ConstantArray::new(array.scalar.clone(), stop - start).into_array())
95    }
96
97    fn scalar_at(array: &ConstantArray, _index: usize) -> VortexResult<Scalar> {
98        Ok(array.scalar.clone())
99    }
100}
101
102impl ValidityVTable<ConstantVTable> for ConstantVTable {
103    fn is_valid(array: &ConstantArray, _index: usize) -> VortexResult<bool> {
104        Ok(!array.scalar().is_null())
105    }
106
107    fn all_valid(array: &ConstantArray) -> VortexResult<bool> {
108        Ok(!array.scalar().is_null())
109    }
110
111    fn all_invalid(array: &ConstantArray) -> VortexResult<bool> {
112        Ok(array.scalar().is_null())
113    }
114
115    fn validity_mask(array: &ConstantArray) -> VortexResult<Mask> {
116        Ok(match array.scalar().is_null() {
117            true => Mask::AllFalse(array.len()),
118            false => Mask::AllTrue(array.len()),
119        })
120    }
121}
122
123impl VisitorVTable<ConstantVTable> for ConstantVTable {
124    fn visit_buffers(array: &ConstantArray, visitor: &mut dyn ArrayBufferVisitor) {
125        let buffer = array
126            .scalar
127            .value()
128            .to_protobytes::<ByteBufferMut>()
129            .freeze();
130        visitor.visit_buffer(&buffer);
131    }
132
133    fn visit_children(_array: &ConstantArray, _visitor: &mut dyn ArrayChildVisitor) {}
134}