vortex_layout/layouts/flat/
writer.rs1use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::serde::SerializeOptions;
7use vortex_array::stats::{Precision, Stat, StatsProvider};
8use vortex_array::{Array, ArrayContext};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_scalar::{BinaryScalar, Utf8Scalar};
12
13use crate::layouts::flat::FlatLayout;
14use crate::layouts::zoned::{lower_bound, upper_bound};
15use crate::segments::SequenceWriter;
16use crate::{IntoLayout, LayoutRef, LayoutStrategy, SendableSequentialStream};
17
18#[derive(Clone)]
19pub struct FlatLayoutStrategy {
20 pub include_padding: bool,
22 pub max_variable_length_statistics_size: usize,
24}
25
26impl Default for FlatLayoutStrategy {
27 fn default() -> Self {
28 Self {
29 include_padding: true,
30 max_variable_length_statistics_size: 64,
31 }
32 }
33}
34
35#[async_trait]
36impl LayoutStrategy for FlatLayoutStrategy {
37 async fn write_stream(
38 &self,
39 ctx: &ArrayContext,
40 sequence_writer: SequenceWriter,
41 mut stream: SendableSequentialStream,
42 ) -> VortexResult<LayoutRef> {
43 let ctx = ctx.clone();
44 let options = self.clone();
45 let Some(chunk) = stream.next().await else {
46 vortex_bail!("flat layout needs a single chunk");
47 };
48 let (sequence_id, chunk) = chunk?;
49
50 let row_count = chunk.len() as u64;
51
52 match chunk.dtype() {
53 DType::Utf8(_) => {
54 if let Some(sv) = chunk.statistics().get(Stat::Min) {
55 let (value, truncated) = lower_bound::<Utf8Scalar>(
56 sv.into_inner(),
57 options.max_variable_length_statistics_size,
58 )?;
59 if truncated {
60 chunk
61 .statistics()
62 .set(Stat::Min, Precision::Inexact(value.into_value()));
63 }
64 }
65
66 if let Some(sv) = chunk.statistics().get(Stat::Max) {
67 let (value, truncated) = upper_bound::<Utf8Scalar>(
68 sv.into_inner(),
69 options.max_variable_length_statistics_size,
70 )?;
71 if let Some(upper_bound) = value {
72 if truncated {
73 chunk
74 .statistics()
75 .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
76 }
77 } else {
78 chunk.statistics().clear(Stat::Max)
79 }
80 }
81 }
82 DType::Binary(_) => {
83 if let Some(sv) = chunk.statistics().get(Stat::Min) {
84 let (value, truncated) = lower_bound::<BinaryScalar>(
85 sv.into_inner(),
86 options.max_variable_length_statistics_size,
87 )?;
88 if truncated {
89 chunk
90 .statistics()
91 .set(Stat::Min, Precision::Inexact(value.into_value()));
92 }
93 }
94
95 if let Some(sv) = chunk.statistics().get(Stat::Max) {
96 let (value, truncated) = upper_bound::<BinaryScalar>(
97 sv.into_inner(),
98 options.max_variable_length_statistics_size,
99 )?;
100 if let Some(upper_bound) = value {
101 if truncated {
102 chunk
103 .statistics()
104 .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
105 }
106 } else {
107 chunk.statistics().clear(Stat::Max)
108 }
109 }
110 }
111 _ => {}
112 }
113
114 let buffers = chunk.serialize(
116 &ctx,
117 &SerializeOptions {
118 offset: 0,
119 include_padding: options.include_padding,
120 },
121 )?;
122 let segment_id = sequence_writer.put(sequence_id, buffers).await?;
123
124 let None = stream.next().await else {
125 vortex_bail!("flat layout received stream with more than a single chunk");
126 };
127 Ok(
128 FlatLayout::new(row_count, stream.dtype().clone(), segment_id, ctx.clone())
129 .into_layout(),
130 )
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::sync::Arc;
137
138 use arrow_buffer::BooleanBufferBuilder;
139 use futures::executor::block_on;
140 use futures::stream;
141 use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
142 use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
143 use vortex_array::stats::{Precision, Stat, StatsProviderExt};
144 use vortex_array::validity::Validity;
145 use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
146 use vortex_buffer::buffer;
147 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
148 use vortex_error::VortexUnwrap;
149 use vortex_expr::root;
150 use vortex_mask::{AllOr, Mask};
151
152 use crate::layouts::flat::writer::FlatLayoutStrategy;
153 use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
154 use crate::sequence::SequenceId;
155 use crate::{
156 LayoutStrategy, SendableSequentialStream, SequentialStreamAdapter, SequentialStreamExt as _,
157 };
158
159 fn stream_only(array: ArrayRef) -> SendableSequentialStream {
160 SequentialStreamAdapter::new(
161 array.dtype().clone(),
162 stream::once(async move { Ok((SequenceId::root().downgrade(), array)) }),
163 )
164 .sendable()
165 }
166
167 #[should_panic]
170 #[test]
171 fn flat_stats() {
172 block_on(async {
173 let ctx = ArrayContext::empty();
174 let segments = TestSegments::default();
175 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
176 let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
177 let layout = FlatLayoutStrategy::default()
178 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
179 .await
180 .unwrap();
181 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
182
183 let result = layout
184 .new_reader("".into(), segments)
185 .unwrap()
186 .projection_evaluation(&(0..layout.row_count()), &root())
187 .unwrap()
188 .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
189 .await
190 .unwrap();
191
192 assert_eq!(
193 result.statistics().get_as::<bool>(Stat::IsSorted),
194 Some(Precision::Exact(true))
195 );
196 })
197 }
198
199 #[test]
200 fn truncates_variable_size_stats() {
201 block_on(async {
202 let ctx = ArrayContext::empty();
203 let segments = TestSegments::default();
204 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
205 let mut builder =
206 VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
207 builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
208 builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
209 let array = builder.finish();
210 array.statistics().set_iter(
211 array
212 .statistics()
213 .compute_all(&Stat::all().collect::<Vec<_>>())
214 .vortex_unwrap()
215 .into_iter(),
216 );
217
218 let layout = FlatLayoutStrategy::default()
219 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
220 .await
221 .unwrap();
222 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
223
224 let result = layout
225 .new_reader("".into(), segments)
226 .unwrap()
227 .projection_evaluation(&(0..layout.row_count()), &root())
228 .unwrap()
229 .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
230 .await
231 .unwrap();
232
233 assert_eq!(
234 result.statistics().get_as::<String>(Stat::Min),
235 Some(Precision::Inexact(
236 "Another string that's meant to be smaller than the previous valu".to_string()
237 ))
238 );
239 assert_eq!(
240 result.statistics().get_as::<String>(Stat::Max),
241 Some(Precision::Inexact(
242 "Long value to test that the statistics are actually truncated, j".to_string()
243 ))
244 );
245 })
246 }
247
248 #[test]
249 fn struct_array_round_trip() {
250 block_on(async {
251 let mut validity_builder = BooleanBufferBuilder::new(2);
252 validity_builder.append(true);
253 validity_builder.append(false);
254 let validity_boolean_buffer = validity_builder.finish();
255 let validity = Validity::Array(
256 BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
257 );
258 let array = StructArray::try_new(
259 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
260 vec![
261 buffer![1_u64, 2].into_array(),
262 buffer![3_u64, 4].into_array(),
263 ],
264 2,
265 validity,
266 )
267 .unwrap();
268
269 let ctx = ArrayContext::empty();
270
271 let (layout, segments) = {
273 let segments = TestSegments::default();
274 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
275 let layout = FlatLayoutStrategy::default()
276 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
277 .await
278 .unwrap();
279
280 (layout, Arc::new(segments) as Arc<dyn SegmentSource>)
281 };
282
283 let result: ArrayRef = layout
285 .new_reader("".into(), segments)
286 .unwrap()
287 .projection_evaluation(&(0..layout.row_count()), &root())
288 .unwrap()
289 .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
290 .await
291 .unwrap();
292
293 assert_eq!(
294 result.validity_mask().unwrap().boolean_buffer(),
295 AllOr::Some(&validity_boolean_buffer)
296 );
297 assert_eq!(
298 result
299 .to_struct()
300 .unwrap()
301 .field_by_name("a")
302 .unwrap()
303 .to_primitive()
304 .unwrap()
305 .as_slice::<u64>(),
306 &[1, 2]
307 );
308 assert_eq!(
309 result
310 .to_struct()
311 .unwrap()
312 .field_by_name("b")
313 .unwrap()
314 .to_primitive()
315 .unwrap()
316 .as_slice::<u64>(),
317 &[3, 4]
318 );
319 })
320 }
321}