1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use tract_itertools::Itertools;
8use tract_ndarray::{ArrayViewD, ArrayViewMutD};
9use AxisOp::*;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub enum InOut {
13 Out(usize),
14 In(usize),
15}
16
17impl InOut {
18 pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
19 match self {
20 InOut::In(ix) => node.inputs[*ix],
21 InOut::Out(ix) => OutletId::new(node.id, *ix),
22 }
23 }
24}
25
26#[derive(Clone, Hash, Eq)]
27#[allow(clippy::large_enum_variant)] #[allow(clippy::derived_hash_with_manual_eq)] pub enum AxisOp {
30 Add(usize),
31 Rm(usize),
32 Move(usize, usize),
33 Reshape(usize, TVec<TDim>, TVec<TDim>),
34}
35
36impl Debug for AxisOp {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 AxisOp::Add(a) => write!(f, "Add({a})"),
40 AxisOp::Rm(a) => write!(f, "Rm({a})"),
41 AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
42 AxisOp::Reshape(at, from, to) => {
43 write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
44 }
45 }
46 }
47}
48
49impl PartialEq for AxisOp {
50 fn eq(&self, other: &AxisOp) -> bool {
51 if self.is_noop() && other.is_noop() {
52 true
53 } else if self.is_noop() != other.is_noop() {
54 false
55 } else {
56 match (self, other) {
57 (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
58 (Move(f1, t1), Move(f2, t2)) => {
59 (f1 == f2 && t1 == t2)
60 || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
61 }
62 (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
63 _ => false,
64 }
65 }
66 }
67}
68
69impl AxisOp {
70 pub fn canonical(&self) -> Cow<AxisOp> {
71 match self {
72 Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
73 Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[0] => {
74 Cow::Owned(Add(*at + 1))
75 }
76 Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[1] => {
77 Cow::Owned(Add(*at))
78 }
79 Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[0] == to[0] => {
80 Cow::Owned(Rm(*at + 1))
81 }
82 Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[1] == to[0] => {
83 Cow::Owned(Rm(*at))
84 }
85 other => Cow::Borrowed(other),
86 }
87 }
88
89 pub fn simplify(&self) -> TVec<AxisOp> {
90 match self.canonical().borrow() {
91 Reshape(_, from, to) if from == to => tvec!(),
92 Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
93 Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
94 Reshape(at, from, to) if from[0] == to[0] => {
95 Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
96 }
97 Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
98 Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
99 }
100 Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
101 .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
102 .collect(),
103 Reshape(at, from, to) if to[0] == 1.to_dim() => {
104 Reshape(*at, from.clone(), to[1..].into())
105 .simplify()
106 .into_iter()
107 .chain(std::iter::once(Add(*at)))
108 .collect()
109 }
110 Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
111 std::iter::once(Rm(at + from.len() - 1))
112 .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
113 .collect()
114 }
115 Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
116 std::iter::once(Add(at + from.len()))
117 .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
118 .collect()
119 }
120 other => tvec!(other.clone()),
121 }
122 }
123
124 pub fn transform_axis(&self, axis: usize) -> Option<usize> {
125 match self.canonical().as_ref() {
126 Add(ix) => Some(axis + (axis >= *ix) as usize),
127 Rm(ix) => {
128 if axis == *ix {
129 None
130 } else {
131 Some(axis - (axis > *ix) as usize)
132 }
133 }
134 Move(from, to) if from < to => {
135 if axis < *from || axis > *to {
136 Some(axis)
137 } else if axis == *from {
138 Some(*to)
139 } else {
140 Some(axis - 1)
141 }
142 }
143 Move(from, to) => {
144 if axis < *to || axis > *from {
145 Some(axis)
146 } else if axis == *from {
147 Some(*to)
148 } else {
149 Some(axis + 1)
150 }
151 }
152 Reshape(at, _, _) if axis < *at => Some(axis),
153 Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
154 Reshape(_, _, _) => None,
155 }
156 }
157
158 pub fn merge_incoming_change(
163 &self,
164 change: &AxisOp,
165 ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
166 match (self.canonical().as_ref(), change.canonical().as_ref()) {
167 (Add(op), Add(c)) => {
168 Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
169 }
170 (Add(op), Rm(c)) => {
171 Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
172 }
173 (Rm(op), Add(c)) => {
174 Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
175 }
176 (Rm(op), Rm(c)) => {
177 Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
178 }
179
180 (Add(x), Move(from, to)) => {
181 if x <= from.min(to) {
182 Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
183 } else if x > from.max(to) {
184 Some((Some(self.clone()), Some(change.clone())))
185 } else {
186 None
187 }
188 }
189
190 (Move(from, to), Add(x)) => {
191 if x <= from.min(to) {
192 Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
193 } else if x > from.max(to) {
194 Some((Some(Move(*from, *to)), Some(Add(*x))))
195 } else {
196 None
197 }
198 }
199
200 (Rm(x), Move(from, to)) => {
201 if x == from {
202 Some((Some(Rm(*to)), None))
203 } else if x < from.min(to) {
204 Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
205 } else if x > from.max(to) {
206 Some((Some(self.clone()), Some(change.clone())))
207 } else if from + 1 == *to && x == to {
208 Some((Some(Rm(*from)), None))
209 } else if from < to && x <= to {
210 Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
211 } else {
212 Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
213 }
214 }
215
216 (Move(from, to), Rm(x)) => {
217 if x < from.min(to) {
218 Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
219 } else if x > from.max(to) {
220 Some((Some(Move(*from, *to)), Some(Rm(*x))))
221 } else {
222 None
223 }
224 }
225
226 (Add(op), Reshape(at, from, to)) => {
227 if op <= at {
228 Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
229 } else if *op > at + from.len() {
230 Some((
231 Some(Add(*op + to.len() - from.len())),
232 Some(Reshape(*at, from.clone(), to.clone())),
233 ))
234 } else {
235 None
236 }
237 }
238 (Rm(op), Reshape(at, from, to)) => {
239 if op < at {
240 Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
241 } else if *op > at + from.len() {
242 Some((
243 Some(Rm(*op + to.len() - from.len())),
244 Some(Reshape(*at, from.clone(), to.clone())),
245 ))
246 } else {
247 None
248 }
249 }
250 (Reshape(at, from, to), Add(change)) => {
251 if change < at {
252 Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
253 } else if *change > *at + from.len() {
254 Some((
255 Some(Reshape(*at, from.clone(), to.clone())),
256 Some(Add(change + to.len() - from.len())),
257 ))
258 } else {
259 None
260 }
261 }
262 (Reshape(at, from, to), Rm(change)) => {
263 if change < at {
264 Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
265 } else if *change > *at + from.len() {
266 Some((
267 Some(Reshape(*at, from.clone(), to.clone())),
268 Some(Rm(change + to.len() - from.len())),
269 ))
270 } else {
271 None
272 }
273 }
274 (Reshape(_, _, _), Move(_, _)) => None, (Move(_, _), Reshape(_, _, _)) => None, (Reshape(_, _, _), Reshape(_, _, _)) => None, _ => None,
278 }
279 }
280
281 pub fn change_shape_array<D: DimLike>(
282 &self,
283 shape: &mut TVec<D>,
284 broadcasting: bool,
285 ) -> TractResult<()> {
286 match self.canonical().as_ref() {
287 Add(ix) => {
288 ensure!(*ix <= shape.len());
289 shape.insert(*ix, D::one());
290 }
291 Rm(ix) => {
292 ensure!(*ix < shape.len());
293 shape.remove(*ix);
294 }
295 Move(from, to) => {
296 ensure!(*from < shape.len());
297 ensure!(*to < shape.len());
298 let axis = shape.remove(*from);
299 shape.insert(*to, axis);
300 }
301 Reshape(at, from, to) => {
302 let from_volume = from.iter().product::<TDim>();
303 let to_volume = to.iter().product::<TDim>();
304 ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
305 ensure!(*at + from.len() <= shape.len());
306 if shape.len() >= from.len() + *at
307 && tract_itertools::izip!(shape.iter().skip(*at), from)
308 .all(|(shape, spec)| shape.to_dim() == *spec)
309 {
310 for _ in from {
311 shape.remove(*at);
312 }
313 for d in to.iter().rev() {
314 shape.insert(*at, d.try_into()?);
315 }
316 } else if broadcasting
317 && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
318 {
319 for _ in from {
320 shape.remove(*at);
321 }
322 for _ in to.iter().rev() {
323 shape.insert(*at, 1.into());
324 }
325 } else {
326 bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
327 }
328 }
329 }
330 Ok(())
331 }
332
333 pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
334 match self.canonical().as_ref() {
335 Add(ix) => shape.insert_axis(*ix),
336 Rm(ix) => {
337 if shape.rank() <= *ix {
338 bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
339 }
340 if shape[*ix] != 1.to_dim() {
341 bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
342 }
343 shape.remove_axis(*ix)
344 }
345 _ => {
346 let mut array = shape.to_tvec();
347 self.change_shape_array(&mut array, broadcasting)?;
348 let mut new_shape = ShapeFact::from_dims(array);
349 std::mem::swap(shape, &mut new_shape);
350 Ok(())
351 }
352 }
353 }
354
355 pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
356 match self.canonical().as_ref() {
357 Add(ix) => {
358 ensure!(*ix <= tensor.rank());
359 tensor.insert_axis(*ix)
360 }
361 Rm(ix) => {
362 ensure!(*ix < tensor.rank());
363 tensor.remove_axis(*ix)
364 }
365 Move(from, to) => {
366 ensure!(*from < tensor.rank());
367 ensure!(*to < tensor.rank());
368 let mut tmp = tensor.clone().move_axis(*from, *to)?;
369 std::mem::swap(tensor, &mut tmp);
370 Ok(())
371 }
372 Reshape(at, from, to) => {
373 let mut shape: TVec<usize> = tensor.shape().into();
374 self.change_shape_array(&mut shape, false)?;
375 if tensor.set_shape(&shape).is_ok() {
376 Ok(())
377 } else if broadcasting
378 && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
379 {
380 if from.len() > to.len() {
381 for _ in to.len()..from.len() {
382 tensor.remove_axis(*at)?;
383 }
384 }
385 if to.len() > from.len() {
386 for _ in from.len()..to.len() {
387 tensor.insert_axis(*at)?;
388 }
389 }
390 Ok(())
391 } else {
392 bail!(
393 "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
394 self,
395 tensor,
396 broadcasting
397 )
398 }
399 }
400 }
401 }
402
403 pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
404 use tract_ndarray::Axis;
405 match *self {
406 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
407 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
408 AxisOp::Move(from, to) if from < to => {
409 for left in from..to {
410 view.swap_axes(left, left + 1);
411 }
412 }
413 AxisOp::Move(from, to) => {
414 for left in (to..from).rev() {
415 view.swap_axes(left, left + 1);
416 }
417 }
418 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
419 }
420 Ok(())
421 }
422
423 pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
424 use tract_ndarray::Axis;
425 match *self {
426 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
427 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
428 AxisOp::Move(from, to) if from < to => {
429 for left in from..to {
430 view.swap_axes(left, left + 1);
431 }
432 }
433 AxisOp::Move(from, to) => {
434 for left in (to..from).rev() {
435 view.swap_axes(left, left + 1);
436 }
437 }
438 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
439 }
440 Ok(())
441 }
442
443 pub fn recip(&self) -> AxisOp {
444 match self.canonical().as_ref() {
445 Add(ix) => Rm(*ix),
446 Rm(ix) => Add(*ix),
447 Move(from, to) if from == to => self.clone(),
448 Move(from, to) if *from + 1 == *to => self.clone(),
449 Move(from, to) if *from == *to + 1 => {
450 unreachable!();
451 }
452 Move(from, to) => Move(*to, *from),
453 Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
454 }
455 }
456
457 pub fn is_noop(&self) -> bool {
458 match self {
459 Move(f, t) if f == t => true,
460 Reshape(_, f, t) if f == t => true,
461 _ => false,
462 }
463 }
464
465 pub fn only_shape(&self) -> bool {
466 if self.is_noop() {
467 return true;
468 }
469 !matches!(self, Move(_, _))
470 }
471
472 pub fn wire_split_axis(
473 model: &mut TypedModel,
474 name: impl ToString,
475 outlet: OutletId,
476 axis: usize,
477 outer_dim: usize,
478 ) -> TractResult<TVec<OutletId>> {
479 let fact = model.outlet_fact(outlet)?;
480 let dim: TDim = fact.shape[axis].clone();
481 let inner_dim = dim.clone() / outer_dim;
482 let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
483 model.wire_node(name.to_string(), op, &[outlet])
484 }
485
486 pub fn wire_collapse_axis(
487 model: &mut TypedModel,
488 name: impl ToString,
489 outlet: OutletId,
490 axis: usize,
491 ) -> TractResult<TVec<OutletId>> {
492 let fact = model.outlet_fact(outlet)?;
493 let dim: TDim = fact.shape[axis].clone();
494 let next_dim: TDim = fact.shape[axis + 1].clone();
495 let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
496 model.wire_node(name.to_string(), op, &[outlet])
497 }
498}
499
500pub fn wire_rank_broadcast(
501 prefix: impl AsRef<str>,
502 target: &mut TypedModel,
503 inputs: &[OutletId],
504) -> TractResult<TVec<OutletId>> {
505 let facts =
506 inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
507 let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
508 let mut wires = tvec!();
509 let prefix = prefix.as_ref();
510 for i in 0..inputs.len() {
511 let mut wire = inputs[i];
512 for j in facts[i].rank()..max_rank {
513 wire =
514 target.wire_node(format!("{prefix}.fix-rank-{i}-{j}"), AxisOp::Add(0), &[wire])?[0];
515 }
516 wires.push(wire);
517 }
518 Ok(wires)
519}
520
521pub fn wire_with_rank_broadcast(
522 prefix: impl AsRef<str>,
523 target: &mut TypedModel,
524 op: impl Into<Box<dyn TypedOp>>,
525 inputs: &[OutletId],
526) -> TractResult<TVec<OutletId>> {
527 let prefix = prefix.as_ref();
528 let wires = wire_rank_broadcast(prefix, target, inputs)?;
529 target.wire_node(prefix, op.into(), &wires)
530}
531
532#[derive(Clone, Debug, PartialEq, Eq, Hash)]
533pub struct AxisChange {
534 pub outlet: OutletId,
535 pub op: AxisOp,
536}
537
538#[derive(Clone, Default, Debug)]
539pub struct AxisChangeConsequence {
540 pub substitute_op: Option<Box<dyn TypedOp>>,
541 pub wire_changes: TVec<(InOut, AxisOp)>,
542}
543
544impl AxisChangeConsequence {
545 pub fn new(
546 _model: &TypedModel,
547 node: &TypedNode,
548 op: Option<Box<dyn TypedOp>>,
549 axis_op: &AxisOp,
550 ) -> AxisChangeConsequence {
551 let mut wire_changes = tvec!();
552 for i in 0..node.inputs.len() {
553 wire_changes.push((InOut::In(i), axis_op.clone()));
554 }
555 for i in 0..node.outputs.len() {
556 wire_changes.push((InOut::Out(i), axis_op.clone()));
557 }
558 AxisChangeConsequence { wire_changes, substitute_op: op }
559 }
560}
561
562impl Op for AxisOp {
563 fn name(&self) -> Cow<str> {
564 match self {
565 Add(_) => "AddAxis".into(),
566 Rm(_) => "RmAxis".into(),
567 Move(_, _) => "MoveAxis".into(),
568 Reshape(_, _, _) => "Reshape".into(),
569 }
570 }
571
572 fn info(&self) -> TractResult<Vec<String>> {
573 match self {
574 Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
575 Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
576 Reshape(at, from, to) => Ok(vec![format!(
577 "Axes starting at {}: {:?} to {:?}",
578 at,
579 from.iter().join(","),
580 to.iter().join(",")
581 )]),
582 }
583 }
584
585 op_as_typed_op!();
586}
587
588impl EvalOp for AxisOp {
589 fn is_stateless(&self) -> bool {
590 true
591 }
592
593 fn eval_with_session(
594 &self,
595 session: &SessionState,
596 inputs: TVec<TValue>,
597 ) -> TractResult<TVec<TValue>> {
598 let mut input = args_1!(inputs).into_tensor();
599 match self {
600 AxisOp::Reshape(skip, from, to) => {
601 let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
602 let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
603 AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
604 }
605 _ => self.change_tensor(&mut input, false)?,
606 }
607 Ok(tvec!(input.into_tvalue()))
608 }
609}
610
611impl TypedOp for AxisOp {
612 as_op!();
613
614 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
615 let mut shape = inputs[0].shape.clone();
616 self.change_shape(&mut shape, false)
617 .with_context(|| format!("Applying {self:?} to {:?}", inputs[0]))?;
618 let mut fact = inputs[0].datum_type.fact(shape);
619 fact.opaque_fact.clone_from(&inputs[0].opaque_fact);
620 Ok(tvec!(fact))
621 }
622
623 fn axes_mapping(
624 &self,
625 inputs: &[&TypedFact],
626 outputs: &[&TypedFact],
627 ) -> TractResult<AxesMapping> {
628 let mut axes: Vec<Axis> = (0..inputs[0].rank())
629 .zip('a'..)
630 .map(|(axis_id, repr)| {
631 let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
632 if let Some(out) = self.transform_axis(axis_id) {
633 axis = axis.output(0, out);
634 }
635 axis
636 })
637 .collect();
638 for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
639 if self.recip().transform_axis(axis).is_none() {
640 axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
641 }
642 }
643 AxesMapping::new(inputs.len(), outputs.len(), axes)
644 }
645
646 fn declutter(
647 &self,
648 model: &TypedModel,
649 node: &TypedNode,
650 ) -> TractResult<Option<TypedModelPatch>> {
651 if self.is_noop() {
652 if let Some(p) = TypedModelPatch::shunt_one_op(model, node)? {
653 return Ok(Some(p));
654 }
655 }
656 let simplified = self.simplify();
657 if simplified.len() != 1 || &simplified[0] != self {
658 let mut patch = TypedModelPatch::default();
659 let mut wire = patch.tap_model(model, node.inputs[0])?;
660 for (ix, op) in simplified.into_iter().enumerate() {
661 wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
662 }
663 patch.shunt_outside(model, node.id.into(), wire)?;
664 Ok(Some(patch))
665 } else {
666 Ok(None)
667 }
668 }
669
670 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
671 Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
672 }
673
674 fn change_axes(
675 &self,
676 _model: &TypedModel,
677 _node: &TypedNode,
678 io: InOut,
679 change: &AxisOp,
680 ) -> TractResult<Option<AxisChangeConsequence>> {
681 let op = if let InOut::Out(0) = io {
682 let more = if let Some(more) =
683 self.recip().change_axes(_model, _node, InOut::In(0), change)?
684 {
685 more
686 } else {
687 return Ok(None);
688 };
689 AxisChangeConsequence {
690 substitute_op: more.substitute_op.map(|op| {
691 if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
692 Box::new(op.recip())
693 } else {
694 op }
696 }),
697 wire_changes: more
698 .wire_changes
699 .into_iter()
700 .map(|wc| {
701 (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
702 })
703 .collect(),
704 }
705 } else if change == self {
706 AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
707 } else {
708 let (new_op, new_change) = if let Some(pair) = self.merge_incoming_change(change) {
709 pair
710 } else {
711 return Ok(None);
712 };
713 trace!(
714 " Change:{:?} self:{:?} -> change:{:?} op:{:?}",
715 change,
716 self,
717 new_change,
718 new_op
719 );
720 let substitute_op: Box<dyn TypedOp> =
721 if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
722 let mut wire_changes = tvec!();
723 if !change.is_noop() {
724 wire_changes.push((InOut::In(0), change.clone()))
725 }
726 if let Some(new_change) = new_change {
727 wire_changes.push((InOut::Out(0), new_change))
728 }
729 AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
730 };
731 Ok(Some(op))
732 }
733
734 fn concretize_dims(
735 &self,
736 _source: &TypedModel,
737 node: &TypedNode,
738 target: &mut TypedModel,
739 mapping: &HashMap<OutletId, OutletId>,
740 values: &SymbolValues,
741 ) -> TractResult<TVec<OutletId>> {
742 let op = if let AxisOp::Reshape(axis, from, to) = self {
743 AxisOp::Reshape(
744 *axis,
745 from.iter().map(|d| d.eval(values)).collect(),
746 to.iter().map(|d| d.eval(values)).collect(),
747 )
748 } else {
749 self.clone()
750 };
751 target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
752 }
753
754 fn codegen(
755 &self,
756 model: &TypedModel,
757 node: &TypedNode,
758 ) -> TractResult<Option<TypedModelPatch>> {
759 if let Some(shape) = node.outputs[0].fact.shape.as_concrete() {
760 if !matches!(self, AxisOp::Move(_, _)) {
761 let (inputs, outputs) = model.node_facts(node.id)?;
762 let mapping = self.axes_mapping(&inputs, &outputs)?;
763 let op = IntoShape {
764 mapping,
765 len: shape.iter().product(),
766 strides: Tensor::natural_strides(shape),
767 dims: shape.into(),
768 };
769 return Ok(Some(TypedModelPatch::replace_single_op(
770 model,
771 node,
772 &node.inputs,
773 op,
774 )?));
775 }
776 }
777 Ok(None)
778 }
779}
780
781fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
783 let mut cycles: TVec<TVec<usize>> = tvec!();
784 let mut done = 0;
785 while done < perm.len() {
786 if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
787 done += 1;
788 continue;
789 }
790 let mut cycle = tvec!();
791 let mut current = done;
792 loop {
793 cycle.push(current);
794 current = perm[current];
795 if current == done {
796 break;
797 }
798 }
799 cycles.push(cycle)
800 }
801 cycles
802}
803
804fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
805 if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
806 Some((cycle[0], cycle[cycle.len() - 1]))
807 } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
808 && cycle[cycle.len() - 1] - 1 == cycle[0]
809 {
810 Some((cycle[1], cycle[0]))
811 } else {
812 None
813 }
814}
815
816fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
817 let mut changes: TVec<(usize, usize)> = tvec!();
818 'top: loop {
819 let mut reached: TVec<usize> = (0..input.len()).collect();
820 changes.iter().for_each(|(f, t)| {
821 let axis = reached.remove(*f);
822 reached.insert(*t, axis);
823 });
824 if &*reached == input {
825 return changes;
826 }
827 let remaining: TVec<usize> =
828 input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
829 let cycles = perm_to_cycles(&remaining);
830 for cycle in &cycles {
831 if let Some(rot) = is_rotation_cycle(cycle) {
832 changes.push(rot);
833 continue 'top;
834 }
835 }
836 changes.push((cycles[0][1], cycles[0][0]));
837 }
838}
839
840pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
841 perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
842}
843
844pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
845 let mut shape: TVec<TDim> = shape_spec.into();
846 fn deal_with_zero<'a>(
847 mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
848 shape: &mut [TDim],
849 ) -> TractResult<()> {
850 let mut remaining_dim_input = 1.to_dim();
851 for slot in shape.iter_mut() {
852 if *slot == (-1).into() {
853 break;
854 }
855 if *slot == 0.into() {
856 if remaining_dim_input != TDim::one() {
857 bail!("Invalid remaining dim");
858 }
859 *slot = (*input_dims.peek().context("Invalid")?).clone();
860 }
861 loop {
862 let quotient = remaining_dim_input.maybe_div(slot);
863 if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
864 remaining_dim_input *= input_dims.next().context("Invalid")?;
865 } else {
866 break;
867 }
868 }
869 remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
870 }
871 Ok(())
872 }
873
874 deal_with_zero(input.iter().peekable(), &mut shape)?;
875 shape.reverse();
876 deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
877 shape.reverse();
878
879 if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
880 let input_vol: TDim = input.iter().product();
881 let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
882 let div = input_vol.maybe_div(&shape_vol)?;
883 if div.1 != 1 {
884 bail!("invalid")
885 }
886 shape[pos] = div.0;
887 }
888 Ok(shape)
889}
890
891pub fn to_axis_ops_with_tf_rules(
892 input_orig: &[TDim],
893 output_spec: &[TDim],
894) -> TractResult<TVec<AxisOp>> {
895 let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
896 let mut stack: TVec<AxisOp> = tvec!();
897 'top: loop {
898 let current_input =
899 stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
900 op.change_shape_array(&mut shape, false)?;
901 Ok(shape)
902 })?;
903 if current_input == final_output {
904 return Ok(stack);
905 }
906 if let Some(common) =
907 current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
908 {
909 if current_input[common].is_one() {
910 stack.push(AxisOp::Rm(common));
911 } else if final_output[common].is_one() {
912 stack.push(AxisOp::Add(common));
913 } else {
914 for i in common..current_input.len() {
917 let i_group = ¤t_input[common..i + 1];
918 let i_volume: TDim = i_group.iter().product();
919 for o in common..final_output.len() {
920 let o_group = &final_output[common..o + 1];
921 let o_volume: TDim = o_group.iter().product();
922 if i_volume == o_volume {
923 stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
924 continue 'top;
925 }
926 }
927 }
928 todo!()
929 }
930 } else if final_output.len() > current_input.len() {
931 stack.push(AxisOp::Add(current_input.len()));
932 } else {
933 stack.push(AxisOp::Rm(current_input.len() - 1));
934 }
935 }
936}
937
938#[derive(Clone, Debug, PartialEq, Eq, Hash)]
939pub struct IntoShape {
940 pub mapping: AxesMapping,
941 pub len: usize,
942 pub dims: TVec<usize>,
943 pub strides: TVec<isize>,
944}
945
946impl Op for IntoShape {
947 fn name(&self) -> Cow<str> {
948 "IntoShape".into()
949 }
950
951 fn info(&self) -> TractResult<Vec<String>> {
952 Ok(vec![format!("{}", self.mapping)])
953 }
954
955 op_as_typed_op!();
956 impl_op_same_as!();
957}
958
959impl EvalOp for IntoShape {
960 fn is_stateless(&self) -> bool {
961 true
962 }
963
964 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
965 let mut input = args_1!(inputs).into_tensor();
966 ensure!(input.len() == self.len);
967 unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
968 Ok(tvec!(input.into_tvalue()))
969 }
970}
971
972impl TypedOp for IntoShape {
973 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
974 Ok(tvec!(inputs[0].datum_type.fact(&self.dims)))
975 }
976
977 fn declutter(
978 &self,
979 model: &TypedModel,
980 node: &TypedNode,
981 ) -> TractResult<Option<TypedModelPatch>> {
982 let input = model.outlet_fact(node.inputs[0])?;
983 if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
984 return TypedModelPatch::shunt_one_op(model, node);
985 }
986 if let Some(succ) = model.single_succ(node.id)? {
987 if let Some(into_shape) = succ.op_as::<IntoShape>() {
988 let op = Self {
989 mapping: self.mapping.compose(&into_shape.mapping)?,
990 ..into_shape.clone()
991 };
992 return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
993 }
994 }
995 Ok(None)
996 }
997
998 as_op!();
999}
1000
1001#[cfg(test)]
1002mod test {
1003 use super::*;
1004
1005 #[test]
1006 fn test_perm_to_cycles() {
1007 assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1008 assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1009 assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1010 assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1011 assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1012 }
1013
1014 #[test]
1015 fn is_rotation() {
1016 assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1017 assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1018 }
1019
1020 #[test]
1021 fn test_perm_one_rotation() {
1022 assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1023 }
1024
1025 #[test]
1026 fn test_perm_two_rotations() {
1027 assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1028 }
1029
1030 #[test]
1031 fn test_perm_complex() {
1032 assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1033 }
1034
1035 #[test]
1042 pub fn transform_op_add_0_add_0() {
1043 let change = Add(0);
1044 let op = Add(0);
1045 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1046 }
1047
1048 #[test]
1053 pub fn transform_op_add_0_add_1() {
1054 let change = Add(0);
1055 let op = Add(1);
1056 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1057 }
1058
1059 #[test]
1064 pub fn transform_op_add_1_add_0() {
1065 let change = Add(1);
1066 let op = Add(0);
1067 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1068 }
1069
1070 #[test]
1075 pub fn transform_op_rm_0_rm_1() {
1076 let change = Rm(0);
1077 let op = Rm(1);
1078 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1079 }
1080
1081 #[test]
1086 pub fn transform_op_rm_1_rm_0() {
1087 let change = Rm(1);
1088 let op = Rm(0);
1089 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1090 }
1091
1092 #[test]
1099 pub fn transform_op_add_0_rm_0() {
1100 let change = Add(0);
1101 let op = Rm(0);
1102 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1103 }
1104
1105 #[test]
1110 pub fn transform_op_add_0_rm_1() {
1111 let change = Add(0);
1112 let op = Rm(1);
1113 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1114 }
1115
1116 #[test]
1121 pub fn transform_op_add_1_rm_0() {
1122 let change = Add(1);
1123 let op = Rm(0);
1124 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1125 }
1126
1127 #[test]
1134 pub fn transform_op_rm_1_add_0() {
1135 let change = Rm(1);
1136 let op = Add(0);
1137 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1138 }
1139
1140 #[test]
1145 pub fn transform_op_rm_0_add_1() {
1146 let change = Rm(0);
1147 let op = Add(1);
1148 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1149 }
1150
1151 #[test]
1156 pub fn transform_op_mv_02_rm_2() {
1157 let change = Move(0, 2);
1158 let op = Rm(2);
1159 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1160 }
1161}
1162
1163#[cfg(test)]
1164mod proptests {
1165 use super::*;
1166 use proptest::prelude::*;
1167
1168 #[derive(Debug)]
1169 struct ComposeProblem {
1170 input: TVec<usize>,
1171 ops: TVec<AxisOp>,
1172 }
1173
1174 impl Arbitrary for AxisOp {
1175 type Parameters = TVec<usize>;
1176 type Strategy = BoxedStrategy<AxisOp>;
1177 fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1178 let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1179 if shape.len() > 1 {
1180 ops = ops
1181 .prop_union(
1182 (0..shape.len(), 0..shape.len() - 1)
1183 .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1184 .boxed(),
1185 )
1186 .boxed()
1187 }
1188 let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1189 if rms.len() > 0 {
1190 ops = ops
1191 .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1192 .boxed()
1193 }
1194 let mergeable: Vec<AxisOp> = shape
1195 .windows(2)
1196 .enumerate()
1197 .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1198 .map(|(ix, w)| {
1199 Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1200 })
1201 .collect();
1202 if mergeable.len() > 1 {
1203 ops = ops
1204 .prop_union(
1205 (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1206 )
1207 .boxed()
1208 }
1209 ops
1210 }
1211 }
1212
1213 impl Arbitrary for ComposeProblem {
1214 type Parameters = ();
1215 type Strategy = BoxedStrategy<ComposeProblem>;
1216 fn arbitrary_with(_args: ()) -> Self::Strategy {
1217 let input = proptest::collection::vec(1usize..4, 1usize..4);
1218 fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1219 if len == 0 {
1220 Just(tvec!()).boxed()
1221 } else {
1222 AxisOp::arbitrary_with(shape.clone())
1223 .prop_flat_map(move |op| {
1224 let mut shape = shape.clone();
1225 op.change_shape_array(&mut shape, false).unwrap();
1226 tail(len - 1, shape.clone()).prop_map(move |mut t| {
1227 t.insert(0, op.clone());
1228 t
1229 })
1230 })
1231 .boxed()
1232 }
1233 }
1234 (input, 1usize..=5)
1235 .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1236 .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1237 .boxed()
1238 }
1239 }
1240
1241 impl ComposeProblem {
1242 pub fn model(&self) -> TractResult<TypedModel> {
1243 let mut model = TypedModel::default();
1244 let mut wire = model.add_source("source", i64::fact(&self.input))?;
1245 for (ix, op) in self.ops.iter().enumerate() {
1246 wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1247 }
1248 model.set_output_outlets(&[wire])?;
1249 Ok(model)
1250 }
1251
1252 fn input(&self) -> TractResult<Tensor> {
1253 unsafe {
1254 let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1255 for i in 0..t.len() {
1256 t.as_slice_mut().unwrap()[i] = i as i64;
1257 }
1258 Ok(t)
1259 }
1260 }
1261
1262 fn check(&self) -> TractResult<()> {
1263 crate::setup_test_logger();
1264 let input = self.input()?;
1265 let model = self.model()?;
1266 let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1267 let optimized = self.model()?.into_decluttered()?;
1268 let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1269 opt[0].close_enough(&raw[0], false)
1270 }
1271 }
1272
1273 proptest! {
1274 #[test]
1275 fn recip(pb in any::<AxisOp>()) {
1276 assert_eq!(pb.recip().recip(), pb);
1277 }
1278
1279 #[test]
1280 fn axis_ops(pb in any::<ComposeProblem>()) {
1281 pb.check().unwrap()
1282 }
1283 }
1284
1285 #[test]
1286 fn add_0_rm_0() {
1287 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1288 pb.check().unwrap();
1289 }
1290
1291 #[test]
1292 fn add_0_move_01() {
1293 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1294 pb.check().unwrap();
1295 }
1296
1297 #[test]
1298 fn add_0_move_01_add_1() {
1299 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1300 pb.check().unwrap();
1301 }
1302
1303 #[test]
1304 fn recip_move_01() {
1305 let op = Move(1, 0);
1306 assert_eq!(op.recip().recip(), op);
1307 }
1308
1309 #[test]
1310 fn recip_move_20() {
1311 let op = Move(2, 0);
1312 assert_eq!(op.recip().recip(), op);
1313 }
1314
1315 #[test]
1316 fn recip_move_02() {
1317 let op = Move(0, 2);
1318 assert_eq!(op.recip().recip(), op);
1319 }
1320
1321 #[test]
1322 fn add_0_add_1_move_02() {
1323 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1324 pb.check().unwrap();
1325 }
1326
1327 #[test]
1328 fn add_0_add_0() {
1329 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1330 pb.check().unwrap();
1331 }
1332
1333 #[test]
1334 fn add_0_add_0_move_02() {
1335 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1336 pb.check().unwrap();
1337 }
1338
1339 #[test]
1340 fn add_0_add_2_move_12() {
1341 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1342 pb.check().unwrap();
1343 }
1344
1345 #[test]
1346 fn add_0_add_0_move_02_rm_0() {
1347 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1348 pb.check().unwrap();
1349 }
1350
1351 #[test]
1352 fn add_0_add_0_move_20_move_20() {
1353 let pb =
1354 ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1355 pb.check().unwrap();
1356 }
1357
1358 #[test]
1359 fn move_01_add_0() {
1360 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1361 pb.check().unwrap();
1362 }
1363
1364 #[test]
1365 fn add_0_move_02_move_02() {
1366 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1367 pb.check().unwrap();
1368 }
1369
1370 #[test]
1371 fn add_0_add_2_move_20_move_12_rm_2() {
1372 let pb = ComposeProblem {
1373 input: tvec![3],
1374 ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1375 };
1376 pb.check().unwrap();
1377 }
1378
1379 #[test]
1380 fn move_02_move_02() {
1381 let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1382 pb.check().unwrap();
1383 }
1384
1385 #[test]
1386 fn rm_1_perm_10_add_0() {
1387 let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1388 pb.check().unwrap();
1389 }
1390
1391 #[test]
1392 fn add_2_move_02_move_02() {
1393 let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1394 pb.check().unwrap();
1395 }
1396
1397 #[test]
1398 fn move_01_move_20_move_20() {
1399 let pb = ComposeProblem {
1400 input: tvec![2, 3, 2],
1401 ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1402 };
1403 pb.check().unwrap();
1404 }
1405
1406 #[test]
1407 fn reshape_axes_tracking() {
1408 let pb = ComposeProblem {
1409 input: tvec![2, 2, 2],
1410 ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1411 };
1412 pb.check().unwrap();
1413 }
1414
1415 #[test]
1416 fn simplify_reshape() {
1417 macro_rules! d {
1418 ($($dim: expr),*) => { tvec!($($dim.to_dim()),*) }
1419 }
1420 assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1421 assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1422 assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1423 assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1424 assert_eq!(
1425 Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1426 tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1427 );
1428 assert_eq!(
1429 Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1430 tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1431 );
1432 assert_eq!(
1433 Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1434 tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1435 );
1436 assert_eq!(
1437 Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1438 tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1439 );
1440 assert_eq!(
1441 Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1442 tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1443 );
1444 assert_eq!(
1445 Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1446 tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1447 );
1448 assert_eq!(
1449 Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1450 tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1451 );
1452 assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1453 }
1454
1455 macro_rules! s {
1456 ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1457 }
1458
1459 macro_rules! r {
1460 ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1461 AxisOp::Reshape($at, tvec!($($from.into()),*), tvec!($($to.into()),*))
1462 }
1463 }
1464
1465 #[test]
1466 fn compute_invalid() {
1467 assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1468 }
1469
1470 #[test]
1471 fn compute_with_leading_zero() {
1472 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1473 }
1474
1475 #[test]
1476 fn compute_with_leading_zero_with_flatten() {
1477 assert_eq!(
1478 &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1479 s![2, 3, 35]
1480 )
1481 }
1482
1483 #[test]
1484 fn compute_with_trailing_zero() {
1485 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1486 }
1487
1488 #[test]
1489 fn compute_bug_1() {
1490 let table = SymbolScope::default();
1491 let s = table.new_with_prefix("S");
1492 assert_eq!(
1493 &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1494 s![s, 1, 256]
1495 )
1496 }
1497
1498 #[test]
1499 fn compute_bug_2() {
1500 let table = SymbolScope::default();
1501 let b = table.new_with_prefix("B");
1502 let s = table.new_with_prefix("S");
1503 assert_eq!(
1504 &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1505 s![s, b, 256]
1506 )
1507 }
1508
1509 #[test]
1510 fn axis_op_rm_begin() {
1511 assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1512 }
1513
1514 #[test]
1515 fn axis_op_rm_end() {
1516 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1517 }
1518
1519 #[test]
1520 fn axis_op_insert_begin() {
1521 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1522 }
1523
1524 #[test]
1525 fn axis_op_insert_end() {
1526 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1527 }
1528
1529 #[test]
1530 fn axis_op_merge() {
1531 assert_eq!(
1532 &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1533 &[r!(2 ; 5,7 => 35 )]
1534 )
1535 }
1536
1537 #[test]
1538 fn axis_op_complex() {
1539 assert_eq!(
1540 &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1541 &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1542 )
1543 }
1544}