vortex_layout/layouts/flat/
writer.rs1use futures::StreamExt;
5use vortex_array::serde::SerializeOptions;
6use vortex_array::stats::{Precision, Stat, StatsProvider};
7use vortex_array::{Array, ArrayContext};
8use vortex_dtype::DType;
9use vortex_error::vortex_bail;
10use vortex_scalar::{BinaryScalar, Utf8Scalar};
11
12use crate::layouts::flat::FlatLayout;
13use crate::layouts::zoned::{lower_bound, upper_bound};
14use crate::segments::SequenceWriter;
15use crate::{IntoLayout, LayoutStrategy, SendableLayoutFuture, SendableSequentialStream};
16
17#[derive(Clone)]
18pub struct FlatLayoutStrategy {
19 pub include_padding: bool,
21 pub max_variable_length_statistics_size: usize,
23}
24
25impl Default for FlatLayoutStrategy {
26 fn default() -> Self {
27 Self {
28 include_padding: true,
29 max_variable_length_statistics_size: 64,
30 }
31 }
32}
33
34impl LayoutStrategy for FlatLayoutStrategy {
35 fn write_stream(
36 &self,
37 ctx: &ArrayContext,
38 sequence_writer: SequenceWriter,
39 mut stream: SendableSequentialStream,
40 ) -> SendableLayoutFuture {
41 let ctx = ctx.clone();
42 let options = self.clone();
43 Box::pin(async move {
44 let Some(chunk) = stream.next().await else {
45 vortex_bail!("flat layout needs a single chunk");
46 };
47 let (sequence_id, chunk) = chunk?;
48
49 let row_count = chunk.len() as u64;
50
51 match chunk.dtype() {
52 DType::Utf8(_) => {
53 if let Some(sv) = chunk.statistics().get(Stat::Min) {
54 let (value, truncated) = lower_bound::<Utf8Scalar>(
55 chunk.dtype(),
56 sv.into_inner(),
57 options.max_variable_length_statistics_size,
58 )?;
59 if truncated {
60 chunk.statistics().set(Stat::Min, Precision::Inexact(value));
61 }
62 }
63
64 if let Some(sv) = chunk.statistics().get(Stat::Max) {
65 let (value, truncated) = upper_bound::<Utf8Scalar>(
66 chunk.dtype(),
67 sv.into_inner(),
68 options.max_variable_length_statistics_size,
69 )?;
70 if let Some(upper_bound) = value {
71 if truncated {
72 chunk
73 .statistics()
74 .set(Stat::Max, Precision::Inexact(upper_bound));
75 }
76 } else {
77 chunk.statistics().clear(Stat::Max)
78 }
79 }
80 }
81 DType::Binary(_) => {
82 if let Some(sv) = chunk.statistics().get(Stat::Min) {
83 let (value, truncated) = lower_bound::<BinaryScalar>(
84 chunk.dtype(),
85 sv.into_inner(),
86 options.max_variable_length_statistics_size,
87 )?;
88 if truncated {
89 chunk.statistics().set(Stat::Min, Precision::Inexact(value));
90 }
91 }
92
93 if let Some(sv) = chunk.statistics().get(Stat::Max) {
94 let (value, truncated) = upper_bound::<BinaryScalar>(
95 chunk.dtype(),
96 sv.into_inner(),
97 options.max_variable_length_statistics_size,
98 )?;
99 if let Some(upper_bound) = value {
100 if truncated {
101 chunk
102 .statistics()
103 .set(Stat::Max, Precision::Inexact(upper_bound));
104 }
105 } else {
106 chunk.statistics().clear(Stat::Max)
107 }
108 }
109 }
110 _ => {}
111 }
112
113 let buffers = chunk.serialize(
115 &ctx,
116 &SerializeOptions {
117 offset: 0,
118 include_padding: options.include_padding,
119 },
120 )?;
121 let segment_id = sequence_writer.put(sequence_id, buffers).await?;
122
123 let None = stream.next().await else {
124 vortex_bail!("flat layout received stream with more than a single chunk");
125 };
126 Ok(
127 FlatLayout::new(row_count, stream.dtype().clone(), segment_id, ctx.clone())
128 .into_layout(),
129 )
130 })
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::sync::Arc;
137
138 use arrow_buffer::BooleanBufferBuilder;
139 use futures::executor::block_on;
140 use futures::stream;
141 use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
142 use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
143 use vortex_array::stats::{Precision, Stat};
144 use vortex_array::validity::Validity;
145 use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
146 use vortex_buffer::buffer;
147 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
148 use vortex_error::VortexUnwrap;
149 use vortex_expr::root;
150 use vortex_mask::{AllOr, Mask};
151
152 use crate::layouts::flat::writer::FlatLayoutStrategy;
153 use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
154 use crate::sequence::SequenceId;
155 use crate::{
156 LayoutStrategy, SendableSequentialStream, SequentialStreamAdapter, SequentialStreamExt as _,
157 };
158
159 fn stream_only(array: ArrayRef) -> SendableSequentialStream {
160 SequentialStreamAdapter::new(
161 array.dtype().clone(),
162 stream::once(async move { Ok((SequenceId::root().downgrade(), array)) }),
163 )
164 .sendable()
165 }
166
167 #[should_panic]
170 #[test]
171 fn flat_stats() {
172 block_on(async {
173 let ctx = ArrayContext::empty();
174 let segments = TestSegments::default();
175 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
176 let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
177 let layout = FlatLayoutStrategy::default()
178 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
179 .await
180 .unwrap();
181 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
182
183 let result = layout
184 .new_reader("".into(), segments)
185 .unwrap()
186 .projection_evaluation(&(0..layout.row_count()), &root())
187 .unwrap()
188 .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
189 .await
190 .unwrap();
191
192 assert_eq!(
193 result.statistics().get_as::<bool>(Stat::IsSorted),
194 Some(Precision::Exact(true))
195 );
196 })
197 }
198
199 #[test]
200 fn truncates_variable_size_stats() {
201 block_on(async {
202 let ctx = ArrayContext::empty();
203 let segments = TestSegments::default();
204 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
205 let mut builder =
206 VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
207 builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
208 builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
209 let array = builder.finish();
210 array.statistics().set_iter(
211 array
212 .statistics()
213 .compute_all(&Stat::all().collect::<Vec<_>>())
214 .vortex_unwrap()
215 .into_iter(),
216 );
217
218 let layout = FlatLayoutStrategy::default()
219 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
220 .await
221 .unwrap();
222 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
223
224 let result = layout
225 .new_reader("".into(), segments)
226 .unwrap()
227 .projection_evaluation(&(0..layout.row_count()), &root())
228 .unwrap()
229 .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
230 .await
231 .unwrap();
232
233 assert_eq!(
234 result.statistics().get_as::<String>(Stat::Min),
235 Some(Precision::Inexact(
236 "Another string that's meant to be smaller than the previous valu".to_string()
237 ))
238 );
239 assert_eq!(
240 result.statistics().get_as::<String>(Stat::Max),
241 Some(Precision::Inexact(
242 "Long value to test that the statistics are actually truncated, j".to_string()
243 ))
244 );
245 })
246 }
247
248 #[test]
249 fn struct_array_round_trip() {
250 block_on(async {
251 let mut validity_builder = BooleanBufferBuilder::new(2);
252 validity_builder.append(true);
253 validity_builder.append(false);
254 let validity_boolean_buffer = validity_builder.finish();
255 let validity = Validity::Array(
256 BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
257 );
258 let array = StructArray::try_new(
259 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
260 vec![
261 buffer![1_u64, 2].into_array(),
262 buffer![3_u64, 4].into_array(),
263 ],
264 2,
265 validity,
266 )
267 .unwrap();
268
269 let ctx = ArrayContext::empty();
270
271 let (layout, segments) = {
273 let segments = TestSegments::default();
274 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
275 let layout = FlatLayoutStrategy::default()
276 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
277 .await
278 .unwrap();
279
280 (layout, Arc::new(segments) as Arc<dyn SegmentSource>)
281 };
282
283 let result: ArrayRef = layout
285 .new_reader("".into(), segments)
286 .unwrap()
287 .projection_evaluation(&(0..layout.row_count()), &root())
288 .unwrap()
289 .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
290 .await
291 .unwrap();
292
293 assert_eq!(
294 result.validity_mask().unwrap().boolean_buffer(),
295 AllOr::Some(&validity_boolean_buffer)
296 );
297 assert_eq!(
298 result
299 .to_struct()
300 .unwrap()
301 .field_by_name("a")
302 .unwrap()
303 .to_primitive()
304 .unwrap()
305 .as_slice::<u64>(),
306 &[1, 2]
307 );
308 assert_eq!(
309 result
310 .to_struct()
311 .unwrap()
312 .field_by_name("b")
313 .unwrap()
314 .to_primitive()
315 .unwrap()
316 .as_slice::<u64>(),
317 &[3, 4]
318 );
319 })
320 }
321}