Skip to main content

vortex_tensor/fixed_shape/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex::dtype::DType;
5use vortex::dtype::extension::ExtDType;
6use vortex::dtype::extension::ExtId;
7use vortex::dtype::extension::ExtVTable;
8use vortex::error::VortexResult;
9use vortex::error::vortex_bail;
10use vortex::error::vortex_ensure;
11use vortex::error::vortex_ensure_eq;
12use vortex::scalar::ScalarValue;
13
14use crate::fixed_shape::FixedShapeTensor;
15use crate::fixed_shape::FixedShapeTensorMetadata;
16use crate::fixed_shape::proto;
17
18impl ExtVTable for FixedShapeTensor {
19    type Metadata = FixedShapeTensorMetadata;
20
21    // TODO(connor): This is just a placeholder for now!!!
22    type NativeValue<'a> = &'a ScalarValue;
23
24    fn id(&self) -> ExtId {
25        ExtId::new_ref("vortex.fixed_shape_tensor")
26    }
27
28    fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
29        Ok(proto::serialize(metadata))
30    }
31
32    fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult<Self::Metadata> {
33        proto::deserialize(metadata)
34    }
35
36    fn validate_dtype(&self, ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
37        let storage_dtype = ext_dtype.storage_dtype();
38        let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else {
39            vortex_bail!(
40                "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}"
41            );
42        };
43
44        // Note that these constraints may be relaxed in the future.
45        vortex_ensure!(
46            element_dtype.is_primitive(),
47            "FixedShapeTensor element dtype must be primitive, got {element_dtype} \
48             (may change in the future)"
49        );
50        vortex_ensure!(
51            !element_dtype.is_nullable(),
52            "FixedShapeTensor element dtype must be non-nullable (may change in the future)"
53        );
54
55        let element_count: usize = ext_dtype.metadata().logical_shape().iter().product();
56        vortex_ensure_eq!(
57            element_count,
58            *list_size as usize,
59            "FixedShapeTensor logical shape product ({element_count}) does not match \
60             FixedSizeList size ({list_size})"
61        );
62
63        Ok(())
64    }
65
66    fn unpack_native<'a>(
67        &self,
68        _ext_dtype: &'a ExtDType<Self>,
69        storage_value: &'a ScalarValue,
70    ) -> VortexResult<Self::NativeValue<'a>> {
71        // TODO(connor): This is just a placeholder. However, even if we have a dedicated native
72        // type for a singular tensor, we do not need to validate anything as any backing memory
73        // should be valid for a given tensor.
74        Ok(storage_value)
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use rstest::rstest;
81    use vortex::dtype::extension::ExtVTable;
82    use vortex::error::VortexResult;
83
84    use crate::fixed_shape::FixedShapeTensor;
85    use crate::fixed_shape::FixedShapeTensorMetadata;
86
87    /// Serializes and deserializes the given metadata through protobuf, asserting equality.
88    fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> {
89        let vtable = FixedShapeTensor;
90        let bytes = vtable.serialize_metadata(metadata)?;
91        let deserialized = vtable.deserialize_metadata(&bytes)?;
92        assert_eq!(&deserialized, metadata);
93        Ok(())
94    }
95
96    #[rstest]
97    #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))]
98    #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))]
99    fn roundtrip_simple(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> {
100        assert_roundtrip(&metadata)
101    }
102
103    #[rstest]
104    #[case::with_permutation(
105        FixedShapeTensorMetadata::new(vec![2, 3, 4])
106            .with_permutation(vec![2, 0, 1])
107    )]
108    #[case::with_dim_names(
109        FixedShapeTensorMetadata::new(vec![3, 4])
110            .with_dim_names(vec!["rows".into(), "cols".into()])
111    )]
112    #[case::all_fields(
113        FixedShapeTensorMetadata::new(vec![2, 3, 4])
114            .with_dim_names(vec!["x".into(), "y".into(), "z".into()])
115            .and_then(|m| m.with_permutation(vec![1, 2, 0]))
116    )]
117    fn roundtrip_with_options(
118        #[case] metadata: VortexResult<FixedShapeTensorMetadata>,
119    ) -> VortexResult<()> {
120        assert_roundtrip(&metadata?)
121    }
122}