tract_tensorflow/ops/nn/s2b/
mod.rs

1use tract_hir::internal::*;
2use tract_ndarray::prelude::*;
3use tract_num_traits::Zero;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub mod raw;
9pub mod unary;
10
11pub fn space_to_batch_nd(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
12    let datum_type = pb.get_attr_datum_type("T")?;
13    Ok(Box::new(raw::SpaceToBatch::new(datum_type)))
14}
15
16pub fn batch_to_space_nd(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
17    let datum_type = pb.get_attr_datum_type("T")?;
18    Ok(Box::new(raw::BatchToSpace::new(datum_type)))
19}
20
21fn space_to_batch<T: Copy + Datum + Zero>(
22    input: TValue,
23    block_shape: &ArrayView1<i32>,
24    paddings: &ArrayView2<i32>,
25) -> TractResult<TValue> {
26    let mut data = input.into_tensor();
27
28    for (ix, pad) in paddings.view().outer_iter().enumerate() {
29        if pad[0] == 0 && pad[1] == 0 {
30            continue;
31        }
32        let mut stack = tvec!();
33        let mut pad_shape = data.shape().to_vec();
34        if pad[0] != 0 {
35            pad_shape[ix + 1] = pad[0] as usize;
36            stack.push(Tensor::zero::<T>(&pad_shape)?);
37        }
38        stack.push(data);
39        if pad[1] != 0 {
40            pad_shape[ix + 1] = pad[1] as usize;
41            stack.push(Tensor::zero::<T>(&pad_shape)?);
42        }
43        data = Tensor::stack_tensors(ix + 1, &stack)?;
44    }
45
46    let mut reshaped = vec![data.shape()[0]];
47    let block_size = block_shape.iter().map(|a| *a as usize).product::<usize>();
48    let mut final_shape = vec![block_size * data.shape()[0]];
49    for (m, &block_shape_dim) in block_shape.iter().enumerate() {
50        reshaped.push(data.shape()[m + 1] / block_shape_dim as usize);
51        reshaped.push(block_shape_dim as usize);
52        final_shape.push(data.shape()[m + 1] / block_shape_dim as usize);
53    }
54    reshaped.extend(&data.shape()[block_shape.len() + 1..]);
55    final_shape.extend(&data.shape()[block_shape.len() + 1..]);
56    let data = data.into_shape(&reshaped)?;
57
58    let mut permuted_axes: Vec<_> = (0..block_shape.len()).map(|x| 2 * x + 2).collect();
59    permuted_axes.push(0);
60    permuted_axes.extend((0..block_shape.len()).map(|x| 2 * x + 1));
61    permuted_axes.extend((block_shape.len() * 2 + 1)..data.rank());
62    let data = data.permute_axes(&permuted_axes)?;
63    let data = data.into_shape(&final_shape)?;
64
65    Ok(data.into_tvalue())
66}
67
68fn batch_to_space<T: Copy + Datum + Zero>(
69    input: TValue,
70    block_shape: &ArrayView1<i32>,
71    crops: &ArrayView2<i32>,
72) -> TractResult<TValue> {
73    let data = input.into_tensor().into_array()?;
74    let input_shape = data.shape().to_vec();
75    let crops: ArrayView2<i32> = crops.view().into_dimensionality()?;
76
77    let block_size = block_shape.iter().map(|a| *a as usize).product::<usize>();
78
79    // block_dim_1 .. block_dim_n, batches/bloc_size, dim_1, .. dim_n, chan_1, .., chan_n
80    let mut unflatten_blocked_shape = vec![];
81    unflatten_blocked_shape.extend(block_shape.iter().map(|a| *a as usize));
82    let batches = data.shape()[0] / block_size;
83    unflatten_blocked_shape.push(batches);
84    unflatten_blocked_shape.extend(&data.shape()[1..]);
85    let data = data.into_shape_with_order(&*unflatten_blocked_shape)?;
86    let mut permuted_axes = vec![block_shape.len()];
87    let mut padded_shape = vec![batches];
88    for i in 0..block_shape.shape()[0] {
89        permuted_axes.push(block_shape.len() + 1 + i);
90        permuted_axes.push(i);
91        padded_shape.push(block_shape[i] as usize * input_shape[i + 1]);
92    }
93    permuted_axes.extend((1 + block_shape.len() * 2)..data.ndim());
94    padded_shape.extend(&input_shape[1 + block_shape.len()..]);
95    let data = data.permuted_axes(permuted_axes);
96    let data: Vec<T> = data.iter().copied().collect();
97    let data = tract_ndarray::ArrayD::from_shape_vec(padded_shape, data)?;
98    let mut data = data;
99    for (i, crop) in crops.outer_iter().enumerate() {
100        if crop[0] != 0 || crop[1] != 0 {
101            let end = data.shape()[1 + i];
102            let range = (crop[0] as usize)..(end - crop[1] as usize);
103            data = data.slice_axis(Axis(i + 1), range.into()).map(|x| *x).to_owned();
104        }
105    }
106    Ok(data.into_tvalue())
107}
108
109#[cfg(test)]
110mod tests {
111    #![allow(non_snake_case)]
112    use super::raw::{BatchToSpace, SpaceToBatch};
113    use super::*;
114
115    // https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
116    #[test]
117    fn space_to_batch_nd_1() {
118        assert_eq!(
119            SpaceToBatch::new(i32::datum_type())
120                .eval(tvec![
121                    tensor4(&[[[[1i32], [2]], [[3], [4]]]]).into(),
122                    tensor1(&[2, 2]).into(),
123                    tensor2(&[[0, 0], [0, 0]]).into(),
124                ])
125                .unwrap(),
126            tvec![tensor4(&[[[[1i32]]], [[[2]]], [[[3]]], [[[4]]]]).into()],
127        )
128    }
129
130    #[test]
131    fn space_to_batch_nd_2() {
132        assert_eq!(
133            SpaceToBatch::new(i32::datum_type())
134                .eval(tvec![
135                    tensor4(&[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).into(),
136                    tensor1(&[2, 2]).into(),
137                    tensor2(&[[0, 0], [0, 0]]).into(),
138                ])
139                .unwrap(),
140            tvec![tensor4(&[[[[1i32, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]],])
141                .into(),],
142        )
143    }
144
145    #[test]
146    fn space_to_batch_nd_3() {
147        assert_eq!(
148            SpaceToBatch::new(i32::datum_type())
149                .eval(tvec![
150                    tensor4(&[[
151                        [[1], [2], [3], [4]],
152                        [[5], [6], [7], [8]],
153                        [[9], [10], [11], [12]],
154                        [[13], [14], [15], [16]],
155                    ]])
156                    .into(),
157                    tensor1(&[2, 2]).into(),
158                    tensor2(&[[0, 0], [0, 0]]).into(),
159                ])
160                .unwrap(),
161            tvec![tensor4(&[
162                [[[1], [3]], [[9], [11]]],
163                [[[2], [4]], [[10], [12]]],
164                [[[5], [7]], [[13], [15]]],
165                [[[6], [8]], [[14], [16]]],
166            ])
167            .into()],
168        )
169    }
170
171    #[test]
172    fn space_to_batch_nd_4() {
173        assert_eq!(
174            SpaceToBatch::new(i32::datum_type())
175                .eval(tvec![
176                    tensor4(&[
177                        [[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
178                        [[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
179                    ])
180                    .into(),
181                    tensor1(&[2, 2]).into(),
182                    tensor2(&[[0, 0], [2, 0]]).into(),
183                ])
184                .unwrap(),
185            tvec![tensor4(&[
186                [[[0], [1], [3]]],
187                [[[0], [9], [11]]],
188                [[[0], [2], [4]]],
189                [[[0], [10], [12]]],
190                [[[0], [5], [7]]],
191                [[[0], [13], [15]]],
192                [[[0], [6], [8]]],
193                [[[0], [14], [16]]],
194            ])
195            .into(),],
196        )
197    }
198
199    #[test]
200    fn space_to_batch_nd_infer_1() {
201        let mut op = SpaceToBatch::new(f32::datum_type());
202        let data = f32::fact([1, 4, 16]).into();
203        let block_shape = InferenceFact::from(Tensor::from(arr1(&[2])));
204        let paddings = InferenceFact::from(Tensor::from(arr2(&[[0.to_dim(), 0.to_dim()]])));
205        let any = InferenceFact::default();
206
207        let (_, outputs, _) =
208            op.infer_facts(tvec!(&data, &block_shape, &paddings), tvec!(&any), tvec!()).unwrap();
209
210        assert_eq!(outputs[0], f32::fact([2, 2, 16]).into())
211    }
212
213    #[test]
214    fn space_to_batch_nd_infer_2() {
215        let table = SymbolScope::default();
216        let s = table.sym("S");
217        let mut op = SpaceToBatch::new(f32::datum_type());
218        let data = f32::fact(dims!(1, s.to_dim() - 4, 16)).into();
219        let block_shape = InferenceFact::from(Tensor::from(arr1(&[2])));
220        let paddings = InferenceFact::from(Tensor::from(arr2(&[[0.to_dim(), (s.to_dim() % 2)]])));
221        let any = InferenceFact::default();
222
223        let (_, outputs, _) =
224            op.infer_facts(tvec!(&data, &block_shape, &paddings), tvec!(&any), tvec!()).unwrap();
225        assert_eq!(
226            outputs[0],
227            f32::fact(dims!(2, (s.to_dim() + s.to_dim() % 2 - 4) / 2, 16)).into()
228        );
229    }
230
231    #[test]
232    fn batch_to_space_nd_1() {
233        assert_eq!(
234            BatchToSpace::new(i32::datum_type())
235                .eval(tvec![
236                    tensor4(&[[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).into(),
237                    tensor1(&[2, 2]).into(),
238                    tensor2(&[[0, 0], [0, 0]]).into(),
239                ])
240                .unwrap(),
241            tvec![tensor4(&[[[[1], [2]], [[3], [4]]]]).into()]
242        )
243    }
244
245    #[test]
246    fn batch_to_space_nd_2() {
247        assert_eq!(
248            BatchToSpace::new(i32::datum_type())
249                .eval(tvec![
250                    tensor4(&[[[[1i32, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]],])
251                        .into(),
252                    tensor1(&[2, 2]).into(),
253                    tensor2(&[[0, 0], [0, 0]]).into(),
254                ])
255                .unwrap(),
256            tvec![tensor4(&[[[[1i32, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).into()]
257        )
258    }
259
260    #[test]
261    fn batch_to_space_nd_3() {
262        assert_eq!(
263            BatchToSpace::new(i32::datum_type())
264                .eval(tvec![
265                    tensor4(&[
266                        [[[1i32], [3]], [[9], [11]]],
267                        [[[2], [4]], [[10], [12]]],
268                        [[[5], [7]], [[13], [15]]],
269                        [[[6], [8]], [[14], [16]]],
270                    ])
271                    .into(),
272                    tensor1(&[2, 2]).into(),
273                    tensor2(&[[0, 0], [0, 0]]).into(),
274                ])
275                .unwrap(),
276            tvec![tensor4(&[[
277                [[1i32], [2], [3], [4]],
278                [[5], [6], [7], [8]],
279                [[9], [10], [11], [12]],
280                [[13], [14], [15], [16]],
281            ]])
282            .into(),]
283        )
284    }
285
286    #[test]
287    fn batch_to_space_nd_4() {
288        assert_eq!(
289            BatchToSpace::new(i32::datum_type())
290                .eval(tvec![
291                    tensor4(&[
292                        [[[0i32], [1], [3]]],
293                        [[[0], [9], [11]]],
294                        [[[0], [2], [4]]],
295                        [[[0], [10], [12]]],
296                        [[[0], [5], [7]]],
297                        [[[0], [13], [15]]],
298                        [[[0], [6], [8]]],
299                        [[[0], [14], [16]]],
300                    ])
301                    .into(),
302                    tensor1(&[2, 2]).into(),
303                    tensor2(&[[0, 0], [2, 0]]).into(),
304                ])
305                .unwrap(),
306            tvec![tensor4(&[
307                [[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
308                [[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
309            ])
310            .into(),]
311        )
312    }
313}