1mod transpose_codec;
30mod transpose_codec_partial;
31
32use std::sync::Arc;
33
34pub use transpose_codec::TransposeCodec;
35use zarrs_metadata::v3::MetadataV3;
36use zarrs_plugin::ExtensionAliasesV3;
37
38use crate::array::{
39 ArrayBytes, ArrayBytesRaw, ArraySubset, ArraySubsetTraits, DataType, Indexer, IndexerError,
40};
41use zarrs_codec::{ArrayBytesOffsets, Codec, CodecError, CodecPluginV3, CodecTraitsV3};
42use zarrs_metadata::DataTypeSize;
43pub use zarrs_metadata_ext::codec::transpose::{
44 TransposeCodecConfiguration, TransposeCodecConfigurationV1, TransposeOrder, TransposeOrderError,
45};
46use zarrs_plugin::PluginCreateError;
47
48zarrs_plugin::impl_extension_aliases!(TransposeCodec, v3: "transpose");
49
50inventory::submit! {
52 CodecPluginV3::new::<TransposeCodec>()
53}
54
55impl CodecTraitsV3 for TransposeCodec {
56 fn create(metadata: &MetadataV3) -> Result<Codec, PluginCreateError> {
57 let configuration: TransposeCodecConfiguration = metadata.to_typed_configuration()?;
58 let codec = Arc::new(TransposeCodec::new_with_configuration(&configuration)?);
59 Ok(Codec::ArrayToArray(codec))
60 }
61}
62
63pub(crate) fn inverse_permutation(order: &[usize]) -> Vec<usize> {
68 let mut inverse = vec![0; order.len()];
69 for (i, &val) in order.iter().enumerate() {
70 inverse[val] = i;
71 }
72 inverse
73}
74
75fn transpose_array(
76 transpose_order: &[usize],
77 untransposed_shape: &[u64],
78 bytes_per_element: usize,
79 data: &[u8],
80) -> Result<Vec<u8>, ndarray::ShapeError> {
81 let mut shape_n = Vec::with_capacity(untransposed_shape.len() + 1);
83 for size in untransposed_shape {
84 shape_n.push(usize::try_from(*size).unwrap());
85 }
86 shape_n.push(bytes_per_element);
87 let array = ndarray::ArrayViewD::<u8>::from_shape(shape_n, data)?;
88
89 let array_transposed = array.permuted_axes(transpose_order);
91 if array_transposed.is_standard_layout() {
92 Ok(array_transposed.to_owned().into_raw_vec_and_offset().0)
93 } else {
94 Ok(array_transposed
95 .as_standard_layout()
96 .into_owned()
97 .into_raw_vec_and_offset()
98 .0)
99 }
100}
101
102fn permute<T: Copy>(v: &[T], order: &[usize]) -> Option<Vec<T>> {
103 if v.len() == order.len() {
104 let mut vec = Vec::<T>::with_capacity(v.len());
105 for axis in order {
106 vec.push(v[*axis]);
107 }
108 Some(vec)
109 } else {
110 None
111 }
112}
113
114fn transpose_vlen<'a>(
115 bytes: &ArrayBytesRaw,
116 offsets: &ArrayBytesOffsets,
117 shape: &[usize],
118 order: Vec<usize>,
119) -> ArrayBytes<'a> {
120 debug_assert_eq!(shape.len(), order.len());
121
122 let ndarray_indices =
124 ndarray::ArrayD::from_shape_vec(shape, (0..shape.iter().product()).collect()).unwrap();
125 let ndarray_indices_transposed = ndarray_indices.permuted_axes(order);
126
127 let mut bytes_new = Vec::with_capacity(bytes.len());
129 let mut offsets_new = Vec::with_capacity(offsets.len());
130 for idx in &ndarray_indices_transposed {
131 offsets_new.push(bytes_new.len());
132 let curr = offsets[*idx];
133 let next = offsets[idx + 1];
134 bytes_new.extend_from_slice(&bytes[curr..next]);
135 }
136 offsets_new.push(bytes_new.len());
137 let offsets_new = unsafe {
138 ArrayBytesOffsets::new_unchecked(offsets_new)
140 };
141 unsafe {
142 ArrayBytes::new_vlen_unchecked(bytes_new, offsets_new)
144 }
145}
146
147fn get_transposed_array_subset(
148 order: &[usize],
149 decoded_region: &dyn ArraySubsetTraits,
150) -> Result<ArraySubset, CodecError> {
151 if decoded_region.dimensionality() != order.len() {
152 return Err(IndexerError::new_incompatible_dimensionality(
153 decoded_region.dimensionality(),
154 order.len(),
155 )
156 .into());
157 }
158
159 let start = permute(&decoded_region.start(), order).expect("matching dimensionality");
160 let size = permute(&decoded_region.shape(), order).expect("matching dimensionality");
161 let ranges = start.iter().zip(size).map(|(&st, si)| st..(st + si));
162 Ok(ArraySubset::from(ranges))
163}
164
165fn get_transposed_indexer(
166 order: &[usize],
167 indexer: &dyn Indexer,
168) -> Result<impl Indexer, CodecError> {
169 indexer
170 .iter_indices()
171 .map(|indices| permute(&indices, order))
172 .collect::<Option<Vec<_>>>()
173 .ok_or_else(|| {
174 IndexerError::new_incompatible_dimensionality(indexer.dimensionality(), order.len())
175 .into()
176 })
177}
178
179pub(crate) fn apply_permutation<'a>(
189 bytes: &ArrayBytes<'a>,
190 input_shape: &[u64],
191 permutation: &[usize],
192 data_type: &DataType,
193) -> Result<ArrayBytes<'a>, CodecError> {
194 if input_shape.len() != permutation.len() {
195 return Err(IndexerError::new_incompatible_dimensionality(
196 input_shape.len(),
197 permutation.len(),
198 )
199 .into());
200 }
201
202 let num_elements = input_shape.iter().product();
203 bytes.validate(num_elements, data_type)?;
204
205 match (bytes, data_type.size()) {
206 (ArrayBytes::Variable(vlen_bytes), DataTypeSize::Variable) => {
207 let bytes = vlen_bytes.bytes();
208 let offsets = vlen_bytes.offsets();
209 let shape: Vec<usize> = input_shape
210 .iter()
211 .map(|s| usize::try_from(*s).unwrap())
212 .collect();
213 Ok(transpose_vlen(bytes, offsets, &shape, permutation.to_vec()))
214 }
215 (ArrayBytes::Fixed(bytes), DataTypeSize::Fixed(data_type_size)) => {
216 let mut order_with_bytes = permutation.to_vec();
218 order_with_bytes.push(permutation.len());
219 let bytes = transpose_array(&order_with_bytes, input_shape, data_type_size, bytes)
220 .map_err(|_| CodecError::Other("transpose_array error".to_string()))?;
221 Ok(ArrayBytes::from(bytes))
222 }
223 (ArrayBytes::Optional(..), _) => Err(CodecError::UnsupportedDataType(
224 data_type.clone(),
225 TransposeCodec::aliases_v3().default_name.to_string(),
226 )),
227 (_, _) => Err(CodecError::Other(
228 "dev error: transpose data type mismatch".to_string(),
229 )),
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use std::num::NonZeroU64;
236 use std::sync::Arc;
237
238 use super::*;
239 use crate::array::codec::BytesCodec;
240 use crate::array::{ArrayBytes, ArraySubset, ChunkShapeTraits, DataType, FillValue, data_type};
241 use zarrs_codec::{ArrayToArrayCodecTraits, ArrayToBytesCodecTraits, CodecOptions};
242
243 fn codec_transpose_round_trip_impl(
244 json: &str,
245 data_type: DataType,
246 fill_value: impl Into<FillValue>,
247 ) {
248 let shape = vec![
249 NonZeroU64::new(2).unwrap(),
250 NonZeroU64::new(2).unwrap(),
251 NonZeroU64::new(3).unwrap(),
252 ];
253 let fill_value = fill_value.into();
254 let size = shape.num_elements_usize() * data_type.fixed_size().unwrap();
255 let bytes: Vec<u8> = (0..size).map(|s| s as u8).collect();
256 let bytes: ArrayBytes = bytes.into();
257
258 let configuration: TransposeCodecConfiguration = serde_json::from_str(json).unwrap();
259 let codec = TransposeCodec::new_with_configuration(&configuration).unwrap();
260
261 let encoded = codec
262 .encode(
263 bytes.clone(),
264 &shape,
265 &data_type,
266 &fill_value,
267 &CodecOptions::default(),
268 )
269 .unwrap();
270 let decoded = codec
271 .decode(
272 encoded,
273 &shape,
274 &data_type,
275 &fill_value,
276 &CodecOptions::default(),
277 )
278 .unwrap();
279 assert_eq!(bytes, decoded);
280 }
281
282 #[test]
283 fn codec_transpose_round_trip_array1() {
284 const JSON: &str = r#"{
285 "order": [0, 2, 1]
286 }"#;
287 codec_transpose_round_trip_impl(JSON, data_type::uint8(), 0u8);
288 }
289
290 #[test]
291 fn codec_transpose_round_trip_array2() {
292 const JSON: &str = r#"{
293 "order": [2, 1, 0]
294 }"#;
295 codec_transpose_round_trip_impl(JSON, data_type::uint16(), 0u16);
296 }
297
298 #[test]
299 fn codec_transpose_round_trip_vlen_string() {
300 use crate::array::Element;
301
302 let shape = vec![NonZeroU64::new(2).unwrap(), NonZeroU64::new(3).unwrap()];
304 let data_type = data_type::string();
305 let fill_value = FillValue::from("");
306
307 let strings: Vec<&str> = vec!["s00", "s01a", "s02ab", "s10abc", "s11abcd", "s12abcde"];
309 let bytes = Element::into_array_bytes(&data_type::string(), strings).unwrap();
310
311 let codec = TransposeCodec::new(TransposeOrder::new(&[1, 0]).unwrap());
313
314 let encoded = codec
315 .encode(
316 bytes.clone(),
317 &shape,
318 &data_type,
319 &fill_value,
320 &CodecOptions::default(),
321 )
322 .unwrap();
323 let decoded = codec
324 .decode(
325 encoded,
326 &shape,
327 &data_type,
328 &fill_value,
329 &CodecOptions::default(),
330 )
331 .unwrap();
332
333 assert_eq!(bytes, decoded);
334 }
335
336 #[test]
337 fn apply_permutation_vlen_string() {
338 use crate::array::Element;
339
340 let order = TransposeOrder::new(&[1, 0]).unwrap();
344
345 let strings: Vec<&str> = vec!["s00", "s01a", "s02ab", "s10abc", "s11abcd", "s12abcde"];
348 let original = Element::into_array_bytes(&data_type::string(), strings).unwrap();
349
350 let transposed_strings: Vec<&str> =
353 vec!["s00", "s10abc", "s01a", "s11abcd", "s02ab", "s12abcde"];
354 let expected_transposed =
355 Element::into_array_bytes(&data_type::string(), transposed_strings).unwrap();
356
357 let encoded =
359 apply_permutation(&original, &[2, 3], &order.0, &data_type::string()).unwrap();
360 assert_eq!(encoded, expected_transposed);
361
362 let order_decode = [1, 0];
365 let decoded =
366 apply_permutation(&encoded, &[3, 2], &order_decode, &data_type::string()).unwrap();
367 assert_eq!(decoded, original);
368 }
369
370 #[test]
371 fn codec_transpose_partial_decode() {
372 let codec = Arc::new(TransposeCodec::new(TransposeOrder::new(&[1, 0]).unwrap()));
373
374 let elements: Vec<f32> = (0..16).map(|i| i as f32).collect();
375 let shape = vec![NonZeroU64::new(4).unwrap(), NonZeroU64::new(4).unwrap()];
376 let data_type = data_type::float32();
377 let fill_value = FillValue::from(0.0f32);
378 let bytes = crate::array::transmute_to_bytes_vec(elements);
379 let bytes: ArrayBytes = bytes.into();
380
381 let encoded = codec
382 .encode(
383 bytes,
384 &shape,
385 &data_type,
386 &fill_value,
387 &CodecOptions::default(),
388 )
389 .unwrap();
390 let input_handle = Arc::new(encoded.into_fixed().unwrap());
391 let bytes_codec = Arc::new(BytesCodec::default());
392 let input_handle = bytes_codec
393 .partial_decoder(
394 input_handle,
395 &shape,
396 &data_type,
397 &fill_value,
398 &CodecOptions::default(),
399 )
400 .unwrap();
401 let partial_decoder = codec
402 .partial_decoder(
403 input_handle.clone(),
404 &shape,
405 &data_type,
406 &fill_value,
407 &CodecOptions::default(),
408 )
409 .unwrap();
410 assert_eq!(partial_decoder.size_held(), input_handle.size_held()); let decoded_regions = [
412 ArraySubset::new_with_ranges(&[0..4, 0..4]),
413 ArraySubset::new_with_ranges(&[1..3, 1..4]),
414 ArraySubset::new_with_ranges(&[2..4, 0..2]),
415 ];
416 let answer: &[Vec<f32>] = &[
417 vec![
418 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
419 15.0,
420 ],
421 vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0],
422 vec![8.0, 9.0, 12.0, 13.0],
423 ];
424 for (decoded_region, expected) in decoded_regions.into_iter().zip(answer.iter()) {
425 let decoded_partial_chunk = partial_decoder
426 .partial_decode(&decoded_region, &CodecOptions::default())
427 .unwrap();
428 let decoded_partial_chunk = crate::array::convert_from_bytes_slice::<f32>(
429 &decoded_partial_chunk.into_fixed().unwrap(),
430 );
431 assert_eq!(expected, &decoded_partial_chunk);
432 }
433 }
434
435 #[cfg(feature = "async")]
436 #[tokio::test]
437 async fn codec_transpose_async_partial_decode() {
438 let codec = Arc::new(TransposeCodec::new(TransposeOrder::new(&[1, 0]).unwrap()));
439
440 let elements: Vec<f32> = (0..16).map(|i| i as f32).collect();
441 let shape = vec![NonZeroU64::new(4).unwrap(), NonZeroU64::new(4).unwrap()];
442 let data_type = data_type::float32();
443 let fill_value = FillValue::from(0.0f32);
444 let bytes = crate::array::transmute_to_bytes_vec(elements);
445 let bytes: ArrayBytes = bytes.into();
446
447 let encoded = codec
448 .encode(
449 bytes.clone(),
450 &shape,
451 &data_type,
452 &fill_value,
453 &CodecOptions::default(),
454 )
455 .unwrap();
456 let input_handle = Arc::new(encoded.into_fixed().unwrap());
457 let bytes_codec = Arc::new(BytesCodec::default());
458 let input_handle = bytes_codec
459 .async_partial_decoder(
460 input_handle,
461 &shape,
462 &data_type,
463 &fill_value,
464 &CodecOptions::default(),
465 )
466 .await
467 .unwrap();
468 let partial_decoder = codec
469 .async_partial_decoder(
470 input_handle,
471 &shape,
472 &data_type,
473 &fill_value,
474 &CodecOptions::default(),
475 )
476 .await
477 .unwrap();
478 let decoded_regions = [
479 ArraySubset::new_with_ranges(&[0..4, 0..4]),
480 ArraySubset::new_with_ranges(&[1..3, 1..4]),
481 ArraySubset::new_with_ranges(&[2..4, 0..2]),
482 ];
483 let answer: &[Vec<f32>] = &[
484 vec![
485 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
486 15.0,
487 ],
488 vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0],
489 vec![8.0, 9.0, 12.0, 13.0],
490 ];
491 for (decoded_region, answer) in decoded_regions.into_iter().zip(answer.iter()) {
492 let decoded_partial_chunk = partial_decoder
493 .partial_decode(&decoded_region, &CodecOptions::default())
494 .await
495 .unwrap();
496 let decoded_partial_chunk = crate::array::convert_from_bytes_slice::<f32>(
497 &decoded_partial_chunk.into_fixed().unwrap(),
498 );
499 assert_eq!(answer, &decoded_partial_chunk);
500 }
501 }
502}