vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::serde::SerializeOptions;
7use vortex_array::stats::{Precision, Stat, StatsProvider};
8use vortex_array::{Array, ArrayContext};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_scalar::{BinaryScalar, Utf8Scalar};
12
13use crate::layouts::flat::FlatLayout;
14use crate::layouts::zoned::{lower_bound, upper_bound};
15use crate::segments::SequenceWriter;
16use crate::{IntoLayout, LayoutRef, LayoutStrategy, SendableSequentialStream};
17
18#[derive(Clone)]
19pub struct FlatLayoutStrategy {
20    /// Whether to include padding for memory-mapped reads.
21    pub include_padding: bool,
22    /// Maximum length of variable length statistics
23    pub max_variable_length_statistics_size: usize,
24}
25
26impl Default for FlatLayoutStrategy {
27    fn default() -> Self {
28        Self {
29            include_padding: true,
30            max_variable_length_statistics_size: 64,
31        }
32    }
33}
34
35#[async_trait]
36impl LayoutStrategy for FlatLayoutStrategy {
37    async fn write_stream(
38        &self,
39        ctx: &ArrayContext,
40        sequence_writer: SequenceWriter,
41        mut stream: SendableSequentialStream,
42    ) -> VortexResult<LayoutRef> {
43        let ctx = ctx.clone();
44        let options = self.clone();
45        let Some(chunk) = stream.next().await else {
46            vortex_bail!("flat layout needs a single chunk");
47        };
48        let (sequence_id, chunk) = chunk?;
49
50        let row_count = chunk.len() as u64;
51
52        match chunk.dtype() {
53            DType::Utf8(_) => {
54                if let Some(sv) = chunk.statistics().get(Stat::Min) {
55                    let (value, truncated) = lower_bound::<Utf8Scalar>(
56                        sv.into_inner(),
57                        options.max_variable_length_statistics_size,
58                    )?;
59                    if truncated {
60                        chunk
61                            .statistics()
62                            .set(Stat::Min, Precision::Inexact(value.into_value()));
63                    }
64                }
65
66                if let Some(sv) = chunk.statistics().get(Stat::Max) {
67                    let (value, truncated) = upper_bound::<Utf8Scalar>(
68                        sv.into_inner(),
69                        options.max_variable_length_statistics_size,
70                    )?;
71                    if let Some(upper_bound) = value {
72                        if truncated {
73                            chunk
74                                .statistics()
75                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
76                        }
77                    } else {
78                        chunk.statistics().clear(Stat::Max)
79                    }
80                }
81            }
82            DType::Binary(_) => {
83                if let Some(sv) = chunk.statistics().get(Stat::Min) {
84                    let (value, truncated) = lower_bound::<BinaryScalar>(
85                        sv.into_inner(),
86                        options.max_variable_length_statistics_size,
87                    )?;
88                    if truncated {
89                        chunk
90                            .statistics()
91                            .set(Stat::Min, Precision::Inexact(value.into_value()));
92                    }
93                }
94
95                if let Some(sv) = chunk.statistics().get(Stat::Max) {
96                    let (value, truncated) = upper_bound::<BinaryScalar>(
97                        sv.into_inner(),
98                        options.max_variable_length_statistics_size,
99                    )?;
100                    if let Some(upper_bound) = value {
101                        if truncated {
102                            chunk
103                                .statistics()
104                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
105                        }
106                    } else {
107                        chunk.statistics().clear(Stat::Max)
108                    }
109                }
110            }
111            _ => {}
112        }
113
114        // TODO(os): spawn serialization
115        let buffers = chunk.serialize(
116            &ctx,
117            &SerializeOptions {
118                offset: 0,
119                include_padding: options.include_padding,
120            },
121        )?;
122        let segment_id = sequence_writer.put(sequence_id, buffers).await?;
123
124        let None = stream.next().await else {
125            vortex_bail!("flat layout received stream with more than a single chunk");
126        };
127        Ok(
128            FlatLayout::new(row_count, stream.dtype().clone(), segment_id, ctx.clone())
129                .into_layout(),
130        )
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::sync::Arc;
137
138    use arrow_buffer::BooleanBufferBuilder;
139    use futures::executor::block_on;
140    use futures::stream;
141    use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
142    use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
143    use vortex_array::stats::{Precision, Stat};
144    use vortex_array::validity::Validity;
145    use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
146    use vortex_buffer::buffer;
147    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
148    use vortex_error::VortexUnwrap;
149    use vortex_expr::root;
150    use vortex_mask::{AllOr, Mask};
151
152    use crate::layouts::flat::writer::FlatLayoutStrategy;
153    use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
154    use crate::sequence::SequenceId;
155    use crate::{
156        LayoutStrategy, SendableSequentialStream, SequentialStreamAdapter, SequentialStreamExt as _,
157    };
158
159    fn stream_only(array: ArrayRef) -> SendableSequentialStream {
160        SequentialStreamAdapter::new(
161            array.dtype().clone(),
162            stream::once(async move { Ok((SequenceId::root().downgrade(), array)) }),
163        )
164        .sendable()
165    }
166
167    // Currently, flat layouts do not force compute stats during write, they only retain
168    // pre-computed stats.
169    #[should_panic]
170    #[test]
171    fn flat_stats() {
172        block_on(async {
173            let ctx = ArrayContext::empty();
174            let segments = TestSegments::default();
175            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
176            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
177            let layout = FlatLayoutStrategy::default()
178                .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
179                .await
180                .unwrap();
181            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
182
183            let result = layout
184                .new_reader("".into(), segments)
185                .unwrap()
186                .projection_evaluation(&(0..layout.row_count()), &root())
187                .unwrap()
188                .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
189                .await
190                .unwrap();
191
192            assert_eq!(
193                result.statistics().get_as::<bool>(Stat::IsSorted),
194                Some(Precision::Exact(true))
195            );
196        })
197    }
198
199    #[test]
200    fn truncates_variable_size_stats() {
201        block_on(async {
202            let ctx = ArrayContext::empty();
203            let segments = TestSegments::default();
204            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
205            let mut builder =
206                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
207            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
208            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
209            let array = builder.finish();
210            array.statistics().set_iter(
211                array
212                    .statistics()
213                    .compute_all(&Stat::all().collect::<Vec<_>>())
214                    .vortex_unwrap()
215                    .into_iter(),
216            );
217
218            let layout = FlatLayoutStrategy::default()
219                .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
220                .await
221                .unwrap();
222            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
223
224            let result = layout
225                .new_reader("".into(), segments)
226                .unwrap()
227                .projection_evaluation(&(0..layout.row_count()), &root())
228                .unwrap()
229                .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
230                .await
231                .unwrap();
232
233            assert_eq!(
234                result.statistics().get_as::<String>(Stat::Min),
235                Some(Precision::Inexact(
236                    "Another string that's meant to be smaller than the previous valu".to_string()
237                ))
238            );
239            assert_eq!(
240                result.statistics().get_as::<String>(Stat::Max),
241                Some(Precision::Inexact(
242                    "Long value to test that the statistics are actually truncated, j".to_string()
243                ))
244            );
245        })
246    }
247
248    #[test]
249    fn struct_array_round_trip() {
250        block_on(async {
251            let mut validity_builder = BooleanBufferBuilder::new(2);
252            validity_builder.append(true);
253            validity_builder.append(false);
254            let validity_boolean_buffer = validity_builder.finish();
255            let validity = Validity::Array(
256                BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
257            );
258            let array = StructArray::try_new(
259                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
260                vec![
261                    buffer![1_u64, 2].into_array(),
262                    buffer![3_u64, 4].into_array(),
263                ],
264                2,
265                validity,
266            )
267            .unwrap();
268
269            let ctx = ArrayContext::empty();
270
271            // Write the array into a byte buffer.
272            let (layout, segments) = {
273                let segments = TestSegments::default();
274                let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
275                let layout = FlatLayoutStrategy::default()
276                    .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
277                    .await
278                    .unwrap();
279
280                (layout, Arc::new(segments) as Arc<dyn SegmentSource>)
281            };
282
283            // We should be able to read the array we just wrote.
284            let result: ArrayRef = layout
285                .new_reader("".into(), segments)
286                .unwrap()
287                .projection_evaluation(&(0..layout.row_count()), &root())
288                .unwrap()
289                .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
290                .await
291                .unwrap();
292
293            assert_eq!(
294                result.validity_mask().unwrap().boolean_buffer(),
295                AllOr::Some(&validity_boolean_buffer)
296            );
297            assert_eq!(
298                result
299                    .to_struct()
300                    .unwrap()
301                    .field_by_name("a")
302                    .unwrap()
303                    .to_primitive()
304                    .unwrap()
305                    .as_slice::<u64>(),
306                &[1, 2]
307            );
308            assert_eq!(
309                result
310                    .to_struct()
311                    .unwrap()
312                    .field_by_name("b")
313                    .unwrap()
314                    .to_primitive()
315                    .unwrap()
316                    .as_slice::<u64>(),
317                &[3, 4]
318            );
319        })
320    }
321}