vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::serde::SerializeOptions;
7use vortex_array::stats::{Precision, Stat, StatsProvider};
8use vortex_array::{Array, ArrayContext};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_io::runtime::Handle;
12use vortex_scalar::{BinaryScalar, Utf8Scalar};
13
14use crate::layouts::flat::FlatLayout;
15use crate::layouts::zoned::{lower_bound, upper_bound};
16use crate::segments::SegmentSinkRef;
17use crate::sequence::{SendableSequentialStream, SequencePointer};
18use crate::{IntoLayout, LayoutRef, LayoutStrategy};
19
20#[derive(Clone)]
21pub struct FlatLayoutStrategy {
22    /// Whether to include padding for memory-mapped reads.
23    pub include_padding: bool,
24    /// Maximum length of variable length statistics
25    pub max_variable_length_statistics_size: usize,
26}
27
28impl Default for FlatLayoutStrategy {
29    fn default() -> Self {
30        Self {
31            include_padding: true,
32            max_variable_length_statistics_size: 64,
33        }
34    }
35}
36
37#[async_trait]
38impl LayoutStrategy for FlatLayoutStrategy {
39    async fn write_stream(
40        &self,
41        ctx: ArrayContext,
42        segment_sink: SegmentSinkRef,
43        mut stream: SendableSequentialStream,
44        _eof: SequencePointer,
45        _handle: Handle,
46    ) -> VortexResult<LayoutRef> {
47        let ctx = ctx.clone();
48        let options = self.clone();
49        let Some(chunk) = stream.next().await else {
50            vortex_bail!("flat layout needs a single chunk");
51        };
52        let (sequence_id, chunk) = chunk?;
53
54        let row_count = chunk.len() as u64;
55
56        match chunk.dtype() {
57            DType::Utf8(_) => {
58                if let Some(sv) = chunk.statistics().get(Stat::Min) {
59                    let (value, truncated) = lower_bound::<Utf8Scalar>(
60                        sv.into_inner(),
61                        options.max_variable_length_statistics_size,
62                    )?;
63                    if truncated {
64                        chunk
65                            .statistics()
66                            .set(Stat::Min, Precision::Inexact(value.into_value()));
67                    }
68                }
69
70                if let Some(sv) = chunk.statistics().get(Stat::Max) {
71                    let (value, truncated) = upper_bound::<Utf8Scalar>(
72                        sv.into_inner(),
73                        options.max_variable_length_statistics_size,
74                    )?;
75                    if let Some(upper_bound) = value {
76                        if truncated {
77                            chunk
78                                .statistics()
79                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
80                        }
81                    } else {
82                        chunk.statistics().clear(Stat::Max)
83                    }
84                }
85            }
86            DType::Binary(_) => {
87                if let Some(sv) = chunk.statistics().get(Stat::Min) {
88                    let (value, truncated) = lower_bound::<BinaryScalar>(
89                        sv.into_inner(),
90                        options.max_variable_length_statistics_size,
91                    )?;
92                    if truncated {
93                        chunk
94                            .statistics()
95                            .set(Stat::Min, Precision::Inexact(value.into_value()));
96                    }
97                }
98
99                if let Some(sv) = chunk.statistics().get(Stat::Max) {
100                    let (value, truncated) = upper_bound::<BinaryScalar>(
101                        sv.into_inner(),
102                        options.max_variable_length_statistics_size,
103                    )?;
104                    if let Some(upper_bound) = value {
105                        if truncated {
106                            chunk
107                                .statistics()
108                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
109                        }
110                    } else {
111                        chunk.statistics().clear(Stat::Max)
112                    }
113                }
114            }
115            _ => {}
116        }
117
118        // TODO(os): spawn serialization
119        let buffers = chunk.serialize(
120            &ctx,
121            &SerializeOptions {
122                offset: 0,
123                include_padding: options.include_padding,
124            },
125        )?;
126        let segment_id = segment_sink.write(sequence_id, buffers).await?;
127
128        let None = stream.next().await else {
129            vortex_bail!("flat layout received stream with more than a single chunk");
130        };
131        Ok(
132            FlatLayout::new(row_count, stream.dtype().clone(), segment_id, ctx.clone())
133                .into_layout(),
134        )
135    }
136
137    fn buffered_bytes(&self) -> u64 {
138        // FlatLayoutStrategy is a leaf strategy with no child strategies and no buffering
139        0
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::Arc;
146
147    use arrow_buffer::BooleanBufferBuilder;
148    use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
149    use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
150    use vortex_array::stats::{Precision, Stat, StatsProviderExt};
151    use vortex_array::validity::Validity;
152    use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, MaskFuture, ToCanonical};
153    use vortex_buffer::buffer;
154    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
155    use vortex_error::VortexUnwrap;
156    use vortex_expr::root;
157    use vortex_io::runtime::single::block_on;
158    use vortex_mask::AllOr;
159
160    use crate::LayoutStrategy;
161    use crate::layouts::flat::writer::FlatLayoutStrategy;
162    use crate::segments::TestSegments;
163    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
164
165    // Currently, flat layouts do not force compute stats during write, they only retain
166    // pre-computed stats.
167    #[should_panic]
168    #[test]
169    fn flat_stats() {
170        block_on(|handle| async {
171            let ctx = ArrayContext::empty();
172            let segments = Arc::new(TestSegments::default());
173            let (ptr, eof) = SequenceId::root().split();
174            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
175            let layout = FlatLayoutStrategy::default()
176                .write_stream(
177                    ctx,
178                    segments.clone(),
179                    array.to_array_stream().sequenced(ptr),
180                    eof,
181                    handle,
182                )
183                .await
184                .unwrap();
185
186            let result = layout
187                .new_reader("".into(), segments)
188                .unwrap()
189                .projection_evaluation(
190                    &(0..layout.row_count()),
191                    &root(),
192                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
193                )
194                .unwrap()
195                .await
196                .unwrap();
197
198            assert_eq!(
199                result.statistics().get_as::<bool>(Stat::IsSorted),
200                Some(Precision::Exact(true))
201            );
202        })
203    }
204
205    #[test]
206    fn truncates_variable_size_stats() {
207        block_on(|handle| async {
208            let ctx = ArrayContext::empty();
209            let segments = Arc::new(TestSegments::default());
210            let (ptr, eof) = SequenceId::root().split();
211            let mut builder =
212                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
213            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
214            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
215            let array = builder.finish();
216            array.statistics().set_iter(
217                array
218                    .statistics()
219                    .compute_all(&Stat::all().collect::<Vec<_>>())
220                    .vortex_unwrap()
221                    .into_iter(),
222            );
223
224            let layout = FlatLayoutStrategy::default()
225                .write_stream(
226                    ctx,
227                    segments.clone(),
228                    array.to_array_stream().sequenced(ptr),
229                    eof,
230                    handle,
231                )
232                .await
233                .unwrap();
234
235            let result = layout
236                .new_reader("".into(), segments)
237                .unwrap()
238                .projection_evaluation(
239                    &(0..layout.row_count()),
240                    &root(),
241                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
242                )
243                .unwrap()
244                .await
245                .unwrap();
246
247            assert_eq!(
248                result.statistics().get_as::<String>(Stat::Min),
249                // The typo is correct, we need this to be truncated.
250                Some(Precision::Inexact(
251                    "Another string that's meant to be smaller than the previous valu".to_string()
252                ))
253            );
254            assert_eq!(
255                result.statistics().get_as::<String>(Stat::Max),
256                Some(Precision::Inexact(
257                    "Long value to test that the statistics are actually truncated, j".to_string()
258                ))
259            );
260        })
261    }
262
263    #[test]
264    fn struct_array_round_trip() {
265        block_on(|handle| async {
266            let mut validity_builder = BooleanBufferBuilder::new(2);
267            validity_builder.append(true);
268            validity_builder.append(false);
269            let validity_boolean_buffer = validity_builder.finish();
270            let validity = Validity::Array(
271                BoolArray::from_bool_buffer(validity_boolean_buffer.clone(), Validity::NonNullable)
272                    .into_array(),
273            );
274            let array = StructArray::try_new(
275                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
276                vec![
277                    buffer![1_u64, 2].into_array(),
278                    buffer![3_u64, 4].into_array(),
279                ],
280                2,
281                validity,
282            )
283            .unwrap();
284
285            let ctx = ArrayContext::empty();
286
287            // Write the array into a byte buffer.
288            let (layout, segments) = {
289                let segments = Arc::new(TestSegments::default());
290                let (ptr, eof) = SequenceId::root().split();
291                let layout = FlatLayoutStrategy::default()
292                    .write_stream(
293                        ctx,
294                        segments.clone(),
295                        array.to_array_stream().sequenced(ptr),
296                        eof,
297                        handle,
298                    )
299                    .await
300                    .unwrap();
301
302                (layout, segments)
303            };
304
305            // We should be able to read the array we just wrote.
306            let result: ArrayRef = layout
307                .new_reader("".into(), segments)
308                .unwrap()
309                .projection_evaluation(
310                    &(0..layout.row_count()),
311                    &root(),
312                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
313                )
314                .unwrap()
315                .await
316                .unwrap();
317
318            assert_eq!(
319                result.validity_mask().boolean_buffer(),
320                AllOr::Some(&validity_boolean_buffer)
321            );
322            assert_eq!(
323                result
324                    .to_struct()
325                    .field_by_name("a")
326                    .unwrap()
327                    .to_primitive()
328                    .as_slice::<u64>(),
329                &[1, 2]
330            );
331            assert_eq!(
332                result
333                    .to_struct()
334                    .field_by_name("b")
335                    .unwrap()
336                    .to_primitive()
337                    .as_slice::<u64>(),
338                &[3, 4]
339            );
340        })
341    }
342}