vortex_sequence/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::serde::ArrayChildren;
5use vortex_array::vtable::{EncodeVTable, SerdeVTable};
6use vortex_array::{Canonical, DeserializeMetadata, ProstMetadata};
7use vortex_buffer::ByteBuffer;
8use vortex_dtype::DType;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::{VortexExpect, VortexResult, vortex_err};
11use vortex_proto::scalar::ScalarValue;
12use vortex_scalar::Scalar;
13
14use crate::array::{SequenceArray, SequenceEncoding, SequenceVTable};
15
16#[derive(Clone, prost::Message)]
17pub struct SequenceMetadata {
18    #[prost(message, tag = "1")]
19    base: Option<ScalarValue>,
20    #[prost(message, tag = "2")]
21    multiplier: Option<ScalarValue>,
22}
23
24impl EncodeVTable<SequenceVTable> for SequenceVTable {
25    fn encode(
26        _encoding: &SequenceEncoding,
27        _canonical: &Canonical,
28        _like: Option<&SequenceArray>,
29    ) -> VortexResult<Option<SequenceArray>> {
30        // TODO(joe): hook up compressor
31        Ok(None)
32    }
33}
34
35impl SerdeVTable<SequenceVTable> for SequenceVTable {
36    type Metadata = ProstMetadata<SequenceMetadata>;
37
38    fn metadata(array: &SequenceArray) -> VortexResult<Option<Self::Metadata>> {
39        Ok(Some(ProstMetadata(SequenceMetadata {
40            base: Some((&array.base()).into()),
41            multiplier: Some((&array.multiplier()).into()),
42        })))
43    }
44
45    fn build(
46        _encoding: &SequenceEncoding,
47        dtype: &DType,
48        len: usize,
49        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
50        _buffers: &[ByteBuffer],
51        _children: &dyn ArrayChildren,
52    ) -> VortexResult<SequenceArray> {
53        let ptype = dtype.as_ptype();
54
55        // We go via scalar to cast the scalar values into the correct PType
56        let base = Scalar::new(
57            DType::Primitive(ptype, NonNullable),
58            metadata
59                .base
60                .as_ref()
61                .ok_or_else(|| vortex_err!("base required"))?
62                .try_into()?,
63        )
64        .as_primitive()
65        .pvalue()
66        .vortex_expect("non-nullable primitive");
67
68        let multiplier = Scalar::new(
69            DType::Primitive(ptype, NonNullable),
70            metadata
71                .multiplier
72                .as_ref()
73                .ok_or_else(|| vortex_err!("base required"))?
74                .try_into()?,
75        )
76        .as_primitive()
77        .pvalue()
78        .vortex_expect("non-nullable primitive");
79
80        Ok(SequenceArray::unchecked_new(
81            base,
82            multiplier,
83            ptype,
84            dtype.nullability(),
85            len,
86        ))
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use std::sync::Arc;
93
94    use arcref::ArcRef;
95    use vortex_array::arrays::{PrimitiveArray, StructArray};
96    use vortex_array::iter::ArrayIteratorExt;
97    use vortex_dtype::Nullability;
98    use vortex_expr::{get_item, root};
99    use vortex_file::{VortexOpenOptions, VortexWriteOptions};
100    use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
101
102    use crate::SequenceArray;
103
104    #[tokio::test]
105    async fn round_trip_seq() {
106        let seq = SequenceArray::typed_new(2i8, 3, Nullability::NonNullable, 4).unwrap();
107        let st = StructArray::from_fields(&[("a", seq.to_array())]).unwrap();
108
109        let file = tokio::fs::File::create("/tmp/abc.vx").await.unwrap();
110        VortexWriteOptions::default()
111            .with_strategy(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())))
112            .write(file, st.to_array_stream())
113            .await
114            .unwrap();
115
116        let file = VortexOpenOptions::file().open("/tmp/abc.vx").await.unwrap();
117        let array = file
118            .scan()
119            .unwrap()
120            .with_projection(get_item("a", root()))
121            .into_array_iter()
122            .unwrap()
123            .read_all()
124            .unwrap();
125
126        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i8 + i * 3));
127
128        assert_eq!(
129            array
130                .to_canonical()
131                .unwrap()
132                .into_primitive()
133                .unwrap()
134                .as_slice::<i8>(),
135            canon.as_slice::<i8>()
136        )
137    }
138}