vortex_tensor/fixed_shape/
vtable.rs1use vortex::dtype::DType;
5use vortex::dtype::extension::ExtDType;
6use vortex::dtype::extension::ExtId;
7use vortex::dtype::extension::ExtVTable;
8use vortex::error::VortexResult;
9use vortex::error::vortex_bail;
10use vortex::error::vortex_ensure;
11use vortex::error::vortex_ensure_eq;
12use vortex::scalar::ScalarValue;
13
14use crate::fixed_shape::FixedShapeTensor;
15use crate::fixed_shape::FixedShapeTensorMetadata;
16use crate::fixed_shape::proto;
17
18impl ExtVTable for FixedShapeTensor {
19 type Metadata = FixedShapeTensorMetadata;
20
21 type NativeValue<'a> = &'a ScalarValue;
23
24 fn id(&self) -> ExtId {
25 ExtId::new_ref("vortex.tensor.fixed_shape_tensor")
26 }
27
28 fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
29 Ok(proto::serialize(metadata))
30 }
31
32 fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult<Self::Metadata> {
33 proto::deserialize(metadata)
34 }
35
36 fn validate_dtype(ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
37 let storage_dtype = ext_dtype.storage_dtype();
38 let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else {
39 vortex_bail!(
40 "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}"
41 );
42 };
43
44 vortex_ensure!(
46 element_dtype.is_primitive(),
47 "FixedShapeTensor element dtype must be primitive, got {element_dtype} \
48 (may change in the future)"
49 );
50 vortex_ensure!(
51 !element_dtype.is_nullable(),
52 "FixedShapeTensor element dtype must be non-nullable (may change in the future)"
53 );
54
55 let element_count: usize = ext_dtype.metadata().logical_shape().iter().product();
56 vortex_ensure_eq!(
57 element_count,
58 *list_size as usize,
59 "FixedShapeTensor logical shape product ({element_count}) does not match \
60 FixedSizeList size ({list_size})"
61 );
62
63 Ok(())
64 }
65
66 fn unpack_native<'a>(
67 _ext_dtype: &'a ExtDType<Self>,
68 storage_value: &'a ScalarValue,
69 ) -> VortexResult<Self::NativeValue<'a>> {
70 Ok(storage_value)
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use rstest::rstest;
80 use vortex::dtype::extension::ExtVTable;
81 use vortex::error::VortexResult;
82
83 use crate::fixed_shape::FixedShapeTensor;
84 use crate::fixed_shape::FixedShapeTensorMetadata;
85
86 fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> {
88 let vtable = FixedShapeTensor;
89 let bytes = vtable.serialize_metadata(metadata)?;
90 let deserialized = vtable.deserialize_metadata(&bytes)?;
91 assert_eq!(&deserialized, metadata);
92 Ok(())
93 }
94
95 #[rstest]
96 #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))]
97 #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))]
98 fn roundtrip_simple(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> {
99 assert_roundtrip(&metadata)
100 }
101
102 #[rstest]
103 #[case::with_permutation(
104 FixedShapeTensorMetadata::new(vec![2, 3, 4])
105 .with_permutation(vec![2, 0, 1])
106 )]
107 #[case::with_dim_names(
108 FixedShapeTensorMetadata::new(vec![3, 4])
109 .with_dim_names(vec!["rows".into(), "cols".into()])
110 )]
111 #[case::all_fields(
112 FixedShapeTensorMetadata::new(vec![2, 3, 4])
113 .with_dim_names(vec!["x".into(), "y".into(), "z".into()])
114 .and_then(|m| m.with_permutation(vec![1, 2, 0]))
115 )]
116 fn roundtrip_with_options(
117 #[case] metadata: VortexResult<FixedShapeTensorMetadata>,
118 ) -> VortexResult<()> {
119 assert_roundtrip(&metadata?)
120 }
121}