Skip to main content

zarrs/array/codec/array_to_array/
transpose.rs

1//! The `transpose` array to array codec (Core).
2//!
3//! Permutes the dimensions of arrays.
4//!
5//! ### Compatible Implementations
6//! This is a core codec and should be compatible with all Zarr V3 implementations that support it.
7//!
8//! ### Specification
9//! - <https://zarr-specs.readthedocs.io/en/latest/v3/codecs/transpose/index.html>
10//! - <https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/transpose>
11//!
12//! ### Codec `name` Aliases (Zarr V3)
13//! - `transpose`
14//!
15//! ### Codec `id` Aliases (Zarr V2)
16//! None
17//!
18//! ### Codec `configuration` Example - [`TransposeCodecConfiguration`]:
19//! ```rust
20//! # let JSON = r#"
21//! {
22//!     "order": [2, 1, 0]
23//! }
24//! # "#;
25//! # use zarrs::metadata_ext::codec::transpose::TransposeCodecConfiguration;
26//! # let configuration: TransposeCodecConfiguration = serde_json::from_str(JSON).unwrap();
27//! ```
28
29mod transpose_codec;
30mod transpose_codec_partial;
31
32use std::sync::Arc;
33
34pub use transpose_codec::TransposeCodec;
35use zarrs_metadata::v3::MetadataV3;
36use zarrs_plugin::ExtensionAliasesV3;
37
38use crate::array::{
39    ArrayBytes, ArrayBytesRaw, ArraySubset, ArraySubsetTraits, DataType, Indexer, IndexerError,
40};
41use zarrs_codec::{ArrayBytesOffsets, Codec, CodecError, CodecPluginV3, CodecTraitsV3};
42use zarrs_metadata::DataTypeSize;
43pub use zarrs_metadata_ext::codec::transpose::{
44    TransposeCodecConfiguration, TransposeCodecConfigurationV1, TransposeOrder, TransposeOrderError,
45};
46use zarrs_plugin::PluginCreateError;
47
48zarrs_plugin::impl_extension_aliases!(TransposeCodec, v3: "transpose");
49
50// Register the V3 codec.
51inventory::submit! {
52    CodecPluginV3::new::<TransposeCodec>()
53}
54
55impl CodecTraitsV3 for TransposeCodec {
56    fn create(metadata: &MetadataV3) -> Result<Codec, PluginCreateError> {
57        let configuration: TransposeCodecConfiguration = metadata.to_typed_configuration()?;
58        let codec = Arc::new(TransposeCodec::new_with_configuration(&configuration)?);
59        Ok(Codec::ArrayToArray(codec))
60    }
61}
62
63/// Compute the inverse permutation order.
64///
65/// For a permutation `p`, returns the inverse permutation `p_inv` such that
66/// `p_inv[p[i]] = i` for all `i`.
67pub(crate) fn inverse_permutation(order: &[usize]) -> Vec<usize> {
68    let mut inverse = vec![0; order.len()];
69    for (i, &val) in order.iter().enumerate() {
70        inverse[val] = i;
71    }
72    inverse
73}
74
75fn transpose_array(
76    transpose_order: &[usize],
77    untransposed_shape: &[u64],
78    bytes_per_element: usize,
79    data: &[u8],
80) -> Result<Vec<u8>, ndarray::ShapeError> {
81    // Create an array view of the data
82    let mut shape_n = Vec::with_capacity(untransposed_shape.len() + 1);
83    for size in untransposed_shape {
84        shape_n.push(usize::try_from(*size).unwrap());
85    }
86    shape_n.push(bytes_per_element);
87    let array = ndarray::ArrayViewD::<u8>::from_shape(shape_n, data)?;
88
89    // Transpose the data
90    let array_transposed = array.permuted_axes(transpose_order);
91    if array_transposed.is_standard_layout() {
92        Ok(array_transposed.to_owned().into_raw_vec_and_offset().0)
93    } else {
94        Ok(array_transposed
95            .as_standard_layout()
96            .into_owned()
97            .into_raw_vec_and_offset()
98            .0)
99    }
100}
101
102fn permute<T: Copy>(v: &[T], order: &[usize]) -> Option<Vec<T>> {
103    if v.len() == order.len() {
104        let mut vec = Vec::<T>::with_capacity(v.len());
105        for axis in order {
106            vec.push(v[*axis]);
107        }
108        Some(vec)
109    } else {
110        None
111    }
112}
113
114fn transpose_vlen<'a>(
115    bytes: &ArrayBytesRaw,
116    offsets: &ArrayBytesOffsets,
117    shape: &[usize],
118    order: Vec<usize>,
119) -> ArrayBytes<'a> {
120    debug_assert_eq!(shape.len(), order.len());
121
122    // Get the transposed element indices
123    let ndarray_indices =
124        ndarray::ArrayD::from_shape_vec(shape, (0..shape.iter().product()).collect()).unwrap();
125    let ndarray_indices_transposed = ndarray_indices.permuted_axes(order);
126
127    // Collect the new bytes/offsets
128    let mut bytes_new = Vec::with_capacity(bytes.len());
129    let mut offsets_new = Vec::with_capacity(offsets.len());
130    for idx in &ndarray_indices_transposed {
131        offsets_new.push(bytes_new.len());
132        let curr = offsets[*idx];
133        let next = offsets[idx + 1];
134        bytes_new.extend_from_slice(&bytes[curr..next]);
135    }
136    offsets_new.push(bytes_new.len());
137    let offsets_new = unsafe {
138        // SAFETY: The offsets are monotonically increasing.
139        ArrayBytesOffsets::new_unchecked(offsets_new)
140    };
141    unsafe {
142        // SAFETY: The last offset is equal to the length of the bytes
143        ArrayBytes::new_vlen_unchecked(bytes_new, offsets_new)
144    }
145}
146
147fn get_transposed_array_subset(
148    order: &[usize],
149    decoded_region: &dyn ArraySubsetTraits,
150) -> Result<ArraySubset, CodecError> {
151    if decoded_region.dimensionality() != order.len() {
152        return Err(IndexerError::new_incompatible_dimensionality(
153            decoded_region.dimensionality(),
154            order.len(),
155        )
156        .into());
157    }
158
159    let start = permute(&decoded_region.start(), order).expect("matching dimensionality");
160    let size = permute(&decoded_region.shape(), order).expect("matching dimensionality");
161    let ranges = start.iter().zip(size).map(|(&st, si)| st..(st + si));
162    Ok(ArraySubset::from(ranges))
163}
164
165fn get_transposed_indexer(
166    order: &[usize],
167    indexer: &dyn Indexer,
168) -> Result<impl Indexer, CodecError> {
169    indexer
170        .iter_indices()
171        .map(|indices| permute(&indices, order))
172        .collect::<Option<Vec<_>>>()
173        .ok_or_else(|| {
174            IndexerError::new_incompatible_dimensionality(indexer.dimensionality(), order.len())
175                .into()
176        })
177}
178
179/// Apply a transpose permutation to array bytes.
180///
181/// # Arguments
182/// * `bytes` - The input array bytes to transpose
183/// * `input_shape` - The shape of the input array
184/// * `permutation` - The permutation order to apply
185/// * `data_type` - The data type of the array elements
186///
187/// The output shape will be `permute(input_shape, permutation)`.
188pub(crate) fn apply_permutation<'a>(
189    bytes: &ArrayBytes<'a>,
190    input_shape: &[u64],
191    permutation: &[usize],
192    data_type: &DataType,
193) -> Result<ArrayBytes<'a>, CodecError> {
194    if input_shape.len() != permutation.len() {
195        return Err(IndexerError::new_incompatible_dimensionality(
196            input_shape.len(),
197            permutation.len(),
198        )
199        .into());
200    }
201
202    let num_elements = input_shape.iter().product();
203    bytes.validate(num_elements, data_type)?;
204
205    match (bytes, data_type.size()) {
206        (ArrayBytes::Variable(vlen_bytes), DataTypeSize::Variable) => {
207            let bytes = vlen_bytes.bytes();
208            let offsets = vlen_bytes.offsets();
209            let shape: Vec<usize> = input_shape
210                .iter()
211                .map(|s| usize::try_from(*s).unwrap())
212                .collect();
213            Ok(transpose_vlen(bytes, offsets, &shape, permutation.to_vec()))
214        }
215        (ArrayBytes::Fixed(bytes), DataTypeSize::Fixed(data_type_size)) => {
216            // For fixed-size types, add an extra dimension for the element bytes
217            let mut order_with_bytes = permutation.to_vec();
218            order_with_bytes.push(permutation.len());
219            let bytes = transpose_array(&order_with_bytes, input_shape, data_type_size, bytes)
220                .map_err(|_| CodecError::Other("transpose_array error".to_string()))?;
221            Ok(ArrayBytes::from(bytes))
222        }
223        (ArrayBytes::Optional(..), _) => Err(CodecError::UnsupportedDataType(
224            data_type.clone(),
225            TransposeCodec::aliases_v3().default_name.to_string(),
226        )),
227        (_, _) => Err(CodecError::Other(
228            "dev error: transpose data type mismatch".to_string(),
229        )),
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use std::num::NonZeroU64;
236    use std::sync::Arc;
237
238    use super::*;
239    use crate::array::codec::BytesCodec;
240    use crate::array::{ArrayBytes, ArraySubset, ChunkShapeTraits, DataType, FillValue, data_type};
241    use zarrs_codec::{ArrayToArrayCodecTraits, ArrayToBytesCodecTraits, CodecOptions};
242
243    fn codec_transpose_round_trip_impl(
244        json: &str,
245        data_type: DataType,
246        fill_value: impl Into<FillValue>,
247    ) {
248        let shape = vec![
249            NonZeroU64::new(2).unwrap(),
250            NonZeroU64::new(2).unwrap(),
251            NonZeroU64::new(3).unwrap(),
252        ];
253        let fill_value = fill_value.into();
254        let size = shape.num_elements_usize() * data_type.fixed_size().unwrap();
255        let bytes: Vec<u8> = (0..size).map(|s| s as u8).collect();
256        let bytes: ArrayBytes = bytes.into();
257
258        let configuration: TransposeCodecConfiguration = serde_json::from_str(json).unwrap();
259        let codec = TransposeCodec::new_with_configuration(&configuration).unwrap();
260
261        let encoded = codec
262            .encode(
263                bytes.clone(),
264                &shape,
265                &data_type,
266                &fill_value,
267                &CodecOptions::default(),
268            )
269            .unwrap();
270        let decoded = codec
271            .decode(
272                encoded,
273                &shape,
274                &data_type,
275                &fill_value,
276                &CodecOptions::default(),
277            )
278            .unwrap();
279        assert_eq!(bytes, decoded);
280    }
281
282    #[test]
283    fn codec_transpose_round_trip_array1() {
284        const JSON: &str = r#"{
285            "order": [0, 2, 1]
286        }"#;
287        codec_transpose_round_trip_impl(JSON, data_type::uint8(), 0u8);
288    }
289
290    #[test]
291    fn codec_transpose_round_trip_array2() {
292        const JSON: &str = r#"{
293            "order": [2, 1, 0]
294        }"#;
295        codec_transpose_round_trip_impl(JSON, data_type::uint16(), 0u16);
296    }
297
298    #[test]
299    fn codec_transpose_round_trip_vlen_string() {
300        use crate::array::Element;
301
302        // Create a 2x3 array of strings
303        let shape = vec![NonZeroU64::new(2).unwrap(), NonZeroU64::new(3).unwrap()];
304        let data_type = data_type::string();
305        let fill_value = FillValue::from("");
306
307        // Create test data: 6 strings in row-major order
308        let strings: Vec<&str> = vec!["s00", "s01a", "s02ab", "s10abc", "s11abcd", "s12abcde"];
309        let bytes = Element::into_array_bytes(&data_type::string(), strings).unwrap();
310
311        // Create transpose codec with order [1, 0] (swap axes)
312        let codec = TransposeCodec::new(TransposeOrder::new(&[1, 0]).unwrap());
313
314        let encoded = codec
315            .encode(
316                bytes.clone(),
317                &shape,
318                &data_type,
319                &fill_value,
320                &CodecOptions::default(),
321            )
322            .unwrap();
323        let decoded = codec
324            .decode(
325                encoded,
326                &shape,
327                &data_type,
328                &fill_value,
329                &CodecOptions::default(),
330            )
331            .unwrap();
332
333        assert_eq!(bytes, decoded);
334    }
335
336    #[test]
337    fn apply_permutation_vlen_string() {
338        use crate::array::Element;
339
340        // Test apply_permutation with vlen data (used by partial encode/decode)
341        // This tests a non-square shape to catch shape mismatch bugs
342        // Original shape: 2x3, Transposed shape: 3x2
343        let order = TransposeOrder::new(&[1, 0]).unwrap();
344
345        // Create test data: 6 strings in row-major order for shape [2, 3]
346        // [[s00, s01, s02], [s10, s11, s12]]
347        let strings: Vec<&str> = vec!["s00", "s01a", "s02ab", "s10abc", "s11abcd", "s12abcde"];
348        let original = Element::into_array_bytes(&data_type::string(), strings).unwrap();
349
350        // Encode: apply transpose order [1, 0] to get shape [3, 2]
351        // Transposed should be: [[s00, s10], [s01, s11], [s02, s12]]
352        let transposed_strings: Vec<&str> =
353            vec!["s00", "s10abc", "s01a", "s11abcd", "s02ab", "s12abcde"];
354        let expected_transposed =
355            Element::into_array_bytes(&data_type::string(), transposed_strings).unwrap();
356
357        // Test encoding (forward permutation)
358        let encoded =
359            apply_permutation(&original, &[2, 3], &order.0, &data_type::string()).unwrap();
360        assert_eq!(encoded, expected_transposed);
361
362        // Test decoding (inverse permutation)
363        // Inverse of [1, 0] is [1, 0]
364        let order_decode = [1, 0];
365        let decoded =
366            apply_permutation(&encoded, &[3, 2], &order_decode, &data_type::string()).unwrap();
367        assert_eq!(decoded, original);
368    }
369
370    #[test]
371    fn codec_transpose_partial_decode() {
372        let codec = Arc::new(TransposeCodec::new(TransposeOrder::new(&[1, 0]).unwrap()));
373
374        let elements: Vec<f32> = (0..16).map(|i| i as f32).collect();
375        let shape = vec![NonZeroU64::new(4).unwrap(), NonZeroU64::new(4).unwrap()];
376        let data_type = data_type::float32();
377        let fill_value = FillValue::from(0.0f32);
378        let bytes = crate::array::transmute_to_bytes_vec(elements);
379        let bytes: ArrayBytes = bytes.into();
380
381        let encoded = codec
382            .encode(
383                bytes,
384                &shape,
385                &data_type,
386                &fill_value,
387                &CodecOptions::default(),
388            )
389            .unwrap();
390        let input_handle = Arc::new(encoded.into_fixed().unwrap());
391        let bytes_codec = Arc::new(BytesCodec::default());
392        let input_handle = bytes_codec
393            .partial_decoder(
394                input_handle,
395                &shape,
396                &data_type,
397                &fill_value,
398                &CodecOptions::default(),
399            )
400            .unwrap();
401        let partial_decoder = codec
402            .partial_decoder(
403                input_handle.clone(),
404                &shape,
405                &data_type,
406                &fill_value,
407                &CodecOptions::default(),
408            )
409            .unwrap();
410        assert_eq!(partial_decoder.size_held(), input_handle.size_held()); // transpose partial decoder does not hold bytes
411        let decoded_regions = [
412            ArraySubset::new_with_ranges(&[0..4, 0..4]),
413            ArraySubset::new_with_ranges(&[1..3, 1..4]),
414            ArraySubset::new_with_ranges(&[2..4, 0..2]),
415        ];
416        let answer: &[Vec<f32>] = &[
417            vec![
418                0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
419                15.0,
420            ],
421            vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0],
422            vec![8.0, 9.0, 12.0, 13.0],
423        ];
424        for (decoded_region, expected) in decoded_regions.into_iter().zip(answer.iter()) {
425            let decoded_partial_chunk = partial_decoder
426                .partial_decode(&decoded_region, &CodecOptions::default())
427                .unwrap();
428            let decoded_partial_chunk = crate::array::convert_from_bytes_slice::<f32>(
429                &decoded_partial_chunk.into_fixed().unwrap(),
430            );
431            assert_eq!(expected, &decoded_partial_chunk);
432        }
433    }
434
435    #[cfg(feature = "async")]
436    #[tokio::test]
437    async fn codec_transpose_async_partial_decode() {
438        let codec = Arc::new(TransposeCodec::new(TransposeOrder::new(&[1, 0]).unwrap()));
439
440        let elements: Vec<f32> = (0..16).map(|i| i as f32).collect();
441        let shape = vec![NonZeroU64::new(4).unwrap(), NonZeroU64::new(4).unwrap()];
442        let data_type = data_type::float32();
443        let fill_value = FillValue::from(0.0f32);
444        let bytes = crate::array::transmute_to_bytes_vec(elements);
445        let bytes: ArrayBytes = bytes.into();
446
447        let encoded = codec
448            .encode(
449                bytes.clone(),
450                &shape,
451                &data_type,
452                &fill_value,
453                &CodecOptions::default(),
454            )
455            .unwrap();
456        let input_handle = Arc::new(encoded.into_fixed().unwrap());
457        let bytes_codec = Arc::new(BytesCodec::default());
458        let input_handle = bytes_codec
459            .async_partial_decoder(
460                input_handle,
461                &shape,
462                &data_type,
463                &fill_value,
464                &CodecOptions::default(),
465            )
466            .await
467            .unwrap();
468        let partial_decoder = codec
469            .async_partial_decoder(
470                input_handle,
471                &shape,
472                &data_type,
473                &fill_value,
474                &CodecOptions::default(),
475            )
476            .await
477            .unwrap();
478        let decoded_regions = [
479            ArraySubset::new_with_ranges(&[0..4, 0..4]),
480            ArraySubset::new_with_ranges(&[1..3, 1..4]),
481            ArraySubset::new_with_ranges(&[2..4, 0..2]),
482        ];
483        let answer: &[Vec<f32>] = &[
484            vec![
485                0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
486                15.0,
487            ],
488            vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0],
489            vec![8.0, 9.0, 12.0, 13.0],
490        ];
491        for (decoded_region, answer) in decoded_regions.into_iter().zip(answer.iter()) {
492            let decoded_partial_chunk = partial_decoder
493                .partial_decode(&decoded_region, &CodecOptions::default())
494                .await
495                .unwrap();
496            let decoded_partial_chunk = crate::array::convert_from_bytes_slice::<f32>(
497                &decoded_partial_chunk.into_fixed().unwrap(),
498            );
499            assert_eq!(answer, &decoded_partial_chunk);
500        }
501    }
502}