Skip to main content

vortex_array/scalar/
truncation.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Produce lower on upper bounds of scalars via truncation.
5
6use vortex_buffer::BufferString;
7use vortex_buffer::ByteBuffer;
8use vortex_dtype::Nullability;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12
13use crate::scalar::Scalar;
14use crate::scalar::StringLike;
15
16/// A trait for truncating [`Scalar`]s to a given length in bytes.
17#[allow(clippy::len_without_is_empty)]
18pub trait ScalarTruncation: Send + Sized {
19    /// Unwrap a Scalar into a ScalarTruncation object
20    ///
21    /// # Errors
22    /// If the scalar doesn't match the truncations dtype.
23    fn from_scalar(value: Scalar) -> VortexResult<Option<Self>>;
24
25    /// The length of the value in bytes.
26    fn len(&self) -> usize;
27
28    /// Convert the value into a [`Scalar`] with the given nullability.
29    fn into_scalar(self, nullability: Nullability) -> Scalar;
30
31    /// Constructs the next [`Scalar`] at most `max_length` bytes that's lexicographically greater
32    /// than this.
33    ///
34    /// Returns `None` if the value is null or if constructing a greater value would overflow.
35    fn upper_bound(self, max_length: usize) -> Option<Self>;
36
37    /// Construct a [`ByteBuffer`] at most `max_length` in size that's less than or equal to
38    /// ourselves.
39    fn lower_bound(self, max_length: usize) -> Self;
40}
41
42impl ScalarTruncation for ByteBuffer {
43    fn from_scalar(value: Scalar) -> VortexResult<Option<Self>> {
44        vortex_ensure!(
45            value.dtype().is_binary(),
46            "Expected binary scalar, got {}",
47            value.dtype()
48        );
49        Ok(value.into_value().map(|b| b.into_binary()))
50    }
51
52    fn len(&self) -> usize {
53        ByteBuffer::len(self)
54    }
55
56    fn into_scalar(self, nullability: Nullability) -> Scalar {
57        Scalar::binary(self, nullability)
58    }
59
60    fn upper_bound(self, max_length: usize) -> Option<Self> {
61        let sliced = self.slice(0..max_length);
62        let mut sliced_mut = sliced.into_mut();
63        for b in sliced_mut.iter_mut().rev() {
64            let (incr, overflow) = b.overflowing_add(1);
65            *b = incr;
66            if !overflow {
67                return Some(sliced_mut.freeze());
68            }
69        }
70        None
71    }
72
73    fn lower_bound(self, max_length: usize) -> Self {
74        self.slice(0..max_length)
75    }
76}
77
78impl ScalarTruncation for BufferString {
79    fn from_scalar(value: Scalar) -> VortexResult<Option<Self>> {
80        vortex_ensure!(
81            value.dtype().is_utf8(),
82            "Expected utf8 scalar, got {}",
83            value.dtype()
84        );
85        Ok(value.into_value().map(|b| b.into_utf8()))
86    }
87
88    fn len(&self) -> usize {
89        self.inner().len()
90    }
91
92    fn into_scalar(self, nullability: Nullability) -> Scalar {
93        Scalar::utf8(self, nullability)
94    }
95
96    /// Constructs the next [`BufferString`] at most `max_length` bytes that's lexicographically greater
97    /// than this.
98    ///
99    /// Returns `None` if the value is null or if constructing a greater value would overflow.
100    fn upper_bound(self, max_length: usize) -> Option<Self> {
101        let utf8_split_pos = (max_length.saturating_sub(3)..=max_length)
102            .rfind(|p| self.is_char_boundary(*p))
103            .vortex_expect("Failed to find utf8 character boundary");
104
105        // SAFETY: we slice to a char boundary so the sliced range contains valid UTF-8.
106        let sliced =
107            unsafe { BufferString::new_unchecked(self.into_inner().slice(..utf8_split_pos)) };
108        sliced.increment().ok()
109    }
110
111    /// Construct a [`BufferString`] at most `max_length` in size that's less than or equal to
112    /// ourselves.
113    fn lower_bound(self, max_length: usize) -> Self {
114        // UTF-8 characters are at most 4 bytes. Since we know that `BufferString` is
115        // valid UTF-8, we must have a valid character boundary.
116        let utf8_split_pos = (max_length.saturating_sub(3)..=max_length)
117            .rfind(|p| self.is_char_boundary(*p))
118            .vortex_expect("Failed to find utf8 character boundary");
119
120        unsafe { BufferString::new_unchecked(self.into_inner().slice(..utf8_split_pos)) }
121    }
122}
123
124/// Truncate the value to be less than max_length in bytes and be lexicographically smaller than the value itself
125pub fn lower_bound(
126    value: Option<impl ScalarTruncation>,
127    max_length: usize,
128    nullability: Nullability,
129) -> Option<(Scalar, bool)> {
130    let value = value?;
131    if value.len() > max_length {
132        Some((value.lower_bound(max_length).into_scalar(nullability), true))
133    } else {
134        Some((value.into_scalar(nullability), false))
135    }
136}
137
138/// Truncate the value to be less than max_length in bytes and be lexicographically greater than the value itself
139pub fn upper_bound(
140    value: Option<impl ScalarTruncation>,
141    max_length: usize,
142    nullability: Nullability,
143) -> Option<(Scalar, bool)> {
144    let value = value?;
145    if value.len() > max_length {
146        value
147            .upper_bound(max_length)
148            .map(|v| (v.into_scalar(nullability), true))
149    } else {
150        Some((value.into_scalar(nullability), false))
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use vortex_buffer::BufferString;
157    use vortex_buffer::ByteBuffer;
158    use vortex_buffer::buffer;
159    use vortex_dtype::Nullability;
160
161    use crate::scalar::truncation::ScalarTruncation;
162    use crate::scalar::truncation::lower_bound;
163    use crate::scalar::truncation::upper_bound;
164
165    #[test]
166    fn binary_lower_bound() {
167        let binary = buffer![0u8, 5, 47, 33, 129];
168        let expected = buffer![0u8, 5];
169        assert_eq!(binary.lower_bound(2), expected,);
170    }
171
172    #[test]
173    fn binary_upper_bound() {
174        let binary = buffer![0u8, 5, 255, 234, 23];
175        let expected = buffer![0u8, 6, 0];
176        assert_eq!(binary.upper_bound(3).unwrap(), expected,);
177    }
178
179    #[test]
180    fn binary_upper_bound_overflow() {
181        let binary = buffer![255u8, 255, 255];
182        assert!(binary.upper_bound(2).is_none());
183    }
184
185    #[test]
186    fn binary_upper_bound_null() {
187        assert!(upper_bound(Option::<ByteBuffer>::None, 10, Nullability::Nullable).is_none());
188    }
189
190    #[test]
191    fn binary_lower_bound_null() {
192        assert!(lower_bound(Option::<ByteBuffer>::None, 10, Nullability::Nullable).is_none());
193    }
194
195    #[test]
196    fn utf8_lower_bound() {
197        let utf8 = BufferString::from("snowman⛄️snowman");
198        let expected = BufferString::from("snowman");
199        assert_eq!(utf8.lower_bound(9), expected);
200    }
201
202    #[test]
203    fn utf8_upper_bound() {
204        let utf8 = BufferString::from("char🪩");
205        let expected = BufferString::from("chas");
206        assert_eq!(utf8.upper_bound(5).unwrap(), expected);
207    }
208
209    #[test]
210    fn utf8_upper_bound_overflow() {
211        let utf8 = BufferString::from("🂑🂒🂓");
212        assert!(utf8.upper_bound(2).is_none());
213    }
214
215    #[test]
216    fn utf8_upper_bound_null() {
217        assert!(upper_bound(Option::<BufferString>::None, 10, Nullability::Nullable).is_none());
218    }
219
220    #[test]
221    fn utf8_lower_bound_null() {
222        assert!(lower_bound(Option::<BufferString>::None, 10, Nullability::Nullable).is_none());
223    }
224}