vortex_layout/layouts/flat/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::serde::SerializeOptions;
7use vortex_array::stats::{Precision, Stat, StatsProvider};
8use vortex_array::{Array, ArrayContext};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_io::runtime::Handle;
12
13use crate::layouts::flat::{FLAT_LAYOUT_INLINE_ARRAY_NODE, FlatLayout};
14use crate::layouts::zoned::{lower_bound, upper_bound};
15use crate::segments::SegmentSinkRef;
16use crate::sequence::{SendableSequentialStream, SequencePointer};
17use crate::{IntoLayout, LayoutRef, LayoutStrategy};
18
19#[derive(Clone)]
20pub struct FlatLayoutStrategy {
21    /// Whether to include padding for memory-mapped reads.
22    pub include_padding: bool,
23    /// Maximum length of variable length statistics
24    pub max_variable_length_statistics_size: usize,
25}
26
27impl Default for FlatLayoutStrategy {
28    fn default() -> Self {
29        Self {
30            include_padding: true,
31            max_variable_length_statistics_size: 64,
32        }
33    }
34}
35
36#[async_trait]
37impl LayoutStrategy for FlatLayoutStrategy {
38    async fn write_stream(
39        &self,
40        ctx: ArrayContext,
41        segment_sink: SegmentSinkRef,
42        mut stream: SendableSequentialStream,
43        _eof: SequencePointer,
44        _handle: Handle,
45    ) -> VortexResult<LayoutRef> {
46        let ctx = ctx.clone();
47        let options = self.clone();
48        let Some(chunk) = stream.next().await else {
49            vortex_bail!("flat layout needs a single chunk");
50        };
51        let (sequence_id, chunk) = chunk?;
52
53        let row_count = chunk.len() as u64;
54
55        match chunk.dtype() {
56            DType::Utf8(_) => {
57                if let Some(sv) = chunk.statistics().get(Stat::Min) {
58                    let (value, truncated) = lower_bound(
59                        sv.into_inner().as_utf8(),
60                        options.max_variable_length_statistics_size,
61                    );
62                    if truncated {
63                        chunk
64                            .statistics()
65                            .set(Stat::Min, Precision::Inexact(value.into_value()));
66                    }
67                }
68
69                if let Some(sv) = chunk.statistics().get(Stat::Max) {
70                    let (value, truncated) = upper_bound(
71                        sv.into_inner().as_utf8(),
72                        options.max_variable_length_statistics_size,
73                    );
74                    if let Some(upper_bound) = value {
75                        if truncated {
76                            chunk
77                                .statistics()
78                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
79                        }
80                    } else {
81                        chunk.statistics().clear(Stat::Max)
82                    }
83                }
84            }
85            DType::Binary(_) => {
86                if let Some(sv) = chunk.statistics().get(Stat::Min) {
87                    let (value, truncated) = lower_bound(
88                        sv.into_inner().as_binary(),
89                        options.max_variable_length_statistics_size,
90                    );
91                    if truncated {
92                        chunk
93                            .statistics()
94                            .set(Stat::Min, Precision::Inexact(value.into_value()));
95                    }
96                }
97
98                if let Some(sv) = chunk.statistics().get(Stat::Max) {
99                    let (value, truncated) = upper_bound(
100                        sv.into_inner().as_binary(),
101                        options.max_variable_length_statistics_size,
102                    );
103                    if let Some(upper_bound) = value {
104                        if truncated {
105                            chunk
106                                .statistics()
107                                .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
108                        }
109                    } else {
110                        chunk.statistics().clear(Stat::Max)
111                    }
112                }
113            }
114            _ => {}
115        }
116
117        // TODO(os): spawn serialization
118        let buffers = chunk.serialize(
119            &ctx,
120            &SerializeOptions {
121                offset: 0,
122                include_padding: options.include_padding,
123            },
124        )?;
125        // there is at least the flatbuffer and the length
126        assert!(buffers.len() >= 2);
127        let array_node =
128            (*FLAT_LAYOUT_INLINE_ARRAY_NODE).then(|| buffers[buffers.len() - 2].clone());
129        let segment_id = segment_sink.write(sequence_id, buffers).await?;
130
131        let None = stream.next().await else {
132            vortex_bail!("flat layout received stream with more than a single chunk");
133        };
134        Ok(FlatLayout::new_with_metadata(
135            row_count,
136            stream.dtype().clone(),
137            segment_id,
138            ctx.clone(),
139            array_node,
140        )
141        .into_layout())
142    }
143
144    fn buffered_bytes(&self) -> u64 {
145        // FlatLayoutStrategy is a leaf strategy with no child strategies and no buffering
146        0
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::sync::Arc;
153
154    use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
155    use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
156    use vortex_array::expr::root;
157    use vortex_array::stats::{Precision, Stat, StatsProviderExt};
158    use vortex_array::validity::Validity;
159    use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, MaskFuture, ToCanonical};
160    use vortex_buffer::{BitBufferMut, buffer};
161    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
162    use vortex_error::VortexUnwrap;
163    use vortex_io::runtime::single::block_on;
164    use vortex_mask::AllOr;
165
166    use crate::LayoutStrategy;
167    use crate::layouts::flat::writer::FlatLayoutStrategy;
168    use crate::segments::TestSegments;
169    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
170
171    // Currently, flat layouts do not force compute stats during write, they only retain
172    // pre-computed stats.
173    #[should_panic]
174    #[test]
175    fn flat_stats() {
176        block_on(|handle| async {
177            let ctx = ArrayContext::empty();
178            let segments = Arc::new(TestSegments::default());
179            let (ptr, eof) = SequenceId::root().split();
180            let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
181            let layout = FlatLayoutStrategy::default()
182                .write_stream(
183                    ctx,
184                    segments.clone(),
185                    array.to_array_stream().sequenced(ptr),
186                    eof,
187                    handle,
188                )
189                .await
190                .unwrap();
191
192            let result = layout
193                .new_reader("".into(), segments)
194                .unwrap()
195                .projection_evaluation(
196                    &(0..layout.row_count()),
197                    &root(),
198                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
199                )
200                .unwrap()
201                .await
202                .unwrap();
203
204            assert_eq!(
205                result.statistics().get_as::<bool>(Stat::IsSorted),
206                Some(Precision::Exact(true))
207            );
208        })
209    }
210
211    #[test]
212    fn truncates_variable_size_stats() {
213        block_on(|handle| async {
214            let ctx = ArrayContext::empty();
215            let segments = Arc::new(TestSegments::default());
216            let (ptr, eof) = SequenceId::root().split();
217            let mut builder =
218                VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
219            builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
220            builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
221            let array = builder.finish();
222            array.statistics().set_iter(
223                array
224                    .statistics()
225                    .compute_all(&Stat::all().collect::<Vec<_>>())
226                    .vortex_unwrap()
227                    .into_iter(),
228            );
229
230            let layout = FlatLayoutStrategy::default()
231                .write_stream(
232                    ctx,
233                    segments.clone(),
234                    array.to_array_stream().sequenced(ptr),
235                    eof,
236                    handle,
237                )
238                .await
239                .unwrap();
240
241            let result = layout
242                .new_reader("".into(), segments)
243                .unwrap()
244                .projection_evaluation(
245                    &(0..layout.row_count()),
246                    &root(),
247                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
248                )
249                .unwrap()
250                .await
251                .unwrap();
252
253            assert_eq!(
254                result.statistics().get_as::<String>(Stat::Min),
255                // The typo is correct, we need this to be truncated.
256                Some(Precision::Inexact(
257                    // spellchecker:ignore-next-line
258                    "Another string that's meant to be smaller than the previous valu".to_string()
259                ))
260            );
261            assert_eq!(
262                result.statistics().get_as::<String>(Stat::Max),
263                Some(Precision::Inexact(
264                    "Long value to test that the statistics are actually truncated, j".to_string()
265                ))
266            );
267        })
268    }
269
270    #[test]
271    fn struct_array_round_trip() {
272        block_on(|handle| async {
273            let mut validity_builder = BitBufferMut::with_capacity(2);
274            validity_builder.append(true);
275            validity_builder.append(false);
276            let validity_boolean_buffer = validity_builder.freeze();
277            let validity = Validity::Array(
278                BoolArray::from_bit_buffer(validity_boolean_buffer.clone(), Validity::NonNullable)
279                    .into_array(),
280            );
281            let array = StructArray::try_new(
282                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
283                vec![
284                    buffer![1_u64, 2].into_array(),
285                    buffer![3_u64, 4].into_array(),
286                ],
287                2,
288                validity,
289            )
290            .unwrap();
291
292            let ctx = ArrayContext::empty();
293
294            // Write the array into a byte buffer.
295            let (layout, segments) = {
296                let segments = Arc::new(TestSegments::default());
297                let (ptr, eof) = SequenceId::root().split();
298                let layout = FlatLayoutStrategy::default()
299                    .write_stream(
300                        ctx,
301                        segments.clone(),
302                        array.to_array_stream().sequenced(ptr),
303                        eof,
304                        handle,
305                    )
306                    .await
307                    .unwrap();
308
309                (layout, segments)
310            };
311
312            // We should be able to read the array we just wrote.
313            let result: ArrayRef = layout
314                .new_reader("".into(), segments)
315                .unwrap()
316                .projection_evaluation(
317                    &(0..layout.row_count()),
318                    &root(),
319                    MaskFuture::new_true(layout.row_count().try_into().unwrap()),
320                )
321                .unwrap()
322                .await
323                .unwrap();
324
325            assert_eq!(
326                result.validity_mask().bit_buffer(),
327                AllOr::Some(&validity_boolean_buffer)
328            );
329            assert_eq!(
330                result
331                    .to_struct()
332                    .field_by_name("a")
333                    .unwrap()
334                    .to_primitive()
335                    .as_slice::<u64>(),
336                &[1, 2]
337            );
338            assert_eq!(
339                result
340                    .to_struct()
341                    .field_by_name("b")
342                    .unwrap()
343                    .to_primitive()
344                    .as_slice::<u64>(),
345                &[3, 4]
346            );
347        })
348    }
349}