vortex_compute/arrow/
decimal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_array::Array;
7use arrow_array::ArrayRef;
8use arrow_array::PrimitiveArray;
9use arrow_array::types::Decimal32Type;
10use arrow_array::types::Decimal64Type;
11use arrow_array::types::Decimal128Type;
12use arrow_array::types::Decimal256Type;
13use vortex_buffer::Buffer;
14use vortex_dtype::PrecisionScale;
15use vortex_dtype::i256;
16use vortex_error::VortexResult;
17use vortex_vector::decimal::DVector;
18use vortex_vector::decimal::DecimalVector;
19
20use crate::arrow::IntoArrow;
21use crate::arrow::IntoVector;
22use crate::arrow::nulls_to_mask;
23
24impl IntoArrow for DecimalVector {
25    type Output = ArrayRef;
26
27    fn into_arrow(self) -> VortexResult<Self::Output> {
28        match self {
29            DecimalVector::D8(v) => Ok(Arc::new(v.into_arrow()?)),
30            DecimalVector::D16(v) => Ok(Arc::new(v.into_arrow()?)),
31            DecimalVector::D32(v) => Ok(Arc::new(v.into_arrow()?)),
32            DecimalVector::D64(v) => Ok(Arc::new(v.into_arrow()?)),
33            DecimalVector::D128(v) => Ok(Arc::new(v.into_arrow()?)),
34            DecimalVector::D256(v) => Ok(Arc::new(v.into_arrow()?)),
35        }
36    }
37}
38
39macro_rules! impl_decimal_upcast_i32 {
40    ($T:ty) => {
41        impl IntoArrow for DVector<$T> {
42            type Output = PrimitiveArray<Decimal32Type>;
43
44            fn into_arrow(self) -> VortexResult<Self::Output> {
45                let (_, elements, validity) = self.into_parts();
46                // Upcast the DVector to Arrow's smallest decimal type (Decimal32)
47                let elements =
48                    Buffer::<i32>::from_trusted_len_iter(elements.iter().map(|i| *i as i32));
49                Ok(PrimitiveArray::<Decimal32Type>::new(
50                    elements.into_arrow_scalar_buffer(),
51                    validity.into(),
52                ))
53            }
54        }
55    };
56}
57
58impl_decimal_upcast_i32!(i8);
59impl_decimal_upcast_i32!(i16);
60
61/// Direct Arrow conversion for vectors that map directly to Arrow decimal types.
62macro_rules! impl_decimal_to_arrow {
63    ($T:ty, $A:ty) => {
64        impl IntoArrow for DVector<$T> {
65            type Output = PrimitiveArray<$A>;
66
67            fn into_arrow(self) -> VortexResult<Self::Output> {
68                let (_, elements, validity) = self.into_parts();
69                Ok(PrimitiveArray::<$A>::new(
70                    elements.into_arrow_scalar_buffer(),
71                    validity.into(),
72                ))
73            }
74        }
75    };
76}
77
78impl_decimal_to_arrow!(i32, Decimal32Type);
79impl_decimal_to_arrow!(i64, Decimal64Type);
80impl_decimal_to_arrow!(i128, Decimal128Type);
81
82impl IntoArrow for DVector<i256> {
83    type Output = PrimitiveArray<Decimal256Type>;
84
85    fn into_arrow(self) -> VortexResult<Self::Output> {
86        let (_, elements, validity) = self.into_parts();
87
88        // Transmute the elements from our i256 to Arrow's.
89        // SAFETY: we use Arrow's type internally for our layout.
90        let elements =
91            unsafe { std::mem::transmute::<Buffer<i256>, Buffer<arrow_buffer::i256>>(elements) };
92
93        Ok(PrimitiveArray::<Decimal256Type>::new(
94            elements.into_arrow_scalar_buffer(),
95            validity.into(),
96        ))
97    }
98}
99
100/// Convert a Decimal32 Arrow array to a DecimalVector.
101impl IntoVector for &PrimitiveArray<Decimal32Type> {
102    type Output = DecimalVector;
103
104    fn into_vector(self) -> VortexResult<Self::Output> {
105        let (precision, scale) = match self.data_type() {
106            arrow_schema::DataType::Decimal32(p, s) => (*p, *s),
107            _ => unreachable!("PrimitiveArray<Decimal32Type> must have Decimal32 data type"),
108        };
109
110        let elements = Buffer::<i32>::from_arrow_scalar_buffer(self.values().clone());
111        let validity = nulls_to_mask(self.nulls(), self.len());
112        let ps = PrecisionScale::<i32>::new(precision, scale);
113
114        Ok(DecimalVector::D32(DVector::new(ps, elements, validity)))
115    }
116}
117
118/// Convert a Decimal64 Arrow array to a DecimalVector.
119impl IntoVector for &PrimitiveArray<Decimal64Type> {
120    type Output = DecimalVector;
121
122    fn into_vector(self) -> VortexResult<Self::Output> {
123        let (precision, scale) = match self.data_type() {
124            arrow_schema::DataType::Decimal64(p, s) => (*p, *s),
125            _ => unreachable!("PrimitiveArray<Decimal64Type> must have Decimal64 data type"),
126        };
127
128        let elements = Buffer::<i64>::from_arrow_scalar_buffer(self.values().clone());
129        let validity = nulls_to_mask(self.nulls(), self.len());
130        let ps = PrecisionScale::<i64>::new(precision, scale);
131
132        Ok(DecimalVector::D64(DVector::new(ps, elements, validity)))
133    }
134}
135
136/// Convert a Decimal128 Arrow array to a DecimalVector.
137impl IntoVector for &PrimitiveArray<Decimal128Type> {
138    type Output = DecimalVector;
139
140    fn into_vector(self) -> VortexResult<Self::Output> {
141        let (precision, scale) = match self.data_type() {
142            arrow_schema::DataType::Decimal128(p, s) => (*p, *s),
143            _ => unreachable!("PrimitiveArray<Decimal128Type> must have Decimal128 data type"),
144        };
145
146        let elements = Buffer::<i128>::from_arrow_scalar_buffer(self.values().clone());
147        let validity = nulls_to_mask(self.nulls(), self.len());
148        let ps = PrecisionScale::<i128>::new(precision, scale);
149
150        Ok(DecimalVector::D128(DVector::new(ps, elements, validity)))
151    }
152}
153
154/// Convert a Decimal256 Arrow array to a DecimalVector.
155impl IntoVector for &PrimitiveArray<Decimal256Type> {
156    type Output = DecimalVector;
157
158    fn into_vector(self) -> VortexResult<Self::Output> {
159        let (precision, scale) = match self.data_type() {
160            arrow_schema::DataType::Decimal256(p, s) => (*p, *s),
161            _ => unreachable!("PrimitiveArray<Decimal256Type> must have Decimal256 data type"),
162        };
163
164        let elements =
165            Buffer::<arrow_buffer::i256>::from_arrow_scalar_buffer(self.values().clone());
166        // SAFETY: we use Arrow's type internally for our layout.
167        let elements =
168            unsafe { std::mem::transmute::<Buffer<arrow_buffer::i256>, Buffer<i256>>(elements) };
169        let validity = nulls_to_mask(self.nulls(), self.len());
170        let ps = PrecisionScale::<i256>::new(precision, scale);
171
172        Ok(DecimalVector::D256(DVector::new(ps, elements, validity)))
173    }
174}