1use std::fmt::Display;
2use std::str::FromStr;
3
4use tract_data::itertools::izip;
5use tract_ndarray::{ArrayViewD, ArrayViewMutD};
6
7use crate::internal::*;
8use crate::prelude::tract_itertools::Itertools;
9
10use super::Axis;
11
12pub trait AxisPattern: std::fmt::Debug {
13 fn search(&self, mapping: &AxesMapping) -> Option<usize>;
14}
15
16impl AxisPattern for char {
17 fn search(&self, mapping: &AxesMapping) -> Option<usize> {
18 mapping.axes.iter().position(|axis| axis.repr == *self)
19 }
20}
21
22impl AxisPattern for (InOut, usize) {
23 fn search(&self, mapping: &AxesMapping) -> Option<usize> {
24 match self.0 {
25 InOut::In(i) => mapping.axes.iter().position(|axis| axis.inputs[i].contains(&self.1)),
26 InOut::Out(o) => mapping.axes.iter().position(|axis| axis.outputs[o].contains(&self.1)),
27 }
28 }
29}
30
31impl AxisPattern for &Axis {
32 fn search(&self, mapping: &AxesMapping) -> Option<usize> {
33 mapping.axes.iter().position(|ax| self == &ax)
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct AxesMapping {
39 input_count: usize,
40 output_count: usize,
41 axes: TVec<Axis>,
42}
43
44impl AxesMapping {
45 pub fn new(
46 input_count: usize,
47 output_count: usize,
48 it: impl AsRef<[Axis]>,
49 ) -> TractResult<AxesMapping> {
50 let axes: TVec<_> = it.as_ref().into();
51 AxesMapping { axes, output_count, input_count }.sorted().check()
52 }
53
54 pub fn for_numpy_matmul(
55 rank: usize,
56 transposing_a: bool,
57 transposing_b: bool,
58 transposing_c: bool,
59 ) -> TractResult<AxesMapping> {
60 let mut axes: TVec<Axis> = ('a'..)
61 .take(rank - 2)
62 .enumerate()
63 .map(|(ix, repr)| Axis {
64 repr,
65 inputs: tvec!(tvec!(ix), tvec!(ix)),
66 outputs: tvec!(tvec!(ix)),
67 })
68 .collect();
69 axes.push(Axis {
70 repr: 'm',
71 inputs: tvec!(tvec!(rank - 2 + transposing_a as usize), tvec!()),
72 outputs: tvec!(tvec!(rank - 2 + transposing_c as usize)),
73 });
74 axes.push(Axis {
75 repr: 'k',
76 inputs: tvec!(
77 tvec!(rank - 1 - transposing_a as usize),
78 tvec!(rank - 2 + transposing_b as usize)
79 ),
80 outputs: tvec!(tvec!()),
81 });
82 axes.push(Axis {
83 repr: 'n',
84 inputs: tvec!(tvec!(), tvec!(rank - 1 - transposing_b as usize),),
85 outputs: tvec!(tvec!(rank - 1 - transposing_c as usize)),
86 });
87 AxesMapping::new(2, 1, axes)
88 }
89
90 pub fn disconnected(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
91 let input_ranks: TVec<usize> = inputs.iter().map(|i| i.rank()).collect();
92 let output_ranks: TVec<usize> = outputs.iter().map(|i| i.rank()).collect();
93 Self::disconnected_for_ranks(&input_ranks, &output_ranks)
94 }
95
96 pub fn disconnected_for_ranks(inputs: &[usize], outputs: &[usize]) -> TractResult<AxesMapping> {
97 let mut axes = tvec!();
98 let mut alphabet = 'a'..;
99 for (ix, &rank) in inputs.iter().enumerate() {
100 for a in 0..rank {
101 axes.push(
102 Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).input(ix, a),
103 );
104 }
105 }
106 for (ix, &rank) in outputs.iter().enumerate() {
107 for a in 0..rank {
108 axes.push(
109 Axis::new(alphabet.next().unwrap(), inputs.len(), outputs.len()).output(ix, a),
110 );
111 }
112 }
113 AxesMapping::new(inputs.len(), outputs.len(), axes)
114 }
115
116 pub fn natural(inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<AxesMapping> {
117 let rank = inputs[0].rank();
118 let axes = (0..rank)
119 .zip('a'..)
120 .map(|(axis_id, repr)| Axis::natural(inputs.len(), outputs.len(), repr, axis_id))
121 .collect::<TVec<_>>();
122 AxesMapping::new(inputs.len(), outputs.len(), axes)
123 }
124
125 pub fn natural_for_rank(
126 inputs: usize,
127 outputs: usize,
128 rank: usize,
129 ) -> TractResult<AxesMapping> {
130 let axes = (0..rank)
131 .zip('a'..)
132 .map(|(axis_id, repr)| Axis::natural(inputs, outputs, repr, axis_id))
133 .collect::<TVec<_>>();
134 AxesMapping::new(inputs, outputs, axes)
135 }
136
137 pub fn iter_all_axes(&self) -> impl Iterator<Item = &Axis> {
138 self.axes.iter()
139 }
140
141 pub fn input_count(&self) -> usize {
142 self.input_count
143 }
144
145 pub fn output_count(&self) -> usize {
146 self.output_count
147 }
148
149 pub fn axis_positions(&self, io: InOut, p: impl AxisPattern) -> TractResult<&[usize]> {
150 let axis = self.axis(p)?;
151 Ok(match io {
152 InOut::In(i) => &*axis.inputs[i],
153 InOut::Out(o) => &*axis.outputs[o],
154 })
155 }
156
157 pub fn rank(&self, io: InOut) -> usize {
158 match io {
159 InOut::In(i) => self.iter_all_axes().map(|axis| axis.inputs[i].len()).sum(),
160 InOut::Out(o) => self.iter_all_axes().map(|axis| axis.outputs[o].len()).sum(),
161 }
162 }
163
164 fn search(&self, p: impl AxisPattern) -> TractResult<usize> {
165 p.search(self).with_context(|| format!("Axis {p:?} not found in {self}"))
166 }
167
168 pub fn axis(&self, p: impl AxisPattern) -> TractResult<&Axis> {
169 Ok(&self.axes[self.search(p)?])
170 }
171
172 fn axis_mut(&mut self, p: impl AxisPattern) -> TractResult<&mut Axis> {
173 let ix = self.search(p)?;
174 Ok(&mut self.axes[ix])
175 }
176
177 pub fn axes(&self, io: InOut) -> impl Iterator<Item = &Axis> {
178 (0..self.rank(io)).map(move |ix| self.axis((io, ix)).unwrap())
179 }
180
181 pub fn track_axis(&self, from: impl AxisPattern, to: InOut) -> TractResult<Option<usize>> {
182 let axis = self.axis(from)?;
183 let positions = axis.interface(to);
184 Ok(if positions.len() == 1 { Some(positions[0]) } else { None })
185 }
186
187 pub fn renaming(mut self, axis: impl AxisPattern, name: char) -> TractResult<AxesMapping> {
188 let position = self.search(axis)?;
189 let old_label = self.axes[position].repr;
190 if let Ok(conflict) = self.axis_mut(name) {
191 conflict.repr = old_label
192 }
193 self.axes[position].repr = name;
194 self.sort();
195 self.check()
196 }
197
198 pub fn linking(
199 mut self,
200 target: impl AxisPattern,
201 axis: impl AxisPattern,
202 ) -> TractResult<AxesMapping> {
203 let axis = self.axis(axis)?;
204 let axis_ix = self.axes.iter().position(|a| a == axis).unwrap();
205 let axis = self.axes.remove(axis_ix);
206 let target = self.axis_mut(target)?;
207 for (ia, ib) in target.inputs.iter_mut().zip(axis.inputs.iter()) {
208 ia.extend(ib.into_iter().cloned())
209 }
210 for (ia, ib) in target.outputs.iter_mut().zip(axis.outputs.iter()) {
211 ia.extend(ib.into_iter().cloned())
212 }
213 self.sort();
214 self.check()
215 }
216
217 fn sort(&mut self) {
218 let order: Vec<(usize, usize, usize, char)> = self
219 .axes
220 .iter()
221 .flat_map(|axis| {
222 axis.inputs
223 .iter()
224 .enumerate()
225 .flat_map(move |(slot, input)| {
226 input.iter().map(move |p| (1, slot, *p, axis.repr))
227 })
228 .chain(axis.outputs.iter().enumerate().flat_map(move |(slot, output)| {
229 output.iter().map(move |p| (0, slot, *p, axis.repr))
230 }))
231 })
232 .sorted()
233 .dedup()
234 .collect_vec();
235 self.axes.sort_by_key(|axis| order.iter().position(|tuple| tuple.3 == axis.repr).unwrap());
236 }
237
238 fn sorted(mut self) -> AxesMapping {
239 self.sort();
240 self
241 }
242
243 fn do_check(&self) -> TractResult<()> {
244 for axis in &self.axes {
245 ensure!(axis.inputs.len() == self.input_count);
246 ensure!(axis.outputs.len() == self.output_count);
247 ensure!(
248 axis.inputs.iter().map(|i| i.len()).sum::<usize>()
249 + axis.outputs.iter().map(|o| o.len()).sum::<usize>()
250 > 0
251 );
252 }
253 for input_ix in 0..self.input_count() {
254 for axis in 0..self.rank(InOut::In(input_ix)) {
255 ensure!(self.axis((InOut::In(input_ix), axis)).is_ok());
256 }
257 }
258 for output_ix in 0..self.output_count() {
259 for axis in 0..self.rank(InOut::Out(output_ix)) {
260 ensure!(self.axis((InOut::Out(output_ix), axis)).is_ok());
261 }
262 }
263 ensure!(self.axes.iter().map(|ax| ax.repr).duplicates().count() == 0);
264 ensure!(
265 self == &{
266 let mut x = self.clone();
267 x.sort();
268 x
269 }
270 );
271 Ok(())
272 }
273
274 pub fn check(self) -> TractResult<AxesMapping> {
275 self.do_check().with_context(|| format!("Checking {:?}", self.axes))?;
276 Ok(self)
277 }
278
279 pub fn available_label(&self) -> char {
280 ('a'..).find(|c| self.iter_all_axes().all(|axis| axis.repr != *c)).unwrap()
281 }
282
283 pub fn is_element_wise_unary(&self) -> bool {
284 self.input_count == 1
285 && self.output_count == 1
286 && self
287 .iter_all_axes()
288 .all(|axis| axis.inputs[0].len() == 1 && axis.outputs[0] == axis.inputs[0])
289 }
290
291 pub fn extract_sub_mapping(
292 &self,
293 inputs: &[usize],
294 outputs: &[usize],
295 ) -> TractResult<AxesMapping> {
296 let axes: Vec<_> = self
297 .iter_all_axes()
298 .filter(|axis| {
299 inputs.iter().any(|i| axis.inputs[*i].len() > 0)
300 || outputs.iter().any(|o| axis.outputs[*o].len() > 0)
301 })
302 .map(|axis| Axis {
303 inputs: axis
304 .inputs
305 .iter()
306 .enumerate()
307 .filter(|(ix, _)| inputs.contains(ix))
308 .map(|(_, it)| it.clone())
309 .collect(),
310 outputs: axis
311 .outputs
312 .iter()
313 .enumerate()
314 .filter(|(ix, _)| outputs.contains(ix))
315 .map(|(_, it)| it.clone())
316 .collect(),
317 repr: axis.repr,
318 })
319 .collect();
320 AxesMapping::new(inputs.len(), outputs.len(), axes)
321 }
322
323 pub fn relabel(mut self) -> TractResult<AxesMapping> {
324 for (ax, repr) in self.axes.iter_mut().zip('a'..) {
325 ax.repr = repr;
326 }
327 Ok(self)
328 }
329
330 pub fn remove_axis(&self, repr: char) -> TractResult<AxesMapping> {
331 let mut axes: TVec<Axis> =
332 self.axes.iter().filter(|axis| axis.repr != repr).cloned().collect();
333 let removed = self.axis(repr).context("Axis not found")?;
334 for input in 0..self.input_count {
335 for &position in &removed.inputs[input] {
336 for other in &mut axes {
337 other.inputs[input]
338 .iter_mut()
339 .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
340 }
341 }
342 }
343 for output in 0..self.output_count {
344 for &position in &removed.outputs[output] {
345 for other in &mut axes {
346 other.outputs[output]
347 .iter_mut()
348 .for_each(|other_pos| *other_pos -= (*other_pos > position) as usize);
349 }
350 }
351 }
352 AxesMapping::new(self.input_count, self.output_count, axes)
353 }
354
355 pub fn remove_axis_occurency(&self, slot: InOut, position: usize) -> TractResult<AxesMapping> {
356 let axis = self.axis((slot, position))?;
357 if axis.inputs.iter().map(|i| i.len()).sum::<usize>()
358 + axis.outputs.iter().map(|i| i.len()).sum::<usize>()
359 == 1
360 {
361 return self.remove_axis(axis.repr);
362 }
363 let mut axes = self.axes.clone();
364 match slot {
365 InOut::In(slot) => {
366 for axis in &mut axes {
367 axis.inputs[slot].retain(|pos| *pos != position);
368 axis.inputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);
369 }
370 }
371 InOut::Out(slot) => {
372 for axis in &mut axes {
373 axis.outputs[slot].retain(|pos| *pos != position);
374 axis.outputs[slot]
375 .iter_mut()
376 .for_each(|pos| *pos -= (*pos > position) as usize);
377 }
378 }
379 }
380 AxesMapping::new(self.input_count, self.output_count, axes)
381 }
382
383 pub fn remove_slot(&self, slot: InOut) -> TractResult<AxesMapping> {
384 let mut axes = self.clone();
385 while axes.rank(slot) > 0 {
386 axes = axes.remove_axis_occurency(slot, 0)?
387 }
388 match slot {
389 InOut::In(slot) => {
390 for axis in &mut axes.axes {
391 axis.inputs.remove(slot);
392 }
393 axes.input_count -= 1;
394 }
395 InOut::Out(slot) => {
396 for axis in &mut axes.axes {
397 axis.outputs.remove(slot);
398 }
399 axes.output_count -= 1;
400 }
401 }
402 axes.sorted().check()
403 }
404
405 pub fn with_extra_input(self, slot: usize) -> TractResult<AxesMapping> {
406 let axes: TVec<Axis> = self
407 .iter_all_axes()
408 .map(|axis| {
409 let mut axis = axis.clone();
410 axis.inputs.insert(slot, tvec!());
411 axis
412 })
413 .collect();
414 AxesMapping::new(self.input_count + 1, self.output_count, axes)
415 }
416
417 pub fn with_extra_axis(
418 mut self,
419 repr: char,
420 io: InOut,
421 position: usize,
422 ) -> TractResult<AxesMapping> {
423 let axis = Axis::new(repr, self.input_count, self.output_count);
424 self.axes.push(axis);
425 self.with_extra_axis_occurency(repr, io, position)
426 }
427
428 pub fn with_extra_axis_occurency(
429 mut self,
430 axis: impl AxisPattern,
431 io: InOut,
432 position: usize,
433 ) -> TractResult<AxesMapping> {
434 match io {
435 InOut::In(slot) => {
436 self.axes.iter_mut().for_each(|axis| {
437 axis.inputs[slot].iter_mut().for_each(|pos| *pos += (*pos >= position) as usize)
438 });
439 self.axis_mut(axis)?.inputs[slot].push(position);
440 }
441 InOut::Out(slot) => {
442 self.axes.iter_mut().for_each(|axis| {
443 axis.outputs[slot]
444 .iter_mut()
445 .for_each(|pos| *pos += (*pos >= position) as usize)
446 });
447 self.axis_mut(axis)?.outputs[slot].push(position);
448 }
449 }
450 self.sort();
451 self.check()
452 }
453
454 pub fn translate_to_axis_ops(&self) -> TractResult<Vec<AxisOp>> {
455 ensure!(self.input_count() == 1);
456 ensure!(self.output_count() == 1);
457 ensure!(self.iter_all_axes().all(|axis| axis.inputs[0].len() <= 1));
458 let rms = self
459 .iter_all_axes()
460 .filter(|a| a.outputs[0].len() == 0)
461 .sorted_by_key(|axis| -(axis.inputs[0][0] as isize))
462 .collect_vec();
463 let adds = self
464 .iter_all_axes()
465 .filter(|a| a.inputs[0].len() == 0)
466 .sorted_by_key(|axis| axis.outputs[0][0] as isize)
467 .collect_vec();
468 let permutation = rms
469 .iter()
470 .chain(adds.iter())
471 .try_fold(self.clone(), |mapping, axis| mapping.remove_axis(axis.repr))?;
472 let permutation = permutation
473 .iter_all_axes()
474 .sorted_by_key(|axis| axis.outputs[0][0])
475 .map(|axis| axis.inputs[0][0])
476 .collect_vec();
477 let permutation = perm_to_ops(&permutation);
478 let rms = rms.iter().map(|axis| AxisOp::Rm(axis.inputs[0][0]));
479 let adds = adds.iter().map(|axis| AxisOp::Add(axis.outputs[0][0]));
480 Ok(rms.chain(permutation).chain(adds).collect())
481 }
482
483 pub fn from_strs(
484 inputs: &[impl AsRef<str>],
485 outputs: &[impl AsRef<str>],
486 ) -> TractResult<AxesMapping> {
487 let mut axes = HashMap::<char, Axis>::default();
488 for (input_ix, input) in inputs.iter().enumerate() {
489 for (ix, axis) in input.as_ref().chars().enumerate() {
490 axes.entry(axis)
491 .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
492 .add_input(input_ix, ix);
493 }
494 }
495 for (output_ix, output) in outputs.iter().enumerate() {
496 for (ix, axis) in output.as_ref().chars().enumerate() {
497 axes.entry(axis)
498 .or_insert_with(|| Axis::new(axis, inputs.len(), outputs.len().max(1)))
499 .add_output(output_ix, ix);
500 }
501 }
502 if outputs.len() == 0 {
503 axes.iter_mut()
504 .sorted_by_key(|(k, _)| *k)
505 .filter(|(_, v)| v.inputs.iter().map(|input| input.len()).sum::<usize>() == 1)
506 .enumerate()
507 .for_each(|(ix, (_, v))| v.add_output(0, ix))
508 }
509 Self::new(
510 inputs.len(),
511 outputs.len().max(1),
512 axes.into_iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
513 )
514 }
515
516 pub fn to_strs(&self) -> (TVec<String>, TVec<String>) {
517 let mut inputs = tvec![];
518 let mut outputs = tvec![];
519 for input in 0..self.input_count() {
520 let s = self
521 .iter_all_axes()
522 .flat_map(|axis| {
523 axis.inputs[input].iter().map(move |position| (position, axis.repr))
524 })
525 .sorted()
526 .map(|(_, r)| r)
527 .collect();
528 inputs.push(s);
529 }
530 for output in 0..self.output_count() {
531 let s = self
532 .iter_all_axes()
533 .flat_map(|axis| {
534 axis.outputs[output].iter().map(move |position| (position, axis.repr))
535 })
536 .sorted()
537 .map(|(_, r)| r)
538 .collect();
539 outputs.push(s);
540 }
541 (inputs, outputs)
542 }
543
544 pub fn change_axis_sink(&self, io: InOut, change: &AxisOp) -> TractResult<Option<AxesMapping>> {
545 let (mut inputs, mut outputs) = self.to_strs();
546 let interface: &mut String = match io {
547 InOut::In(i) => &mut inputs[i],
548 InOut::Out(o) => &mut outputs[o],
549 };
550 let mut axes: Vec<char> = interface.chars().collect();
551 match change {
552 AxisOp::Rm(rm) => {
553 axes.remove(*rm);
554 }
555 AxisOp::Add(add) => axes.insert(*add, self.available_label()),
556 AxisOp::Move(from, to) => {
557 let c = axes.remove(*from);
558 axes.insert(*to, c);
559 }
560 _ => return Ok(None),
561 };
562 *interface = axes.into_iter().collect();
563 Ok(Some(AxesMapping::from_strs(&inputs, &outputs)?))
564 }
565
566 pub fn direct(&self, a: InOut, b: InOut) -> bool {
567 self.axes.iter().all(|axis| axis.interface(a) == axis.interface(b))
568 }
569
570 pub fn same_layout<D: DimLike>(
571 &self,
572 a: InOut,
573 b: InOut,
574 shape_a: impl AsRef<[D]>,
575 shape_b: impl AsRef<[D]>,
576 ) -> bool {
577 let shape_a = shape_a.as_ref();
578 let shape_b = shape_b.as_ref();
579 shape_a.iter().cloned().product::<D>() == shape_b.iter().cloned().product()
580 && izip!(
581 self.axes(a).zip(shape_a.iter()).filter(|(_axis, d)| **d != D::one()),
582 self.axes(b).zip(shape_b.iter()).filter(|(_axis, d)| **d != D::one())
583 )
584 .all(|(a, b)| a == b)
585 }
586
587 pub fn axis_ops_to_canonical(&self, io: InOut) -> TractResult<Vec<AxisOp>> {
588 let rank = self.rank(io);
589 let target_rank = self.axes.len();
590 let mut next_insert_axis = 0;
591 let mut permutation = tvec!();
592 for axis in &self.axes {
593 let spec = match io {
594 InOut::In(i) => axis.inputs[i].first(),
595 InOut::Out(o) => axis.outputs[o].first(),
596 };
597 if let Some(pos_in_a) = spec {
598 permutation.push(pos_in_a + target_rank - rank)
599 } else {
600 permutation.push(next_insert_axis);
601 next_insert_axis += 1;
602 }
603 }
604 let mut ops = vec![AxisOp::Add(0); target_rank - rank];
605 ops.extend(crate::ops::change_axes::perm_to_ops(&permutation));
606 Ok(ops)
607 }
608
609 pub fn view_to_canonical<D>(&self, io: InOut, view: &mut ArrayViewD<D>) -> TractResult<()> {
610 for op in self.axis_ops_to_canonical(io)? {
611 op.change_view(view)?;
612 }
613 Ok(())
614 }
615
616 pub fn view_to_canonical_mut<D>(
617 &self,
618 io: InOut,
619 view: &mut ArrayViewMutD<D>,
620 ) -> TractResult<()> {
621 for op in self.axis_ops_to_canonical(io)? {
622 op.change_view_mut(view)?;
623 }
624 Ok(())
625 }
626
627 pub fn compose(&self, other: &AxesMapping) -> TractResult<AxesMapping> {
628 ensure!(self.input_count() == 1 && self.output_count() == 1);
629 ensure!(other.input_count() == 1 && other.output_count() == 1);
630 let mut result = AxesMapping::disconnected_for_ranks(
631 &[self.rank(InOut::In(0))],
632 &[other.rank(InOut::Out(0))],
633 )?;
634 for ix in 0..result.rank(InOut::In(0)) {
635 let Some(inter) = self.track_axis((InOut::In(0), ix), InOut::Out(0))? else { continue };
636 let Some(out) = other.track_axis((InOut::In(0), inter), InOut::Out(0))? else {
637 continue;
638 };
639 result = result.linking((InOut::Out(0), out), (InOut::In(0), ix))?;
640 }
641 Ok(result)
642 }
643}
644
645impl FromStr for AxesMapping {
646 type Err = TractError;
647 fn from_str(s: &str) -> Result<Self, Self::Err> {
648 assert!(!s.contains("..."));
649 let s = s.replace(' ', "");
650 let (inputs, outputs) =
651 if let Some((i, r)) = s.split_once("->") { (i, r) } else { (&*s, "") };
652 let inputs: TVec<&str> = inputs.split(',').collect();
653 let outputs: TVec<&str> = outputs.split(',').filter(|s| s.len() > 0).collect();
654 AxesMapping::from_strs(&inputs, &outputs)
655 }
656}
657
658impl Display for AxesMapping {
659 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
660 let (inputs, outputs) = self.to_strs();
661 write!(f, "{}->{}", inputs.iter().join(","), outputs.iter().join(","))
662 }
663}
664
665#[cfg(test)]
666mod test {
667 use super::*;
668
669 fn m(s: &str) -> AxesMapping {
670 s.parse().unwrap()
671 }
672
673 #[test]
674 fn test_parse_transpose() {
675 assert_eq!(
676 m("ij->ji"),
677 AxesMapping::new(
678 1,
679 1,
680 tvec![
681 Axis::new('i', 1, 1).output(0, 1).input(0, 0),
682 Axis::new('j', 1, 1).output(0, 0).input(0, 1)
683 ]
684 )
685 .unwrap(),
686 )
687 }
688
689 #[test]
690 fn test_parse_diag() {
691 assert_eq!(
692 m("ii->i"),
693 AxesMapping::new(
694 1,
695 1,
696 tvec![Axis::new('i', 1, 1).output(0, 0).input(0, 0).input(0, 1)]
697 )
698 .unwrap(),
699 )
700 }
701
702 #[test]
703 fn test_parse_adamar_product_explicit() {
704 assert_eq!(
705 m("i,i->i"),
706 AxesMapping::new(
707 2,
708 1,
709 tvec![Axis::new('i', 2, 1).output(0, 0).input(0, 0).input(1, 0)]
710 )
711 .unwrap(),
712 )
713 }
714
715 #[test]
716 fn test_parse_inner_product_implicit() {
717 assert_eq!(m("i,i"), m("i,i->"))
718 }
719
720 #[test]
721 fn test_parse_batch_matmul() {
722 assert_eq!(
723 m("bij , bjk -> bik "),
724 AxesMapping::new(
725 2,
726 1,
727 tvec![
728 Axis::new('b', 2, 1).output(0, 0).input(0, 0).input(1, 0),
729 Axis::new('i', 2, 1).output(0, 1).input(0, 1),
730 Axis::new('j', 2, 1).input(0, 2).input(1, 1),
731 Axis::new('k', 2, 1).output(0, 2).input(1, 2)
732 ]
733 )
734 .unwrap()
735 )
736 }
737
738 #[test]
739 fn test_parse_outer_product() {
740 assert_eq!(
741 m("i,j->ij"),
742 AxesMapping::new(
743 2,
744 1,
745 tvec![
746 Axis::new('i', 2, 1).output(0, 0).input(0, 0),
747 Axis::new('j', 2, 1).output(0, 1).input(1, 0)
748 ]
749 )
750 .unwrap(),
751 )
752 }
753
754 #[test]
755 fn test_parse_bilinear() {
756 assert_eq!(
757 m("ik,jkl,il->ij"),
758 AxesMapping::new(
759 3,
760 1,
761 tvec![
762 Axis::new('i', 3, 1).output(0, 0).input(0, 0).input(2, 0),
763 Axis::new('j', 3, 1).output(0, 1).input(1, 0),
764 Axis::new('k', 3, 1).input(0, 1).input(1, 1),
765 Axis::new('l', 3, 1).input(1, 2).input(2, 1)
766 ]
767 )
768 .unwrap(),
769 )
770 }
771
772 #[test]
773 fn test_parse_complex_tensor_contraction() {
774 assert_eq!(
775 m("pqrs,tuqvr->pstuv"),
776 AxesMapping::new(
777 2,
778 1,
779 tvec![
780 Axis::new('p', 2, 1).output(0, 0).input(0, 0),
781 Axis::new('q', 2, 1).input(0, 1).input(1, 2),
782 Axis::new('r', 2, 1).input(0, 2).input(1, 4),
783 Axis::new('s', 2, 1).output(0, 1).input(0, 3),
784 Axis::new('t', 2, 1).output(0, 2).input(1, 0),
785 Axis::new('u', 2, 1).output(0, 3).input(1, 1),
786 Axis::new('v', 2, 1).output(0, 4).input(1, 3),
787 ]
788 )
789 .unwrap(),
790 )
791 }
792
793 #[test]
794 fn test_parse_complex_tensor_contraction_implicit() {
795 assert_eq!(m("pqrs,tuqvr"), m("pqrs,tuqvr->pstuv"))
796 }
797
798 #[test]
799 fn test_display_expr() {
800 assert_eq!(m("pqrs,tuqvr->pstuv").to_string(), "pqrs,tuqvr->pstuv");
801 }
802
803 #[test]
804 fn test_parse_pulsed_matmul() {
805 assert_eq!(
806 m("sij,ijk->sik"),
807 AxesMapping::new(
808 2,
809 1,
810 tvec![
811 Axis::new('i', 2, 1).output(0, 1).input(0, 1).input(1, 0),
812 Axis::new('j', 2, 1).input(0, 2).input(1, 1),
813 Axis::new('k', 2, 1).output(0, 2).input(1, 2),
814 Axis::new('s', 2, 1).output(0, 0).input(0, 0),
815 ]
816 )
817 .unwrap()
818 )
819 }
820
821 #[test]
822 fn test_parse_pulsed_batch_matmul() {
823 assert_eq!(
824 m("bsij,ijk->bsik"),
825 AxesMapping::new(
826 2,
827 1,
828 tvec![
829 Axis::new('b', 2, 1).output(0, 0).input(0, 0),
830 Axis::new('i', 2, 1).output(0, 2).input(0, 2).input(1, 0),
831 Axis::new('j', 2, 1).input(0, 3).input(1, 1),
832 Axis::new('k', 2, 1).output(0, 3).input(1, 2),
833 Axis::new('s', 2, 1).output(0, 1).input(0, 1),
834 ]
835 )
836 .unwrap()
837 )
838 }
839
840 #[test]
841 fn test_extract_sub_mapping() {
842 assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[0], &[0]).unwrap(), m("bsij->bsik"));
843 assert_eq!(m("bsij,ijk->bsik").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->bsik"));
844 assert_eq!(m("bsij,ijk->ij").extract_sub_mapping(&[1], &[0]).unwrap(), m("ijk->ij"));
845 }
846
847 #[test]
848 fn test_remove_axis_0() {
849 assert_eq!(m("ab->a").remove_axis('b').unwrap(), m("a->a"));
850 assert_eq!(m("ba->a").remove_axis('b').unwrap(), m("a->a"));
851 assert_eq!(m("a->ba").remove_axis('b').unwrap(), m("a->a"));
852 assert_eq!(m("a->ab").remove_axis('b').unwrap(), m("a->a"));
853 assert_eq!(m("ab,a->a").remove_axis('b').unwrap(), m("a,a->a"));
854 assert_eq!(m("ba,a->a").remove_axis('b').unwrap(), m("a,a->a"));
855 assert_eq!(m("a,ab->a").remove_axis('b').unwrap(), m("a,a->a"));
856 assert_eq!(m("a,ba->a").remove_axis('b').unwrap(), m("a,a->a"));
857 assert_eq!(m("a,a->ab").remove_axis('b').unwrap(), m("a,a->a"));
858 assert_eq!(m("a,a->ba").remove_axis('b').unwrap(), m("a,a->a"));
859 assert_eq!(m("bsij,ijk->bsik").remove_axis('i').unwrap(), m("bsj,jk->bsk"),);
860 }
861
862 #[test]
863 fn test_translate_to_ops_rm_add() {
864 assert_eq!(m("ab->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(1)));
865 assert_eq!(m("ba->a").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(0)));
866 assert_eq!(
867 m("ab->c").translate_to_axis_ops().unwrap(),
868 vec!(AxisOp::Rm(1), AxisOp::Rm(0), AxisOp::Add(0))
869 );
870 }
871
872 #[test]
873 fn test_translate_to_ops_add_0() {
874 assert_eq!(
875 m("bacmn->bmn").translate_to_axis_ops().unwrap(),
876 vec!(AxisOp::Rm(2), AxisOp::Rm(1))
877 );
878 }
879
880 #[test]
881 fn test_translate_to_ops_move() {
882 assert_eq!(m("ab->ba").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(1, 0)));
883 }
884
885 #[test]
886 fn test_translate_to_ops_move_20() {
887 assert_eq!(m("abc->cab").translate_to_axis_ops().unwrap(), vec!(AxisOp::Move(2, 0)));
888 }
889
890 #[test]
891 fn test_translate_to_ops_complex() {
892 assert_eq!(
893 m("anbck->backn").translate_to_axis_ops().unwrap(),
894 vec!(AxisOp::Move(2, 0), AxisOp::Move(2, 4))
895 );
896 }
897}