Skip to main content

tract_core/ops/
change_axes.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use AxisOp::*;
8use num_traits::One;
9use tract_itertools::Itertools;
10use tract_linalg::block_quant::BlockQuantFact;
11use tract_ndarray::{ArrayViewD, ArrayViewMutD};
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum InOut {
15    Out(usize),
16    In(usize),
17}
18
19impl InOut {
20    pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
21        match self {
22            InOut::In(ix) => node.inputs[*ix],
23            InOut::Out(ix) => OutletId::new(node.id, *ix),
24        }
25    }
26
27    pub fn is_input(&self) -> bool {
28        matches!(self, InOut::In(_))
29    }
30
31    pub fn is_output(&self) -> bool {
32        matches!(self, InOut::Out(_))
33    }
34
35    pub fn slot(&self) -> usize {
36        match self {
37            InOut::Out(o) => *o,
38            InOut::In(i) => *i,
39        }
40    }
41}
42
43#[derive(Clone, Hash, Eq)]
44#[allow(clippy::large_enum_variant)] // FIXME ?
45#[allow(clippy::derived_hash_with_manual_eq)] // FIXME. this one may be pretty bad. how about a.canonical() == b.canonical() ? need proper canonicalizeation of Reshape
46pub enum AxisOp {
47    Add(usize),
48    Rm(usize),
49    Move(usize, usize),
50    Reshape(usize, TVec<TDim>, TVec<TDim>),
51}
52
53impl Debug for AxisOp {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AxisOp::Add(a) => write!(f, "Add({a})"),
57            AxisOp::Rm(a) => write!(f, "Rm({a})"),
58            AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
59            AxisOp::Reshape(at, from, to) => {
60                write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
61            }
62        }
63    }
64}
65
66impl PartialEq for AxisOp {
67    fn eq(&self, other: &AxisOp) -> bool {
68        if self.is_noop() && other.is_noop() {
69            true
70        } else if self.is_noop() != other.is_noop() {
71            false
72        } else {
73            match (self, other) {
74                (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
75                (Move(f1, t1), Move(f2, t2)) => {
76                    (f1 == f2 && t1 == t2)
77                        || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
78                }
79                (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
80                _ => false,
81            }
82        }
83    }
84}
85
86impl AxisOp {
87    pub fn canonical(&self) -> Cow<'_, AxisOp> {
88        match self {
89            Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
90            Reshape(at, from, to)
91                if from.len() == 1 && to.len() == 2 && from[0] == to[0] && to[1].is_one() =>
92            {
93                Cow::Owned(Add(*at + 1))
94            }
95            Reshape(at, from, to)
96                if from.len() == 1 && to.len() == 2 && from[0] == to[1] && to[0].is_one() =>
97            {
98                Cow::Owned(Add(*at))
99            }
100            Reshape(at, from, to)
101                if from.len() == 2 && to.len() == 1 && from[0] == to[0] && from[1].is_one() =>
102            {
103                Cow::Owned(Rm(*at + 1))
104            }
105            Reshape(at, from, to)
106                if from.len() == 2 && to.len() == 1 && from[1] == to[0] && from[0].is_one() =>
107            {
108                Cow::Owned(Rm(*at))
109            }
110            other => Cow::Borrowed(other),
111        }
112    }
113
114    pub fn simplify(&self) -> TVec<AxisOp> {
115        match self.canonical().borrow() {
116            Reshape(_, from, to) if from == to => tvec!(),
117            Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
118            Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
119            Reshape(at, from, to) if from[0] == to[0] => {
120                Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
121            }
122            Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
123                Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
124            }
125            Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
126                .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
127                .collect(),
128            Reshape(at, from, to) if to[0] == 1.to_dim() => {
129                Reshape(*at, from.clone(), to[1..].into())
130                    .simplify()
131                    .into_iter()
132                    .chain(std::iter::once(Add(*at)))
133                    .collect()
134            }
135            Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
136                std::iter::once(Rm(at + from.len() - 1))
137                    .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
138                    .collect()
139            }
140            Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
141                std::iter::once(Add(at + from.len()))
142                    .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
143                    .collect()
144            }
145            other => tvec!(other.clone()),
146        }
147    }
148
149    pub fn transform_axis(&self, axis: usize) -> Option<usize> {
150        match self.canonical().as_ref() {
151            Add(ix) => Some(axis + (axis >= *ix) as usize),
152            Rm(ix) => {
153                if axis == *ix {
154                    None
155                } else {
156                    Some(axis - (axis > *ix) as usize)
157                }
158            }
159            Move(from, to) if from < to => {
160                if axis < *from || axis > *to {
161                    Some(axis)
162                } else if axis == *from {
163                    Some(*to)
164                } else {
165                    Some(axis - 1)
166                }
167            }
168            Move(from, to) => {
169                if axis < *to || axis > *from {
170                    Some(axis)
171                } else if axis == *from {
172                    Some(*to)
173                } else {
174                    Some(axis + 1)
175                }
176            }
177            Reshape(at, _, _) if axis < *at => Some(axis),
178            Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
179            Reshape(_, _, _) => None,
180        }
181    }
182
183    // if sucessful return Some()
184    // first item is the Op we want to be replaced by. if none, we are now identity.
185    // second item is the change to propagate. if none, the output is not
186    // changed
187    pub fn merge_incoming_change(
188        &self,
189        change: &AxisOp,
190    ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
191        match (self.canonical().as_ref(), change.canonical().as_ref()) {
192            (Add(op), Add(c)) => {
193                Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
194            }
195            (Add(op), Rm(c)) => {
196                Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
197            }
198            (Rm(op), Add(c)) => {
199                Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
200            }
201            (Rm(op), Rm(c)) => {
202                Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
203            }
204
205            (Add(x), Move(from, to)) => {
206                if x <= from.min(to) {
207                    Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
208                } else if x > from.max(to) {
209                    Some((Some(self.clone()), Some(change.clone())))
210                } else {
211                    None
212                }
213            }
214
215            (Move(from, to), Add(x)) => {
216                if x <= from.min(to) {
217                    Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
218                } else if x > from.max(to) {
219                    Some((Some(Move(*from, *to)), Some(Add(*x))))
220                } else {
221                    None
222                }
223            }
224
225            (Rm(x), Move(from, to)) => {
226                if x == from {
227                    Some((Some(Rm(*to)), None))
228                } else if x < from.min(to) {
229                    Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
230                } else if x > from.max(to) {
231                    Some((Some(self.clone()), Some(change.clone())))
232                } else if from + 1 == *to && x == to {
233                    Some((Some(Rm(*from)), None))
234                } else if from < to && x <= to {
235                    Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
236                } else {
237                    Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
238                }
239            }
240
241            (Move(from, to), Rm(x)) => {
242                if x < from.min(to) {
243                    Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
244                } else if x > from.max(to) {
245                    Some((Some(Move(*from, *to)), Some(Rm(*x))))
246                } else {
247                    None
248                }
249            }
250
251            (Add(op), Reshape(at, from, to)) => {
252                if op <= at {
253                    Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
254                } else if *op > at + from.len() {
255                    Some((
256                        Some(Add(*op + to.len() - from.len())),
257                        Some(Reshape(*at, from.clone(), to.clone())),
258                    ))
259                } else {
260                    None
261                }
262            }
263            (Rm(op), Reshape(at, from, to)) => {
264                if op < at {
265                    Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
266                } else if *op > at + from.len() {
267                    Some((
268                        Some(Rm(*op + to.len() - from.len())),
269                        Some(Reshape(*at, from.clone(), to.clone())),
270                    ))
271                } else {
272                    None
273                }
274            }
275            (Reshape(at, from, to), Add(change)) => {
276                if change < at {
277                    Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
278                } else if *change > *at + from.len() {
279                    Some((
280                        Some(Reshape(*at, from.clone(), to.clone())),
281                        Some(Add(change + to.len() - from.len())),
282                    ))
283                } else {
284                    None
285                }
286            }
287            (Reshape(at, from, to), Rm(change)) => {
288                if change < at {
289                    Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
290                } else if *change > *at + from.len() {
291                    Some((
292                        Some(Reshape(*at, from.clone(), to.clone())),
293                        Some(Rm(change + to.len() - from.len())),
294                    ))
295                } else {
296                    None
297                }
298            }
299            (Reshape(_, _, _), Move(_, _)) => None, // todo, some are manageable
300            (Move(_, _), Reshape(_, _, _)) => None, // todo, some are manageable
301            (Reshape(_, _, _), Reshape(_, _, _)) => None, // todo, some are manageable
302            _ => None,
303        }
304    }
305
306    pub fn change_shape_array<D: DimLike>(
307        &self,
308        shape: &mut TVec<D>,
309        broadcasting: bool,
310    ) -> TractResult<()> {
311        match self.canonical().as_ref() {
312            Add(ix) => {
313                ensure!(*ix <= shape.len());
314                shape.insert(*ix, D::one());
315            }
316            Rm(ix) => {
317                ensure!(*ix < shape.len());
318                shape.remove(*ix);
319            }
320            Move(from, to) => {
321                ensure!(*from < shape.len());
322                ensure!(*to < shape.len());
323                let axis = shape.remove(*from);
324                shape.insert(*to, axis);
325            }
326            Reshape(at, from, to) => {
327                let from_volume = from.iter().product::<TDim>();
328                let to_volume = to.iter().product::<TDim>();
329                ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
330                ensure!(*at + from.len() <= shape.len());
331                if shape.len() >= from.len() + *at
332                    && tract_itertools::izip!(shape.iter().skip(*at), from)
333                        .all(|(shape, spec)| shape.to_dim() == *spec)
334                {
335                    for _ in from {
336                        shape.remove(*at);
337                    }
338                    for d in to.iter().rev() {
339                        shape.insert(*at, d.try_into()?);
340                    }
341                } else if broadcasting
342                    && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
343                {
344                    for _ in from {
345                        shape.remove(*at);
346                    }
347                    for _ in to.iter().rev() {
348                        shape.insert(*at, 1.into());
349                    }
350                } else {
351                    bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
352                }
353            }
354        }
355        Ok(())
356    }
357
358    pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
359        match self.canonical().as_ref() {
360            Add(ix) => shape.insert_axis(*ix),
361            Rm(ix) => {
362                if shape.rank() <= *ix {
363                    bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
364                }
365                if shape[*ix] != 1.to_dim() {
366                    bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
367                }
368                shape.remove_axis(*ix)
369            }
370            _ => {
371                let mut array = shape.to_tvec();
372                self.change_shape_array(&mut array, broadcasting)?;
373                let mut new_shape = ShapeFact::from_dims(array);
374                std::mem::swap(shape, &mut new_shape);
375                Ok(())
376            }
377        }
378    }
379
380    pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
381        if self.required_rank() > tensor.rank() && tensor.datum_type().is_opaque() {
382            let inner_change = self.trim_left(tensor.rank())?;
383            for opaque in tensor.as_slice_mut::<Opaque>()? {
384                if let Some(bwf) = opaque.downcast_ref::<BlobWithFact>() {
385                    let bqf = bwf
386                        .fact
387                        .downcast_ref::<BlockQuantFact>()
388                        .context("Expected BlockQuantFact")?;
389                    let mut new_shape: TVec<usize> = bqf.shape().into();
390                    inner_change.change_shape_array(&mut new_shape, false)?;
391                    let new_bqv = BlobWithFact {
392                        value: Arc::clone(&bwf.value),
393                        fact: Box::new(BlockQuantFact::new(bqf.format.clone(), new_shape)),
394                    };
395                    *opaque = Opaque(Arc::new(new_bqv));
396                } else {
397                    bail!("Can't apply {self:?} to opaque tensor {tensor:?}");
398                }
399            }
400            return Ok(());
401        }
402        ensure!(self.required_rank() <= tensor.rank());
403        match self.canonical().as_ref() {
404            Add(ix) => tensor.insert_axis(*ix),
405            Rm(ix) => tensor.remove_axis(*ix),
406            Move(from, to) => {
407                let mut tmp = tensor.clone().move_axis(*from, *to)?;
408                std::mem::swap(tensor, &mut tmp);
409                Ok(())
410            }
411            Reshape(at, from, to) => {
412                let mut shape: TVec<usize> = tensor.shape().into();
413                self.change_shape_array(&mut shape, true)?;
414                if tensor.set_shape(&shape).is_ok() {
415                    Ok(())
416                } else if broadcasting
417                    && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
418                {
419                    if from.len() > to.len() {
420                        for _ in to.len()..from.len() {
421                            tensor.remove_axis(*at)?;
422                        }
423                    }
424                    if to.len() > from.len() {
425                        for _ in from.len()..to.len() {
426                            tensor.insert_axis(*at)?;
427                        }
428                    }
429                    Ok(())
430                } else {
431                    bail!(
432                        "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
433                        self,
434                        tensor,
435                        broadcasting
436                    )
437                }
438            }
439        }
440    }
441
442    pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
443        use tract_ndarray::Axis;
444        match *self {
445            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
446            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
447            AxisOp::Move(from, to) if from < to => {
448                for left in from..to {
449                    view.swap_axes(left, left + 1);
450                }
451            }
452            AxisOp::Move(from, to) => {
453                for left in (to..from).rev() {
454                    view.swap_axes(left, left + 1);
455                }
456            }
457            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
458        }
459        Ok(())
460    }
461
462    pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
463        use tract_ndarray::Axis;
464        match *self {
465            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
466            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
467            AxisOp::Move(from, to) if from < to => {
468                for left in from..to {
469                    view.swap_axes(left, left + 1);
470                }
471            }
472            AxisOp::Move(from, to) => {
473                for left in (to..from).rev() {
474                    view.swap_axes(left, left + 1);
475                }
476            }
477            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
478        }
479        Ok(())
480    }
481
482    pub fn recip(&self) -> AxisOp {
483        match self.canonical().as_ref() {
484            Add(ix) => Rm(*ix),
485            Rm(ix) => Add(*ix),
486            Move(from, to) if from == to => self.clone(),
487            Move(from, to) if *from + 1 == *to => self.clone(),
488            Move(from, to) if *from == *to + 1 => {
489                unreachable!();
490            }
491            Move(from, to) => Move(*to, *from),
492            Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
493        }
494    }
495
496    pub fn is_noop(&self) -> bool {
497        match self {
498            Move(f, t) if f == t => true,
499            Reshape(_, f, t) if f == t => true,
500            _ => false,
501        }
502    }
503
504    pub fn only_shape(&self) -> bool {
505        if self.is_noop() {
506            return true;
507        }
508        !matches!(self, Move(_, _))
509    }
510
511    pub fn wire_split_axis(
512        model: &mut TypedModel,
513        name: impl ToString,
514        outlet: OutletId,
515        axis: usize,
516        outer_dim: usize,
517    ) -> TractResult<TVec<OutletId>> {
518        let fact = model.outlet_fact(outlet)?;
519        let dim: TDim = fact.shape[axis].clone();
520        let inner_dim = dim.clone() / outer_dim;
521        let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
522        model.wire_node(name.to_string(), op, &[outlet])
523    }
524
525    pub fn wire_collapse_axis(
526        model: &mut TypedModel,
527        name: impl ToString,
528        outlet: OutletId,
529        axis: usize,
530    ) -> TractResult<TVec<OutletId>> {
531        let fact = model.outlet_fact(outlet)?;
532        let dim: TDim = fact.shape[axis].clone();
533        let next_dim: TDim = fact.shape[axis + 1].clone();
534        let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
535        model.wire_node(name.to_string(), op, &[outlet])
536    }
537
538    #[inline]
539    pub fn required_rank(&self) -> usize {
540        match self {
541            Rm(r) => r + 1,
542            Add(a) => *a,
543            Reshape(at, from, _to) => at + from.len(),
544            Move(from, to) => *from.max(to),
545        }
546    }
547
548    pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
549        Ok(match self {
550            Rm(r) if *r >= prefix => Rm(r - prefix),
551            Add(a) if *a >= prefix => Add(a - prefix),
552            Reshape(at, from, to) if *at >= prefix => {
553                Reshape(at - prefix, from.clone(), to.clone())
554            }
555            Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
556            _ => bail!("Can no trim left {self:?} by {prefix}"),
557        })
558    }
559}
560
561pub fn wire_rank_broadcast(
562    prefix: impl AsRef<str>,
563    target: &mut TypedModel,
564    inputs: &[OutletId],
565) -> TractResult<TVec<OutletId>> {
566    let facts =
567        inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
568    let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
569    let mut wires = tvec!();
570    for i in 0..inputs.len() {
571        let mut wire = inputs[i];
572        for _ in facts[i].rank()..max_rank {
573            let name = target.unique_name(prefix.as_ref().to_string() + ".fix-rank");
574            wire = target.wire_node(name, AxisOp::Add(0), &[wire])?[0];
575        }
576        wires.push(wire);
577    }
578    Ok(wires)
579}
580
581pub fn wire_with_rank_broadcast(
582    prefix: impl AsRef<str>,
583    target: &mut TypedModel,
584    op: impl Into<Box<dyn TypedOp>>,
585    inputs: &[OutletId],
586) -> TractResult<TVec<OutletId>> {
587    let prefix = prefix.as_ref();
588    let wires = wire_rank_broadcast(prefix, target, inputs)?;
589    target.wire_node(prefix, op.into(), &wires)
590}
591
592#[derive(Clone, Debug, PartialEq, Eq, Hash)]
593pub struct AxisChange {
594    pub outlet: OutletId,
595    pub op: AxisOp,
596}
597
598#[derive(Clone, Default, Debug)]
599pub struct AxisChangeConsequence {
600    pub substitute_op: Option<Box<dyn TypedOp>>,
601    pub wire_changes: TVec<(InOut, AxisOp)>,
602}
603
604impl AxisChangeConsequence {
605    pub fn new(
606        _model: &TypedModel,
607        node: &TypedNode,
608        op: Option<Box<dyn TypedOp>>,
609        axis_op: &AxisOp,
610    ) -> AxisChangeConsequence {
611        let mut wire_changes = tvec!();
612        for i in 0..node.inputs.len() {
613            wire_changes.push((InOut::In(i), axis_op.clone()));
614        }
615        for i in 0..node.outputs.len() {
616            wire_changes.push((InOut::Out(i), axis_op.clone()));
617        }
618        AxisChangeConsequence { wire_changes, substitute_op: op }
619    }
620}
621
622impl Op for AxisOp {
623    fn name(&self) -> StaticName {
624        match self {
625            Add(_) => "AddAxis".into(),
626            Rm(_) => "RmAxis".into(),
627            Move(_, _) => "MoveAxis".into(),
628            Reshape(_, _, _) => "Reshape".into(),
629        }
630    }
631
632    fn info(&self) -> TractResult<Vec<String>> {
633        match self {
634            Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
635            Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
636            Reshape(at, from, to) => Ok(vec![format!(
637                "Axes starting at {}: {:?} to {:?}",
638                at,
639                from.iter().join(","),
640                to.iter().join(",")
641            )]),
642        }
643    }
644
645    op_as_typed_op!();
646}
647
648impl EvalOp for AxisOp {
649    fn is_stateless(&self) -> bool {
650        true
651    }
652
653    fn eval_with_session(
654        &self,
655        _node_id: usize,
656        session: &TurnState,
657        inputs: TVec<TValue>,
658    ) -> TractResult<TVec<TValue>> {
659        let mut input = args_1!(inputs).into_tensor();
660        match self {
661            AxisOp::Reshape(skip, from, to) => {
662                let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
663                let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
664                AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
665            }
666            _ => self.change_tensor(&mut input, false)?,
667        }
668        Ok(tvec!(input.into_tvalue()))
669    }
670}
671
672impl TypedOp for AxisOp {
673    as_op!();
674
675    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
676        if self.required_rank() > inputs[0].rank() {
677            if let Some(bqf) =
678                inputs[0].opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
679            {
680                let mut new_inner_shape: TVec<usize> = bqf.shape().into();
681                self.trim_left(inputs[0].rank())?
682                    .change_shape_array(&mut new_inner_shape, false)?;
683                let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_inner_shape);
684                let mut new_fact = Opaque::fact(inputs[0].shape.clone()).with_opaque_fact(new_bqf);
685                if let Some(k) = &inputs[0].konst {
686                    let mut new = k.clone().into_tensor(); // cloning bqv is cheap
687                    self.change_tensor(&mut new, false)?;
688                    new_fact.konst = Some(new.into());
689                }
690                return Ok(tvec!(new_fact));
691            }
692        }
693        let mut shape = inputs[0].shape.clone();
694        self.change_shape(&mut shape, false)?;
695        let mut fact = inputs[0].datum_type.fact(shape);
696        fact.opaque_fact.clone_from(&inputs[0].opaque_fact);
697        Ok(tvec!(fact))
698    }
699
700    fn axes_mapping(
701        &self,
702        inputs: &[&TypedFact],
703        outputs: &[&TypedFact],
704    ) -> TractResult<AxesMapping> {
705        let mut axes: Vec<Axis> = (0..inputs[0].rank())
706            .zip('a'..)
707            .map(|(axis_id, repr)| {
708                let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
709                if let Some(out) = self.transform_axis(axis_id) {
710                    axis = axis.output(0, out);
711                }
712                axis
713            })
714            .collect();
715        for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
716            if self.recip().transform_axis(axis).is_none() {
717                axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
718            }
719        }
720        AxesMapping::new(inputs.len(), outputs.len(), axes)
721    }
722
723    fn declutter(
724        &self,
725        model: &TypedModel,
726        node: &TypedNode,
727    ) -> TractResult<Option<TypedModelPatch>> {
728        if self.is_noop() {
729            if let Some(p) = TypedModelPatch::shunt_one_op(model, node)? {
730                return Ok(Some(p));
731            }
732        }
733        let simplified = self.simplify();
734        if simplified.len() != 1 || &simplified[0] != self {
735            let mut patch = TypedModelPatch::default();
736            let mut wire = patch.tap_model(model, node.inputs[0])?;
737            for (ix, op) in simplified.into_iter().enumerate() {
738                wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
739            }
740            patch.shunt_outside(model, node.id.into(), wire)?;
741            Ok(Some(patch))
742        } else {
743            Ok(None)
744        }
745    }
746
747    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
748        Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
749    }
750
751    fn change_axes(
752        &self,
753        _model: &TypedModel,
754        _node: &TypedNode,
755        io: InOut,
756        change: &AxisOp,
757    ) -> TractResult<Option<AxisChangeConsequence>> {
758        let op = if let InOut::Out(0) = io {
759            rule_if_some!(more = self.recip().change_axes(_model, _node, InOut::In(0), change)?);
760            AxisChangeConsequence {
761                substitute_op: more.substitute_op.map(|op| {
762                    if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
763                        Box::new(op.recip())
764                    } else {
765                        op // have to be identity
766                    }
767                }),
768                wire_changes: more
769                    .wire_changes
770                    .into_iter()
771                    .map(|wc| {
772                        (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
773                    })
774                    .collect(),
775            }
776        } else if change == self {
777            AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
778        } else {
779            rule_if_some!((new_op, new_change) = self.merge_incoming_change(change));
780            trace!("  Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
781            let substitute_op: Box<dyn TypedOp> =
782                if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
783            let mut wire_changes = tvec!();
784            if !change.is_noop() {
785                wire_changes.push((InOut::In(0), change.clone()))
786            }
787            if let Some(new_change) = new_change {
788                wire_changes.push((InOut::Out(0), new_change))
789            }
790            AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
791        };
792        Ok(Some(op))
793    }
794
795    fn concretize_dims(
796        &self,
797        _source: &TypedModel,
798        node: &TypedNode,
799        target: &mut TypedModel,
800        mapping: &HashMap<OutletId, OutletId>,
801        values: &SymbolValues,
802    ) -> TractResult<TVec<OutletId>> {
803        let op = if let AxisOp::Reshape(axis, from, to) = self {
804            AxisOp::Reshape(
805                *axis,
806                from.iter().map(|d| d.eval(values)).collect(),
807                to.iter().map(|d| d.eval(values)).collect(),
808            )
809        } else {
810            self.clone()
811        };
812        target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
813    }
814
815    fn slice(
816        &self,
817        patch: &mut TypedModelPatch,
818        _model: &TypedModel,
819        node: &TypedNode,
820        _prefix: &str,
821        inputs: &[OutletId],
822        output_axis: usize,
823        _start: &TDim,
824        _end: &TDim,
825    ) -> TractResult<Option<TVec<OutletId>>> {
826        // is this test really useful ? or axis mapping preempt this ?
827        if let Reshape(pos, _from, to) = self {
828            if output_axis >= *pos && output_axis < pos + to.len() {
829                return Ok(None);
830            }
831        }
832        patch.wire_node(&node.name, &node.op, inputs).map(Some)
833    }
834
835    fn codegen(
836        &self,
837        model: &TypedModel,
838        node: &TypedNode,
839    ) -> TractResult<Option<TypedModelPatch>> {
840        if node.outputs[0].fact.opaque_fact.is_some() {
841            return Ok(None);
842        }
843        if let Some(shape) = node.outputs[0].fact.shape.as_concrete() {
844            if !matches!(self, AxisOp::Move(_, _)) {
845                let (inputs, outputs) = model.node_facts(node.id)?;
846                let mapping = self.axes_mapping(&inputs, &outputs)?;
847                let op = IntoShape {
848                    mapping,
849                    len: shape.iter().product(),
850                    strides: Tensor::natural_strides(shape),
851                    dims: shape.into(),
852                };
853                return Ok(Some(TypedModelPatch::replace_single_op(
854                    model,
855                    node,
856                    &node.inputs,
857                    op,
858                )?));
859            }
860        }
861        Ok(None)
862    }
863}
864
865// a, b, c is a <- b, b <- c, c <- a
866fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
867    let mut cycles: TVec<TVec<usize>> = tvec!();
868    let mut done = 0;
869    while done < perm.len() {
870        if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
871            done += 1;
872            continue;
873        }
874        let mut cycle = tvec!();
875        let mut current = done;
876        loop {
877            cycle.push(current);
878            current = perm[current];
879            if current == done {
880                break;
881            }
882        }
883        cycles.push(cycle)
884    }
885    cycles
886}
887
888fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
889    if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
890        Some((cycle[0], cycle[cycle.len() - 1]))
891    } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
892        && cycle[cycle.len() - 1] - 1 == cycle[0]
893    {
894        Some((cycle[1], cycle[0]))
895    } else {
896        None
897    }
898}
899
900fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
901    let mut changes: TVec<(usize, usize)> = tvec!();
902    'top: loop {
903        let mut reached: TVec<usize> = (0..input.len()).collect();
904        changes.iter().for_each(|(f, t)| {
905            let axis = reached.remove(*f);
906            reached.insert(*t, axis);
907        });
908        if &*reached == input {
909            return changes;
910        }
911        let remaining: TVec<usize> =
912            input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
913        let cycles = perm_to_cycles(&remaining);
914        for cycle in &cycles {
915            if let Some(rot) = is_rotation_cycle(cycle) {
916                changes.push(rot);
917                continue 'top;
918            }
919        }
920        changes.push((cycles[0][1], cycles[0][0]));
921    }
922}
923
924pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
925    perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
926}
927
928pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
929    let mut shape: TVec<TDim> = shape_spec.into();
930    fn deal_with_zero<'a>(
931        mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
932        shape: &mut [TDim],
933    ) -> TractResult<()> {
934        let mut remaining_dim_input = 1.to_dim();
935        for slot in shape.iter_mut() {
936            if *slot == (-1).into() {
937                break;
938            }
939            if *slot == 0.into() {
940                if remaining_dim_input != TDim::one() {
941                    bail!("Invalid remaining dim");
942                }
943                *slot = (*input_dims.peek().context("Invalid")?).clone();
944            }
945            loop {
946                let quotient = remaining_dim_input.maybe_div(slot);
947                if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
948                    remaining_dim_input *= input_dims.next().context("Invalid")?;
949                } else {
950                    break;
951                }
952            }
953            remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
954        }
955        Ok(())
956    }
957
958    deal_with_zero(input.iter().peekable(), &mut shape)?;
959    shape.reverse();
960    deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
961    shape.reverse();
962
963    if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
964        let input_vol: TDim = input.iter().product();
965        let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
966        let div = input_vol.maybe_div(&shape_vol)?;
967        if div.1 != 1 {
968            bail!("invalid")
969        }
970        shape[pos] = div.0;
971    }
972    Ok(shape)
973}
974
975pub fn to_axis_ops_with_tf_rules(
976    input_orig: &[TDim],
977    output_spec: &[TDim],
978) -> TractResult<TVec<AxisOp>> {
979    let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
980    let mut stack: TVec<AxisOp> = tvec!();
981    'top: loop {
982        let current_input =
983            stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
984                op.change_shape_array(&mut shape, false)?;
985                Ok(shape)
986            })?;
987        if current_input == final_output {
988            return Ok(stack);
989        }
990        if let Some(common) =
991            current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
992        {
993            if current_input[common].is_one() {
994                stack.push(AxisOp::Rm(common));
995            } else if final_output[common].is_one() {
996                stack.push(AxisOp::Add(common));
997            } else {
998                // actual regrouping. search for a match. this is quadratic, but
999                // rank is expected to be somewhat reasonable
1000                for i in common..current_input.len() {
1001                    let i_group = &current_input[common..i + 1];
1002                    let i_volume: TDim = i_group.iter().product();
1003                    for o in common..final_output.len() {
1004                        let o_group = &final_output[common..o + 1];
1005                        let o_volume: TDim = o_group.iter().product();
1006                        if i_volume == o_volume {
1007                            stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1008                            continue 'top;
1009                        }
1010                    }
1011                }
1012                todo!()
1013            }
1014        } else if final_output.len() > current_input.len() {
1015            stack.push(AxisOp::Add(current_input.len()));
1016        } else {
1017            stack.push(AxisOp::Rm(current_input.len() - 1));
1018        }
1019    }
1020}
1021
1022#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1023pub struct IntoShape {
1024    pub mapping: AxesMapping,
1025    pub len: usize,
1026    pub dims: TVec<usize>,
1027    pub strides: TVec<isize>,
1028}
1029
1030impl Op for IntoShape {
1031    fn name(&self) -> StaticName {
1032        "IntoShape".into()
1033    }
1034
1035    fn info(&self) -> TractResult<Vec<String>> {
1036        Ok(vec![format!("{}", self.mapping)])
1037    }
1038
1039    op_as_typed_op!();
1040    impl_op_same_as!();
1041}
1042
1043impl EvalOp for IntoShape {
1044    fn is_stateless(&self) -> bool {
1045        true
1046    }
1047
1048    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1049        let mut input = args_1!(inputs).into_tensor();
1050        ensure!(input.len() == self.len);
1051        unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1052        Ok(tvec!(input.into_tvalue()))
1053    }
1054}
1055
1056impl TypedOp for IntoShape {
1057    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1058        let mut fact = inputs[0].datum_type.fact(&self.dims);
1059        if let Some(of) = &inputs[0].opaque_fact {
1060            fact = fact.with_opaque_fact(of.clone());
1061        }
1062        Ok(tvec!(fact))
1063    }
1064
1065    fn declutter(
1066        &self,
1067        model: &TypedModel,
1068        node: &TypedNode,
1069    ) -> TractResult<Option<TypedModelPatch>> {
1070        let input = model.outlet_fact(node.inputs[0])?;
1071        if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1072            return TypedModelPatch::shunt_one_op(model, node);
1073        }
1074        if let Some(succ) = model.single_succ(node.id)? {
1075            if let Some(into_shape) = succ.op_as::<IntoShape>() {
1076                let op = Self {
1077                    mapping: self.mapping.compose(&into_shape.mapping)?,
1078                    ..into_shape.clone()
1079                };
1080                return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1081            }
1082        }
1083        Ok(None)
1084    }
1085
1086    as_op!();
1087}
1088
1089#[cfg(test)]
1090mod test {
1091    use super::*;
1092
1093    #[test]
1094    fn test_perm_to_cycles() {
1095        assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1096        assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1097        assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1098        assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1099        assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1100    }
1101
1102    #[test]
1103    fn is_rotation() {
1104        assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1105        assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1106    }
1107
1108    #[test]
1109    fn test_perm_one_rotation() {
1110        assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1111    }
1112
1113    #[test]
1114    fn test_perm_two_rotations() {
1115        assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1116    }
1117
1118    #[test]
1119    fn test_perm_complex() {
1120        assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1121    }
1122
1123    // ADD-ADD
1124
1125    //                          Op
1126    //           b,c   ------|Add(0)|----->        n,b,c
1127    //   Add(0)                                            Add(1)
1128    //         a,b,c   ------|Add(0)|----->        a,n,b,c
1129    #[test]
1130    pub fn transform_op_add_0_add_0() {
1131        let change = Add(0);
1132        let op = Add(0);
1133        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1134    }
1135
1136    //                          Op
1137    //           b,c   ------|Add(1)|----->        b,n,c
1138    //   Add(0)                                                 Add(0)
1139    //         a,b,c   ------|Add(2)|----->        a,b,n,c
1140    #[test]
1141    pub fn transform_op_add_0_add_1() {
1142        let change = Add(0);
1143        let op = Add(1);
1144        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1145    }
1146
1147    //                          Op
1148    //           a,c   ------|Add(0)|----->        n,a,c
1149    //   Add(1)                                                 Add(2)
1150    //         a,b,c   ------|Add(0)|----->        n,a,b,c
1151    #[test]
1152    pub fn transform_op_add_1_add_0() {
1153        let change = Add(1);
1154        let op = Add(0);
1155        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1156    }
1157
1158    //                          Op
1159    //         a,b,c   ------|Rm(1)|----->         a,c
1160    //   Rm(0)                                             Rm(0)
1161    //           b,c   ------|Rm(0)|----->         c
1162    #[test]
1163    pub fn transform_op_rm_0_rm_1() {
1164        let change = Rm(0);
1165        let op = Rm(1);
1166        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1167    }
1168
1169    //                          Op
1170    //         a,b,c   ------|Rm(0)|----->         b,c
1171    //   Rm(1)                                             Rm(0)
1172    //           a,c   ------|Rm(0)|----->         c
1173    #[test]
1174    pub fn transform_op_rm_1_rm_0() {
1175        let change = Rm(1);
1176        let op = Rm(0);
1177        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1178    }
1179
1180    // ADD - RM
1181
1182    //                          Op
1183    //          b,c     ------|Rm(0)|------>        c
1184    //   Add(0)                                                 Add(0)
1185    //          a,b,c   ------|Rm(1)|----->         a,c
1186    #[test]
1187    pub fn transform_op_add_0_rm_0() {
1188        let change = Add(0);
1189        let op = Rm(0);
1190        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1191    }
1192
1193    //                          Op
1194    //          b,c     ------|Rm(1)|------>        b
1195    //   Add(0)                                                 Add(0)
1196    //          a,b,c   ------|Rm(2)|----->         a,b
1197    #[test]
1198    pub fn transform_op_add_0_rm_1() {
1199        let change = Add(0);
1200        let op = Rm(1);
1201        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1202    }
1203
1204    //                          Op
1205    //          a,c     ------|Rm(0)|------>        c
1206    //   Add(1)                                                 Add(0)
1207    //          a,b,c   ------|Rm(0)|----->         b,c
1208    #[test]
1209    pub fn transform_op_add_1_rm_0() {
1210        let change = Add(1);
1211        let op = Rm(0);
1212        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1213    }
1214
1215    // RM - ADD
1216
1217    //                          Op
1218    //         a,b,c   ------|Add(0)|----->        X,a,b,c
1219    //   Rm(1)                                                 Rm(2)
1220    //           a,c   ------|Add(0)|----->        X,a,c
1221    #[test]
1222    pub fn transform_op_rm_1_add_0() {
1223        let change = Rm(1);
1224        let op = Add(0);
1225        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1226    }
1227
1228    //                          Op
1229    //         a,b,c   ------|Add(1)|----->        a,X,b,c
1230    //   Rm(0)                                                 Rm(0)
1231    //           b,c   ------|Add(0)|----->        X,b,c
1232    #[test]
1233    pub fn transform_op_rm_0_add_1() {
1234        let change = Rm(0);
1235        let op = Add(1);
1236        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1237    }
1238
1239    //                          Op
1240    //         a,b,c   ------|Rm(2)|----->        a,b
1241    //   Move(0, 2)                                           Move(0,1)
1242    //         b,c,a   ------|Rm(1)|----->        b,a
1243    #[test]
1244    pub fn transform_op_mv_02_rm_2() {
1245        let change = Move(0, 2);
1246        let op = Rm(2);
1247        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1248    }
1249}
1250
1251#[cfg(test)]
1252mod proptests {
1253    use super::*;
1254    use proptest::prelude::*;
1255
1256    #[derive(Debug)]
1257    struct ComposeProblem {
1258        input: TVec<usize>,
1259        ops: TVec<AxisOp>,
1260    }
1261
1262    impl Arbitrary for AxisOp {
1263        type Parameters = TVec<usize>;
1264        type Strategy = BoxedStrategy<AxisOp>;
1265        fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1266            let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1267            if shape.len() > 1 {
1268                ops = ops
1269                    .prop_union(
1270                        (0..shape.len(), 0..shape.len() - 1)
1271                            .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1272                            .boxed(),
1273                    )
1274                    .boxed()
1275            }
1276            let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1277            if rms.len() > 0 {
1278                ops = ops
1279                    .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1280                    .boxed()
1281            }
1282            let mergeable: Vec<AxisOp> = shape
1283                .windows(2)
1284                .enumerate()
1285                .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1286                .map(|(ix, w)| {
1287                    Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1288                })
1289                .collect();
1290            if mergeable.len() > 1 {
1291                ops = ops
1292                    .prop_union(
1293                        (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1294                    )
1295                    .boxed()
1296            }
1297            ops
1298        }
1299    }
1300
1301    impl Arbitrary for ComposeProblem {
1302        type Parameters = ();
1303        type Strategy = BoxedStrategy<ComposeProblem>;
1304        fn arbitrary_with(_args: ()) -> Self::Strategy {
1305            let input = proptest::collection::vec(1usize..4, 1usize..4);
1306            fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1307                if len == 0 {
1308                    Just(tvec!()).boxed()
1309                } else {
1310                    AxisOp::arbitrary_with(shape.clone())
1311                        .prop_flat_map(move |op| {
1312                            let mut shape = shape.clone();
1313                            op.change_shape_array(&mut shape, false).unwrap();
1314                            tail(len - 1, shape.clone()).prop_map(move |mut t| {
1315                                t.insert(0, op.clone());
1316                                t
1317                            })
1318                        })
1319                        .boxed()
1320                }
1321            }
1322            (input, 1usize..=5)
1323                .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1324                .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1325                .boxed()
1326        }
1327    }
1328
1329    impl ComposeProblem {
1330        pub fn model(&self) -> TractResult<TypedModel> {
1331            let mut model = TypedModel::default();
1332            let mut wire = model.add_source("source", i64::fact(&self.input))?;
1333            for (ix, op) in self.ops.iter().enumerate() {
1334                wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1335            }
1336            model.set_output_outlets(&[wire])?;
1337            Ok(model)
1338        }
1339
1340        fn input(&self) -> TractResult<Tensor> {
1341            unsafe {
1342                let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1343                for i in 0..t.len() {
1344                    t.as_slice_mut().unwrap()[i] = i as i64;
1345                }
1346                Ok(t)
1347            }
1348        }
1349
1350        fn check(&self) -> TractResult<()> {
1351            crate::setup_test_logger();
1352            let input = self.input()?;
1353            let model = self.model()?;
1354            let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1355            let optimized = self.model()?.into_decluttered()?;
1356            let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1357            opt[0].close_enough(&raw[0], false)
1358        }
1359    }
1360
1361    proptest! {
1362        #[test]
1363        fn recip(pb in any::<AxisOp>()) {
1364            assert_eq!(pb.recip().recip(), pb);
1365        }
1366
1367        #[test]
1368        fn axis_ops(pb in any::<ComposeProblem>()) {
1369            pb.check().unwrap()
1370        }
1371    }
1372
1373    #[test]
1374    fn add_0_rm_0() {
1375        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1376        pb.check().unwrap();
1377    }
1378
1379    #[test]
1380    fn add_0_move_01() {
1381        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1382        pb.check().unwrap();
1383    }
1384
1385    #[test]
1386    fn add_0_move_01_add_1() {
1387        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1388        pb.check().unwrap();
1389    }
1390
1391    #[test]
1392    fn recip_move_01() {
1393        let op = Move(1, 0);
1394        assert_eq!(op.recip().recip(), op);
1395    }
1396
1397    #[test]
1398    fn recip_move_20() {
1399        let op = Move(2, 0);
1400        assert_eq!(op.recip().recip(), op);
1401    }
1402
1403    #[test]
1404    fn recip_move_02() {
1405        let op = Move(0, 2);
1406        assert_eq!(op.recip().recip(), op);
1407    }
1408
1409    #[test]
1410    fn add_0_add_1_move_02() {
1411        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1412        pb.check().unwrap();
1413    }
1414
1415    #[test]
1416    fn add_0_add_0() {
1417        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1418        pb.check().unwrap();
1419    }
1420
1421    #[test]
1422    fn add_0_add_0_move_02() {
1423        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1424        pb.check().unwrap();
1425    }
1426
1427    #[test]
1428    fn add_0_add_2_move_12() {
1429        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1430        pb.check().unwrap();
1431    }
1432
1433    #[test]
1434    fn add_0_add_0_move_02_rm_0() {
1435        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1436        pb.check().unwrap();
1437    }
1438
1439    #[test]
1440    fn add_0_add_0_move_20_move_20() {
1441        let pb =
1442            ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1443        pb.check().unwrap();
1444    }
1445
1446    #[test]
1447    fn move_01_add_0() {
1448        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1449        pb.check().unwrap();
1450    }
1451
1452    #[test]
1453    fn add_0_move_02_move_02() {
1454        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1455        pb.check().unwrap();
1456    }
1457
1458    #[test]
1459    fn add_0_add_2_move_20_move_12_rm_2() {
1460        let pb = ComposeProblem {
1461            input: tvec![3],
1462            ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1463        };
1464        pb.check().unwrap();
1465    }
1466
1467    #[test]
1468    fn move_02_move_02() {
1469        let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1470        pb.check().unwrap();
1471    }
1472
1473    #[test]
1474    fn rm_1_perm_10_add_0() {
1475        let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1476        pb.check().unwrap();
1477    }
1478
1479    #[test]
1480    fn add_2_move_02_move_02() {
1481        let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1482        pb.check().unwrap();
1483    }
1484
1485    #[test]
1486    fn move_01_move_20_move_20() {
1487        let pb = ComposeProblem {
1488            input: tvec![2, 3, 2],
1489            ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1490        };
1491        pb.check().unwrap();
1492    }
1493
1494    #[test]
1495    fn reshape_axes_tracking() {
1496        let pb = ComposeProblem {
1497            input: tvec![2, 2, 2],
1498            ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1499        };
1500        pb.check().unwrap();
1501    }
1502
1503    #[test]
1504    fn simplify_reshape() {
1505        macro_rules! d {
1506            ($($dim: expr),*) =>  { tvec!($($dim.to_dim()),*) }
1507        }
1508        assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1509        assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1510        assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1511        assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1512        assert_eq!(
1513            Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1514            tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1515        );
1516        assert_eq!(
1517            Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1518            tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1519        );
1520        assert_eq!(
1521            Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1522            tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1523        );
1524        assert_eq!(
1525            Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1526            tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1527        );
1528        assert_eq!(
1529            Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1530            tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1531        );
1532        assert_eq!(
1533            Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1534            tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1535        );
1536        assert_eq!(
1537            Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1538            tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1539        );
1540        assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1541    }
1542
1543    macro_rules! s {
1544        ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1545    }
1546
1547    macro_rules! r {
1548        ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1549            AxisOp::Reshape($at, tvec!($($from.into()),*),  tvec!($($to.into()),*))
1550        }
1551    }
1552
1553    #[test]
1554    fn compute_invalid() {
1555        assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1556    }
1557
1558    #[test]
1559    fn compute_with_leading_zero() {
1560        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1561    }
1562
1563    #[test]
1564    fn compute_with_leading_zero_with_flatten() {
1565        assert_eq!(
1566            &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1567            s![2, 3, 35]
1568        )
1569    }
1570
1571    #[test]
1572    fn compute_with_trailing_zero() {
1573        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1574    }
1575
1576    #[test]
1577    fn compute_bug_1() {
1578        let table = SymbolScope::default();
1579        let s = table.new_with_prefix("S");
1580        assert_eq!(
1581            &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1582            s![s, 1, 256]
1583        )
1584    }
1585
1586    #[test]
1587    fn compute_bug_2() {
1588        let table = SymbolScope::default();
1589        let b = table.new_with_prefix("B");
1590        let s = table.new_with_prefix("S");
1591        assert_eq!(
1592            &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1593            s![s, b, 256]
1594        )
1595    }
1596
1597    #[test]
1598    fn axis_op_rm_begin() {
1599        assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1600    }
1601
1602    #[test]
1603    fn axis_op_rm_end() {
1604        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1605    }
1606
1607    #[test]
1608    fn axis_op_insert_begin() {
1609        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1610    }
1611
1612    #[test]
1613    fn axis_op_insert_end() {
1614        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1615    }
1616
1617    #[test]
1618    fn axis_op_merge() {
1619        assert_eq!(
1620            &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1621            &[r!(2 ; 5,7 => 35 )]
1622        )
1623    }
1624
1625    #[test]
1626    fn axis_op_complex() {
1627        assert_eq!(
1628            &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1629            &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1630        )
1631    }
1632}