tract_core/ops/
change_axes.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use num_traits::One;
8use tract_itertools::Itertools;
9use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue};
10use tract_ndarray::{ArrayViewD, ArrayViewMutD};
11use AxisOp::*;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum InOut {
15    Out(usize),
16    In(usize),
17}
18
19impl InOut {
20    pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
21        match self {
22            InOut::In(ix) => node.inputs[*ix],
23            InOut::Out(ix) => OutletId::new(node.id, *ix),
24        }
25    }
26
27    pub fn is_input(&self) -> bool {
28        matches!(self, InOut::In(_))
29    }
30
31    pub fn is_output(&self) -> bool {
32        matches!(self, InOut::Out(_))
33    }
34
35    pub fn slot(&self) -> usize {
36        match self {
37            InOut::Out(o) => *o,
38            InOut::In(i) => *i,
39        }
40    }
41}
42
43#[derive(Clone, Hash, Eq)]
44#[allow(clippy::large_enum_variant)] // FIXME ?
45#[allow(clippy::derived_hash_with_manual_eq)] // FIXME. this one may be pretty bad. how about a.canonical() == b.canonical() ? need proper canonicalizeation of Reshape
46pub enum AxisOp {
47    Add(usize),
48    Rm(usize),
49    Move(usize, usize),
50    Reshape(usize, TVec<TDim>, TVec<TDim>),
51}
52
53impl Debug for AxisOp {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AxisOp::Add(a) => write!(f, "Add({a})"),
57            AxisOp::Rm(a) => write!(f, "Rm({a})"),
58            AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
59            AxisOp::Reshape(at, from, to) => {
60                write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
61            }
62        }
63    }
64}
65
66impl PartialEq for AxisOp {
67    fn eq(&self, other: &AxisOp) -> bool {
68        if self.is_noop() && other.is_noop() {
69            true
70        } else if self.is_noop() != other.is_noop() {
71            false
72        } else {
73            match (self, other) {
74                (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
75                (Move(f1, t1), Move(f2, t2)) => {
76                    (f1 == f2 && t1 == t2)
77                        || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
78                }
79                (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
80                _ => false,
81            }
82        }
83    }
84}
85
86impl AxisOp {
87    pub fn canonical(&self) -> Cow<AxisOp> {
88        match self {
89            Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
90            Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[0] => {
91                Cow::Owned(Add(*at + 1))
92            }
93            Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[1] => {
94                Cow::Owned(Add(*at))
95            }
96            Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[0] == to[0] => {
97                Cow::Owned(Rm(*at + 1))
98            }
99            Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[1] == to[0] => {
100                Cow::Owned(Rm(*at))
101            }
102            other => Cow::Borrowed(other),
103        }
104    }
105
106    pub fn simplify(&self) -> TVec<AxisOp> {
107        match self.canonical().borrow() {
108            Reshape(_, from, to) if from == to => tvec!(),
109            Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
110            Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
111            Reshape(at, from, to) if from[0] == to[0] => {
112                Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
113            }
114            Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
115                Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
116            }
117            Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
118                .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
119                .collect(),
120            Reshape(at, from, to) if to[0] == 1.to_dim() => {
121                Reshape(*at, from.clone(), to[1..].into())
122                    .simplify()
123                    .into_iter()
124                    .chain(std::iter::once(Add(*at)))
125                    .collect()
126            }
127            Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
128                std::iter::once(Rm(at + from.len() - 1))
129                    .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
130                    .collect()
131            }
132            Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
133                std::iter::once(Add(at + from.len()))
134                    .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
135                    .collect()
136            }
137            other => tvec!(other.clone()),
138        }
139    }
140
141    pub fn transform_axis(&self, axis: usize) -> Option<usize> {
142        match self.canonical().as_ref() {
143            Add(ix) => Some(axis + (axis >= *ix) as usize),
144            Rm(ix) => {
145                if axis == *ix {
146                    None
147                } else {
148                    Some(axis - (axis > *ix) as usize)
149                }
150            }
151            Move(from, to) if from < to => {
152                if axis < *from || axis > *to {
153                    Some(axis)
154                } else if axis == *from {
155                    Some(*to)
156                } else {
157                    Some(axis - 1)
158                }
159            }
160            Move(from, to) => {
161                if axis < *to || axis > *from {
162                    Some(axis)
163                } else if axis == *from {
164                    Some(*to)
165                } else {
166                    Some(axis + 1)
167                }
168            }
169            Reshape(at, _, _) if axis < *at => Some(axis),
170            Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
171            Reshape(_, _, _) => None,
172        }
173    }
174
175    // if sucessful return Some()
176    // first item is the Op we want to be replaced by. if none, we are now identity.
177    // second item is the change to propagate. if none, the output is not
178    // changed
179    pub fn merge_incoming_change(
180        &self,
181        change: &AxisOp,
182    ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
183        match (self.canonical().as_ref(), change.canonical().as_ref()) {
184            (Add(op), Add(c)) => {
185                Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
186            }
187            (Add(op), Rm(c)) => {
188                Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
189            }
190            (Rm(op), Add(c)) => {
191                Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
192            }
193            (Rm(op), Rm(c)) => {
194                Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
195            }
196
197            (Add(x), Move(from, to)) => {
198                if x <= from.min(to) {
199                    Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
200                } else if x > from.max(to) {
201                    Some((Some(self.clone()), Some(change.clone())))
202                } else {
203                    None
204                }
205            }
206
207            (Move(from, to), Add(x)) => {
208                if x <= from.min(to) {
209                    Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
210                } else if x > from.max(to) {
211                    Some((Some(Move(*from, *to)), Some(Add(*x))))
212                } else {
213                    None
214                }
215            }
216
217            (Rm(x), Move(from, to)) => {
218                if x == from {
219                    Some((Some(Rm(*to)), None))
220                } else if x < from.min(to) {
221                    Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
222                } else if x > from.max(to) {
223                    Some((Some(self.clone()), Some(change.clone())))
224                } else if from + 1 == *to && x == to {
225                    Some((Some(Rm(*from)), None))
226                } else if from < to && x <= to {
227                    Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
228                } else {
229                    Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
230                }
231            }
232
233            (Move(from, to), Rm(x)) => {
234                if x < from.min(to) {
235                    Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
236                } else if x > from.max(to) {
237                    Some((Some(Move(*from, *to)), Some(Rm(*x))))
238                } else {
239                    None
240                }
241            }
242
243            (Add(op), Reshape(at, from, to)) => {
244                if op <= at {
245                    Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
246                } else if *op > at + from.len() {
247                    Some((
248                        Some(Add(*op + to.len() - from.len())),
249                        Some(Reshape(*at, from.clone(), to.clone())),
250                    ))
251                } else {
252                    None
253                }
254            }
255            (Rm(op), Reshape(at, from, to)) => {
256                if op < at {
257                    Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
258                } else if *op > at + from.len() {
259                    Some((
260                        Some(Rm(*op + to.len() - from.len())),
261                        Some(Reshape(*at, from.clone(), to.clone())),
262                    ))
263                } else {
264                    None
265                }
266            }
267            (Reshape(at, from, to), Add(change)) => {
268                if change < at {
269                    Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
270                } else if *change > *at + from.len() {
271                    Some((
272                        Some(Reshape(*at, from.clone(), to.clone())),
273                        Some(Add(change + to.len() - from.len())),
274                    ))
275                } else {
276                    None
277                }
278            }
279            (Reshape(at, from, to), Rm(change)) => {
280                if change < at {
281                    Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
282                } else if *change > *at + from.len() {
283                    Some((
284                        Some(Reshape(*at, from.clone(), to.clone())),
285                        Some(Rm(change + to.len() - from.len())),
286                    ))
287                } else {
288                    None
289                }
290            }
291            (Reshape(_, _, _), Move(_, _)) => None, // todo, some are manageable
292            (Move(_, _), Reshape(_, _, _)) => None, // todo, some are manageable
293            (Reshape(_, _, _), Reshape(_, _, _)) => None, // todo, some are manageable
294            _ => None,
295        }
296    }
297
298    pub fn change_shape_array<D: DimLike>(
299        &self,
300        shape: &mut TVec<D>,
301        broadcasting: bool,
302    ) -> TractResult<()> {
303        match self.canonical().as_ref() {
304            Add(ix) => {
305                ensure!(*ix <= shape.len());
306                shape.insert(*ix, D::one());
307            }
308            Rm(ix) => {
309                ensure!(*ix < shape.len());
310                shape.remove(*ix);
311            }
312            Move(from, to) => {
313                ensure!(*from < shape.len());
314                ensure!(*to < shape.len());
315                let axis = shape.remove(*from);
316                shape.insert(*to, axis);
317            }
318            Reshape(at, from, to) => {
319                let from_volume = from.iter().product::<TDim>();
320                let to_volume = to.iter().product::<TDim>();
321                ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
322                ensure!(*at + from.len() <= shape.len());
323                if shape.len() >= from.len() + *at
324                    && tract_itertools::izip!(shape.iter().skip(*at), from)
325                        .all(|(shape, spec)| shape.to_dim() == *spec)
326                {
327                    for _ in from {
328                        shape.remove(*at);
329                    }
330                    for d in to.iter().rev() {
331                        shape.insert(*at, d.try_into()?);
332                    }
333                } else if broadcasting
334                    && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
335                {
336                    for _ in from {
337                        shape.remove(*at);
338                    }
339                    for _ in to.iter().rev() {
340                        shape.insert(*at, 1.into());
341                    }
342                } else {
343                    bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
344                }
345            }
346        }
347        Ok(())
348    }
349
350    pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
351        match self.canonical().as_ref() {
352            Add(ix) => shape.insert_axis(*ix),
353            Rm(ix) => {
354                if shape.rank() <= *ix {
355                    bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
356                }
357                if shape[*ix] != 1.to_dim() {
358                    bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
359                }
360                shape.remove_axis(*ix)
361            }
362            _ => {
363                let mut array = shape.to_tvec();
364                self.change_shape_array(&mut array, broadcasting)?;
365                let mut new_shape = ShapeFact::from_dims(array);
366                std::mem::swap(shape, &mut new_shape);
367                Ok(())
368            }
369        }
370    }
371
372    pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
373        if self.required_rank() > tensor.rank() && tensor.datum_type().is_opaque() {
374            let inner_change = self.trim_left(tensor.rank())?;
375            for opaque in tensor.as_slice_mut::<Opaque>()? {
376                if let Some(bqv) = opaque.downcast_ref::<BlockQuantValue>() {
377                    let mut new_shape: TVec<usize> = bqv.fact.shape().into();
378                    inner_change.change_shape_array(&mut new_shape, false)?;
379                    let new_bqv = BlockQuantValue {
380                        value: Arc::clone(&bqv.value),
381                        fact: BlockQuantFact::new(bqv.fact.format.clone(), new_shape),
382                    };
383                    *opaque = Opaque(Arc::new(new_bqv));
384                } else {
385                    bail!("Can't apply {self:?} to opaque tensor {tensor:?}");
386                }
387            }
388            return Ok(());
389        }
390        ensure!(self.required_rank() <= tensor.rank());
391        match self.canonical().as_ref() {
392            Add(ix) => tensor.insert_axis(*ix),
393            Rm(ix) => tensor.remove_axis(*ix),
394            Move(from, to) => {
395                let mut tmp = tensor.clone().move_axis(*from, *to)?;
396                std::mem::swap(tensor, &mut tmp);
397                Ok(())
398            }
399            Reshape(at, from, to) => {
400                let mut shape: TVec<usize> = tensor.shape().into();
401                self.change_shape_array(&mut shape, false)?;
402                if tensor.set_shape(&shape).is_ok() {
403                    Ok(())
404                } else if broadcasting
405                    && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
406                {
407                    if from.len() > to.len() {
408                        for _ in to.len()..from.len() {
409                            tensor.remove_axis(*at)?;
410                        }
411                    }
412                    if to.len() > from.len() {
413                        for _ in from.len()..to.len() {
414                            tensor.insert_axis(*at)?;
415                        }
416                    }
417                    Ok(())
418                } else {
419                    bail!(
420                        "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
421                        self,
422                        tensor,
423                        broadcasting
424                    )
425                }
426            }
427        }
428    }
429
430    pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
431        use tract_ndarray::Axis;
432        match *self {
433            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
434            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
435            AxisOp::Move(from, to) if from < to => {
436                for left in from..to {
437                    view.swap_axes(left, left + 1);
438                }
439            }
440            AxisOp::Move(from, to) => {
441                for left in (to..from).rev() {
442                    view.swap_axes(left, left + 1);
443                }
444            }
445            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
446        }
447        Ok(())
448    }
449
450    pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
451        use tract_ndarray::Axis;
452        match *self {
453            AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
454            AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
455            AxisOp::Move(from, to) if from < to => {
456                for left in from..to {
457                    view.swap_axes(left, left + 1);
458                }
459            }
460            AxisOp::Move(from, to) => {
461                for left in (to..from).rev() {
462                    view.swap_axes(left, left + 1);
463                }
464            }
465            AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
466        }
467        Ok(())
468    }
469
470    pub fn recip(&self) -> AxisOp {
471        match self.canonical().as_ref() {
472            Add(ix) => Rm(*ix),
473            Rm(ix) => Add(*ix),
474            Move(from, to) if from == to => self.clone(),
475            Move(from, to) if *from + 1 == *to => self.clone(),
476            Move(from, to) if *from == *to + 1 => {
477                unreachable!();
478            }
479            Move(from, to) => Move(*to, *from),
480            Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
481        }
482    }
483
484    pub fn is_noop(&self) -> bool {
485        match self {
486            Move(f, t) if f == t => true,
487            Reshape(_, f, t) if f == t => true,
488            _ => false,
489        }
490    }
491
492    pub fn only_shape(&self) -> bool {
493        if self.is_noop() {
494            return true;
495        }
496        !matches!(self, Move(_, _))
497    }
498
499    pub fn wire_split_axis(
500        model: &mut TypedModel,
501        name: impl ToString,
502        outlet: OutletId,
503        axis: usize,
504        outer_dim: usize,
505    ) -> TractResult<TVec<OutletId>> {
506        let fact = model.outlet_fact(outlet)?;
507        let dim: TDim = fact.shape[axis].clone();
508        let inner_dim = dim.clone() / outer_dim;
509        let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
510        model.wire_node(name.to_string(), op, &[outlet])
511    }
512
513    pub fn wire_collapse_axis(
514        model: &mut TypedModel,
515        name: impl ToString,
516        outlet: OutletId,
517        axis: usize,
518    ) -> TractResult<TVec<OutletId>> {
519        let fact = model.outlet_fact(outlet)?;
520        let dim: TDim = fact.shape[axis].clone();
521        let next_dim: TDim = fact.shape[axis + 1].clone();
522        let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
523        model.wire_node(name.to_string(), op, &[outlet])
524    }
525
526    #[inline]
527    pub fn required_rank(&self) -> usize {
528        match self {
529            Rm(r) => r + 1,
530            Add(a) => *a,
531            Reshape(at, from, _to) => at + from.len(),
532            Move(from, to) => *from.max(to),
533        }
534    }
535
536    pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
537        Ok(match self {
538            Rm(r) if *r >= prefix => Rm(r - prefix),
539            Add(a) if *a >= prefix => Add(a - prefix),
540            Reshape(at, from, to) if *at >= prefix => {
541                Reshape(at - prefix, from.clone(), to.clone())
542            }
543            Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
544            _ => bail!("Can no trim left {self:?} by {prefix}"),
545        })
546    }
547}
548
549pub fn wire_rank_broadcast(
550    prefix: impl AsRef<str>,
551    target: &mut TypedModel,
552    inputs: &[OutletId],
553) -> TractResult<TVec<OutletId>> {
554    let facts =
555        inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
556    let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
557    let mut wires = tvec!();
558    let prefix = prefix.as_ref();
559    for i in 0..inputs.len() {
560        let mut wire = inputs[i];
561        for j in facts[i].rank()..max_rank {
562            wire =
563                target.wire_node(format!("{prefix}.fix-rank-{i}-{j}"), AxisOp::Add(0), &[wire])?[0];
564        }
565        wires.push(wire);
566    }
567    Ok(wires)
568}
569
570pub fn wire_with_rank_broadcast(
571    prefix: impl AsRef<str>,
572    target: &mut TypedModel,
573    op: impl Into<Box<dyn TypedOp>>,
574    inputs: &[OutletId],
575) -> TractResult<TVec<OutletId>> {
576    let prefix = prefix.as_ref();
577    let wires = wire_rank_broadcast(prefix, target, inputs)?;
578    target.wire_node(prefix, op.into(), &wires)
579}
580
581#[derive(Clone, Debug, PartialEq, Eq, Hash)]
582pub struct AxisChange {
583    pub outlet: OutletId,
584    pub op: AxisOp,
585}
586
587#[derive(Clone, Default, Debug)]
588pub struct AxisChangeConsequence {
589    pub substitute_op: Option<Box<dyn TypedOp>>,
590    pub wire_changes: TVec<(InOut, AxisOp)>,
591}
592
593impl AxisChangeConsequence {
594    pub fn new(
595        _model: &TypedModel,
596        node: &TypedNode,
597        op: Option<Box<dyn TypedOp>>,
598        axis_op: &AxisOp,
599    ) -> AxisChangeConsequence {
600        let mut wire_changes = tvec!();
601        for i in 0..node.inputs.len() {
602            wire_changes.push((InOut::In(i), axis_op.clone()));
603        }
604        for i in 0..node.outputs.len() {
605            wire_changes.push((InOut::Out(i), axis_op.clone()));
606        }
607        AxisChangeConsequence { wire_changes, substitute_op: op }
608    }
609}
610
611impl Op for AxisOp {
612    fn name(&self) -> Cow<str> {
613        match self {
614            Add(_) => "AddAxis".into(),
615            Rm(_) => "RmAxis".into(),
616            Move(_, _) => "MoveAxis".into(),
617            Reshape(_, _, _) => "Reshape".into(),
618        }
619    }
620
621    fn info(&self) -> TractResult<Vec<String>> {
622        match self {
623            Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
624            Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
625            Reshape(at, from, to) => Ok(vec![format!(
626                "Axes starting at {}: {:?} to {:?}",
627                at,
628                from.iter().join(","),
629                to.iter().join(",")
630            )]),
631        }
632    }
633
634    op_as_typed_op!();
635}
636
637impl EvalOp for AxisOp {
638    fn is_stateless(&self) -> bool {
639        true
640    }
641
642    fn eval_with_session(
643        &self,
644        session: &SessionState,
645        inputs: TVec<TValue>,
646    ) -> TractResult<TVec<TValue>> {
647        let mut input = args_1!(inputs).into_tensor();
648        match self {
649            AxisOp::Reshape(skip, from, to) => {
650                let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
651                let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
652                AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
653            }
654            _ => self.change_tensor(&mut input, false)?,
655        }
656        Ok(tvec!(input.into_tvalue()))
657    }
658}
659
660impl TypedOp for AxisOp {
661    as_op!();
662
663    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
664        if self.required_rank() > inputs[0].rank() {
665            if let Some(bqf) =
666                inputs[0].opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
667            {
668                let mut new_inner_shape: TVec<usize> = bqf.shape().into();
669                self.trim_left(inputs[0].rank())?
670                    .change_shape_array(&mut new_inner_shape, false)?;
671                let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_inner_shape);
672                let mut new_fact = Opaque::fact(inputs[0].shape.clone()).with_opaque_fact(new_bqf);
673                if let Some(k) = &inputs[0].konst {
674                    let mut new = k.clone().into_tensor(); // cloning bqv is cheap
675                    self.change_tensor(&mut new, false)?;
676                    new_fact.konst = Some(new.into());
677                }
678                return Ok(tvec!(new_fact));
679            }
680        }
681        let mut shape = inputs[0].shape.clone();
682        self.change_shape(&mut shape, false)?;
683        let mut fact = inputs[0].datum_type.fact(shape);
684        fact.opaque_fact.clone_from(&inputs[0].opaque_fact);
685        Ok(tvec!(fact))
686    }
687
688    fn axes_mapping(
689        &self,
690        inputs: &[&TypedFact],
691        outputs: &[&TypedFact],
692    ) -> TractResult<AxesMapping> {
693        let mut axes: Vec<Axis> = (0..inputs[0].rank())
694            .zip('a'..)
695            .map(|(axis_id, repr)| {
696                let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
697                if let Some(out) = self.transform_axis(axis_id) {
698                    axis = axis.output(0, out);
699                }
700                axis
701            })
702            .collect();
703        for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
704            if self.recip().transform_axis(axis).is_none() {
705                axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
706            }
707        }
708        AxesMapping::new(inputs.len(), outputs.len(), axes)
709    }
710
711    fn declutter(
712        &self,
713        model: &TypedModel,
714        node: &TypedNode,
715    ) -> TractResult<Option<TypedModelPatch>> {
716        if self.is_noop() {
717            if let Some(p) = TypedModelPatch::shunt_one_op(model, node)? {
718                return Ok(Some(p));
719            }
720        }
721        let simplified = self.simplify();
722        if simplified.len() != 1 || &simplified[0] != self {
723            let mut patch = TypedModelPatch::default();
724            let mut wire = patch.tap_model(model, node.inputs[0])?;
725            for (ix, op) in simplified.into_iter().enumerate() {
726                wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
727            }
728            patch.shunt_outside(model, node.id.into(), wire)?;
729            Ok(Some(patch))
730        } else {
731            Ok(None)
732        }
733    }
734
735    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
736        Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
737    }
738
739    fn change_axes(
740        &self,
741        _model: &TypedModel,
742        _node: &TypedNode,
743        io: InOut,
744        change: &AxisOp,
745    ) -> TractResult<Option<AxisChangeConsequence>> {
746        let op = if let InOut::Out(0) = io {
747            let more = if let Some(more) =
748                self.recip().change_axes(_model, _node, InOut::In(0), change)?
749            {
750                more
751            } else {
752                return Ok(None);
753            };
754            AxisChangeConsequence {
755                substitute_op: more.substitute_op.map(|op| {
756                    if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
757                        Box::new(op.recip())
758                    } else {
759                        op // have to be identity
760                    }
761                }),
762                wire_changes: more
763                    .wire_changes
764                    .into_iter()
765                    .map(|wc| {
766                        (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
767                    })
768                    .collect(),
769            }
770        } else if change == self {
771            AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
772        } else {
773            let (new_op, new_change) = if let Some(pair) = self.merge_incoming_change(change) {
774                pair
775            } else {
776                return Ok(None);
777            };
778            trace!(
779                "  Change:{:?} self:{:?} -> change:{:?} op:{:?}",
780                change,
781                self,
782                new_change,
783                new_op
784            );
785            let substitute_op: Box<dyn TypedOp> =
786                if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
787            let mut wire_changes = tvec!();
788            if !change.is_noop() {
789                wire_changes.push((InOut::In(0), change.clone()))
790            }
791            if let Some(new_change) = new_change {
792                wire_changes.push((InOut::Out(0), new_change))
793            }
794            AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
795        };
796        Ok(Some(op))
797    }
798
799    fn concretize_dims(
800        &self,
801        _source: &TypedModel,
802        node: &TypedNode,
803        target: &mut TypedModel,
804        mapping: &HashMap<OutletId, OutletId>,
805        values: &SymbolValues,
806    ) -> TractResult<TVec<OutletId>> {
807        let op = if let AxisOp::Reshape(axis, from, to) = self {
808            AxisOp::Reshape(
809                *axis,
810                from.iter().map(|d| d.eval(values)).collect(),
811                to.iter().map(|d| d.eval(values)).collect(),
812            )
813        } else {
814            self.clone()
815        };
816        target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
817    }
818
819    fn slice(
820        &self,
821        patch: &mut TypedModelPatch,
822        _model: &TypedModel,
823        node: &TypedNode,
824        _prefix: &str,
825        inputs: &[OutletId],
826        output_axis: usize,
827        _start: &TDim,
828        _end: &TDim,
829    ) -> TractResult<Option<TVec<OutletId>>> {
830        // is this test really useful ? or axis mapping preempt this ?
831        if let Reshape(pos, _from, to) = self {
832            if output_axis >= *pos && output_axis < pos + to.len() {
833                return Ok(None);
834            }
835        }
836        patch.wire_node(&node.name, &node.op, inputs).map(Some)
837    }
838
839    fn codegen(
840        &self,
841        model: &TypedModel,
842        node: &TypedNode,
843    ) -> TractResult<Option<TypedModelPatch>> {
844        if let Some(shape) = node.outputs[0].fact.shape.as_concrete() {
845            if !matches!(self, AxisOp::Move(_, _)) {
846                let (inputs, outputs) = model.node_facts(node.id)?;
847                let mapping = self.axes_mapping(&inputs, &outputs)?;
848                let op = IntoShape {
849                    mapping,
850                    len: shape.iter().product(),
851                    strides: Tensor::natural_strides(shape),
852                    dims: shape.into(),
853                };
854                return Ok(Some(TypedModelPatch::replace_single_op(
855                    model,
856                    node,
857                    &node.inputs,
858                    op,
859                )?));
860            }
861        }
862        Ok(None)
863    }
864}
865
866// a, b, c is a <- b, b <- c, c <- a
867fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
868    let mut cycles: TVec<TVec<usize>> = tvec!();
869    let mut done = 0;
870    while done < perm.len() {
871        if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
872            done += 1;
873            continue;
874        }
875        let mut cycle = tvec!();
876        let mut current = done;
877        loop {
878            cycle.push(current);
879            current = perm[current];
880            if current == done {
881                break;
882            }
883        }
884        cycles.push(cycle)
885    }
886    cycles
887}
888
889fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
890    if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
891        Some((cycle[0], cycle[cycle.len() - 1]))
892    } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
893        && cycle[cycle.len() - 1] - 1 == cycle[0]
894    {
895        Some((cycle[1], cycle[0]))
896    } else {
897        None
898    }
899}
900
901fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
902    let mut changes: TVec<(usize, usize)> = tvec!();
903    'top: loop {
904        let mut reached: TVec<usize> = (0..input.len()).collect();
905        changes.iter().for_each(|(f, t)| {
906            let axis = reached.remove(*f);
907            reached.insert(*t, axis);
908        });
909        if &*reached == input {
910            return changes;
911        }
912        let remaining: TVec<usize> =
913            input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
914        let cycles = perm_to_cycles(&remaining);
915        for cycle in &cycles {
916            if let Some(rot) = is_rotation_cycle(cycle) {
917                changes.push(rot);
918                continue 'top;
919            }
920        }
921        changes.push((cycles[0][1], cycles[0][0]));
922    }
923}
924
925pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
926    perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
927}
928
929pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
930    let mut shape: TVec<TDim> = shape_spec.into();
931    fn deal_with_zero<'a>(
932        mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
933        shape: &mut [TDim],
934    ) -> TractResult<()> {
935        let mut remaining_dim_input = 1.to_dim();
936        for slot in shape.iter_mut() {
937            if *slot == (-1).into() {
938                break;
939            }
940            if *slot == 0.into() {
941                if remaining_dim_input != TDim::one() {
942                    bail!("Invalid remaining dim");
943                }
944                *slot = (*input_dims.peek().context("Invalid")?).clone();
945            }
946            loop {
947                let quotient = remaining_dim_input.maybe_div(slot);
948                if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
949                    remaining_dim_input *= input_dims.next().context("Invalid")?;
950                } else {
951                    break;
952                }
953            }
954            remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
955        }
956        Ok(())
957    }
958
959    deal_with_zero(input.iter().peekable(), &mut shape)?;
960    shape.reverse();
961    deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
962    shape.reverse();
963
964    if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
965        let input_vol: TDim = input.iter().product();
966        let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
967        let div = input_vol.maybe_div(&shape_vol)?;
968        if div.1 != 1 {
969            bail!("invalid")
970        }
971        shape[pos] = div.0;
972    }
973    Ok(shape)
974}
975
976pub fn to_axis_ops_with_tf_rules(
977    input_orig: &[TDim],
978    output_spec: &[TDim],
979) -> TractResult<TVec<AxisOp>> {
980    let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
981    let mut stack: TVec<AxisOp> = tvec!();
982    'top: loop {
983        let current_input =
984            stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
985                op.change_shape_array(&mut shape, false)?;
986                Ok(shape)
987            })?;
988        if current_input == final_output {
989            return Ok(stack);
990        }
991        if let Some(common) =
992            current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
993        {
994            if current_input[common].is_one() {
995                stack.push(AxisOp::Rm(common));
996            } else if final_output[common].is_one() {
997                stack.push(AxisOp::Add(common));
998            } else {
999                // actual regrouping. search for a match. this is quadratic, but
1000                // rank is expected to be somewhat reasonable
1001                for i in common..current_input.len() {
1002                    let i_group = &current_input[common..i + 1];
1003                    let i_volume: TDim = i_group.iter().product();
1004                    for o in common..final_output.len() {
1005                        let o_group = &final_output[common..o + 1];
1006                        let o_volume: TDim = o_group.iter().product();
1007                        if i_volume == o_volume {
1008                            stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1009                            continue 'top;
1010                        }
1011                    }
1012                }
1013                todo!()
1014            }
1015        } else if final_output.len() > current_input.len() {
1016            stack.push(AxisOp::Add(current_input.len()));
1017        } else {
1018            stack.push(AxisOp::Rm(current_input.len() - 1));
1019        }
1020    }
1021}
1022
1023#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1024pub struct IntoShape {
1025    pub mapping: AxesMapping,
1026    pub len: usize,
1027    pub dims: TVec<usize>,
1028    pub strides: TVec<isize>,
1029}
1030
1031impl Op for IntoShape {
1032    fn name(&self) -> Cow<str> {
1033        "IntoShape".into()
1034    }
1035
1036    fn info(&self) -> TractResult<Vec<String>> {
1037        Ok(vec![format!("{}", self.mapping)])
1038    }
1039
1040    op_as_typed_op!();
1041    impl_op_same_as!();
1042}
1043
1044impl EvalOp for IntoShape {
1045    fn is_stateless(&self) -> bool {
1046        true
1047    }
1048
1049    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1050        let mut input = args_1!(inputs).into_tensor();
1051        ensure!(input.len() == self.len);
1052        unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1053        Ok(tvec!(input.into_tvalue()))
1054    }
1055}
1056
1057impl TypedOp for IntoShape {
1058    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1059        let mut fact = inputs[0].datum_type.fact(&self.dims);
1060        if let Some(of) = &inputs[0].opaque_fact {
1061            fact = fact.with_opaque_fact(of.clone());
1062        }
1063        Ok(tvec!(fact))
1064    }
1065
1066    fn declutter(
1067        &self,
1068        model: &TypedModel,
1069        node: &TypedNode,
1070    ) -> TractResult<Option<TypedModelPatch>> {
1071        let input = model.outlet_fact(node.inputs[0])?;
1072        if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1073            return TypedModelPatch::shunt_one_op(model, node);
1074        }
1075        if let Some(succ) = model.single_succ(node.id)? {
1076            if let Some(into_shape) = succ.op_as::<IntoShape>() {
1077                let op = Self {
1078                    mapping: self.mapping.compose(&into_shape.mapping)?,
1079                    ..into_shape.clone()
1080                };
1081                return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1082            }
1083        }
1084        Ok(None)
1085    }
1086
1087    as_op!();
1088}
1089
1090#[cfg(test)]
1091mod test {
1092    use super::*;
1093
1094    #[test]
1095    fn test_perm_to_cycles() {
1096        assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1097        assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1098        assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1099        assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1100        assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1101    }
1102
1103    #[test]
1104    fn is_rotation() {
1105        assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1106        assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1107    }
1108
1109    #[test]
1110    fn test_perm_one_rotation() {
1111        assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1112    }
1113
1114    #[test]
1115    fn test_perm_two_rotations() {
1116        assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1117    }
1118
1119    #[test]
1120    fn test_perm_complex() {
1121        assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1122    }
1123
1124    // ADD-ADD
1125
1126    //                          Op
1127    //           b,c   ------|Add(0)|----->        n,b,c
1128    //   Add(0)                                            Add(1)
1129    //         a,b,c   ------|Add(0)|----->        a,n,b,c
1130    #[test]
1131    pub fn transform_op_add_0_add_0() {
1132        let change = Add(0);
1133        let op = Add(0);
1134        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1135    }
1136
1137    //                          Op
1138    //           b,c   ------|Add(1)|----->        b,n,c
1139    //   Add(0)                                                 Add(0)
1140    //         a,b,c   ------|Add(2)|----->        a,b,n,c
1141    #[test]
1142    pub fn transform_op_add_0_add_1() {
1143        let change = Add(0);
1144        let op = Add(1);
1145        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1146    }
1147
1148    //                          Op
1149    //           a,c   ------|Add(0)|----->        n,a,c
1150    //   Add(1)                                                 Add(2)
1151    //         a,b,c   ------|Add(0)|----->        n,a,b,c
1152    #[test]
1153    pub fn transform_op_add_1_add_0() {
1154        let change = Add(1);
1155        let op = Add(0);
1156        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1157    }
1158
1159    //                          Op
1160    //         a,b,c   ------|Rm(1)|----->         a,c
1161    //   Rm(0)                                             Rm(0)
1162    //           b,c   ------|Rm(0)|----->         c
1163    #[test]
1164    pub fn transform_op_rm_0_rm_1() {
1165        let change = Rm(0);
1166        let op = Rm(1);
1167        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1168    }
1169
1170    //                          Op
1171    //         a,b,c   ------|Rm(0)|----->         b,c
1172    //   Rm(1)                                             Rm(0)
1173    //           a,c   ------|Rm(0)|----->         c
1174    #[test]
1175    pub fn transform_op_rm_1_rm_0() {
1176        let change = Rm(1);
1177        let op = Rm(0);
1178        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1179    }
1180
1181    // ADD - RM
1182
1183    //                          Op
1184    //          b,c     ------|Rm(0)|------>        c
1185    //   Add(0)                                                 Add(0)
1186    //          a,b,c   ------|Rm(1)|----->         a,c
1187    #[test]
1188    pub fn transform_op_add_0_rm_0() {
1189        let change = Add(0);
1190        let op = Rm(0);
1191        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1192    }
1193
1194    //                          Op
1195    //          b,c     ------|Rm(1)|------>        b
1196    //   Add(0)                                                 Add(0)
1197    //          a,b,c   ------|Rm(2)|----->         a,b
1198    #[test]
1199    pub fn transform_op_add_0_rm_1() {
1200        let change = Add(0);
1201        let op = Rm(1);
1202        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1203    }
1204
1205    //                          Op
1206    //          a,c     ------|Rm(0)|------>        c
1207    //   Add(1)                                                 Add(0)
1208    //          a,b,c   ------|Rm(0)|----->         b,c
1209    #[test]
1210    pub fn transform_op_add_1_rm_0() {
1211        let change = Add(1);
1212        let op = Rm(0);
1213        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1214    }
1215
1216    // RM - ADD
1217
1218    //                          Op
1219    //         a,b,c   ------|Add(0)|----->        X,a,b,c
1220    //   Rm(1)                                                 Rm(2)
1221    //           a,c   ------|Add(0)|----->        X,a,c
1222    #[test]
1223    pub fn transform_op_rm_1_add_0() {
1224        let change = Rm(1);
1225        let op = Add(0);
1226        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1227    }
1228
1229    //                          Op
1230    //         a,b,c   ------|Add(1)|----->        a,X,b,c
1231    //   Rm(0)                                                 Rm(0)
1232    //           b,c   ------|Add(0)|----->        X,b,c
1233    #[test]
1234    pub fn transform_op_rm_0_add_1() {
1235        let change = Rm(0);
1236        let op = Add(1);
1237        assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1238    }
1239
1240    //                          Op
1241    //         a,b,c   ------|Rm(2)|----->        a,b
1242    //   Move(0, 2)                                           Move(0,1)
1243    //         b,c,a   ------|Rm(1)|----->        b,a
1244    #[test]
1245    pub fn transform_op_mv_02_rm_2() {
1246        let change = Move(0, 2);
1247        let op = Rm(2);
1248        assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1249    }
1250}
1251
1252#[cfg(test)]
1253mod proptests {
1254    use super::*;
1255    use proptest::prelude::*;
1256
1257    #[derive(Debug)]
1258    struct ComposeProblem {
1259        input: TVec<usize>,
1260        ops: TVec<AxisOp>,
1261    }
1262
1263    impl Arbitrary for AxisOp {
1264        type Parameters = TVec<usize>;
1265        type Strategy = BoxedStrategy<AxisOp>;
1266        fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1267            let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1268            if shape.len() > 1 {
1269                ops = ops
1270                    .prop_union(
1271                        (0..shape.len(), 0..shape.len() - 1)
1272                            .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1273                            .boxed(),
1274                    )
1275                    .boxed()
1276            }
1277            let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1278            if rms.len() > 0 {
1279                ops = ops
1280                    .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1281                    .boxed()
1282            }
1283            let mergeable: Vec<AxisOp> = shape
1284                .windows(2)
1285                .enumerate()
1286                .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1287                .map(|(ix, w)| {
1288                    Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1289                })
1290                .collect();
1291            if mergeable.len() > 1 {
1292                ops = ops
1293                    .prop_union(
1294                        (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1295                    )
1296                    .boxed()
1297            }
1298            ops
1299        }
1300    }
1301
1302    impl Arbitrary for ComposeProblem {
1303        type Parameters = ();
1304        type Strategy = BoxedStrategy<ComposeProblem>;
1305        fn arbitrary_with(_args: ()) -> Self::Strategy {
1306            let input = proptest::collection::vec(1usize..4, 1usize..4);
1307            fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1308                if len == 0 {
1309                    Just(tvec!()).boxed()
1310                } else {
1311                    AxisOp::arbitrary_with(shape.clone())
1312                        .prop_flat_map(move |op| {
1313                            let mut shape = shape.clone();
1314                            op.change_shape_array(&mut shape, false).unwrap();
1315                            tail(len - 1, shape.clone()).prop_map(move |mut t| {
1316                                t.insert(0, op.clone());
1317                                t
1318                            })
1319                        })
1320                        .boxed()
1321                }
1322            }
1323            (input, 1usize..=5)
1324                .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1325                .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1326                .boxed()
1327        }
1328    }
1329
1330    impl ComposeProblem {
1331        pub fn model(&self) -> TractResult<TypedModel> {
1332            let mut model = TypedModel::default();
1333            let mut wire = model.add_source("source", i64::fact(&self.input))?;
1334            for (ix, op) in self.ops.iter().enumerate() {
1335                wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1336            }
1337            model.set_output_outlets(&[wire])?;
1338            Ok(model)
1339        }
1340
1341        fn input(&self) -> TractResult<Tensor> {
1342            unsafe {
1343                let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1344                for i in 0..t.len() {
1345                    t.as_slice_mut().unwrap()[i] = i as i64;
1346                }
1347                Ok(t)
1348            }
1349        }
1350
1351        fn check(&self) -> TractResult<()> {
1352            crate::setup_test_logger();
1353            let input = self.input()?;
1354            let model = self.model()?;
1355            let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1356            let optimized = self.model()?.into_decluttered()?;
1357            let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1358            opt[0].close_enough(&raw[0], false)
1359        }
1360    }
1361
1362    proptest! {
1363        #[test]
1364        fn recip(pb in any::<AxisOp>()) {
1365            assert_eq!(pb.recip().recip(), pb);
1366        }
1367
1368        #[test]
1369        fn axis_ops(pb in any::<ComposeProblem>()) {
1370            pb.check().unwrap()
1371        }
1372    }
1373
1374    #[test]
1375    fn add_0_rm_0() {
1376        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1377        pb.check().unwrap();
1378    }
1379
1380    #[test]
1381    fn add_0_move_01() {
1382        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1383        pb.check().unwrap();
1384    }
1385
1386    #[test]
1387    fn add_0_move_01_add_1() {
1388        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1389        pb.check().unwrap();
1390    }
1391
1392    #[test]
1393    fn recip_move_01() {
1394        let op = Move(1, 0);
1395        assert_eq!(op.recip().recip(), op);
1396    }
1397
1398    #[test]
1399    fn recip_move_20() {
1400        let op = Move(2, 0);
1401        assert_eq!(op.recip().recip(), op);
1402    }
1403
1404    #[test]
1405    fn recip_move_02() {
1406        let op = Move(0, 2);
1407        assert_eq!(op.recip().recip(), op);
1408    }
1409
1410    #[test]
1411    fn add_0_add_1_move_02() {
1412        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1413        pb.check().unwrap();
1414    }
1415
1416    #[test]
1417    fn add_0_add_0() {
1418        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1419        pb.check().unwrap();
1420    }
1421
1422    #[test]
1423    fn add_0_add_0_move_02() {
1424        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1425        pb.check().unwrap();
1426    }
1427
1428    #[test]
1429    fn add_0_add_2_move_12() {
1430        let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1431        pb.check().unwrap();
1432    }
1433
1434    #[test]
1435    fn add_0_add_0_move_02_rm_0() {
1436        let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1437        pb.check().unwrap();
1438    }
1439
1440    #[test]
1441    fn add_0_add_0_move_20_move_20() {
1442        let pb =
1443            ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1444        pb.check().unwrap();
1445    }
1446
1447    #[test]
1448    fn move_01_add_0() {
1449        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1450        pb.check().unwrap();
1451    }
1452
1453    #[test]
1454    fn add_0_move_02_move_02() {
1455        let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1456        pb.check().unwrap();
1457    }
1458
1459    #[test]
1460    fn add_0_add_2_move_20_move_12_rm_2() {
1461        let pb = ComposeProblem {
1462            input: tvec![3],
1463            ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1464        };
1465        pb.check().unwrap();
1466    }
1467
1468    #[test]
1469    fn move_02_move_02() {
1470        let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1471        pb.check().unwrap();
1472    }
1473
1474    #[test]
1475    fn rm_1_perm_10_add_0() {
1476        let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1477        pb.check().unwrap();
1478    }
1479
1480    #[test]
1481    fn add_2_move_02_move_02() {
1482        let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1483        pb.check().unwrap();
1484    }
1485
1486    #[test]
1487    fn move_01_move_20_move_20() {
1488        let pb = ComposeProblem {
1489            input: tvec![2, 3, 2],
1490            ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1491        };
1492        pb.check().unwrap();
1493    }
1494
1495    #[test]
1496    fn reshape_axes_tracking() {
1497        let pb = ComposeProblem {
1498            input: tvec![2, 2, 2],
1499            ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1500        };
1501        pb.check().unwrap();
1502    }
1503
1504    #[test]
1505    fn simplify_reshape() {
1506        macro_rules! d {
1507            ($($dim: expr),*) =>  { tvec!($($dim.to_dim()),*) }
1508        }
1509        assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1510        assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1511        assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1512        assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1513        assert_eq!(
1514            Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1515            tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1516        );
1517        assert_eq!(
1518            Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1519            tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1520        );
1521        assert_eq!(
1522            Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1523            tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1524        );
1525        assert_eq!(
1526            Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1527            tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1528        );
1529        assert_eq!(
1530            Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1531            tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1532        );
1533        assert_eq!(
1534            Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1535            tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1536        );
1537        assert_eq!(
1538            Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1539            tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1540        );
1541        assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1542    }
1543
1544    macro_rules! s {
1545        ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1546    }
1547
1548    macro_rules! r {
1549        ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1550            AxisOp::Reshape($at, tvec!($($from.into()),*),  tvec!($($to.into()),*))
1551        }
1552    }
1553
1554    #[test]
1555    fn compute_invalid() {
1556        assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1557    }
1558
1559    #[test]
1560    fn compute_with_leading_zero() {
1561        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1562    }
1563
1564    #[test]
1565    fn compute_with_leading_zero_with_flatten() {
1566        assert_eq!(
1567            &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1568            s![2, 3, 35]
1569        )
1570    }
1571
1572    #[test]
1573    fn compute_with_trailing_zero() {
1574        assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1575    }
1576
1577    #[test]
1578    fn compute_bug_1() {
1579        let table = SymbolScope::default();
1580        let s = table.new_with_prefix("S");
1581        assert_eq!(
1582            &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1583            s![s, 1, 256]
1584        )
1585    }
1586
1587    #[test]
1588    fn compute_bug_2() {
1589        let table = SymbolScope::default();
1590        let b = table.new_with_prefix("B");
1591        let s = table.new_with_prefix("S");
1592        assert_eq!(
1593            &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1594            s![s, b, 256]
1595        )
1596    }
1597
1598    #[test]
1599    fn axis_op_rm_begin() {
1600        assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1601    }
1602
1603    #[test]
1604    fn axis_op_rm_end() {
1605        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1606    }
1607
1608    #[test]
1609    fn axis_op_insert_begin() {
1610        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1611    }
1612
1613    #[test]
1614    fn axis_op_insert_end() {
1615        assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1616    }
1617
1618    #[test]
1619    fn axis_op_merge() {
1620        assert_eq!(
1621            &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1622            &[r!(2 ; 5,7 => 35 )]
1623        )
1624    }
1625
1626    #[test]
1627    fn axis_op_complex() {
1628        assert_eq!(
1629            &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1630            &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1631        )
1632    }
1633}