Skip to main content

tract_core/ops/
change_axes.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use AxisOp::*;
8use tract_itertools::Itertools;
9use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
10use tract_ndarray::{ArrayViewD, ArrayViewMutD};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
13pub enum InOut {
14    Out(usize),
15    In(usize),
16}
17
18impl InOut {
19    pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
20        match self {
21            InOut::In(ix) => node.inputs[*ix],
22            InOut::Out(ix) => OutletId::new(node.id, *ix),
23        }
24    }
25
26    pub fn is_input(&self) -> bool {
27        matches!(self, InOut::In(_))
28    }
29
30    pub fn is_output(&self) -> bool {
31        matches!(self, InOut::Out(_))
32    }
33
34    pub fn slot(&self) -> usize {
35        match self {
36            InOut::Out(o) => *o,
37            InOut::In(i) => *i,
38        }
39    }
40}
41
42#[derive(Clone, Hash, Eq)]
43#[allow(clippy::large_enum_variant)] // FIXME ?
44#[allow(clippy::derived_hash_with_manual_eq)] // FIXME. this one may be pretty bad. how about a.canonical() == b.canonical() ? need proper canonicalizeation of Reshape
45pub enum AxisOp {
46    Add(usize),
47    Rm(usize),
48    Move(usize, usize),
49    Reshape(usize, TVec<TDim>, TVec<TDim>),
50}
51
52impl Debug for AxisOp {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            AxisOp::Add(a) => write!(f, "Add({a})"),
56            AxisOp::Rm(a) => write!(f, "Rm({a})"),
57            AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
58            AxisOp::Reshape(at, from, to) => {
59                write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
60            }
61        }
62    }
63}
64
65impl PartialEq for AxisOp {
66    fn eq(&self, other: &AxisOp) -> bool {
67        if self.is_noop() && other.is_noop() {
68            true
69        } else if self.is_noop() != other.is_noop() {
70            false
71        } else {
72            match (self, other) {
73                (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
74                (Move(f1, t1), Move(f2, t2)) => {
75                    (f1 == f2 && t1 == t2)
76                        || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
77                }
78                (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
79                _ => false,
80            }
81        }
82    }
83}
84
85impl AxisOp {
86    pub fn canonical(&self) -> Cow<'_, AxisOp> {
87        match self {
88            Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
89            Reshape(at, from, to)
90                if from.len() == 1 && to.len() == 2 && from[0] == to[0] && to[1].is_one() =>
91            {
92                Cow::Owned(Add(*at + 1))
93            }
94            Reshape(at, from, to)
95                if from.len() == 1 && to.len() == 2 && from[0] == to[1] && to[0].is_one() =>
96            {
97                Cow::Owned(Add(*at))
98            }
99            Reshape(at, from, to)
100                if from.len() == 2 && to.len() == 1 && from[0] == to[0] && from[1].is_one() =>
101            {
102                Cow::Owned(Rm(*at + 1))
103            }
104            Reshape(at, from, to)
105                if from.len() == 2 && to.len() == 1 && from[1] == to[0] && from[0].is_one() =>
106            {
107                Cow::Owned(Rm(*at))
108            }
109            other => Cow::Borrowed(other),
110        }
111    }
112
113    pub fn simplify(&self) -> TVec<AxisOp> {
114        match self.canonical().borrow() {
115            Reshape(_, from, to) if from == to => tvec!(),
116            Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
117            Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
118            Reshape(at, from, to) if from[0] == to[0] => {
119                Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
120            }
121            Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
122                Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
123            }
124            Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
125                .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
126                .collect(),
127            Reshape(at, from, to) if to[0] == 1.to_dim() => {
128                Reshape(*at, from.clone(), to[1..].into())
129                    .simplify()
130                    .into_iter()
131                    .chain(std::iter::once(Add(*at)))
132                    .collect()
133            }
134            Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
135                std::iter::once(Rm(at + from.len() - 1))
136                    .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
137                    .collect()
138            }
139            Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
140                std::iter::once(Add(at + from.len()))
141                    .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
142                    .collect()
143            }
144            other => tvec!(other.clone()),
145        }
146    }
147
148    pub fn transform_axis(&self, axis: usize) -> Option<usize> {
149        match self.canonical().as_ref() {
150            Add(ix) => Some(axis + (axis >= *ix) as usize),
151            Rm(ix) => {
152                if axis == *ix {
153                    None
154                } else {
155                    Some(axis - (axis > *ix) as usize)
156                }
157            }
158            Move(from, to) if from < to => {
159                if axis < *from || axis > *to {
160                    Some(axis)
161                } else if axis == *from {
162                    Some(*to)
163                } else {
164                    Some(axis - 1)
165                }
166            }
167            Move(from, to) => {
168                if axis < *to || axis > *from {
169                    Some(axis)
170                } else if axis == *from {
171                    Some(*to)
172                } else {
173                    Some(axis + 1)
174                }
175            }
176            Reshape(at, _, _) if axis < *at => Some(axis),
177            Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
178            Reshape(_, _, _) => None,
179        }
180    }
181
182    // if sucessful return Some()
183    // first item is the Op we want to be replaced by. if none, we are now identity.
184    // second item is the change to propagate. if none, the output is not
185    // changed
186    pub fn merge_incoming_change(
187        &self,
188        change: &AxisOp,
189    ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
190        match (self.canonical().as_ref(), change.canonical().as_ref()) {
191            (Add(op), Add(c)) => {
192                Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
193            }
194            (Add(op), Rm(c)) => {
195                Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
196            }
197            (Rm(op), Add(c)) => {
198                Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
199            }
200            (Rm(op), Rm(c)) => {
201                Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
202            }
203
204            (Add(x), Move(from, to)) => {
205                if x <= from.min(to) {
206                    Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
207                } else if x > from.max(to) {
208                    Some((Some(self.clone()), Some(change.clone())))
209                } else {
210                    None
211                }
212            }
213
214            (Move(from, to), Add(x)) => {
215                if x <= from.min(to) {
216                    Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
217                } else if x > from.max(to) {
218                    Some((Some(Move(*from, *to)), Some(Add(*x))))
219                } else {
220                    None
221                }
222            }
223
224            (Rm(x), Move(from, to)) => {
225                if x == from {
226                    Some((Some(Rm(*to)), None))
227                } else if x < from.min(to) {
228                    Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
229                } else if x > from.max(to) {
230                    Some((Some(self.clone()), Some(change.clone())))
231                } else if from + 1 == *to && x == to {
232                    Some((Some(Rm(*from)), None))
233                } else if from < to && x <= to {
234                    Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
235                } else {
236                    Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
237                }
238            }
239
240            (Move(from, to), Rm(x)) => {
241                if x < from.min(to) {
242                    Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
243                } else if x > from.max(to) {
244                    Some((Some(Move(*from, *to)), Some(Rm(*x))))
245                } else {
246                    None
247                }
248            }
249
250            (Add(op), Reshape(at, from, to)) => {
251                if op <= at {
252                    Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
253                } else if *op > at + from.len() {
254                    Some((
255                        Some(Add(*op + to.len() - from.len())),
256                        Some(Reshape(*at, from.clone(), to.clone())),
257                    ))
258                } else {
259                    None
260                }
261            }
262            (Rm(op), Reshape(at, from, to)) => {
263                if op < at {
264                    Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
265                } else if *op > at + from.len() {
266                    Some((
267                        Some(Rm(*op + to.len() - from.len())),
268                        Some(Reshape(*at, from.clone(), to.clone())),
269                    ))
270                } else {
271                    None
272                }
273            }
274            (Reshape(at, from, to), Add(change)) => {
275                if change < at {
276                    Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
277                } else if *change > *at + from.len() {
278                    Some((
279                        Some(Reshape(*at, from.clone(), to.clone())),
280                        Some(Add(change + to.len() - from.len())),
281                    ))
282                } else {
283                    None
284                }
285            }
286            (Reshape(at, from, to), Rm(change)) => {
287                if change < at {
288                    Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
289                } else if *change > *at + from.len() {
290                    Some((
291                        Some(Reshape(*at, from.clone(), to.clone())),
292                        Some(Rm(change + to.len() - from.len())),
293                    ))
294                } else {
295                    None
296                }
297            }
298            (Reshape(_, _, _), Move(_, _)) => None, // todo, some are manageable
299            (Move(_, _), Reshape(_, _, _)) => None, // todo, some are manageable
300            (Reshape(_, _, _), Reshape(_, _, _)) => None, // todo, some are manageable
301            _ => None,
302        }
303    }
304
305    pub fn change_shape_array<D: DimLike>(
306        &self,
307        shape: &mut TVec<D>,
308        broadcasting: bool,
309    ) -> TractResult<()> {
310        match self.canonical().as_ref() {
311            Add(ix) => {
312                ensure!(*ix <= shape.len());
313                shape.insert(*ix, D::one());
314            }
315            Rm(ix) => {
316                ensure!(*ix < shape.len());
317                shape.remove(*ix);
318            }
319            Move(from, to) => {
320                ensure!(*from < shape.len());
321                ensure!(*to < shape.len());
322                let axis = shape.remove(*from);
323                shape.insert(*to, axis);
324            }
325            Reshape(at, from, to) => {
326                let from_volume = from.iter().product::<TDim>();
327                let to_volume = to.iter().product::<TDim>();
328                // Two algebraically equal volumes can land in different
329                // factored forms when the same dimension is built two ways
330                // (e.g. (B+2BY)·(1+Y) vs B·(1+Y)·(1+2Y) on Conformer-style
331                // streaming attention).  Compare polynomial expansions so
332                // structural mismatch on factor ordering doesn't fail the
333                // check.
334                ensure!(
335                    from_volume.clone().expand_polynomial()
336                        == to_volume.clone().expand_polynomial(),
337                    "{from_volume} should be equal to {to_volume}"
338                );
339                ensure!(*at + from.len() <= shape.len());
340                if shape.len() >= from.len() + *at
341                    && tract_itertools::izip!(shape.iter().skip(*at), from)
342                        .all(|(shape, spec)| shape.to_dim() == *spec)
343                {
344                    for _ in from {
345                        shape.remove(*at);
346                    }
347                    for d in to.iter().rev() {
348                        shape.insert(*at, d.try_into()?);
349                    }
350                } else if broadcasting
351                    && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
352                {
353                    for _ in from {
354                        shape.remove(*at);
355                    }
356                    for _ in to.iter().rev() {
357                        shape.insert(*at, 1.into());
358                    }
359                } else {
360                    bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
361                }
362            }
363        }
364        Ok(())
365    }
366
367    pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
368        match self.canonical().as_ref() {
369            Add(ix) => shape.insert_axis(*ix),
370            Rm(ix) => {
371                if shape.rank() <= *ix {
372                    bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
373                }
374                if shape[*ix] != 1.to_dim() {
375                    bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
376                }
377                shape.remove_axis(*ix)
378            }
379            _ => {
380                let mut array = shape.to_tvec();
381                self.change_shape_array(&mut array, broadcasting)?;
382                let mut new_shape = ShapeFact::from_dims(array);
383                std::mem::swap(shape, &mut new_shape);
384                Ok(())
385            }
386        }
387    }
388
389    pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
390        if tensor.storage_as::<BlockQuantStorage>().is_some() {
391            let bqs = tensor.try_storage_as::<BlockQuantStorage>()?.clone();
392            let mut new_shape: TVec<usize> = tensor.shape().into();
393            self.change_shape_array(&mut new_shape, false)?;
394            let mut new_tensor = bqs.into_tensor_with_shape(tensor.datum_type(), &new_shape);
395            std::mem::swap(tensor, &mut new_tensor);
396            return Ok(());
397        }
398        ensure!(self.required_rank() <= tensor.rank());
399        match self.canonical().as_ref() {
400            Add(ix) => tensor.insert_axis(*ix),
401            Rm(ix) => tensor.remove_axis(*ix),
402            Move(from, to) => {
403                let mut tmp = tensor.clone().move_axis(*from, *to)?;
404                std::mem::swap(tensor, &mut tmp);
405                Ok(())
406            }
407            Reshape(at, from, to) => {
408                let mut shape: TVec<usize> = tensor.shape().into();
409                self.change_shape_array(&mut shape, true)?;
410                if tensor.set_shape(&shape).is_ok() {
411                    Ok(())
412                } else if broadcasting
413                    && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
414                {
415                    if from.len() > to.len() {
416                        for _ in to.len()..from.len() {
417                            tensor.remove_axis(*at)?;
418                        }
419                    }
420                    if to.len() > from.len() {
421                        for _ in from.len()..to.len() {
422                            tensor.insert_axis(*at)?;
423                        }
424                    }
425                    Ok(())
426                } else {
427                    bail!(
428                        "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
429                        self,
430                        tensor,
431                        broadcasting
432                    )
433                }
434            }
435        }
436    }
437
438    pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
439        use tract_ndarray::Axis;
440        match *self {
441            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
442            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
443            AxisOp::Move(from, to) if from < to => {
444                for left in from..to {
445                    view.swap_axes(left, left + 1);
446                }
447            }
448            AxisOp::Move(from, to) => {
449                for left in (to..from).rev() {
450                    view.swap_axes(left, left + 1);
451                }
452            }
453            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
454        }
455        Ok(())
456    }
457
458    pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
459        use tract_ndarray::Axis;
460        match *self {
461            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
462            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
463            AxisOp::Move(from, to) if from < to => {
464                for left in from..to {
465                    view.swap_axes(left, left + 1);
466                }
467            }
468            AxisOp::Move(from, to) => {
469                for left in (to..from).rev() {
470                    view.swap_axes(left, left + 1);
471                }
472            }
473            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
474        }
475        Ok(())
476    }
477
478    pub fn recip(&self) -> AxisOp {
479        match self.canonical().as_ref() {
480            Add(ix) => Rm(*ix),
481            Rm(ix) => Add(*ix),
482            Move(from, to) if from == to => self.clone(),
483            Move(from, to) if *from + 1 == *to => self.clone(),
484            Move(from, to) if *from == *to + 1 => {
485                unreachable!();
486            }
487            Move(from, to) => Move(*to, *from),
488            Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
489        }
490    }
491
492    pub fn is_noop(&self) -> bool {
493        match self {
494            Move(f, t) if f == t => true,
495            Reshape(_, f, t) if f == t => true,
496            _ => false,
497        }
498    }
499
500    pub fn only_shape(&self) -> bool {
501        if self.is_noop() {
502            return true;
503        }
504        !matches!(self, Move(_, _))
505    }
506
507    pub fn wire_split_axis(
508        model: &mut TypedModel,
509        name: impl ToString,
510        outlet: OutletId,
511        axis: usize,
512        outer_dim: usize,
513    ) -> TractResult<TVec<OutletId>> {
514        let fact = model.outlet_fact(outlet)?;
515        let dim: TDim = fact.shape[axis].clone();
516        let inner_dim = dim.clone() / outer_dim;
517        let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
518        model.wire_node(name.to_string(), op, &[outlet])
519    }
520
521    pub fn wire_collapse_axis(
522        model: &mut TypedModel,
523        name: impl ToString,
524        outlet: OutletId,
525        axis: usize,
526    ) -> TractResult<TVec<OutletId>> {
527        let fact = model.outlet_fact(outlet)?;
528        let dim: TDim = fact.shape[axis].clone();
529        let next_dim: TDim = fact.shape[axis + 1].clone();
530        let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
531        model.wire_node(name.to_string(), op, &[outlet])
532    }
533
534    #[inline]
535    pub fn required_rank(&self) -> usize {
536        match self {
537            Rm(r) => r + 1,
538            Add(a) => *a,
539            Reshape(at, from, _to) => at + from.len(),
540            Move(from, to) => *from.max(to),
541        }
542    }
543
544    pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
545        Ok(match self {
546            Rm(r) if *r >= prefix => Rm(r - prefix),
547            Add(a) if *a >= prefix => Add(a - prefix),
548            Reshape(at, from, to) if *at >= prefix => {
549                Reshape(at - prefix, from.clone(), to.clone())
550            }
551            Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
552            _ => bail!("Can no trim left {self:?} by {prefix}"),
553        })
554    }
555}
556
557pub fn wire_rank_broadcast(
558    prefix: impl AsRef<str>,
559    target: &mut TypedModel,
560    inputs: &[OutletId],
561) -> TractResult<TVec<OutletId>> {
562    let facts =
563        inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
564    let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
565    let mut wires = tvec!();
566    for i in 0..inputs.len() {
567        let mut wire = inputs[i];
568        for _ in facts[i].rank()..max_rank {
569            let name = target.unique_name(prefix.as_ref().to_string() + ".fix-rank");
570            wire = target.wire_node(name, AxisOp::Add(0), &[wire])?[0];
571        }
572        wires.push(wire);
573    }
574    Ok(wires)
575}
576
577pub fn wire_with_rank_broadcast(
578    prefix: impl AsRef<str>,
579    target: &mut TypedModel,
580    op: impl Into<Box<dyn TypedOp>>,
581    inputs: &[OutletId],
582) -> TractResult<TVec<OutletId>> {
583    let prefix = prefix.as_ref();
584    let wires = wire_rank_broadcast(prefix, target, inputs)?;
585    target.wire_node(prefix, op.into(), &wires)
586}
587
588#[derive(Clone, Debug, PartialEq, Eq, Hash)]
589pub struct AxisChange {
590    pub outlet: OutletId,
591    pub op: AxisOp,
592}
593
594#[derive(Clone, Default, Debug)]
595pub struct AxisChangeConsequence {
596    pub substitute_op: Option<Box<dyn TypedOp>>,
597    pub wire_changes: TVec<(InOut, AxisOp)>,
598}
599
600impl AxisChangeConsequence {
601    pub fn new(
602        _model: &TypedModel,
603        node: &TypedNode,
604        op: Option<Box<dyn TypedOp>>,
605        axis_op: &AxisOp,
606    ) -> AxisChangeConsequence {
607        let mut wire_changes = tvec!();
608        for i in 0..node.inputs.len() {
609            wire_changes.push((InOut::In(i), axis_op.clone()));
610        }
611        for i in 0..node.outputs.len() {
612            wire_changes.push((InOut::Out(i), axis_op.clone()));
613        }
614        AxisChangeConsequence { wire_changes, substitute_op: op }
615    }
616}
617
618impl Op for AxisOp {
619    fn name(&self) -> StaticName {
620        match self {
621            Add(_) => "AddAxis".into(),
622            Rm(_) => "RmAxis".into(),
623            Move(_, _) => "MoveAxis".into(),
624            Reshape(_, _, _) => "Reshape".into(),
625        }
626    }
627
628    fn info(&self) -> TractResult<Vec<String>> {
629        match self {
630            Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
631            Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
632            Reshape(at, from, to) => Ok(vec![format!(
633                "Axes starting at {}: {:?} to {:?}",
634                at,
635                from.iter().join(","),
636                to.iter().join(",")
637            )]),
638        }
639    }
640
641    op_as_typed_op!();
642}
643
644impl EvalOp for AxisOp {
645    fn is_stateless(&self) -> bool {
646        true
647    }
648
649    fn eval_with_session(
650        &self,
651        _node_id: usize,
652        session: &TurnState,
653        inputs: TVec<TValue>,
654    ) -> TractResult<TVec<TValue>> {
655        let mut input = args_1!(inputs).into_tensor();
656        match self {
657            AxisOp::Reshape(skip, from, to) => {
658                let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
659                let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
660                AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
661            }
662            _ => self.change_tensor(&mut input, false)?,
663        }
664        Ok(tvec!(input.into_tvalue()))
665    }
666}
667
668/// Remap coordinate symbols in a TDim expression according to an AxisOp.
669/// Returns None if the remapping cannot be determined (e.g. general reshape
670/// with both ends > 1).
671fn remap_uniform_tdim(expr: &TDim, axis_op: &AxisOp) -> Option<TDim> {
672    let syms = expr.symbols();
673    let coord_syms: Vec<(usize, Symbol)> = syms
674        .into_iter()
675        .filter_map(|s| {
676            let name = format!("{s}");
677            name.strip_prefix("🎯").and_then(|rest| rest.parse::<usize>().ok()).map(|k| (k, s))
678        })
679        .collect();
680
681    if coord_syms.is_empty() {
682        // No coordinate symbols – the value is uniform across all positions; propagate as-is.
683        return Some(expr.clone());
684    }
685
686    if let AxisOp::Reshape(at, from_dims, to_dims) = axis_op.canonical().as_ref() {
687        // Trivial all-ones case: shape change is purely cosmetic, value is unaffected.
688        if from_dims.iter().all(|d| d.is_one()) && to_dims.iter().all(|d| d.is_one()) {
689            return Some(expr.clone());
690        }
691        // Pure split: from = [D], to = [d_0, …, d_{k-1}], Π = D.  The input
692        // axis-`at` position decomposes as
693        //     pos[at] = Σ_i pos[at+i]_new · stride_i
694        // with `stride_i = Π_{j>i} to_dims[j]` (last stride is 1).  Other
695        // input axes shift right by `k-1` (the net rank change).
696        if from_dims.len() == 1 {
697            let from_dim = from_dims[0].clone();
698            let to_product: TDim = to_dims.iter().fold(TDim::Val(1), |acc, d| acc * d.clone());
699            if to_product == from_dim {
700                let k_to = to_dims.len();
701                let mut map: HashMap<Symbol, TDim> = HashMap::default();
702                for (k, sym) in &coord_syms {
703                    let scope = sym.scope()?;
704                    let new_expr = if *k < *at {
705                        TDim::Sym(sym.clone())
706                    } else if *k == *at {
707                        let mut sum = TDim::Val(0);
708                        let mut stride = TDim::Val(1);
709                        for i in (0..k_to).rev() {
710                            let new_sym = scope.coord_sym(*at + i);
711                            sum = sum + TDim::Sym(new_sym) * stride.clone();
712                            stride = stride * to_dims[i].clone();
713                        }
714                        sum
715                    } else {
716                        TDim::Sym(scope.coord_sym(*k + k_to - 1))
717                    };
718                    map.insert(sym.clone(), new_expr);
719                }
720                return expr.substitute_all(&map).ok().map(|e| e.reduce());
721            }
722        }
723        // Pure merge: from = [d_0, …, d_{k-1}], to = [D].  We can express
724        // `pos[at+i]_old` from `pos[at]_new` only via integer division and
725        // modulo, which TDim doesn't carry.  Special-case the easy form
726        // where all but one of the merged dims is 1 — then the lone
727        // non-trivial sub-axis just maps to the new merged axis.
728        if to_dims.len() == 1 {
729            let to_dim = to_dims[0].clone();
730            let from_product: TDim = from_dims.iter().fold(TDim::Val(1), |acc, d| acc * d.clone());
731            if from_product == to_dim {
732                let k_from = from_dims.len();
733                let mut map: HashMap<Symbol, TDim> = HashMap::default();
734                for (k, sym) in &coord_syms {
735                    let scope = sym.scope()?;
736                    let new_expr = if *k < *at {
737                        TDim::Sym(sym.clone())
738                    } else if *k < *at + k_from {
739                        let i = *k - *at;
740                        let only_nontrivial =
741                            from_dims.iter().enumerate().all(|(j, d)| j == i || d.is_one());
742                        if only_nontrivial {
743                            TDim::Sym(scope.coord_sym(*at))
744                        } else {
745                            return None;
746                        }
747                    } else {
748                        TDim::Sym(scope.coord_sym(*k - (k_from - 1)))
749                    };
750                    map.insert(sym.clone(), new_expr);
751                }
752                return expr.substitute_all(&map).ok().map(|e| e.reduce());
753            }
754        }
755        return None;
756    }
757
758    // For Add/Rm/Move: use transform_axis and substitute all at once to avoid
759    // double-substitution when two axes swap positions (e.g. Move).
760    let map: HashMap<Symbol, TDim> = coord_syms
761        .into_iter()
762        .filter_map(|(k, sym)| {
763            let new_k = axis_op.transform_axis(k)?;
764            if new_k == k {
765                return None;
766            }
767            let scope = sym.scope()?;
768            Some((sym, TDim::Sym(scope.coord_sym(new_k))))
769        })
770        .collect();
771    if map.is_empty() {
772        return Some(expr.clone());
773    }
774    expr.substitute_all(&map).ok()
775}
776
777impl TypedOp for AxisOp {
778    as_op!();
779
780    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
781        if let Some(bqf) =
782            inputs[0].exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
783        {
784            let mut new_shape: TVec<usize> = bqf.shape().into();
785            self.change_shape_array(&mut new_shape, false)?;
786            let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_shape.clone());
787            let shape: TVec<TDim> = new_shape.iter().map(|d| d.to_dim()).collect();
788            let mut new_fact = inputs[0].datum_type.fact(&*shape).with_exotic_fact(new_bqf);
789            if let Some(k) = &inputs[0].konst {
790                let mut new = k.clone().into_tensor();
791                self.change_tensor(&mut new, false)?;
792                new_fact.konst = Some(new.into());
793            }
794            return Ok(tvec!(new_fact));
795        }
796        let mut shape = inputs[0].shape.clone();
797        self.change_shape(&mut shape, false)?;
798        let mut fact = inputs[0].datum_type.fact(shape);
799        fact.exotic_fact.clone_from(&inputs[0].exotic_fact);
800        if let Some(tdim) = &inputs[0].uniform_tdim {
801            fact.uniform_tdim = remap_uniform_tdim(tdim, self);
802        }
803        Ok(tvec!(fact))
804    }
805
806    fn input_roi(
807        &self,
808        model: &TypedModel,
809        node: &TypedNode,
810    ) -> TractResult<Option<TVec<Option<TDim>>>> {
811        crate::optim::propagate_roi::bubble_roi(model, node)
812    }
813
814    fn axes_mapping(
815        &self,
816        inputs: &[&TypedFact],
817        outputs: &[&TypedFact],
818    ) -> TractResult<AxesMapping> {
819        let mut axes: Vec<Axis> = (0..inputs[0].rank())
820            .zip('a'..)
821            .map(|(axis_id, repr)| {
822                let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
823                if let Some(out) = self.transform_axis(axis_id) {
824                    axis = axis.output(0, out);
825                }
826                axis
827            })
828            .collect();
829        for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
830            if self.recip().transform_axis(axis).is_none() {
831                axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
832            }
833        }
834        AxesMapping::new(inputs.len(), outputs.len(), axes)
835    }
836
837    fn declutter(
838        &self,
839        model: &TypedModel,
840        node: &TypedNode,
841    ) -> TractResult<Option<TypedModelPatch>> {
842        if self.is_noop()
843            && let Some(p) = TypedModelPatch::shunt_one_op(model, node)?
844        {
845            return Ok(Some(p));
846        }
847        let simplified = self.simplify();
848        if simplified.len() != 1 || &simplified[0] != self {
849            let mut patch = TypedModelPatch::default();
850            let mut wire = patch.tap_model(model, node.inputs[0])?;
851            for (ix, op) in simplified.into_iter().enumerate() {
852                wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
853            }
854            patch.shunt_outside(model, node.id.into(), wire)?;
855            Ok(Some(patch))
856        } else {
857            Ok(None)
858        }
859    }
860
861    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
862        Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
863    }
864
865    fn change_axes(
866        &self,
867        _model: &TypedModel,
868        _node: &TypedNode,
869        io: InOut,
870        change: &AxisOp,
871    ) -> TractResult<Option<AxisChangeConsequence>> {
872        let op = if let InOut::Out(0) = io {
873            rule_if_some!(more = self.recip().change_axes(_model, _node, InOut::In(0), change)?);
874            AxisChangeConsequence {
875                substitute_op: more.substitute_op.map(|op| {
876                    if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
877                        Box::new(op.recip())
878                    } else {
879                        op // have to be identity
880                    }
881                }),
882                wire_changes: more
883                    .wire_changes
884                    .into_iter()
885                    .map(|wc| {
886                        (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
887                    })
888                    .collect(),
889            }
890        } else if change == self {
891            AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
892        } else {
893            rule_if_some!((new_op, new_change) = self.merge_incoming_change(change));
894            trace!("  Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
895            let substitute_op: Box<dyn TypedOp> =
896                if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
897            let mut wire_changes = tvec!();
898            if !change.is_noop() {
899                wire_changes.push((InOut::In(0), change.clone()))
900            }
901            if let Some(new_change) = new_change {
902                wire_changes.push((InOut::Out(0), new_change))
903            }
904            AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
905        };
906        Ok(Some(op))
907    }
908
909    fn set_symbols(
910        &self,
911        _source: &TypedModel,
912        node: &TypedNode,
913        target: &mut TypedModel,
914        mapping: &HashMap<OutletId, OutletId>,
915        subs: &HashMap<Symbol, TDim>,
916    ) -> TractResult<TVec<OutletId>> {
917        let op = if let AxisOp::Reshape(axis, from, to) = self {
918            AxisOp::Reshape(
919                *axis,
920                from.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<_>>()?,
921                to.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<_>>()?,
922            )
923        } else {
924            self.clone()
925        };
926        target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
927    }
928
929    fn slice(
930        &self,
931        patch: &mut TypedModelPatch,
932        _model: &TypedModel,
933        node: &TypedNode,
934        _prefix: &str,
935        inputs: &[OutletId],
936        output_axis: usize,
937        _start: &TDim,
938        _end: &TDim,
939    ) -> TractResult<Option<TVec<OutletId>>> {
940        // is this test really useful ? or axis mapping preempt this ?
941        if let Reshape(pos, _from, to) = self
942            && output_axis >= *pos
943            && output_axis < pos + to.len()
944        {
945            return Ok(None);
946        }
947        patch.wire_node(&node.name, &node.op, inputs).map(Some)
948    }
949
950    fn codegen(
951        &self,
952        model: &TypedModel,
953        node: &TypedNode,
954    ) -> TractResult<Option<TypedModelPatch>> {
955        rule_if!(node.outputs[0].fact.exotic_fact.is_none());
956        if let Some(shape) = node.outputs[0].fact.shape.as_concrete()
957            && !matches!(self, AxisOp::Move(_, _))
958        {
959            let (inputs, outputs) = model.node_facts(node.id)?;
960            let mapping = self.axes_mapping(&inputs, &outputs)?;
961            let op = IntoShape {
962                mapping,
963                len: shape.iter().product(),
964                strides: Tensor::natural_strides(shape),
965                dims: shape.into(),
966            };
967            return Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, op)?));
968        }
969        Ok(None)
970    }
971}
972
973// a, b, c is a <- b, b <- c, c <- a
974fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
975    let mut cycles: TVec<TVec<usize>> = tvec!();
976    let mut done = 0;
977    while done < perm.len() {
978        if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
979            done += 1;
980            continue;
981        }
982        let mut cycle = tvec!();
983        let mut current = done;
984        loop {
985            cycle.push(current);
986            current = perm[current];
987            if current == done {
988                break;
989            }
990        }
991        cycles.push(cycle)
992    }
993    cycles
994}
995
996fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
997    if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
998        Some((cycle[0], cycle[cycle.len() - 1]))
999    } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
1000        && cycle[cycle.len() - 1] - 1 == cycle[0]
1001    {
1002        Some((cycle[1], cycle[0]))
1003    } else {
1004        None
1005    }
1006}
1007
1008fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
1009    let mut changes: TVec<(usize, usize)> = tvec!();
1010    'top: loop {
1011        let mut reached: TVec<usize> = (0..input.len()).collect();
1012        changes.iter().for_each(|(f, t)| {
1013            let axis = reached.remove(*f);
1014            reached.insert(*t, axis);
1015        });
1016        if &*reached == input {
1017            return changes;
1018        }
1019        let remaining: TVec<usize> =
1020            input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
1021        let cycles = perm_to_cycles(&remaining);
1022        for cycle in &cycles {
1023            if let Some(rot) = is_rotation_cycle(cycle) {
1024                changes.push(rot);
1025                continue 'top;
1026            }
1027        }
1028        changes.push((cycles[0][1], cycles[0][0]));
1029    }
1030}
1031
1032pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
1033    perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
1034}
1035
1036pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
1037    let mut shape: TVec<TDim> = shape_spec.into();
1038    // Replace 0s with corresponding input dims (positional, per ONNX/TF spec)
1039    for (i, s) in shape.iter_mut().enumerate() {
1040        if *s == 0.into() {
1041            *s = input
1042                .get(i)
1043                .with_context(|| {
1044                    format!("Reshape: 0 at position {i} but input only has {} dims", input.len())
1045                })?
1046                .clone();
1047        }
1048    }
1049    let input_vol: TDim = input.iter().product();
1050    if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
1051        let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
1052        let div = input_vol.maybe_div(&shape_vol)?;
1053        if div.1 != 1 {
1054            bail!("invalid")
1055        }
1056        shape[pos] = div.0;
1057    } else {
1058        let shape_vol: TDim = shape.iter().product();
1059        if input_vol != shape_vol {
1060            bail!(
1061                "Reshape volume mismatch: input {input:?} (vol={input_vol}) vs shape {shape:?} (vol={shape_vol})"
1062            );
1063        }
1064    }
1065    Ok(shape)
1066}
1067
1068pub fn to_axis_ops_with_tf_rules(
1069    input_orig: &[TDim],
1070    output_spec: &[TDim],
1071) -> TractResult<TVec<AxisOp>> {
1072    let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
1073    let mut stack: TVec<AxisOp> = tvec!();
1074    'top: loop {
1075        let current_input =
1076            stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
1077                op.change_shape_array(&mut shape, false)?;
1078                Ok(shape)
1079            })?;
1080        if current_input == final_output {
1081            return Ok(stack);
1082        }
1083        if let Some(common) =
1084            current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
1085        {
1086            if current_input[common].is_one() {
1087                stack.push(AxisOp::Rm(common));
1088            } else if final_output[common].is_one() {
1089                stack.push(AxisOp::Add(common));
1090            } else {
1091                // actual regrouping. search for a match. this is quadratic, but
1092                // rank is expected to be somewhat reasonable
1093                for i in common..current_input.len() {
1094                    let i_group = &current_input[common..i + 1];
1095                    let i_volume: TDim = i_group.iter().product();
1096                    for o in common..final_output.len() {
1097                        let o_group = &final_output[common..o + 1];
1098                        let o_volume: TDim = o_group.iter().product();
1099                        if i_volume == o_volume {
1100                            stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1101                            continue 'top;
1102                        }
1103                    }
1104                }
1105                bail!(
1106                    "Could not find matching reshape grouping: current_input={current_input:?} final_output={final_output:?} common={common}"
1107                )
1108            }
1109        } else if final_output.len() > current_input.len() {
1110            stack.push(AxisOp::Add(current_input.len()));
1111        } else {
1112            stack.push(AxisOp::Rm(current_input.len() - 1));
1113        }
1114    }
1115}
1116
1117#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1118pub struct IntoShape {
1119    pub mapping: AxesMapping,
1120    pub len: usize,
1121    pub dims: TVec<usize>,
1122    pub strides: TVec<isize>,
1123}
1124
1125impl Op for IntoShape {
1126    fn name(&self) -> StaticName {
1127        "IntoShape".into()
1128    }
1129
1130    fn info(&self) -> TractResult<Vec<String>> {
1131        Ok(vec![format!("{}", self.mapping)])
1132    }
1133
1134    op_as_typed_op!();
1135}
1136
1137impl EvalOp for IntoShape {
1138    fn is_stateless(&self) -> bool {
1139        true
1140    }
1141
1142    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1143        let mut input = args_1!(inputs).into_tensor();
1144        ensure!(input.len() == self.len);
1145        unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1146        Ok(tvec!(input.into_tvalue()))
1147    }
1148}
1149
1150impl TypedOp for IntoShape {
1151    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1152        let mut fact = inputs[0].datum_type.fact(&self.dims);
1153        if let Some(of) = &inputs[0].exotic_fact {
1154            fact = fact.with_exotic_fact(of.clone());
1155        }
1156        Ok(tvec!(fact))
1157    }
1158
1159    fn declutter(
1160        &self,
1161        model: &TypedModel,
1162        node: &TypedNode,
1163    ) -> TractResult<Option<TypedModelPatch>> {
1164        let input = model.outlet_fact(node.inputs[0])?;
1165        if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1166            return TypedModelPatch::shunt_one_op(model, node);
1167        }
1168        if let Some(succ) = model.single_succ(node.id)?
1169            && let Some(into_shape) = succ.op_as::<IntoShape>()
1170        {
1171            let op =
1172                Self { mapping: self.mapping.compose(&into_shape.mapping)?, ..into_shape.clone() };
1173            return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1174        }
1175        Ok(None)
1176    }
1177
1178    as_op!();
1179}
1180
1181#[cfg(test)]
1182mod test {
1183    use super::*;
1184
1185    #[test]
1186    fn test_perm_to_cycles() {
1187        assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1188        assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1189        assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1190        assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1191        assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1192    }
1193
1194    #[test]
1195    fn is_rotation() {
1196        assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1197        assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1198    }
1199
1200    #[test]
1201    fn test_perm_one_rotation() {
1202        assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1203    }
1204
1205    #[test]
1206    fn test_perm_two_rotations() {
1207        assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1208    }
1209
1210    #[test]
1211    fn test_perm_complex() {
1212        assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1213    }
1214
1215    // ADD-ADD
1216
1217    //                          Op
1218    //           b,c   ------|Add(0)|----->        n,b,c
1219    //   Add(0)                                            Add(1)
1220    //         a,b,c   ------|Add(0)|----->        a,n,b,c
1221    #[test]
1222    pub fn transform_op_add_0_add_0() {
1223        let change = Add(0);
1224        let op = Add(0);
1225        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1226    }
1227
1228    //                          Op
1229    //           b,c   ------|Add(1)|----->        b,n,c
1230    //   Add(0)                                                 Add(0)
1231    //         a,b,c   ------|Add(2)|----->        a,b,n,c
1232    #[test]
1233    pub fn transform_op_add_0_add_1() {
1234        let change = Add(0);
1235        let op = Add(1);
1236        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1237    }
1238
1239    //                          Op
1240    //           a,c   ------|Add(0)|----->        n,a,c
1241    //   Add(1)                                                 Add(2)
1242    //         a,b,c   ------|Add(0)|----->        n,a,b,c
1243    #[test]
1244    pub fn transform_op_add_1_add_0() {
1245        let change = Add(1);
1246        let op = Add(0);
1247        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1248    }
1249
1250    //                          Op
1251    //         a,b,c   ------|Rm(1)|----->         a,c
1252    //   Rm(0)                                             Rm(0)
1253    //           b,c   ------|Rm(0)|----->         c
1254    #[test]
1255    pub fn transform_op_rm_0_rm_1() {
1256        let change = Rm(0);
1257        let op = Rm(1);
1258        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1259    }
1260
1261    //                          Op
1262    //         a,b,c   ------|Rm(0)|----->         b,c
1263    //   Rm(1)                                             Rm(0)
1264    //           a,c   ------|Rm(0)|----->         c
1265    #[test]
1266    pub fn transform_op_rm_1_rm_0() {
1267        let change = Rm(1);
1268        let op = Rm(0);
1269        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1270    }
1271
1272    // ADD - RM
1273
1274    //                          Op
1275    //          b,c     ------|Rm(0)|------>        c
1276    //   Add(0)                                                 Add(0)
1277    //          a,b,c   ------|Rm(1)|----->         a,c
1278    #[test]
1279    pub fn transform_op_add_0_rm_0() {
1280        let change = Add(0);
1281        let op = Rm(0);
1282        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1283    }
1284
1285    //                          Op
1286    //          b,c     ------|Rm(1)|------>        b
1287    //   Add(0)                                                 Add(0)
1288    //          a,b,c   ------|Rm(2)|----->         a,b
1289    #[test]
1290    pub fn transform_op_add_0_rm_1() {
1291        let change = Add(0);
1292        let op = Rm(1);
1293        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1294    }
1295
1296    //                          Op
1297    //          a,c     ------|Rm(0)|------>        c
1298    //   Add(1)                                                 Add(0)
1299    //          a,b,c   ------|Rm(0)|----->         b,c
1300    #[test]
1301    pub fn transform_op_add_1_rm_0() {
1302        let change = Add(1);
1303        let op = Rm(0);
1304        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1305    }
1306
1307    // RM - ADD
1308
1309    //                          Op
1310    //         a,b,c   ------|Add(0)|----->        X,a,b,c
1311    //   Rm(1)                                                 Rm(2)
1312    //           a,c   ------|Add(0)|----->        X,a,c
1313    #[test]
1314    pub fn transform_op_rm_1_add_0() {
1315        let change = Rm(1);
1316        let op = Add(0);
1317        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1318    }
1319
1320    //                          Op
1321    //         a,b,c   ------|Add(1)|----->        a,X,b,c
1322    //   Rm(0)                                                 Rm(0)
1323    //           b,c   ------|Add(0)|----->        X,b,c
1324    #[test]
1325    pub fn transform_op_rm_0_add_1() {
1326        let change = Rm(0);
1327        let op = Add(1);
1328        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1329    }
1330
1331    //                          Op
1332    //         a,b,c   ------|Rm(2)|----->        a,b
1333    //   Move(0, 2)                                           Move(0,1)
1334    //         b,c,a   ------|Rm(1)|----->        b,a
1335    #[test]
1336    pub fn transform_op_mv_02_rm_2() {
1337        let change = Move(0, 2);
1338        let op = Rm(2);
1339        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1340    }
1341}
1342
1343#[cfg(test)]
1344mod proptests {
1345    use super::*;
1346    use proptest::prelude::*;
1347
1348    #[derive(Debug)]
1349    struct ComposeProblem {
1350        input: TVec<usize>,
1351        ops: TVec<AxisOp>,
1352    }
1353
1354    impl Arbitrary for AxisOp {
1355        type Parameters = TVec<usize>;
1356        type Strategy = BoxedStrategy<AxisOp>;
1357        fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1358            let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1359            if shape.len() > 1 {
1360                ops = ops
1361                    .prop_union(
1362                        (0..shape.len(), 0..shape.len() - 1)
1363                            .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1364                            .boxed(),
1365                    )
1366                    .boxed()
1367            }
1368            let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1369            if rms.len() > 0 {
1370                ops = ops
1371                    .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1372                    .boxed()
1373            }
1374            let mergeable: Vec<AxisOp> = shape
1375                .windows(2)
1376                .enumerate()
1377                .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1378                .map(|(ix, w)| {
1379                    Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1380                })
1381                .collect();
1382            if mergeable.len() > 1 {
1383                ops = ops
1384                    .prop_union(
1385                        (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1386                    )
1387                    .boxed()
1388            }
1389            ops
1390        }
1391    }
1392
1393    impl Arbitrary for ComposeProblem {
1394        type Parameters = ();
1395        type Strategy = BoxedStrategy<ComposeProblem>;
1396        fn arbitrary_with(_args: ()) -> Self::Strategy {
1397            let input = proptest::collection::vec(1usize..4, 1usize..4);
1398            fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1399                if len == 0 {
1400                    Just(tvec!()).boxed()
1401                } else {
1402                    AxisOp::arbitrary_with(shape.clone())
1403                        .prop_flat_map(move |op| {
1404                            let mut shape = shape.clone();
1405                            op.change_shape_array(&mut shape, false).unwrap();
1406                            tail(len - 1, shape.clone()).prop_map(move |mut t| {
1407                                t.insert(0, op.clone());
1408                                t
1409                            })
1410                        })
1411                        .boxed()
1412                }
1413            }
1414            (input, 1usize..=5)
1415                .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1416                .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1417                .boxed()
1418        }
1419    }
1420
1421    impl ComposeProblem {
1422        pub fn model(&self) -> TractResult<TypedModel> {
1423            let mut model = TypedModel::default();
1424            let mut wire = model.add_source("source", i64::fact(&self.input))?;
1425            for (ix, op) in self.ops.iter().enumerate() {
1426                wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1427            }
1428            model.select_output_outlets(&[wire])?;
1429            Ok(model)
1430        }
1431
1432        fn input(&self) -> TractResult<Tensor> {
1433            unsafe {
1434                let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1435                for i in 0..t.len() {
1436                    t.try_as_plain_mut().unwrap().as_slice_mut().unwrap()[i] = i as i64;
1437                }
1438                Ok(t)
1439            }
1440        }
1441
1442        fn check(&self) -> TractResult<()> {
1443            crate::setup_test_logger();
1444            let input = self.input()?;
1445            let model = self.model()?;
1446            let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1447            let optimized = self.model()?.into_decluttered()?;
1448            let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1449            opt[0].close_enough(&raw[0], false)
1450        }
1451    }
1452
1453    proptest! {
1454        #[test]
1455        fn recip(pb in any::<AxisOp>()) {
1456            assert_eq!(pb.recip().recip(), pb);
1457        }
1458
1459        #[test]
1460        fn axis_ops(pb in any::<ComposeProblem>()) {
1461            pb.check().unwrap()
1462        }
1463    }
1464
1465    #[test]
1466    fn add_0_rm_0() {
1467        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1468        pb.check().unwrap();
1469    }
1470
1471    #[test]
1472    fn add_0_move_01() {
1473        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1474        pb.check().unwrap();
1475    }
1476
1477    #[test]
1478    fn add_0_move_01_add_1() {
1479        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1480        pb.check().unwrap();
1481    }
1482
1483    #[test]
1484    fn recip_move_01() {
1485        let op = Move(1, 0);
1486        assert_eq!(op.recip().recip(), op);
1487    }
1488
1489    #[test]
1490    fn recip_move_20() {
1491        let op = Move(2, 0);
1492        assert_eq!(op.recip().recip(), op);
1493    }
1494
1495    #[test]
1496    fn recip_move_02() {
1497        let op = Move(0, 2);
1498        assert_eq!(op.recip().recip(), op);
1499    }
1500
1501    #[test]
1502    fn add_0_add_1_move_02() {
1503        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1504        pb.check().unwrap();
1505    }
1506
1507    #[test]
1508    fn add_0_add_0() {
1509        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1510        pb.check().unwrap();
1511    }
1512
1513    #[test]
1514    fn add_0_add_0_move_02() {
1515        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1516        pb.check().unwrap();
1517    }
1518
1519    #[test]
1520    fn add_0_add_2_move_12() {
1521        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1522        pb.check().unwrap();
1523    }
1524
1525    #[test]
1526    fn add_0_add_0_move_02_rm_0() {
1527        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1528        pb.check().unwrap();
1529    }
1530
1531    #[test]
1532    fn add_0_add_0_move_20_move_20() {
1533        let pb =
1534            ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1535        pb.check().unwrap();
1536    }
1537
1538    #[test]
1539    fn move_01_add_0() {
1540        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1541        pb.check().unwrap();
1542    }
1543
1544    #[test]
1545    fn add_0_move_02_move_02() {
1546        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1547        pb.check().unwrap();
1548    }
1549
1550    #[test]
1551    fn add_0_add_2_move_20_move_12_rm_2() {
1552        let pb = ComposeProblem {
1553            input: tvec![3],
1554            ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1555        };
1556        pb.check().unwrap();
1557    }
1558
1559    #[test]
1560    fn move_02_move_02() {
1561        let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1562        pb.check().unwrap();
1563    }
1564
1565    #[test]
1566    fn rm_1_perm_10_add_0() {
1567        let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1568        pb.check().unwrap();
1569    }
1570
1571    #[test]
1572    fn add_2_move_02_move_02() {
1573        let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1574        pb.check().unwrap();
1575    }
1576
1577    #[test]
1578    fn move_01_move_20_move_20() {
1579        let pb = ComposeProblem {
1580            input: tvec![2, 3, 2],
1581            ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1582        };
1583        pb.check().unwrap();
1584    }
1585
1586    #[test]
1587    fn reshape_axes_tracking() {
1588        let pb = ComposeProblem {
1589            input: tvec![2, 2, 2],
1590            ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1591        };
1592        pb.check().unwrap();
1593    }
1594
1595    #[test]
1596    fn simplify_reshape() {
1597        macro_rules! d {
1598            ($($dim: expr),*) =>  { tvec!($($dim.to_dim()),*) }
1599        }
1600        assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1601        assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1602        assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1603        assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1604        assert_eq!(
1605            Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1606            tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1607        );
1608        assert_eq!(
1609            Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1610            tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1611        );
1612        assert_eq!(
1613            Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1614            tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1615        );
1616        assert_eq!(
1617            Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1618            tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1619        );
1620        assert_eq!(
1621            Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1622            tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1623        );
1624        assert_eq!(
1625            Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1626            tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1627        );
1628        assert_eq!(
1629            Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1630            tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1631        );
1632        assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1633    }
1634
1635    macro_rules! s {
1636        ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1637    }
1638
1639    macro_rules! r {
1640        ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1641            AxisOp::Reshape($at, tvec!($($from.into()),*),  tvec!($($to.into()),*))
1642        }
1643    }
1644
1645    #[test]
1646    fn compute_invalid() {
1647        assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1648    }
1649
1650    #[test]
1651    fn compute_with_leading_zero() {
1652        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1653    }
1654
1655    #[test]
1656    fn compute_with_leading_zero_with_flatten() {
1657        assert_eq!(
1658            &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1659            s![2, 3, 35]
1660        )
1661    }
1662
1663    #[test]
1664    fn compute_with_trailing_zero() {
1665        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1666    }
1667
1668    #[test]
1669    fn compute_bug_1() {
1670        let table = SymbolScope::default();
1671        let s = table.new_with_prefix("S");
1672        assert_eq!(
1673            &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1674            s![s, 1, 256]
1675        )
1676    }
1677
1678    #[test]
1679    fn compute_bug_2() {
1680        let table = SymbolScope::default();
1681        let b = table.new_with_prefix("B");
1682        let s = table.new_with_prefix("S");
1683        assert_eq!(
1684            &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1685            s![s, b, 256]
1686        )
1687    }
1688
1689    #[test]
1690    fn compute_zero_with_rank_change() {
1691        // Moonshine RoPE: input rank 4, output rank 5, two leading 0s
1692        assert_eq!(
1693            &*compute_shape_with_tf_rules(s![1, 52, 8, 32], s!(0, 0, 8, 16, 2)).unwrap(),
1694            s![1, 52, 8, 16, 2]
1695        )
1696    }
1697
1698    #[test]
1699    fn axis_op_rm_begin() {
1700        assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1701    }
1702
1703    #[test]
1704    fn axis_op_rm_end() {
1705        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1706    }
1707
1708    #[test]
1709    fn axis_op_insert_begin() {
1710        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1711    }
1712
1713    #[test]
1714    fn axis_op_insert_end() {
1715        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1716    }
1717
1718    #[test]
1719    fn axis_op_merge() {
1720        assert_eq!(
1721            &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1722            &[r!(2 ; 5,7 => 35 )]
1723        )
1724    }
1725
1726    #[test]
1727    fn axis_op_complex() {
1728        assert_eq!(
1729            &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1730            &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1731        )
1732    }
1733}