1use std::fmt::Display;
2use std::str::FromStr;
3
4use tract_data::itertools::izip;
5use tract_ndarray::{ArrayViewD, ArrayViewMutD};
6
7use crate::internal::*;
8use crate::prelude::tract_itertools::Itertools;
9
10use super::Axis;
11
12pub trait AxisPattern: std::fmt::Debug {
13 fn search(&self, mapping: &AxesMapping) -> Option<usize>;
14}
15
16impl AxisPattern for char {
17 fn search(&self, mapping: &AxesMapping) -> Option<usize> {
18 mapping.axes.iter().position(|axis| axis.repr == *self)
19 }
20}
21
22impl AxisPattern for (InOut, usize) {
23 fn search(&self, mapping: &AxesMapping) -> Option<usize> {
24 match self.0 {
25 InOut::In(i) => mapping.axes.iter().position(|axis| axis.inputs[i].contains(&self.1)),
26 InOut::Out(o) => mapping.axes.iter().position(|axis| axis.outputs[o].contains(&self.1)),
27 }
28 }
29}
30
31impl AxisPattern for &Axis {
32 fn search(&self, mapping: &AxesMapping) -> Option<usize> {
33 mapping.axes.iter().position(|ax| self == &ax)
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct AxesMapping {
39 input_count: usize,
40 output_count: usize,
41 axes: TVec<Axis>,
42}
43
44impl AxesMapping {
45 pub fn new(
46 input_count: usize,
47 output_count: usize,
48 it: impl AsRef<[Axis]>,
49 ) -> TractResult<AxesMapping> {
50 let axes: TVec<_> = it.as_ref().into();
51 AxesMapping { axes, output_count, input_count }.sorted().check()
52 }
53
54 pub fn for_numpy_matmul(
55 rank: usize,
56 transposing_a: bool,
57 transposing_b: bool,
58 transposing_c: bool,
59 ) -> TractResult<AxesMapping> {
60 let mut axes: TVec<Axis> = ('a'..)
61 .take(rank - 2)
62 .enumerate()
63 .map(|(ix, repr)| Axis {
64 repr,
65 inputs: tvec!(tvec!(ix), tvec!(ix)),
66 outputs: tvec!(tvec!(ix)),
67 })
68 .collect();
69 axes.push(Axis {
70 repr: 'm',
71 inputs: tvec!(tvec!(rank - 2 + transposing_a as usize), tvec!()),
72 outputs: tvec!(tvec!(rank - 2 + transposing_c as usize)),
73 });
74 axes.push(Axis {
75 repr: 'k',
76 inputs: tvec!(
77 tvec!(rank - 1 - transposing_a as usize),
78 tvec!(rank - 2 + transposing_b as usize)
79 ),
80 outputs: tvec!(tvec!()),
81 });
82 axes.push(Axis {
83 repr: 'n',
84 inputs: tvec!(tvec!(), tvec!(rank - 1 - transposing_b as usize),),
85 outputs: tvec!(tvec!(rank - 1 - transposing_c as usize)),
86 });
87 AxesMapping::new(2, 1, axes)
88 }
89
90 pub fn disconnected(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
91 let input_ranks: TVec<usize> = inputs.iter().map(|i| i.rank()).collect();
92 let output_ranks: TVec<usize> = outputs.iter().map(|i| i.rank()).collect();
93 Self::disconnected_for_ranks(&input_ranks, &output_ranks)
94 }
95
96 pub fn disconnected_for_ranks(inputs: &[usize], outputs: &[usize]) -> TractResult<AxesMapping> {
97 let mut axes = tvec!();
98 let mut alphabet = 'a'..;
99 for (ix, &rank) in inputs.iter().enumerate() {
100 for a in 0..rank {
101 axes.push(
102 Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).input(ix, a),
103 );
104 }
105 }
106 for (ix, &rank) in outputs.iter().enumerate() {
107 for a in 0..rank {
108 axes.push(
109 Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).output(ix, a),
110 );
111 }
112 }
113 AxesMapping::new(inputs.len(), outputs.len(), axes)
114 }
115
116 pub fn natural(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
117 let rank = inputs[0].rank();
118 let axes = (0..rank)
119 .zip('a'..)
120 .map(|(axis_id, repr)| Axis::natural(inputs.len(), outputs.len(), repr, axis_id))
121 .collect::<TVec<_>>();
122 AxesMapping::new(inputs.len(), outputs.len(), axes)
123 }
124
125 pub fn natural_for_rank(
126 inputs: usize,
127 outputs: usize,
128 rank: usize,
129 ) -> TractResult<AxesMapping> {
130 let axes = (0..rank)
131 .zip('a'..)
132 .map(|(axis_id, repr)| Axis::natural(inputs, outputs, repr, axis_id))
133 .collect::<TVec<_>>();
134 AxesMapping::new(inputs, outputs, axes)
135 }
136
137 pub fn iter_all_axes(&self) -> impl Iterator<Item = &Axis> {
138 self.axes.iter()
139 }
140
141 pub fn iter_all_axes_mut(&mut self) -> impl Iterator<Item = &mut Axis> {
142 self.axes.iter_mut()
143 }
144
145 pub fn input_count(&self) -> usize {
146 self.input_count
147 }
148
149 pub fn output_count(&self) -> usize {
150 self.output_count
151 }
152
153 pub fn axis_positions(&self, io: InOut, p: impl AxisPattern) -> TractResult<&[usize]> {
154 let axis = self.axis(p)?;
155 Ok(match io {
156 InOut::In(i) => &*axis.inputs[i],
157 InOut::Out(o) => &*axis.outputs[o],
158 })
159 }
160
161 pub fn rank(&self, io: InOut) -> usize {
162 match io {
163 InOut::In(i) => self.iter_all_axes().map(|axis| axis.inputs[i].len()).sum(),
164 InOut::Out(o) => self.iter_all_axes().map(|axis| axis.outputs[o].len()).sum(),
165 }
166 }
167
168 fn search(&self, p: impl AxisPattern) -> TractResult<usize> {
169 p.search(self).with_context(|| format!("Axis {p:?} not found in {self}"))
170 }
171
172 pub fn axis(&self, p: impl AxisPattern) -> TractResult<&Axis> {
173 Ok(&self.axes[self.search(p)?])
174 }
175
176 fn axis_mut(&mut self, p: impl AxisPattern) -> TractResult<&mut Axis> {
177 let ix = self.search(p)?;
178 Ok(&mut self.axes[ix])
179 }
180
181 pub fn axes(&self, io: InOut) -> impl Iterator<Item = &Axis> {
182 (0..self.rank(io)).map(move |ix| self.axis((io, ix)).unwrap())
183 }
184
185 pub fn track_axis(&self, from: impl AxisPattern, to: InOut) -> TractResult<Option<usize>> {
186 let axis = self.axis(from)?;
187 let positions = axis.interface(to);
188 Ok(if positions.len() == 1 { Some(positions[0]) } else { None })
189 }
190
191 pub fn renaming(mut self, axis: impl AxisPattern, name: char) -> TractResult<AxesMapping> {
192 let position = self.search(axis)?;
193 let old_label = self.axes[position].repr;
194 if let Ok(conflict) = self.axis_mut(name) {
195 conflict.repr = old_label
196 }
197 self.axes[position].repr = name;
198 self.sort();
199 self.check()
200 }
201
202 pub fn linking(
203 mut self,
204 target: impl AxisPattern,
205 axis: impl AxisPattern,
206 ) -> TractResult<AxesMapping> {
207 let axis = self.axis(axis)?;
208 let axis_ix = self.axes.iter().position(|a| a == axis).unwrap();
209 let axis = self.axes.remove(axis_ix);
210 let target = self.axis_mut(target)?;
211 for (ia, ib) in target.inputs.iter_mut().zip(axis.inputs.iter()) {
212 ia.extend(ib.into_iter().cloned())
213 }
214 for (ia, ib) in target.outputs.iter_mut().zip(axis.outputs.iter()) {
215 ia.extend(ib.into_iter().cloned())
216 }
217 self.sort();
218 self.check()
219 }
220
221 fn sort(&mut self) {
222 let order: Vec<(usize, usize, usize, char)> = self
223 .axes
224 .iter()
225 .flat_map(|axis| {
226 axis.inputs
227 .iter()
228 .enumerate()
229 .flat_map(move |(slot, input)| {
230 input.iter().map(move |p| (1, slot, *p, axis.repr))
231 })
232 .chain(axis.outputs.iter().enumerate().flat_map(move |(slot, output)| {
233 output.iter().map(move |p| (0, slot, *p, axis.repr))
234 }))
235 })
236 .sorted()
237 .dedup()
238 .collect_vec();
239 self.axes.sort_by_key(|axis| order.iter().position(|tuple| tuple.3 == axis.repr).unwrap());
240 }
241
242 fn sorted(mut self) -> AxesMapping {
243 self.sort();
244 self
245 }
246
247 fn do_check(&self) -> TractResult<()> {
248 for axis in &self.axes {
249 ensure!(axis.inputs.len() == self.input_count);
250 ensure!(axis.outputs.len() == self.output_count);
251 ensure!(
252 axis.inputs.iter().map(|i| i.len()).sum::<usize>()
253 + axis.outputs.iter().map(|o| o.len()).sum::<usize>()
254 > 0
255 );
256 }
257 for input_ix in 0..self.input_count() {
258 for axis in 0..self.rank(InOut::In(input_ix)) {
259 ensure!(self.axis((InOut::In(input_ix), axis)).is_ok());
260 }
261 }
262 for output_ix in 0..self.output_count() {
263 for axis in 0..self.rank(InOut::Out(output_ix)) {
264 ensure!(self.axis((InOut::Out(output_ix), axis)).is_ok());
265 }
266 }
267 ensure!(self.axes.iter().map(|ax| ax.repr).duplicates().count() == 0);
268 ensure!(
269 self == &{
270 let mut x = self.clone();
271 x.sort();
272 x
273 }
274 );
275 Ok(())
276 }
277
278 pub fn check(self) -> TractResult<AxesMapping> {
279 self.do_check().with_context(|| format!("Checking {:?}", self.axes))?;
280 Ok(self)
281 }
282
283 pub fn available_label(&self) -> char {
284 self.available_labels().next().unwrap()
285 }
286
287 pub fn available_labels(&self) -> impl Iterator<Item = char> + '_ {
288 ('a'..).filter(|c| self.iter_all_axes().all(|axis| axis.repr != *c))
289 }
290
291 pub fn is_element_wise_unary(&self) -> bool {
292 self.input_count == 1
293 && self.output_count == 1
294 && self
295 .iter_all_axes()
296 .all(|axis| axis.inputs[0].len() == 1 && axis.outputs[0] == axis.inputs[0])
297 }
298
299 pub fn extract_sub_mapping(
300 &self,
301 inputs: &[usize],
302 outputs: &[usize],
303 ) -> TractResult<AxesMapping> {
304 let axes: Vec<_> = self
305 .iter_all_axes()
306 .filter(|axis| {
307 inputs.iter().any(|i| axis.inputs[*i].len() > 0)
308 || outputs.iter().any(|o| axis.outputs[*o].len() > 0)
309 })
310 .map(|axis| Axis {
311 inputs: axis
312 .inputs
313 .iter()
314 .enumerate()
315 .filter(|(ix, _)| inputs.contains(ix))
316 .map(|(_, it)| it.clone())
317 .collect(),
318 outputs: axis
319 .outputs
320 .iter()
321 .enumerate()
322 .filter(|(ix, _)| outputs.contains(ix))
323 .map(|(_, it)| it.clone())
324 .collect(),
325 repr: axis.repr,
326 })
327 .collect();
328 AxesMapping::new(inputs.len(), outputs.len(), axes)
329 }
330
331 pub fn relabel(mut self) -> TractResult<AxesMapping> {
332 for (ax, repr) in self.axes.iter_mut().zip('a'..) {
333 ax.repr = repr;
334 }
335 Ok(self)
336 }
337
338 pub fn remove_axis(&self, repr: char) -> TractResult<AxesMapping> {
339 let mut axes: TVec<Axis> =
340 self.axes.iter().filter(|axis| axis.repr != repr).cloned().collect();
341 let removed = self.axis(repr).context("Axis not found")?;
342 for input in 0..self.input_count {
343 for &position in &removed.inputs[input] {
344 for other in &mut axes {
345 other.inputs[input]
346 .iter_mut()
347 .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
348 }
349 }
350 }
351 for output in 0..self.output_count {
352 for &position in &removed.outputs[output] {
353 for other in &mut axes {
354 other.outputs[output]
355 .iter_mut()
356 .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
357 }
358 }
359 }
360 AxesMapping::new(self.input_count, self.output_count, axes)
361 }
362
363 pub fn remove_axis_occurency(&self, slot: InOut, position: usize) -> TractResult<AxesMapping> {
364 let axis = self.axis((slot, position))?;
365 if axis.inputs.iter().map(|i| i.len()).sum::<usize>()
366 + axis.outputs.iter().map(|i| i.len()).sum::<usize>()
367 == 1
368 {
369 return self.remove_axis(axis.repr);
370 }
371 let mut axes = self.axes.clone();
372 match slot {
373 InOut::In(slot) => {
374 for axis in &mut axes {
375 axis.inputs[slot].retain(|pos| *pos != position);
376 axis.inputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);
377 }
378 }
379 InOut::Out(slot) => {
380 for axis in &mut axes {
381 axis.outputs[slot].retain(|pos| *pos != position);
382 axis.outputs[slot]
383 .iter_mut()
384 .for_each(|pos| *pos -= (*pos > position) as usize);
385 }
386 }
387 }
388 AxesMapping::new(self.input_count, self.output_count, axes)
389 }
390
391 pub fn remove_slot(&self, slot: InOut) -> TractResult<AxesMapping> {
392 let mut axes = self.clone();
393 while axes.rank(slot) > 0 {
394 axes = axes.remove_axis_occurency(slot, 0)?
395 }
396 match slot {
397 InOut::In(slot) => {
398 for axis in &mut axes.axes {
399 axis.inputs.remove(slot);
400 }
401 axes.input_count -= 1;
402 }
403 InOut::Out(slot) => {
404 for axis in &mut axes.axes {
405 axis.outputs.remove(slot);
406 }
407 axes.output_count -= 1;
408 }
409 }
410 axes.sorted().check()
411 }
412
413 pub fn with_extra_input(self, slot: usize) -> TractResult<AxesMapping> {
414 let axes: TVec<Axis> = self
415 .iter_all_axes()
416 .map(|axis| {
417 let mut axis = axis.clone();
418 axis.inputs.insert(slot, tvec!());
419 axis
420 })
421 .collect();
422 AxesMapping::new(self.input_count + 1, self.output_count, axes)
423 }
424
425 pub fn with_extra_axis(
426 mut self,
427 repr: char,
428 io: InOut,
429 position: usize,
430 ) -> TractResult<AxesMapping> {
431 let axis = Axis::new(repr, self.input_count, self.output_count);
432 self.axes.push(axis);
433 self.with_extra_axis_occurency(repr, io, position)
434 }
435
436 pub fn with_extra_axis_occurency(
437 mut self,
438 axis: impl AxisPattern,
439 io: InOut,
440 position: usize,
441 ) -> TractResult<AxesMapping> {
442 match io {
443 InOut::In(slot) => {
444 self.axes.iter_mut().for_each(|axis| {
445 axis.inputs[slot].iter_mut().for_each(|pos| *pos += (*pos >= position) as usize)
446 });
447 self.axis_mut(axis)?.inputs[slot].push(position);
448 }
449 InOut::Out(slot) => {
450 self.axes.iter_mut().for_each(|axis| {
451 axis.outputs[slot]
452 .iter_mut()
453 .for_each(|pos| *pos += (*pos >= position) as usize)
454 });
455 self.axis_mut(axis)?.outputs[slot].push(position);
456 }
457 }
458 self.sort();
459 self.check()
460 }
461
462 pub fn translate_to_axis_ops(&self) -> TractResult<Vec<AxisOp>> {
463 ensure!(self.input_count() == 1);
464 ensure!(self.output_count() == 1);
465 ensure!(self.iter_all_axes().all(|axis| axis.inputs[0].len() <= 1));
466 let rms = self
467 .iter_all_axes()
468 .filter(|a| a.outputs[0].len() == 0)
469 .sorted_by_key(|axis| -(axis.inputs[0][0] as isize))
470 .collect_vec();
471 let adds = self
472 .iter_all_axes()
473 .filter(|a| a.inputs[0].len() == 0)
474 .sorted_by_key(|axis| axis.outputs[0][0] as isize)
475 .collect_vec();
476 let permutation = rms
477 .iter()
478 .chain(adds.iter())
479 .try_fold(self.clone(), |mapping, axis| mapping.remove_axis(axis.repr))?;
480 let permutation = permutation
481 .iter_all_axes()
482 .sorted_by_key(|axis| axis.outputs[0][0])
483 .map(|axis| axis.inputs[0][0])
484 .collect_vec();
485 let permutation = perm_to_ops(&permutation);
486 let rms = rms.iter().map(|axis| AxisOp::Rm(axis.inputs[0][0]));
487 let adds = adds.iter().map(|axis| AxisOp::Add(axis.outputs[0][0]));
488 Ok(rms.chain(permutation).chain(adds).collect())
489 }
490
491 pub fn from_strs(
492 inputs: &[impl AsRef<str>],
493 outputs: &[impl AsRef<str>],
494 ) -> TractResult<AxesMapping> {
495 let mut axes = HashMap::<char, Axis>::default();
496 for (input_ix, input) in inputs.iter().enumerate() {
497 for (ix, axis) in input.as_ref().chars().enumerate() {
498 axes.entry(axis)
499 .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
500 .add_input(input_ix, ix);
501 }
502 }
503 for (output_ix, output) in outputs.iter().enumerate() {
504 for (ix, axis) in output.as_ref().chars().enumerate() {
505 axes.entry(axis)
506 .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
507 .add_output(output_ix, ix);
508 }
509 }
510 if outputs.len() == 0 {
511 axes.iter_mut()
512 .sorted_by_key(|(k, _)| *k)
513 .filter(|(_, v)| v.inputs.iter().map(|input| input.len()).sum::<usize>() == 1)
514 .enumerate()
515 .for_each(|(ix, (_, v))| v.add_output(0, ix))
516 }
517 Self::new(
518 inputs.len(),
519 outputs.len().max(1),
520 axes.into_iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
521 )
522 }
523
524 pub fn to_strs(&self) -> (TVec<String>, TVec<String>) {
525 let mut inputs = tvec![];
526 let mut outputs = tvec![];
527 for input in 0..self.input_count() {
528 let s = self
529 .iter_all_axes()
530 .flat_map(|axis| {
531 axis.inputs[input].iter().map(move |position| (position, axis.repr))
532 })
533 .sorted()
534 .map(|(_, r)| r)
535 .collect();
536 inputs.push(s);
537 }
538 for output in 0..self.output_count() {
539 let s = self
540 .iter_all_axes()
541 .flat_map(|axis| {
542 axis.outputs[output].iter().map(move |position| (position, axis.repr))
543 })
544 .sorted()
545 .map(|(_, r)| r)
546 .collect();
547 outputs.push(s);
548 }
549 (inputs, outputs)
550 }
551
552 pub fn change_axis_sink(&self, io: InOut, change: &AxisOp) -> TractResult<Option<AxesMapping>> {
553 let (mut inputs, mut outputs) = self.to_strs();
554 let interface: &mut String = match io {
555 InOut::In(i) => &mut inputs[i],
556 InOut::Out(o) => &mut outputs[o],
557 };
558 let mut axes: Vec<char> = interface.chars().collect();
559 match change {
560 AxisOp::Rm(rm) => {
561 axes.remove(*rm);
562 }
563 AxisOp::Add(add) => axes.insert(*add, self.available_label()),
564 AxisOp::Move(from, to) => {
565 let c = axes.remove(*from);
566 axes.insert(*to, c);
567 }
568 _ => return Ok(None),
569 };
570 *interface = axes.into_iter().collect();
571 Ok(Some(AxesMapping::from_strs(&inputs, &outputs)?))
572 }
573
574 pub fn direct(&self, a: InOut, b: InOut) -> bool {
575 self.axes.iter().all(|axis| axis.interface(a) == axis.interface(b))
576 }
577
578 pub fn same_layout<D: DimLike>(
579 &self,
580 a: InOut,
581 b: InOut,
582 shape_a: impl AsRef<[D]>,
583 shape_b: impl AsRef<[D]>,
584 ) -> bool {
585 let shape_a = shape_a.as_ref();
586 let shape_b = shape_b.as_ref();
587 shape_a.iter().cloned().product::<D>() == shape_b.iter().cloned().product()
588 && izip!(
589 self.axes(a).zip(shape_a.iter()).filter(|(_axis, d)| **d != D::one()),
590 self.axes(b).zip(shape_b.iter()).filter(|(_axis, d)| **d != D::one())
591 )
592 .all(|(a, b)| a == b)
593 }
594
595 pub fn axis_ops_to_canonical(&self, io: InOut) -> TractResult<Vec<AxisOp>> {
596 let rank = self.rank(io);
597 let target_rank = self.axes.len();
598 let mut next_insert_axis = 0;
599 let mut permutation = tvec!();
600 for axis in &self.axes {
601 let spec = match io {
602 InOut::In(i) => axis.inputs[i].first(),
603 InOut::Out(o) => axis.outputs[o].first(),
604 };
605 if let Some(pos_in_a) = spec {
606 permutation.push(pos_in_a + target_rank - rank)
607 } else {
608 permutation.push(next_insert_axis);
609 next_insert_axis += 1;
610 }
611 }
612 let mut ops = vec![AxisOp::Add(0); target_rank - rank];
613 ops.extend(crate::ops::change_axes::perm_to_ops(&permutation));
614 Ok(ops)
615 }
616
617 pub fn view_to_canonical<D>(&self, io: InOut, view: &mut ArrayViewD<D>) -> TractResult<()> {
618 for op in self.axis_ops_to_canonical(io)? {
619 op.change_view(view)?;
620 }
621 Ok(())
622 }
623
624 pub fn view_to_canonical_mut<D>(
625 &self,
626 io: InOut,
627 view: &mut ArrayViewMutD<D>,
628 ) -> TractResult<()> {
629 for op in self.axis_ops_to_canonical(io)? {
630 op.change_view_mut(view)?;
631 }
632 Ok(())
633 }
634
635 pub fn compose(&self, other: &AxesMapping) -> TractResult<AxesMapping> {
636 ensure!(self.input_count() == 1 && self.output_count() == 1);
637 ensure!(other.input_count() == 1 && other.output_count() == 1);
638 let mut result = AxesMapping::disconnected_for_ranks(
639 &[self.rank(InOut::In(0))],
640 &[other.rank(InOut::Out(0))],
641 )?;
642 for ix in 0..result.rank(InOut::In(0)) {
643 let Some(inter) = self.track_axis((InOut::In(0), ix), InOut::Out(0))? else { continue };
644 let Some(out) = other.track_axis((InOut::In(0), inter), InOut::Out(0))? else {
645 continue;
646 };
647 result = result.linking((InOut::Out(0), out), (InOut::In(0), ix))?;
648 }
649 Ok(result)
650 }
651}
652
653impl FromStr for AxesMapping {
654 type Err = TractError;
655 fn from_str(s: &str) -> Result<Self, Self::Err> {
656 assert!(!s.contains("..."));
657 let s = s.replace(' ', "");
658 let (inputs, outputs) =
659 if let Some((i, r)) = s.split_once("->") { (i, r) } else { (&*s, "") };
660 let inputs: TVec<&str> = inputs.split(',').collect();
661 let outputs: TVec<&str> = outputs.split(',').filter(|s| s.len() > 0).collect();
662 AxesMapping::from_strs(&inputs, &outputs)
663 }
664}
665
666impl Display for AxesMapping {
667 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
668 let (inputs, outputs) = self.to_strs();
669 write!(f, "{}->{}", inputs.iter().join(","), outputs.iter().join(","))
670 }
671}
672
673#[cfg(test)]
674mod test {
675 use super::*;
676
677 fn m(s: &str) -> AxesMapping {
678 s.parse().unwrap()
679 }
680
681 #[test]
682 fn test_parse_transpose() {
683 assert_eq!(
684 m("ij->ji"),
685 AxesMapping::new(
686 1,
687 1,
688 tvec![
689 Axis::new('i', 1, 1).output(0, 1).input(0, 0),
690 Axis::new('j', 1, 1).output(0, 0).input(0, 1)
691 ]
692 )
693 .unwrap(),
694 )
695 }
696
697 #[test]
698 fn test_parse_diag() {
699 assert_eq!(
700 m("ii->i"),
701 AxesMapping::new(
702 1,
703 1,
704 tvec![Axis::new('i', 1, 1).output(0, 0).input(0, 0).input(0, 1)]
705 )
706 .unwrap(),
707 )
708 }
709
710 #[test]
711 fn test_parse_adamar_product_explicit() {
712 assert_eq!(
713 m("i,i->i"),
714 AxesMapping::new(
715 2,
716 1,
717 tvec![Axis::new('i', 2, 1).output(0, 0).input(0, 0).input(1, 0)]
718 )
719 .unwrap(),
720 )
721 }
722
723 #[test]
724 fn test_parse_inner_product_implicit() {
725 assert_eq!(m("i,i"), m("i,i->"))
726 }
727
728 #[test]
729 fn test_parse_batch_matmul() {
730 assert_eq!(
731 m("bij , bjk -> bik "),
732 AxesMapping::new(
733 2,
734 1,
735 tvec![
736 Axis::new('b', 2, 1).output(0, 0).input(0, 0).input(1, 0),
737 Axis::new('i', 2, 1).output(0, 1).input(0, 1),
738 Axis::new('j', 2, 1).input(0, 2).input(1, 1),
739 Axis::new('k', 2, 1).output(0, 2).input(1, 2)
740 ]
741 )
742 .unwrap()
743 )
744 }
745
746 #[test]
747 fn test_parse_outer_product() {
748 assert_eq!(
749 m("i,j->ij"),
750 AxesMapping::new(
751 2,
752 1,
753 tvec![
754 Axis::new('i', 2, 1).output(0, 0).input(0, 0),
755 Axis::new('j', 2, 1).output(0, 1).input(1, 0)
756 ]
757 )
758 .unwrap(),
759 )
760 }
761
762 #[test]
763 fn test_parse_bilinear() {
764 assert_eq!(
765 m("ik,jkl,il->ij"),
766 AxesMapping::new(
767 3,
768 1,
769 tvec![
770 Axis::new('i', 3, 1).output(0, 0).input(0, 0).input(2, 0),
771 Axis::new('j', 3, 1).output(0, 1).input(1, 0),
772 Axis::new('k', 3, 1).input(0, 1).input(1, 1),
773 Axis::new('l', 3, 1).input(1, 2).input(2, 1)
774 ]
775 )
776 .unwrap(),
777 )
778 }
779
780 #[test]
781 fn test_parse_complex_tensor_contraction() {
782 assert_eq!(
783 m("pqrs,tuqvr->pstuv"),
784 AxesMapping::new(
785 2,
786 1,
787 tvec![
788 Axis::new('p', 2, 1).output(0, 0).input(0, 0),
789 Axis::new('q', 2, 1).input(0, 1).input(1, 2),
790 Axis::new('r', 2, 1).input(0, 2).input(1, 4),
791 Axis::new('s', 2, 1).output(0, 1).input(0, 3),
792 Axis::new('t', 2, 1).output(0, 2).input(1, 0),
793 Axis::new('u', 2, 1).output(0, 3).input(1, 1),
794 Axis::new('v', 2, 1).output(0, 4).input(1, 3),
795 ]
796 )
797 .unwrap(),
798 )
799 }
800
801 #[test]
802 fn test_parse_complex_tensor_contraction_implicit() {
803 assert_eq!(m("pqrs,tuqvr"), m("pqrs,tuqvr->pstuv"))
804 }
805
806 #[test]
807 fn test_display_expr() {
808 assert_eq!(m("pqrs,tuqvr->pstuv").to_string(), "pqrs,tuqvr->pstuv");
809 }
810
811 #[test]
812 fn test_parse_pulsed_matmul() {
813 assert_eq!(
814 m("sij,ijk->sik"),
815 AxesMapping::new(
816 2,
817 1,
818 tvec![
819 Axis::new('i', 2, 1).output(0, 1).input(0, 1).input(1, 0),
820 Axis::new('j', 2, 1).input(0, 2).input(1, 1),
821 Axis::new('k', 2, 1).output(0, 2).input(1, 2),
822 Axis::new('s', 2, 1).output(0, 0).input(0, 0),
823 ]
824 )
825 .unwrap()
826 )
827 }
828
829 #[test]
830 fn test_parse_pulsed_batch_matmul() {
831 assert_eq!(
832 m("bsij,ijk->bsik"),
833 AxesMapping::new(
834 2,
835 1,
836 tvec![
837 Axis::new('b', 2, 1).output(0, 0).input(0, 0),
838 Axis::new('i', 2, 1).output(0, 2).input(0, 2).input(1, 0),
839 Axis::new('j', 2, 1).input(0, 3).input(1, 1),
840 Axis::new('k', 2, 1).output(0, 3).input(1, 2),
841 Axis::new('s', 2, 1).output(0, 1).input(0, 1),
842 ]
843 )
844 .unwrap()
845 )
846 }
847
848 #[test]
849 fn test_extract_sub_mapping() {
850 assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[0], &[0]).unwrap(), m("bsij->bsik"));
851 assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->bsik"));
852 assert_eq!(m("bsij,ijk->ij").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->ij"));
853 }
854
855 #[test]
856 fn test_remove_axis_0() {
857 assert_eq!(m("ab->a").remove_axis('b').unwrap(), m("a->a"));
858 assert_eq!(m("ba->a").remove_axis('b').unwrap(), m("a->a"));
859 assert_eq!(m("a->ba").remove_axis('b').unwrap(), m("a->a"));
860 assert_eq!(m("a->ab").remove_axis('b').unwrap(), m("a->a"));
861 assert_eq!(m("ab,a->a").remove_axis('b').unwrap(), m("a,a->a"));
862 assert_eq!(m("ba,a->a").remove_axis('b').unwrap(), m("a,a->a"));
863 assert_eq!(m("a,ab->a").remove_axis('b').unwrap(), m("a,a->a"));
864 assert_eq!(m("a,ba->a").remove_axis('b').unwrap(), m("a,a->a"));
865 assert_eq!(m("a,a->ab").remove_axis('b').unwrap(), m("a,a->a"));
866 assert_eq!(m("a,a->ba").remove_axis('b').unwrap(), m("a,a->a"));
867 assert_eq!(m("bsij,ijk->bsik").remove_axis('i').unwrap(), m("bsj,jk->bsk"),);
868 }
869
870 #[test]
871 fn test_translate_to_ops_rm_add() {
872 assert_eq!(m("ab->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(1)));
873 assert_eq!(m("ba->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(0)));
874 assert_eq!(
875 m("ab->c").translate_to_axis_ops().unwrap(),
876 vec!(AxisOp::Rm(1), AxisOp::Rm(0), AxisOp::Add(0))
877 );
878 }
879
880 #[test]
881 fn test_translate_to_ops_add_0() {
882 assert_eq!(
883 m("bacmn->bmn").translate_to_axis_ops().unwrap(),
884 vec!(AxisOp::Rm(2), AxisOp::Rm(1))
885 );
886 }
887
888 #[test]
889 fn test_translate_to_ops_move() {
890 assert_eq!(m("ab->ba").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(1, 0)));
891 }
892
893 #[test]
894 fn test_translate_to_ops_move_20() {
895 assert_eq!(m("abc->cab").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(2, 0)));
896 }
897
898 #[test]
899 fn test_translate_to_ops_complex() {
900 assert_eq!(
901 m("anbck->backn").translate_to_axis_ops().unwrap(),
902 vec!(AxisOp::Move(2, 0), AxisOp::Move(2, 4))
903 );
904 }
905}