tract_core/axes/
mapping.rs

1use std::fmt::Display;
2use std::str::FromStr;
3
4use tract_data::itertools::izip;
5use tract_ndarray::{ArrayViewD, ArrayViewMutD};
6
7use crate::internal::*;
8use crate::prelude::tract_itertools::Itertools;
9
10use super::Axis;
11
12pub trait AxisPattern: std::fmt::Debug {
13    fn search(&self, mapping: &AxesMapping) -> Option<usize>;
14}
15
16impl AxisPattern for char {
17    fn search(&self, mapping: &AxesMapping) -> Option<usize> {
18        mapping.axes.iter().position(|axis| axis.repr == *self)
19    }
20}
21
22impl AxisPattern for (InOut, usize) {
23    fn search(&self, mapping: &AxesMapping) -> Option<usize> {
24        match self.0 {
25            InOut::In(i) => mapping.axes.iter().position(|axis| axis.inputs[i].contains(&self.1)),
26            InOut::Out(o) => mapping.axes.iter().position(|axis| axis.outputs[o].contains(&self.1)),
27        }
28    }
29}
30
31impl AxisPattern for &Axis {
32    fn search(&self, mapping: &AxesMapping) -> Option<usize> {
33        mapping.axes.iter().position(|ax| self == &ax)
34    }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct AxesMapping {
39    input_count: usize,
40    output_count: usize,
41    axes: TVec<Axis>,
42}
43
44impl AxesMapping {
45    pub fn new(
46        input_count: usize,
47        output_count: usize,
48        it: impl AsRef<[Axis]>,
49    ) -> TractResult<AxesMapping> {
50        let axes: TVec<_> = it.as_ref().into();
51        AxesMapping { axes, output_count, input_count }.sorted().check()
52    }
53
54    pub fn for_numpy_matmul(
55        rank: usize,
56        transposing_a: bool,
57        transposing_b: bool,
58        transposing_c: bool,
59    ) -> TractResult<AxesMapping> {
60        let mut axes: TVec<Axis> = ('a'..)
61            .take(rank - 2)
62            .enumerate()
63            .map(|(ix, repr)| Axis {
64                repr,
65                inputs: tvec!(tvec!(ix), tvec!(ix)),
66                outputs: tvec!(tvec!(ix)),
67            })
68            .collect();
69        axes.push(Axis {
70            repr: 'm',
71            inputs: tvec!(tvec!(rank - 2 + transposing_a as usize), tvec!()),
72            outputs: tvec!(tvec!(rank - 2 + transposing_c as usize)),
73        });
74        axes.push(Axis {
75            repr: 'k',
76            inputs: tvec!(
77                tvec!(rank - 1 - transposing_a as usize),
78                tvec!(rank - 2 + transposing_b as usize)
79            ),
80            outputs: tvec!(tvec!()),
81        });
82        axes.push(Axis {
83            repr: 'n',
84            inputs: tvec!(tvec!(), tvec!(rank - 1 - transposing_b as usize),),
85            outputs: tvec!(tvec!(rank - 1 - transposing_c as usize)),
86        });
87        AxesMapping::new(2, 1, axes)
88    }
89
90    pub fn disconnected(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
91        let input_ranks: TVec<usize> = inputs.iter().map(|i| i.rank()).collect();
92        let output_ranks: TVec<usize> = outputs.iter().map(|i| i.rank()).collect();
93        Self::disconnected_for_ranks(&input_ranks, &output_ranks)
94    }
95
96    pub fn disconnected_for_ranks(inputs: &[usize], outputs: &[usize]) -> TractResult<AxesMapping> {
97        let mut axes = tvec!();
98        let mut alphabet = 'a'..;
99        for (ix, &rank) in inputs.iter().enumerate() {
100            for a in 0..rank {
101                axes.push(
102                    Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).input(ix, a),
103                );
104            }
105        }
106        for (ix, &rank) in outputs.iter().enumerate() {
107            for a in 0..rank {
108                axes.push(
109                    Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).output(ix, a),
110                );
111            }
112        }
113        AxesMapping::new(inputs.len(), outputs.len(), axes)
114    }
115
116    pub fn natural(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
117        let rank = inputs[0].rank();
118        let axes = (0..rank)
119            .zip('a'..)
120            .map(|(axis_id, repr)| Axis::natural(inputs.len(), outputs.len(), repr, axis_id))
121            .collect::<TVec<_>>();
122        AxesMapping::new(inputs.len(), outputs.len(), axes)
123    }
124
125    pub fn natural_for_rank(
126        inputs: usize,
127        outputs: usize,
128        rank: usize,
129    ) -> TractResult<AxesMapping> {
130        let axes = (0..rank)
131            .zip('a'..)
132            .map(|(axis_id, repr)| Axis::natural(inputs, outputs, repr, axis_id))
133            .collect::<TVec<_>>();
134        AxesMapping::new(inputs, outputs, axes)
135    }
136
137    pub fn iter_all_axes(&self) -> impl Iterator<Item = &Axis> {
138        self.axes.iter()
139    }
140
141    pub fn iter_all_axes_mut(&mut self) -> impl Iterator<Item = &mut Axis> {
142        self.axes.iter_mut()
143    }
144
145    pub fn input_count(&self) -> usize {
146        self.input_count
147    }
148
149    pub fn output_count(&self) -> usize {
150        self.output_count
151    }
152
153    pub fn axis_positions(&self, io: InOut, p: impl AxisPattern) -> TractResult<&[usize]> {
154        let axis = self.axis(p)?;
155        Ok(match io {
156            InOut::In(i) => &*axis.inputs[i],
157            InOut::Out(o) => &*axis.outputs[o],
158        })
159    }
160
161    pub fn rank(&self, io: InOut) -> usize {
162        match io {
163            InOut::In(i) => self.iter_all_axes().map(|axis| axis.inputs[i].len()).sum(),
164            InOut::Out(o) => self.iter_all_axes().map(|axis| axis.outputs[o].len()).sum(),
165        }
166    }
167
168    fn search(&self, p: impl AxisPattern) -> TractResult<usize> {
169        p.search(self).with_context(|| format!("Axis {p:?} not found in {self}"))
170    }
171
172    pub fn axis(&self, p: impl AxisPattern) -> TractResult<&Axis> {
173        Ok(&self.axes[self.search(p)?])
174    }
175
176    fn axis_mut(&mut self, p: impl AxisPattern) -> TractResult<&mut Axis> {
177        let ix = self.search(p)?;
178        Ok(&mut self.axes[ix])
179    }
180
181    pub fn axes(&self, io: InOut) -> impl Iterator<Item = &Axis> {
182        (0..self.rank(io)).map(move |ix| self.axis((io, ix)).unwrap())
183    }
184
185    pub fn track_axis(&self, from: impl AxisPattern, to: InOut) -> TractResult<Option<usize>> {
186        let axis = self.axis(from)?;
187        let positions = axis.interface(to);
188        Ok(if positions.len() == 1 { Some(positions[0]) } else { None })
189    }
190
191    pub fn renaming(mut self, axis: impl AxisPattern, name: char) -> TractResult<AxesMapping> {
192        let position = self.search(axis)?;
193        let old_label = self.axes[position].repr;
194        if let Ok(conflict) = self.axis_mut(name) {
195            conflict.repr = old_label
196        }
197        self.axes[position].repr = name;
198        self.sort();
199        self.check()
200    }
201
202    pub fn linking(
203        mut self,
204        target: impl AxisPattern,
205        axis: impl AxisPattern,
206    ) -> TractResult<AxesMapping> {
207        let axis = self.axis(axis)?;
208        let axis_ix = self.axes.iter().position(|a| a == axis).unwrap();
209        let axis = self.axes.remove(axis_ix);
210        let target = self.axis_mut(target)?;
211        for (ia, ib) in target.inputs.iter_mut().zip(axis.inputs.iter()) {
212            ia.extend(ib.into_iter().cloned())
213        }
214        for (ia, ib) in target.outputs.iter_mut().zip(axis.outputs.iter()) {
215            ia.extend(ib.into_iter().cloned())
216        }
217        self.sort();
218        self.check()
219    }
220
221    fn sort(&mut self) {
222        let order: Vec<(usize, usize, usize, char)> = self
223            .axes
224            .iter()
225            .flat_map(|axis| {
226                axis.inputs
227                    .iter()
228                    .enumerate()
229                    .flat_map(move |(slot, input)| {
230                        input.iter().map(move |p| (1, slot, *p, axis.repr))
231                    })
232                    .chain(axis.outputs.iter().enumerate().flat_map(move |(slot, output)| {
233                        output.iter().map(move |p| (0, slot, *p, axis.repr))
234                    }))
235            })
236            .sorted()
237            .dedup()
238            .collect_vec();
239        self.axes.sort_by_key(|axis| order.iter().position(|tuple| tuple.3 == axis.repr).unwrap());
240    }
241
242    fn sorted(mut self) -> AxesMapping {
243        self.sort();
244        self
245    }
246
247    fn do_check(&self) -> TractResult<()> {
248        for axis in &self.axes {
249            ensure!(axis.inputs.len() == self.input_count);
250            ensure!(axis.outputs.len() == self.output_count);
251            ensure!(
252                axis.inputs.iter().map(|i| i.len()).sum::<usize>()
253                    + axis.outputs.iter().map(|o| o.len()).sum::<usize>()
254                    > 0
255            );
256        }
257        for input_ix in 0..self.input_count() {
258            for axis in 0..self.rank(InOut::In(input_ix)) {
259                ensure!(self.axis((InOut::In(input_ix), axis)).is_ok());
260            }
261        }
262        for output_ix in 0..self.output_count() {
263            for axis in 0..self.rank(InOut::Out(output_ix)) {
264                ensure!(self.axis((InOut::Out(output_ix), axis)).is_ok());
265            }
266        }
267        ensure!(self.axes.iter().map(|ax| ax.repr).duplicates().count() == 0);
268        ensure!(
269            self == &{
270                let mut x = self.clone();
271                x.sort();
272                x
273            }
274        );
275        Ok(())
276    }
277
278    pub fn check(self) -> TractResult<AxesMapping> {
279        self.do_check().with_context(|| format!("Checking {:?}", self.axes))?;
280        Ok(self)
281    }
282
283    pub fn available_label(&self) -> char {
284        self.available_labels().next().unwrap()
285    }
286
287    pub fn available_labels(&self) -> impl Iterator<Item = char> + '_ {
288        ('a'..).filter(|c| self.iter_all_axes().all(|axis| axis.repr != *c))
289    }
290
291    pub fn is_element_wise_unary(&self) -> bool {
292        self.input_count == 1
293            && self.output_count == 1
294            && self
295                .iter_all_axes()
296                .all(|axis| axis.inputs[0].len() == 1 && axis.outputs[0] == axis.inputs[0])
297    }
298
299    pub fn extract_sub_mapping(
300        &self,
301        inputs: &[usize],
302        outputs: &[usize],
303    ) -> TractResult<AxesMapping> {
304        let axes: Vec<_> = self
305            .iter_all_axes()
306            .filter(|axis| {
307                inputs.iter().any(|i| axis.inputs[*i].len() > 0)
308                    || outputs.iter().any(|o| axis.outputs[*o].len() > 0)
309            })
310            .map(|axis| Axis {
311                inputs: axis
312                    .inputs
313                    .iter()
314                    .enumerate()
315                    .filter(|(ix, _)| inputs.contains(ix))
316                    .map(|(_, it)| it.clone())
317                    .collect(),
318                outputs: axis
319                    .outputs
320                    .iter()
321                    .enumerate()
322                    .filter(|(ix, _)| outputs.contains(ix))
323                    .map(|(_, it)| it.clone())
324                    .collect(),
325                repr: axis.repr,
326            })
327            .collect();
328        AxesMapping::new(inputs.len(), outputs.len(), axes)
329    }
330
331    pub fn relabel(mut self) -> TractResult<AxesMapping> {
332        for (ax, repr) in self.axes.iter_mut().zip('a'..) {
333            ax.repr = repr;
334        }
335        Ok(self)
336    }
337
338    pub fn remove_axis(&self, repr: char) -> TractResult<AxesMapping> {
339        let mut axes: TVec<Axis> =
340            self.axes.iter().filter(|axis| axis.repr != repr).cloned().collect();
341        let removed = self.axis(repr).context("Axis not found")?;
342        for input in 0..self.input_count {
343            for &position in &removed.inputs[input] {
344                for other in &mut axes {
345                    other.inputs[input]
346                        .iter_mut()
347                        .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
348                }
349            }
350        }
351        for output in 0..self.output_count {
352            for &position in &removed.outputs[output] {
353                for other in &mut axes {
354                    other.outputs[output]
355                        .iter_mut()
356                        .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
357                }
358            }
359        }
360        AxesMapping::new(self.input_count, self.output_count, axes)
361    }
362
363    pub fn remove_axis_occurency(&self, slot: InOut, position: usize) -> TractResult<AxesMapping> {
364        let axis = self.axis((slot, position))?;
365        if axis.inputs.iter().map(|i| i.len()).sum::<usize>()
366            + axis.outputs.iter().map(|i| i.len()).sum::<usize>()
367            == 1
368        {
369            return self.remove_axis(axis.repr);
370        }
371        let mut axes = self.axes.clone();
372        match slot {
373            InOut::In(slot) => {
374                for axis in &mut axes {
375                    axis.inputs[slot].retain(|pos| *pos != position);
376                    axis.inputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);
377                }
378            }
379            InOut::Out(slot) => {
380                for axis in &mut axes {
381                    axis.outputs[slot].retain(|pos| *pos != position);
382                    axis.outputs[slot]
383                        .iter_mut()
384                        .for_each(|pos| *pos -= (*pos > position) as usize);
385                }
386            }
387        }
388        AxesMapping::new(self.input_count, self.output_count, axes)
389    }
390
391    pub fn remove_slot(&self, slot: InOut) -> TractResult<AxesMapping> {
392        let mut axes = self.clone();
393        while axes.rank(slot) > 0 {
394            axes = axes.remove_axis_occurency(slot, 0)?
395        }
396        match slot {
397            InOut::In(slot) => {
398                for axis in &mut axes.axes {
399                    axis.inputs.remove(slot);
400                }
401                axes.input_count -= 1;
402            }
403            InOut::Out(slot) => {
404                for axis in &mut axes.axes {
405                    axis.outputs.remove(slot);
406                }
407                axes.output_count -= 1;
408            }
409        }
410        axes.sorted().check()
411    }
412
413    pub fn with_extra_input(self, slot: usize) -> TractResult<AxesMapping> {
414        let axes: TVec<Axis> = self
415            .iter_all_axes()
416            .map(|axis| {
417                let mut axis = axis.clone();
418                axis.inputs.insert(slot, tvec!());
419                axis
420            })
421            .collect();
422        AxesMapping::new(self.input_count + 1, self.output_count, axes)
423    }
424
425    pub fn with_extra_axis(
426        mut self,
427        repr: char,
428        io: InOut,
429        position: usize,
430    ) -> TractResult<AxesMapping> {
431        let axis = Axis::new(repr, self.input_count, self.output_count);
432        self.axes.push(axis);
433        self.with_extra_axis_occurency(repr, io, position)
434    }
435
436    pub fn with_extra_axis_occurency(
437        mut self,
438        axis: impl AxisPattern,
439        io: InOut,
440        position: usize,
441    ) -> TractResult<AxesMapping> {
442        match io {
443            InOut::In(slot) => {
444                self.axes.iter_mut().for_each(|axis| {
445                    axis.inputs[slot].iter_mut().for_each(|pos| *pos += (*pos >= position) as usize)
446                });
447                self.axis_mut(axis)?.inputs[slot].push(position);
448            }
449            InOut::Out(slot) => {
450                self.axes.iter_mut().for_each(|axis| {
451                    axis.outputs[slot]
452                        .iter_mut()
453                        .for_each(|pos| *pos += (*pos >= position) as usize)
454                });
455                self.axis_mut(axis)?.outputs[slot].push(position);
456            }
457        }
458        self.sort();
459        self.check()
460    }
461
462    pub fn translate_to_axis_ops(&self) -> TractResult<Vec<AxisOp>> {
463        ensure!(self.input_count() == 1);
464        ensure!(self.output_count() == 1);
465        ensure!(self.iter_all_axes().all(|axis| axis.inputs[0].len() <= 1));
466        let rms = self
467            .iter_all_axes()
468            .filter(|a| a.outputs[0].len() == 0)
469            .sorted_by_key(|axis| -(axis.inputs[0][0] as isize))
470            .collect_vec();
471        let adds = self
472            .iter_all_axes()
473            .filter(|a| a.inputs[0].len() == 0)
474            .sorted_by_key(|axis| axis.outputs[0][0] as isize)
475            .collect_vec();
476        let permutation = rms
477            .iter()
478            .chain(adds.iter())
479            .try_fold(self.clone(), |mapping, axis| mapping.remove_axis(axis.repr))?;
480        let permutation = permutation
481            .iter_all_axes()
482            .sorted_by_key(|axis| axis.outputs[0][0])
483            .map(|axis| axis.inputs[0][0])
484            .collect_vec();
485        let permutation = perm_to_ops(&permutation);
486        let rms = rms.iter().map(|axis| AxisOp::Rm(axis.inputs[0][0]));
487        let adds = adds.iter().map(|axis| AxisOp::Add(axis.outputs[0][0]));
488        Ok(rms.chain(permutation).chain(adds).collect())
489    }
490
491    pub fn from_strs(
492        inputs: &[impl AsRef<str>],
493        outputs: &[impl AsRef<str>],
494    ) -> TractResult<AxesMapping> {
495        let mut axes = HashMap::<char, Axis>::default();
496        for (input_ix, input) in inputs.iter().enumerate() {
497            for (ix, axis) in input.as_ref().chars().enumerate() {
498                axes.entry(axis)
499                    .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
500                    .add_input(input_ix, ix);
501            }
502        }
503        for (output_ix, output) in outputs.iter().enumerate() {
504            for (ix, axis) in output.as_ref().chars().enumerate() {
505                axes.entry(axis)
506                    .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
507                    .add_output(output_ix, ix);
508            }
509        }
510        if outputs.len() == 0 {
511            axes.iter_mut()
512                .sorted_by_key(|(k, _)| *k)
513                .filter(|(_, v)| v.inputs.iter().map(|input| input.len()).sum::<usize>() == 1)
514                .enumerate()
515                .for_each(|(ix, (_, v))| v.add_output(0, ix))
516        }
517        Self::new(
518            inputs.len(),
519            outputs.len().max(1),
520            axes.into_iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
521        )
522    }
523
524    pub fn to_strs(&self) -> (TVec<String>, TVec<String>) {
525        let mut inputs = tvec![];
526        let mut outputs = tvec![];
527        for input in 0..self.input_count() {
528            let s = self
529                .iter_all_axes()
530                .flat_map(|axis| {
531                    axis.inputs[input].iter().map(move |position| (position, axis.repr))
532                })
533                .sorted()
534                .map(|(_, r)| r)
535                .collect();
536            inputs.push(s);
537        }
538        for output in 0..self.output_count() {
539            let s = self
540                .iter_all_axes()
541                .flat_map(|axis| {
542                    axis.outputs[output].iter().map(move |position| (position, axis.repr))
543                })
544                .sorted()
545                .map(|(_, r)| r)
546                .collect();
547            outputs.push(s);
548        }
549        (inputs, outputs)
550    }
551
552    pub fn change_axis_sink(&self, io: InOut, change: &AxisOp) -> TractResult<Option<AxesMapping>> {
553        let (mut inputs, mut outputs) = self.to_strs();
554        let interface: &mut String = match io {
555            InOut::In(i) => &mut inputs[i],
556            InOut::Out(o) => &mut outputs[o],
557        };
558        let mut axes: Vec<char> = interface.chars().collect();
559        match change {
560            AxisOp::Rm(rm) => {
561                axes.remove(*rm);
562            }
563            AxisOp::Add(add) => axes.insert(*add, self.available_label()),
564            AxisOp::Move(from, to) => {
565                let c = axes.remove(*from);
566                axes.insert(*to, c);
567            }
568            _ => return Ok(None),
569        };
570        *interface = axes.into_iter().collect();
571        Ok(Some(AxesMapping::from_strs(&inputs, &outputs)?))
572    }
573
574    pub fn direct(&self, a: InOut, b: InOut) -> bool {
575        self.axes.iter().all(|axis| axis.interface(a) == axis.interface(b))
576    }
577
578    pub fn same_layout<D: DimLike>(
579        &self,
580        a: InOut,
581        b: InOut,
582        shape_a: impl AsRef<[D]>,
583        shape_b: impl AsRef<[D]>,
584    ) -> bool {
585        let shape_a = shape_a.as_ref();
586        let shape_b = shape_b.as_ref();
587        shape_a.iter().cloned().product::<D>() == shape_b.iter().cloned().product()
588            && izip!(
589                self.axes(a).zip(shape_a.iter()).filter(|(_axis, d)| **d != D::one()),
590                self.axes(b).zip(shape_b.iter()).filter(|(_axis, d)| **d != D::one())
591            )
592            .all(|(a, b)| a == b)
593    }
594
595    pub fn axis_ops_to_canonical(&self, io: InOut) -> TractResult<Vec<AxisOp>> {
596        let rank = self.rank(io);
597        let target_rank = self.axes.len();
598        let mut next_insert_axis = 0;
599        let mut permutation = tvec!();
600        for axis in &self.axes {
601            let spec = match io {
602                InOut::In(i) => axis.inputs[i].first(),
603                InOut::Out(o) => axis.outputs[o].first(),
604            };
605            if let Some(pos_in_a) = spec {
606                permutation.push(pos_in_a + target_rank - rank)
607            } else {
608                permutation.push(next_insert_axis);
609                next_insert_axis += 1;
610            }
611        }
612        let mut ops = vec![AxisOp::Add(0); target_rank - rank];
613        ops.extend(crate::ops::change_axes::perm_to_ops(&permutation));
614        Ok(ops)
615    }
616
617    pub fn view_to_canonical<D>(&self, io: InOut, view: &mut ArrayViewD<D>) -> TractResult<()> {
618        for op in self.axis_ops_to_canonical(io)? {
619            op.change_view(view)?;
620        }
621        Ok(())
622    }
623
624    pub fn view_to_canonical_mut<D>(
625        &self,
626        io: InOut,
627        view: &mut ArrayViewMutD<D>,
628    ) -> TractResult<()> {
629        for op in self.axis_ops_to_canonical(io)? {
630            op.change_view_mut(view)?;
631        }
632        Ok(())
633    }
634
635    pub fn compose(&self, other: &AxesMapping) -> TractResult<AxesMapping> {
636        ensure!(self.input_count() == 1 && self.output_count() == 1);
637        ensure!(other.input_count() == 1 && other.output_count() == 1);
638        let mut result = AxesMapping::disconnected_for_ranks(
639            &[self.rank(InOut::In(0))],
640            &[other.rank(InOut::Out(0))],
641        )?;
642        for ix in 0..result.rank(InOut::In(0)) {
643            let Some(inter) = self.track_axis((InOut::In(0), ix), InOut::Out(0))? else { continue };
644            let Some(out) = other.track_axis((InOut::In(0), inter), InOut::Out(0))? else {
645                continue;
646            };
647            result = result.linking((InOut::Out(0), out), (InOut::In(0), ix))?;
648        }
649        Ok(result)
650    }
651}
652
653impl FromStr for AxesMapping {
654    type Err = TractError;
655    fn from_str(s: &str) -> Result<Self, Self::Err> {
656        assert!(!s.contains("..."));
657        let s = s.replace(' ', "");
658        let (inputs, outputs) =
659            if let Some((i, r)) = s.split_once("->") { (i, r) } else { (&*s, "") };
660        let inputs: TVec<&str> = inputs.split(',').collect();
661        let outputs: TVec<&str> = outputs.split(',').filter(|s| s.len() > 0).collect();
662        AxesMapping::from_strs(&inputs, &outputs)
663    }
664}
665
666impl Display for AxesMapping {
667    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
668        let (inputs, outputs) = self.to_strs();
669        write!(f, "{}->{}", inputs.iter().join(","), outputs.iter().join(","))
670    }
671}
672
673#[cfg(test)]
674mod test {
675    use super::*;
676
677    fn m(s: &str) -> AxesMapping {
678        s.parse().unwrap()
679    }
680
681    #[test]
682    fn test_parse_transpose() {
683        assert_eq!(
684            m("ij->ji"),
685            AxesMapping::new(
686                1,
687                1,
688                tvec![
689                    Axis::new('i', 1, 1).output(0, 1).input(0, 0),
690                    Axis::new('j', 1, 1).output(0, 0).input(0, 1)
691                ]
692            )
693            .unwrap(),
694        )
695    }
696
697    #[test]
698    fn test_parse_diag() {
699        assert_eq!(
700            m("ii->i"),
701            AxesMapping::new(
702                1,
703                1,
704                tvec![Axis::new('i', 1, 1).output(0, 0).input(0, 0).input(0, 1)]
705            )
706            .unwrap(),
707        )
708    }
709
710    #[test]
711    fn test_parse_adamar_product_explicit() {
712        assert_eq!(
713            m("i,i->i"),
714            AxesMapping::new(
715                2,
716                1,
717                tvec![Axis::new('i', 2, 1).output(0, 0).input(0, 0).input(1, 0)]
718            )
719            .unwrap(),
720        )
721    }
722
723    #[test]
724    fn test_parse_inner_product_implicit() {
725        assert_eq!(m("i,i"), m("i,i->"))
726    }
727
728    #[test]
729    fn test_parse_batch_matmul() {
730        assert_eq!(
731            m("bij , bjk -> bik "),
732            AxesMapping::new(
733                2,
734                1,
735                tvec![
736                    Axis::new('b', 2, 1).output(0, 0).input(0, 0).input(1, 0),
737                    Axis::new('i', 2, 1).output(0, 1).input(0, 1),
738                    Axis::new('j', 2, 1).input(0, 2).input(1, 1),
739                    Axis::new('k', 2, 1).output(0, 2).input(1, 2)
740                ]
741            )
742            .unwrap()
743        )
744    }
745
746    #[test]
747    fn test_parse_outer_product() {
748        assert_eq!(
749            m("i,j->ij"),
750            AxesMapping::new(
751                2,
752                1,
753                tvec![
754                    Axis::new('i', 2, 1).output(0, 0).input(0, 0),
755                    Axis::new('j', 2, 1).output(0, 1).input(1, 0)
756                ]
757            )
758            .unwrap(),
759        )
760    }
761
762    #[test]
763    fn test_parse_bilinear() {
764        assert_eq!(
765            m("ik,jkl,il->ij"),
766            AxesMapping::new(
767                3,
768                1,
769                tvec![
770                    Axis::new('i', 3, 1).output(0, 0).input(0, 0).input(2, 0),
771                    Axis::new('j', 3, 1).output(0, 1).input(1, 0),
772                    Axis::new('k', 3, 1).input(0, 1).input(1, 1),
773                    Axis::new('l', 3, 1).input(1, 2).input(2, 1)
774                ]
775            )
776            .unwrap(),
777        )
778    }
779
780    #[test]
781    fn test_parse_complex_tensor_contraction() {
782        assert_eq!(
783            m("pqrs,tuqvr->pstuv"),
784            AxesMapping::new(
785                2,
786                1,
787                tvec![
788                    Axis::new('p', 2, 1).output(0, 0).input(0, 0),
789                    Axis::new('q', 2, 1).input(0, 1).input(1, 2),
790                    Axis::new('r', 2, 1).input(0, 2).input(1, 4),
791                    Axis::new('s', 2, 1).output(0, 1).input(0, 3),
792                    Axis::new('t', 2, 1).output(0, 2).input(1, 0),
793                    Axis::new('u', 2, 1).output(0, 3).input(1, 1),
794                    Axis::new('v', 2, 1).output(0, 4).input(1, 3),
795                ]
796            )
797            .unwrap(),
798        )
799    }
800
801    #[test]
802    fn test_parse_complex_tensor_contraction_implicit() {
803        assert_eq!(m("pqrs,tuqvr"), m("pqrs,tuqvr->pstuv"))
804    }
805
806    #[test]
807    fn test_display_expr() {
808        assert_eq!(m("pqrs,tuqvr->pstuv").to_string(), "pqrs,tuqvr->pstuv");
809    }
810
811    #[test]
812    fn test_parse_pulsed_matmul() {
813        assert_eq!(
814            m("sij,ijk->sik"),
815            AxesMapping::new(
816                2,
817                1,
818                tvec![
819                    Axis::new('i', 2, 1).output(0, 1).input(0, 1).input(1, 0),
820                    Axis::new('j', 2, 1).input(0, 2).input(1, 1),
821                    Axis::new('k', 2, 1).output(0, 2).input(1, 2),
822                    Axis::new('s', 2, 1).output(0, 0).input(0, 0),
823                ]
824            )
825            .unwrap()
826        )
827    }
828
829    #[test]
830    fn test_parse_pulsed_batch_matmul() {
831        assert_eq!(
832            m("bsij,ijk->bsik"),
833            AxesMapping::new(
834                2,
835                1,
836                tvec![
837                    Axis::new('b', 2, 1).output(0, 0).input(0, 0),
838                    Axis::new('i', 2, 1).output(0, 2).input(0, 2).input(1, 0),
839                    Axis::new('j', 2, 1).input(0, 3).input(1, 1),
840                    Axis::new('k', 2, 1).output(0, 3).input(1, 2),
841                    Axis::new('s', 2, 1).output(0, 1).input(0, 1),
842                ]
843            )
844            .unwrap()
845        )
846    }
847
848    #[test]
849    fn test_extract_sub_mapping() {
850        assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[0], &[0]).unwrap(), m("bsij->bsik"));
851        assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->bsik"));
852        assert_eq!(m("bsij,ijk->ij").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->ij"));
853    }
854
855    #[test]
856    fn test_remove_axis_0() {
857        assert_eq!(m("ab->a").remove_axis('b').unwrap(), m("a->a"));
858        assert_eq!(m("ba->a").remove_axis('b').unwrap(), m("a->a"));
859        assert_eq!(m("a->ba").remove_axis('b').unwrap(), m("a->a"));
860        assert_eq!(m("a->ab").remove_axis('b').unwrap(), m("a->a"));
861        assert_eq!(m("ab,a->a").remove_axis('b').unwrap(), m("a,a->a"));
862        assert_eq!(m("ba,a->a").remove_axis('b').unwrap(), m("a,a->a"));
863        assert_eq!(m("a,ab->a").remove_axis('b').unwrap(), m("a,a->a"));
864        assert_eq!(m("a,ba->a").remove_axis('b').unwrap(), m("a,a->a"));
865        assert_eq!(m("a,a->ab").remove_axis('b').unwrap(), m("a,a->a"));
866        assert_eq!(m("a,a->ba").remove_axis('b').unwrap(), m("a,a->a"));
867        assert_eq!(m("bsij,ijk->bsik").remove_axis('i').unwrap(), m("bsj,jk->bsk"),);
868    }
869
870    #[test]
871    fn test_translate_to_ops_rm_add() {
872        assert_eq!(m("ab->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(1)));
873        assert_eq!(m("ba->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(0)));
874        assert_eq!(
875            m("ab->c").translate_to_axis_ops().unwrap(),
876            vec!(AxisOp::Rm(1), AxisOp::Rm(0), AxisOp::Add(0))
877        );
878    }
879
880    #[test]
881    fn test_translate_to_ops_add_0() {
882        assert_eq!(
883            m("bacmn->bmn").translate_to_axis_ops().unwrap(),
884            vec!(AxisOp::Rm(2), AxisOp::Rm(1))
885        );
886    }
887
888    #[test]
889    fn test_translate_to_ops_move() {
890        assert_eq!(m("ab->ba").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(1, 0)));
891    }
892
893    #[test]
894    fn test_translate_to_ops_move_20() {
895        assert_eq!(m("abc->cab").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(2, 0)));
896    }
897
898    #[test]
899    fn test_translate_to_ops_complex() {
900        assert_eq!(
901            m("anbck->backn").translate_to_axis_ops().unwrap(),
902            vec!(AxisOp::Move(2, 0), AxisOp::Move(2, 4))
903        );
904    }
905}