Skip to main content

tract_core/ops/
change_axes.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use AxisOp::*;
8use num_traits::One;
9use tract_itertools::Itertools;
10use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
11use tract_ndarray::{ArrayViewD, ArrayViewMutD};
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum InOut {
15    Out(usize),
16    In(usize),
17}
18
19impl InOut {
20    pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
21        match self {
22            InOut::In(ix) => node.inputs[*ix],
23            InOut::Out(ix) => OutletId::new(node.id, *ix),
24        }
25    }
26
27    pub fn is_input(&self) -> bool {
28        matches!(self, InOut::In(_))
29    }
30
31    pub fn is_output(&self) -> bool {
32        matches!(self, InOut::Out(_))
33    }
34
35    pub fn slot(&self) -> usize {
36        match self {
37            InOut::Out(o) => *o,
38            InOut::In(i) => *i,
39        }
40    }
41}
42
43#[derive(Clone, Hash, Eq)]
44#[allow(clippy::large_enum_variant)] // FIXME ?
45#[allow(clippy::derived_hash_with_manual_eq)] // FIXME. this one may be pretty bad. how about a.canonical() == b.canonical() ? need proper canonicalizeation of Reshape
46pub enum AxisOp {
47    Add(usize),
48    Rm(usize),
49    Move(usize, usize),
50    Reshape(usize, TVec<TDim>, TVec<TDim>),
51}
52
53impl Debug for AxisOp {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AxisOp::Add(a) => write!(f, "Add({a})"),
57            AxisOp::Rm(a) => write!(f, "Rm({a})"),
58            AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
59            AxisOp::Reshape(at, from, to) => {
60                write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
61            }
62        }
63    }
64}
65
66impl PartialEq for AxisOp {
67    fn eq(&self, other: &AxisOp) -> bool {
68        if self.is_noop() && other.is_noop() {
69            true
70        } else if self.is_noop() != other.is_noop() {
71            false
72        } else {
73            match (self, other) {
74                (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
75                (Move(f1, t1), Move(f2, t2)) => {
76                    (f1 == f2 && t1 == t2)
77                        || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
78                }
79                (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
80                _ => false,
81            }
82        }
83    }
84}
85
86impl AxisOp {
87    pub fn canonical(&self) -> Cow<'_, AxisOp> {
88        match self {
89            Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
90            Reshape(at, from, to)
91                if from.len() == 1 && to.len() == 2 && from[0] == to[0] && to[1].is_one() =>
92            {
93                Cow::Owned(Add(*at + 1))
94            }
95            Reshape(at, from, to)
96                if from.len() == 1 && to.len() == 2 && from[0] == to[1] && to[0].is_one() =>
97            {
98                Cow::Owned(Add(*at))
99            }
100            Reshape(at, from, to)
101                if from.len() == 2 && to.len() == 1 && from[0] == to[0] && from[1].is_one() =>
102            {
103                Cow::Owned(Rm(*at + 1))
104            }
105            Reshape(at, from, to)
106                if from.len() == 2 && to.len() == 1 && from[1] == to[0] && from[0].is_one() =>
107            {
108                Cow::Owned(Rm(*at))
109            }
110            other => Cow::Borrowed(other),
111        }
112    }
113
114    pub fn simplify(&self) -> TVec<AxisOp> {
115        match self.canonical().borrow() {
116            Reshape(_, from, to) if from == to => tvec!(),
117            Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
118            Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
119            Reshape(at, from, to) if from[0] == to[0] => {
120                Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
121            }
122            Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
123                Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
124            }
125            Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
126                .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
127                .collect(),
128            Reshape(at, from, to) if to[0] == 1.to_dim() => {
129                Reshape(*at, from.clone(), to[1..].into())
130                    .simplify()
131                    .into_iter()
132                    .chain(std::iter::once(Add(*at)))
133                    .collect()
134            }
135            Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
136                std::iter::once(Rm(at + from.len() - 1))
137                    .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
138                    .collect()
139            }
140            Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
141                std::iter::once(Add(at + from.len()))
142                    .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
143                    .collect()
144            }
145            other => tvec!(other.clone()),
146        }
147    }
148
149    pub fn transform_axis(&self, axis: usize) -> Option<usize> {
150        match self.canonical().as_ref() {
151            Add(ix) => Some(axis + (axis >= *ix) as usize),
152            Rm(ix) => {
153                if axis == *ix {
154                    None
155                } else {
156                    Some(axis - (axis > *ix) as usize)
157                }
158            }
159            Move(from, to) if from < to => {
160                if axis < *from || axis > *to {
161                    Some(axis)
162                } else if axis == *from {
163                    Some(*to)
164                } else {
165                    Some(axis - 1)
166                }
167            }
168            Move(from, to) => {
169                if axis < *to || axis > *from {
170                    Some(axis)
171                } else if axis == *from {
172                    Some(*to)
173                } else {
174                    Some(axis + 1)
175                }
176            }
177            Reshape(at, _, _) if axis < *at => Some(axis),
178            Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
179            Reshape(_, _, _) => None,
180        }
181    }
182
183    // if sucessful return Some()
184    // first item is the Op we want to be replaced by. if none, we are now identity.
185    // second item is the change to propagate. if none, the output is not
186    // changed
187    pub fn merge_incoming_change(
188        &self,
189        change: &AxisOp,
190    ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
191        match (self.canonical().as_ref(), change.canonical().as_ref()) {
192            (Add(op), Add(c)) => {
193                Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
194            }
195            (Add(op), Rm(c)) => {
196                Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
197            }
198            (Rm(op), Add(c)) => {
199                Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
200            }
201            (Rm(op), Rm(c)) => {
202                Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
203            }
204
205            (Add(x), Move(from, to)) => {
206                if x <= from.min(to) {
207                    Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
208                } else if x > from.max(to) {
209                    Some((Some(self.clone()), Some(change.clone())))
210                } else {
211                    None
212                }
213            }
214
215            (Move(from, to), Add(x)) => {
216                if x <= from.min(to) {
217                    Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
218                } else if x > from.max(to) {
219                    Some((Some(Move(*from, *to)), Some(Add(*x))))
220                } else {
221                    None
222                }
223            }
224
225            (Rm(x), Move(from, to)) => {
226                if x == from {
227                    Some((Some(Rm(*to)), None))
228                } else if x < from.min(to) {
229                    Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
230                } else if x > from.max(to) {
231                    Some((Some(self.clone()), Some(change.clone())))
232                } else if from + 1 == *to && x == to {
233                    Some((Some(Rm(*from)), None))
234                } else if from < to && x <= to {
235                    Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
236                } else {
237                    Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
238                }
239            }
240
241            (Move(from, to), Rm(x)) => {
242                if x < from.min(to) {
243                    Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
244                } else if x > from.max(to) {
245                    Some((Some(Move(*from, *to)), Some(Rm(*x))))
246                } else {
247                    None
248                }
249            }
250
251            (Add(op), Reshape(at, from, to)) => {
252                if op <= at {
253                    Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
254                } else if *op > at + from.len() {
255                    Some((
256                        Some(Add(*op + to.len() - from.len())),
257                        Some(Reshape(*at, from.clone(), to.clone())),
258                    ))
259                } else {
260                    None
261                }
262            }
263            (Rm(op), Reshape(at, from, to)) => {
264                if op < at {
265                    Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
266                } else if *op > at + from.len() {
267                    Some((
268                        Some(Rm(*op + to.len() - from.len())),
269                        Some(Reshape(*at, from.clone(), to.clone())),
270                    ))
271                } else {
272                    None
273                }
274            }
275            (Reshape(at, from, to), Add(change)) => {
276                if change < at {
277                    Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
278                } else if *change > *at + from.len() {
279                    Some((
280                        Some(Reshape(*at, from.clone(), to.clone())),
281                        Some(Add(change + to.len() - from.len())),
282                    ))
283                } else {
284                    None
285                }
286            }
287            (Reshape(at, from, to), Rm(change)) => {
288                if change < at {
289                    Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
290                } else if *change > *at + from.len() {
291                    Some((
292                        Some(Reshape(*at, from.clone(), to.clone())),
293                        Some(Rm(change + to.len() - from.len())),
294                    ))
295                } else {
296                    None
297                }
298            }
299            (Reshape(_, _, _), Move(_, _)) => None, // todo, some are manageable
300            (Move(_, _), Reshape(_, _, _)) => None, // todo, some are manageable
301            (Reshape(_, _, _), Reshape(_, _, _)) => None, // todo, some are manageable
302            _ => None,
303        }
304    }
305
306    pub fn change_shape_array<D: DimLike>(
307        &self,
308        shape: &mut TVec<D>,
309        broadcasting: bool,
310    ) -> TractResult<()> {
311        match self.canonical().as_ref() {
312            Add(ix) => {
313                ensure!(*ix <= shape.len());
314                shape.insert(*ix, D::one());
315            }
316            Rm(ix) => {
317                ensure!(*ix < shape.len());
318                shape.remove(*ix);
319            }
320            Move(from, to) => {
321                ensure!(*from < shape.len());
322                ensure!(*to < shape.len());
323                let axis = shape.remove(*from);
324                shape.insert(*to, axis);
325            }
326            Reshape(at, from, to) => {
327                let from_volume = from.iter().product::<TDim>();
328                let to_volume = to.iter().product::<TDim>();
329                ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
330                ensure!(*at + from.len() <= shape.len());
331                if shape.len() >= from.len() + *at
332                    && tract_itertools::izip!(shape.iter().skip(*at), from)
333                        .all(|(shape, spec)| shape.to_dim() == *spec)
334                {
335                    for _ in from {
336                        shape.remove(*at);
337                    }
338                    for d in to.iter().rev() {
339                        shape.insert(*at, d.try_into()?);
340                    }
341                } else if broadcasting
342                    && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
343                {
344                    for _ in from {
345                        shape.remove(*at);
346                    }
347                    for _ in to.iter().rev() {
348                        shape.insert(*at, 1.into());
349                    }
350                } else {
351                    bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
352                }
353            }
354        }
355        Ok(())
356    }
357
358    pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
359        match self.canonical().as_ref() {
360            Add(ix) => shape.insert_axis(*ix),
361            Rm(ix) => {
362                if shape.rank() <= *ix {
363                    bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
364                }
365                if shape[*ix] != 1.to_dim() {
366                    bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
367                }
368                shape.remove_axis(*ix)
369            }
370            _ => {
371                let mut array = shape.to_tvec();
372                self.change_shape_array(&mut array, broadcasting)?;
373                let mut new_shape = ShapeFact::from_dims(array);
374                std::mem::swap(shape, &mut new_shape);
375                Ok(())
376            }
377        }
378    }
379
380    pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
381        if tensor.storage_as::<BlockQuantStorage>().is_some() {
382            let bqs = tensor.try_storage_as::<BlockQuantStorage>()?.clone();
383            let mut new_shape: TVec<usize> = tensor.shape().into();
384            self.change_shape_array(&mut new_shape, false)?;
385            let mut new_tensor = bqs.into_tensor_with_shape(tensor.datum_type(), &new_shape);
386            std::mem::swap(tensor, &mut new_tensor);
387            return Ok(());
388        }
389        ensure!(self.required_rank() <= tensor.rank());
390        match self.canonical().as_ref() {
391            Add(ix) => tensor.insert_axis(*ix),
392            Rm(ix) => tensor.remove_axis(*ix),
393            Move(from, to) => {
394                let mut tmp = tensor.clone().move_axis(*from, *to)?;
395                std::mem::swap(tensor, &mut tmp);
396                Ok(())
397            }
398            Reshape(at, from, to) => {
399                let mut shape: TVec<usize> = tensor.shape().into();
400                self.change_shape_array(&mut shape, true)?;
401                if tensor.set_shape(&shape).is_ok() {
402                    Ok(())
403                } else if broadcasting
404                    && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
405                {
406                    if from.len() > to.len() {
407                        for _ in to.len()..from.len() {
408                            tensor.remove_axis(*at)?;
409                        }
410                    }
411                    if to.len() > from.len() {
412                        for _ in from.len()..to.len() {
413                            tensor.insert_axis(*at)?;
414                        }
415                    }
416                    Ok(())
417                } else {
418                    bail!(
419                        "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
420                        self,
421                        tensor,
422                        broadcasting
423                    )
424                }
425            }
426        }
427    }
428
429    pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
430        use tract_ndarray::Axis;
431        match *self {
432            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
433            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
434            AxisOp::Move(from, to) if from < to => {
435                for left in from..to {
436                    view.swap_axes(left, left + 1);
437                }
438            }
439            AxisOp::Move(from, to) => {
440                for left in (to..from).rev() {
441                    view.swap_axes(left, left + 1);
442                }
443            }
444            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
445        }
446        Ok(())
447    }
448
449    pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
450        use tract_ndarray::Axis;
451        match *self {
452            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
453            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
454            AxisOp::Move(from, to) if from < to => {
455                for left in from..to {
456                    view.swap_axes(left, left + 1);
457                }
458            }
459            AxisOp::Move(from, to) => {
460                for left in (to..from).rev() {
461                    view.swap_axes(left, left + 1);
462                }
463            }
464            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
465        }
466        Ok(())
467    }
468
469    pub fn recip(&self) -> AxisOp {
470        match self.canonical().as_ref() {
471            Add(ix) => Rm(*ix),
472            Rm(ix) => Add(*ix),
473            Move(from, to) if from == to => self.clone(),
474            Move(from, to) if *from + 1 == *to => self.clone(),
475            Move(from, to) if *from == *to + 1 => {
476                unreachable!();
477            }
478            Move(from, to) => Move(*to, *from),
479            Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
480        }
481    }
482
483    pub fn is_noop(&self) -> bool {
484        match self {
485            Move(f, t) if f == t => true,
486            Reshape(_, f, t) if f == t => true,
487            _ => false,
488        }
489    }
490
491    pub fn only_shape(&self) -> bool {
492        if self.is_noop() {
493            return true;
494        }
495        !matches!(self, Move(_, _))
496    }
497
498    pub fn wire_split_axis(
499        model: &mut TypedModel,
500        name: impl ToString,
501        outlet: OutletId,
502        axis: usize,
503        outer_dim: usize,
504    ) -> TractResult<TVec<OutletId>> {
505        let fact = model.outlet_fact(outlet)?;
506        let dim: TDim = fact.shape[axis].clone();
507        let inner_dim = dim.clone() / outer_dim;
508        let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
509        model.wire_node(name.to_string(), op, &[outlet])
510    }
511
512    pub fn wire_collapse_axis(
513        model: &mut TypedModel,
514        name: impl ToString,
515        outlet: OutletId,
516        axis: usize,
517    ) -> TractResult<TVec<OutletId>> {
518        let fact = model.outlet_fact(outlet)?;
519        let dim: TDim = fact.shape[axis].clone();
520        let next_dim: TDim = fact.shape[axis + 1].clone();
521        let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
522        model.wire_node(name.to_string(), op, &[outlet])
523    }
524
525    #[inline]
526    pub fn required_rank(&self) -> usize {
527        match self {
528            Rm(r) => r + 1,
529            Add(a) => *a,
530            Reshape(at, from, _to) => at + from.len(),
531            Move(from, to) => *from.max(to),
532        }
533    }
534
535    pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
536        Ok(match self {
537            Rm(r) if *r >= prefix => Rm(r - prefix),
538            Add(a) if *a >= prefix => Add(a - prefix),
539            Reshape(at, from, to) if *at >= prefix => {
540                Reshape(at - prefix, from.clone(), to.clone())
541            }
542            Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
543            _ => bail!("Can no trim left {self:?} by {prefix}"),
544        })
545    }
546}
547
548pub fn wire_rank_broadcast(
549    prefix: impl AsRef<str>,
550    target: &mut TypedModel,
551    inputs: &[OutletId],
552) -> TractResult<TVec<OutletId>> {
553    let facts =
554        inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
555    let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
556    let mut wires = tvec!();
557    for i in 0..inputs.len() {
558        let mut wire = inputs[i];
559        for _ in facts[i].rank()..max_rank {
560            let name = target.unique_name(prefix.as_ref().to_string() + ".fix-rank");
561            wire = target.wire_node(name, AxisOp::Add(0), &[wire])?[0];
562        }
563        wires.push(wire);
564    }
565    Ok(wires)
566}
567
568pub fn wire_with_rank_broadcast(
569    prefix: impl AsRef<str>,
570    target: &mut TypedModel,
571    op: impl Into<Box<dyn TypedOp>>,
572    inputs: &[OutletId],
573) -> TractResult<TVec<OutletId>> {
574    let prefix = prefix.as_ref();
575    let wires = wire_rank_broadcast(prefix, target, inputs)?;
576    target.wire_node(prefix, op.into(), &wires)
577}
578
579#[derive(Clone, Debug, PartialEq, Eq, Hash)]
580pub struct AxisChange {
581    pub outlet: OutletId,
582    pub op: AxisOp,
583}
584
585#[derive(Clone, Default, Debug)]
586pub struct AxisChangeConsequence {
587    pub substitute_op: Option<Box<dyn TypedOp>>,
588    pub wire_changes: TVec<(InOut, AxisOp)>,
589}
590
591impl AxisChangeConsequence {
592    pub fn new(
593        _model: &TypedModel,
594        node: &TypedNode,
595        op: Option<Box<dyn TypedOp>>,
596        axis_op: &AxisOp,
597    ) -> AxisChangeConsequence {
598        let mut wire_changes = tvec!();
599        for i in 0..node.inputs.len() {
600            wire_changes.push((InOut::In(i), axis_op.clone()));
601        }
602        for i in 0..node.outputs.len() {
603            wire_changes.push((InOut::Out(i), axis_op.clone()));
604        }
605        AxisChangeConsequence { wire_changes, substitute_op: op }
606    }
607}
608
609impl Op for AxisOp {
610    fn name(&self) -> StaticName {
611        match self {
612            Add(_) => "AddAxis".into(),
613            Rm(_) => "RmAxis".into(),
614            Move(_, _) => "MoveAxis".into(),
615            Reshape(_, _, _) => "Reshape".into(),
616        }
617    }
618
619    fn info(&self) -> TractResult<Vec<String>> {
620        match self {
621            Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
622            Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
623            Reshape(at, from, to) => Ok(vec![format!(
624                "Axes starting at {}: {:?} to {:?}",
625                at,
626                from.iter().join(","),
627                to.iter().join(",")
628            )]),
629        }
630    }
631
632    op_as_typed_op!();
633}
634
635impl EvalOp for AxisOp {
636    fn is_stateless(&self) -> bool {
637        true
638    }
639
640    fn eval_with_session(
641        &self,
642        _node_id: usize,
643        session: &TurnState,
644        inputs: TVec<TValue>,
645    ) -> TractResult<TVec<TValue>> {
646        let mut input = args_1!(inputs).into_tensor();
647        match self {
648            AxisOp::Reshape(skip, from, to) => {
649                let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
650                let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
651                AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
652            }
653            _ => self.change_tensor(&mut input, false)?,
654        }
655        Ok(tvec!(input.into_tvalue()))
656    }
657}
658
659/// Remap coordinate symbols in a TDim expression according to an AxisOp.
660/// Returns None if the remapping cannot be determined (e.g. general reshape).
661fn remap_uniform_tdim(expr: &TDim, axis_op: &AxisOp) -> Option<TDim> {
662    let syms = expr.symbols();
663    let coord_syms: Vec<(usize, Symbol)> = syms
664        .into_iter()
665        .filter_map(|s| {
666            let name = format!("{s}");
667            name.strip_prefix("🎯").and_then(|rest| rest.parse::<usize>().ok()).map(|k| (k, s))
668        })
669        .collect();
670
671    if coord_syms.is_empty() {
672        // No coordinate symbols – the value is uniform across all positions; propagate as-is.
673        return Some(expr.clone());
674    }
675
676    // Reshape: only handle trivial all-ones case.
677    if let AxisOp::Reshape(_, from_dims, to_dims) = axis_op.canonical().as_ref() {
678        return if from_dims.iter().all(|d| d.is_one()) && to_dims.iter().all(|d| d.is_one()) {
679            Some(expr.clone())
680        } else {
681            None
682        };
683    }
684
685    // For Add/Rm/Move: use transform_axis and substitute all at once to avoid
686    // double-substitution when two axes swap positions (e.g. Move).
687    let map: HashMap<Symbol, TDim> = coord_syms
688        .into_iter()
689        .filter_map(|(k, sym)| {
690            let new_k = axis_op.transform_axis(k)?;
691            if new_k == k {
692                return None;
693            }
694            let scope = sym.scope()?;
695            Some((sym, TDim::Sym(scope.coord_sym(new_k))))
696        })
697        .collect();
698    if map.is_empty() {
699        return Some(expr.clone());
700    }
701    expr.substitute_all(&map).ok()
702}
703
704impl TypedOp for AxisOp {
705    as_op!();
706
707    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
708        if let Some(bqf) =
709            inputs[0].exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
710        {
711            let mut new_shape: TVec<usize> = bqf.shape().into();
712            self.change_shape_array(&mut new_shape, false)?;
713            let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_shape.clone());
714            let shape: TVec<TDim> = new_shape.iter().map(|d| d.to_dim()).collect();
715            let mut new_fact = inputs[0].datum_type.fact(&*shape).with_exotic_fact(new_bqf);
716            if let Some(k) = &inputs[0].konst {
717                let mut new = k.clone().into_tensor();
718                self.change_tensor(&mut new, false)?;
719                new_fact.konst = Some(new.into());
720            }
721            return Ok(tvec!(new_fact));
722        }
723        let mut shape = inputs[0].shape.clone();
724        self.change_shape(&mut shape, false)?;
725        let mut fact = inputs[0].datum_type.fact(shape);
726        fact.exotic_fact.clone_from(&inputs[0].exotic_fact);
727        if let Some(tdim) = &inputs[0].uniform_tdim {
728            fact.uniform_tdim = remap_uniform_tdim(tdim, self);
729        }
730        Ok(tvec!(fact))
731    }
732
733    fn input_roi(
734        &self,
735        model: &TypedModel,
736        node: &TypedNode,
737    ) -> TractResult<Option<TVec<Option<TDim>>>> {
738        crate::optim::propagate_roi::bubble_roi(model, node)
739    }
740
741    fn axes_mapping(
742        &self,
743        inputs: &[&TypedFact],
744        outputs: &[&TypedFact],
745    ) -> TractResult<AxesMapping> {
746        let mut axes: Vec<Axis> = (0..inputs[0].rank())
747            .zip('a'..)
748            .map(|(axis_id, repr)| {
749                let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
750                if let Some(out) = self.transform_axis(axis_id) {
751                    axis = axis.output(0, out);
752                }
753                axis
754            })
755            .collect();
756        for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
757            if self.recip().transform_axis(axis).is_none() {
758                axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
759            }
760        }
761        AxesMapping::new(inputs.len(), outputs.len(), axes)
762    }
763
764    fn declutter(
765        &self,
766        model: &TypedModel,
767        node: &TypedNode,
768    ) -> TractResult<Option<TypedModelPatch>> {
769        if self.is_noop()
770            && let Some(p) = TypedModelPatch::shunt_one_op(model, node)?
771        {
772            return Ok(Some(p));
773        }
774        let simplified = self.simplify();
775        if simplified.len() != 1 || &simplified[0] != self {
776            let mut patch = TypedModelPatch::default();
777            let mut wire = patch.tap_model(model, node.inputs[0])?;
778            for (ix, op) in simplified.into_iter().enumerate() {
779                wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
780            }
781            patch.shunt_outside(model, node.id.into(), wire)?;
782            Ok(Some(patch))
783        } else {
784            Ok(None)
785        }
786    }
787
788    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
789        Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
790    }
791
792    fn change_axes(
793        &self,
794        _model: &TypedModel,
795        _node: &TypedNode,
796        io: InOut,
797        change: &AxisOp,
798    ) -> TractResult<Option<AxisChangeConsequence>> {
799        let op = if let InOut::Out(0) = io {
800            rule_if_some!(more = self.recip().change_axes(_model, _node, InOut::In(0), change)?);
801            AxisChangeConsequence {
802                substitute_op: more.substitute_op.map(|op| {
803                    if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
804                        Box::new(op.recip())
805                    } else {
806                        op // have to be identity
807                    }
808                }),
809                wire_changes: more
810                    .wire_changes
811                    .into_iter()
812                    .map(|wc| {
813                        (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
814                    })
815                    .collect(),
816            }
817        } else if change == self {
818            AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
819        } else {
820            rule_if_some!((new_op, new_change) = self.merge_incoming_change(change));
821            trace!("  Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
822            let substitute_op: Box<dyn TypedOp> =
823                if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
824            let mut wire_changes = tvec!();
825            if !change.is_noop() {
826                wire_changes.push((InOut::In(0), change.clone()))
827            }
828            if let Some(new_change) = new_change {
829                wire_changes.push((InOut::Out(0), new_change))
830            }
831            AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
832        };
833        Ok(Some(op))
834    }
835
836    fn concretize_dims(
837        &self,
838        _source: &TypedModel,
839        node: &TypedNode,
840        target: &mut TypedModel,
841        mapping: &HashMap<OutletId, OutletId>,
842        values: &SymbolValues,
843    ) -> TractResult<TVec<OutletId>> {
844        let op = if let AxisOp::Reshape(axis, from, to) = self {
845            AxisOp::Reshape(
846                *axis,
847                from.iter().map(|d| d.eval(values)).collect(),
848                to.iter().map(|d| d.eval(values)).collect(),
849            )
850        } else {
851            self.clone()
852        };
853        target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
854    }
855
856    fn slice(
857        &self,
858        patch: &mut TypedModelPatch,
859        _model: &TypedModel,
860        node: &TypedNode,
861        _prefix: &str,
862        inputs: &[OutletId],
863        output_axis: usize,
864        _start: &TDim,
865        _end: &TDim,
866    ) -> TractResult<Option<TVec<OutletId>>> {
867        // is this test really useful ? or axis mapping preempt this ?
868        if let Reshape(pos, _from, to) = self
869            && output_axis >= *pos
870            && output_axis < pos + to.len()
871        {
872            return Ok(None);
873        }
874        patch.wire_node(&node.name, &node.op, inputs).map(Some)
875    }
876
877    fn codegen(
878        &self,
879        model: &TypedModel,
880        node: &TypedNode,
881    ) -> TractResult<Option<TypedModelPatch>> {
882        if node.outputs[0].fact.exotic_fact.is_some() {
883            return Ok(None);
884        }
885        if let Some(shape) = node.outputs[0].fact.shape.as_concrete()
886            && !matches!(self, AxisOp::Move(_, _))
887        {
888            let (inputs, outputs) = model.node_facts(node.id)?;
889            let mapping = self.axes_mapping(&inputs, &outputs)?;
890            let op = IntoShape {
891                mapping,
892                len: shape.iter().product(),
893                strides: Tensor::natural_strides(shape),
894                dims: shape.into(),
895            };
896            return Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, op)?));
897        }
898        Ok(None)
899    }
900}
901
902// a, b, c is a <- b, b <- c, c <- a
903fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
904    let mut cycles: TVec<TVec<usize>> = tvec!();
905    let mut done = 0;
906    while done < perm.len() {
907        if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
908            done += 1;
909            continue;
910        }
911        let mut cycle = tvec!();
912        let mut current = done;
913        loop {
914            cycle.push(current);
915            current = perm[current];
916            if current == done {
917                break;
918            }
919        }
920        cycles.push(cycle)
921    }
922    cycles
923}
924
925fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
926    if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
927        Some((cycle[0], cycle[cycle.len() - 1]))
928    } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
929        && cycle[cycle.len() - 1] - 1 == cycle[0]
930    {
931        Some((cycle[1], cycle[0]))
932    } else {
933        None
934    }
935}
936
937fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
938    let mut changes: TVec<(usize, usize)> = tvec!();
939    'top: loop {
940        let mut reached: TVec<usize> = (0..input.len()).collect();
941        changes.iter().for_each(|(f, t)| {
942            let axis = reached.remove(*f);
943            reached.insert(*t, axis);
944        });
945        if &*reached == input {
946            return changes;
947        }
948        let remaining: TVec<usize> =
949            input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
950        let cycles = perm_to_cycles(&remaining);
951        for cycle in &cycles {
952            if let Some(rot) = is_rotation_cycle(cycle) {
953                changes.push(rot);
954                continue 'top;
955            }
956        }
957        changes.push((cycles[0][1], cycles[0][0]));
958    }
959}
960
961pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
962    perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
963}
964
965pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
966    let mut shape: TVec<TDim> = shape_spec.into();
967    fn deal_with_zero<'a>(
968        mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
969        shape: &mut [TDim],
970    ) -> TractResult<()> {
971        let mut remaining_dim_input = 1.to_dim();
972        for slot in shape.iter_mut() {
973            if *slot == (-1).into() {
974                break;
975            }
976            if *slot == 0.into() {
977                if remaining_dim_input != TDim::one() {
978                    bail!("Invalid remaining dim");
979                }
980                *slot = (*input_dims.peek().context("Invalid")?).clone();
981            }
982            loop {
983                let quotient = remaining_dim_input.maybe_div(slot);
984                if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
985                    remaining_dim_input *= input_dims.next().context("Invalid")?;
986                } else {
987                    break;
988                }
989            }
990            remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
991        }
992        Ok(())
993    }
994
995    deal_with_zero(input.iter().peekable(), &mut shape)?;
996    shape.reverse();
997    deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
998    shape.reverse();
999
1000    if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
1001        let input_vol: TDim = input.iter().product();
1002        let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
1003        let div = input_vol.maybe_div(&shape_vol)?;
1004        if div.1 != 1 {
1005            bail!("invalid")
1006        }
1007        shape[pos] = div.0;
1008    }
1009    Ok(shape)
1010}
1011
1012pub fn to_axis_ops_with_tf_rules(
1013    input_orig: &[TDim],
1014    output_spec: &[TDim],
1015) -> TractResult<TVec<AxisOp>> {
1016    let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
1017    let mut stack: TVec<AxisOp> = tvec!();
1018    'top: loop {
1019        let current_input =
1020            stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
1021                op.change_shape_array(&mut shape, false)?;
1022                Ok(shape)
1023            })?;
1024        if current_input == final_output {
1025            return Ok(stack);
1026        }
1027        if let Some(common) =
1028            current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
1029        {
1030            if current_input[common].is_one() {
1031                stack.push(AxisOp::Rm(common));
1032            } else if final_output[common].is_one() {
1033                stack.push(AxisOp::Add(common));
1034            } else {
1035                // actual regrouping. search for a match. this is quadratic, but
1036                // rank is expected to be somewhat reasonable
1037                for i in common..current_input.len() {
1038                    let i_group = &current_input[common..i + 1];
1039                    let i_volume: TDim = i_group.iter().product();
1040                    for o in common..final_output.len() {
1041                        let o_group = &final_output[common..o + 1];
1042                        let o_volume: TDim = o_group.iter().product();
1043                        if i_volume == o_volume {
1044                            stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1045                            continue 'top;
1046                        }
1047                    }
1048                }
1049                todo!()
1050            }
1051        } else if final_output.len() > current_input.len() {
1052            stack.push(AxisOp::Add(current_input.len()));
1053        } else {
1054            stack.push(AxisOp::Rm(current_input.len() - 1));
1055        }
1056    }
1057}
1058
1059#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1060pub struct IntoShape {
1061    pub mapping: AxesMapping,
1062    pub len: usize,
1063    pub dims: TVec<usize>,
1064    pub strides: TVec<isize>,
1065}
1066
1067impl Op for IntoShape {
1068    fn name(&self) -> StaticName {
1069        "IntoShape".into()
1070    }
1071
1072    fn info(&self) -> TractResult<Vec<String>> {
1073        Ok(vec![format!("{}", self.mapping)])
1074    }
1075
1076    op_as_typed_op!();
1077}
1078
1079impl EvalOp for IntoShape {
1080    fn is_stateless(&self) -> bool {
1081        true
1082    }
1083
1084    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1085        let mut input = args_1!(inputs).into_tensor();
1086        ensure!(input.len() == self.len);
1087        unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1088        Ok(tvec!(input.into_tvalue()))
1089    }
1090}
1091
1092impl TypedOp for IntoShape {
1093    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1094        let mut fact = inputs[0].datum_type.fact(&self.dims);
1095        if let Some(of) = &inputs[0].exotic_fact {
1096            fact = fact.with_exotic_fact(of.clone());
1097        }
1098        Ok(tvec!(fact))
1099    }
1100
1101    fn declutter(
1102        &self,
1103        model: &TypedModel,
1104        node: &TypedNode,
1105    ) -> TractResult<Option<TypedModelPatch>> {
1106        let input = model.outlet_fact(node.inputs[0])?;
1107        if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1108            return TypedModelPatch::shunt_one_op(model, node);
1109        }
1110        if let Some(succ) = model.single_succ(node.id)?
1111            && let Some(into_shape) = succ.op_as::<IntoShape>()
1112        {
1113            let op =
1114                Self { mapping: self.mapping.compose(&into_shape.mapping)?, ..into_shape.clone() };
1115            return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1116        }
1117        Ok(None)
1118    }
1119
1120    as_op!();
1121}
1122
1123#[cfg(test)]
1124mod test {
1125    use super::*;
1126
1127    #[test]
1128    fn test_perm_to_cycles() {
1129        assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1130        assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1131        assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1132        assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1133        assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1134    }
1135
1136    #[test]
1137    fn is_rotation() {
1138        assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1139        assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1140    }
1141
1142    #[test]
1143    fn test_perm_one_rotation() {
1144        assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1145    }
1146
1147    #[test]
1148    fn test_perm_two_rotations() {
1149        assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1150    }
1151
1152    #[test]
1153    fn test_perm_complex() {
1154        assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1155    }
1156
1157    // ADD-ADD
1158
1159    //                          Op
1160    //           b,c   ------|Add(0)|----->        n,b,c
1161    //   Add(0)                                            Add(1)
1162    //         a,b,c   ------|Add(0)|----->        a,n,b,c
1163    #[test]
1164    pub fn transform_op_add_0_add_0() {
1165        let change = Add(0);
1166        let op = Add(0);
1167        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1168    }
1169
1170    //                          Op
1171    //           b,c   ------|Add(1)|----->        b,n,c
1172    //   Add(0)                                                 Add(0)
1173    //         a,b,c   ------|Add(2)|----->        a,b,n,c
1174    #[test]
1175    pub fn transform_op_add_0_add_1() {
1176        let change = Add(0);
1177        let op = Add(1);
1178        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1179    }
1180
1181    //                          Op
1182    //           a,c   ------|Add(0)|----->        n,a,c
1183    //   Add(1)                                                 Add(2)
1184    //         a,b,c   ------|Add(0)|----->        n,a,b,c
1185    #[test]
1186    pub fn transform_op_add_1_add_0() {
1187        let change = Add(1);
1188        let op = Add(0);
1189        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1190    }
1191
1192    //                          Op
1193    //         a,b,c   ------|Rm(1)|----->         a,c
1194    //   Rm(0)                                             Rm(0)
1195    //           b,c   ------|Rm(0)|----->         c
1196    #[test]
1197    pub fn transform_op_rm_0_rm_1() {
1198        let change = Rm(0);
1199        let op = Rm(1);
1200        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1201    }
1202
1203    //                          Op
1204    //         a,b,c   ------|Rm(0)|----->         b,c
1205    //   Rm(1)                                             Rm(0)
1206    //           a,c   ------|Rm(0)|----->         c
1207    #[test]
1208    pub fn transform_op_rm_1_rm_0() {
1209        let change = Rm(1);
1210        let op = Rm(0);
1211        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1212    }
1213
1214    // ADD - RM
1215
1216    //                          Op
1217    //          b,c     ------|Rm(0)|------>        c
1218    //   Add(0)                                                 Add(0)
1219    //          a,b,c   ------|Rm(1)|----->         a,c
1220    #[test]
1221    pub fn transform_op_add_0_rm_0() {
1222        let change = Add(0);
1223        let op = Rm(0);
1224        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1225    }
1226
1227    //                          Op
1228    //          b,c     ------|Rm(1)|------>        b
1229    //   Add(0)                                                 Add(0)
1230    //          a,b,c   ------|Rm(2)|----->         a,b
1231    #[test]
1232    pub fn transform_op_add_0_rm_1() {
1233        let change = Add(0);
1234        let op = Rm(1);
1235        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1236    }
1237
1238    //                          Op
1239    //          a,c     ------|Rm(0)|------>        c
1240    //   Add(1)                                                 Add(0)
1241    //          a,b,c   ------|Rm(0)|----->         b,c
1242    #[test]
1243    pub fn transform_op_add_1_rm_0() {
1244        let change = Add(1);
1245        let op = Rm(0);
1246        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1247    }
1248
1249    // RM - ADD
1250
1251    //                          Op
1252    //         a,b,c   ------|Add(0)|----->        X,a,b,c
1253    //   Rm(1)                                                 Rm(2)
1254    //           a,c   ------|Add(0)|----->        X,a,c
1255    #[test]
1256    pub fn transform_op_rm_1_add_0() {
1257        let change = Rm(1);
1258        let op = Add(0);
1259        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1260    }
1261
1262    //                          Op
1263    //         a,b,c   ------|Add(1)|----->        a,X,b,c
1264    //   Rm(0)                                                 Rm(0)
1265    //           b,c   ------|Add(0)|----->        X,b,c
1266    #[test]
1267    pub fn transform_op_rm_0_add_1() {
1268        let change = Rm(0);
1269        let op = Add(1);
1270        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1271    }
1272
1273    //                          Op
1274    //         a,b,c   ------|Rm(2)|----->        a,b
1275    //   Move(0, 2)                                           Move(0,1)
1276    //         b,c,a   ------|Rm(1)|----->        b,a
1277    #[test]
1278    pub fn transform_op_mv_02_rm_2() {
1279        let change = Move(0, 2);
1280        let op = Rm(2);
1281        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1282    }
1283}
1284
1285#[cfg(test)]
1286mod proptests {
1287    use super::*;
1288    use proptest::prelude::*;
1289
1290    #[derive(Debug)]
1291    struct ComposeProblem {
1292        input: TVec<usize>,
1293        ops: TVec<AxisOp>,
1294    }
1295
1296    impl Arbitrary for AxisOp {
1297        type Parameters = TVec<usize>;
1298        type Strategy = BoxedStrategy<AxisOp>;
1299        fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1300            let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1301            if shape.len() > 1 {
1302                ops = ops
1303                    .prop_union(
1304                        (0..shape.len(), 0..shape.len() - 1)
1305                            .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1306                            .boxed(),
1307                    )
1308                    .boxed()
1309            }
1310            let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1311            if rms.len() > 0 {
1312                ops = ops
1313                    .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1314                    .boxed()
1315            }
1316            let mergeable: Vec<AxisOp> = shape
1317                .windows(2)
1318                .enumerate()
1319                .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1320                .map(|(ix, w)| {
1321                    Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1322                })
1323                .collect();
1324            if mergeable.len() > 1 {
1325                ops = ops
1326                    .prop_union(
1327                        (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1328                    )
1329                    .boxed()
1330            }
1331            ops
1332        }
1333    }
1334
1335    impl Arbitrary for ComposeProblem {
1336        type Parameters = ();
1337        type Strategy = BoxedStrategy<ComposeProblem>;
1338        fn arbitrary_with(_args: ()) -> Self::Strategy {
1339            let input = proptest::collection::vec(1usize..4, 1usize..4);
1340            fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1341                if len == 0 {
1342                    Just(tvec!()).boxed()
1343                } else {
1344                    AxisOp::arbitrary_with(shape.clone())
1345                        .prop_flat_map(move |op| {
1346                            let mut shape = shape.clone();
1347                            op.change_shape_array(&mut shape, false).unwrap();
1348                            tail(len - 1, shape.clone()).prop_map(move |mut t| {
1349                                t.insert(0, op.clone());
1350                                t
1351                            })
1352                        })
1353                        .boxed()
1354                }
1355            }
1356            (input, 1usize..=5)
1357                .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1358                .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1359                .boxed()
1360        }
1361    }
1362
1363    impl ComposeProblem {
1364        pub fn model(&self) -> TractResult<TypedModel> {
1365            let mut model = TypedModel::default();
1366            let mut wire = model.add_source("source", i64::fact(&self.input))?;
1367            for (ix, op) in self.ops.iter().enumerate() {
1368                wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1369            }
1370            model.select_output_outlets(&[wire])?;
1371            Ok(model)
1372        }
1373
1374        fn input(&self) -> TractResult<Tensor> {
1375            unsafe {
1376                let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1377                for i in 0..t.len() {
1378                    t.try_as_plain_mut().unwrap().as_slice_mut().unwrap()[i] = i as i64;
1379                }
1380                Ok(t)
1381            }
1382        }
1383
1384        fn check(&self) -> TractResult<()> {
1385            crate::setup_test_logger();
1386            let input = self.input()?;
1387            let model = self.model()?;
1388            let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1389            let optimized = self.model()?.into_decluttered()?;
1390            let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1391            opt[0].close_enough(&raw[0], false)
1392        }
1393    }
1394
1395    proptest! {
1396        #[test]
1397        fn recip(pb in any::<AxisOp>()) {
1398            assert_eq!(pb.recip().recip(), pb);
1399        }
1400
1401        #[test]
1402        fn axis_ops(pb in any::<ComposeProblem>()) {
1403            pb.check().unwrap()
1404        }
1405    }
1406
1407    #[test]
1408    fn add_0_rm_0() {
1409        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1410        pb.check().unwrap();
1411    }
1412
1413    #[test]
1414    fn add_0_move_01() {
1415        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1416        pb.check().unwrap();
1417    }
1418
1419    #[test]
1420    fn add_0_move_01_add_1() {
1421        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1422        pb.check().unwrap();
1423    }
1424
1425    #[test]
1426    fn recip_move_01() {
1427        let op = Move(1, 0);
1428        assert_eq!(op.recip().recip(), op);
1429    }
1430
1431    #[test]
1432    fn recip_move_20() {
1433        let op = Move(2, 0);
1434        assert_eq!(op.recip().recip(), op);
1435    }
1436
1437    #[test]
1438    fn recip_move_02() {
1439        let op = Move(0, 2);
1440        assert_eq!(op.recip().recip(), op);
1441    }
1442
1443    #[test]
1444    fn add_0_add_1_move_02() {
1445        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1446        pb.check().unwrap();
1447    }
1448
1449    #[test]
1450    fn add_0_add_0() {
1451        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1452        pb.check().unwrap();
1453    }
1454
1455    #[test]
1456    fn add_0_add_0_move_02() {
1457        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1458        pb.check().unwrap();
1459    }
1460
1461    #[test]
1462    fn add_0_add_2_move_12() {
1463        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1464        pb.check().unwrap();
1465    }
1466
1467    #[test]
1468    fn add_0_add_0_move_02_rm_0() {
1469        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1470        pb.check().unwrap();
1471    }
1472
1473    #[test]
1474    fn add_0_add_0_move_20_move_20() {
1475        let pb =
1476            ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1477        pb.check().unwrap();
1478    }
1479
1480    #[test]
1481    fn move_01_add_0() {
1482        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1483        pb.check().unwrap();
1484    }
1485
1486    #[test]
1487    fn add_0_move_02_move_02() {
1488        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1489        pb.check().unwrap();
1490    }
1491
1492    #[test]
1493    fn add_0_add_2_move_20_move_12_rm_2() {
1494        let pb = ComposeProblem {
1495            input: tvec![3],
1496            ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1497        };
1498        pb.check().unwrap();
1499    }
1500
1501    #[test]
1502    fn move_02_move_02() {
1503        let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1504        pb.check().unwrap();
1505    }
1506
1507    #[test]
1508    fn rm_1_perm_10_add_0() {
1509        let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1510        pb.check().unwrap();
1511    }
1512
1513    #[test]
1514    fn add_2_move_02_move_02() {
1515        let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1516        pb.check().unwrap();
1517    }
1518
1519    #[test]
1520    fn move_01_move_20_move_20() {
1521        let pb = ComposeProblem {
1522            input: tvec![2, 3, 2],
1523            ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1524        };
1525        pb.check().unwrap();
1526    }
1527
1528    #[test]
1529    fn reshape_axes_tracking() {
1530        let pb = ComposeProblem {
1531            input: tvec![2, 2, 2],
1532            ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1533        };
1534        pb.check().unwrap();
1535    }
1536
1537    #[test]
1538    fn simplify_reshape() {
1539        macro_rules! d {
1540            ($($dim: expr),*) =>  { tvec!($($dim.to_dim()),*) }
1541        }
1542        assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1543        assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1544        assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1545        assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1546        assert_eq!(
1547            Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1548            tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1549        );
1550        assert_eq!(
1551            Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1552            tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1553        );
1554        assert_eq!(
1555            Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1556            tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1557        );
1558        assert_eq!(
1559            Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1560            tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1561        );
1562        assert_eq!(
1563            Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1564            tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1565        );
1566        assert_eq!(
1567            Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1568            tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1569        );
1570        assert_eq!(
1571            Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1572            tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1573        );
1574        assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1575    }
1576
1577    macro_rules! s {
1578        ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1579    }
1580
1581    macro_rules! r {
1582        ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1583            AxisOp::Reshape($at, tvec!($($from.into()),*),  tvec!($($to.into()),*))
1584        }
1585    }
1586
1587    #[test]
1588    fn compute_invalid() {
1589        assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1590    }
1591
1592    #[test]
1593    fn compute_with_leading_zero() {
1594        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1595    }
1596
1597    #[test]
1598    fn compute_with_leading_zero_with_flatten() {
1599        assert_eq!(
1600            &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1601            s![2, 3, 35]
1602        )
1603    }
1604
1605    #[test]
1606    fn compute_with_trailing_zero() {
1607        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1608    }
1609
1610    #[test]
1611    fn compute_bug_1() {
1612        let table = SymbolScope::default();
1613        let s = table.new_with_prefix("S");
1614        assert_eq!(
1615            &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1616            s![s, 1, 256]
1617        )
1618    }
1619
1620    #[test]
1621    fn compute_bug_2() {
1622        let table = SymbolScope::default();
1623        let b = table.new_with_prefix("B");
1624        let s = table.new_with_prefix("S");
1625        assert_eq!(
1626            &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1627            s![s, b, 256]
1628        )
1629    }
1630
1631    #[test]
1632    fn axis_op_rm_begin() {
1633        assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1634    }
1635
1636    #[test]
1637    fn axis_op_rm_end() {
1638        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1639    }
1640
1641    #[test]
1642    fn axis_op_insert_begin() {
1643        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1644    }
1645
1646    #[test]
1647    fn axis_op_insert_end() {
1648        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1649    }
1650
1651    #[test]
1652    fn axis_op_merge() {
1653        assert_eq!(
1654            &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1655            &[r!(2 ; 5,7 => 35 )]
1656        )
1657    }
1658
1659    #[test]
1660    fn axis_op_complex() {
1661        assert_eq!(
1662            &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1663            &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1664        )
1665    }
1666}