Skip to main content

zarrs/array/codec/array_to_array/
reshape.rs

1//! The `reshape` array to array codec (Experimental).
2//!
3//! Performs a reshaping operation.
4//!
5//! <div class="warning">
6//! This codec is experimental and may be incompatible with other Zarr V3 implementations.
7//! </div>
8//!
9//! ### Compatible Implementations
10//! None
11//!
12//! ### Specification
13//! - <https://github.com/zarr-developers/zarr-extensions/blob/7295bf1ec15c978f1a63b90d55891712b950c797/codecs/reshape/README.md>
14//!
15//! ### Codec `name` Aliases (Zarr V3)
16//! - `reshape`
17//!
18//! ### Codec `id` Aliases (Zarr V2)
19//! None
20//!
21//! ### Codec `configuration` Example - [`ReshapeCodecConfiguration`]:
22//! ```rust
23//! # let JSON = r#"
24//! {
25//!     "shape": [[0, 1], -1, [3], 10]
26//! }
27//! # "#;
28//! # use zarrs::metadata_ext::codec::reshape::ReshapeCodecConfiguration;
29//! # let configuration: ReshapeCodecConfiguration = serde_json::from_str(JSON).unwrap();
30//! ```
31
32mod reshape_codec;
33mod reshape_codec_partial;
34
35use std::num::NonZeroU64;
36use std::sync::Arc;
37
38use num::Integer;
39pub use reshape_codec::ReshapeCodec;
40use zarrs_metadata::v3::MetadataV3;
41
42// use itertools::Itertools;
43use crate::array::{ArrayIndices, ChunkShape, Indexer, IndexerError, unravel_index};
44use zarrs_codec::{Codec, CodecError, CodecPluginV3, CodecTraitsV3};
45pub use zarrs_metadata_ext::codec::reshape::{
46    ReshapeCodecConfiguration, ReshapeCodecConfigurationV1, ReshapeDim, ReshapeShape,
47};
48use zarrs_plugin::PluginCreateError;
49
50fn get_encoded_shape(
51    reshape_shape: &ReshapeShape,
52    decoded_shape: &[NonZeroU64],
53) -> Result<ChunkShape, CodecError> {
54    let mut encoded_shape = Vec::with_capacity(reshape_shape.0.len());
55    let mut fill_index = None;
56    for output_dim in &reshape_shape.0 {
57        match output_dim {
58            ReshapeDim::Size(size) => encoded_shape.push(*size),
59            ReshapeDim::InputDims(input_dims) => {
60                let mut product = NonZeroU64::new(1).unwrap();
61                for input_dim in input_dims {
62                    let input_shape = *decoded_shape
63                        .get(usize::try_from(*input_dim).unwrap())
64                        .ok_or_else(|| {
65                            CodecError::Other(
66                                format!("reshape codec shape references a dimension ({input_dim}) larger than the chunk dimensionality ({})", decoded_shape.len()),
67                            )
68                        })?;
69                    product = product.checked_mul(input_shape).unwrap();
70                }
71                encoded_shape.push(product);
72            }
73            ReshapeDim::Auto(_) => {
74                fill_index = Some(encoded_shape.len());
75                encoded_shape.push(NonZeroU64::new(1).unwrap());
76            }
77        }
78    }
79
80    let num_elements_input = decoded_shape.iter().map(|u| u.get()).product::<u64>();
81    let num_elements_output = encoded_shape.iter().map(|u| u.get()).product::<u64>();
82    if let Some(fill_index) = fill_index {
83        let (quot, rem) = num_elements_input.div_rem(&num_elements_output);
84        if rem == 0 {
85            encoded_shape[fill_index] = NonZeroU64::new(quot).unwrap();
86        } else {
87            return Err(CodecError::Other(format!(
88                "reshape codec no substitution for dim {fill_index} can satisfy decoded_shape {decoded_shape:?} == encoded_shape {encoded_shape:?}."
89            )));
90        }
91    } else if num_elements_input != num_elements_output {
92        return Err(CodecError::Other(format!(
93            "reshape codec encoded/decoded number of elements differ: decoded_shape {decoded_shape:?} ({num_elements_input}) encoded_shape {encoded_shape:?} ({num_elements_output})."
94        )));
95    }
96
97    Ok(encoded_shape)
98}
99
100fn get_reshaped_indexer(
101    indexer: &dyn Indexer,
102    decoded_shape: &[NonZeroU64],
103    encoded_shape: &[NonZeroU64],
104) -> Result<impl Indexer, CodecError> {
105    if indexer.dimensionality() != decoded_shape.len() {
106        return Err(IndexerError::new_incompatible_dimensionality(
107            indexer.dimensionality(),
108            decoded_shape.len(),
109        )
110        .into());
111    }
112
113    let decoded_shape = bytemuck::must_cast_slice(decoded_shape);
114    let encoded_shape = bytemuck::must_cast_slice(encoded_shape);
115    let indices = indexer
116        .iter_linearised_indices(decoded_shape)?
117        .map(|linear_index| {
118            unravel_index(linear_index, encoded_shape).ok_or_else(|| {
119                CodecError::Other(
120                    "reshape codec encoded/decoded number of elements differ".to_string(),
121                )
122            })
123        })
124        .collect::<Result<Vec<ArrayIndices>, _>>()?;
125
126    Ok(indices)
127}
128
129zarrs_plugin::impl_extension_aliases!(ReshapeCodec, v3: "reshape");
130
131// Register the V3 codec.
132inventory::submit! {
133    CodecPluginV3::new::<ReshapeCodec>()
134}
135
136impl CodecTraitsV3 for ReshapeCodec {
137    fn create(metadata: &MetadataV3) -> Result<Codec, PluginCreateError> {
138        let configuration: ReshapeCodecConfiguration = metadata.to_typed_configuration()?;
139        let codec = Arc::new(ReshapeCodec::new_with_configuration(&configuration)?);
140        Ok(Codec::ArrayToArray(codec))
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use std::num::NonZeroU64;
147    use std::sync::Mutex;
148
149    use super::*;
150    use crate::array::codec::BytesCodec;
151    use crate::array::{ArrayBytes, ArraySubset, ChunkShapeTraits, DataType, FillValue, data_type};
152    use zarrs_codec::{
153        ArrayPartialDecoderTraits, ArrayToArrayCodecTraits, ArrayToBytesCodecTraits, CodecOptions,
154    };
155
156    fn codec_reshape_round_trip_impl(
157        json: &str,
158        data_type: DataType,
159        fill_value: FillValue,
160        output_shape: Vec<NonZeroU64>,
161    ) -> Result<(), Box<dyn std::error::Error>> {
162        let shape = vec![
163            NonZeroU64::new(5).unwrap(),
164            NonZeroU64::new(4).unwrap(),
165            NonZeroU64::new(4).unwrap(),
166            NonZeroU64::new(3).unwrap(),
167        ];
168        let size = shape.num_elements_usize() * data_type.fixed_size().unwrap();
169        let bytes: Vec<u8> = (0..size).map(|s| s as u8).collect();
170        let bytes: ArrayBytes = bytes.into();
171
172        let configuration: ReshapeCodecConfiguration = serde_json::from_str(json)?;
173        let codec = ReshapeCodec::new_with_configuration(&configuration)?;
174        assert_eq!(codec.encoded_shape(&shape)?, output_shape);
175
176        let encoded = codec.encode(
177            bytes.clone(),
178            &shape,
179            &data_type,
180            &fill_value,
181            &CodecOptions::default(),
182        )?;
183        let decoded = codec.decode(
184            encoded,
185            &shape,
186            &data_type,
187            &fill_value,
188            &CodecOptions::default(),
189        )?;
190        assert_eq!(bytes, decoded);
191        Ok(())
192    }
193
194    #[test]
195    fn codec_reshape_round_trip_array1() {
196        const JSON: &str = r#"{
197            "shape": [[0, 1], [2], 3]
198        }"#;
199        let output_shape = vec![
200            NonZeroU64::new(20).unwrap(),
201            NonZeroU64::new(4).unwrap(),
202            NonZeroU64::new(3).unwrap(),
203        ];
204        assert!(
205            codec_reshape_round_trip_impl(
206                JSON,
207                data_type::uint32(),
208                FillValue::from(0u32),
209                output_shape
210            )
211            .is_ok()
212        );
213    }
214
215    #[test]
216    fn codec_reshape_round_trip_array2() {
217        const JSON: &str = r#"{
218            "shape": [[0, 1], [2], -1]
219        }"#;
220        let output_shape = vec![
221            NonZeroU64::new(20).unwrap(),
222            NonZeroU64::new(4).unwrap(),
223            NonZeroU64::new(3).unwrap(),
224        ];
225        assert!(
226            codec_reshape_round_trip_impl(
227                JSON,
228                data_type::uint32(),
229                FillValue::from(0u32),
230                output_shape
231            )
232            .is_ok()
233        );
234    }
235
236    #[test]
237    fn codec_reshape_round_trip_array3() {
238        const JSON: &str = r#"{
239            "shape": [[0, 1, 2], 3]
240        }"#;
241        let output_shape = vec![NonZeroU64::new(80).unwrap(), NonZeroU64::new(3).unwrap()];
242        assert!(
243            codec_reshape_round_trip_impl(
244                JSON,
245                data_type::uint32(),
246                FillValue::from(0u32),
247                output_shape
248            )
249            .is_ok()
250        );
251    }
252
253    #[test]
254    fn codec_reshape_round_trip_array4() {
255        const JSON: &str = r#"{
256            "shape": [[0], -1, [2, 3]]
257        }"#;
258        let output_shape = vec![
259            NonZeroU64::new(5).unwrap(),
260            NonZeroU64::new(4).unwrap(),
261            NonZeroU64::new(12).unwrap(),
262        ];
263        assert!(
264            codec_reshape_round_trip_impl(
265                JSON,
266                data_type::uint32(),
267                FillValue::from(0u32),
268                output_shape
269            )
270            .is_ok()
271        );
272    }
273
274    #[test]
275    fn codec_reshape_round_trip_array5() {
276        const JSON: &str = r#"{
277            "shape": [[0], -1, [3]]
278        }"#;
279        let output_shape = vec![
280            NonZeroU64::new(5).unwrap(),
281            NonZeroU64::new(16).unwrap(),
282            NonZeroU64::new(3).unwrap(),
283        ];
284        assert!(
285            codec_reshape_round_trip_impl(
286                JSON,
287                data_type::uint32(),
288                FillValue::from(0u32),
289                output_shape
290            )
291            .is_ok()
292        );
293    }
294
295    #[test]
296    fn codec_reshape_round_trip_array6() {
297        const JSON: &str = r#"{
298            "shape": [-1, 2, 2, [3]]
299        }"#;
300        let output_shape = vec![
301            NonZeroU64::new(20).unwrap(),
302            NonZeroU64::new(2).unwrap(),
303            NonZeroU64::new(2).unwrap(),
304            NonZeroU64::new(3).unwrap(),
305        ];
306        assert!(
307            codec_reshape_round_trip_impl(
308                JSON,
309                data_type::uint32(),
310                FillValue::from(0u32),
311                output_shape
312            )
313            .is_ok()
314        );
315    }
316
317    #[test]
318    fn codec_reshape_invalid1() {
319        const JSON: &str = r#"{
320            "shape": [-1, 2, 2, [4]]
321        }"#;
322        let output_shape = vec![
323            NonZeroU64::new(20).unwrap(),
324            NonZeroU64::new(2).unwrap(),
325            NonZeroU64::new(2).unwrap(),
326            NonZeroU64::new(3).unwrap(),
327        ];
328        assert!(
329            codec_reshape_round_trip_impl(
330                JSON,
331                data_type::uint32(),
332                FillValue::from(0u32),
333                output_shape
334            )
335            .is_err()
336        );
337    }
338
339    #[test]
340    fn codec_reshape_invalid2() {
341        const JSON: &str = r#"{
342            "shape": [2, 2, 2]
343        }"#;
344        let output_shape = vec![
345            NonZeroU64::new(20).unwrap(),
346            NonZeroU64::new(2).unwrap(),
347            NonZeroU64::new(2).unwrap(),
348            NonZeroU64::new(3).unwrap(),
349        ];
350        assert!(
351            codec_reshape_round_trip_impl(
352                JSON,
353                data_type::uint32(),
354                FillValue::from(0u32),
355                output_shape
356            )
357            .is_err()
358        );
359    }
360
361    fn partial_decoder_u16(
362        codec: Arc<ReshapeCodec>,
363        shape: &[NonZeroU64],
364        elements: Vec<u16>,
365    ) -> Arc<dyn ArrayPartialDecoderTraits> {
366        let data_type = data_type::uint16();
367        let fill_value = FillValue::from(0u16);
368        let bytes = crate::array::transmute_to_bytes_vec(elements);
369        let bytes: ArrayBytes = bytes.into();
370        let encoded = codec
371            .encode(
372                bytes,
373                shape,
374                &data_type,
375                &fill_value,
376                &CodecOptions::default(),
377            )
378            .unwrap();
379        let input_handle = Arc::new(encoded.into_fixed().unwrap());
380        let bytes_codec = Arc::new(BytesCodec::default());
381        let (encoded_shape, encoded_data_type, encoded_fill_value) = codec
382            .encoded_representation(shape, &data_type, &fill_value)
383            .unwrap();
384        let input_handle = bytes_codec
385            .partial_decoder(
386                input_handle,
387                &encoded_shape,
388                &encoded_data_type,
389                &encoded_fill_value,
390                &CodecOptions::default(),
391            )
392            .unwrap();
393        codec
394            .partial_decoder(
395                input_handle,
396                shape,
397                &data_type,
398                &fill_value,
399                &CodecOptions::default(),
400            )
401            .unwrap()
402    }
403
404    fn partial_decode_u16(
405        partial_decoder: &dyn ArrayPartialDecoderTraits,
406        indexer: &dyn Indexer,
407    ) -> Vec<u16> {
408        let decoded_partial_chunk = partial_decoder
409            .partial_decode(indexer, &CodecOptions::default())
410            .unwrap();
411        crate::array::convert_from_bytes_slice::<u16>(&decoded_partial_chunk.into_fixed().unwrap())
412            .to_vec()
413    }
414
415    fn partial_encode_u16(
416        codec: Arc<ReshapeCodec>,
417        shape: &[NonZeroU64],
418        elements: Vec<u16>,
419        indexer: &dyn Indexer,
420        elements_partial_encode: Vec<u16>,
421    ) -> Vec<u16> {
422        let data_type = data_type::uint16();
423        let fill_value = FillValue::from(0u16);
424        let bytes = crate::array::transmute_to_bytes_vec(elements);
425        let bytes: ArrayBytes = bytes.into();
426        let encoded = codec
427            .encode(
428                bytes,
429                shape,
430                &data_type,
431                &fill_value,
432                &CodecOptions::default(),
433            )
434            .unwrap();
435
436        let bytes_codec = Arc::new(BytesCodec::default());
437        let (encoded_shape, encoded_data_type, encoded_fill_value) = codec
438            .encoded_representation(shape, &data_type, &fill_value)
439            .unwrap();
440        let encoded_chunk = bytes_codec
441            .encode(
442                encoded,
443                &encoded_shape,
444                &encoded_data_type,
445                &encoded_fill_value,
446                &CodecOptions::default(),
447            )
448            .unwrap()
449            .into_owned();
450        let output = Arc::new(Mutex::new(Some(encoded_chunk)));
451        let input_output_handle = bytes_codec
452            .clone()
453            .partial_encoder(
454                output.clone(),
455                &encoded_shape,
456                &encoded_data_type,
457                &encoded_fill_value,
458                &CodecOptions::default(),
459            )
460            .unwrap();
461        let partial_encoder = codec
462            .clone()
463            .partial_encoder(
464                input_output_handle,
465                shape,
466                &data_type,
467                &fill_value,
468                &CodecOptions::default(),
469            )
470            .unwrap();
471        assert!(partial_encoder.supports_partial_encode());
472
473        let bytes = crate::array::transmute_to_bytes_vec(elements_partial_encode);
474        partial_encoder
475            .partial_encode(indexer, &ArrayBytes::from(bytes), &CodecOptions::default())
476            .unwrap();
477
478        let output = output.lock().unwrap().clone().unwrap();
479        let decoded_encoded = bytes_codec
480            .decode(
481                output.into(),
482                &encoded_shape,
483                &encoded_data_type,
484                &encoded_fill_value,
485                &CodecOptions::default(),
486            )
487            .unwrap();
488        let decoded = codec
489            .decode(
490                decoded_encoded,
491                shape,
492                &data_type,
493                &fill_value,
494                &CodecOptions::default(),
495            )
496            .unwrap();
497        crate::array::convert_from_bytes_slice::<u16>(&decoded.into_fixed().unwrap()).to_vec()
498    }
499
500    #[test]
501    fn codec_reshape_partial_decode_array_subset() {
502        // Decoded shape [2, 3, 4]:
503        //
504        //   decoded[0, :, :]          decoded[1, :, :]
505        //   00 01 02 03              12 13 14 15
506        //   04 05 06 07              16 17 18 19  <- select cols 1..4
507        //   08 09 10 11              20 21 22 23  <- select cols 1..4
508        //
509        // Encoded shape [4, 6] after [[2], [0, 1]]:
510        //
511        //   00 01 02 03 04 05
512        //   06 07 08 09 10 11
513        //   12 13 14 15 16 17
514        //   18 19 20 21 22 23
515        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
516            ReshapeDim::InputDims(vec![2]),
517            ReshapeDim::InputDims(vec![0, 1]),
518        ])));
519        let shape = vec![
520            NonZeroU64::new(2).unwrap(),
521            NonZeroU64::new(3).unwrap(),
522            NonZeroU64::new(4).unwrap(),
523        ];
524        let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
525
526        let decoded_region = ArraySubset::new_with_ranges(&[1..2, 1..3, 1..4]);
527        assert_eq!(
528            partial_decode_u16(partial_decoder.as_ref(), &decoded_region),
529            [17, 18, 19, 21, 22, 23]
530        );
531    }
532
533    #[test]
534    fn codec_reshape_partial_decode_indexer() {
535        // Decoded shape [2, 3, 4]:
536        //
537        //   decoded[0, :, :]          decoded[1, :, :]
538        //   00 01 02 03              12 13 14 15
539        //   04 05 06 07              16 17 18 19
540        //   08 09 10 11              20 21 22 23
541        //
542        // Encoded shape [4, 6] after [[2], [0, 1]]:
543        //
544        //   00 01 02 03 04 05
545        //   06 07 08 09 10 11
546        //   12 13 14 15 16 17
547        //   18 19 20 21 22 23
548        //
549        // Points:
550        //   decoded[1, 2, 3] -> encoded[3, 5] -> 23
551        //   decoded[0, 0, 1] -> encoded[0, 1] -> 01
552        //   decoded[1, 0, 2] -> encoded[2, 2] -> 14
553        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
554            ReshapeDim::InputDims(vec![2]),
555            ReshapeDim::InputDims(vec![0, 1]),
556        ])));
557        let shape = vec![
558            NonZeroU64::new(2).unwrap(),
559            NonZeroU64::new(3).unwrap(),
560            NonZeroU64::new(4).unwrap(),
561        ];
562        let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
563
564        let indexer = vec![vec![1, 2, 3], vec![0, 0, 1], vec![1, 0, 2]];
565        assert_eq!(
566            partial_decode_u16(partial_decoder.as_ref(), &indexer),
567            [23, 1, 14]
568        );
569    }
570
571    #[test]
572    fn codec_reshape_partial_decode_flatten_array_subset() {
573        // Decoded shape [2, 3, 4] flattened to encoded shape [24]:
574        //
575        //   decoded[0, :, :]          decoded[1, :, :]
576        //   00 01 02 03              12 13 14 15
577        //   04 05 06 07  <- select   16 17 18 19  <- select
578        //   08 09 10 11  <- select   20 21 22 23  <- select
579        //
580        //   encoded:
581        //   00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23
582        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
583            ReshapeDim::InputDims(vec![0, 1, 2]),
584        ])));
585        let shape = vec![
586            NonZeroU64::new(2).unwrap(),
587            NonZeroU64::new(3).unwrap(),
588            NonZeroU64::new(4).unwrap(),
589        ];
590        let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
591
592        let decoded_region = ArraySubset::new_with_ranges(&[0..2, 1..3, 2..4]);
593        assert_eq!(
594            partial_decode_u16(partial_decoder.as_ref(), &decoded_region),
595            [6, 7, 10, 11, 18, 19, 22, 23]
596        );
597    }
598
599    #[test]
600    fn codec_reshape_partial_decode_auto_dimension() {
601        // Decoded shape [2, 3, 4]:
602        //
603        //   decoded[0, :, :]          decoded[1, :, :]
604        //   00 01 02 03              12 13 14 15
605        //   04 05 06 07              16 17 18 19
606        //   08 09 10 11              20 21 22 23
607        //
608        // Encoded shape [4, 6] after [4, -1]:
609        //
610        //   00 01 02 03 04 05
611        //   06 07 08 09 10 11
612        //   12 13 14 15 16 17
613        //   18 19 20 21 22 23
614        //
615        // Points:
616        //   decoded[0, 2, 3] -> encoded[1, 5] -> 11
617        //   decoded[1, 0, 0] -> encoded[2, 0] -> 12
618        //   decoded[1, 2, 2] -> encoded[3, 4] -> 22
619        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
620            ReshapeDim::Size(NonZeroU64::new(4).unwrap()),
621            ReshapeDim::auto(),
622        ])));
623        let shape = vec![
624            NonZeroU64::new(2).unwrap(),
625            NonZeroU64::new(3).unwrap(),
626            NonZeroU64::new(4).unwrap(),
627        ];
628        let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
629
630        let indexer = vec![vec![0, 2, 3], vec![1, 0, 0], vec![1, 2, 2]];
631        assert_eq!(
632            partial_decode_u16(partial_decoder.as_ref(), &indexer),
633            [11, 12, 22]
634        );
635    }
636
637    #[test]
638    fn codec_reshape_partial_decode_1d_to_nd() {
639        // Decoded shape [12]:
640        //
641        //   00 01 02 03 04 05 06 07 08 09 10 11
642        //
643        // Encoded shape [2, 3, 2]:
644        //
645        //   encoded[0, :, :]          encoded[1, :, :]
646        //   00 01                    06 07
647        //   02 03                    08 09
648        //   04 05                    10 11
649        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
650            ReshapeDim::Size(NonZeroU64::new(2).unwrap()),
651            ReshapeDim::Size(NonZeroU64::new(3).unwrap()),
652            ReshapeDim::Size(NonZeroU64::new(2).unwrap()),
653        ])));
654        let shape = vec![NonZeroU64::new(12).unwrap()];
655        let partial_decoder = partial_decoder_u16(codec, &shape, (0..12).collect());
656
657        #[expect(clippy::single_range_in_vec_init)]
658        let decoded_region = ArraySubset::new_with_ranges(&[3..10]);
659        assert_eq!(
660            partial_decode_u16(partial_decoder.as_ref(), &decoded_region),
661            [3, 4, 5, 6, 7, 8, 9]
662        );
663    }
664
665    #[test]
666    fn codec_reshape_partial_decode_composite_indexer() {
667        // Decoded shape [2, 3, 4]:
668        //
669        //   decoded[0, :, :]          decoded[1, :, :]
670        //   00 01 02 03  <- select   12 13 14 15
671        //   04 05 06 07              16 17 18 19
672        //   08 09 10 11              20 21 22 23  <- select
673        //
674        // The composite indexer requests two disjoint decoded regions:
675        //   [0..1, 0..1, 2..4] -> 02 03
676        //   [1..2, 2..3, 0..2] -> 20 21
677        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
678            ReshapeDim::InputDims(vec![2]),
679            ReshapeDim::InputDims(vec![0, 1]),
680        ])));
681        let shape = vec![
682            NonZeroU64::new(2).unwrap(),
683            NonZeroU64::new(3).unwrap(),
684            NonZeroU64::new(4).unwrap(),
685        ];
686        let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
687
688        let decoded_regions = [
689            ArraySubset::new_with_ranges(&[0..1, 0..1, 2..4]),
690            ArraySubset::new_with_ranges(&[1..2, 2..3, 0..2]),
691        ];
692        assert_eq!(
693            partial_decode_u16(partial_decoder.as_ref(), &decoded_regions),
694            [2, 3, 20, 21]
695        );
696    }
697
698    #[test]
699    fn codec_reshape_partial_decode_invalid_indexers() {
700        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
701            ReshapeDim::InputDims(vec![2]),
702            ReshapeDim::InputDims(vec![0, 1]),
703        ])));
704        let shape = vec![
705            NonZeroU64::new(2).unwrap(),
706            NonZeroU64::new(3).unwrap(),
707            NonZeroU64::new(4).unwrap(),
708        ];
709        let partial_decoder = partial_decoder_u16(codec, &shape, (0..24).collect());
710
711        let wrong_dimensionality = ArraySubset::new_with_ranges(&[0..1, 0..1]);
712        assert!(
713            partial_decoder
714                .partial_decode(&wrong_dimensionality, &CodecOptions::default())
715                .is_err()
716        );
717
718        let out_of_bounds = vec![vec![2, 0, 0]];
719        assert!(
720            partial_decoder
721                .partial_decode(&out_of_bounds, &CodecOptions::default())
722                .is_err()
723        );
724    }
725
726    #[test]
727    fn codec_reshape_partial_encode_array_subset() {
728        // Decoded shape [2, 3, 4]:
729        //
730        //   decoded[0, :, :]          decoded[1, :, :]
731        //   00 01 02 03              12 13 14 15
732        //   04 05 06 07              16 17 18 19  <- write 100 101 102
733        //   08 09 10 11              20 21 22 23  <- write 103 104 105
734        //
735        // Encoded shape [4, 6] after [[2], [0, 1]]:
736        //
737        //   00 01 02 03 04 05
738        //   06 07 08 09 10 11
739        //   12 13 14 15 16 17
740        //   18 19 20 21 22 23
741        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
742            ReshapeDim::InputDims(vec![2]),
743            ReshapeDim::InputDims(vec![0, 1]),
744        ])));
745        let shape = vec![
746            NonZeroU64::new(2).unwrap(),
747            NonZeroU64::new(3).unwrap(),
748            NonZeroU64::new(4).unwrap(),
749        ];
750        let decoded_region = ArraySubset::new_with_ranges(&[1..2, 1..3, 1..4]);
751
752        assert_eq!(
753            partial_encode_u16(
754                codec,
755                &shape,
756                (0..24).collect(),
757                &decoded_region,
758                vec![100, 101, 102, 103, 104, 105],
759            ),
760            [
761                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 100, 101, 102, 20, 103,
762                104, 105,
763            ]
764        );
765    }
766
767    #[test]
768    fn codec_reshape_partial_encode_indexer() {
769        // Decoded shape [2, 3, 4]:
770        //
771        //   decoded[0, :, :]          decoded[1, :, :]
772        //   00 01 02 03              12 13 14 15
773        //   04 05 06 07              16 17 18 19
774        //   08 09 10 11              20 21 22 23
775        //
776        // Encoded shape [4, 6] after [[2], [0, 1]]:
777        //
778        //   00 01 02 03 04 05
779        //   06 07 08 09 10 11
780        //   12 13 14 15 16 17
781        //   18 19 20 21 22 23
782        //
783        // Writes:
784        //   decoded[1, 2, 3] -> encoded[3, 5] <- 100
785        //   decoded[0, 0, 1] -> encoded[0, 1] <- 101
786        //   decoded[1, 0, 2] -> encoded[2, 2] <- 102
787        let codec = Arc::new(ReshapeCodec::new(ReshapeShape(vec![
788            ReshapeDim::InputDims(vec![2]),
789            ReshapeDim::InputDims(vec![0, 1]),
790        ])));
791        let shape = vec![
792            NonZeroU64::new(2).unwrap(),
793            NonZeroU64::new(3).unwrap(),
794            NonZeroU64::new(4).unwrap(),
795        ];
796        let indexer = vec![vec![1, 2, 3], vec![0, 0, 1], vec![1, 0, 2]];
797
798        assert_eq!(
799            partial_encode_u16(
800                codec,
801                &shape,
802                (0..24).collect(),
803                &indexer,
804                vec![100, 101, 102],
805            ),
806            [
807                0, 101, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 102, 15, 16, 17, 18, 19, 20, 21,
808                22, 100,
809            ]
810        );
811    }
812}