1use tract_hir::internal::*;
2use tract_ndarray::prelude::*;
3use tract_num_traits::Zero;
4
5use crate::model::ParsingContext;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub mod raw;
9pub mod unary;
10
11pub fn space_to_batch_nd(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
12 let datum_type = pb.get_attr_datum_type("T")?;
13 Ok(Box::new(raw::SpaceToBatch::new(datum_type)))
14}
15
16pub fn batch_to_space_nd(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
17 let datum_type = pb.get_attr_datum_type("T")?;
18 Ok(Box::new(raw::BatchToSpace::new(datum_type)))
19}
20
21fn space_to_batch<T: Copy + Datum + Zero>(
22 input: TValue,
23 block_shape: &ArrayView1<i32>,
24 paddings: &ArrayView2<i32>,
25) -> TractResult<TValue> {
26 let mut data = input.into_tensor();
27
28 for (ix, pad) in paddings.view().outer_iter().enumerate() {
29 if pad[0] == 0 && pad[1] == 0 {
30 continue;
31 }
32 let mut stack = tvec!();
33 let mut pad_shape = data.shape().to_vec();
34 if pad[0] != 0 {
35 pad_shape[ix + 1] = pad[0] as usize;
36 stack.push(Tensor::zero::<T>(&pad_shape)?);
37 }
38 stack.push(data);
39 if pad[1] != 0 {
40 pad_shape[ix + 1] = pad[1] as usize;
41 stack.push(Tensor::zero::<T>(&pad_shape)?);
42 }
43 data = Tensor::stack_tensors(ix + 1, &stack)?;
44 }
45
46 let mut reshaped = vec![data.shape()[0]];
47 let block_size = block_shape.iter().map(|a| *a as usize).product::<usize>();
48 let mut final_shape = vec![block_size * data.shape()[0]];
49 for (m, &block_shape_dim) in block_shape.iter().enumerate() {
50 reshaped.push(data.shape()[m + 1] / block_shape_dim as usize);
51 reshaped.push(block_shape_dim as usize);
52 final_shape.push(data.shape()[m + 1] / block_shape_dim as usize);
53 }
54 reshaped.extend(&data.shape()[block_shape.len() + 1..]);
55 final_shape.extend(&data.shape()[block_shape.len() + 1..]);
56 let data = data.into_shape(&reshaped)?;
57
58 let mut permuted_axes: Vec<_> = (0..block_shape.len()).map(|x| 2 * x + 2).collect();
59 permuted_axes.push(0);
60 permuted_axes.extend((0..block_shape.len()).map(|x| 2 * x + 1));
61 permuted_axes.extend((block_shape.len() * 2 + 1)..data.rank());
62 let data = data.permute_axes(&permuted_axes)?;
63 let data = data.into_shape(&final_shape)?;
64
65 Ok(data.into_tvalue())
66}
67
68fn batch_to_space<T: Copy + Datum + Zero>(
69 input: TValue,
70 block_shape: &ArrayView1<i32>,
71 crops: &ArrayView2<i32>,
72) -> TractResult<TValue> {
73 let data = input.into_tensor().into_array()?;
74 let input_shape = data.shape().to_vec();
75 let crops: ArrayView2<i32> = crops.view().into_dimensionality()?;
76
77 let block_size = block_shape.iter().map(|a| *a as usize).product::<usize>();
78
79 let mut unflatten_blocked_shape = vec![];
81 unflatten_blocked_shape.extend(block_shape.iter().map(|a| *a as usize));
82 let batches = data.shape()[0] / block_size;
83 unflatten_blocked_shape.push(batches);
84 unflatten_blocked_shape.extend(&data.shape()[1..]);
85 let data = data.into_shape_with_order(&*unflatten_blocked_shape)?;
86 let mut permuted_axes = vec![block_shape.len()];
87 let mut padded_shape = vec![batches];
88 for i in 0..block_shape.shape()[0] {
89 permuted_axes.push(block_shape.len() + 1 + i);
90 permuted_axes.push(i);
91 padded_shape.push(block_shape[i] as usize * input_shape[i + 1]);
92 }
93 permuted_axes.extend((1 + block_shape.len() * 2)..data.ndim());
94 padded_shape.extend(&input_shape[1 + block_shape.len()..]);
95 let data = data.permuted_axes(permuted_axes);
96 let data: Vec<T> = data.iter().copied().collect();
97 let data = tract_ndarray::ArrayD::from_shape_vec(padded_shape, data)?;
98 let mut data = data;
99 for (i, crop) in crops.outer_iter().enumerate() {
100 if crop[0] != 0 || crop[1] != 0 {
101 let end = data.shape()[1 + i];
102 let range = (crop[0] as usize)..(end - crop[1] as usize);
103 data = data.slice_axis(Axis(i + 1), range.into()).map(|x| *x).to_owned();
104 }
105 }
106 Ok(data.into_tvalue())
107}
108
109#[cfg(test)]
110mod tests {
111 #![allow(non_snake_case)]
112 use super::raw::{BatchToSpace, SpaceToBatch};
113 use super::*;
114
115 #[test]
117 fn space_to_batch_nd_1() {
118 assert_eq!(
119 SpaceToBatch::new(i32::datum_type())
120 .eval(tvec![
121 tensor4(&[[[[1i32], [2]], [[3], [4]]]]).into(),
122 tensor1(&[2, 2]).into(),
123 tensor2(&[[0, 0], [0, 0]]).into(),
124 ])
125 .unwrap(),
126 tvec![tensor4(&[[[[1i32]]], [[[2]]], [[[3]]], [[[4]]]]).into()],
127 )
128 }
129
130 #[test]
131 fn space_to_batch_nd_2() {
132 assert_eq!(
133 SpaceToBatch::new(i32::datum_type())
134 .eval(tvec![
135 tensor4(&[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).into(),
136 tensor1(&[2, 2]).into(),
137 tensor2(&[[0, 0], [0, 0]]).into(),
138 ])
139 .unwrap(),
140 tvec![tensor4(&[[[[1i32, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]],])
141 .into(),],
142 )
143 }
144
145 #[test]
146 fn space_to_batch_nd_3() {
147 assert_eq!(
148 SpaceToBatch::new(i32::datum_type())
149 .eval(tvec![
150 tensor4(&[[
151 [[1], [2], [3], [4]],
152 [[5], [6], [7], [8]],
153 [[9], [10], [11], [12]],
154 [[13], [14], [15], [16]],
155 ]])
156 .into(),
157 tensor1(&[2, 2]).into(),
158 tensor2(&[[0, 0], [0, 0]]).into(),
159 ])
160 .unwrap(),
161 tvec![tensor4(&[
162 [[[1], [3]], [[9], [11]]],
163 [[[2], [4]], [[10], [12]]],
164 [[[5], [7]], [[13], [15]]],
165 [[[6], [8]], [[14], [16]]],
166 ])
167 .into()],
168 )
169 }
170
171 #[test]
172 fn space_to_batch_nd_4() {
173 assert_eq!(
174 SpaceToBatch::new(i32::datum_type())
175 .eval(tvec![
176 tensor4(&[
177 [[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
178 [[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
179 ])
180 .into(),
181 tensor1(&[2, 2]).into(),
182 tensor2(&[[0, 0], [2, 0]]).into(),
183 ])
184 .unwrap(),
185 tvec![tensor4(&[
186 [[[0], [1], [3]]],
187 [[[0], [9], [11]]],
188 [[[0], [2], [4]]],
189 [[[0], [10], [12]]],
190 [[[0], [5], [7]]],
191 [[[0], [13], [15]]],
192 [[[0], [6], [8]]],
193 [[[0], [14], [16]]],
194 ])
195 .into(),],
196 )
197 }
198
199 #[test]
200 fn space_to_batch_nd_infer_1() {
201 let mut op = SpaceToBatch::new(f32::datum_type());
202 let data = f32::fact([1, 4, 16]).into();
203 let block_shape = InferenceFact::from(Tensor::from(arr1(&[2])));
204 let paddings = InferenceFact::from(Tensor::from(arr2(&[[0.to_dim(), 0.to_dim()]])));
205 let any = InferenceFact::default();
206
207 let (_, outputs, _) =
208 op.infer_facts(tvec!(&data, &block_shape, &paddings), tvec!(&any), tvec!()).unwrap();
209
210 assert_eq!(outputs[0], f32::fact([2, 2, 16]).into())
211 }
212
213 #[test]
214 fn space_to_batch_nd_infer_2() {
215 let table = SymbolScope::default();
216 let s = table.sym("S");
217 let mut op = SpaceToBatch::new(f32::datum_type());
218 let data = f32::fact(dims!(1, s.to_dim() - 4, 16)).into();
219 let block_shape = InferenceFact::from(Tensor::from(arr1(&[2])));
220 let paddings = InferenceFact::from(Tensor::from(arr2(&[[0.to_dim(), (s.to_dim() % 2)]])));
221 let any = InferenceFact::default();
222
223 let (_, outputs, _) =
224 op.infer_facts(tvec!(&data, &block_shape, &paddings), tvec!(&any), tvec!()).unwrap();
225 assert_eq!(
226 outputs[0],
227 f32::fact(dims!(2, (s.to_dim() + s.to_dim() % 2 - 4) / 2, 16)).into()
228 );
229 }
230
231 #[test]
232 fn batch_to_space_nd_1() {
233 assert_eq!(
234 BatchToSpace::new(i32::datum_type())
235 .eval(tvec![
236 tensor4(&[[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).into(),
237 tensor1(&[2, 2]).into(),
238 tensor2(&[[0, 0], [0, 0]]).into(),
239 ])
240 .unwrap(),
241 tvec![tensor4(&[[[[1], [2]], [[3], [4]]]]).into()]
242 )
243 }
244
245 #[test]
246 fn batch_to_space_nd_2() {
247 assert_eq!(
248 BatchToSpace::new(i32::datum_type())
249 .eval(tvec![
250 tensor4(&[[[[1i32, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]],])
251 .into(),
252 tensor1(&[2, 2]).into(),
253 tensor2(&[[0, 0], [0, 0]]).into(),
254 ])
255 .unwrap(),
256 tvec![tensor4(&[[[[1i32, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).into()]
257 )
258 }
259
260 #[test]
261 fn batch_to_space_nd_3() {
262 assert_eq!(
263 BatchToSpace::new(i32::datum_type())
264 .eval(tvec![
265 tensor4(&[
266 [[[1i32], [3]], [[9], [11]]],
267 [[[2], [4]], [[10], [12]]],
268 [[[5], [7]], [[13], [15]]],
269 [[[6], [8]], [[14], [16]]],
270 ])
271 .into(),
272 tensor1(&[2, 2]).into(),
273 tensor2(&[[0, 0], [0, 0]]).into(),
274 ])
275 .unwrap(),
276 tvec![tensor4(&[[
277 [[1i32], [2], [3], [4]],
278 [[5], [6], [7], [8]],
279 [[9], [10], [11], [12]],
280 [[13], [14], [15], [16]],
281 ]])
282 .into(),]
283 )
284 }
285
286 #[test]
287 fn batch_to_space_nd_4() {
288 assert_eq!(
289 BatchToSpace::new(i32::datum_type())
290 .eval(tvec![
291 tensor4(&[
292 [[[0i32], [1], [3]]],
293 [[[0], [9], [11]]],
294 [[[0], [2], [4]]],
295 [[[0], [10], [12]]],
296 [[[0], [5], [7]]],
297 [[[0], [13], [15]]],
298 [[[0], [6], [8]]],
299 [[[0], [14], [16]]],
300 ])
301 .into(),
302 tensor1(&[2, 2]).into(),
303 tensor2(&[[0, 0], [2, 0]]).into(),
304 ])
305 .unwrap(),
306 tvec![tensor4(&[
307 [[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
308 [[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
309 ])
310 .into(),]
311 )
312 }
313}