tract_hir/ops/array/
permute_axes.rs1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct PermuteAxes {
6 pub axes: Option<TVec<usize>>,
7}
8
9
10
11impl PermuteAxes {
12 fn compute_shape<D: DimLike>(&self, input: &[D]) -> TractResult<TVec<D>> {
13 if let Some(ref axes) = self.axes {
14 if input.len() != axes.len() {
15 bail!(
16 "Op expects tensor of rank {}, input is actually of rank {}.",
17 axes.len(),
18 input.len()
19 );
20 }
21 let mut new_shape = tvec![D::zero(); input.len()];
22 for (ix, &d) in axes.iter().enumerate() {
23 new_shape[ix] = input[d].clone();
24 }
25 Ok(new_shape)
26 } else {
27 let mut new_shape: TVec<D> = input.iter().cloned().collect();
28 new_shape.reverse();
29 Ok(new_shape)
30 }
31 }
32}
33
34impl Expansion for PermuteAxes {
35 fn name(&self) -> StaticName {
36 "PermuteAxes".into()
37 }
38
39
40 fn info(&self) -> TractResult<Vec<String>> {
41 Ok(vec![format!("{:?}", self.axes)])
42 }
43
44 fn rules<'r, 'p: 'r, 's: 'r>(
45 &'s self,
46 s: &mut Solver<'r>,
47 inputs: &'p [TensorProxy],
48 outputs: &'p [TensorProxy],
49 ) -> InferenceResult {
50 check_output_arity(outputs, 1)?;
51 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
52 s.equals(&outputs[0].rank, &inputs[0].rank)?;
53 s.given(&inputs[0].shape, move |s, shape| {
54 let output_shape = self.compute_shape(&shape)?;
55 s.equals(&outputs[0].shape, output_shape)
56 })?;
57 if let Some(axes) = &self.axes {
58 s.equals(&outputs[0].rank, axes.len() as i64)?;
59 }
60 Ok(())
61 }
62
63 fn wire(
64 &self,
65 prefix: &str,
66 target: &mut TypedModel,
67 inputs: &[OutletId],
68 ) -> TractResult<TVec<OutletId>> {
69 let fact = target.outlet_fact(inputs[0])?;
70 let axes = if let Some(axes) = &self.axes {
71 if fact.rank() != axes.len() {
72 bail!(
73 "Op expects tensor of rank {}, input is actually of rank {}.",
74 axes.len(),
75 fact.rank()
76 );
77 }
78 axes.clone()
79 } else {
80 (0..fact.rank()).rev().collect()
81 };
82 let mut wire: TVec<OutletId> = inputs.into();
83 for (ix, op) in perm_to_ops(&axes).into_iter().enumerate() {
84 wire = target.wire_node(format!("{}.{}-{}", prefix, op.name(), ix), op, &wire)?;
85 }
86 Ok(wire)
87 }
88}