tract_hir/ops/array/
permute_axes.rs

1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct PermuteAxes {
6    pub axes: Option<TVec<usize>>,
7}
8
9
10
11impl PermuteAxes {
12    fn compute_shape<D: DimLike>(&self, input: &[D]) -> TractResult<TVec<D>> {
13        if let Some(ref axes) = self.axes {
14            if input.len() != axes.len() {
15                bail!(
16                    "Op expects tensor of rank {}, input is actually of rank {}.",
17                    axes.len(),
18                    input.len()
19                );
20            }
21            let mut new_shape = tvec![D::zero(); input.len()];
22            for (ix, &d) in axes.iter().enumerate() {
23                new_shape[ix] = input[d].clone();
24            }
25            Ok(new_shape)
26        } else {
27            let mut new_shape: TVec<D> = input.iter().cloned().collect();
28            new_shape.reverse();
29            Ok(new_shape)
30        }
31    }
32}
33
34impl Expansion for PermuteAxes {
35    fn name(&self) -> StaticName {
36        "PermuteAxes".into()
37    }
38
39
40    fn info(&self) -> TractResult<Vec<String>> {
41        Ok(vec![format!("{:?}", self.axes)])
42    }
43
44    fn rules<'r, 'p: 'r, 's: 'r>(
45        &'s self,
46        s: &mut Solver<'r>,
47        inputs: &'p [TensorProxy],
48        outputs: &'p [TensorProxy],
49    ) -> InferenceResult {
50        check_output_arity(outputs, 1)?;
51        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
52        s.equals(&outputs[0].rank, &inputs[0].rank)?;
53        s.given(&inputs[0].shape, move |s, shape| {
54            let output_shape = self.compute_shape(&shape)?;
55            s.equals(&outputs[0].shape, output_shape)
56        })?;
57        if let Some(axes) = &self.axes {
58            s.equals(&outputs[0].rank, axes.len() as i64)?;
59        }
60        Ok(())
61    }
62
63    fn wire(
64        &self,
65        prefix: &str,
66        target: &mut TypedModel,
67        inputs: &[OutletId],
68    ) -> TractResult<TVec<OutletId>> {
69        let fact = target.outlet_fact(inputs[0])?;
70        let axes = if let Some(axes) = &self.axes {
71            if fact.rank() != axes.len() {
72                bail!(
73                    "Op expects tensor of rank {}, input is actually of rank {}.",
74                    axes.len(),
75                    fact.rank()
76                );
77            }
78            axes.clone()
79        } else {
80            (0..fact.rank()).rev().collect()
81        };
82        let mut wire: TVec<OutletId> = inputs.into();
83        for (ix, op) in perm_to_ops(&axes).into_iter().enumerate() {
84            wire = target.wire_node(format!("{}.{}-{}", prefix, op.name(), ix), op, &wire)?;
85        }
86        Ok(wire)
87    }
88}