1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use AxisOp::*;
8use num_traits::One;
9use tract_itertools::Itertools;
10use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
11use tract_ndarray::{ArrayViewD, ArrayViewMutD};
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub enum InOut {
15 Out(usize),
16 In(usize),
17}
18
19impl InOut {
20 pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
21 match self {
22 InOut::In(ix) => node.inputs[*ix],
23 InOut::Out(ix) => OutletId::new(node.id, *ix),
24 }
25 }
26
27 pub fn is_input(&self) -> bool {
28 matches!(self, InOut::In(_))
29 }
30
31 pub fn is_output(&self) -> bool {
32 matches!(self, InOut::Out(_))
33 }
34
35 pub fn slot(&self) -> usize {
36 match self {
37 InOut::Out(o) => *o,
38 InOut::In(i) => *i,
39 }
40 }
41}
42
43#[derive(Clone, Hash, Eq)]
44#[allow(clippy::large_enum_variant)] #[allow(clippy::derived_hash_with_manual_eq)] pub enum AxisOp {
47 Add(usize),
48 Rm(usize),
49 Move(usize, usize),
50 Reshape(usize, TVec<TDim>, TVec<TDim>),
51}
52
53impl Debug for AxisOp {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 AxisOp::Add(a) => write!(f, "Add({a})"),
57 AxisOp::Rm(a) => write!(f, "Rm({a})"),
58 AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
59 AxisOp::Reshape(at, from, to) => {
60 write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
61 }
62 }
63 }
64}
65
66impl PartialEq for AxisOp {
67 fn eq(&self, other: &AxisOp) -> bool {
68 if self.is_noop() && other.is_noop() {
69 true
70 } else if self.is_noop() != other.is_noop() {
71 false
72 } else {
73 match (self, other) {
74 (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
75 (Move(f1, t1), Move(f2, t2)) => {
76 (f1 == f2 && t1 == t2)
77 || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
78 }
79 (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
80 _ => false,
81 }
82 }
83 }
84}
85
86impl AxisOp {
87 pub fn canonical(&self) -> Cow<'_, AxisOp> {
88 match self {
89 Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
90 Reshape(at, from, to)
91 if from.len() == 1 && to.len() == 2 && from[0] == to[0] && to[1].is_one() =>
92 {
93 Cow::Owned(Add(*at + 1))
94 }
95 Reshape(at, from, to)
96 if from.len() == 1 && to.len() == 2 && from[0] == to[1] && to[0].is_one() =>
97 {
98 Cow::Owned(Add(*at))
99 }
100 Reshape(at, from, to)
101 if from.len() == 2 && to.len() == 1 && from[0] == to[0] && from[1].is_one() =>
102 {
103 Cow::Owned(Rm(*at + 1))
104 }
105 Reshape(at, from, to)
106 if from.len() == 2 && to.len() == 1 && from[1] == to[0] && from[0].is_one() =>
107 {
108 Cow::Owned(Rm(*at))
109 }
110 other => Cow::Borrowed(other),
111 }
112 }
113
114 pub fn simplify(&self) -> TVec<AxisOp> {
115 match self.canonical().borrow() {
116 Reshape(_, from, to) if from == to => tvec!(),
117 Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
118 Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
119 Reshape(at, from, to) if from[0] == to[0] => {
120 Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
121 }
122 Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
123 Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
124 }
125 Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
126 .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
127 .collect(),
128 Reshape(at, from, to) if to[0] == 1.to_dim() => {
129 Reshape(*at, from.clone(), to[1..].into())
130 .simplify()
131 .into_iter()
132 .chain(std::iter::once(Add(*at)))
133 .collect()
134 }
135 Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
136 std::iter::once(Rm(at + from.len() - 1))
137 .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
138 .collect()
139 }
140 Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
141 std::iter::once(Add(at + from.len()))
142 .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
143 .collect()
144 }
145 other => tvec!(other.clone()),
146 }
147 }
148
149 pub fn transform_axis(&self, axis: usize) -> Option<usize> {
150 match self.canonical().as_ref() {
151 Add(ix) => Some(axis + (axis >= *ix) as usize),
152 Rm(ix) => {
153 if axis == *ix {
154 None
155 } else {
156 Some(axis - (axis > *ix) as usize)
157 }
158 }
159 Move(from, to) if from < to => {
160 if axis < *from || axis > *to {
161 Some(axis)
162 } else if axis == *from {
163 Some(*to)
164 } else {
165 Some(axis - 1)
166 }
167 }
168 Move(from, to) => {
169 if axis < *to || axis > *from {
170 Some(axis)
171 } else if axis == *from {
172 Some(*to)
173 } else {
174 Some(axis + 1)
175 }
176 }
177 Reshape(at, _, _) if axis < *at => Some(axis),
178 Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
179 Reshape(_, _, _) => None,
180 }
181 }
182
183 pub fn merge_incoming_change(
188 &self,
189 change: &AxisOp,
190 ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
191 match (self.canonical().as_ref(), change.canonical().as_ref()) {
192 (Add(op), Add(c)) => {
193 Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
194 }
195 (Add(op), Rm(c)) => {
196 Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
197 }
198 (Rm(op), Add(c)) => {
199 Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
200 }
201 (Rm(op), Rm(c)) => {
202 Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
203 }
204
205 (Add(x), Move(from, to)) => {
206 if x <= from.min(to) {
207 Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
208 } else if x > from.max(to) {
209 Some((Some(self.clone()), Some(change.clone())))
210 } else {
211 None
212 }
213 }
214
215 (Move(from, to), Add(x)) => {
216 if x <= from.min(to) {
217 Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
218 } else if x > from.max(to) {
219 Some((Some(Move(*from, *to)), Some(Add(*x))))
220 } else {
221 None
222 }
223 }
224
225 (Rm(x), Move(from, to)) => {
226 if x == from {
227 Some((Some(Rm(*to)), None))
228 } else if x < from.min(to) {
229 Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
230 } else if x > from.max(to) {
231 Some((Some(self.clone()), Some(change.clone())))
232 } else if from + 1 == *to && x == to {
233 Some((Some(Rm(*from)), None))
234 } else if from < to && x <= to {
235 Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
236 } else {
237 Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
238 }
239 }
240
241 (Move(from, to), Rm(x)) => {
242 if x < from.min(to) {
243 Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
244 } else if x > from.max(to) {
245 Some((Some(Move(*from, *to)), Some(Rm(*x))))
246 } else {
247 None
248 }
249 }
250
251 (Add(op), Reshape(at, from, to)) => {
252 if op <= at {
253 Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
254 } else if *op > at + from.len() {
255 Some((
256 Some(Add(*op + to.len() - from.len())),
257 Some(Reshape(*at, from.clone(), to.clone())),
258 ))
259 } else {
260 None
261 }
262 }
263 (Rm(op), Reshape(at, from, to)) => {
264 if op < at {
265 Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
266 } else if *op > at + from.len() {
267 Some((
268 Some(Rm(*op + to.len() - from.len())),
269 Some(Reshape(*at, from.clone(), to.clone())),
270 ))
271 } else {
272 None
273 }
274 }
275 (Reshape(at, from, to), Add(change)) => {
276 if change < at {
277 Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
278 } else if *change > *at + from.len() {
279 Some((
280 Some(Reshape(*at, from.clone(), to.clone())),
281 Some(Add(change + to.len() - from.len())),
282 ))
283 } else {
284 None
285 }
286 }
287 (Reshape(at, from, to), Rm(change)) => {
288 if change < at {
289 Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
290 } else if *change > *at + from.len() {
291 Some((
292 Some(Reshape(*at, from.clone(), to.clone())),
293 Some(Rm(change + to.len() - from.len())),
294 ))
295 } else {
296 None
297 }
298 }
299 (Reshape(_, _, _), Move(_, _)) => None, (Move(_, _), Reshape(_, _, _)) => None, (Reshape(_, _, _), Reshape(_, _, _)) => None, _ => None,
303 }
304 }
305
306 pub fn change_shape_array<D: DimLike>(
307 &self,
308 shape: &mut TVec<D>,
309 broadcasting: bool,
310 ) -> TractResult<()> {
311 match self.canonical().as_ref() {
312 Add(ix) => {
313 ensure!(*ix <= shape.len());
314 shape.insert(*ix, D::one());
315 }
316 Rm(ix) => {
317 ensure!(*ix < shape.len());
318 shape.remove(*ix);
319 }
320 Move(from, to) => {
321 ensure!(*from < shape.len());
322 ensure!(*to < shape.len());
323 let axis = shape.remove(*from);
324 shape.insert(*to, axis);
325 }
326 Reshape(at, from, to) => {
327 let from_volume = from.iter().product::<TDim>();
328 let to_volume = to.iter().product::<TDim>();
329 ensure!(from_volume == to_volume, "{from_volume} should be equal to {to_volume}");
330 ensure!(*at + from.len() <= shape.len());
331 if shape.len() >= from.len() + *at
332 && tract_itertools::izip!(shape.iter().skip(*at), from)
333 .all(|(shape, spec)| shape.to_dim() == *spec)
334 {
335 for _ in from {
336 shape.remove(*at);
337 }
338 for d in to.iter().rev() {
339 shape.insert(*at, d.try_into()?);
340 }
341 } else if broadcasting
342 && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
343 {
344 for _ in from {
345 shape.remove(*at);
346 }
347 for _ in to.iter().rev() {
348 shape.insert(*at, 1.into());
349 }
350 } else {
351 bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
352 }
353 }
354 }
355 Ok(())
356 }
357
358 pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
359 match self.canonical().as_ref() {
360 Add(ix) => shape.insert_axis(*ix),
361 Rm(ix) => {
362 if shape.rank() <= *ix {
363 bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
364 }
365 if shape[*ix] != 1.to_dim() {
366 bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
367 }
368 shape.remove_axis(*ix)
369 }
370 _ => {
371 let mut array = shape.to_tvec();
372 self.change_shape_array(&mut array, broadcasting)?;
373 let mut new_shape = ShapeFact::from_dims(array);
374 std::mem::swap(shape, &mut new_shape);
375 Ok(())
376 }
377 }
378 }
379
380 pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
381 if tensor.storage_as::<BlockQuantStorage>().is_some() {
382 let bqs = tensor.try_storage_as::<BlockQuantStorage>()?.clone();
383 let mut new_shape: TVec<usize> = tensor.shape().into();
384 self.change_shape_array(&mut new_shape, false)?;
385 let mut new_tensor = bqs.into_tensor_with_shape(tensor.datum_type(), &new_shape);
386 std::mem::swap(tensor, &mut new_tensor);
387 return Ok(());
388 }
389 ensure!(self.required_rank() <= tensor.rank());
390 match self.canonical().as_ref() {
391 Add(ix) => tensor.insert_axis(*ix),
392 Rm(ix) => tensor.remove_axis(*ix),
393 Move(from, to) => {
394 let mut tmp = tensor.clone().move_axis(*from, *to)?;
395 std::mem::swap(tensor, &mut tmp);
396 Ok(())
397 }
398 Reshape(at, from, to) => {
399 let mut shape: TVec<usize> = tensor.shape().into();
400 self.change_shape_array(&mut shape, true)?;
401 if tensor.set_shape(&shape).is_ok() {
402 Ok(())
403 } else if broadcasting
404 && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
405 {
406 if from.len() > to.len() {
407 for _ in to.len()..from.len() {
408 tensor.remove_axis(*at)?;
409 }
410 }
411 if to.len() > from.len() {
412 for _ in from.len()..to.len() {
413 tensor.insert_axis(*at)?;
414 }
415 }
416 Ok(())
417 } else {
418 bail!(
419 "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
420 self,
421 tensor,
422 broadcasting
423 )
424 }
425 }
426 }
427 }
428
429 pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
430 use tract_ndarray::Axis;
431 match *self {
432 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
433 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
434 AxisOp::Move(from, to) if from < to => {
435 for left in from..to {
436 view.swap_axes(left, left + 1);
437 }
438 }
439 AxisOp::Move(from, to) => {
440 for left in (to..from).rev() {
441 view.swap_axes(left, left + 1);
442 }
443 }
444 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
445 }
446 Ok(())
447 }
448
449 pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
450 use tract_ndarray::Axis;
451 match *self {
452 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
453 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
454 AxisOp::Move(from, to) if from < to => {
455 for left in from..to {
456 view.swap_axes(left, left + 1);
457 }
458 }
459 AxisOp::Move(from, to) => {
460 for left in (to..from).rev() {
461 view.swap_axes(left, left + 1);
462 }
463 }
464 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
465 }
466 Ok(())
467 }
468
469 pub fn recip(&self) -> AxisOp {
470 match self.canonical().as_ref() {
471 Add(ix) => Rm(*ix),
472 Rm(ix) => Add(*ix),
473 Move(from, to) if from == to => self.clone(),
474 Move(from, to) if *from + 1 == *to => self.clone(),
475 Move(from, to) if *from == *to + 1 => {
476 unreachable!();
477 }
478 Move(from, to) => Move(*to, *from),
479 Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
480 }
481 }
482
483 pub fn is_noop(&self) -> bool {
484 match self {
485 Move(f, t) if f == t => true,
486 Reshape(_, f, t) if f == t => true,
487 _ => false,
488 }
489 }
490
491 pub fn only_shape(&self) -> bool {
492 if self.is_noop() {
493 return true;
494 }
495 !matches!(self, Move(_, _))
496 }
497
498 pub fn wire_split_axis(
499 model: &mut TypedModel,
500 name: impl ToString,
501 outlet: OutletId,
502 axis: usize,
503 outer_dim: usize,
504 ) -> TractResult<TVec<OutletId>> {
505 let fact = model.outlet_fact(outlet)?;
506 let dim: TDim = fact.shape[axis].clone();
507 let inner_dim = dim.clone() / outer_dim;
508 let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
509 model.wire_node(name.to_string(), op, &[outlet])
510 }
511
512 pub fn wire_collapse_axis(
513 model: &mut TypedModel,
514 name: impl ToString,
515 outlet: OutletId,
516 axis: usize,
517 ) -> TractResult<TVec<OutletId>> {
518 let fact = model.outlet_fact(outlet)?;
519 let dim: TDim = fact.shape[axis].clone();
520 let next_dim: TDim = fact.shape[axis + 1].clone();
521 let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
522 model.wire_node(name.to_string(), op, &[outlet])
523 }
524
525 #[inline]
526 pub fn required_rank(&self) -> usize {
527 match self {
528 Rm(r) => r + 1,
529 Add(a) => *a,
530 Reshape(at, from, _to) => at + from.len(),
531 Move(from, to) => *from.max(to),
532 }
533 }
534
535 pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
536 Ok(match self {
537 Rm(r) if *r >= prefix => Rm(r - prefix),
538 Add(a) if *a >= prefix => Add(a - prefix),
539 Reshape(at, from, to) if *at >= prefix => {
540 Reshape(at - prefix, from.clone(), to.clone())
541 }
542 Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
543 _ => bail!("Can no trim left {self:?} by {prefix}"),
544 })
545 }
546}
547
548pub fn wire_rank_broadcast(
549 prefix: impl AsRef<str>,
550 target: &mut TypedModel,
551 inputs: &[OutletId],
552) -> TractResult<TVec<OutletId>> {
553 let facts =
554 inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
555 let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
556 let mut wires = tvec!();
557 for i in 0..inputs.len() {
558 let mut wire = inputs[i];
559 for _ in facts[i].rank()..max_rank {
560 let name = target.unique_name(prefix.as_ref().to_string() + ".fix-rank");
561 wire = target.wire_node(name, AxisOp::Add(0), &[wire])?[0];
562 }
563 wires.push(wire);
564 }
565 Ok(wires)
566}
567
568pub fn wire_with_rank_broadcast(
569 prefix: impl AsRef<str>,
570 target: &mut TypedModel,
571 op: impl Into<Box<dyn TypedOp>>,
572 inputs: &[OutletId],
573) -> TractResult<TVec<OutletId>> {
574 let prefix = prefix.as_ref();
575 let wires = wire_rank_broadcast(prefix, target, inputs)?;
576 target.wire_node(prefix, op.into(), &wires)
577}
578
579#[derive(Clone, Debug, PartialEq, Eq, Hash)]
580pub struct AxisChange {
581 pub outlet: OutletId,
582 pub op: AxisOp,
583}
584
585#[derive(Clone, Default, Debug)]
586pub struct AxisChangeConsequence {
587 pub substitute_op: Option<Box<dyn TypedOp>>,
588 pub wire_changes: TVec<(InOut, AxisOp)>,
589}
590
591impl AxisChangeConsequence {
592 pub fn new(
593 _model: &TypedModel,
594 node: &TypedNode,
595 op: Option<Box<dyn TypedOp>>,
596 axis_op: &AxisOp,
597 ) -> AxisChangeConsequence {
598 let mut wire_changes = tvec!();
599 for i in 0..node.inputs.len() {
600 wire_changes.push((InOut::In(i), axis_op.clone()));
601 }
602 for i in 0..node.outputs.len() {
603 wire_changes.push((InOut::Out(i), axis_op.clone()));
604 }
605 AxisChangeConsequence { wire_changes, substitute_op: op }
606 }
607}
608
609impl Op for AxisOp {
610 fn name(&self) -> StaticName {
611 match self {
612 Add(_) => "AddAxis".into(),
613 Rm(_) => "RmAxis".into(),
614 Move(_, _) => "MoveAxis".into(),
615 Reshape(_, _, _) => "Reshape".into(),
616 }
617 }
618
619 fn info(&self) -> TractResult<Vec<String>> {
620 match self {
621 Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
622 Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
623 Reshape(at, from, to) => Ok(vec![format!(
624 "Axes starting at {}: {:?} to {:?}",
625 at,
626 from.iter().join(","),
627 to.iter().join(",")
628 )]),
629 }
630 }
631
632 op_as_typed_op!();
633}
634
635impl EvalOp for AxisOp {
636 fn is_stateless(&self) -> bool {
637 true
638 }
639
640 fn eval_with_session(
641 &self,
642 _node_id: usize,
643 session: &TurnState,
644 inputs: TVec<TValue>,
645 ) -> TractResult<TVec<TValue>> {
646 let mut input = args_1!(inputs).into_tensor();
647 match self {
648 AxisOp::Reshape(skip, from, to) => {
649 let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
650 let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
651 AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
652 }
653 _ => self.change_tensor(&mut input, false)?,
654 }
655 Ok(tvec!(input.into_tvalue()))
656 }
657}
658
659fn remap_uniform_tdim(expr: &TDim, axis_op: &AxisOp) -> Option<TDim> {
662 let syms = expr.symbols();
663 let coord_syms: Vec<(usize, Symbol)> = syms
664 .into_iter()
665 .filter_map(|s| {
666 let name = format!("{s}");
667 name.strip_prefix("🎯").and_then(|rest| rest.parse::<usize>().ok()).map(|k| (k, s))
668 })
669 .collect();
670
671 if coord_syms.is_empty() {
672 return Some(expr.clone());
674 }
675
676 if let AxisOp::Reshape(_, from_dims, to_dims) = axis_op.canonical().as_ref() {
678 return if from_dims.iter().all(|d| d.is_one()) && to_dims.iter().all(|d| d.is_one()) {
679 Some(expr.clone())
680 } else {
681 None
682 };
683 }
684
685 let map: HashMap<Symbol, TDim> = coord_syms
688 .into_iter()
689 .filter_map(|(k, sym)| {
690 let new_k = axis_op.transform_axis(k)?;
691 if new_k == k {
692 return None;
693 }
694 let scope = sym.scope()?;
695 Some((sym, TDim::Sym(scope.coord_sym(new_k))))
696 })
697 .collect();
698 if map.is_empty() {
699 return Some(expr.clone());
700 }
701 expr.substitute_all(&map).ok()
702}
703
704impl TypedOp for AxisOp {
705 as_op!();
706
707 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
708 if let Some(bqf) =
709 inputs[0].exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
710 {
711 let mut new_shape: TVec<usize> = bqf.shape().into();
712 self.change_shape_array(&mut new_shape, false)?;
713 let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_shape.clone());
714 let shape: TVec<TDim> = new_shape.iter().map(|d| d.to_dim()).collect();
715 let mut new_fact = inputs[0].datum_type.fact(&*shape).with_exotic_fact(new_bqf);
716 if let Some(k) = &inputs[0].konst {
717 let mut new = k.clone().into_tensor();
718 self.change_tensor(&mut new, false)?;
719 new_fact.konst = Some(new.into());
720 }
721 return Ok(tvec!(new_fact));
722 }
723 let mut shape = inputs[0].shape.clone();
724 self.change_shape(&mut shape, false)?;
725 let mut fact = inputs[0].datum_type.fact(shape);
726 fact.exotic_fact.clone_from(&inputs[0].exotic_fact);
727 if let Some(tdim) = &inputs[0].uniform_tdim {
728 fact.uniform_tdim = remap_uniform_tdim(tdim, self);
729 }
730 Ok(tvec!(fact))
731 }
732
733 fn input_roi(
734 &self,
735 model: &TypedModel,
736 node: &TypedNode,
737 ) -> TractResult<Option<TVec<Option<TDim>>>> {
738 crate::optim::propagate_roi::bubble_roi(model, node)
739 }
740
741 fn axes_mapping(
742 &self,
743 inputs: &[&TypedFact],
744 outputs: &[&TypedFact],
745 ) -> TractResult<AxesMapping> {
746 let mut axes: Vec<Axis> = (0..inputs[0].rank())
747 .zip('a'..)
748 .map(|(axis_id, repr)| {
749 let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
750 if let Some(out) = self.transform_axis(axis_id) {
751 axis = axis.output(0, out);
752 }
753 axis
754 })
755 .collect();
756 for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
757 if self.recip().transform_axis(axis).is_none() {
758 axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
759 }
760 }
761 AxesMapping::new(inputs.len(), outputs.len(), axes)
762 }
763
764 fn declutter(
765 &self,
766 model: &TypedModel,
767 node: &TypedNode,
768 ) -> TractResult<Option<TypedModelPatch>> {
769 if self.is_noop()
770 && let Some(p) = TypedModelPatch::shunt_one_op(model, node)?
771 {
772 return Ok(Some(p));
773 }
774 let simplified = self.simplify();
775 if simplified.len() != 1 || &simplified[0] != self {
776 let mut patch = TypedModelPatch::default();
777 let mut wire = patch.tap_model(model, node.inputs[0])?;
778 for (ix, op) in simplified.into_iter().enumerate() {
779 wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
780 }
781 patch.shunt_outside(model, node.id.into(), wire)?;
782 Ok(Some(patch))
783 } else {
784 Ok(None)
785 }
786 }
787
788 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
789 Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
790 }
791
792 fn change_axes(
793 &self,
794 _model: &TypedModel,
795 _node: &TypedNode,
796 io: InOut,
797 change: &AxisOp,
798 ) -> TractResult<Option<AxisChangeConsequence>> {
799 let op = if let InOut::Out(0) = io {
800 rule_if_some!(more = self.recip().change_axes(_model, _node, InOut::In(0), change)?);
801 AxisChangeConsequence {
802 substitute_op: more.substitute_op.map(|op| {
803 if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
804 Box::new(op.recip())
805 } else {
806 op }
808 }),
809 wire_changes: more
810 .wire_changes
811 .into_iter()
812 .map(|wc| {
813 (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
814 })
815 .collect(),
816 }
817 } else if change == self {
818 AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
819 } else {
820 rule_if_some!((new_op, new_change) = self.merge_incoming_change(change));
821 trace!(" Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
822 let substitute_op: Box<dyn TypedOp> =
823 if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
824 let mut wire_changes = tvec!();
825 if !change.is_noop() {
826 wire_changes.push((InOut::In(0), change.clone()))
827 }
828 if let Some(new_change) = new_change {
829 wire_changes.push((InOut::Out(0), new_change))
830 }
831 AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
832 };
833 Ok(Some(op))
834 }
835
836 fn concretize_dims(
837 &self,
838 _source: &TypedModel,
839 node: &TypedNode,
840 target: &mut TypedModel,
841 mapping: &HashMap<OutletId, OutletId>,
842 values: &SymbolValues,
843 ) -> TractResult<TVec<OutletId>> {
844 let op = if let AxisOp::Reshape(axis, from, to) = self {
845 AxisOp::Reshape(
846 *axis,
847 from.iter().map(|d| d.eval(values)).collect(),
848 to.iter().map(|d| d.eval(values)).collect(),
849 )
850 } else {
851 self.clone()
852 };
853 target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
854 }
855
856 fn slice(
857 &self,
858 patch: &mut TypedModelPatch,
859 _model: &TypedModel,
860 node: &TypedNode,
861 _prefix: &str,
862 inputs: &[OutletId],
863 output_axis: usize,
864 _start: &TDim,
865 _end: &TDim,
866 ) -> TractResult<Option<TVec<OutletId>>> {
867 if let Reshape(pos, _from, to) = self
869 && output_axis >= *pos
870 && output_axis < pos + to.len()
871 {
872 return Ok(None);
873 }
874 patch.wire_node(&node.name, &node.op, inputs).map(Some)
875 }
876
877 fn codegen(
878 &self,
879 model: &TypedModel,
880 node: &TypedNode,
881 ) -> TractResult<Option<TypedModelPatch>> {
882 if node.outputs[0].fact.exotic_fact.is_some() {
883 return Ok(None);
884 }
885 if let Some(shape) = node.outputs[0].fact.shape.as_concrete()
886 && !matches!(self, AxisOp::Move(_, _))
887 {
888 let (inputs, outputs) = model.node_facts(node.id)?;
889 let mapping = self.axes_mapping(&inputs, &outputs)?;
890 let op = IntoShape {
891 mapping,
892 len: shape.iter().product(),
893 strides: Tensor::natural_strides(shape),
894 dims: shape.into(),
895 };
896 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, op)?));
897 }
898 Ok(None)
899 }
900}
901
902fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
904 let mut cycles: TVec<TVec<usize>> = tvec!();
905 let mut done = 0;
906 while done < perm.len() {
907 if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
908 done += 1;
909 continue;
910 }
911 let mut cycle = tvec!();
912 let mut current = done;
913 loop {
914 cycle.push(current);
915 current = perm[current];
916 if current == done {
917 break;
918 }
919 }
920 cycles.push(cycle)
921 }
922 cycles
923}
924
925fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
926 if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
927 Some((cycle[0], cycle[cycle.len() - 1]))
928 } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
929 && cycle[cycle.len() - 1] - 1 == cycle[0]
930 {
931 Some((cycle[1], cycle[0]))
932 } else {
933 None
934 }
935}
936
937fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
938 let mut changes: TVec<(usize, usize)> = tvec!();
939 'top: loop {
940 let mut reached: TVec<usize> = (0..input.len()).collect();
941 changes.iter().for_each(|(f, t)| {
942 let axis = reached.remove(*f);
943 reached.insert(*t, axis);
944 });
945 if &*reached == input {
946 return changes;
947 }
948 let remaining: TVec<usize> =
949 input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
950 let cycles = perm_to_cycles(&remaining);
951 for cycle in &cycles {
952 if let Some(rot) = is_rotation_cycle(cycle) {
953 changes.push(rot);
954 continue 'top;
955 }
956 }
957 changes.push((cycles[0][1], cycles[0][0]));
958 }
959}
960
961pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
962 perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
963}
964
965pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
966 let mut shape: TVec<TDim> = shape_spec.into();
967 fn deal_with_zero<'a>(
968 mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
969 shape: &mut [TDim],
970 ) -> TractResult<()> {
971 let mut remaining_dim_input = 1.to_dim();
972 for slot in shape.iter_mut() {
973 if *slot == (-1).into() {
974 break;
975 }
976 if *slot == 0.into() {
977 if remaining_dim_input != TDim::one() {
978 bail!("Invalid remaining dim");
979 }
980 *slot = (*input_dims.peek().context("Invalid")?).clone();
981 }
982 loop {
983 let quotient = remaining_dim_input.maybe_div(slot);
984 if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
985 remaining_dim_input *= input_dims.next().context("Invalid")?;
986 } else {
987 break;
988 }
989 }
990 remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
991 }
992 Ok(())
993 }
994
995 deal_with_zero(input.iter().peekable(), &mut shape)?;
996 shape.reverse();
997 deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
998 shape.reverse();
999
1000 if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
1001 let input_vol: TDim = input.iter().product();
1002 let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
1003 let div = input_vol.maybe_div(&shape_vol)?;
1004 if div.1 != 1 {
1005 bail!("invalid")
1006 }
1007 shape[pos] = div.0;
1008 }
1009 Ok(shape)
1010}
1011
1012pub fn to_axis_ops_with_tf_rules(
1013 input_orig: &[TDim],
1014 output_spec: &[TDim],
1015) -> TractResult<TVec<AxisOp>> {
1016 let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
1017 let mut stack: TVec<AxisOp> = tvec!();
1018 'top: loop {
1019 let current_input =
1020 stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
1021 op.change_shape_array(&mut shape, false)?;
1022 Ok(shape)
1023 })?;
1024 if current_input == final_output {
1025 return Ok(stack);
1026 }
1027 if let Some(common) =
1028 current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
1029 {
1030 if current_input[common].is_one() {
1031 stack.push(AxisOp::Rm(common));
1032 } else if final_output[common].is_one() {
1033 stack.push(AxisOp::Add(common));
1034 } else {
1035 for i in common..current_input.len() {
1038 let i_group = ¤t_input[common..i + 1];
1039 let i_volume: TDim = i_group.iter().product();
1040 for o in common..final_output.len() {
1041 let o_group = &final_output[common..o + 1];
1042 let o_volume: TDim = o_group.iter().product();
1043 if i_volume == o_volume {
1044 stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1045 continue 'top;
1046 }
1047 }
1048 }
1049 todo!()
1050 }
1051 } else if final_output.len() > current_input.len() {
1052 stack.push(AxisOp::Add(current_input.len()));
1053 } else {
1054 stack.push(AxisOp::Rm(current_input.len() - 1));
1055 }
1056 }
1057}
1058
1059#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1060pub struct IntoShape {
1061 pub mapping: AxesMapping,
1062 pub len: usize,
1063 pub dims: TVec<usize>,
1064 pub strides: TVec<isize>,
1065}
1066
1067impl Op for IntoShape {
1068 fn name(&self) -> StaticName {
1069 "IntoShape".into()
1070 }
1071
1072 fn info(&self) -> TractResult<Vec<String>> {
1073 Ok(vec![format!("{}", self.mapping)])
1074 }
1075
1076 op_as_typed_op!();
1077}
1078
1079impl EvalOp for IntoShape {
1080 fn is_stateless(&self) -> bool {
1081 true
1082 }
1083
1084 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1085 let mut input = args_1!(inputs).into_tensor();
1086 ensure!(input.len() == self.len);
1087 unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1088 Ok(tvec!(input.into_tvalue()))
1089 }
1090}
1091
1092impl TypedOp for IntoShape {
1093 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1094 let mut fact = inputs[0].datum_type.fact(&self.dims);
1095 if let Some(of) = &inputs[0].exotic_fact {
1096 fact = fact.with_exotic_fact(of.clone());
1097 }
1098 Ok(tvec!(fact))
1099 }
1100
1101 fn declutter(
1102 &self,
1103 model: &TypedModel,
1104 node: &TypedNode,
1105 ) -> TractResult<Option<TypedModelPatch>> {
1106 let input = model.outlet_fact(node.inputs[0])?;
1107 if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1108 return TypedModelPatch::shunt_one_op(model, node);
1109 }
1110 if let Some(succ) = model.single_succ(node.id)?
1111 && let Some(into_shape) = succ.op_as::<IntoShape>()
1112 {
1113 let op =
1114 Self { mapping: self.mapping.compose(&into_shape.mapping)?, ..into_shape.clone() };
1115 return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1116 }
1117 Ok(None)
1118 }
1119
1120 as_op!();
1121}
1122
1123#[cfg(test)]
1124mod test {
1125 use super::*;
1126
1127 #[test]
1128 fn test_perm_to_cycles() {
1129 assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1130 assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1131 assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1132 assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1133 assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1134 }
1135
1136 #[test]
1137 fn is_rotation() {
1138 assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1139 assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1140 }
1141
1142 #[test]
1143 fn test_perm_one_rotation() {
1144 assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1145 }
1146
1147 #[test]
1148 fn test_perm_two_rotations() {
1149 assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1150 }
1151
1152 #[test]
1153 fn test_perm_complex() {
1154 assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1155 }
1156
1157 #[test]
1164 pub fn transform_op_add_0_add_0() {
1165 let change = Add(0);
1166 let op = Add(0);
1167 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1168 }
1169
1170 #[test]
1175 pub fn transform_op_add_0_add_1() {
1176 let change = Add(0);
1177 let op = Add(1);
1178 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1179 }
1180
1181 #[test]
1186 pub fn transform_op_add_1_add_0() {
1187 let change = Add(1);
1188 let op = Add(0);
1189 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1190 }
1191
1192 #[test]
1197 pub fn transform_op_rm_0_rm_1() {
1198 let change = Rm(0);
1199 let op = Rm(1);
1200 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1201 }
1202
1203 #[test]
1208 pub fn transform_op_rm_1_rm_0() {
1209 let change = Rm(1);
1210 let op = Rm(0);
1211 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1212 }
1213
1214 #[test]
1221 pub fn transform_op_add_0_rm_0() {
1222 let change = Add(0);
1223 let op = Rm(0);
1224 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1225 }
1226
1227 #[test]
1232 pub fn transform_op_add_0_rm_1() {
1233 let change = Add(0);
1234 let op = Rm(1);
1235 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1236 }
1237
1238 #[test]
1243 pub fn transform_op_add_1_rm_0() {
1244 let change = Add(1);
1245 let op = Rm(0);
1246 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1247 }
1248
1249 #[test]
1256 pub fn transform_op_rm_1_add_0() {
1257 let change = Rm(1);
1258 let op = Add(0);
1259 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1260 }
1261
1262 #[test]
1267 pub fn transform_op_rm_0_add_1() {
1268 let change = Rm(0);
1269 let op = Add(1);
1270 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1271 }
1272
1273 #[test]
1278 pub fn transform_op_mv_02_rm_2() {
1279 let change = Move(0, 2);
1280 let op = Rm(2);
1281 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1282 }
1283}
1284
1285#[cfg(test)]
1286mod proptests {
1287 use super::*;
1288 use proptest::prelude::*;
1289
1290 #[derive(Debug)]
1291 struct ComposeProblem {
1292 input: TVec<usize>,
1293 ops: TVec<AxisOp>,
1294 }
1295
1296 impl Arbitrary for AxisOp {
1297 type Parameters = TVec<usize>;
1298 type Strategy = BoxedStrategy<AxisOp>;
1299 fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1300 let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1301 if shape.len() > 1 {
1302 ops = ops
1303 .prop_union(
1304 (0..shape.len(), 0..shape.len() - 1)
1305 .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1306 .boxed(),
1307 )
1308 .boxed()
1309 }
1310 let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1311 if rms.len() > 0 {
1312 ops = ops
1313 .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1314 .boxed()
1315 }
1316 let mergeable: Vec<AxisOp> = shape
1317 .windows(2)
1318 .enumerate()
1319 .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1320 .map(|(ix, w)| {
1321 Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1322 })
1323 .collect();
1324 if mergeable.len() > 1 {
1325 ops = ops
1326 .prop_union(
1327 (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1328 )
1329 .boxed()
1330 }
1331 ops
1332 }
1333 }
1334
1335 impl Arbitrary for ComposeProblem {
1336 type Parameters = ();
1337 type Strategy = BoxedStrategy<ComposeProblem>;
1338 fn arbitrary_with(_args: ()) -> Self::Strategy {
1339 let input = proptest::collection::vec(1usize..4, 1usize..4);
1340 fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1341 if len == 0 {
1342 Just(tvec!()).boxed()
1343 } else {
1344 AxisOp::arbitrary_with(shape.clone())
1345 .prop_flat_map(move |op| {
1346 let mut shape = shape.clone();
1347 op.change_shape_array(&mut shape, false).unwrap();
1348 tail(len - 1, shape.clone()).prop_map(move |mut t| {
1349 t.insert(0, op.clone());
1350 t
1351 })
1352 })
1353 .boxed()
1354 }
1355 }
1356 (input, 1usize..=5)
1357 .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1358 .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1359 .boxed()
1360 }
1361 }
1362
1363 impl ComposeProblem {
1364 pub fn model(&self) -> TractResult<TypedModel> {
1365 let mut model = TypedModel::default();
1366 let mut wire = model.add_source("source", i64::fact(&self.input))?;
1367 for (ix, op) in self.ops.iter().enumerate() {
1368 wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1369 }
1370 model.select_output_outlets(&[wire])?;
1371 Ok(model)
1372 }
1373
1374 fn input(&self) -> TractResult<Tensor> {
1375 unsafe {
1376 let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1377 for i in 0..t.len() {
1378 t.try_as_plain_mut().unwrap().as_slice_mut().unwrap()[i] = i as i64;
1379 }
1380 Ok(t)
1381 }
1382 }
1383
1384 fn check(&self) -> TractResult<()> {
1385 crate::setup_test_logger();
1386 let input = self.input()?;
1387 let model = self.model()?;
1388 let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1389 let optimized = self.model()?.into_decluttered()?;
1390 let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1391 opt[0].close_enough(&raw[0], false)
1392 }
1393 }
1394
1395 proptest! {
1396 #[test]
1397 fn recip(pb in any::<AxisOp>()) {
1398 assert_eq!(pb.recip().recip(), pb);
1399 }
1400
1401 #[test]
1402 fn axis_ops(pb in any::<ComposeProblem>()) {
1403 pb.check().unwrap()
1404 }
1405 }
1406
1407 #[test]
1408 fn add_0_rm_0() {
1409 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1410 pb.check().unwrap();
1411 }
1412
1413 #[test]
1414 fn add_0_move_01() {
1415 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1416 pb.check().unwrap();
1417 }
1418
1419 #[test]
1420 fn add_0_move_01_add_1() {
1421 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1422 pb.check().unwrap();
1423 }
1424
1425 #[test]
1426 fn recip_move_01() {
1427 let op = Move(1, 0);
1428 assert_eq!(op.recip().recip(), op);
1429 }
1430
1431 #[test]
1432 fn recip_move_20() {
1433 let op = Move(2, 0);
1434 assert_eq!(op.recip().recip(), op);
1435 }
1436
1437 #[test]
1438 fn recip_move_02() {
1439 let op = Move(0, 2);
1440 assert_eq!(op.recip().recip(), op);
1441 }
1442
1443 #[test]
1444 fn add_0_add_1_move_02() {
1445 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1446 pb.check().unwrap();
1447 }
1448
1449 #[test]
1450 fn add_0_add_0() {
1451 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1452 pb.check().unwrap();
1453 }
1454
1455 #[test]
1456 fn add_0_add_0_move_02() {
1457 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1458 pb.check().unwrap();
1459 }
1460
1461 #[test]
1462 fn add_0_add_2_move_12() {
1463 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1464 pb.check().unwrap();
1465 }
1466
1467 #[test]
1468 fn add_0_add_0_move_02_rm_0() {
1469 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1470 pb.check().unwrap();
1471 }
1472
1473 #[test]
1474 fn add_0_add_0_move_20_move_20() {
1475 let pb =
1476 ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1477 pb.check().unwrap();
1478 }
1479
1480 #[test]
1481 fn move_01_add_0() {
1482 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1483 pb.check().unwrap();
1484 }
1485
1486 #[test]
1487 fn add_0_move_02_move_02() {
1488 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1489 pb.check().unwrap();
1490 }
1491
1492 #[test]
1493 fn add_0_add_2_move_20_move_12_rm_2() {
1494 let pb = ComposeProblem {
1495 input: tvec![3],
1496 ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1497 };
1498 pb.check().unwrap();
1499 }
1500
1501 #[test]
1502 fn move_02_move_02() {
1503 let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1504 pb.check().unwrap();
1505 }
1506
1507 #[test]
1508 fn rm_1_perm_10_add_0() {
1509 let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1510 pb.check().unwrap();
1511 }
1512
1513 #[test]
1514 fn add_2_move_02_move_02() {
1515 let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1516 pb.check().unwrap();
1517 }
1518
1519 #[test]
1520 fn move_01_move_20_move_20() {
1521 let pb = ComposeProblem {
1522 input: tvec![2, 3, 2],
1523 ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1524 };
1525 pb.check().unwrap();
1526 }
1527
1528 #[test]
1529 fn reshape_axes_tracking() {
1530 let pb = ComposeProblem {
1531 input: tvec![2, 2, 2],
1532 ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1533 };
1534 pb.check().unwrap();
1535 }
1536
1537 #[test]
1538 fn simplify_reshape() {
1539 macro_rules! d {
1540 ($($dim: expr),*) => { tvec!($($dim.to_dim()),*) }
1541 }
1542 assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1543 assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1544 assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1545 assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1546 assert_eq!(
1547 Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1548 tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1549 );
1550 assert_eq!(
1551 Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1552 tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1553 );
1554 assert_eq!(
1555 Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1556 tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1557 );
1558 assert_eq!(
1559 Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1560 tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1561 );
1562 assert_eq!(
1563 Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1564 tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1565 );
1566 assert_eq!(
1567 Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1568 tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1569 );
1570 assert_eq!(
1571 Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1572 tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1573 );
1574 assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1575 }
1576
1577 macro_rules! s {
1578 ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1579 }
1580
1581 macro_rules! r {
1582 ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1583 AxisOp::Reshape($at, tvec!($($from.into()),*), tvec!($($to.into()),*))
1584 }
1585 }
1586
1587 #[test]
1588 fn compute_invalid() {
1589 assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1590 }
1591
1592 #[test]
1593 fn compute_with_leading_zero() {
1594 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1595 }
1596
1597 #[test]
1598 fn compute_with_leading_zero_with_flatten() {
1599 assert_eq!(
1600 &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1601 s![2, 3, 35]
1602 )
1603 }
1604
1605 #[test]
1606 fn compute_with_trailing_zero() {
1607 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1608 }
1609
1610 #[test]
1611 fn compute_bug_1() {
1612 let table = SymbolScope::default();
1613 let s = table.new_with_prefix("S");
1614 assert_eq!(
1615 &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1616 s![s, 1, 256]
1617 )
1618 }
1619
1620 #[test]
1621 fn compute_bug_2() {
1622 let table = SymbolScope::default();
1623 let b = table.new_with_prefix("B");
1624 let s = table.new_with_prefix("S");
1625 assert_eq!(
1626 &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1627 s![s, b, 256]
1628 )
1629 }
1630
1631 #[test]
1632 fn axis_op_rm_begin() {
1633 assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1634 }
1635
1636 #[test]
1637 fn axis_op_rm_end() {
1638 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1639 }
1640
1641 #[test]
1642 fn axis_op_insert_begin() {
1643 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1644 }
1645
1646 #[test]
1647 fn axis_op_insert_end() {
1648 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1649 }
1650
1651 #[test]
1652 fn axis_op_merge() {
1653 assert_eq!(
1654 &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1655 &[r!(2 ; 5,7 => 35 )]
1656 )
1657 }
1658
1659 #[test]
1660 fn axis_op_complex() {
1661 assert_eq!(
1662 &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1663 &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1664 )
1665 }
1666}