1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use num_traits::One;
8use tract_itertools::Itertools;
9use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue};
10use tract_ndarray::{ArrayViewD, ArrayViewMutD};
11use AxisOp::*;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum InOut {
15 Out(usize),
16 In(usize),
17}
18
19impl InOut {
20 pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
21 match self {
22 InOut::In(ix) => node.inputs[*ix],
23 InOut::Out(ix) => OutletId::new(node.id, *ix),
24 }
25 }
26
27 pub fn is_input(&self) -> bool {
28 matches!(self, InOut::In(_))
29 }
30
31 pub fn is_output(&self) -> bool {
32 matches!(self, InOut::Out(_))
33 }
34
35 pub fn slot(&self) -> usize {
36 match self {
37 InOut::Out(o) => *o,
38 InOut::In(i) => *i,
39 }
40 }
41}
42
43#[derive(Clone, Hash, Eq)]
44#[allow(clippy::large_enum_variant)] #[allow(clippy::derived_hash_with_manual_eq)] pub enum AxisOp {
47 Add(usize),
48 Rm(usize),
49 Move(usize, usize),
50 Reshape(usize, TVec<TDim>, TVec<TDim>),
51}
52
53impl Debug for AxisOp {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 AxisOp::Add(a) => write!(f, "Add({a})"),
57 AxisOp::Rm(a) => write!(f, "Rm({a})"),
58 AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
59 AxisOp::Reshape(at, from, to) => {
60 write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
61 }
62 }
63 }
64}
65
66impl PartialEq for AxisOp {
67 fn eq(&self, other: &AxisOp) -> bool {
68 if self.is_noop() && other.is_noop() {
69 true
70 } else if self.is_noop() != other.is_noop() {
71 false
72 } else {
73 match (self, other) {
74 (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
75 (Move(f1, t1), Move(f2, t2)) => {
76 (f1 == f2 && t1 == t2)
77 || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
78 }
79 (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
80 _ => false,
81 }
82 }
83 }
84}
85
86impl AxisOp {
87 pub fn canonical(&self) -> Cow<'_, AxisOp> {
88 match self {
89 Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
90 Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[0] => {
91 Cow::Owned(Add(*at + 1))
92 }
93 Reshape(at, from, to) if from.len() == 1 && to.len() == 2 && from[0] == to[1] => {
94 Cow::Owned(Add(*at))
95 }
96 Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[0] == to[0] => {
97 Cow::Owned(Rm(*at + 1))
98 }
99 Reshape(at, from, to) if from.len() == 2 && to.len() == 1 && from[1] == to[0] => {
100 Cow::Owned(Rm(*at))
101 }
102 other => Cow::Borrowed(other),
103 }
104 }
105
106 pub fn simplify(&self) -> TVec<AxisOp> {
107 match self.canonical().borrow() {
108 Reshape(_, from, to) if from == to => tvec!(),
109 Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
110 Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
111 Reshape(at, from, to) if from[0] == to[0] => {
112 Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
113 }
114 Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
115 Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
116 }
117 Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
118 .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
119 .collect(),
120 Reshape(at, from, to) if to[0] == 1.to_dim() => {
121 Reshape(*at, from.clone(), to[1..].into())
122 .simplify()
123 .into_iter()
124 .chain(std::iter::once(Add(*at)))
125 .collect()
126 }
127 Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
128 std::iter::once(Rm(at + from.len() - 1))
129 .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
130 .collect()
131 }
132 Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
133 std::iter::once(Add(at + from.len()))
134 .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
135 .collect()
136 }
137 other => tvec!(other.clone()),
138 }
139 }
140
141 pub fn transform_axis(&self, axis: usize) -> Option<usize> {
142 match self.canonical().as_ref() {
143 Add(ix) => Some(axis + (axis >= *ix) as usize),
144 Rm(ix) => {
145 if axis == *ix {
146 None
147 } else {
148 Some(axis - (axis > *ix) as usize)
149 }
150 }
151 Move(from, to) if from < to => {
152 if axis < *from || axis > *to {
153 Some(axis)
154 } else if axis == *from {
155 Some(*to)
156 } else {
157 Some(axis - 1)
158 }
159 }
160 Move(from, to) => {
161 if axis < *to || axis > *from {
162 Some(axis)
163 } else if axis == *from {
164 Some(*to)
165 } else {
166 Some(axis + 1)
167 }
168 }
169 Reshape(at, _, _) if axis < *at => Some(axis),
170 Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
171 Reshape(_, _, _) => None,
172 }
173 }
174
175 pub fn merge_incoming_change(
180 &self,
181 change: &AxisOp,
182 ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
183 match (self.canonical().as_ref(), change.canonical().as_ref()) {
184 (Add(op), Add(c)) => {
185 Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
186 }
187 (Add(op), Rm(c)) => {
188 Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
189 }
190 (Rm(op), Add(c)) => {
191 Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
192 }
193 (Rm(op), Rm(c)) => {
194 Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
195 }
196
197 (Add(x), Move(from, to)) => {
198 if x <= from.min(to) {
199 Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
200 } else if x > from.max(to) {
201 Some((Some(self.clone()), Some(change.clone())))
202 } else {
203 None
204 }
205 }
206
207 (Move(from, to), Add(x)) => {
208 if x <= from.min(to) {
209 Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
210 } else if x > from.max(to) {
211 Some((Some(Move(*from, *to)), Some(Add(*x))))
212 } else {
213 None
214 }
215 }
216
217 (Rm(x), Move(from, to)) => {
218 if x == from {
219 Some((Some(Rm(*to)), None))
220 } else if x < from.min(to) {
221 Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
222 } else if x > from.max(to) {
223 Some((Some(self.clone()), Some(change.clone())))
224 } else if from + 1 == *to && x == to {
225 Some((Some(Rm(*from)), None))
226 } else if from < to && x <= to {
227 Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
228 } else {
229 Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
230 }
231 }
232
233 (Move(from, to), Rm(x)) => {
234 if x < from.min(to) {
235 Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
236 } else if x > from.max(to) {
237 Some((Some(Move(*from, *to)), Some(Rm(*x))))
238 } else {
239 None
240 }
241 }
242
243 (Add(op), Reshape(at, from, to)) => {
244 if op <= at {
245 Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
246 } else if *op > at + from.len() {
247 Some((
248 Some(Add(*op + to.len() - from.len())),
249 Some(Reshape(*at, from.clone(), to.clone())),
250 ))
251 } else {
252 None
253 }
254 }
255 (Rm(op), Reshape(at, from, to)) => {
256 if op < at {
257 Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
258 } else if *op > at + from.len() {
259 Some((
260 Some(Rm(*op + to.len() - from.len())),
261 Some(Reshape(*at, from.clone(), to.clone())),
262 ))
263 } else {
264 None
265 }
266 }
267 (Reshape(at, from, to), Add(change)) => {
268 if change < at {
269 Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
270 } else if *change > *at + from.len() {
271 Some((
272 Some(Reshape(*at, from.clone(), to.clone())),
273 Some(Add(change + to.len() - from.len())),
274 ))
275 } else {
276 None
277 }
278 }
279 (Reshape(at, from, to), Rm(change)) => {
280 if change < at {
281 Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
282 } else if *change > *at + from.len() {
283 Some((
284 Some(Reshape(*at, from.clone(), to.clone())),
285 Some(Rm(change + to.len() - from.len())),
286 ))
287 } else {
288 None
289 }
290 }
291 (Reshape(_, _, _), Move(_, _)) => None, (Move(_, _), Reshape(_, _, _)) => None, (Reshape(_, _, _), Reshape(_, _, _)) => None, _ => None,
295 }
296 }
297
298 pub fn change_shape_array<D: DimLike>(
299 &self,
300 shape: &mut TVec<D>,
301 broadcasting: bool,
302 ) -> TractResult<()> {
303 match self.canonical().as_ref() {
304 Add(ix) => {
305 ensure!(*ix <= shape.len());
306 shape.insert(*ix, D::one());
307 }
308 Rm(ix) => {
309 ensure!(*ix < shape.len());
310 shape.remove(*ix);
311 }
312 Move(from, to) => {
313 ensure!(*from < shape.len());
314 ensure!(*to < shape.len());
315 let axis = shape.remove(*from);
316 shape.insert(*to, axis);
317 }
318 Reshape(at, from, to) => {
319 let from_volume = from.iter().product::<TDim>();
320 let to_volume = to.iter().product::<TDim>();
321 ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
322 ensure!(*at + from.len() <= shape.len());
323 if shape.len() >= from.len() + *at
324 && tract_itertools::izip!(shape.iter().skip(*at), from)
325 .all(|(shape, spec)| shape.to_dim() == *spec)
326 {
327 for _ in from {
328 shape.remove(*at);
329 }
330 for d in to.iter().rev() {
331 shape.insert(*at, d.try_into()?);
332 }
333 } else if broadcasting
334 && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
335 {
336 for _ in from {
337 shape.remove(*at);
338 }
339 for _ in to.iter().rev() {
340 shape.insert(*at, 1.into());
341 }
342 } else {
343 bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
344 }
345 }
346 }
347 Ok(())
348 }
349
350 pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
351 match self.canonical().as_ref() {
352 Add(ix) => shape.insert_axis(*ix),
353 Rm(ix) => {
354 if shape.rank() <= *ix {
355 bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
356 }
357 if shape[*ix] != 1.to_dim() {
358 bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
359 }
360 shape.remove_axis(*ix)
361 }
362 _ => {
363 let mut array = shape.to_tvec();
364 self.change_shape_array(&mut array, broadcasting)?;
365 let mut new_shape = ShapeFact::from_dims(array);
366 std::mem::swap(shape, &mut new_shape);
367 Ok(())
368 }
369 }
370 }
371
372 pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
373 if self.required_rank() > tensor.rank() && tensor.datum_type().is_opaque() {
374 let inner_change = self.trim_left(tensor.rank())?;
375 for opaque in tensor.as_slice_mut::<Opaque>()? {
376 if let Some(bqv) = opaque.downcast_ref::<BlockQuantValue>() {
377 let mut new_shape: TVec<usize> = bqv.fact.shape().into();
378 inner_change.change_shape_array(&mut new_shape, false)?;
379 let new_bqv = BlockQuantValue {
380 value: Arc::clone(&bqv.value),
381 fact: BlockQuantFact::new(bqv.fact.format.clone(), new_shape),
382 };
383 *opaque = Opaque(Arc::new(new_bqv));
384 } else {
385 bail!("Can't apply {self:?} to opaque tensor {tensor:?}");
386 }
387 }
388 return Ok(());
389 }
390 ensure!(self.required_rank() <= tensor.rank());
391 match self.canonical().as_ref() {
392 Add(ix) => tensor.insert_axis(*ix),
393 Rm(ix) => tensor.remove_axis(*ix),
394 Move(from, to) => {
395 let mut tmp = tensor.clone().move_axis(*from, *to)?;
396 std::mem::swap(tensor, &mut tmp);
397 Ok(())
398 }
399 Reshape(at, from, to) => {
400 let mut shape: TVec<usize> = tensor.shape().into();
401 self.change_shape_array(&mut shape, true)?;
402 if tensor.set_shape(&shape).is_ok() {
403 Ok(())
404 } else if broadcasting
405 && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
406 {
407 if from.len() > to.len() {
408 for _ in to.len()..from.len() {
409 tensor.remove_axis(*at)?;
410 }
411 }
412 if to.len() > from.len() {
413 for _ in from.len()..to.len() {
414 tensor.insert_axis(*at)?;
415 }
416 }
417 Ok(())
418 } else {
419 bail!(
420 "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
421 self,
422 tensor,
423 broadcasting
424 )
425 }
426 }
427 }
428 }
429
430 pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
431 use tract_ndarray::Axis;
432 match *self {
433 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
434 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
435 AxisOp::Move(from, to) if from < to => {
436 for left in from..to {
437 view.swap_axes(left, left + 1);
438 }
439 }
440 AxisOp::Move(from, to) => {
441 for left in (to..from).rev() {
442 view.swap_axes(left, left + 1);
443 }
444 }
445 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
446 }
447 Ok(())
448 }
449
450 pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
451 use tract_ndarray::Axis;
452 match *self {
453 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
454 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
455 AxisOp::Move(from, to) if from < to => {
456 for left in from..to {
457 view.swap_axes(left, left + 1);
458 }
459 }
460 AxisOp::Move(from, to) => {
461 for left in (to..from).rev() {
462 view.swap_axes(left, left + 1);
463 }
464 }
465 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
466 }
467 Ok(())
468 }
469
470 pub fn recip(&self) -> AxisOp {
471 match self.canonical().as_ref() {
472 Add(ix) => Rm(*ix),
473 Rm(ix) => Add(*ix),
474 Move(from, to) if from == to => self.clone(),
475 Move(from, to) if *from + 1 == *to => self.clone(),
476 Move(from, to) if *from == *to + 1 => {
477 unreachable!();
478 }
479 Move(from, to) => Move(*to, *from),
480 Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
481 }
482 }
483
484 pub fn is_noop(&self) -> bool {
485 match self {
486 Move(f, t) if f == t => true,
487 Reshape(_, f, t) if f == t => true,
488 _ => false,
489 }
490 }
491
492 pub fn only_shape(&self) -> bool {
493 if self.is_noop() {
494 return true;
495 }
496 !matches!(self, Move(_, _))
497 }
498
499 pub fn wire_split_axis(
500 model: &mut TypedModel,
501 name: impl ToString,
502 outlet: OutletId,
503 axis: usize,
504 outer_dim: usize,
505 ) -> TractResult<TVec<OutletId>> {
506 let fact = model.outlet_fact(outlet)?;
507 let dim: TDim = fact.shape[axis].clone();
508 let inner_dim = dim.clone() / outer_dim;
509 let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
510 model.wire_node(name.to_string(), op, &[outlet])
511 }
512
513 pub fn wire_collapse_axis(
514 model: &mut TypedModel,
515 name: impl ToString,
516 outlet: OutletId,
517 axis: usize,
518 ) -> TractResult<TVec<OutletId>> {
519 let fact = model.outlet_fact(outlet)?;
520 let dim: TDim = fact.shape[axis].clone();
521 let next_dim: TDim = fact.shape[axis + 1].clone();
522 let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
523 model.wire_node(name.to_string(), op, &[outlet])
524 }
525
526 #[inline]
527 pub fn required_rank(&self) -> usize {
528 match self {
529 Rm(r) => r + 1,
530 Add(a) => *a,
531 Reshape(at, from, _to) => at + from.len(),
532 Move(from, to) => *from.max(to),
533 }
534 }
535
536 pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
537 Ok(match self {
538 Rm(r) if *r >= prefix => Rm(r - prefix),
539 Add(a) if *a >= prefix => Add(a - prefix),
540 Reshape(at, from, to) if *at >= prefix => {
541 Reshape(at - prefix, from.clone(), to.clone())
542 }
543 Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
544 _ => bail!("Can no trim left {self:?} by {prefix}"),
545 })
546 }
547}
548
549pub fn wire_rank_broadcast(
550 prefix: impl AsRef<str>,
551 target: &mut TypedModel,
552 inputs: &[OutletId],
553) -> TractResult<TVec<OutletId>> {
554 let facts =
555 inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
556 let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
557 let mut wires = tvec!();
558 let prefix = prefix.as_ref();
559 for i in 0..inputs.len() {
560 let mut wire = inputs[i];
561 for j in facts[i].rank()..max_rank {
562 wire =
563 target.wire_node(format!("{prefix}.fix-rank-{i}-{j}"), AxisOp::Add(0), &[wire])?[0];
564 }
565 wires.push(wire);
566 }
567 Ok(wires)
568}
569
570pub fn wire_with_rank_broadcast(
571 prefix: impl AsRef<str>,
572 target: &mut TypedModel,
573 op: impl Into<Box<dyn TypedOp>>,
574 inputs: &[OutletId],
575) -> TractResult<TVec<OutletId>> {
576 let prefix = prefix.as_ref();
577 let wires = wire_rank_broadcast(prefix, target, inputs)?;
578 target.wire_node(prefix, op.into(), &wires)
579}
580
581#[derive(Clone, Debug, PartialEq, Eq, Hash)]
582pub struct AxisChange {
583 pub outlet: OutletId,
584 pub op: AxisOp,
585}
586
587#[derive(Clone, Default, Debug)]
588pub struct AxisChangeConsequence {
589 pub substitute_op: Option<Box<dyn TypedOp>>,
590 pub wire_changes: TVec<(InOut, AxisOp)>,
591}
592
593impl AxisChangeConsequence {
594 pub fn new(
595 _model: &TypedModel,
596 node: &TypedNode,
597 op: Option<Box<dyn TypedOp>>,
598 axis_op: &AxisOp,
599 ) -> AxisChangeConsequence {
600 let mut wire_changes = tvec!();
601 for i in 0..node.inputs.len() {
602 wire_changes.push((InOut::In(i), axis_op.clone()));
603 }
604 for i in 0..node.outputs.len() {
605 wire_changes.push((InOut::Out(i), axis_op.clone()));
606 }
607 AxisChangeConsequence { wire_changes, substitute_op: op }
608 }
609}
610
611impl Op for AxisOp {
612 fn name(&self) -> StaticName {
613 match self {
614 Add(_) => "AddAxis".into(),
615 Rm(_) => "RmAxis".into(),
616 Move(_, _) => "MoveAxis".into(),
617 Reshape(_, _, _) => "Reshape".into(),
618 }
619 }
620
621 fn info(&self) -> TractResult<Vec<String>> {
622 match self {
623 Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
624 Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
625 Reshape(at, from, to) => Ok(vec![format!(
626 "Axes starting at {}: {:?} to {:?}",
627 at,
628 from.iter().join(","),
629 to.iter().join(",")
630 )]),
631 }
632 }
633
634 op_as_typed_op!();
635}
636
637impl EvalOp for AxisOp {
638 fn is_stateless(&self) -> bool {
639 true
640 }
641
642 fn eval_with_session(
643 &self,
644 _node_id: usize,
645 session: &SessionState,
646 inputs: TVec<TValue>,
647 ) -> TractResult<TVec<TValue>> {
648 let mut input = args_1!(inputs).into_tensor();
649 match self {
650 AxisOp::Reshape(skip, from, to) => {
651 let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
652 let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
653 AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
654 }
655 _ => self.change_tensor(&mut input, false)?,
656 }
657 Ok(tvec!(input.into_tvalue()))
658 }
659}
660
661impl TypedOp for AxisOp {
662 as_op!();
663
664 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
665 if self.required_rank() > inputs[0].rank() {
666 if let Some(bqf) =
667 inputs[0].opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
668 {
669 let mut new_inner_shape: TVec<usize> = bqf.shape().into();
670 self.trim_left(inputs[0].rank())?
671 .change_shape_array(&mut new_inner_shape, false)?;
672 let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_inner_shape);
673 let mut new_fact = Opaque::fact(inputs[0].shape.clone()).with_opaque_fact(new_bqf);
674 if let Some(k) = &inputs[0].konst {
675 let mut new = k.clone().into_tensor(); self.change_tensor(&mut new, false)?;
677 new_fact.konst = Some(new.into());
678 }
679 return Ok(tvec!(new_fact));
680 }
681 }
682 let mut shape = inputs[0].shape.clone();
683 self.change_shape(&mut shape, false)?;
684 let mut fact = inputs[0].datum_type.fact(shape);
685 fact.opaque_fact.clone_from(&inputs[0].opaque_fact);
686 Ok(tvec!(fact))
687 }
688
689 fn axes_mapping(
690 &self,
691 inputs: &[&TypedFact],
692 outputs: &[&TypedFact],
693 ) -> TractResult<AxesMapping> {
694 let mut axes: Vec<Axis> = (0..inputs[0].rank())
695 .zip('a'..)
696 .map(|(axis_id, repr)| {
697 let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
698 if let Some(out) = self.transform_axis(axis_id) {
699 axis = axis.output(0, out);
700 }
701 axis
702 })
703 .collect();
704 for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
705 if self.recip().transform_axis(axis).is_none() {
706 axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
707 }
708 }
709 AxesMapping::new(inputs.len(), outputs.len(), axes)
710 }
711
712 fn declutter(
713 &self,
714 model: &TypedModel,
715 node: &TypedNode,
716 ) -> TractResult<Option<TypedModelPatch>> {
717 if self.is_noop() {
718 if let Some(p) = TypedModelPatch::shunt_one_op(model, node)? {
719 return Ok(Some(p));
720 }
721 }
722 let simplified = self.simplify();
723 if simplified.len() != 1 || &simplified[0] != self {
724 let mut patch = TypedModelPatch::default();
725 let mut wire = patch.tap_model(model, node.inputs[0])?;
726 for (ix, op) in simplified.into_iter().enumerate() {
727 wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
728 }
729 patch.shunt_outside(model, node.id.into(), wire)?;
730 Ok(Some(patch))
731 } else {
732 Ok(None)
733 }
734 }
735
736 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
737 Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
738 }
739
740 fn change_axes(
741 &self,
742 _model: &TypedModel,
743 _node: &TypedNode,
744 io: InOut,
745 change: &AxisOp,
746 ) -> TractResult<Option<AxisChangeConsequence>> {
747 let op = if let InOut::Out(0) = io {
748 let more = if let Some(more) =
749 self.recip().change_axes(_model, _node, InOut::In(0), change)?
750 {
751 more
752 } else {
753 return Ok(None);
754 };
755 AxisChangeConsequence {
756 substitute_op: more.substitute_op.map(|op| {
757 if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
758 Box::new(op.recip())
759 } else {
760 op }
762 }),
763 wire_changes: more
764 .wire_changes
765 .into_iter()
766 .map(|wc| {
767 (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
768 })
769 .collect(),
770 }
771 } else if change == self {
772 AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
773 } else {
774 let (new_op, new_change) = if let Some(pair) = self.merge_incoming_change(change) {
775 pair
776 } else {
777 return Ok(None);
778 };
779 trace!(" Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
780 let substitute_op: Box<dyn TypedOp> =
781 if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
782 let mut wire_changes = tvec!();
783 if !change.is_noop() {
784 wire_changes.push((InOut::In(0), change.clone()))
785 }
786 if let Some(new_change) = new_change {
787 wire_changes.push((InOut::Out(0), new_change))
788 }
789 AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
790 };
791 Ok(Some(op))
792 }
793
794 fn concretize_dims(
795 &self,
796 _source: &TypedModel,
797 node: &TypedNode,
798 target: &mut TypedModel,
799 mapping: &HashMap<OutletId, OutletId>,
800 values: &SymbolValues,
801 ) -> TractResult<TVec<OutletId>> {
802 let op = if let AxisOp::Reshape(axis, from, to) = self {
803 AxisOp::Reshape(
804 *axis,
805 from.iter().map(|d| d.eval(values)).collect(),
806 to.iter().map(|d| d.eval(values)).collect(),
807 )
808 } else {
809 self.clone()
810 };
811 target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
812 }
813
814 fn slice(
815 &self,
816 patch: &mut TypedModelPatch,
817 _model: &TypedModel,
818 node: &TypedNode,
819 _prefix: &str,
820 inputs: &[OutletId],
821 output_axis: usize,
822 _start: &TDim,
823 _end: &TDim,
824 ) -> TractResult<Option<TVec<OutletId>>> {
825 if let Reshape(pos, _from, to) = self {
827 if output_axis >= *pos && output_axis < pos + to.len() {
828 return Ok(None);
829 }
830 }
831 patch.wire_node(&node.name, &node.op, inputs).map(Some)
832 }
833
834 fn codegen(
835 &self,
836 model: &TypedModel,
837 node: &TypedNode,
838 ) -> TractResult<Option<TypedModelPatch>> {
839 if node.outputs[0].fact.opaque_fact.is_some() {
840 return Ok(None);
841 }
842 if let Some(shape) = node.outputs[0].fact.shape.as_concrete() {
843 if !matches!(self, AxisOp::Move(_, _)) {
844 let (inputs, outputs) = model.node_facts(node.id)?;
845 let mapping = self.axes_mapping(&inputs, &outputs)?;
846 let op = IntoShape {
847 mapping,
848 len: shape.iter().product(),
849 strides: Tensor::natural_strides(shape),
850 dims: shape.into(),
851 };
852 return Ok(Some(TypedModelPatch::replace_single_op(
853 model,
854 node,
855 &node.inputs,
856 op,
857 )?));
858 }
859 }
860 Ok(None)
861 }
862}
863
864fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
866 let mut cycles: TVec<TVec<usize>> = tvec!();
867 let mut done = 0;
868 while done < perm.len() {
869 if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
870 done += 1;
871 continue;
872 }
873 let mut cycle = tvec!();
874 let mut current = done;
875 loop {
876 cycle.push(current);
877 current = perm[current];
878 if current == done {
879 break;
880 }
881 }
882 cycles.push(cycle)
883 }
884 cycles
885}
886
887fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
888 if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
889 Some((cycle[0], cycle[cycle.len() - 1]))
890 } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
891 && cycle[cycle.len() - 1] - 1 == cycle[0]
892 {
893 Some((cycle[1], cycle[0]))
894 } else {
895 None
896 }
897}
898
899fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
900 let mut changes: TVec<(usize, usize)> = tvec!();
901 'top: loop {
902 let mut reached: TVec<usize> = (0..input.len()).collect();
903 changes.iter().for_each(|(f, t)| {
904 let axis = reached.remove(*f);
905 reached.insert(*t, axis);
906 });
907 if &*reached == input {
908 return changes;
909 }
910 let remaining: TVec<usize> =
911 input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
912 let cycles = perm_to_cycles(&remaining);
913 for cycle in &cycles {
914 if let Some(rot) = is_rotation_cycle(cycle) {
915 changes.push(rot);
916 continue 'top;
917 }
918 }
919 changes.push((cycles[0][1], cycles[0][0]));
920 }
921}
922
923pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
924 perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
925}
926
927pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
928 let mut shape: TVec<TDim> = shape_spec.into();
929 fn deal_with_zero<'a>(
930 mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
931 shape: &mut [TDim],
932 ) -> TractResult<()> {
933 let mut remaining_dim_input = 1.to_dim();
934 for slot in shape.iter_mut() {
935 if *slot == (-1).into() {
936 break;
937 }
938 if *slot == 0.into() {
939 if remaining_dim_input != TDim::one() {
940 bail!("Invalid remaining dim");
941 }
942 *slot = (*input_dims.peek().context("Invalid")?).clone();
943 }
944 loop {
945 let quotient = remaining_dim_input.maybe_div(slot);
946 if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
947 remaining_dim_input *= input_dims.next().context("Invalid")?;
948 } else {
949 break;
950 }
951 }
952 remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
953 }
954 Ok(())
955 }
956
957 deal_with_zero(input.iter().peekable(), &mut shape)?;
958 shape.reverse();
959 deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
960 shape.reverse();
961
962 if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
963 let input_vol: TDim = input.iter().product();
964 let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
965 let div = input_vol.maybe_div(&shape_vol)?;
966 if div.1 != 1 {
967 bail!("invalid")
968 }
969 shape[pos] = div.0;
970 }
971 Ok(shape)
972}
973
974pub fn to_axis_ops_with_tf_rules(
975 input_orig: &[TDim],
976 output_spec: &[TDim],
977) -> TractResult<TVec<AxisOp>> {
978 let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
979 let mut stack: TVec<AxisOp> = tvec!();
980 'top: loop {
981 let current_input =
982 stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
983 op.change_shape_array(&mut shape, false)?;
984 Ok(shape)
985 })?;
986 if current_input == final_output {
987 return Ok(stack);
988 }
989 if let Some(common) =
990 current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
991 {
992 if current_input[common].is_one() {
993 stack.push(AxisOp::Rm(common));
994 } else if final_output[common].is_one() {
995 stack.push(AxisOp::Add(common));
996 } else {
997 for i in common..current_input.len() {
1000 let i_group = ¤t_input[common..i + 1];
1001 let i_volume: TDim = i_group.iter().product();
1002 for o in common..final_output.len() {
1003 let o_group = &final_output[common..o + 1];
1004 let o_volume: TDim = o_group.iter().product();
1005 if i_volume == o_volume {
1006 stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1007 continue 'top;
1008 }
1009 }
1010 }
1011 todo!()
1012 }
1013 } else if final_output.len() > current_input.len() {
1014 stack.push(AxisOp::Add(current_input.len()));
1015 } else {
1016 stack.push(AxisOp::Rm(current_input.len() - 1));
1017 }
1018 }
1019}
1020
1021#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1022pub struct IntoShape {
1023 pub mapping: AxesMapping,
1024 pub len: usize,
1025 pub dims: TVec<usize>,
1026 pub strides: TVec<isize>,
1027}
1028
1029impl Op for IntoShape {
1030 fn name(&self) -> StaticName {
1031 "IntoShape".into()
1032 }
1033
1034 fn info(&self) -> TractResult<Vec<String>> {
1035 Ok(vec![format!("{}", self.mapping)])
1036 }
1037
1038 op_as_typed_op!();
1039 impl_op_same_as!();
1040}
1041
1042impl EvalOp for IntoShape {
1043 fn is_stateless(&self) -> bool {
1044 true
1045 }
1046
1047 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1048 let mut input = args_1!(inputs).into_tensor();
1049 ensure!(input.len() == self.len);
1050 unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1051 Ok(tvec!(input.into_tvalue()))
1052 }
1053}
1054
1055impl TypedOp for IntoShape {
1056 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1057 let mut fact = inputs[0].datum_type.fact(&self.dims);
1058 if let Some(of) = &inputs[0].opaque_fact {
1059 fact = fact.with_opaque_fact(of.clone());
1060 }
1061 Ok(tvec!(fact))
1062 }
1063
1064 fn declutter(
1065 &self,
1066 model: &TypedModel,
1067 node: &TypedNode,
1068 ) -> TractResult<Option<TypedModelPatch>> {
1069 let input = model.outlet_fact(node.inputs[0])?;
1070 if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1071 return TypedModelPatch::shunt_one_op(model, node);
1072 }
1073 if let Some(succ) = model.single_succ(node.id)? {
1074 if let Some(into_shape) = succ.op_as::<IntoShape>() {
1075 let op = Self {
1076 mapping: self.mapping.compose(&into_shape.mapping)?,
1077 ..into_shape.clone()
1078 };
1079 return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1080 }
1081 }
1082 Ok(None)
1083 }
1084
1085 as_op!();
1086}
1087
1088#[cfg(test)]
1089mod test {
1090 use super::*;
1091
1092 #[test]
1093 fn test_perm_to_cycles() {
1094 assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1095 assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1096 assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1097 assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1098 assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1099 }
1100
1101 #[test]
1102 fn is_rotation() {
1103 assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1104 assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1105 }
1106
1107 #[test]
1108 fn test_perm_one_rotation() {
1109 assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1110 }
1111
1112 #[test]
1113 fn test_perm_two_rotations() {
1114 assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1115 }
1116
1117 #[test]
1118 fn test_perm_complex() {
1119 assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1120 }
1121
1122 #[test]
1129 pub fn transform_op_add_0_add_0() {
1130 let change = Add(0);
1131 let op = Add(0);
1132 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1133 }
1134
1135 #[test]
1140 pub fn transform_op_add_0_add_1() {
1141 let change = Add(0);
1142 let op = Add(1);
1143 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1144 }
1145
1146 #[test]
1151 pub fn transform_op_add_1_add_0() {
1152 let change = Add(1);
1153 let op = Add(0);
1154 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1155 }
1156
1157 #[test]
1162 pub fn transform_op_rm_0_rm_1() {
1163 let change = Rm(0);
1164 let op = Rm(1);
1165 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1166 }
1167
1168 #[test]
1173 pub fn transform_op_rm_1_rm_0() {
1174 let change = Rm(1);
1175 let op = Rm(0);
1176 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1177 }
1178
1179 #[test]
1186 pub fn transform_op_add_0_rm_0() {
1187 let change = Add(0);
1188 let op = Rm(0);
1189 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1190 }
1191
1192 #[test]
1197 pub fn transform_op_add_0_rm_1() {
1198 let change = Add(0);
1199 let op = Rm(1);
1200 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1201 }
1202
1203 #[test]
1208 pub fn transform_op_add_1_rm_0() {
1209 let change = Add(1);
1210 let op = Rm(0);
1211 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1212 }
1213
1214 #[test]
1221 pub fn transform_op_rm_1_add_0() {
1222 let change = Rm(1);
1223 let op = Add(0);
1224 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1225 }
1226
1227 #[test]
1232 pub fn transform_op_rm_0_add_1() {
1233 let change = Rm(0);
1234 let op = Add(1);
1235 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1236 }
1237
1238 #[test]
1243 pub fn transform_op_mv_02_rm_2() {
1244 let change = Move(0, 2);
1245 let op = Rm(2);
1246 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1247 }
1248}
1249
1250#[cfg(test)]
1251mod proptests {
1252 use super::*;
1253 use proptest::prelude::*;
1254
1255 #[derive(Debug)]
1256 struct ComposeProblem {
1257 input: TVec<usize>,
1258 ops: TVec<AxisOp>,
1259 }
1260
1261 impl Arbitrary for AxisOp {
1262 type Parameters = TVec<usize>;
1263 type Strategy = BoxedStrategy<AxisOp>;
1264 fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1265 let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1266 if shape.len() > 1 {
1267 ops = ops
1268 .prop_union(
1269 (0..shape.len(), 0..shape.len() - 1)
1270 .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1271 .boxed(),
1272 )
1273 .boxed()
1274 }
1275 let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1276 if rms.len() > 0 {
1277 ops = ops
1278 .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1279 .boxed()
1280 }
1281 let mergeable: Vec<AxisOp> = shape
1282 .windows(2)
1283 .enumerate()
1284 .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1285 .map(|(ix, w)| {
1286 Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1287 })
1288 .collect();
1289 if mergeable.len() > 1 {
1290 ops = ops
1291 .prop_union(
1292 (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1293 )
1294 .boxed()
1295 }
1296 ops
1297 }
1298 }
1299
1300 impl Arbitrary for ComposeProblem {
1301 type Parameters = ();
1302 type Strategy = BoxedStrategy<ComposeProblem>;
1303 fn arbitrary_with(_args: ()) -> Self::Strategy {
1304 let input = proptest::collection::vec(1usize..4, 1usize..4);
1305 fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1306 if len == 0 {
1307 Just(tvec!()).boxed()
1308 } else {
1309 AxisOp::arbitrary_with(shape.clone())
1310 .prop_flat_map(move |op| {
1311 let mut shape = shape.clone();
1312 op.change_shape_array(&mut shape, false).unwrap();
1313 tail(len - 1, shape.clone()).prop_map(move |mut t| {
1314 t.insert(0, op.clone());
1315 t
1316 })
1317 })
1318 .boxed()
1319 }
1320 }
1321 (input, 1usize..=5)
1322 .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1323 .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1324 .boxed()
1325 }
1326 }
1327
1328 impl ComposeProblem {
1329 pub fn model(&self) -> TractResult<TypedModel> {
1330 let mut model = TypedModel::default();
1331 let mut wire = model.add_source("source", i64::fact(&self.input))?;
1332 for (ix, op) in self.ops.iter().enumerate() {
1333 wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1334 }
1335 model.set_output_outlets(&[wire])?;
1336 Ok(model)
1337 }
1338
1339 fn input(&self) -> TractResult<Tensor> {
1340 unsafe {
1341 let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1342 for i in 0..t.len() {
1343 t.as_slice_mut().unwrap()[i] = i as i64;
1344 }
1345 Ok(t)
1346 }
1347 }
1348
1349 fn check(&self) -> TractResult<()> {
1350 crate::setup_test_logger();
1351 let input = self.input()?;
1352 let model = self.model()?;
1353 let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1354 let optimized = self.model()?.into_decluttered()?;
1355 let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1356 opt[0].close_enough(&raw[0], false)
1357 }
1358 }
1359
1360 proptest! {
1361 #[test]
1362 fn recip(pb in any::<AxisOp>()) {
1363 assert_eq!(pb.recip().recip(), pb);
1364 }
1365
1366 #[test]
1367 fn axis_ops(pb in any::<ComposeProblem>()) {
1368 pb.check().unwrap()
1369 }
1370 }
1371
1372 #[test]
1373 fn add_0_rm_0() {
1374 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1375 pb.check().unwrap();
1376 }
1377
1378 #[test]
1379 fn add_0_move_01() {
1380 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1381 pb.check().unwrap();
1382 }
1383
1384 #[test]
1385 fn add_0_move_01_add_1() {
1386 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1387 pb.check().unwrap();
1388 }
1389
1390 #[test]
1391 fn recip_move_01() {
1392 let op = Move(1, 0);
1393 assert_eq!(op.recip().recip(), op);
1394 }
1395
1396 #[test]
1397 fn recip_move_20() {
1398 let op = Move(2, 0);
1399 assert_eq!(op.recip().recip(), op);
1400 }
1401
1402 #[test]
1403 fn recip_move_02() {
1404 let op = Move(0, 2);
1405 assert_eq!(op.recip().recip(), op);
1406 }
1407
1408 #[test]
1409 fn add_0_add_1_move_02() {
1410 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1411 pb.check().unwrap();
1412 }
1413
1414 #[test]
1415 fn add_0_add_0() {
1416 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1417 pb.check().unwrap();
1418 }
1419
1420 #[test]
1421 fn add_0_add_0_move_02() {
1422 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1423 pb.check().unwrap();
1424 }
1425
1426 #[test]
1427 fn add_0_add_2_move_12() {
1428 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1429 pb.check().unwrap();
1430 }
1431
1432 #[test]
1433 fn add_0_add_0_move_02_rm_0() {
1434 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1435 pb.check().unwrap();
1436 }
1437
1438 #[test]
1439 fn add_0_add_0_move_20_move_20() {
1440 let pb =
1441 ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1442 pb.check().unwrap();
1443 }
1444
1445 #[test]
1446 fn move_01_add_0() {
1447 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1448 pb.check().unwrap();
1449 }
1450
1451 #[test]
1452 fn add_0_move_02_move_02() {
1453 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1454 pb.check().unwrap();
1455 }
1456
1457 #[test]
1458 fn add_0_add_2_move_20_move_12_rm_2() {
1459 let pb = ComposeProblem {
1460 input: tvec![3],
1461 ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1462 };
1463 pb.check().unwrap();
1464 }
1465
1466 #[test]
1467 fn move_02_move_02() {
1468 let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1469 pb.check().unwrap();
1470 }
1471
1472 #[test]
1473 fn rm_1_perm_10_add_0() {
1474 let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1475 pb.check().unwrap();
1476 }
1477
1478 #[test]
1479 fn add_2_move_02_move_02() {
1480 let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1481 pb.check().unwrap();
1482 }
1483
1484 #[test]
1485 fn move_01_move_20_move_20() {
1486 let pb = ComposeProblem {
1487 input: tvec![2, 3, 2],
1488 ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1489 };
1490 pb.check().unwrap();
1491 }
1492
1493 #[test]
1494 fn reshape_axes_tracking() {
1495 let pb = ComposeProblem {
1496 input: tvec![2, 2, 2],
1497 ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1498 };
1499 pb.check().unwrap();
1500 }
1501
1502 #[test]
1503 fn simplify_reshape() {
1504 macro_rules! d {
1505 ($($dim: expr),*) => { tvec!($($dim.to_dim()),*) }
1506 }
1507 assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1508 assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1509 assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1510 assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1511 assert_eq!(
1512 Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1513 tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1514 );
1515 assert_eq!(
1516 Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1517 tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1518 );
1519 assert_eq!(
1520 Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1521 tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1522 );
1523 assert_eq!(
1524 Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1525 tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1526 );
1527 assert_eq!(
1528 Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1529 tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1530 );
1531 assert_eq!(
1532 Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1533 tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1534 );
1535 assert_eq!(
1536 Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1537 tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1538 );
1539 assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1540 }
1541
1542 macro_rules! s {
1543 ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1544 }
1545
1546 macro_rules! r {
1547 ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1548 AxisOp::Reshape($at, tvec!($($from.into()),*), tvec!($($to.into()),*))
1549 }
1550 }
1551
1552 #[test]
1553 fn compute_invalid() {
1554 assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1555 }
1556
1557 #[test]
1558 fn compute_with_leading_zero() {
1559 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1560 }
1561
1562 #[test]
1563 fn compute_with_leading_zero_with_flatten() {
1564 assert_eq!(
1565 &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1566 s![2, 3, 35]
1567 )
1568 }
1569
1570 #[test]
1571 fn compute_with_trailing_zero() {
1572 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1573 }
1574
1575 #[test]
1576 fn compute_bug_1() {
1577 let table = SymbolScope::default();
1578 let s = table.new_with_prefix("S");
1579 assert_eq!(
1580 &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1581 s![s, 1, 256]
1582 )
1583 }
1584
1585 #[test]
1586 fn compute_bug_2() {
1587 let table = SymbolScope::default();
1588 let b = table.new_with_prefix("B");
1589 let s = table.new_with_prefix("S");
1590 assert_eq!(
1591 &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1592 s![s, b, 256]
1593 )
1594 }
1595
1596 #[test]
1597 fn axis_op_rm_begin() {
1598 assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1599 }
1600
1601 #[test]
1602 fn axis_op_rm_end() {
1603 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1604 }
1605
1606 #[test]
1607 fn axis_op_insert_begin() {
1608 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1609 }
1610
1611 #[test]
1612 fn axis_op_insert_end() {
1613 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1614 }
1615
1616 #[test]
1617 fn axis_op_merge() {
1618 assert_eq!(
1619 &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1620 &[r!(2 ; 5,7 => 35 )]
1621 )
1622 }
1623
1624 #[test]
1625 fn axis_op_complex() {
1626 assert_eq!(
1627 &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1628 &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1629 )
1630 }
1631}