vortex_vector/decimal/
generic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Definition and implementation of [`DVector<D>`].
5
6use std::fmt::Debug;
7use std::ops::BitAnd;
8use std::ops::RangeBounds;
9
10use vortex_buffer::Buffer;
11use vortex_dtype::NativeDecimalType;
12use vortex_dtype::PrecisionScale;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_mask::Mask;
17
18use crate::VectorOps;
19use crate::decimal::DScalar;
20use crate::decimal::DVectorMut;
21
22/// An immutable vector of generic decimal values.
23///
24/// `D` is bound by [`NativeDecimalType`], which can be one of the native integer types (`i8`,
25/// `i16`, `i32`, `i64`, `i128`) or `i256`. `D` is used to store the decimal values.
26///
27/// The decimal vector maintains a [`PrecisionScale<D>`] that defines the precision (total number of
28/// digits) and scale (digits after the decimal point) for all values in the vector.
29#[derive(Debug, Clone)]
30pub struct DVector<D> {
31    /// The precision and scale of each decimal in the decimal vector.
32    pub(super) ps: PrecisionScale<D>,
33    /// The buffer representing the vector decimal elements.
34    pub(super) elements: Buffer<D>,
35    /// The validity mask (where `true` represents an element is **not** null).
36    pub(super) validity: Mask,
37}
38
39impl<D: NativeDecimalType + PartialEq> PartialEq for DVector<D> {
40    fn eq(&self, other: &Self) -> bool {
41        if self.elements.len() != other.elements.len() {
42            return false;
43        }
44        // Precision and scale must match
45        if self.ps != other.ps {
46            return false;
47        }
48        // Validity patterns must match
49        if self.validity != other.validity {
50            return false;
51        }
52        // Compare all elements, OR with !validity to ignore invalid positions
53        self.elements
54            .iter()
55            .zip(other.elements.iter())
56            .enumerate()
57            .all(|(i, (a, b))| !self.validity.value(i) | (a == b))
58    }
59}
60
61impl<D: NativeDecimalType + Eq> Eq for DVector<D> {}
62
63impl<D: NativeDecimalType> DVector<D> {
64    /// Creates a new [`DVector<D>`] from the given [`PrecisionScale`], elements buffer, and
65    /// validity mask.
66    ///
67    /// # Panics
68    ///
69    /// Panics if:
70    ///
71    /// - The lengths of the `elements` and `validity` do not match.
72    /// - Any of the elements are out of bounds for the given [`PrecisionScale`].
73    pub fn new(ps: PrecisionScale<D>, elements: Buffer<D>, validity: Mask) -> Self {
74        Self::try_new(ps, elements, validity).vortex_expect("Failed to create `DVector`")
75    }
76
77    /// Tries to create a new [`DVector<D>`] from the given [`PrecisionScale`], elements buffer, and
78    /// validity mask.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if:
83    ///
84    /// - The lengths of the `elements` and `validity` do not match.
85    /// - Any of the elements are out of bounds for the given [`PrecisionScale`].
86    pub fn try_new(
87        ps: PrecisionScale<D>,
88        elements: Buffer<D>,
89        validity: Mask,
90    ) -> VortexResult<Self> {
91        if elements.len() != validity.len() {
92            vortex_bail!(
93                "Elements length {} does not match validity length {}",
94                elements.len(),
95                validity.len()
96            );
97        }
98
99        // TODO(0ax1): iteration based on mask density via threshold_iter
100
101        // We assert that each non-null element is within bounds for the given precision/scale.
102        for (element, is_valid) in elements.iter().zip(validity.to_bit_buffer().iter()) {
103            if is_valid && !ps.is_valid(*element) {
104                vortex_bail!(
105                    "One or more elements (e.g. {element}) are out of bounds for precision {} and scale {}",
106                    ps.precision(),
107                    ps.scale(),
108                );
109            }
110        }
111
112        Ok(Self {
113            ps,
114            elements,
115            validity,
116        })
117    }
118
119    /// Creates a new [`DVector<D>`] from the given [`PrecisionScale`], elements buffer, and
120    /// validity mask, _without_ validation.
121    ///
122    /// # Safety
123    ///
124    /// The caller must ensure:
125    ///
126    /// - The lengths of the elements and validity are equal.
127    /// - All elements are in bounds for the given [`PrecisionScale`].
128    pub unsafe fn new_unchecked(
129        ps: PrecisionScale<D>,
130        elements: Buffer<D>,
131        validity: Mask,
132    ) -> Self {
133        if cfg!(debug_assertions) {
134            Self::try_new(ps, elements, validity).vortex_expect("Failed to create `DVector`")
135        } else {
136            Self {
137                ps,
138                elements,
139                validity,
140            }
141        }
142    }
143
144    /// Decomposes the decimal vector into its constituent parts ([`PrecisionScale`], decimal
145    /// buffer, and validity).
146    pub fn into_parts(self) -> (PrecisionScale<D>, Buffer<D>, Mask) {
147        (self.ps, self.elements, self.validity)
148    }
149
150    /// Get the precision/scale of the decimal vector.
151    pub fn precision_scale(&self) -> PrecisionScale<D> {
152        self.ps
153    }
154
155    /// Returns the precision of the decimal vector.
156    pub fn precision(&self) -> u8 {
157        self.ps.precision()
158    }
159
160    /// Returns the scale of the decimal vector.
161    pub fn scale(&self) -> i8 {
162        self.ps.scale()
163    }
164
165    /// Returns a reference to the underlying elements buffer containing the decimal data.
166    pub fn elements(&self) -> &Buffer<D> {
167        &self.elements
168    }
169
170    /// Gets a nullable element at the given index, panicking on out-of-bounds.
171    ///
172    /// If the element at the given index is null, returns `None`. Otherwise, returns `Some(x)`,
173    /// where `x: D`.
174    ///
175    /// Note that this `get` method is different from the standard library [`slice::get`], which
176    /// returns `None` if the index is out of bounds. This method will panic if the index is out of
177    /// bounds, and return `None` if the elements is null.
178    ///
179    /// # Panics
180    ///
181    /// Panics if the index is out of bounds.
182    pub fn get(&self, index: usize) -> Option<&D> {
183        self.validity.value(index).then(|| &self.elements[index])
184    }
185}
186
187impl<D: NativeDecimalType> AsRef<[D]> for DVector<D> {
188    fn as_ref(&self) -> &[D] {
189        &self.elements
190    }
191}
192
193impl<D: NativeDecimalType> VectorOps for DVector<D> {
194    type Mutable = DVectorMut<D>;
195    type Scalar = DScalar<D>;
196
197    fn len(&self) -> usize {
198        self.elements.len()
199    }
200
201    fn validity(&self) -> &Mask {
202        &self.validity
203    }
204
205    fn mask_validity(&mut self, mask: &Mask) {
206        self.validity = self.validity.bitand(mask);
207    }
208
209    fn scalar_at(&self, index: usize) -> DScalar<D> {
210        assert!(index < self.len());
211
212        let is_valid = self.validity.value(index);
213        let value = is_valid.then(|| self.elements[index]);
214
215        // SAFETY: We have already checked the validity on construction of the vector
216        unsafe { DScalar::<D>::new_unchecked(self.ps, value) }
217    }
218
219    fn slice(&self, range: impl RangeBounds<usize> + Clone + Debug) -> Self {
220        let elements = self.elements.slice(range.clone());
221        let validity = self.validity.slice(range);
222        Self {
223            ps: self.ps,
224            elements,
225            validity,
226        }
227    }
228
229    fn clear(&mut self) {
230        self.elements.clear();
231        self.validity.clear();
232    }
233
234    fn try_into_mut(self) -> Result<DVectorMut<D>, Self> {
235        let elements = match self.elements.try_into_mut() {
236            Ok(elements) => elements,
237            Err(elements) => {
238                return Err(Self {
239                    ps: self.ps,
240                    elements,
241                    validity: self.validity,
242                });
243            }
244        };
245
246        match self.validity.try_into_mut() {
247            Ok(validity_mut) => Ok(DVectorMut {
248                ps: self.ps,
249                elements,
250                validity: validity_mut,
251            }),
252            Err(validity) => Err(Self {
253                ps: self.ps,
254                elements: elements.freeze(),
255                validity,
256            }),
257        }
258    }
259
260    fn into_mut(self) -> DVectorMut<D> {
261        DVectorMut {
262            ps: self.ps,
263            elements: self.elements.into_mut(),
264            validity: self.validity.into_mut(),
265        }
266    }
267}