tract_core/ops/
change_axes.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use tract_itertools::Itertools;
8use tract_ndarray::{ArrayViewD, ArrayViewMutD};
9use AxisOp::*;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub enum InOut {
13    Out(usize),
14    In(usize),
15}
16
17impl InOut {
18    pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
19        match self {
20            InOut::In(ix) => node.inputs[*ix],
21            InOut::Out(ix) => OutletId::new(node.id, *ix),
22        }
23    }
24}
25
26#[derive(Clone, Hash, Eq)]
27#[allow(clippy::large_enum_variant)] // FIXME ?
28#[allow(clippy::derived_hash_with_manual_eq)] // FIXME. this one may be pretty bad. how about a.canonical() == b.canonical() ? need proper canonicalizeation of Reshape
29pub enum AxisOp {
30    Add(usize),
31    Rm(usize),
32    Move(usize, usize),
33    Reshape(usize, TVec<TDim>, TVec<TDim>),
34}
35
36impl Debug for AxisOp {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            AxisOp::Add(a) => write!(f, "Add({a})"),
40            AxisOp::Rm(a) => write!(f, "Rm({a})"),
41            AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
42            AxisOp::Reshape(at, from, to) => {
43                write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
44            }
45        }
46    }
47}
48
49impl PartialEq for AxisOp {
50    fn eq(&self, other: &AxisOp) -> bool {
51        if self.is_noop() && other.is_noop() {
52            true
53        } else if self.is_noop() != other.is_noop() {
54            false
55        } else {
56            match (self, other) {
57                (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
58                (Move(f1, t1), Move(f2, t2)) => {
59                    (f1 == f2 && t1 == t2)
60                        || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
61                }
62                (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
63                _ => false,
64            }
65        }
66    }
67}
68
69impl AxisOp {
70    pub fn canonical(&self) -> Cow<AxisOp> {
71        match self {
72            Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
73            Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[0] => {
74                Cow::Owned(Add(*at + 1))
75            }
76            Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[1] => {
77                Cow::Owned(Add(*at))
78            }
79            Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[0] == to[0] => {
80                Cow::Owned(Rm(*at + 1))
81            }
82            Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[1] == to[0] => {
83                Cow::Owned(Rm(*at))
84            }
85            other => Cow::Borrowed(other),
86        }
87    }
88
89    pub fn simplify(&self) -> TVec<AxisOp> {
90        match self.canonical().borrow() {
91            Reshape(_, from, to) if from == to => tvec!(),
92            Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
93            Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
94            Reshape(at, from, to) if from[0] == to[0] => {
95                Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
96            }
97            Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
98                Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
99            }
100            Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
101                .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
102                .collect(),
103            Reshape(at, from, to) if to[0] == 1.to_dim() => {
104                Reshape(*at, from.clone(), to[1..].into())
105                    .simplify()
106                    .into_iter()
107                    .chain(std::iter::once(Add(*at)))
108                    .collect()
109            }
110            Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
111                std::iter::once(Rm(at + from.len() - 1))
112                    .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
113                    .collect()
114            }
115            Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
116                std::iter::once(Add(at + from.len()))
117                    .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
118                    .collect()
119            }
120            other => tvec!(other.clone()),
121        }
122    }
123
124    pub fn transform_axis(&self, axis: usize) -> Option<usize> {
125        match self.canonical().as_ref() {
126            Add(ix) => Some(axis + (axis >= *ix) as usize),
127            Rm(ix) => {
128                if axis == *ix {
129                    None
130                } else {
131                    Some(axis - (axis > *ix) as usize)
132                }
133            }
134            Move(from, to) if from < to => {
135                if axis < *from || axis > *to {
136                    Some(axis)
137                } else if axis == *from {
138                    Some(*to)
139                } else {
140                    Some(axis - 1)
141                }
142            }
143            Move(from, to) => {
144                if axis < *to || axis > *from {
145                    Some(axis)
146                } else if axis == *from {
147                    Some(*to)
148                } else {
149                    Some(axis + 1)
150                }
151            }
152            Reshape(at, _, _) if axis < *at => Some(axis),
153            Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
154            Reshape(_, _, _) => None,
155        }
156    }
157
158    // if sucessful return Some()
159    // first item is the Op we want to be replaced by. if none, we are now identity.
160    // second item is the change to propagate. if none, the output is not
161    // changed
162    pub fn merge_incoming_change(
163        &self,
164        change: &AxisOp,
165    ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
166        match (self.canonical().as_ref(), change.canonical().as_ref()) {
167            (Add(op), Add(c)) => {
168                Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
169            }
170            (Add(op), Rm(c)) => {
171                Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
172            }
173            (Rm(op), Add(c)) => {
174                Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
175            }
176            (Rm(op), Rm(c)) => {
177                Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
178            }
179
180            (Add(x), Move(from, to)) => {
181                if x <= from.min(to) {
182                    Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
183                } else if x > from.max(to) {
184                    Some((Some(self.clone()), Some(change.clone())))
185                } else {
186                    None
187                }
188            }
189
190            (Move(from, to), Add(x)) => {
191                if x <= from.min(to) {
192                    Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
193                } else if x > from.max(to) {
194                    Some((Some(Move(*from, *to)), Some(Add(*x))))
195                } else {
196                    None
197                }
198            }
199
200            (Rm(x), Move(from, to)) => {
201                if x == from {
202                    Some((Some(Rm(*to)), None))
203                } else if x < from.min(to) {
204                    Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
205                } else if x > from.max(to) {
206                    Some((Some(self.clone()), Some(change.clone())))
207                } else if from + 1 == *to && x == to {
208                    Some((Some(Rm(*from)), None))
209                } else if from < to && x <= to {
210                    Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
211                } else {
212                    Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
213                }
214            }
215
216            (Move(from, to), Rm(x)) => {
217                if x < from.min(to) {
218                    Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
219                } else if x > from.max(to) {
220                    Some((Some(Move(*from, *to)), Some(Rm(*x))))
221                } else {
222                    None
223                }
224            }
225
226            (Add(op), Reshape(at, from, to)) => {
227                if op <= at {
228                    Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
229                } else if *op > at + from.len() {
230                    Some((
231                        Some(Add(*op + to.len() - from.len())),
232                        Some(Reshape(*at, from.clone(), to.clone())),
233                    ))
234                } else {
235                    None
236                }
237            }
238            (Rm(op), Reshape(at, from, to)) => {
239                if op < at {
240                    Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
241                } else if *op > at + from.len() {
242                    Some((
243                        Some(Rm(*op + to.len() - from.len())),
244                        Some(Reshape(*at, from.clone(), to.clone())),
245                    ))
246                } else {
247                    None
248                }
249            }
250            (Reshape(at, from, to), Add(change)) => {
251                if change < at {
252                    Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
253                } else if *change > *at + from.len() {
254                    Some((
255                        Some(Reshape(*at, from.clone(), to.clone())),
256                        Some(Add(change + to.len() - from.len())),
257                    ))
258                } else {
259                    None
260                }
261            }
262            (Reshape(at, from, to), Rm(change)) => {
263                if change < at {
264                    Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
265                } else if *change > *at + from.len() {
266                    Some((
267                        Some(Reshape(*at, from.clone(), to.clone())),
268                        Some(Rm(change + to.len() - from.len())),
269                    ))
270                } else {
271                    None
272                }
273            }
274            (Reshape(_, _, _), Move(_, _)) => None, // todo, some are manageable
275            (Move(_, _), Reshape(_, _, _)) => None, // todo, some are manageable
276            (Reshape(_, _, _), Reshape(_, _, _)) => None, // todo, some are manageable
277            _ => None,
278        }
279    }
280
281    pub fn change_shape_array<D: DimLike>(
282        &self,
283        shape: &mut TVec<D>,
284        broadcasting: bool,
285    ) -> TractResult<()> {
286        match self.canonical().as_ref() {
287            Add(ix) => {
288                ensure!(*ix <= shape.len());
289                shape.insert(*ix, D::one());
290            }
291            Rm(ix) => {
292                ensure!(*ix < shape.len());
293                shape.remove(*ix);
294            }
295            Move(from, to) => {
296                ensure!(*from < shape.len());
297                ensure!(*to < shape.len());
298                let axis = shape.remove(*from);
299                shape.insert(*to, axis);
300            }
301            Reshape(at, from, to) => {
302                let from_volume = from.iter().product::<TDim>();
303                let to_volume = to.iter().product::<TDim>();
304                ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
305                ensure!(*at + from.len() <= shape.len());
306                if shape.len() >= from.len() + *at
307                    && tract_itertools::izip!(shape.iter().skip(*at), from)
308                        .all(|(shape, spec)| shape.to_dim() == *spec)
309                {
310                    for _ in from {
311                        shape.remove(*at);
312                    }
313                    for d in to.iter().rev() {
314                        shape.insert(*at, d.try_into()?);
315                    }
316                } else if broadcasting
317                    && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
318                {
319                    for _ in from {
320                        shape.remove(*at);
321                    }
322                    for _ in to.iter().rev() {
323                        shape.insert(*at, 1.into());
324                    }
325                } else {
326                    bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
327                }
328            }
329        }
330        Ok(())
331    }
332
333    pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
334        match self.canonical().as_ref() {
335            Add(ix) => shape.insert_axis(*ix),
336            Rm(ix) => {
337                if shape.rank() <= *ix {
338                    bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
339                }
340                if shape[*ix] != 1.to_dim() {
341                    bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
342                }
343                shape.remove_axis(*ix)
344            }
345            _ => {
346                let mut array = shape.to_tvec();
347                self.change_shape_array(&mut array, broadcasting)?;
348                let mut new_shape = ShapeFact::from_dims(array);
349                std::mem::swap(shape, &mut new_shape);
350                Ok(())
351            }
352        }
353    }
354
355    pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
356        match self.canonical().as_ref() {
357            Add(ix) => {
358                ensure!(*ix <= tensor.rank());
359                tensor.insert_axis(*ix)
360            }
361            Rm(ix) => {
362                ensure!(*ix < tensor.rank());
363                tensor.remove_axis(*ix)
364            }
365            Move(from, to) => {
366                ensure!(*from < tensor.rank());
367                ensure!(*to < tensor.rank());
368                let mut tmp = tensor.clone().move_axis(*from, *to)?;
369                std::mem::swap(tensor, &mut tmp);
370                Ok(())
371            }
372            Reshape(at, from, to) => {
373                let mut shape: TVec<usize> = tensor.shape().into();
374                self.change_shape_array(&mut shape, false)?;
375                if tensor.set_shape(&shape).is_ok() {
376                    Ok(())
377                } else if broadcasting
378                    && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
379                {
380                    if from.len() > to.len() {
381                        for _ in to.len()..from.len() {
382                            tensor.remove_axis(*at)?;
383                        }
384                    }
385                    if to.len() > from.len() {
386                        for _ in from.len()..to.len() {
387                            tensor.insert_axis(*at)?;
388                        }
389                    }
390                    Ok(())
391                } else {
392                    bail!(
393                        "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
394                        self,
395                        tensor,
396                        broadcasting
397                    )
398                }
399            }
400        }
401    }
402
403    pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
404        use tract_ndarray::Axis;
405        match *self {
406            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
407            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
408            AxisOp::Move(from, to) if from < to => {
409                for left in from..to {
410                    view.swap_axes(left, left + 1);
411                }
412            }
413            AxisOp::Move(from, to) => {
414                for left in (to..from).rev() {
415                    view.swap_axes(left, left + 1);
416                }
417            }
418            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
419        }
420        Ok(())
421    }
422
423    pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
424        use tract_ndarray::Axis;
425        match *self {
426            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
427            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
428            AxisOp::Move(from, to) if from < to => {
429                for left in from..to {
430                    view.swap_axes(left, left + 1);
431                }
432            }
433            AxisOp::Move(from, to) => {
434                for left in (to..from).rev() {
435                    view.swap_axes(left, left + 1);
436                }
437            }
438            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
439        }
440        Ok(())
441    }
442
443    pub fn recip(&self) -> AxisOp {
444        match self.canonical().as_ref() {
445            Add(ix) => Rm(*ix),
446            Rm(ix) => Add(*ix),
447            Move(from, to) if from == to => self.clone(),
448            Move(from, to) if *from + 1 == *to => self.clone(),
449            Move(from, to) if *from == *to + 1 => {
450                unreachable!();
451            }
452            Move(from, to) => Move(*to, *from),
453            Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
454        }
455    }
456
457    pub fn is_noop(&self) -> bool {
458        match self {
459            Move(f, t) if f == t => true,
460            Reshape(_, f, t) if f == t => true,
461            _ => false,
462        }
463    }
464
465    pub fn only_shape(&self) -> bool {
466        if self.is_noop() {
467            return true;
468        }
469        !matches!(self, Move(_, _))
470    }
471
472    pub fn wire_split_axis(
473        model: &mut TypedModel,
474        name: impl ToString,
475        outlet: OutletId,
476        axis: usize,
477        outer_dim: usize,
478    ) -> TractResult<TVec<OutletId>> {
479        let fact = model.outlet_fact(outlet)?;
480        let dim: TDim = fact.shape[axis].clone();
481        let inner_dim = dim.clone() / outer_dim;
482        let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
483        model.wire_node(name.to_string(), op, &[outlet])
484    }
485
486    pub fn wire_collapse_axis(
487        model: &mut TypedModel,
488        name: impl ToString,
489        outlet: OutletId,
490        axis: usize,
491    ) -> TractResult<TVec<OutletId>> {
492        let fact = model.outlet_fact(outlet)?;
493        let dim: TDim = fact.shape[axis].clone();
494        let next_dim: TDim = fact.shape[axis + 1].clone();
495        let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
496        model.wire_node(name.to_string(), op, &[outlet])
497    }
498}
499
500pub fn wire_rank_broadcast(
501    prefix: impl AsRef<str>,
502    target: &mut TypedModel,
503    inputs: &[OutletId],
504) -> TractResult<TVec<OutletId>> {
505    let facts =
506        inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
507    let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
508    let mut wires = tvec!();
509    let prefix = prefix.as_ref();
510    for i in 0..inputs.len() {
511        let mut wire = inputs[i];
512        for j in facts[i].rank()..max_rank {
513            wire =
514                target.wire_node(format!("{prefix}.fix-rank-{i}-{j}"), AxisOp::Add(0), &[wire])?[0];
515        }
516        wires.push(wire);
517    }
518    Ok(wires)
519}
520
521pub fn wire_with_rank_broadcast(
522    prefix: impl AsRef<str>,
523    target: &mut TypedModel,
524    op: impl Into<Box<dyn TypedOp>>,
525    inputs: &[OutletId],
526) -> TractResult<TVec<OutletId>> {
527    let prefix = prefix.as_ref();
528    let wires = wire_rank_broadcast(prefix, target, inputs)?;
529    target.wire_node(prefix, op.into(), &wires)
530}
531
532#[derive(Clone, Debug, PartialEq, Eq, Hash)]
533pub struct AxisChange {
534    pub outlet: OutletId,
535    pub op: AxisOp,
536}
537
538#[derive(Clone, Default, Debug)]
539pub struct AxisChangeConsequence {
540    pub substitute_op: Option<Box<dyn TypedOp>>,
541    pub wire_changes: TVec<(InOut, AxisOp)>,
542}
543
544impl AxisChangeConsequence {
545    pub fn new(
546        _model: &TypedModel,
547        node: &TypedNode,
548        op: Option<Box<dyn TypedOp>>,
549        axis_op: &AxisOp,
550    ) -> AxisChangeConsequence {
551        let mut wire_changes = tvec!();
552        for i in 0..node.inputs.len() {
553            wire_changes.push((InOut::In(i), axis_op.clone()));
554        }
555        for i in 0..node.outputs.len() {
556            wire_changes.push((InOut::Out(i), axis_op.clone()));
557        }
558        AxisChangeConsequence { wire_changes, substitute_op: op }
559    }
560}
561
562impl Op for AxisOp {
563    fn name(&self) -> Cow<str> {
564        match self {
565            Add(_) => "AddAxis".into(),
566            Rm(_) => "RmAxis".into(),
567            Move(_, _) => "MoveAxis".into(),
568            Reshape(_, _, _) => "Reshape".into(),
569        }
570    }
571
572    fn info(&self) -> TractResult<Vec<String>> {
573        match self {
574            Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
575            Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
576            Reshape(at, from, to) => Ok(vec![format!(
577                "Axes starting at {}: {:?} to {:?}",
578                at,
579                from.iter().join(","),
580                to.iter().join(",")
581            )]),
582        }
583    }
584
585    op_as_typed_op!();
586}
587
588impl EvalOp for AxisOp {
589    fn is_stateless(&self) -> bool {
590        true
591    }
592
593    fn eval_with_session(
594        &self,
595        session: &SessionState,
596        inputs: TVec<TValue>,
597    ) -> TractResult<TVec<TValue>> {
598        let mut input = args_1!(inputs).into_tensor();
599        match self {
600            AxisOp::Reshape(skip, from, to) => {
601                let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
602                let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
603                AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
604            }
605            _ => self.change_tensor(&mut input, false)?,
606        }
607        Ok(tvec!(input.into_tvalue()))
608    }
609}
610
611impl TypedOp for AxisOp {
612    as_op!();
613
614    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
615        let mut shape = inputs[0].shape.clone();
616        self.change_shape(&mut shape, false)
617            .with_context(|| format!("Applying {self:?} to {:?}", inputs[0]))?;
618        let mut fact = inputs[0].datum_type.fact(shape);
619        fact.opaque_fact.clone_from(&inputs[0].opaque_fact);
620        Ok(tvec!(fact))
621    }
622
623    fn axes_mapping(
624        &self,
625        inputs: &[&TypedFact],
626        outputs: &[&TypedFact],
627    ) -> TractResult<AxesMapping> {
628        let mut axes: Vec<Axis> = (0..inputs[0].rank())
629            .zip('a'..)
630            .map(|(axis_id, repr)| {
631                let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
632                if let Some(out) = self.transform_axis(axis_id) {
633                    axis = axis.output(0, out);
634                }
635                axis
636            })
637            .collect();
638        for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
639            if self.recip().transform_axis(axis).is_none() {
640                axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
641            }
642        }
643        AxesMapping::new(inputs.len(), outputs.len(), axes)
644    }
645
646    fn declutter(
647        &self,
648        model: &TypedModel,
649        node: &TypedNode,
650    ) -> TractResult<Option<TypedModelPatch>> {
651        if self.is_noop() {
652            if let Some(p) = TypedModelPatch::shunt_one_op(model, node)? {
653                return Ok(Some(p));
654            }
655        }
656        let simplified = self.simplify();
657        if simplified.len() != 1 || &simplified[0] != self {
658            let mut patch = TypedModelPatch::default();
659            let mut wire = patch.tap_model(model, node.inputs[0])?;
660            for (ix, op) in simplified.into_iter().enumerate() {
661                wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
662            }
663            patch.shunt_outside(model, node.id.into(), wire)?;
664            Ok(Some(patch))
665        } else {
666            Ok(None)
667        }
668    }
669
670    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
671        Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
672    }
673
674    fn change_axes(
675        &self,
676        _model: &TypedModel,
677        _node: &TypedNode,
678        io: InOut,
679        change: &AxisOp,
680    ) -> TractResult<Option<AxisChangeConsequence>> {
681        let op = if let InOut::Out(0) = io {
682            let more = if let Some(more) =
683                self.recip().change_axes(_model, _node, InOut::In(0), change)?
684            {
685                more
686            } else {
687                return Ok(None);
688            };
689            AxisChangeConsequence {
690                substitute_op: more.substitute_op.map(|op| {
691                    if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
692                        Box::new(op.recip())
693                    } else {
694                        op // have to be identity
695                    }
696                }),
697                wire_changes: more
698                    .wire_changes
699                    .into_iter()
700                    .map(|wc| {
701                        (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
702                    })
703                    .collect(),
704            }
705        } else if change == self {
706            AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
707        } else {
708            let (new_op, new_change) = if let Some(pair) = self.merge_incoming_change(change) {
709                pair
710            } else {
711                return Ok(None);
712            };
713            trace!(
714                "  Change:{:?} self:{:?} -> change:{:?} op:{:?}",
715                change,
716                self,
717                new_change,
718                new_op
719            );
720            let substitute_op: Box<dyn TypedOp> =
721                if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
722            let mut wire_changes = tvec!();
723            if !change.is_noop() {
724                wire_changes.push((InOut::In(0), change.clone()))
725            }
726            if let Some(new_change) = new_change {
727                wire_changes.push((InOut::Out(0), new_change))
728            }
729            AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
730        };
731        Ok(Some(op))
732    }
733
734    fn concretize_dims(
735        &self,
736        _source: &TypedModel,
737        node: &TypedNode,
738        target: &mut TypedModel,
739        mapping: &HashMap<OutletId, OutletId>,
740        values: &SymbolValues,
741    ) -> TractResult<TVec<OutletId>> {
742        let op = if let AxisOp::Reshape(axis, from, to) = self {
743            AxisOp::Reshape(
744                *axis,
745                from.iter().map(|d| d.eval(values)).collect(),
746                to.iter().map(|d| d.eval(values)).collect(),
747            )
748        } else {
749            self.clone()
750        };
751        target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
752    }
753
754    fn codegen(
755        &self,
756        model: &TypedModel,
757        node: &TypedNode,
758    ) -> TractResult<Option<TypedModelPatch>> {
759        if let Some(shape) = node.outputs[0].fact.shape.as_concrete() {
760            if !matches!(self, AxisOp::Move(_, _)) {
761                let (inputs, outputs) = model.node_facts(node.id)?;
762                let mapping = self.axes_mapping(&inputs, &outputs)?;
763                let op = IntoShape {
764                    mapping,
765                    len: shape.iter().product(),
766                    strides: Tensor::natural_strides(shape),
767                    dims: shape.into(),
768                };
769                return Ok(Some(TypedModelPatch::replace_single_op(
770                    model,
771                    node,
772                    &node.inputs,
773                    op,
774                )?));
775            }
776        }
777        Ok(None)
778    }
779}
780
781// a, b, c is a <- b, b <- c, c <- a
782fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
783    let mut cycles: TVec<TVec<usize>> = tvec!();
784    let mut done = 0;
785    while done < perm.len() {
786        if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
787            done += 1;
788            continue;
789        }
790        let mut cycle = tvec!();
791        let mut current = done;
792        loop {
793            cycle.push(current);
794            current = perm[current];
795            if current == done {
796                break;
797            }
798        }
799        cycles.push(cycle)
800    }
801    cycles
802}
803
804fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
805    if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
806        Some((cycle[0], cycle[cycle.len() - 1]))
807    } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
808        && cycle[cycle.len() - 1] - 1 == cycle[0]
809    {
810        Some((cycle[1], cycle[0]))
811    } else {
812        None
813    }
814}
815
816fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
817    let mut changes: TVec<(usize, usize)> = tvec!();
818    'top: loop {
819        let mut reached: TVec<usize> = (0..input.len()).collect();
820        changes.iter().for_each(|(f, t)| {
821            let axis = reached.remove(*f);
822            reached.insert(*t, axis);
823        });
824        if &*reached == input {
825            return changes;
826        }
827        let remaining: TVec<usize> =
828            input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
829        let cycles = perm_to_cycles(&remaining);
830        for cycle in &cycles {
831            if let Some(rot) = is_rotation_cycle(cycle) {
832                changes.push(rot);
833                continue 'top;
834            }
835        }
836        changes.push((cycles[0][1], cycles[0][0]));
837    }
838}
839
840pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
841    perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
842}
843
844pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
845    let mut shape: TVec<TDim> = shape_spec.into();
846    fn deal_with_zero<'a>(
847        mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
848        shape: &mut [TDim],
849    ) -> TractResult<()> {
850        let mut remaining_dim_input = 1.to_dim();
851        for slot in shape.iter_mut() {
852            if *slot == (-1).into() {
853                break;
854            }
855            if *slot == 0.into() {
856                if remaining_dim_input != TDim::one() {
857                    bail!("Invalid remaining dim");
858                }
859                *slot = (*input_dims.peek().context("Invalid")?).clone();
860            }
861            loop {
862                let quotient = remaining_dim_input.maybe_div(slot);
863                if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
864                    remaining_dim_input *= input_dims.next().context("Invalid")?;
865                } else {
866                    break;
867                }
868            }
869            remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
870        }
871        Ok(())
872    }
873
874    deal_with_zero(input.iter().peekable(), &mut shape)?;
875    shape.reverse();
876    deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
877    shape.reverse();
878
879    if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
880        let input_vol: TDim = input.iter().product();
881        let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
882        let div = input_vol.maybe_div(&shape_vol)?;
883        if div.1 != 1 {
884            bail!("invalid")
885        }
886        shape[pos] = div.0;
887    }
888    Ok(shape)
889}
890
891pub fn to_axis_ops_with_tf_rules(
892    input_orig: &[TDim],
893    output_spec: &[TDim],
894) -> TractResult<TVec<AxisOp>> {
895    let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
896    let mut stack: TVec<AxisOp> = tvec!();
897    'top: loop {
898        let current_input =
899            stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
900                op.change_shape_array(&mut shape, false)?;
901                Ok(shape)
902            })?;
903        if current_input == final_output {
904            return Ok(stack);
905        }
906        if let Some(common) =
907            current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
908        {
909            if current_input[common].is_one() {
910                stack.push(AxisOp::Rm(common));
911            } else if final_output[common].is_one() {
912                stack.push(AxisOp::Add(common));
913            } else {
914                // actual regrouping. search for a match. this is quadratic, but
915                // rank is expected to be somewhat reasonable
916                for i in common..current_input.len() {
917                    let i_group = &current_input[common..i + 1];
918                    let i_volume: TDim = i_group.iter().product();
919                    for o in common..final_output.len() {
920                        let o_group = &final_output[common..o + 1];
921                        let o_volume: TDim = o_group.iter().product();
922                        if i_volume == o_volume {
923                            stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
924                            continue 'top;
925                        }
926                    }
927                }
928                todo!()
929            }
930        } else if final_output.len() > current_input.len() {
931            stack.push(AxisOp::Add(current_input.len()));
932        } else {
933            stack.push(AxisOp::Rm(current_input.len() - 1));
934        }
935    }
936}
937
938#[derive(Clone, Debug, PartialEq, Eq, Hash)]
939pub struct IntoShape {
940    pub mapping: AxesMapping,
941    pub len: usize,
942    pub dims: TVec<usize>,
943    pub strides: TVec<isize>,
944}
945
946impl Op for IntoShape {
947    fn name(&self) -> Cow<str> {
948        "IntoShape".into()
949    }
950
951    fn info(&self) -> TractResult<Vec<String>> {
952        Ok(vec![format!("{}", self.mapping)])
953    }
954
955    op_as_typed_op!();
956    impl_op_same_as!();
957}
958
959impl EvalOp for IntoShape {
960    fn is_stateless(&self) -> bool {
961        true
962    }
963
964    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
965        let mut input = args_1!(inputs).into_tensor();
966        ensure!(input.len() == self.len);
967        unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
968        Ok(tvec!(input.into_tvalue()))
969    }
970}
971
972impl TypedOp for IntoShape {
973    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
974        Ok(tvec!(inputs[0].datum_type.fact(&self.dims)))
975    }
976
977    fn declutter(
978        &self,
979        model: &TypedModel,
980        node: &TypedNode,
981    ) -> TractResult<Option<TypedModelPatch>> {
982        let input = model.outlet_fact(node.inputs[0])?;
983        if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
984            return TypedModelPatch::shunt_one_op(model, node);
985        }
986        if let Some(succ) = model.single_succ(node.id)? {
987            if let Some(into_shape) = succ.op_as::<IntoShape>() {
988                let op = Self {
989                    mapping: self.mapping.compose(&into_shape.mapping)?,
990                    ..into_shape.clone()
991                };
992                return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
993            }
994        }
995        Ok(None)
996    }
997
998    as_op!();
999}
1000
1001#[cfg(test)]
1002mod test {
1003    use super::*;
1004
1005    #[test]
1006    fn test_perm_to_cycles() {
1007        assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1008        assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1009        assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1010        assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1011        assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1012    }
1013
1014    #[test]
1015    fn is_rotation() {
1016        assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1017        assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1018    }
1019
1020    #[test]
1021    fn test_perm_one_rotation() {
1022        assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1023    }
1024
1025    #[test]
1026    fn test_perm_two_rotations() {
1027        assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1028    }
1029
1030    #[test]
1031    fn test_perm_complex() {
1032        assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1033    }
1034
1035    // ADD-ADD
1036
1037    //                          Op
1038    //           b,c   ------|Add(0)|----->        n,b,c
1039    //   Add(0)                                            Add(1)
1040    //         a,b,c   ------|Add(0)|----->        a,n,b,c
1041    #[test]
1042    pub fn transform_op_add_0_add_0() {
1043        let change = Add(0);
1044        let op = Add(0);
1045        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1046    }
1047
1048    //                          Op
1049    //           b,c   ------|Add(1)|----->        b,n,c
1050    //   Add(0)                                                 Add(0)
1051    //         a,b,c   ------|Add(2)|----->        a,b,n,c
1052    #[test]
1053    pub fn transform_op_add_0_add_1() {
1054        let change = Add(0);
1055        let op = Add(1);
1056        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1057    }
1058
1059    //                          Op
1060    //           a,c   ------|Add(0)|----->        n,a,c
1061    //   Add(1)                                                 Add(2)
1062    //         a,b,c   ------|Add(0)|----->        n,a,b,c
1063    #[test]
1064    pub fn transform_op_add_1_add_0() {
1065        let change = Add(1);
1066        let op = Add(0);
1067        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1068    }
1069
1070    //                          Op
1071    //         a,b,c   ------|Rm(1)|----->         a,c
1072    //   Rm(0)                                             Rm(0)
1073    //           b,c   ------|Rm(0)|----->         c
1074    #[test]
1075    pub fn transform_op_rm_0_rm_1() {
1076        let change = Rm(0);
1077        let op = Rm(1);
1078        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1079    }
1080
1081    //                          Op
1082    //         a,b,c   ------|Rm(0)|----->         b,c
1083    //   Rm(1)                                             Rm(0)
1084    //           a,c   ------|Rm(0)|----->         c
1085    #[test]
1086    pub fn transform_op_rm_1_rm_0() {
1087        let change = Rm(1);
1088        let op = Rm(0);
1089        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1090    }
1091
1092    // ADD - RM
1093
1094    //                          Op
1095    //          b,c     ------|Rm(0)|------>        c
1096    //   Add(0)                                                 Add(0)
1097    //          a,b,c   ------|Rm(1)|----->         a,c
1098    #[test]
1099    pub fn transform_op_add_0_rm_0() {
1100        let change = Add(0);
1101        let op = Rm(0);
1102        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1103    }
1104
1105    //                          Op
1106    //          b,c     ------|Rm(1)|------>        b
1107    //   Add(0)                                                 Add(0)
1108    //          a,b,c   ------|Rm(2)|----->         a,b
1109    #[test]
1110    pub fn transform_op_add_0_rm_1() {
1111        let change = Add(0);
1112        let op = Rm(1);
1113        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1114    }
1115
1116    //                          Op
1117    //          a,c     ------|Rm(0)|------>        c
1118    //   Add(1)                                                 Add(0)
1119    //          a,b,c   ------|Rm(0)|----->         b,c
1120    #[test]
1121    pub fn transform_op_add_1_rm_0() {
1122        let change = Add(1);
1123        let op = Rm(0);
1124        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1125    }
1126
1127    // RM - ADD
1128
1129    //                          Op
1130    //         a,b,c   ------|Add(0)|----->        X,a,b,c
1131    //   Rm(1)                                                 Rm(2)
1132    //           a,c   ------|Add(0)|----->        X,a,c
1133    #[test]
1134    pub fn transform_op_rm_1_add_0() {
1135        let change = Rm(1);
1136        let op = Add(0);
1137        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1138    }
1139
1140    //                          Op
1141    //         a,b,c   ------|Add(1)|----->        a,X,b,c
1142    //   Rm(0)                                                 Rm(0)
1143    //           b,c   ------|Add(0)|----->        X,b,c
1144    #[test]
1145    pub fn transform_op_rm_0_add_1() {
1146        let change = Rm(0);
1147        let op = Add(1);
1148        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1149    }
1150
1151    //                          Op
1152    //         a,b,c   ------|Rm(2)|----->        a,b
1153    //   Move(0, 2)                                           Move(0,1)
1154    //         b,c,a   ------|Rm(1)|----->        b,a
1155    #[test]
1156    pub fn transform_op_mv_02_rm_2() {
1157        let change = Move(0, 2);
1158        let op = Rm(2);
1159        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1160    }
1161}
1162
1163#[cfg(test)]
1164mod proptests {
1165    use super::*;
1166    use proptest::prelude::*;
1167
1168    #[derive(Debug)]
1169    struct ComposeProblem {
1170        input: TVec<usize>,
1171        ops: TVec<AxisOp>,
1172    }
1173
1174    impl Arbitrary for AxisOp {
1175        type Parameters = TVec<usize>;
1176        type Strategy = BoxedStrategy<AxisOp>;
1177        fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1178            let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1179            if shape.len() > 1 {
1180                ops = ops
1181                    .prop_union(
1182                        (0..shape.len(), 0..shape.len() - 1)
1183                            .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1184                            .boxed(),
1185                    )
1186                    .boxed()
1187            }
1188            let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1189            if rms.len() > 0 {
1190                ops = ops
1191                    .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1192                    .boxed()
1193            }
1194            let mergeable: Vec<AxisOp> = shape
1195                .windows(2)
1196                .enumerate()
1197                .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1198                .map(|(ix, w)| {
1199                    Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1200                })
1201                .collect();
1202            if mergeable.len() > 1 {
1203                ops = ops
1204                    .prop_union(
1205                        (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1206                    )
1207                    .boxed()
1208            }
1209            ops
1210        }
1211    }
1212
1213    impl Arbitrary for ComposeProblem {
1214        type Parameters = ();
1215        type Strategy = BoxedStrategy<ComposeProblem>;
1216        fn arbitrary_with(_args: ()) -> Self::Strategy {
1217            let input = proptest::collection::vec(1usize..4, 1usize..4);
1218            fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1219                if len == 0 {
1220                    Just(tvec!()).boxed()
1221                } else {
1222                    AxisOp::arbitrary_with(shape.clone())
1223                        .prop_flat_map(move |op| {
1224                            let mut shape = shape.clone();
1225                            op.change_shape_array(&mut shape, false).unwrap();
1226                            tail(len - 1, shape.clone()).prop_map(move |mut t| {
1227                                t.insert(0, op.clone());
1228                                t
1229                            })
1230                        })
1231                        .boxed()
1232                }
1233            }
1234            (input, 1usize..=5)
1235                .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1236                .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1237                .boxed()
1238        }
1239    }
1240
1241    impl ComposeProblem {
1242        pub fn model(&self) -> TractResult<TypedModel> {
1243            let mut model = TypedModel::default();
1244            let mut wire = model.add_source("source", i64::fact(&self.input))?;
1245            for (ix, op) in self.ops.iter().enumerate() {
1246                wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1247            }
1248            model.set_output_outlets(&[wire])?;
1249            Ok(model)
1250        }
1251
1252        fn input(&self) -> TractResult<Tensor> {
1253            unsafe {
1254                let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1255                for i in 0..t.len() {
1256                    t.as_slice_mut().unwrap()[i] = i as i64;
1257                }
1258                Ok(t)
1259            }
1260        }
1261
1262        fn check(&self) -> TractResult<()> {
1263            crate::setup_test_logger();
1264            let input = self.input()?;
1265            let model = self.model()?;
1266            let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1267            let optimized = self.model()?.into_decluttered()?;
1268            let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1269            opt[0].close_enough(&raw[0], false)
1270        }
1271    }
1272
1273    proptest! {
1274        #[test]
1275        fn recip(pb in any::<AxisOp>()) {
1276            assert_eq!(pb.recip().recip(), pb);
1277        }
1278
1279        #[test]
1280        fn axis_ops(pb in any::<ComposeProblem>()) {
1281            pb.check().unwrap()
1282        }
1283    }
1284
1285    #[test]
1286    fn add_0_rm_0() {
1287        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1288        pb.check().unwrap();
1289    }
1290
1291    #[test]
1292    fn add_0_move_01() {
1293        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1294        pb.check().unwrap();
1295    }
1296
1297    #[test]
1298    fn add_0_move_01_add_1() {
1299        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1300        pb.check().unwrap();
1301    }
1302
1303    #[test]
1304    fn recip_move_01() {
1305        let op = Move(1, 0);
1306        assert_eq!(op.recip().recip(), op);
1307    }
1308
1309    #[test]
1310    fn recip_move_20() {
1311        let op = Move(2, 0);
1312        assert_eq!(op.recip().recip(), op);
1313    }
1314
1315    #[test]
1316    fn recip_move_02() {
1317        let op = Move(0, 2);
1318        assert_eq!(op.recip().recip(), op);
1319    }
1320
1321    #[test]
1322    fn add_0_add_1_move_02() {
1323        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1324        pb.check().unwrap();
1325    }
1326
1327    #[test]
1328    fn add_0_add_0() {
1329        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1330        pb.check().unwrap();
1331    }
1332
1333    #[test]
1334    fn add_0_add_0_move_02() {
1335        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1336        pb.check().unwrap();
1337    }
1338
1339    #[test]
1340    fn add_0_add_2_move_12() {
1341        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1342        pb.check().unwrap();
1343    }
1344
1345    #[test]
1346    fn add_0_add_0_move_02_rm_0() {
1347        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1348        pb.check().unwrap();
1349    }
1350
1351    #[test]
1352    fn add_0_add_0_move_20_move_20() {
1353        let pb =
1354            ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1355        pb.check().unwrap();
1356    }
1357
1358    #[test]
1359    fn move_01_add_0() {
1360        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1361        pb.check().unwrap();
1362    }
1363
1364    #[test]
1365    fn add_0_move_02_move_02() {
1366        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1367        pb.check().unwrap();
1368    }
1369
1370    #[test]
1371    fn add_0_add_2_move_20_move_12_rm_2() {
1372        let pb = ComposeProblem {
1373            input: tvec![3],
1374            ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1375        };
1376        pb.check().unwrap();
1377    }
1378
1379    #[test]
1380    fn move_02_move_02() {
1381        let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1382        pb.check().unwrap();
1383    }
1384
1385    #[test]
1386    fn rm_1_perm_10_add_0() {
1387        let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1388        pb.check().unwrap();
1389    }
1390
1391    #[test]
1392    fn add_2_move_02_move_02() {
1393        let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1394        pb.check().unwrap();
1395    }
1396
1397    #[test]
1398    fn move_01_move_20_move_20() {
1399        let pb = ComposeProblem {
1400            input: tvec![2, 3, 2],
1401            ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1402        };
1403        pb.check().unwrap();
1404    }
1405
1406    #[test]
1407    fn reshape_axes_tracking() {
1408        let pb = ComposeProblem {
1409            input: tvec![2, 2, 2],
1410            ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1411        };
1412        pb.check().unwrap();
1413    }
1414
1415    #[test]
1416    fn simplify_reshape() {
1417        macro_rules! d {
1418            ($($dim: expr),*) =>  { tvec!($($dim.to_dim()),*) }
1419        }
1420        assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1421        assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1422        assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1423        assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1424        assert_eq!(
1425            Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1426            tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1427        );
1428        assert_eq!(
1429            Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1430            tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1431        );
1432        assert_eq!(
1433            Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1434            tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1435        );
1436        assert_eq!(
1437            Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1438            tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1439        );
1440        assert_eq!(
1441            Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1442            tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1443        );
1444        assert_eq!(
1445            Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1446            tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1447        );
1448        assert_eq!(
1449            Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1450            tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1451        );
1452        assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1453    }
1454
1455    macro_rules! s {
1456        ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1457    }
1458
1459    macro_rules! r {
1460        ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1461            AxisOp::Reshape($at, tvec!($($from.into()),*),  tvec!($($to.into()),*))
1462        }
1463    }
1464
1465    #[test]
1466    fn compute_invalid() {
1467        assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1468    }
1469
1470    #[test]
1471    fn compute_with_leading_zero() {
1472        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1473    }
1474
1475    #[test]
1476    fn compute_with_leading_zero_with_flatten() {
1477        assert_eq!(
1478            &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1479            s![2, 3, 35]
1480        )
1481    }
1482
1483    #[test]
1484    fn compute_with_trailing_zero() {
1485        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1486    }
1487
1488    #[test]
1489    fn compute_bug_1() {
1490        let table = SymbolScope::default();
1491        let s = table.new_with_prefix("S");
1492        assert_eq!(
1493            &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1494            s![s, 1, 256]
1495        )
1496    }
1497
1498    #[test]
1499    fn compute_bug_2() {
1500        let table = SymbolScope::default();
1501        let b = table.new_with_prefix("B");
1502        let s = table.new_with_prefix("S");
1503        assert_eq!(
1504            &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1505            s![s, b, 256]
1506        )
1507    }
1508
1509    #[test]
1510    fn axis_op_rm_begin() {
1511        assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1512    }
1513
1514    #[test]
1515    fn axis_op_rm_end() {
1516        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1517    }
1518
1519    #[test]
1520    fn axis_op_insert_begin() {
1521        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1522    }
1523
1524    #[test]
1525    fn axis_op_insert_end() {
1526        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1527    }
1528
1529    #[test]
1530    fn axis_op_merge() {
1531        assert_eq!(
1532            &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1533            &[r!(2 ; 5,7 => 35 )]
1534        )
1535    }
1536
1537    #[test]
1538    fn axis_op_complex() {
1539        assert_eq!(
1540            &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1541            &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1542        )
1543    }
1544}