vortex_tensor/fixed_shape/
vtable.rs1use vortex::dtype::DType;
5use vortex::dtype::extension::ExtDType;
6use vortex::dtype::extension::ExtId;
7use vortex::dtype::extension::ExtVTable;
8use vortex::error::VortexResult;
9use vortex::error::vortex_bail;
10use vortex::error::vortex_ensure;
11use vortex::error::vortex_ensure_eq;
12use vortex::scalar::ScalarValue;
13
14use crate::fixed_shape::FixedShapeTensor;
15use crate::fixed_shape::FixedShapeTensorMetadata;
16use crate::fixed_shape::proto;
17
18impl ExtVTable for FixedShapeTensor {
19 type Metadata = FixedShapeTensorMetadata;
20
21 type NativeValue<'a> = &'a ScalarValue;
23
24 fn id(&self) -> ExtId {
25 ExtId::new_ref("vortex.fixed_shape_tensor")
26 }
27
28 fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
29 Ok(proto::serialize(metadata))
30 }
31
32 fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult<Self::Metadata> {
33 proto::deserialize(metadata)
34 }
35
36 fn validate_dtype(&self, ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
37 let storage_dtype = ext_dtype.storage_dtype();
38 let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else {
39 vortex_bail!(
40 "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}"
41 );
42 };
43
44 vortex_ensure!(
46 element_dtype.is_primitive(),
47 "FixedShapeTensor element dtype must be primitive, got {element_dtype} \
48 (may change in the future)"
49 );
50 vortex_ensure!(
51 !element_dtype.is_nullable(),
52 "FixedShapeTensor element dtype must be non-nullable (may change in the future)"
53 );
54
55 let element_count: usize = ext_dtype.metadata().logical_shape().iter().product();
56 vortex_ensure_eq!(
57 element_count,
58 *list_size as usize,
59 "FixedShapeTensor logical shape product ({element_count}) does not match \
60 FixedSizeList size ({list_size})"
61 );
62
63 Ok(())
64 }
65
66 fn unpack_native<'a>(
67 &self,
68 _ext_dtype: &'a ExtDType<Self>,
69 storage_value: &'a ScalarValue,
70 ) -> VortexResult<Self::NativeValue<'a>> {
71 Ok(storage_value)
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use rstest::rstest;
81 use vortex::dtype::extension::ExtVTable;
82 use vortex::error::VortexResult;
83
84 use crate::fixed_shape::FixedShapeTensor;
85 use crate::fixed_shape::FixedShapeTensorMetadata;
86
87 fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> {
89 let vtable = FixedShapeTensor;
90 let bytes = vtable.serialize_metadata(metadata)?;
91 let deserialized = vtable.deserialize_metadata(&bytes)?;
92 assert_eq!(&deserialized, metadata);
93 Ok(())
94 }
95
96 #[rstest]
97 #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))]
98 #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))]
99 fn roundtrip_simple(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> {
100 assert_roundtrip(&metadata)
101 }
102
103 #[rstest]
104 #[case::with_permutation(
105 FixedShapeTensorMetadata::new(vec![2, 3, 4])
106 .with_permutation(vec![2, 0, 1])
107 )]
108 #[case::with_dim_names(
109 FixedShapeTensorMetadata::new(vec![3, 4])
110 .with_dim_names(vec!["rows".into(), "cols".into()])
111 )]
112 #[case::all_fields(
113 FixedShapeTensorMetadata::new(vec![2, 3, 4])
114 .with_dim_names(vec!["x".into(), "y".into(), "z".into()])
115 .and_then(|m| m.with_permutation(vec![1, 2, 0]))
116 )]
117 fn roundtrip_with_options(
118 #[case] metadata: VortexResult<FixedShapeTensorMetadata>,
119 ) -> VortexResult<()> {
120 assert_roundtrip(&metadata?)
121 }
122}