Skip to main content

vortex_tensor/fixed_shape/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex::dtype::DType;
5use vortex::dtype::extension::ExtDType;
6use vortex::dtype::extension::ExtId;
7use vortex::dtype::extension::ExtVTable;
8use vortex::error::VortexResult;
9use vortex::error::vortex_bail;
10use vortex::error::vortex_ensure;
11use vortex::error::vortex_ensure_eq;
12use vortex::scalar::ScalarValue;
13
14use crate::fixed_shape::FixedShapeTensor;
15use crate::fixed_shape::FixedShapeTensorMetadata;
16use crate::fixed_shape::proto;
17
18impl ExtVTable for FixedShapeTensor {
19    type Metadata = FixedShapeTensorMetadata;
20
21    // TODO(connor): This is just a placeholder for now!!!
22    type NativeValue<'a> = &'a ScalarValue;
23
24    fn id(&self) -> ExtId {
25        ExtId::new_ref("vortex.tensor.fixed_shape_tensor")
26    }
27
28    fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
29        Ok(proto::serialize(metadata))
30    }
31
32    fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult<Self::Metadata> {
33        proto::deserialize(metadata)
34    }
35
36    fn validate_dtype(ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
37        let storage_dtype = ext_dtype.storage_dtype();
38        let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else {
39            vortex_bail!(
40                "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}"
41            );
42        };
43
44        // Note that these constraints may be relaxed in the future.
45        vortex_ensure!(
46            element_dtype.is_primitive(),
47            "FixedShapeTensor element dtype must be primitive, got {element_dtype} \
48             (may change in the future)"
49        );
50        vortex_ensure!(
51            !element_dtype.is_nullable(),
52            "FixedShapeTensor element dtype must be non-nullable (may change in the future)"
53        );
54
55        let element_count: usize = ext_dtype.metadata().logical_shape().iter().product();
56        vortex_ensure_eq!(
57            element_count,
58            *list_size as usize,
59            "FixedShapeTensor logical shape product ({element_count}) does not match \
60             FixedSizeList size ({list_size})"
61        );
62
63        Ok(())
64    }
65
66    fn unpack_native<'a>(
67        _ext_dtype: &'a ExtDType<Self>,
68        storage_value: &'a ScalarValue,
69    ) -> VortexResult<Self::NativeValue<'a>> {
70        // TODO(connor): This is just a placeholder. However, even if we have a dedicated native
71        // type for a singular tensor, we do not need to validate anything as any backing memory
72        // should be valid for a given tensor.
73        Ok(storage_value)
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use rstest::rstest;
80    use vortex::dtype::extension::ExtVTable;
81    use vortex::error::VortexResult;
82
83    use crate::fixed_shape::FixedShapeTensor;
84    use crate::fixed_shape::FixedShapeTensorMetadata;
85
86    /// Serializes and deserializes the given metadata through protobuf, asserting equality.
87    fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> {
88        let vtable = FixedShapeTensor;
89        let bytes = vtable.serialize_metadata(metadata)?;
90        let deserialized = vtable.deserialize_metadata(&bytes)?;
91        assert_eq!(&deserialized, metadata);
92        Ok(())
93    }
94
95    #[rstest]
96    #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))]
97    #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))]
98    fn roundtrip_simple(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> {
99        assert_roundtrip(&metadata)
100    }
101
102    #[rstest]
103    #[case::with_permutation(
104        FixedShapeTensorMetadata::new(vec![2, 3, 4])
105            .with_permutation(vec![2, 0, 1])
106    )]
107    #[case::with_dim_names(
108        FixedShapeTensorMetadata::new(vec![3, 4])
109            .with_dim_names(vec!["rows".into(), "cols".into()])
110    )]
111    #[case::all_fields(
112        FixedShapeTensorMetadata::new(vec![2, 3, 4])
113            .with_dim_names(vec!["x".into(), "y".into(), "z".into()])
114            .and_then(|m| m.with_permutation(vec![1, 2, 0]))
115    )]
116    fn roundtrip_with_options(
117        #[case] metadata: VortexResult<FixedShapeTensorMetadata>,
118    ) -> VortexResult<()> {
119        assert_roundtrip(&metadata?)
120    }
121}