1mod reshape_codec;
33mod reshape_codec_partial;
34
35use std::num::NonZeroU64;
36use std::sync::Arc;
37
38use num::Integer;
39pub use reshape_codec::ReshapeCodec;
40use zarrs_metadata::v3::MetadataV3;
41
42use crate::array::{ArrayIndices, ChunkShape, Indexer, IndexerError, unravel_index};
44use zarrs_codec::{Codec, CodecError, CodecPluginV3, CodecTraitsV3};
45pub use zarrs_metadata_ext::codec::reshape::{
46 ReshapeCodecConfiguration, ReshapeCodecConfigurationV1, ReshapeDim, ReshapeShape,
47};
48use zarrs_plugin::PluginCreateError;
49
50fn get_encoded_shape(
51 reshape_shape: &ReshapeShape,
52 decoded_shape: &[NonZeroU64],
53) -> Result<ChunkShape, CodecError> {
54 let mut encoded_shape = Vec::with_capacity(reshape_shape.0.len());
55 let mut fill_index = None;
56 for output_dim in &reshape_shape.0 {
57 match output_dim {
58 ReshapeDim::Size(size) => encoded_shape.push(*size),
59 ReshapeDim::InputDims(input_dims) => {
60 let mut product = NonZeroU64::new(1).unwrap();
61 for input_dim in input_dims {
62 let input_shape = *decoded_shape
63 .get(usize::try_from(*input_dim).unwrap())
64 .ok_or_else(|| {
65 CodecError::Other(
66 format!("reshape codec shape references a dimension ({input_dim}) larger than the chunk dimensionality ({})", decoded_shape.len()),
67 )
68 })?;
69 product = product.checked_mul(input_shape).unwrap();
70 }
71 encoded_shape.push(product);
72 }
73 ReshapeDim::Auto(_) => {
74 fill_index = Some(encoded_shape.len());
75 encoded_shape.push(NonZeroU64::new(1).unwrap());
76 }
77 }
78 }
79
80 let num_elements_input = decoded_shape.iter().map(|u| u.get()).product::<u64>();
81 let num_elements_output = encoded_shape.iter().map(|u| u.get()).product::<u64>();
82 if let Some(fill_index) = fill_index {
83 let (quot, rem) = num_elements_input.div_rem(&num_elements_output);
84 if rem == 0 {
85 encoded_shape[fill_index] = NonZeroU64::new(quot).unwrap();
86 } else {
87 return Err(CodecError::Other(format!(
88 "reshape codec no substitution for dim {fill_index} can satisfy decoded_shape {decoded_shape:?} == encoded_shape {encoded_shape:?}."
89 )));
90 }
91 } else if num_elements_input != num_elements_output {
92 return Err(CodecError::Other(format!(
93 "reshape codec encoded/decoded number of elements differ: decoded_shape {decoded_shape:?} ({num_elements_input}) encoded_shape {encoded_shape:?} ({num_elements_output})."
94 )));
95 }
96
97 Ok(encoded_shape)
98}
99
100fn get_reshaped_indexer(
101 indexer: &dyn Indexer,
102 decoded_shape: &[NonZeroU64],
103 encoded_shape: &[NonZeroU64],
104) -> Result<impl Indexer, CodecError> {
105 if indexer.dimensionality() != decoded_shape.len() {
106 return Err(IndexerError::new_incompatible_dimensionality(
107 indexer.dimensionality(),
108 decoded_shape.len(),
109 )
110 .into());
111 }
112
113 let decoded_shape = bytemuck::must_cast_slice(decoded_shape);
114 let encoded_shape = bytemuck::must_cast_slice(encoded_shape);
115 let indices = indexer
116 .iter_linearised_indices(decoded_shape)?
117 .map(|linear_index| {
118 unravel_index(linear_index, encoded_shape).ok_or_else(|| {
119 CodecError::Other(
120 "reshape codec encoded/decoded number of elements differ".to_string(),
121 )
122 })
123 })
124 .collect::<Result<Vec<ArrayIndices>, _>>()?;
125
126 Ok(indices)
127}
128
129zarrs_plugin::impl_extension_aliases!(ReshapeCodec, v3: "reshape");
130
131inventory::submit! {
133 CodecPluginV3::new::<ReshapeCodec>()
134}
135
136impl CodecTraitsV3 for ReshapeCodec {
137 fn create(metadata: &MetadataV3) -> Result<Codec, PluginCreateError> {
138 let configuration: ReshapeCodecConfiguration = metadata.to_typed_configuration()?;
139 let codec = Arc::new(ReshapeCodec::new_with_configuration(&configuration)?);
140 Ok(Codec::ArrayToArray(codec))
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::num::NonZeroU64;
147 use std::sync::Mutex;
148
149 use super::*;
150 use crate::array::codec::BytesCodec;
151 use crate::array::{ArrayBytes, ArraySubset, ChunkShapeTraits, DataType, FillValue, data_type};
152 use zarrs_codec::{
153 ArrayPartialDecoderTraits, ArrayToArrayCodecTraits, ArrayToBytesCodecTraits, CodecOptions,
154 };
155
156 fn codec_reshape_round_trip_impl(
157 json: &str,
158 data_type: DataType,
159 fill_value: FillValue,
160 output_shape: Vec<NonZeroU64>,
161 ) -> Result<(), Box<dyn std::error::Error>> {
162 let shape = vec![
163 NonZeroU64::new(5).unwrap(),
164 NonZeroU64::new(4).unwrap(),
165 NonZeroU64::new(4).unwrap(),
166 NonZeroU64::new(3).unwrap(),
167 ];
168 let size = shape.num_elements_usize() * data_type.fixed_size().unwrap();
169 let bytes: Vec<u8> = (0..size).map(|s| s as u8).collect();
170 let bytes: ArrayBytes = bytes.into();
171
172 let configuration: ReshapeCodecConfiguration = serde_json::from_str(json)?;
173 let codec = ReshapeCodec::new_with_configuration(&configuration)?;
174 assert_eq!(codec.encoded_shape(&shape)?, output_shape);
175
176 let encoded = codec.encode(
177 bytes.clone(),
178 &shape,
179 &data_type,
180 &fill_value,
181 &CodecOptions::default(),
182 )?;
183 let decoded = codec.decode(
184 encoded,
185 &shape,
186 &data_type,
187 &fill_value,
188 &CodecOptions::default(),
189 )?;
190 assert_eq!(bytes, decoded);
191 Ok(())
192 }
193
194 #[test]
195 fn codec_reshape_round_trip_array1() {
196 const JSON: &str = r#"{
197 "shape": [[0, 1], [2], 3]
198 }"#;
199 let output_shape = vec![
200 NonZeroU64::new(20).unwrap(),
201 NonZeroU64::new(4).unwrap(),
202 NonZeroU64::new(3).unwrap(),
203 ];
204 assert!(
205 codec_reshape_round_trip_impl(
206 JSON,
207 data_type::uint32(),
208 FillValue::from(0u32),
209 output_shape
210 )
211 .is_ok()
212 );
213 }
214
215 #[test]
216 fn codec_reshape_round_trip_array2() {
217 const JSON: &str = r#"{
218 "shape": [[0, 1], [2], -1]
219 }"#;
220 let output_shape = vec![
221 NonZeroU64::new(20).unwrap(),
222 NonZeroU64::new(4).unwrap(),
223 NonZeroU64::new(3).unwrap(),
224 ];
225 assert!(
226 codec_reshape_round_trip_impl(
227 JSON,
228 data_type::uint32(),
229 FillValue::from(0u32),
230 output_shape
231 )
232 .is_ok()
233 );
234 }
235
236 #[test]
237 fn codec_reshape_round_trip_array3() {
238 const JSON: &str = r#"{
239 "shape": [[0, 1, 2], 3]
240 }"#;
241 let output_shape = vec![NonZeroU64::new(80).unwrap(), NonZeroU64::new(3).unwrap()];
242 assert!(
243 codec_reshape_round_trip_impl(
244 JSON,
245 data_type::uint32(),
246 FillValue::from(0u32),
247 output_shape
248 )
249 .is_ok()
250 );
251 }
252
253 #[test]
254 fn codec_reshape_round_trip_array4() {
255 const JSON: &str = r#"{
256 "shape": [[0], -1, [2, 3]]
257 }"#;
258 let output_shape = vec![
259 NonZeroU64::new(5).unwrap(),
260 NonZeroU64::new(4).unwrap(),
261 NonZeroU64::new(12).unwrap(),
262 ];
263 assert!(
264 codec_reshape_round_trip_impl(
265 JSON,
266 data_type::uint32(),
267 FillValue::from(0u32),
268 output_shape
269 )
270 .is_ok()
271 );
272 }
273
274 #[test]
275 fn codec_reshape_round_trip_array5() {
276 const JSON: &str = r#"{
277 "shape": [[0], -1, [3]]
278 }"#;
279 let output_shape = vec![
280 NonZeroU64::new(5).unwrap(),
281 NonZeroU64::new(16).unwrap(),
282 NonZeroU64::new(3).unwrap(),
283 ];
284 assert!(
285 codec_reshape_round_trip_impl(
286 JSON,
287 data_type::uint32(),
288 FillValue::from(0u32),
289 output_shape
290 )
291 .is_ok()
292 );
293 }
294
295 #[test]
296 fn codec_reshape_round_trip_array6() {
297 const JSON: &str = r#"{
298 "shape": [-1, 2, 2, [3]]
299 }"#;
300 let output_shape = vec![
301 NonZeroU64::new(20).unwrap(),
302 NonZeroU64::new(2).unwrap(),
303 NonZeroU64::new(2).unwrap(),
304 NonZeroU64::new(3).unwrap(),
305 ];
306 assert!(
307 codec_reshape_round_trip_impl(
308 JSON,
309 data_type::uint32(),
310 FillValue::from(0u32),
311 output_shape
312 )
313 .is_ok()
314 );
315 }
316
317 #[test]
318 fn codec_reshape_invalid1() {
319 const JSON: &str = r#"{
320 "shape": [-1, 2, 2, [4]]
321 }"#;
322 let output_shape = vec![
323 NonZeroU64::new(20).unwrap(),
324 NonZeroU64::new(2).unwrap(),
325 NonZeroU64::new(2).unwrap(),
326 NonZeroU64::new(3).unwrap(),
327 ];
328 assert!(
329 codec_reshape_round_trip_impl(
330 JSON,
331 data_type::uint32(),
332 FillValue::from(0u32),
333 output_shape
334 )
335 .is_err()
336 );
337 }
338
339 #[test]
340 fn codec_reshape_invalid2() {
341 const JSON: &str = r#"{
342 "shape": [2, 2, 2]
343 }"#;
344 let output_shape = vec![
345 NonZeroU64::new(20).unwrap(),
346 NonZeroU64::new(2).unwrap(),
347 NonZeroU64::new(2).unwrap(),
348 NonZeroU64::new(3).unwrap(),
349 ];
350 assert!(
351 codec_reshape_round_trip_impl(
352 JSON,
353 data_type::uint32(),
354 FillValue::from(0u32),
355 output_shape
356 )
357 .is_err()
358 );
359 }
360
361 fn partial_decoder_u16(
362 codec: Arc<ReshapeCodec>,
363 shape: &[NonZeroU64],
364 elements: Vec<u16>,
365 ) -> Arc<dyn ArrayPartialDecoderTraits> {
366 let data_type = data_type::uint16();
367 let fill_value = FillValue::from(0u16);
368 let bytes = crate::array::transmute_to_bytes_vec(elements);
369 let bytes: ArrayBytes = bytes.into();
370 let encoded = codec
371 .encode(
372 bytes,
373 shape,
374 &data_type,
375 &fill_value,
376 &CodecOptions::default(),
377 )
378 .unwrap();
379 let input_handle = Arc::new(encoded.into_fixed().unwrap());
380 let bytes_codec = Arc::new(BytesCodec::default());
381 let (encoded_shape, encoded_data_type, encoded_fill_value) = codec
382 .encoded_representation(shape, &data_type, &fill_value)
383 .unwrap();
384 let input_handle = bytes_codec
385 .partial_decoder(
386 input_handle,
387 &encoded_shape,
388 &encoded_data_type,
389 &encoded_fill_value,
390 &CodecOptions::default(),
391 )
392 .unwrap();
393 codec
394 .partial_decoder(
395 input_handle,
396 shape,
397 &data_type,
398 &fill_value,
399 &CodecOptions::default(),
400 )
401 .unwrap()
402 }
403
404 fn partial_decode_u16(
405 partial_decoder: &dyn ArrayPartialDecoderTraits,
406 indexer: &dyn Indexer,
407 ) -> Vec<u16> {
408 let decoded_partial_chunk = partial_decoder
409 .partial_decode(indexer, &CodecOptions::default())
410 .unwrap();
411 crate::array::convert_from_bytes_slice::<u16>(&decoded_partial_chunk.into_fixed().unwrap())
412 .to_vec()
413 }
414
415 fn partial_encode_u16(
416 codec: Arc<ReshapeCodec>,
417 shape: &[NonZeroU64],
418 elements: Vec<u16>,
419 indexer: &dyn Indexer,
420 elements_partial_encode: Vec<u16>,
421 ) -> Vec<u16> {
422 let data_type = data_type::uint16();
423 let fill_value = FillValue::from(0u16);
424 let bytes = crate::array::transmute_to_bytes_vec(elements);
425 let bytes: ArrayBytes = bytes.into();
426 let encoded = codec
427 .encode(
428 bytes,
429 shape,
430 &data_type,
431 &fill_value,
432 &CodecOptions::default(),
433 )
434 .unwrap();
435
436 let bytes_codec = Arc::new(BytesCodec::default());
437 let (encoded_shape, encoded_data_type, encoded_fill_value) = codec
438 .encoded_representation(shape, &data_type, &fill_value)
439 .unwrap();
440 let encoded_chunk = bytes_codec
441 .encode(
442 encoded,
443 &encoded_shape,
444 &encoded_data_type,
445 &encoded_fill_value,
446 &CodecOptions::default(),
447 )
448 .unwrap()
449 .into_owned();
450 let output = Arc::new(Mutex::new(Some(encoded_chunk)));
451 let input_output_handle = bytes_codec
452 .clone()
453 .partial_encoder(
454 output.clone(),
455 &encoded_shape,
456 &encoded_data_type,
457 &encoded_fill_value,
458 &CodecOptions::default(),
459 )
460 .unwrap();
461 let partial_encoder = codec
462 .clone()
463 .partial_encoder(
464 input_output_handle,
465 shape,
466 &data_type,
467 &fill_value,
468 &CodecOptions::default(),
469 )
470 .unwrap();
471 assert!(partial_encoder.supports_partial_encode());
472
473 let bytes = crate::array::transmute_to_bytes_vec(elements_partial_encode);
474 partial_encoder
475 .partial_encode(indexer, &ArrayBytes::from(bytes), &CodecOptions::default())
476 .unwrap();
477
478 let output = output.lock().unwrap().clone().unwrap();
479 let decoded_encoded = bytes_codec
480 .decode(
481 output.into(),
482 &encoded_shape,
483 &encoded_data_type,
484 &encoded_fill_value,
485 &CodecOptions::default(),
486 )
487 .unwrap();
488 let decoded = codec
489 .decode(
490 decoded_encoded,
491 shape,
492 &data_type,
493 &fill_value,
494 &CodecOptions::default(),
495 )
496 .unwrap();
497 crate::array::convert_from_bytes_slice::<u16>(&decoded.into_fixed().unwrap()).to_vec()
498 }
499
500 #[test]
501 fn codec_reshape_partial_decode_array_subset() {
502 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
516 ReshapeDim::InputDims(vec![2]),
517 ReshapeDim::InputDims(vec![0, 1]),
518 ])));
519 let shape = vec![
520 NonZeroU64::new(2).unwrap(),
521 NonZeroU64::new(3).unwrap(),
522 NonZeroU64::new(4).unwrap(),
523 ];
524 let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
525
526 let decoded_region = ArraySubset::new_with_ranges(&[1..2, 1..3, 1..4]);
527 assert_eq!(
528 partial_decode_u16(partial_decoder.as_ref(), &decoded_region),
529 [17, 18, 19, 21, 22, 23]
530 );
531 }
532
533 #[test]
534 fn codec_reshape_partial_decode_indexer() {
535 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
554 ReshapeDim::InputDims(vec![2]),
555 ReshapeDim::InputDims(vec![0, 1]),
556 ])));
557 let shape = vec![
558 NonZeroU64::new(2).unwrap(),
559 NonZeroU64::new(3).unwrap(),
560 NonZeroU64::new(4).unwrap(),
561 ];
562 let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
563
564 let indexer = vec![vec![1, 2, 3], vec![0, 0, 1], vec![1, 0, 2]];
565 assert_eq!(
566 partial_decode_u16(partial_decoder.as_ref(), &indexer),
567 [23, 1, 14]
568 );
569 }
570
571 #[test]
572 fn codec_reshape_partial_decode_flatten_array_subset() {
573 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
583 ReshapeDim::InputDims(vec![0, 1, 2]),
584 ])));
585 let shape = vec![
586 NonZeroU64::new(2).unwrap(),
587 NonZeroU64::new(3).unwrap(),
588 NonZeroU64::new(4).unwrap(),
589 ];
590 let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
591
592 let decoded_region = ArraySubset::new_with_ranges(&[0..2, 1..3, 2..4]);
593 assert_eq!(
594 partial_decode_u16(partial_decoder.as_ref(), &decoded_region),
595 [6, 7, 10, 11, 18, 19, 22, 23]
596 );
597 }
598
599 #[test]
600 fn codec_reshape_partial_decode_auto_dimension() {
601 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
620 ReshapeDim::Size(NonZeroU64::new(4).unwrap()),
621 ReshapeDim::auto(),
622 ])));
623 let shape = vec![
624 NonZeroU64::new(2).unwrap(),
625 NonZeroU64::new(3).unwrap(),
626 NonZeroU64::new(4).unwrap(),
627 ];
628 let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
629
630 let indexer = vec![vec![0, 2, 3], vec![1, 0, 0], vec![1, 2, 2]];
631 assert_eq!(
632 partial_decode_u16(partial_decoder.as_ref(), &indexer),
633 [11, 12, 22]
634 );
635 }
636
637 #[test]
638 fn codec_reshape_partial_decode_1d_to_nd() {
639 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
650 ReshapeDim::Size(NonZeroU64::new(2).unwrap()),
651 ReshapeDim::Size(NonZeroU64::new(3).unwrap()),
652 ReshapeDim::Size(NonZeroU64::new(2).unwrap()),
653 ])));
654 let shape = vec![NonZeroU64::new(12).unwrap()];
655 let partial_decoder = partial_decoder_u16(codec, &shape, (0..12).collect());
656
657 #[expect(clippy::single_range_in_vec_init)]
658 let decoded_region = ArraySubset::new_with_ranges(&[3..10]);
659 assert_eq!(
660 partial_decode_u16(partial_decoder.as_ref(), &decoded_region),
661 [3, 4, 5, 6, 7, 8, 9]
662 );
663 }
664
665 #[test]
666 fn codec_reshape_partial_decode_composite_indexer() {
667 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
678 ReshapeDim::InputDims(vec![2]),
679 ReshapeDim::InputDims(vec![0, 1]),
680 ])));
681 let shape = vec![
682 NonZeroU64::new(2).unwrap(),
683 NonZeroU64::new(3).unwrap(),
684 NonZeroU64::new(4).unwrap(),
685 ];
686 let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
687
688 let decoded_regions = [
689 ArraySubset::new_with_ranges(&[0..1, 0..1, 2..4]),
690 ArraySubset::new_with_ranges(&[1..2, 2..3, 0..2]),
691 ];
692 assert_eq!(
693 partial_decode_u16(partial_decoder.as_ref(), &decoded_regions),
694 [2, 3, 20, 21]
695 );
696 }
697
698 #[test]
699 fn codec_reshape_partial_decode_invalid_indexers() {
700 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
701 ReshapeDim::InputDims(vec![2]),
702 ReshapeDim::InputDims(vec![0, 1]),
703 ])));
704 let shape = vec![
705 NonZeroU64::new(2).unwrap(),
706 NonZeroU64::new(3).unwrap(),
707 NonZeroU64::new(4).unwrap(),
708 ];
709 let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
710
711 let wrong_dimensionality = ArraySubset::new_with_ranges(&[0..1, 0..1]);
712 assert!(
713 partial_decoder
714 .partial_decode(&wrong_dimensionality, &CodecOptions::default())
715 .is_err()
716 );
717
718 let out_of_bounds = vec![vec![2, 0, 0]];
719 assert!(
720 partial_decoder
721 .partial_decode(&out_of_bounds, &CodecOptions::default())
722 .is_err()
723 );
724 }
725
726 #[test]
727 fn codec_reshape_partial_encode_array_subset() {
728 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
742 ReshapeDim::InputDims(vec![2]),
743 ReshapeDim::InputDims(vec![0, 1]),
744 ])));
745 let shape = vec![
746 NonZeroU64::new(2).unwrap(),
747 NonZeroU64::new(3).unwrap(),
748 NonZeroU64::new(4).unwrap(),
749 ];
750 let decoded_region = ArraySubset::new_with_ranges(&[1..2, 1..3, 1..4]);
751
752 assert_eq!(
753 partial_encode_u16(
754 codec,
755 &shape,
756 (0..24).collect(),
757 &decoded_region,
758 vec![100, 101, 102, 103, 104, 105],
759 ),
760 [
761 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 100, 101, 102, 20, 103,
762 104, 105,
763 ]
764 );
765 }
766
767 #[test]
768 fn codec_reshape_partial_encode_indexer() {
769 let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
788 ReshapeDim::InputDims(vec![2]),
789 ReshapeDim::InputDims(vec![0, 1]),
790 ])));
791 let shape = vec![
792 NonZeroU64::new(2).unwrap(),
793 NonZeroU64::new(3).unwrap(),
794 NonZeroU64::new(4).unwrap(),
795 ];
796 let indexer = vec![vec![1, 2, 3], vec![0, 0, 1], vec![1, 0, 2]];
797
798 assert_eq!(
799 partial_encode_u16(
800 codec,
801 &shape,
802 (0..24).collect(),
803 &indexer,
804 vec![100, 101, 102],
805 ),
806 [
807 0, 101, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 102, 15, 16, 17, 18, 19, 20, 21,
808 22, 100,
809 ]
810 );
811 }
812}