tract_core/axes/
mapping.rs

1use std::fmt::Display;
2use std::str::FromStr;
3
4use tract_data::itertools::izip;
5use tract_ndarray::{ArrayViewD, ArrayViewMutD};
6
7use crate::internal::*;
8use crate::prelude::tract_itertools::Itertools;
9
10use super::Axis;
11
12pub trait AxisPattern: std::fmt::Debug {
13    fn search(&self, mapping: &AxesMapping) -> Option<usize>;
14}
15
16impl AxisPattern for char {
17    fn search(&self, mapping: &AxesMapping) -> Option<usize> {
18        mapping.axes.iter().position(|axis| axis.repr == *self)
19    }
20}
21
22impl AxisPattern for (InOut, usize) {
23    fn search(&self, mapping: &AxesMapping) -> Option<usize> {
24        match self.0 {
25            InOut::In(i) => mapping.axes.iter().position(|axis| axis.inputs[i].contains(&self.1)),
26            InOut::Out(o) => mapping.axes.iter().position(|axis| axis.outputs[o].contains(&self.1)),
27        }
28    }
29}
30
31impl AxisPattern for &Axis {
32    fn search(&self, mapping: &AxesMapping) -> Option<usize> {
33        mapping.axes.iter().position(|ax| self == &ax)
34    }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct AxesMapping {
39    input_count: usize,
40    output_count: usize,
41    axes: TVec<Axis>,
42}
43
44impl AxesMapping {
45    pub fn new(
46        input_count: usize,
47        output_count: usize,
48        it: impl AsRef<[Axis]>,
49    ) -> TractResult<AxesMapping> {
50        let axes: TVec<_> = it.as_ref().into();
51        AxesMapping { axes, output_count, input_count }.sorted().check()
52    }
53
54    pub fn for_numpy_matmul(
55        rank: usize,
56        transposing_a: bool,
57        transposing_b: bool,
58        transposing_c: bool,
59    ) -> TractResult<AxesMapping> {
60        let mut axes: TVec<Axis> = ('a'..)
61            .take(rank - 2)
62            .enumerate()
63            .map(|(ix, repr)| Axis {
64                repr,
65                inputs: tvec!(tvec!(ix), tvec!(ix)),
66                outputs: tvec!(tvec!(ix)),
67            })
68            .collect();
69        axes.push(Axis {
70            repr: 'm',
71            inputs: tvec!(tvec!(rank - 2 + transposing_a as usize), tvec!()),
72            outputs: tvec!(tvec!(rank - 2 + transposing_c as usize)),
73        });
74        axes.push(Axis {
75            repr: 'k',
76            inputs: tvec!(
77                tvec!(rank - 1 - transposing_a as usize),
78                tvec!(rank - 2 + transposing_b as usize)
79            ),
80            outputs: tvec!(tvec!()),
81        });
82        axes.push(Axis {
83            repr: 'n',
84            inputs: tvec!(tvec!(), tvec!(rank - 1 - transposing_b as usize),),
85            outputs: tvec!(tvec!(rank - 1 - transposing_c as usize)),
86        });
87        AxesMapping::new(2, 1, axes)
88    }
89
90    pub fn disconnected(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
91        let input_ranks: TVec<usize> = inputs.iter().map(|i| i.rank()).collect();
92        let output_ranks: TVec<usize> = outputs.iter().map(|i| i.rank()).collect();
93        Self::disconnected_for_ranks(&input_ranks, &output_ranks)
94    }
95
96    pub fn disconnected_for_ranks(inputs: &[usize], outputs: &[usize]) -> TractResult<AxesMapping> {
97        let mut axes = tvec!();
98        let mut alphabet = 'a'..;
99        for (ix, &rank) in inputs.iter().enumerate() {
100            for a in 0..rank {
101                axes.push(
102                    Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).input(ix, a),
103                );
104            }
105        }
106        for (ix, &rank) in outputs.iter().enumerate() {
107            for a in 0..rank {
108                axes.push(
109                    Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).output(ix, a),
110                );
111            }
112        }
113        AxesMapping::new(inputs.len(), outputs.len(), axes)
114    }
115
116    pub fn natural(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
117        let rank = inputs[0].rank();
118        let axes = (0..rank)
119            .zip('a'..)
120            .map(|(axis_id, repr)| Axis::natural(inputs.len(), outputs.len(), repr, axis_id))
121            .collect::<TVec<_>>();
122        AxesMapping::new(inputs.len(), outputs.len(), axes)
123    }
124
125    pub fn natural_for_rank(
126        inputs: usize,
127        outputs: usize,
128        rank: usize,
129    ) -> TractResult<AxesMapping> {
130        let axes = (0..rank)
131            .zip('a'..)
132            .map(|(axis_id, repr)| Axis::natural(inputs, outputs, repr, axis_id))
133            .collect::<TVec<_>>();
134        AxesMapping::new(inputs, outputs, axes)
135    }
136
137    pub fn iter_all_axes(&self) -> impl Iterator<Item = &Axis> {
138        self.axes.iter()
139    }
140
141    pub fn input_count(&self) -> usize {
142        self.input_count
143    }
144
145    pub fn output_count(&self) -> usize {
146        self.output_count
147    }
148
149    pub fn axis_positions(&self, io: InOut, p: impl AxisPattern) -> TractResult<&[usize]> {
150        let axis = self.axis(p)?;
151        Ok(match io {
152            InOut::In(i) => &*axis.inputs[i],
153            InOut::Out(o) => &*axis.outputs[o],
154        })
155    }
156
157    pub fn rank(&self, io: InOut) -> usize {
158        match io {
159            InOut::In(i) => self.iter_all_axes().map(|axis| axis.inputs[i].len()).sum(),
160            InOut::Out(o) => self.iter_all_axes().map(|axis| axis.outputs[o].len()).sum(),
161        }
162    }
163
164    fn search(&self, p: impl AxisPattern) -> TractResult<usize> {
165        p.search(self).with_context(|| format!("Axis {p:?} not found in {self}"))
166    }
167
168    pub fn axis(&self, p: impl AxisPattern) -> TractResult<&Axis> {
169        Ok(&self.axes[self.search(p)?])
170    }
171
172    fn axis_mut(&mut self, p: impl AxisPattern) -> TractResult<&mut Axis> {
173        let ix = self.search(p)?;
174        Ok(&mut self.axes[ix])
175    }
176
177    pub fn axes(&self, io: InOut) -> impl Iterator<Item = &Axis> {
178        (0..self.rank(io)).map(move |ix| self.axis((io, ix)).unwrap())
179    }
180
181    pub fn track_axis(&self, from: impl AxisPattern, to: InOut) -> TractResult<Option<usize>> {
182        let axis = self.axis(from)?;
183        let positions = axis.interface(to);
184        Ok(if positions.len() == 1 { Some(positions[0]) } else { None })
185    }
186
187    pub fn renaming(mut self, axis: impl AxisPattern, name: char) -> TractResult<AxesMapping> {
188        let position = self.search(axis)?;
189        let old_label = self.axes[position].repr;
190        if let Ok(conflict) = self.axis_mut(name) {
191            conflict.repr = old_label
192        }
193        self.axes[position].repr = name;
194        self.sort();
195        self.check()
196    }
197
198    pub fn linking(
199        mut self,
200        target: impl AxisPattern,
201        axis: impl AxisPattern,
202    ) -> TractResult<AxesMapping> {
203        let axis = self.axis(axis)?;
204        let axis_ix = self.axes.iter().position(|a| a == axis).unwrap();
205        let axis = self.axes.remove(axis_ix);
206        let target = self.axis_mut(target)?;
207        for (ia, ib) in target.inputs.iter_mut().zip(axis.inputs.iter()) {
208            ia.extend(ib.into_iter().cloned())
209        }
210        for (ia, ib) in target.outputs.iter_mut().zip(axis.outputs.iter()) {
211            ia.extend(ib.into_iter().cloned())
212        }
213        self.sort();
214        self.check()
215    }
216
217    fn sort(&mut self) {
218        let order: Vec<(usize, usize, usize, char)> = self
219            .axes
220            .iter()
221            .flat_map(|axis| {
222                axis.inputs
223                    .iter()
224                    .enumerate()
225                    .flat_map(move |(slot, input)| {
226                        input.iter().map(move |p| (1, slot, *p, axis.repr))
227                    })
228                    .chain(axis.outputs.iter().enumerate().flat_map(move |(slot, output)| {
229                        output.iter().map(move |p| (0, slot, *p, axis.repr))
230                    }))
231            })
232            .sorted()
233            .dedup()
234            .collect_vec();
235        self.axes.sort_by_key(|axis| order.iter().position(|tuple| tuple.3 == axis.repr).unwrap());
236    }
237
238    fn sorted(mut self) -> AxesMapping {
239        self.sort();
240        self
241    }
242
243    fn do_check(&self) -> TractResult<()> {
244        for axis in &self.axes {
245            ensure!(axis.inputs.len() == self.input_count);
246            ensure!(axis.outputs.len() == self.output_count);
247            ensure!(
248                axis.inputs.iter().map(|i| i.len()).sum::<usize>()
249                    + axis.outputs.iter().map(|o| o.len()).sum::<usize>()
250                    > 0
251            );
252        }
253        for input_ix in 0..self.input_count() {
254            for axis in 0..self.rank(InOut::In(input_ix)) {
255                ensure!(self.axis((InOut::In(input_ix), axis)).is_ok());
256            }
257        }
258        for output_ix in 0..self.output_count() {
259            for axis in 0..self.rank(InOut::Out(output_ix)) {
260                ensure!(self.axis((InOut::Out(output_ix), axis)).is_ok());
261            }
262        }
263        ensure!(self.axes.iter().map(|ax| ax.repr).duplicates().count() == 0);
264        ensure!(
265            self == &{
266                let mut x = self.clone();
267                x.sort();
268                x
269            }
270        );
271        Ok(())
272    }
273
274    pub fn check(self) -> TractResult<AxesMapping> {
275        self.do_check().with_context(|| format!("Checking {:?}", self.axes))?;
276        Ok(self)
277    }
278
279    pub fn available_label(&self) -> char {
280        ('a'..).find(|c| self.iter_all_axes().all(|axis| axis.repr != *c)).unwrap()
281    }
282
283    pub fn is_element_wise_unary(&self) -> bool {
284        self.input_count == 1
285            && self.output_count == 1
286            && self
287                .iter_all_axes()
288                .all(|axis| axis.inputs[0].len() == 1 && axis.outputs[0] == axis.inputs[0])
289    }
290
291    pub fn extract_sub_mapping(
292        &self,
293        inputs: &[usize],
294        outputs: &[usize],
295    ) -> TractResult<AxesMapping> {
296        let axes: Vec<_> = self
297            .iter_all_axes()
298            .filter(|axis| {
299                inputs.iter().any(|i| axis.inputs[*i].len() > 0)
300                    || outputs.iter().any(|o| axis.outputs[*o].len() > 0)
301            })
302            .map(|axis| Axis {
303                inputs: axis
304                    .inputs
305                    .iter()
306                    .enumerate()
307                    .filter(|(ix, _)| inputs.contains(ix))
308                    .map(|(_, it)| it.clone())
309                    .collect(),
310                outputs: axis
311                    .outputs
312                    .iter()
313                    .enumerate()
314                    .filter(|(ix, _)| outputs.contains(ix))
315                    .map(|(_, it)| it.clone())
316                    .collect(),
317                repr: axis.repr,
318            })
319            .collect();
320        AxesMapping::new(inputs.len(), outputs.len(), axes)
321    }
322
323    pub fn relabel(mut self) -> TractResult<AxesMapping> {
324        for (ax, repr) in self.axes.iter_mut().zip('a'..) {
325            ax.repr = repr;
326        }
327        Ok(self)
328    }
329
330    pub fn remove_axis(&self, repr: char) -> TractResult<AxesMapping> {
331        let mut axes: TVec<Axis> =
332            self.axes.iter().filter(|axis| axis.repr != repr).cloned().collect();
333        let removed = self.axis(repr).context("Axis not found")?;
334        for input in 0..self.input_count {
335            for &position in &removed.inputs[input] {
336                for other in &mut axes {
337                    other.inputs[input]
338                        .iter_mut()
339                        .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
340                }
341            }
342        }
343        for output in 0..self.output_count {
344            for &position in &removed.outputs[output] {
345                for other in &mut axes {
346                    other.outputs[output]
347                        .iter_mut()
348                        .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
349                }
350            }
351        }
352        AxesMapping::new(self.input_count, self.output_count, axes)
353    }
354
355    pub fn remove_axis_occurency(&self, slot: InOut, position: usize) -> TractResult<AxesMapping> {
356        let axis = self.axis((slot, position))?;
357        if axis.inputs.iter().map(|i| i.len()).sum::<usize>()
358            + axis.outputs.iter().map(|i| i.len()).sum::<usize>()
359            == 1
360        {
361            return self.remove_axis(axis.repr);
362        }
363        let mut axes = self.axes.clone();
364        match slot {
365            InOut::In(slot) => {
366                for axis in &mut axes {
367                    axis.inputs[slot].retain(|pos| *pos != position);
368                    axis.inputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);
369                }
370            }
371            InOut::Out(slot) => {
372                for axis in &mut axes {
373                    axis.outputs[slot].retain(|pos| *pos != position);
374                    axis.outputs[slot]
375                        .iter_mut()
376                        .for_each(|pos| *pos -= (*pos > position) as usize);
377                }
378            }
379        }
380        AxesMapping::new(self.input_count, self.output_count, axes)
381    }
382
383    pub fn remove_slot(&self, slot: InOut) -> TractResult<AxesMapping> {
384        let mut axes = self.clone();
385        while axes.rank(slot) > 0 {
386            axes = axes.remove_axis_occurency(slot, 0)?
387        }
388        match slot {
389            InOut::In(slot) => {
390                for axis in &mut axes.axes {
391                    axis.inputs.remove(slot);
392                }
393                axes.input_count -= 1;
394            }
395            InOut::Out(slot) => {
396                for axis in &mut axes.axes {
397                    axis.outputs.remove(slot);
398                }
399                axes.output_count -= 1;
400            }
401        }
402        axes.sorted().check()
403    }
404
405    pub fn with_extra_input(self, slot: usize) -> TractResult<AxesMapping> {
406        let axes: TVec<Axis> = self
407            .iter_all_axes()
408            .map(|axis| {
409                let mut axis = axis.clone();
410                axis.inputs.insert(slot, tvec!());
411                axis
412            })
413            .collect();
414        AxesMapping::new(self.input_count + 1, self.output_count, axes)
415    }
416
417    pub fn with_extra_axis(
418        mut self,
419        repr: char,
420        io: InOut,
421        position: usize,
422    ) -> TractResult<AxesMapping> {
423        let axis = Axis::new(repr, self.input_count, self.output_count);
424        self.axes.push(axis);
425        self.with_extra_axis_occurency(repr, io, position)
426    }
427
428    pub fn with_extra_axis_occurency(
429        mut self,
430        axis: impl AxisPattern,
431        io: InOut,
432        position: usize,
433    ) -> TractResult<AxesMapping> {
434        match io {
435            InOut::In(slot) => {
436                self.axes.iter_mut().for_each(|axis| {
437                    axis.inputs[slot].iter_mut().for_each(|pos| *pos += (*pos >= position) as usize)
438                });
439                self.axis_mut(axis)?.inputs[slot].push(position);
440            }
441            InOut::Out(slot) => {
442                self.axes.iter_mut().for_each(|axis| {
443                    axis.outputs[slot]
444                        .iter_mut()
445                        .for_each(|pos| *pos += (*pos >= position) as usize)
446                });
447                self.axis_mut(axis)?.outputs[slot].push(position);
448            }
449        }
450        self.sort();
451        self.check()
452    }
453
454    pub fn translate_to_axis_ops(&self) -> TractResult<Vec<AxisOp>> {
455        ensure!(self.input_count() == 1);
456        ensure!(self.output_count() == 1);
457        ensure!(self.iter_all_axes().all(|axis| axis.inputs[0].len() <= 1));
458        let rms = self
459            .iter_all_axes()
460            .filter(|a| a.outputs[0].len() == 0)
461            .sorted_by_key(|axis| -(axis.inputs[0][0] as isize))
462            .collect_vec();
463        let adds = self
464            .iter_all_axes()
465            .filter(|a| a.inputs[0].len() == 0)
466            .sorted_by_key(|axis| axis.outputs[0][0] as isize)
467            .collect_vec();
468        let permutation = rms
469            .iter()
470            .chain(adds.iter())
471            .try_fold(self.clone(), |mapping, axis| mapping.remove_axis(axis.repr))?;
472        let permutation = permutation
473            .iter_all_axes()
474            .sorted_by_key(|axis| axis.outputs[0][0])
475            .map(|axis| axis.inputs[0][0])
476            .collect_vec();
477        let permutation = perm_to_ops(&permutation);
478        let rms = rms.iter().map(|axis| AxisOp::Rm(axis.inputs[0][0]));
479        let adds = adds.iter().map(|axis| AxisOp::Add(axis.outputs[0][0]));
480        Ok(rms.chain(permutation).chain(adds).collect())
481    }
482
483    pub fn from_strs(
484        inputs: &[impl AsRef<str>],
485        outputs: &[impl AsRef<str>],
486    ) -> TractResult<AxesMapping> {
487        let mut axes = HashMap::<char, Axis>::default();
488        for (input_ix, input) in inputs.iter().enumerate() {
489            for (ix, axis) in input.as_ref().chars().enumerate() {
490                axes.entry(axis)
491                    .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
492                    .add_input(input_ix, ix);
493            }
494        }
495        for (output_ix, output) in outputs.iter().enumerate() {
496            for (ix, axis) in output.as_ref().chars().enumerate() {
497                axes.entry(axis)
498                    .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
499                    .add_output(output_ix, ix);
500            }
501        }
502        if outputs.len() == 0 {
503            axes.iter_mut()
504                .sorted_by_key(|(k, _)| *k)
505                .filter(|(_, v)| v.inputs.iter().map(|input| input.len()).sum::<usize>() == 1)
506                .enumerate()
507                .for_each(|(ix, (_, v))| v.add_output(0, ix))
508        }
509        Self::new(
510            inputs.len(),
511            outputs.len().max(1),
512            axes.into_iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
513        )
514    }
515
516    pub fn to_strs(&self) -> (TVec<String>, TVec<String>) {
517        let mut inputs = tvec![];
518        let mut outputs = tvec![];
519        for input in 0..self.input_count() {
520            let s = self
521                .iter_all_axes()
522                .flat_map(|axis| {
523                    axis.inputs[input].iter().map(move |position| (position, axis.repr))
524                })
525                .sorted()
526                .map(|(_, r)| r)
527                .collect();
528            inputs.push(s);
529        }
530        for output in 0..self.output_count() {
531            let s = self
532                .iter_all_axes()
533                .flat_map(|axis| {
534                    axis.outputs[output].iter().map(move |position| (position, axis.repr))
535                })
536                .sorted()
537                .map(|(_, r)| r)
538                .collect();
539            outputs.push(s);
540        }
541        (inputs, outputs)
542    }
543
544    pub fn change_axis_sink(&self, io: InOut, change: &AxisOp) -> TractResult<Option<AxesMapping>> {
545        let (mut inputs, mut outputs) = self.to_strs();
546        let interface: &mut String = match io {
547            InOut::In(i) => &mut inputs[i],
548            InOut::Out(o) => &mut outputs[o],
549        };
550        let mut axes: Vec<char> = interface.chars().collect();
551        match change {
552            AxisOp::Rm(rm) => {
553                axes.remove(*rm);
554            }
555            AxisOp::Add(add) => axes.insert(*add, self.available_label()),
556            AxisOp::Move(from, to) => {
557                let c = axes.remove(*from);
558                axes.insert(*to, c);
559            }
560            _ => return Ok(None),
561        };
562        *interface = axes.into_iter().collect();
563        Ok(Some(AxesMapping::from_strs(&inputs, &outputs)?))
564    }
565
566    pub fn direct(&self, a: InOut, b: InOut) -> bool {
567        self.axes.iter().all(|axis| axis.interface(a) == axis.interface(b))
568    }
569
570    pub fn same_layout<D: DimLike>(
571        &self,
572        a: InOut,
573        b: InOut,
574        shape_a: impl AsRef<[D]>,
575        shape_b: impl AsRef<[D]>,
576    ) -> bool {
577        let shape_a = shape_a.as_ref();
578        let shape_b = shape_b.as_ref();
579        shape_a.iter().cloned().product::<D>() == shape_b.iter().cloned().product()
580            && izip!(
581                self.axes(a).zip(shape_a.iter()).filter(|(_axis, d)| **d != D::one()),
582                self.axes(b).zip(shape_b.iter()).filter(|(_axis, d)| **d != D::one())
583            )
584            .all(|(a, b)| a == b)
585    }
586
587    pub fn axis_ops_to_canonical(&self, io: InOut) -> TractResult<Vec<AxisOp>> {
588        let rank = self.rank(io);
589        let target_rank = self.axes.len();
590        let mut next_insert_axis = 0;
591        let mut permutation = tvec!();
592        for axis in &self.axes {
593            let spec = match io {
594                InOut::In(i) => axis.inputs[i].first(),
595                InOut::Out(o) => axis.outputs[o].first(),
596            };
597            if let Some(pos_in_a) = spec {
598                permutation.push(pos_in_a + target_rank - rank)
599            } else {
600                permutation.push(next_insert_axis);
601                next_insert_axis += 1;
602            }
603        }
604        let mut ops = vec![AxisOp::Add(0); target_rank - rank];
605        ops.extend(crate::ops::change_axes::perm_to_ops(&permutation));
606        Ok(ops)
607    }
608
609    pub fn view_to_canonical<D>(&self, io: InOut, view: &mut ArrayViewD<D>) -> TractResult<()> {
610        for op in self.axis_ops_to_canonical(io)? {
611            op.change_view(view)?;
612        }
613        Ok(())
614    }
615
616    pub fn view_to_canonical_mut<D>(
617        &self,
618        io: InOut,
619        view: &mut ArrayViewMutD<D>,
620    ) -> TractResult<()> {
621        for op in self.axis_ops_to_canonical(io)? {
622            op.change_view_mut(view)?;
623        }
624        Ok(())
625    }
626
627    pub fn compose(&self, other: &AxesMapping) -> TractResult<AxesMapping> {
628        ensure!(self.input_count() == 1 && self.output_count() == 1);
629        ensure!(other.input_count() == 1 && other.output_count() == 1);
630        let mut result = AxesMapping::disconnected_for_ranks(
631            &[self.rank(InOut::In(0))],
632            &[other.rank(InOut::Out(0))],
633        )?;
634        for ix in 0..result.rank(InOut::In(0)) {
635            let Some(inter) = self.track_axis((InOut::In(0), ix), InOut::Out(0))? else { continue };
636            let Some(out) = other.track_axis((InOut::In(0), inter), InOut::Out(0))? else {
637                continue;
638            };
639            result = result.linking((InOut::Out(0), out), (InOut::In(0), ix))?;
640        }
641        Ok(result)
642    }
643}
644
645impl FromStr for AxesMapping {
646    type Err = TractError;
647    fn from_str(s: &str) -> Result<Self, Self::Err> {
648        assert!(!s.contains("..."));
649        let s = s.replace(' ', "");
650        let (inputs, outputs) =
651            if let Some((i, r)) = s.split_once("->") { (i, r) } else { (&*s, "") };
652        let inputs: TVec<&str> = inputs.split(',').collect();
653        let outputs: TVec<&str> = outputs.split(',').filter(|s| s.len() > 0).collect();
654        AxesMapping::from_strs(&inputs, &outputs)
655    }
656}
657
658impl Display for AxesMapping {
659    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
660        let (inputs, outputs) = self.to_strs();
661        write!(f, "{}->{}", inputs.iter().join(","), outputs.iter().join(","))
662    }
663}
664
665#[cfg(test)]
666mod test {
667    use super::*;
668
669    fn m(s: &str) -> AxesMapping {
670        s.parse().unwrap()
671    }
672
673    #[test]
674    fn test_parse_transpose() {
675        assert_eq!(
676            m("ij->ji"),
677            AxesMapping::new(
678                1,
679                1,
680                tvec![
681                    Axis::new('i', 1, 1).output(0, 1).input(0, 0),
682                    Axis::new('j', 1, 1).output(0, 0).input(0, 1)
683                ]
684            )
685            .unwrap(),
686        )
687    }
688
689    #[test]
690    fn test_parse_diag() {
691        assert_eq!(
692            m("ii->i"),
693            AxesMapping::new(
694                1,
695                1,
696                tvec![Axis::new('i', 1, 1).output(0, 0).input(0, 0).input(0, 1)]
697            )
698            .unwrap(),
699        )
700    }
701
702    #[test]
703    fn test_parse_adamar_product_explicit() {
704        assert_eq!(
705            m("i,i->i"),
706            AxesMapping::new(
707                2,
708                1,
709                tvec![Axis::new('i', 2, 1).output(0, 0).input(0, 0).input(1, 0)]
710            )
711            .unwrap(),
712        )
713    }
714
715    #[test]
716    fn test_parse_inner_product_implicit() {
717        assert_eq!(m("i,i"), m("i,i->"))
718    }
719
720    #[test]
721    fn test_parse_batch_matmul() {
722        assert_eq!(
723            m("bij , bjk -> bik "),
724            AxesMapping::new(
725                2,
726                1,
727                tvec![
728                    Axis::new('b', 2, 1).output(0, 0).input(0, 0).input(1, 0),
729                    Axis::new('i', 2, 1).output(0, 1).input(0, 1),
730                    Axis::new('j', 2, 1).input(0, 2).input(1, 1),
731                    Axis::new('k', 2, 1).output(0, 2).input(1, 2)
732                ]
733            )
734            .unwrap()
735        )
736    }
737
738    #[test]
739    fn test_parse_outer_product() {
740        assert_eq!(
741            m("i,j->ij"),
742            AxesMapping::new(
743                2,
744                1,
745                tvec![
746                    Axis::new('i', 2, 1).output(0, 0).input(0, 0),
747                    Axis::new('j', 2, 1).output(0, 1).input(1, 0)
748                ]
749            )
750            .unwrap(),
751        )
752    }
753
754    #[test]
755    fn test_parse_bilinear() {
756        assert_eq!(
757            m("ik,jkl,il->ij"),
758            AxesMapping::new(
759                3,
760                1,
761                tvec![
762                    Axis::new('i', 3, 1).output(0, 0).input(0, 0).input(2, 0),
763                    Axis::new('j', 3, 1).output(0, 1).input(1, 0),
764                    Axis::new('k', 3, 1).input(0, 1).input(1, 1),
765                    Axis::new('l', 3, 1).input(1, 2).input(2, 1)
766                ]
767            )
768            .unwrap(),
769        )
770    }
771
772    #[test]
773    fn test_parse_complex_tensor_contraction() {
774        assert_eq!(
775            m("pqrs,tuqvr->pstuv"),
776            AxesMapping::new(
777                2,
778                1,
779                tvec![
780                    Axis::new('p', 2, 1).output(0, 0).input(0, 0),
781                    Axis::new('q', 2, 1).input(0, 1).input(1, 2),
782                    Axis::new('r', 2, 1).input(0, 2).input(1, 4),
783                    Axis::new('s', 2, 1).output(0, 1).input(0, 3),
784                    Axis::new('t', 2, 1).output(0, 2).input(1, 0),
785                    Axis::new('u', 2, 1).output(0, 3).input(1, 1),
786                    Axis::new('v', 2, 1).output(0, 4).input(1, 3),
787                ]
788            )
789            .unwrap(),
790        )
791    }
792
793    #[test]
794    fn test_parse_complex_tensor_contraction_implicit() {
795        assert_eq!(m("pqrs,tuqvr"), m("pqrs,tuqvr->pstuv"))
796    }
797
798    #[test]
799    fn test_display_expr() {
800        assert_eq!(m("pqrs,tuqvr->pstuv").to_string(), "pqrs,tuqvr->pstuv");
801    }
802
803    #[test]
804    fn test_parse_pulsed_matmul() {
805        assert_eq!(
806            m("sij,ijk->sik"),
807            AxesMapping::new(
808                2,
809                1,
810                tvec![
811                    Axis::new('i', 2, 1).output(0, 1).input(0, 1).input(1, 0),
812                    Axis::new('j', 2, 1).input(0, 2).input(1, 1),
813                    Axis::new('k', 2, 1).output(0, 2).input(1, 2),
814                    Axis::new('s', 2, 1).output(0, 0).input(0, 0),
815                ]
816            )
817            .unwrap()
818        )
819    }
820
821    #[test]
822    fn test_parse_pulsed_batch_matmul() {
823        assert_eq!(
824            m("bsij,ijk->bsik"),
825            AxesMapping::new(
826                2,
827                1,
828                tvec![
829                    Axis::new('b', 2, 1).output(0, 0).input(0, 0),
830                    Axis::new('i', 2, 1).output(0, 2).input(0, 2).input(1, 0),
831                    Axis::new('j', 2, 1).input(0, 3).input(1, 1),
832                    Axis::new('k', 2, 1).output(0, 3).input(1, 2),
833                    Axis::new('s', 2, 1).output(0, 1).input(0, 1),
834                ]
835            )
836            .unwrap()
837        )
838    }
839
840    #[test]
841    fn test_extract_sub_mapping() {
842        assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[0], &[0]).unwrap(), m("bsij->bsik"));
843        assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->bsik"));
844        assert_eq!(m("bsij,ijk->ij").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->ij"));
845    }
846
847    #[test]
848    fn test_remove_axis_0() {
849        assert_eq!(m("ab->a").remove_axis('b').unwrap(), m("a->a"));
850        assert_eq!(m("ba->a").remove_axis('b').unwrap(), m("a->a"));
851        assert_eq!(m("a->ba").remove_axis('b').unwrap(), m("a->a"));
852        assert_eq!(m("a->ab").remove_axis('b').unwrap(), m("a->a"));
853        assert_eq!(m("ab,a->a").remove_axis('b').unwrap(), m("a,a->a"));
854        assert_eq!(m("ba,a->a").remove_axis('b').unwrap(), m("a,a->a"));
855        assert_eq!(m("a,ab->a").remove_axis('b').unwrap(), m("a,a->a"));
856        assert_eq!(m("a,ba->a").remove_axis('b').unwrap(), m("a,a->a"));
857        assert_eq!(m("a,a->ab").remove_axis('b').unwrap(), m("a,a->a"));
858        assert_eq!(m("a,a->ba").remove_axis('b').unwrap(), m("a,a->a"));
859        assert_eq!(m("bsij,ijk->bsik").remove_axis('i').unwrap(), m("bsj,jk->bsk"),);
860    }
861
862    #[test]
863    fn test_translate_to_ops_rm_add() {
864        assert_eq!(m("ab->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(1)));
865        assert_eq!(m("ba->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(0)));
866        assert_eq!(
867            m("ab->c").translate_to_axis_ops().unwrap(),
868            vec!(AxisOp::Rm(1), AxisOp::Rm(0), AxisOp::Add(0))
869        );
870    }
871
872    #[test]
873    fn test_translate_to_ops_add_0() {
874        assert_eq!(
875            m("bacmn->bmn").translate_to_axis_ops().unwrap(),
876            vec!(AxisOp::Rm(2), AxisOp::Rm(1))
877        );
878    }
879
880    #[test]
881    fn test_translate_to_ops_move() {
882        assert_eq!(m("ab->ba").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(1, 0)));
883    }
884
885    #[test]
886    fn test_translate_to_ops_move_20() {
887        assert_eq!(m("abc->cab").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(2, 0)));
888    }
889
890    #[test]
891    fn test_translate_to_ops_complex() {
892        assert_eq!(
893            m("anbck->backn").translate_to_axis_ops().unwrap(),
894            vec!(AxisOp::Move(2, 0), AxisOp::Move(2, 4))
895        );
896    }
897}