vortex_layout/layouts/flat/
writer.rs1use async_trait::async_trait;
5use futures::StreamExt;
6use vortex_array::serde::SerializeOptions;
7use vortex_array::stats::{Precision, Stat, StatsProvider};
8use vortex_array::{Array, ArrayContext};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_scalar::{BinaryScalar, Utf8Scalar};
12
13use crate::layouts::flat::FlatLayout;
14use crate::layouts::zoned::{lower_bound, upper_bound};
15use crate::segments::SequenceWriter;
16use crate::{IntoLayout, LayoutRef, LayoutStrategy, SendableSequentialStream};
17
18#[derive(Clone)]
19pub struct FlatLayoutStrategy {
20 pub include_padding: bool,
22 pub max_variable_length_statistics_size: usize,
24}
25
26impl Default for FlatLayoutStrategy {
27 fn default() -> Self {
28 Self {
29 include_padding: true,
30 max_variable_length_statistics_size: 64,
31 }
32 }
33}
34
35#[async_trait]
36impl LayoutStrategy for FlatLayoutStrategy {
37 async fn write_stream(
38 &self,
39 ctx: &ArrayContext,
40 sequence_writer: SequenceWriter,
41 mut stream: SendableSequentialStream,
42 ) -> VortexResult<LayoutRef> {
43 let ctx = ctx.clone();
44 let options = self.clone();
45 let Some(chunk) = stream.next().await else {
46 vortex_bail!("flat layout needs a single chunk");
47 };
48 let (sequence_id, chunk) = chunk?;
49
50 let row_count = chunk.len() as u64;
51
52 match chunk.dtype() {
53 DType::Utf8(_) => {
54 if let Some(sv) = chunk.statistics().get(Stat::Min) {
55 let (value, truncated) = lower_bound::<Utf8Scalar>(
56 sv.into_inner(),
57 options.max_variable_length_statistics_size,
58 )?;
59 if truncated {
60 chunk
61 .statistics()
62 .set(Stat::Min, Precision::Inexact(value.into_value()));
63 }
64 }
65
66 if let Some(sv) = chunk.statistics().get(Stat::Max) {
67 let (value, truncated) = upper_bound::<Utf8Scalar>(
68 sv.into_inner(),
69 options.max_variable_length_statistics_size,
70 )?;
71 if let Some(upper_bound) = value {
72 if truncated {
73 chunk
74 .statistics()
75 .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
76 }
77 } else {
78 chunk.statistics().clear(Stat::Max)
79 }
80 }
81 }
82 DType::Binary(_) => {
83 if let Some(sv) = chunk.statistics().get(Stat::Min) {
84 let (value, truncated) = lower_bound::<BinaryScalar>(
85 sv.into_inner(),
86 options.max_variable_length_statistics_size,
87 )?;
88 if truncated {
89 chunk
90 .statistics()
91 .set(Stat::Min, Precision::Inexact(value.into_value()));
92 }
93 }
94
95 if let Some(sv) = chunk.statistics().get(Stat::Max) {
96 let (value, truncated) = upper_bound::<BinaryScalar>(
97 sv.into_inner(),
98 options.max_variable_length_statistics_size,
99 )?;
100 if let Some(upper_bound) = value {
101 if truncated {
102 chunk
103 .statistics()
104 .set(Stat::Max, Precision::Inexact(upper_bound.into_value()));
105 }
106 } else {
107 chunk.statistics().clear(Stat::Max)
108 }
109 }
110 }
111 _ => {}
112 }
113
114 let buffers = chunk.serialize(
116 &ctx,
117 &SerializeOptions {
118 offset: 0,
119 include_padding: options.include_padding,
120 },
121 )?;
122 let segment_id = sequence_writer.put(sequence_id, buffers).await?;
123
124 let None = stream.next().await else {
125 vortex_bail!("flat layout received stream with more than a single chunk");
126 };
127 Ok(
128 FlatLayout::new(row_count, stream.dtype().clone(), segment_id, ctx.clone())
129 .into_layout(),
130 )
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::sync::Arc;
137
138 use arrow_buffer::BooleanBufferBuilder;
139 use futures::executor::block_on;
140 use futures::stream;
141 use vortex_array::arrays::{BoolArray, PrimitiveArray, StructArray};
142 use vortex_array::builders::{ArrayBuilder, VarBinViewBuilder};
143 use vortex_array::stats::{Precision, Stat, StatsProviderExt};
144 use vortex_array::validity::Validity;
145 use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
146 use vortex_buffer::buffer;
147 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
148 use vortex_error::VortexUnwrap;
149 use vortex_expr::root;
150 use vortex_mask::AllOr;
151
152 use crate::layouts::flat::writer::FlatLayoutStrategy;
153 use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
154 use crate::sequence::SequenceId;
155 use crate::{
156 LayoutStrategy, MaskFuture, SendableSequentialStream, SequentialStreamAdapter,
157 SequentialStreamExt as _,
158 };
159
160 fn stream_only(array: ArrayRef) -> SendableSequentialStream {
161 SequentialStreamAdapter::new(
162 array.dtype().clone(),
163 stream::once(async move { Ok((SequenceId::root().downgrade(), array)) }),
164 )
165 .sendable()
166 }
167
168 #[should_panic]
171 #[test]
172 fn flat_stats() {
173 block_on(async {
174 let ctx = ArrayContext::empty();
175 let segments = TestSegments::default();
176 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
177 let array = PrimitiveArray::new(buffer![1, 2, 3, 4, 5], Validity::AllValid);
178 let layout = FlatLayoutStrategy::default()
179 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
180 .await
181 .unwrap();
182 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
183
184 let result = layout
185 .new_reader("".into(), segments)
186 .unwrap()
187 .projection_evaluation(&(0..layout.row_count()), &root())
188 .unwrap()
189 .invoke(MaskFuture::new_true(layout.row_count().try_into().unwrap()))
190 .await
191 .unwrap();
192
193 assert_eq!(
194 result.statistics().get_as::<bool>(Stat::IsSorted),
195 Some(Precision::Exact(true))
196 );
197 })
198 }
199
200 #[test]
201 fn truncates_variable_size_stats() {
202 block_on(async {
203 let ctx = ArrayContext::empty();
204 let segments = TestSegments::default();
205 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
206 let mut builder =
207 VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 2);
208 builder.append_value("Long value to test that the statistics are actually truncated, it needs a bit of extra padding though");
209 builder.append_value("Another string that's meant to be smaller than the previous value, though still need extra padding");
210 let array = builder.finish();
211 array.statistics().set_iter(
212 array
213 .statistics()
214 .compute_all(&Stat::all().collect::<Vec<_>>())
215 .vortex_unwrap()
216 .into_iter(),
217 );
218
219 let layout = FlatLayoutStrategy::default()
220 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
221 .await
222 .unwrap();
223 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
224
225 let result = layout
226 .new_reader("".into(), segments)
227 .unwrap()
228 .projection_evaluation(&(0..layout.row_count()), &root())
229 .unwrap()
230 .invoke(MaskFuture::new_true(layout.row_count().try_into().unwrap()))
231 .await
232 .unwrap();
233
234 assert_eq!(
235 result.statistics().get_as::<String>(Stat::Min),
236 Some(Precision::Inexact(
237 "Another string that's meant to be smaller than the previous valu".to_string()
238 ))
239 );
240 assert_eq!(
241 result.statistics().get_as::<String>(Stat::Max),
242 Some(Precision::Inexact(
243 "Long value to test that the statistics are actually truncated, j".to_string()
244 ))
245 );
246 })
247 }
248
249 #[test]
250 fn struct_array_round_trip() {
251 block_on(async {
252 let mut validity_builder = BooleanBufferBuilder::new(2);
253 validity_builder.append(true);
254 validity_builder.append(false);
255 let validity_boolean_buffer = validity_builder.finish();
256 let validity = Validity::Array(
257 BoolArray::new(validity_boolean_buffer.clone(), Validity::NonNullable).into_array(),
258 );
259 let array = StructArray::try_new(
260 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
261 vec![
262 buffer![1_u64, 2].into_array(),
263 buffer![3_u64, 4].into_array(),
264 ],
265 2,
266 validity,
267 )
268 .unwrap();
269
270 let ctx = ArrayContext::empty();
271
272 let (layout, segments) = {
274 let segments = TestSegments::default();
275 let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
276 let layout = FlatLayoutStrategy::default()
277 .write_stream(&ctx, sequence_writer, stream_only(array.to_array()))
278 .await
279 .unwrap();
280
281 (layout, Arc::new(segments) as Arc<dyn SegmentSource>)
282 };
283
284 let result: ArrayRef = layout
286 .new_reader("".into(), segments)
287 .unwrap()
288 .projection_evaluation(&(0..layout.row_count()), &root())
289 .unwrap()
290 .invoke(MaskFuture::new_true(layout.row_count().try_into().unwrap()))
291 .await
292 .unwrap();
293
294 assert_eq!(
295 result.validity_mask().boolean_buffer(),
296 AllOr::Some(&validity_boolean_buffer)
297 );
298 assert_eq!(
299 result
300 .to_struct()
301 .field_by_name("a")
302 .unwrap()
303 .to_primitive()
304 .as_slice::<u64>(),
305 &[1, 2]
306 );
307 assert_eq!(
308 result
309 .to_struct()
310 .field_by_name("b")
311 .unwrap()
312 .to_primitive()
313 .as_slice::<u64>(),
314 &[3, 4]
315 );
316 })
317 }
318}