vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::serde::SerializeOptions;
7use vortex_array::stats::{Precision, Stat, StatsProvider};
8use vortex_array::{Array, ArrayContext};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_scalar::{BinaryScalar, Utf8Scalar};
12
13use crate::layouts::flat::FlatLayout;
14use crate::layouts::zoned::{lower_bound, upper_bound};
15use crate::segments::SequenceWriter;
16use crate::{IntoLayout, LayoutRef, LayoutStrategy, SendableSequentialStream};
17
18#[derive(Clone)]
19pub struct FlatLayoutStrategy {
20    /// Whether to include padding for memory-mapped reads.
21    pub include_padding: bool,
22    /// Maximum length of variable length statistics
23    pub max_variable_length_statistics_size: usize,
24}
25
26impl Default for FlatLayoutStrategy {
27    fn default() -> Self {
28        Self {
29            include_padding: true,
30            max_variable_length_statistics_size: 64,
31        }
32    }
33}
34
35#[async_trait]
36impl LayoutStrategy for FlatLayoutStrategy {
37    async fn write_stream(
38        &self,
39        ctx: &ArrayContext,
40        sequence_writer: SequenceWriter,
41        mut stream: SendableSequentialStream,
42    ) -> VortexResult<LayoutRef> {
43        let ctx = ctx.clone();
44        let options = self.clone();
45        let Some(chunk) = stream.next().await else {
46            vortex_bail!("flat layout needs a single chunk");
47        };
48        let (sequence_id, chunk) = chunk?;
49
50        let row_count = chunk.len() as u64;
51
52        match chunk.dtype() {
53            DType::Utf8(_) => {
54                if let Some(sv) = chunk.statistics().get(Stat::Min) {
55                    let (value, truncated) = lower_bound::<Utf8Scalar>(
56                        sv.into_inner(),
57                        options.max_variable_length_statistics_size,
58                    )?;
59                    if truncated {
60                        chunk
61                            .statistics()
62                            .set(Stat::Min, Precision::Inexact(value.into_value()));
63                    }
64                }
65
66                if let Some(sv) = chunk.statistics().get(Stat::Max) {
67                    let (value, truncated) = upper_bound::<Utf8Scalar>(
68                        sv.into_inner(),
69                        options.max_variable_length_statistics_size,
70                    )?;
71                    if let Some(upper_bound) = value {
72                        if truncated {
73                            chunk
74                                .statistics()
75                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
76                        }
77                    } else {
78                        chunk.statistics().clear(Stat::Max)
79                    }
80                }
81            }
82            DType::Binary(_) => {
83                if let Some(sv) = chunk.statistics().get(Stat::Min) {
84                    let (value, truncated) = lower_bound::<BinaryScalar>(
85                        sv.into_inner(),
86                        options.max_variable_length_statistics_size,
87                    )?;
88                    if truncated {
89                        chunk
90                            .statistics()
91                            .set(Stat::Min, Precision::Inexact(value.into_value()));
92                    }
93                }
94
95                if let Some(sv) = chunk.statistics().get(Stat::Max) {
96                    let (value, truncated) = upper_bound::<BinaryScalar>(
97                        sv.into_inner(),
98                        options.max_variable_length_statistics_size,
99                    )?;
100                    if let Some(upper_bound) = value {
101                        if truncated {
102                            chunk
103                                .statistics()
104                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
105                        }
106                    } else {
107                        chunk.statistics().clear(Stat::Max)
108                    }
109                }
110            }
111            _ => {}
112        }
113
114        // TODO(os): spawn serialization
115        let buffers = chunk.serialize(
116            &ctx,
117            &SerializeOptions {
118                offset: 0,
119                include_padding: options.include_padding,
120            },
121        )?;
122        let segment_id = sequence_writer.put(sequence_id, buffers).await?;
123
124        let None = stream.next().await else {
125            vortex_bail!("flat layout received stream with more than a single chunk");
126        };
127        Ok(
128            FlatLayout::new(row_count, stream.dtype().clone(), segment_id, ctx.clone())
129                .into_layout(),
130        )
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::sync::Arc;
137
138    use arrow_buffer::BooleanBufferBuilder;
139    use futures::executor::block_on;
140    use futures::stream;
141    use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
142    use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
143    use vortex_array::stats::{Precision, Stat, StatsProviderExt};
144    use vortex_array::validity::Validity;
145    use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
146    use vortex_buffer::buffer;
147    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
148    use vortex_error::VortexUnwrap;
149    use vortex_expr::root;
150    use vortex_mask::AllOr;
151
152    use crate::layouts::flat::writer::FlatLayoutStrategy;
153    use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
154    use crate::sequence::SequenceId;
155    use crate::{
156        LayoutStrategy, MaskFuture, SendableSequentialStream, SequentialStreamAdapter,
157        SequentialStreamExt as _,
158    };
159
160    fn stream_only(array: ArrayRef) -> SendableSequentialStream {
161        SequentialStreamAdapter::new(
162            array.dtype().clone(),
163            stream::once(async move { Ok((SequenceId::root().downgrade(), array)) }),
164        )
165        .sendable()
166    }
167
168    // Currently, flat layouts do not force compute stats during write, they only retain
169    // pre-computed stats.
170    #[should_panic]
171    #[test]
172    fn flat_stats() {
173        block_on(async {
174            let ctx = ArrayContext::empty();
175            let segments = TestSegments::default();
176            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
177            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
178            let layout = FlatLayoutStrategy::default()
179                .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
180                .await
181                .unwrap();
182            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
183
184            let result = layout
185                .new_reader("".into(), segments)
186                .unwrap()
187                .projection_evaluation(&(0..layout.row_count()), &root())
188                .unwrap()
189                .invoke(MaskFuture::new_true(layout.row_count().try_into().unwrap()))
190                .await
191                .unwrap();
192
193            assert_eq!(
194                result.statistics().get_as::<bool>(Stat::IsSorted),
195                Some(Precision::Exact(true))
196            );
197        })
198    }
199
200    #[test]
201    fn truncates_variable_size_stats() {
202        block_on(async {
203            let ctx = ArrayContext::empty();
204            let segments = TestSegments::default();
205            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
206            let mut builder =
207                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
208            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
209            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
210            let array = builder.finish();
211            array.statistics().set_iter(
212                array
213                    .statistics()
214                    .compute_all(&Stat::all().collect::<Vec<_>>())
215                    .vortex_unwrap()
216                    .into_iter(),
217            );
218
219            let layout = FlatLayoutStrategy::default()
220                .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
221                .await
222                .unwrap();
223            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
224
225            let result = layout
226                .new_reader("".into(), segments)
227                .unwrap()
228                .projection_evaluation(&(0..layout.row_count()), &root())
229                .unwrap()
230                .invoke(MaskFuture::new_true(layout.row_count().try_into().unwrap()))
231                .await
232                .unwrap();
233
234            assert_eq!(
235                result.statistics().get_as::<String>(Stat::Min),
236                Some(Precision::Inexact(
237                    "Another string that's meant to be smaller than the previous valu".to_string()
238                ))
239            );
240            assert_eq!(
241                result.statistics().get_as::<String>(Stat::Max),
242                Some(Precision::Inexact(
243                    "Long value to test that the statistics are actually truncated, j".to_string()
244                ))
245            );
246        })
247    }
248
249    #[test]
250    fn struct_array_round_trip() {
251        block_on(async {
252            let mut validity_builder = BooleanBufferBuilder::new(2);
253            validity_builder.append(true);
254            validity_builder.append(false);
255            let validity_boolean_buffer = validity_builder.finish();
256            let validity = Validity::Array(
257                BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
258            );
259            let array = StructArray::try_new(
260                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
261                vec![
262                    buffer![1_u64, 2].into_array(),
263                    buffer![3_u64, 4].into_array(),
264                ],
265                2,
266                validity,
267            )
268            .unwrap();
269
270            let ctx = ArrayContext::empty();
271
272            // Write the array into a byte buffer.
273            let (layout, segments) = {
274                let segments = TestSegments::default();
275                let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
276                let layout = FlatLayoutStrategy::default()
277                    .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
278                    .await
279                    .unwrap();
280
281                (layout, Arc::new(segments) as Arc<dyn SegmentSource>)
282            };
283
284            // We should be able to read the array we just wrote.
285            let result: ArrayRef = layout
286                .new_reader("".into(), segments)
287                .unwrap()
288                .projection_evaluation(&(0..layout.row_count()), &root())
289                .unwrap()
290                .invoke(MaskFuture::new_true(layout.row_count().try_into().unwrap()))
291                .await
292                .unwrap();
293
294            assert_eq!(
295                result.validity_mask().boolean_buffer(),
296                AllOr::Some(&validity_boolean_buffer)
297            );
298            assert_eq!(
299                result
300                    .to_struct()
301                    .field_by_name("a")
302                    .unwrap()
303                    .to_primitive()
304                    .as_slice::<u64>(),
305                &[1, 2]
306            );
307            assert_eq!(
308                result
309                    .to_struct()
310                    .field_by_name("b")
311                    .unwrap()
312                    .to_primitive()
313                    .as_slice::<u64>(),
314                &[3, 4]
315            );
316        })
317    }
318}