Skip to main content

tract_core/ops/
change_axes.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use num_traits::One;
8use tract_itertools::Itertools;
9use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue};
10use tract_ndarray::{ArrayViewD, ArrayViewMutD};
11use AxisOp::*;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum InOut {
15    Out(usize),
16    In(usize),
17}
18
19impl InOut {
20    pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
21        match self {
22            InOut::In(ix) => node.inputs[*ix],
23            InOut::Out(ix) => OutletId::new(node.id, *ix),
24        }
25    }
26
27    pub fn is_input(&self) -> bool {
28        matches!(self, InOut::In(_))
29    }
30
31    pub fn is_output(&self) -> bool {
32        matches!(self, InOut::Out(_))
33    }
34
35    pub fn slot(&self) -> usize {
36        match self {
37            InOut::Out(o) => *o,
38            InOut::In(i) => *i,
39        }
40    }
41}
42
43#[derive(Clone, Hash, Eq)]
44#[allow(clippy::large_enum_variant)] // FIXME ?
45#[allow(clippy::derived_hash_with_manual_eq)] // FIXME. this one may be pretty bad. how about a.canonical() == b.canonical() ? need proper canonicalizeation of Reshape
46pub enum AxisOp {
47    Add(usize),
48    Rm(usize),
49    Move(usize, usize),
50    Reshape(usize, TVec<TDim>, TVec<TDim>),
51}
52
53impl Debug for AxisOp {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AxisOp::Add(a) => write!(f, "Add({a})"),
57            AxisOp::Rm(a) => write!(f, "Rm({a})"),
58            AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
59            AxisOp::Reshape(at, from, to) => {
60                write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
61            }
62        }
63    }
64}
65
66impl PartialEq for AxisOp {
67    fn eq(&self, other: &AxisOp) -> bool {
68        if self.is_noop() && other.is_noop() {
69            true
70        } else if self.is_noop() != other.is_noop() {
71            false
72        } else {
73            match (self, other) {
74                (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
75                (Move(f1, t1), Move(f2, t2)) => {
76                    (f1 == f2 && t1 == t2)
77                        || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
78                }
79                (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
80                _ => false,
81            }
82        }
83    }
84}
85
86impl AxisOp {
87    pub fn canonical(&self) -> Cow<'_, AxisOp> {
88        match self {
89            Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
90            Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[0] => {
91                Cow::Owned(Add(*at + 1))
92            }
93            Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[1] => {
94                Cow::Owned(Add(*at))
95            }
96            Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[0] == to[0] => {
97                Cow::Owned(Rm(*at + 1))
98            }
99            Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[1] == to[0] => {
100                Cow::Owned(Rm(*at))
101            }
102            other => Cow::Borrowed(other),
103        }
104    }
105
106    pub fn simplify(&self) -> TVec<AxisOp> {
107        match self.canonical().borrow() {
108            Reshape(_, from, to) if from == to => tvec!(),
109            Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
110            Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
111            Reshape(at, from, to) if from[0] == to[0] => {
112                Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
113            }
114            Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
115                Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
116            }
117            Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
118                .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
119                .collect(),
120            Reshape(at, from, to) if to[0] == 1.to_dim() => {
121                Reshape(*at, from.clone(), to[1..].into())
122                    .simplify()
123                    .into_iter()
124                    .chain(std::iter::once(Add(*at)))
125                    .collect()
126            }
127            Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
128                std::iter::once(Rm(at + from.len() - 1))
129                    .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
130                    .collect()
131            }
132            Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
133                std::iter::once(Add(at + from.len()))
134                    .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
135                    .collect()
136            }
137            other => tvec!(other.clone()),
138        }
139    }
140
141    pub fn transform_axis(&self, axis: usize) -> Option<usize> {
142        match self.canonical().as_ref() {
143            Add(ix) => Some(axis + (axis >= *ix) as usize),
144            Rm(ix) => {
145                if axis == *ix {
146                    None
147                } else {
148                    Some(axis - (axis > *ix) as usize)
149                }
150            }
151            Move(from, to) if from < to => {
152                if axis < *from || axis > *to {
153                    Some(axis)
154                } else if axis == *from {
155                    Some(*to)
156                } else {
157                    Some(axis - 1)
158                }
159            }
160            Move(from, to) => {
161                if axis < *to || axis > *from {
162                    Some(axis)
163                } else if axis == *from {
164                    Some(*to)
165                } else {
166                    Some(axis + 1)
167                }
168            }
169            Reshape(at, _, _) if axis < *at => Some(axis),
170            Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
171            Reshape(_, _, _) => None,
172        }
173    }
174
175    // if sucessful return Some()
176    // first item is the Op we want to be replaced by. if none, we are now identity.
177    // second item is the change to propagate. if none, the output is not
178    // changed
179    pub fn merge_incoming_change(
180        &self,
181        change: &AxisOp,
182    ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
183        match (self.canonical().as_ref(), change.canonical().as_ref()) {
184            (Add(op), Add(c)) => {
185                Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
186            }
187            (Add(op), Rm(c)) => {
188                Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
189            }
190            (Rm(op), Add(c)) => {
191                Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
192            }
193            (Rm(op), Rm(c)) => {
194                Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
195            }
196
197            (Add(x), Move(from, to)) => {
198                if x <= from.min(to) {
199                    Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
200                } else if x > from.max(to) {
201                    Some((Some(self.clone()), Some(change.clone())))
202                } else {
203                    None
204                }
205            }
206
207            (Move(from, to), Add(x)) => {
208                if x <= from.min(to) {
209                    Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
210                } else if x > from.max(to) {
211                    Some((Some(Move(*from, *to)), Some(Add(*x))))
212                } else {
213                    None
214                }
215            }
216
217            (Rm(x), Move(from, to)) => {
218                if x == from {
219                    Some((Some(Rm(*to)), None))
220                } else if x < from.min(to) {
221                    Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
222                } else if x > from.max(to) {
223                    Some((Some(self.clone()), Some(change.clone())))
224                } else if from + 1 == *to && x == to {
225                    Some((Some(Rm(*from)), None))
226                } else if from < to && x <= to {
227                    Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
228                } else {
229                    Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
230                }
231            }
232
233            (Move(from, to), Rm(x)) => {
234                if x < from.min(to) {
235                    Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
236                } else if x > from.max(to) {
237                    Some((Some(Move(*from, *to)), Some(Rm(*x))))
238                } else {
239                    None
240                }
241            }
242
243            (Add(op), Reshape(at, from, to)) => {
244                if op <= at {
245                    Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
246                } else if *op > at + from.len() {
247                    Some((
248                        Some(Add(*op + to.len() - from.len())),
249                        Some(Reshape(*at, from.clone(), to.clone())),
250                    ))
251                } else {
252                    None
253                }
254            }
255            (Rm(op), Reshape(at, from, to)) => {
256                if op < at {
257                    Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
258                } else if *op > at + from.len() {
259                    Some((
260                        Some(Rm(*op + to.len() - from.len())),
261                        Some(Reshape(*at, from.clone(), to.clone())),
262                    ))
263                } else {
264                    None
265                }
266            }
267            (Reshape(at, from, to), Add(change)) => {
268                if change < at {
269                    Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
270                } else if *change > *at + from.len() {
271                    Some((
272                        Some(Reshape(*at, from.clone(), to.clone())),
273                        Some(Add(change + to.len() - from.len())),
274                    ))
275                } else {
276                    None
277                }
278            }
279            (Reshape(at, from, to), Rm(change)) => {
280                if change < at {
281                    Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
282                } else if *change > *at + from.len() {
283                    Some((
284                        Some(Reshape(*at, from.clone(), to.clone())),
285                        Some(Rm(change + to.len() - from.len())),
286                    ))
287                } else {
288                    None
289                }
290            }
291            (Reshape(_, _, _), Move(_, _)) => None, // todo, some are manageable
292            (Move(_, _), Reshape(_, _, _)) => None, // todo, some are manageable
293            (Reshape(_, _, _), Reshape(_, _, _)) => None, // todo, some are manageable
294            _ => None,
295        }
296    }
297
298    pub fn change_shape_array<D: DimLike>(
299        &self,
300        shape: &mut TVec<D>,
301        broadcasting: bool,
302    ) -> TractResult<()> {
303        match self.canonical().as_ref() {
304            Add(ix) => {
305                ensure!(*ix <= shape.len());
306                shape.insert(*ix, D::one());
307            }
308            Rm(ix) => {
309                ensure!(*ix < shape.len());
310                shape.remove(*ix);
311            }
312            Move(from, to) => {
313                ensure!(*from < shape.len());
314                ensure!(*to < shape.len());
315                let axis = shape.remove(*from);
316                shape.insert(*to, axis);
317            }
318            Reshape(at, from, to) => {
319                let from_volume = from.iter().product::<TDim>();
320                let to_volume = to.iter().product::<TDim>();
321                ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
322                ensure!(*at + from.len() <= shape.len());
323                if shape.len() >= from.len() + *at
324                    && tract_itertools::izip!(shape.iter().skip(*at), from)
325                        .all(|(shape, spec)| shape.to_dim() == *spec)
326                {
327                    for _ in from {
328                        shape.remove(*at);
329                    }
330                    for d in to.iter().rev() {
331                        shape.insert(*at, d.try_into()?);
332                    }
333                } else if broadcasting
334                    && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
335                {
336                    for _ in from {
337                        shape.remove(*at);
338                    }
339                    for _ in to.iter().rev() {
340                        shape.insert(*at, 1.into());
341                    }
342                } else {
343                    bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
344                }
345            }
346        }
347        Ok(())
348    }
349
350    pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
351        match self.canonical().as_ref() {
352            Add(ix) => shape.insert_axis(*ix),
353            Rm(ix) => {
354                if shape.rank() <= *ix {
355                    bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
356                }
357                if shape[*ix] != 1.to_dim() {
358                    bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
359                }
360                shape.remove_axis(*ix)
361            }
362            _ => {
363                let mut array = shape.to_tvec();
364                self.change_shape_array(&mut array, broadcasting)?;
365                let mut new_shape = ShapeFact::from_dims(array);
366                std::mem::swap(shape, &mut new_shape);
367                Ok(())
368            }
369        }
370    }
371
372    pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
373        if self.required_rank() > tensor.rank() && tensor.datum_type().is_opaque() {
374            let inner_change = self.trim_left(tensor.rank())?;
375            for opaque in tensor.as_slice_mut::<Opaque>()? {
376                if let Some(bqv) = opaque.downcast_ref::<BlockQuantValue>() {
377                    let mut new_shape: TVec<usize> = bqv.fact.shape().into();
378                    inner_change.change_shape_array(&mut new_shape, false)?;
379                    let new_bqv = BlockQuantValue {
380                        value: Arc::clone(&bqv.value),
381                        fact: BlockQuantFact::new(bqv.fact.format.clone(), new_shape),
382                    };
383                    *opaque = Opaque(Arc::new(new_bqv));
384                } else {
385                    bail!("Can't apply {self:?} to opaque tensor {tensor:?}");
386                }
387            }
388            return Ok(());
389        }
390        ensure!(self.required_rank() <= tensor.rank());
391        match self.canonical().as_ref() {
392            Add(ix) => tensor.insert_axis(*ix),
393            Rm(ix) => tensor.remove_axis(*ix),
394            Move(from, to) => {
395                let mut tmp = tensor.clone().move_axis(*from, *to)?;
396                std::mem::swap(tensor, &mut tmp);
397                Ok(())
398            }
399            Reshape(at, from, to) => {
400                let mut shape: TVec<usize> = tensor.shape().into();
401                self.change_shape_array(&mut shape, true)?;
402                if tensor.set_shape(&shape).is_ok() {
403                    Ok(())
404                } else if broadcasting
405                    && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
406                {
407                    if from.len() > to.len() {
408                        for _ in to.len()..from.len() {
409                            tensor.remove_axis(*at)?;
410                        }
411                    }
412                    if to.len() > from.len() {
413                        for _ in from.len()..to.len() {
414                            tensor.insert_axis(*at)?;
415                        }
416                    }
417                    Ok(())
418                } else {
419                    bail!(
420                        "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
421                        self,
422                        tensor,
423                        broadcasting
424                    )
425                }
426            }
427        }
428    }
429
430    pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
431        use tract_ndarray::Axis;
432        match *self {
433            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
434            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
435            AxisOp::Move(from, to) if from < to => {
436                for left in from..to {
437                    view.swap_axes(left, left + 1);
438                }
439            }
440            AxisOp::Move(from, to) => {
441                for left in (to..from).rev() {
442                    view.swap_axes(left, left + 1);
443                }
444            }
445            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
446        }
447        Ok(())
448    }
449
450    pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
451        use tract_ndarray::Axis;
452        match *self {
453            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
454            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
455            AxisOp::Move(from, to) if from < to => {
456                for left in from..to {
457                    view.swap_axes(left, left + 1);
458                }
459            }
460            AxisOp::Move(from, to) => {
461                for left in (to..from).rev() {
462                    view.swap_axes(left, left + 1);
463                }
464            }
465            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
466        }
467        Ok(())
468    }
469
470    pub fn recip(&self) -> AxisOp {
471        match self.canonical().as_ref() {
472            Add(ix) => Rm(*ix),
473            Rm(ix) => Add(*ix),
474            Move(from, to) if from == to => self.clone(),
475            Move(from, to) if *from + 1 == *to => self.clone(),
476            Move(from, to) if *from == *to + 1 => {
477                unreachable!();
478            }
479            Move(from, to) => Move(*to, *from),
480            Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
481        }
482    }
483
484    pub fn is_noop(&self) -> bool {
485        match self {
486            Move(f, t) if f == t => true,
487            Reshape(_, f, t) if f == t => true,
488            _ => false,
489        }
490    }
491
492    pub fn only_shape(&self) -> bool {
493        if self.is_noop() {
494            return true;
495        }
496        !matches!(self, Move(_, _))
497    }
498
499    pub fn wire_split_axis(
500        model: &mut TypedModel,
501        name: impl ToString,
502        outlet: OutletId,
503        axis: usize,
504        outer_dim: usize,
505    ) -> TractResult<TVec<OutletId>> {
506        let fact = model.outlet_fact(outlet)?;
507        let dim: TDim = fact.shape[axis].clone();
508        let inner_dim = dim.clone() / outer_dim;
509        let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
510        model.wire_node(name.to_string(), op, &[outlet])
511    }
512
513    pub fn wire_collapse_axis(
514        model: &mut TypedModel,
515        name: impl ToString,
516        outlet: OutletId,
517        axis: usize,
518    ) -> TractResult<TVec<OutletId>> {
519        let fact = model.outlet_fact(outlet)?;
520        let dim: TDim = fact.shape[axis].clone();
521        let next_dim: TDim = fact.shape[axis + 1].clone();
522        let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
523        model.wire_node(name.to_string(), op, &[outlet])
524    }
525
526    #[inline]
527    pub fn required_rank(&self) -> usize {
528        match self {
529            Rm(r) => r + 1,
530            Add(a) => *a,
531            Reshape(at, from, _to) => at + from.len(),
532            Move(from, to) => *from.max(to),
533        }
534    }
535
536    pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
537        Ok(match self {
538            Rm(r) if *r >= prefix => Rm(r - prefix),
539            Add(a) if *a >= prefix => Add(a - prefix),
540            Reshape(at, from, to) if *at >= prefix => {
541                Reshape(at - prefix, from.clone(), to.clone())
542            }
543            Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
544            _ => bail!("Can no trim left {self:?} by {prefix}"),
545        })
546    }
547}
548
549pub fn wire_rank_broadcast(
550    prefix: impl AsRef<str>,
551    target: &mut TypedModel,
552    inputs: &[OutletId],
553) -> TractResult<TVec<OutletId>> {
554    let facts =
555        inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
556    let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
557    let mut wires = tvec!();
558    let prefix = prefix.as_ref();
559    for i in 0..inputs.len() {
560        let mut wire = inputs[i];
561        for j in facts[i].rank()..max_rank {
562            wire =
563                target.wire_node(format!("{prefix}.fix-rank-{i}-{j}"), AxisOp::Add(0), &[wire])?[0];
564        }
565        wires.push(wire);
566    }
567    Ok(wires)
568}
569
570pub fn wire_with_rank_broadcast(
571    prefix: impl AsRef<str>,
572    target: &mut TypedModel,
573    op: impl Into<Box<dyn TypedOp>>,
574    inputs: &[OutletId],
575) -> TractResult<TVec<OutletId>> {
576    let prefix = prefix.as_ref();
577    let wires = wire_rank_broadcast(prefix, target, inputs)?;
578    target.wire_node(prefix, op.into(), &wires)
579}
580
581#[derive(Clone, Debug, PartialEq, Eq, Hash)]
582pub struct AxisChange {
583    pub outlet: OutletId,
584    pub op: AxisOp,
585}
586
587#[derive(Clone, Default, Debug)]
588pub struct AxisChangeConsequence {
589    pub substitute_op: Option<Box<dyn TypedOp>>,
590    pub wire_changes: TVec<(InOut, AxisOp)>,
591}
592
593impl AxisChangeConsequence {
594    pub fn new(
595        _model: &TypedModel,
596        node: &TypedNode,
597        op: Option<Box<dyn TypedOp>>,
598        axis_op: &AxisOp,
599    ) -> AxisChangeConsequence {
600        let mut wire_changes = tvec!();
601        for i in 0..node.inputs.len() {
602            wire_changes.push((InOut::In(i), axis_op.clone()));
603        }
604        for i in 0..node.outputs.len() {
605            wire_changes.push((InOut::Out(i), axis_op.clone()));
606        }
607        AxisChangeConsequence { wire_changes, substitute_op: op }
608    }
609}
610
611impl Op for AxisOp {
612    fn name(&self) -> StaticName {
613        match self {
614            Add(_) => "AddAxis".into(),
615            Rm(_) => "RmAxis".into(),
616            Move(_, _) => "MoveAxis".into(),
617            Reshape(_, _, _) => "Reshape".into(),
618        }
619    }
620
621    fn info(&self) -> TractResult<Vec<String>> {
622        match self {
623            Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
624            Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
625            Reshape(at, from, to) => Ok(vec![format!(
626                "Axes starting at {}: {:?} to {:?}",
627                at,
628                from.iter().join(","),
629                to.iter().join(",")
630            )]),
631        }
632    }
633
634    op_as_typed_op!();
635}
636
637impl EvalOp for AxisOp {
638    fn is_stateless(&self) -> bool {
639        true
640    }
641
642    fn eval_with_session(
643        &self,
644        _node_id: usize,
645        session: &SessionState,
646        inputs: TVec<TValue>,
647    ) -> TractResult<TVec<TValue>> {
648        let mut input = args_1!(inputs).into_tensor();
649        match self {
650            AxisOp::Reshape(skip, from, to) => {
651                let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
652                let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
653                AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
654            }
655            _ => self.change_tensor(&mut input, false)?,
656        }
657        Ok(tvec!(input.into_tvalue()))
658    }
659}
660
661impl TypedOp for AxisOp {
662    as_op!();
663
664    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
665        if self.required_rank() > inputs[0].rank() {
666            if let Some(bqf) =
667                inputs[0].opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
668            {
669                let mut new_inner_shape: TVec<usize> = bqf.shape().into();
670                self.trim_left(inputs[0].rank())?
671                    .change_shape_array(&mut new_inner_shape, false)?;
672                let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_inner_shape);
673                let mut new_fact = Opaque::fact(inputs[0].shape.clone()).with_opaque_fact(new_bqf);
674                if let Some(k) = &inputs[0].konst {
675                    let mut new = k.clone().into_tensor(); // cloning bqv is cheap
676                    self.change_tensor(&mut new, false)?;
677                    new_fact.konst = Some(new.into());
678                }
679                return Ok(tvec!(new_fact));
680            }
681        }
682        let mut shape = inputs[0].shape.clone();
683        self.change_shape(&mut shape, false)?;
684        let mut fact = inputs[0].datum_type.fact(shape);
685        fact.opaque_fact.clone_from(&inputs[0].opaque_fact);
686        Ok(tvec!(fact))
687    }
688
689    fn axes_mapping(
690        &self,
691        inputs: &[&TypedFact],
692        outputs: &[&TypedFact],
693    ) -> TractResult<AxesMapping> {
694        let mut axes: Vec<Axis> = (0..inputs[0].rank())
695            .zip('a'..)
696            .map(|(axis_id, repr)| {
697                let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
698                if let Some(out) = self.transform_axis(axis_id) {
699                    axis = axis.output(0, out);
700                }
701                axis
702            })
703            .collect();
704        for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
705            if self.recip().transform_axis(axis).is_none() {
706                axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
707            }
708        }
709        AxesMapping::new(inputs.len(), outputs.len(), axes)
710    }
711
712    fn declutter(
713        &self,
714        model: &TypedModel,
715        node: &TypedNode,
716    ) -> TractResult<Option<TypedModelPatch>> {
717        if self.is_noop() {
718            if let Some(p) = TypedModelPatch::shunt_one_op(model, node)? {
719                return Ok(Some(p));
720            }
721        }
722        let simplified = self.simplify();
723        if simplified.len() != 1 || &simplified[0] != self {
724            let mut patch = TypedModelPatch::default();
725            let mut wire = patch.tap_model(model, node.inputs[0])?;
726            for (ix, op) in simplified.into_iter().enumerate() {
727                wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
728            }
729            patch.shunt_outside(model, node.id.into(), wire)?;
730            Ok(Some(patch))
731        } else {
732            Ok(None)
733        }
734    }
735
736    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
737        Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
738    }
739
740    fn change_axes(
741        &self,
742        _model: &TypedModel,
743        _node: &TypedNode,
744        io: InOut,
745        change: &AxisOp,
746    ) -> TractResult<Option<AxisChangeConsequence>> {
747        let op = if let InOut::Out(0) = io {
748            let more = if let Some(more) =
749                self.recip().change_axes(_model, _node, InOut::In(0), change)?
750            {
751                more
752            } else {
753                return Ok(None);
754            };
755            AxisChangeConsequence {
756                substitute_op: more.substitute_op.map(|op| {
757                    if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
758                        Box::new(op.recip())
759                    } else {
760                        op // have to be identity
761                    }
762                }),
763                wire_changes: more
764                    .wire_changes
765                    .into_iter()
766                    .map(|wc| {
767                        (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
768                    })
769                    .collect(),
770            }
771        } else if change == self {
772            AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
773        } else {
774            let (new_op, new_change) = if let Some(pair) = self.merge_incoming_change(change) {
775                pair
776            } else {
777                return Ok(None);
778            };
779            trace!("  Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
780            let substitute_op: Box<dyn TypedOp> =
781                if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
782            let mut wire_changes = tvec!();
783            if !change.is_noop() {
784                wire_changes.push((InOut::In(0), change.clone()))
785            }
786            if let Some(new_change) = new_change {
787                wire_changes.push((InOut::Out(0), new_change))
788            }
789            AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
790        };
791        Ok(Some(op))
792    }
793
794    fn concretize_dims(
795        &self,
796        _source: &TypedModel,
797        node: &TypedNode,
798        target: &mut TypedModel,
799        mapping: &HashMap<OutletId, OutletId>,
800        values: &SymbolValues,
801    ) -> TractResult<TVec<OutletId>> {
802        let op = if let AxisOp::Reshape(axis, from, to) = self {
803            AxisOp::Reshape(
804                *axis,
805                from.iter().map(|d| d.eval(values)).collect(),
806                to.iter().map(|d| d.eval(values)).collect(),
807            )
808        } else {
809            self.clone()
810        };
811        target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
812    }
813
814    fn slice(
815        &self,
816        patch: &mut TypedModelPatch,
817        _model: &TypedModel,
818        node: &TypedNode,
819        _prefix: &str,
820        inputs: &[OutletId],
821        output_axis: usize,
822        _start: &TDim,
823        _end: &TDim,
824    ) -> TractResult<Option<TVec<OutletId>>> {
825        // is this test really useful ? or axis mapping preempt this ?
826        if let Reshape(pos, _from, to) = self {
827            if output_axis >= *pos && output_axis < pos + to.len() {
828                return Ok(None);
829            }
830        }
831        patch.wire_node(&node.name, &node.op, inputs).map(Some)
832    }
833
834    fn codegen(
835        &self,
836        model: &TypedModel,
837        node: &TypedNode,
838    ) -> TractResult<Option<TypedModelPatch>> {
839        if node.outputs[0].fact.opaque_fact.is_some() {
840            return Ok(None);
841        }
842        if let Some(shape) = node.outputs[0].fact.shape.as_concrete() {
843            if !matches!(self, AxisOp::Move(_, _)) {
844                let (inputs, outputs) = model.node_facts(node.id)?;
845                let mapping = self.axes_mapping(&inputs, &outputs)?;
846                let op = IntoShape {
847                    mapping,
848                    len: shape.iter().product(),
849                    strides: Tensor::natural_strides(shape),
850                    dims: shape.into(),
851                };
852                return Ok(Some(TypedModelPatch::replace_single_op(
853                    model,
854                    node,
855                    &node.inputs,
856                    op,
857                )?));
858            }
859        }
860        Ok(None)
861    }
862}
863
864// a, b, c is a <- b, b <- c, c <- a
865fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
866    let mut cycles: TVec<TVec<usize>> = tvec!();
867    let mut done = 0;
868    while done < perm.len() {
869        if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
870            done += 1;
871            continue;
872        }
873        let mut cycle = tvec!();
874        let mut current = done;
875        loop {
876            cycle.push(current);
877            current = perm[current];
878            if current == done {
879                break;
880            }
881        }
882        cycles.push(cycle)
883    }
884    cycles
885}
886
887fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
888    if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
889        Some((cycle[0], cycle[cycle.len() - 1]))
890    } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
891        && cycle[cycle.len() - 1] - 1 == cycle[0]
892    {
893        Some((cycle[1], cycle[0]))
894    } else {
895        None
896    }
897}
898
899fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
900    let mut changes: TVec<(usize, usize)> = tvec!();
901    'top: loop {
902        let mut reached: TVec<usize> = (0..input.len()).collect();
903        changes.iter().for_each(|(f, t)| {
904            let axis = reached.remove(*f);
905            reached.insert(*t, axis);
906        });
907        if &*reached == input {
908            return changes;
909        }
910        let remaining: TVec<usize> =
911            input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
912        let cycles = perm_to_cycles(&remaining);
913        for cycle in &cycles {
914            if let Some(rot) = is_rotation_cycle(cycle) {
915                changes.push(rot);
916                continue 'top;
917            }
918        }
919        changes.push((cycles[0][1], cycles[0][0]));
920    }
921}
922
923pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
924    perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
925}
926
927pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
928    let mut shape: TVec<TDim> = shape_spec.into();
929    fn deal_with_zero<'a>(
930        mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
931        shape: &mut [TDim],
932    ) -> TractResult<()> {
933        let mut remaining_dim_input = 1.to_dim();
934        for slot in shape.iter_mut() {
935            if *slot == (-1).into() {
936                break;
937            }
938            if *slot == 0.into() {
939                if remaining_dim_input != TDim::one() {
940                    bail!("Invalid remaining dim");
941                }
942                *slot = (*input_dims.peek().context("Invalid")?).clone();
943            }
944            loop {
945                let quotient = remaining_dim_input.maybe_div(slot);
946                if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
947                    remaining_dim_input *= input_dims.next().context("Invalid")?;
948                } else {
949                    break;
950                }
951            }
952            remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
953        }
954        Ok(())
955    }
956
957    deal_with_zero(input.iter().peekable(), &mut shape)?;
958    shape.reverse();
959    deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
960    shape.reverse();
961
962    if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
963        let input_vol: TDim = input.iter().product();
964        let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
965        let div = input_vol.maybe_div(&shape_vol)?;
966        if div.1 != 1 {
967            bail!("invalid")
968        }
969        shape[pos] = div.0;
970    }
971    Ok(shape)
972}
973
974pub fn to_axis_ops_with_tf_rules(
975    input_orig: &[TDim],
976    output_spec: &[TDim],
977) -> TractResult<TVec<AxisOp>> {
978    let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
979    let mut stack: TVec<AxisOp> = tvec!();
980    'top: loop {
981        let current_input =
982            stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
983                op.change_shape_array(&mut shape, false)?;
984                Ok(shape)
985            })?;
986        if current_input == final_output {
987            return Ok(stack);
988        }
989        if let Some(common) =
990            current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
991        {
992            if current_input[common].is_one() {
993                stack.push(AxisOp::Rm(common));
994            } else if final_output[common].is_one() {
995                stack.push(AxisOp::Add(common));
996            } else {
997                // actual regrouping. search for a match. this is quadratic, but
998                // rank is expected to be somewhat reasonable
999                for i in common..current_input.len() {
1000                    let i_group = &current_input[common..i + 1];
1001                    let i_volume: TDim = i_group.iter().product();
1002                    for o in common..final_output.len() {
1003                        let o_group = &final_output[common..o + 1];
1004                        let o_volume: TDim = o_group.iter().product();
1005                        if i_volume == o_volume {
1006                            stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1007                            continue 'top;
1008                        }
1009                    }
1010                }
1011                todo!()
1012            }
1013        } else if final_output.len() > current_input.len() {
1014            stack.push(AxisOp::Add(current_input.len()));
1015        } else {
1016            stack.push(AxisOp::Rm(current_input.len() - 1));
1017        }
1018    }
1019}
1020
1021#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1022pub struct IntoShape {
1023    pub mapping: AxesMapping,
1024    pub len: usize,
1025    pub dims: TVec<usize>,
1026    pub strides: TVec<isize>,
1027}
1028
1029impl Op for IntoShape {
1030    fn name(&self) -> StaticName {
1031        "IntoShape".into()
1032    }
1033
1034    fn info(&self) -> TractResult<Vec<String>> {
1035        Ok(vec![format!("{}", self.mapping)])
1036    }
1037
1038    op_as_typed_op!();
1039    impl_op_same_as!();
1040}
1041
1042impl EvalOp for IntoShape {
1043    fn is_stateless(&self) -> bool {
1044        true
1045    }
1046
1047    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1048        let mut input = args_1!(inputs).into_tensor();
1049        ensure!(input.len() == self.len);
1050        unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1051        Ok(tvec!(input.into_tvalue()))
1052    }
1053}
1054
1055impl TypedOp for IntoShape {
1056    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1057        let mut fact = inputs[0].datum_type.fact(&self.dims);
1058        if let Some(of) = &inputs[0].opaque_fact {
1059            fact = fact.with_opaque_fact(of.clone());
1060        }
1061        Ok(tvec!(fact))
1062    }
1063
1064    fn declutter(
1065        &self,
1066        model: &TypedModel,
1067        node: &TypedNode,
1068    ) -> TractResult<Option<TypedModelPatch>> {
1069        let input = model.outlet_fact(node.inputs[0])?;
1070        if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1071            return TypedModelPatch::shunt_one_op(model, node);
1072        }
1073        if let Some(succ) = model.single_succ(node.id)? {
1074            if let Some(into_shape) = succ.op_as::<IntoShape>() {
1075                let op = Self {
1076                    mapping: self.mapping.compose(&into_shape.mapping)?,
1077                    ..into_shape.clone()
1078                };
1079                return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1080            }
1081        }
1082        Ok(None)
1083    }
1084
1085    as_op!();
1086}
1087
1088#[cfg(test)]
1089mod test {
1090    use super::*;
1091
1092    #[test]
1093    fn test_perm_to_cycles() {
1094        assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1095        assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1096        assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1097        assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1098        assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1099    }
1100
1101    #[test]
1102    fn is_rotation() {
1103        assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1104        assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1105    }
1106
1107    #[test]
1108    fn test_perm_one_rotation() {
1109        assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1110    }
1111
1112    #[test]
1113    fn test_perm_two_rotations() {
1114        assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1115    }
1116
1117    #[test]
1118    fn test_perm_complex() {
1119        assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1120    }
1121
1122    // ADD-ADD
1123
1124    //                          Op
1125    //           b,c   ------|Add(0)|----->        n,b,c
1126    //   Add(0)                                            Add(1)
1127    //         a,b,c   ------|Add(0)|----->        a,n,b,c
1128    #[test]
1129    pub fn transform_op_add_0_add_0() {
1130        let change = Add(0);
1131        let op = Add(0);
1132        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1133    }
1134
1135    //                          Op
1136    //           b,c   ------|Add(1)|----->        b,n,c
1137    //   Add(0)                                                 Add(0)
1138    //         a,b,c   ------|Add(2)|----->        a,b,n,c
1139    #[test]
1140    pub fn transform_op_add_0_add_1() {
1141        let change = Add(0);
1142        let op = Add(1);
1143        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1144    }
1145
1146    //                          Op
1147    //           a,c   ------|Add(0)|----->        n,a,c
1148    //   Add(1)                                                 Add(2)
1149    //         a,b,c   ------|Add(0)|----->        n,a,b,c
1150    #[test]
1151    pub fn transform_op_add_1_add_0() {
1152        let change = Add(1);
1153        let op = Add(0);
1154        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1155    }
1156
1157    //                          Op
1158    //         a,b,c   ------|Rm(1)|----->         a,c
1159    //   Rm(0)                                             Rm(0)
1160    //           b,c   ------|Rm(0)|----->         c
1161    #[test]
1162    pub fn transform_op_rm_0_rm_1() {
1163        let change = Rm(0);
1164        let op = Rm(1);
1165        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1166    }
1167
1168    //                          Op
1169    //         a,b,c   ------|Rm(0)|----->         b,c
1170    //   Rm(1)                                             Rm(0)
1171    //           a,c   ------|Rm(0)|----->         c
1172    #[test]
1173    pub fn transform_op_rm_1_rm_0() {
1174        let change = Rm(1);
1175        let op = Rm(0);
1176        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1177    }
1178
1179    // ADD - RM
1180
1181    //                          Op
1182    //          b,c     ------|Rm(0)|------>        c
1183    //   Add(0)                                                 Add(0)
1184    //          a,b,c   ------|Rm(1)|----->         a,c
1185    #[test]
1186    pub fn transform_op_add_0_rm_0() {
1187        let change = Add(0);
1188        let op = Rm(0);
1189        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1190    }
1191
1192    //                          Op
1193    //          b,c     ------|Rm(1)|------>        b
1194    //   Add(0)                                                 Add(0)
1195    //          a,b,c   ------|Rm(2)|----->         a,b
1196    #[test]
1197    pub fn transform_op_add_0_rm_1() {
1198        let change = Add(0);
1199        let op = Rm(1);
1200        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1201    }
1202
1203    //                          Op
1204    //          a,c     ------|Rm(0)|------>        c
1205    //   Add(1)                                                 Add(0)
1206    //          a,b,c   ------|Rm(0)|----->         b,c
1207    #[test]
1208    pub fn transform_op_add_1_rm_0() {
1209        let change = Add(1);
1210        let op = Rm(0);
1211        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1212    }
1213
1214    // RM - ADD
1215
1216    //                          Op
1217    //         a,b,c   ------|Add(0)|----->        X,a,b,c
1218    //   Rm(1)                                                 Rm(2)
1219    //           a,c   ------|Add(0)|----->        X,a,c
1220    #[test]
1221    pub fn transform_op_rm_1_add_0() {
1222        let change = Rm(1);
1223        let op = Add(0);
1224        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1225    }
1226
1227    //                          Op
1228    //         a,b,c   ------|Add(1)|----->        a,X,b,c
1229    //   Rm(0)                                                 Rm(0)
1230    //           b,c   ------|Add(0)|----->        X,b,c
1231    #[test]
1232    pub fn transform_op_rm_0_add_1() {
1233        let change = Rm(0);
1234        let op = Add(1);
1235        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1236    }
1237
1238    //                          Op
1239    //         a,b,c   ------|Rm(2)|----->        a,b
1240    //   Move(0, 2)                                           Move(0,1)
1241    //         b,c,a   ------|Rm(1)|----->        b,a
1242    #[test]
1243    pub fn transform_op_mv_02_rm_2() {
1244        let change = Move(0, 2);
1245        let op = Rm(2);
1246        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1247    }
1248}
1249
1250#[cfg(test)]
1251mod proptests {
1252    use super::*;
1253    use proptest::prelude::*;
1254
1255    #[derive(Debug)]
1256    struct ComposeProblem {
1257        input: TVec<usize>,
1258        ops: TVec<AxisOp>,
1259    }
1260
1261    impl Arbitrary for AxisOp {
1262        type Parameters = TVec<usize>;
1263        type Strategy = BoxedStrategy<AxisOp>;
1264        fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1265            let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1266            if shape.len() > 1 {
1267                ops = ops
1268                    .prop_union(
1269                        (0..shape.len(), 0..shape.len() - 1)
1270                            .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1271                            .boxed(),
1272                    )
1273                    .boxed()
1274            }
1275            let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1276            if rms.len() > 0 {
1277                ops = ops
1278                    .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1279                    .boxed()
1280            }
1281            let mergeable: Vec<AxisOp> = shape
1282                .windows(2)
1283                .enumerate()
1284                .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1285                .map(|(ix, w)| {
1286                    Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1287                })
1288                .collect();
1289            if mergeable.len() > 1 {
1290                ops = ops
1291                    .prop_union(
1292                        (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1293                    )
1294                    .boxed()
1295            }
1296            ops
1297        }
1298    }
1299
1300    impl Arbitrary for ComposeProblem {
1301        type Parameters = ();
1302        type Strategy = BoxedStrategy<ComposeProblem>;
1303        fn arbitrary_with(_args: ()) -> Self::Strategy {
1304            let input = proptest::collection::vec(1usize..4, 1usize..4);
1305            fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1306                if len == 0 {
1307                    Just(tvec!()).boxed()
1308                } else {
1309                    AxisOp::arbitrary_with(shape.clone())
1310                        .prop_flat_map(move |op| {
1311                            let mut shape = shape.clone();
1312                            op.change_shape_array(&mut shape, false).unwrap();
1313                            tail(len - 1, shape.clone()).prop_map(move |mut t| {
1314                                t.insert(0, op.clone());
1315                                t
1316                            })
1317                        })
1318                        .boxed()
1319                }
1320            }
1321            (input, 1usize..=5)
1322                .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1323                .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1324                .boxed()
1325        }
1326    }
1327
1328    impl ComposeProblem {
1329        pub fn model(&self) -> TractResult<TypedModel> {
1330            let mut model = TypedModel::default();
1331            let mut wire = model.add_source("source", i64::fact(&self.input))?;
1332            for (ix, op) in self.ops.iter().enumerate() {
1333                wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1334            }
1335            model.set_output_outlets(&[wire])?;
1336            Ok(model)
1337        }
1338
1339        fn input(&self) -> TractResult<Tensor> {
1340            unsafe {
1341                let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1342                for i in 0..t.len() {
1343                    t.as_slice_mut().unwrap()[i] = i as i64;
1344                }
1345                Ok(t)
1346            }
1347        }
1348
1349        fn check(&self) -> TractResult<()> {
1350            crate::setup_test_logger();
1351            let input = self.input()?;
1352            let model = self.model()?;
1353            let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1354            let optimized = self.model()?.into_decluttered()?;
1355            let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1356            opt[0].close_enough(&raw[0], false)
1357        }
1358    }
1359
1360    proptest! {
1361        #[test]
1362        fn recip(pb in any::<AxisOp>()) {
1363            assert_eq!(pb.recip().recip(), pb);
1364        }
1365
1366        #[test]
1367        fn axis_ops(pb in any::<ComposeProblem>()) {
1368            pb.check().unwrap()
1369        }
1370    }
1371
1372    #[test]
1373    fn add_0_rm_0() {
1374        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1375        pb.check().unwrap();
1376    }
1377
1378    #[test]
1379    fn add_0_move_01() {
1380        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1381        pb.check().unwrap();
1382    }
1383
1384    #[test]
1385    fn add_0_move_01_add_1() {
1386        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1387        pb.check().unwrap();
1388    }
1389
1390    #[test]
1391    fn recip_move_01() {
1392        let op = Move(1, 0);
1393        assert_eq!(op.recip().recip(), op);
1394    }
1395
1396    #[test]
1397    fn recip_move_20() {
1398        let op = Move(2, 0);
1399        assert_eq!(op.recip().recip(), op);
1400    }
1401
1402    #[test]
1403    fn recip_move_02() {
1404        let op = Move(0, 2);
1405        assert_eq!(op.recip().recip(), op);
1406    }
1407
1408    #[test]
1409    fn add_0_add_1_move_02() {
1410        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1411        pb.check().unwrap();
1412    }
1413
1414    #[test]
1415    fn add_0_add_0() {
1416        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1417        pb.check().unwrap();
1418    }
1419
1420    #[test]
1421    fn add_0_add_0_move_02() {
1422        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1423        pb.check().unwrap();
1424    }
1425
1426    #[test]
1427    fn add_0_add_2_move_12() {
1428        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1429        pb.check().unwrap();
1430    }
1431
1432    #[test]
1433    fn add_0_add_0_move_02_rm_0() {
1434        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1435        pb.check().unwrap();
1436    }
1437
1438    #[test]
1439    fn add_0_add_0_move_20_move_20() {
1440        let pb =
1441            ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1442        pb.check().unwrap();
1443    }
1444
1445    #[test]
1446    fn move_01_add_0() {
1447        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1448        pb.check().unwrap();
1449    }
1450
1451    #[test]
1452    fn add_0_move_02_move_02() {
1453        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1454        pb.check().unwrap();
1455    }
1456
1457    #[test]
1458    fn add_0_add_2_move_20_move_12_rm_2() {
1459        let pb = ComposeProblem {
1460            input: tvec![3],
1461            ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1462        };
1463        pb.check().unwrap();
1464    }
1465
1466    #[test]
1467    fn move_02_move_02() {
1468        let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1469        pb.check().unwrap();
1470    }
1471
1472    #[test]
1473    fn rm_1_perm_10_add_0() {
1474        let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1475        pb.check().unwrap();
1476    }
1477
1478    #[test]
1479    fn add_2_move_02_move_02() {
1480        let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1481        pb.check().unwrap();
1482    }
1483
1484    #[test]
1485    fn move_01_move_20_move_20() {
1486        let pb = ComposeProblem {
1487            input: tvec![2, 3, 2],
1488            ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1489        };
1490        pb.check().unwrap();
1491    }
1492
1493    #[test]
1494    fn reshape_axes_tracking() {
1495        let pb = ComposeProblem {
1496            input: tvec![2, 2, 2],
1497            ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1498        };
1499        pb.check().unwrap();
1500    }
1501
1502    #[test]
1503    fn simplify_reshape() {
1504        macro_rules! d {
1505            ($($dim: expr),*) =>  { tvec!($($dim.to_dim()),*) }
1506        }
1507        assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1508        assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1509        assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1510        assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1511        assert_eq!(
1512            Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1513            tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1514        );
1515        assert_eq!(
1516            Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1517            tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1518        );
1519        assert_eq!(
1520            Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1521            tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1522        );
1523        assert_eq!(
1524            Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1525            tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1526        );
1527        assert_eq!(
1528            Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1529            tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1530        );
1531        assert_eq!(
1532            Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1533            tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1534        );
1535        assert_eq!(
1536            Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1537            tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1538        );
1539        assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1540    }
1541
1542    macro_rules! s {
1543        ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1544    }
1545
1546    macro_rules! r {
1547        ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1548            AxisOp::Reshape($at, tvec!($($from.into()),*),  tvec!($($to.into()),*))
1549        }
1550    }
1551
1552    #[test]
1553    fn compute_invalid() {
1554        assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1555    }
1556
1557    #[test]
1558    fn compute_with_leading_zero() {
1559        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1560    }
1561
1562    #[test]
1563    fn compute_with_leading_zero_with_flatten() {
1564        assert_eq!(
1565            &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1566            s![2, 3, 35]
1567        )
1568    }
1569
1570    #[test]
1571    fn compute_with_trailing_zero() {
1572        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1573    }
1574
1575    #[test]
1576    fn compute_bug_1() {
1577        let table = SymbolScope::default();
1578        let s = table.new_with_prefix("S");
1579        assert_eq!(
1580            &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1581            s![s, 1, 256]
1582        )
1583    }
1584
1585    #[test]
1586    fn compute_bug_2() {
1587        let table = SymbolScope::default();
1588        let b = table.new_with_prefix("B");
1589        let s = table.new_with_prefix("S");
1590        assert_eq!(
1591            &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1592            s![s, b, 256]
1593        )
1594    }
1595
1596    #[test]
1597    fn axis_op_rm_begin() {
1598        assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1599    }
1600
1601    #[test]
1602    fn axis_op_rm_end() {
1603        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1604    }
1605
1606    #[test]
1607    fn axis_op_insert_begin() {
1608        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1609    }
1610
1611    #[test]
1612    fn axis_op_insert_end() {
1613        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1614    }
1615
1616    #[test]
1617    fn axis_op_merge() {
1618        assert_eq!(
1619            &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1620            &[r!(2 ; 5,7 => 35 )]
1621        )
1622    }
1623
1624    #[test]
1625    fn axis_op_complex() {
1626        assert_eq!(
1627            &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1628            &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1629        )
1630    }
1631}