1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::model::{TypedModel, TypedNode};
6use crate::ops::identity::Identity;
7use AxisOp::*;
8use tract_itertools::Itertools;
9use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
10use tract_ndarray::{ArrayViewD, ArrayViewMutD};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
13pub enum InOut {
14 Out(usize),
15 In(usize),
16}
17
18impl InOut {
19 pub fn as_outlet<F: Clone + Fact, O: Clone>(&self, node: &Node<F, O>) -> OutletId {
20 match self {
21 InOut::In(ix) => node.inputs[*ix],
22 InOut::Out(ix) => OutletId::new(node.id, *ix),
23 }
24 }
25
26 pub fn is_input(&self) -> bool {
27 matches!(self, InOut::In(_))
28 }
29
30 pub fn is_output(&self) -> bool {
31 matches!(self, InOut::Out(_))
32 }
33
34 pub fn slot(&self) -> usize {
35 match self {
36 InOut::Out(o) => *o,
37 InOut::In(i) => *i,
38 }
39 }
40}
41
42#[derive(Clone, Hash, Eq)]
43#[allow(clippy::large_enum_variant)] #[allow(clippy::derived_hash_with_manual_eq)] pub enum AxisOp {
46 Add(usize),
47 Rm(usize),
48 Move(usize, usize),
49 Reshape(usize, TVec<TDim>, TVec<TDim>),
50}
51
52impl Debug for AxisOp {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 AxisOp::Add(a) => write!(f, "Add({a})"),
56 AxisOp::Rm(a) => write!(f, "Rm({a})"),
57 AxisOp::Move(from, to) => write!(f, "Move({from},{to})"),
58 AxisOp::Reshape(at, from, to) => {
59 write!(f, "Reshape({at}, [{}], [{}])", from.iter().join(","), to.iter().join(","))
60 }
61 }
62 }
63}
64
65impl PartialEq for AxisOp {
66 fn eq(&self, other: &AxisOp) -> bool {
67 if self.is_noop() && other.is_noop() {
68 true
69 } else if self.is_noop() != other.is_noop() {
70 false
71 } else {
72 match (self, other) {
73 (Add(a), Add(b)) | (Rm(a), Rm(b)) => a == b,
74 (Move(f1, t1), Move(f2, t2)) => {
75 (f1 == f2 && t1 == t2)
76 || ((*t1 == f1 + 1 || *f1 == t1 + 1) && t2 == f1 && t1 == f2)
77 }
78 (Reshape(at1, f1, t1), Reshape(at2, f2, t2)) => at1 == at2 && f1 == f2 && t1 == t2,
79 _ => false,
80 }
81 }
82 }
83}
84
85impl AxisOp {
86 pub fn canonical(&self) -> Cow<'_, AxisOp> {
87 match self {
88 Move(from, to) if *from == to + 1 => Cow::Owned(Move(*to, *from)),
89 Reshape(at, from, to)
90 if from.len() == 1 && to.len() == 2 && from[0] == to[0] && to[1].is_one() =>
91 {
92 Cow::Owned(Add(*at + 1))
93 }
94 Reshape(at, from, to)
95 if from.len() == 1 && to.len() == 2 && from[0] == to[1] && to[0].is_one() =>
96 {
97 Cow::Owned(Add(*at))
98 }
99 Reshape(at, from, to)
100 if from.len() == 2 && to.len() == 1 && from[0] == to[0] && from[1].is_one() =>
101 {
102 Cow::Owned(Rm(*at + 1))
103 }
104 Reshape(at, from, to)
105 if from.len() == 2 && to.len() == 1 && from[1] == to[0] && from[0].is_one() =>
106 {
107 Cow::Owned(Rm(*at))
108 }
109 other => Cow::Borrowed(other),
110 }
111 }
112
113 pub fn simplify(&self) -> TVec<AxisOp> {
114 match self.canonical().borrow() {
115 Reshape(_, from, to) if from == to => tvec!(),
116 Reshape(at, from, to) if to.len() == 0 => tvec!(Rm(*at); from.len()),
117 Reshape(at, from, to) if from.len() == 0 => tvec!(Add(*at); to.len()),
118 Reshape(at, from, to) if from[0] == to[0] => {
119 Reshape(at + 1, from[1..].into(), to[1..].into()).simplify()
120 }
121 Reshape(at, from, to) if from[from.len() - 1] == to[to.len() - 1] => {
122 Reshape(*at, from[..from.len() - 1].into(), to[..to.len() - 1].into()).simplify()
123 }
124 Reshape(at, from, to) if from[0] == 1.to_dim() => std::iter::once(Rm(*at))
125 .chain(Reshape(*at, from[1..].into(), to.clone()).simplify())
126 .collect(),
127 Reshape(at, from, to) if to[0] == 1.to_dim() => {
128 Reshape(*at, from.clone(), to[1..].into())
129 .simplify()
130 .into_iter()
131 .chain(std::iter::once(Add(*at)))
132 .collect()
133 }
134 Reshape(at, from, to) if from[from.len() - 1] == 1.to_dim() => {
135 std::iter::once(Rm(at + from.len() - 1))
136 .chain(Reshape(*at, from[..from.len() - 1].into(), to.clone()).simplify())
137 .collect()
138 }
139 Reshape(at, from, to) if to[to.len() - 1] == 1.to_dim() => {
140 std::iter::once(Add(at + from.len()))
141 .chain(Reshape(*at, from.clone(), to[..to.len() - 1].into()).simplify())
142 .collect()
143 }
144 other => tvec!(other.clone()),
145 }
146 }
147
148 pub fn transform_axis(&self, axis: usize) -> Option<usize> {
149 match self.canonical().as_ref() {
150 Add(ix) => Some(axis + (axis >= *ix) as usize),
151 Rm(ix) => {
152 if axis == *ix {
153 None
154 } else {
155 Some(axis - (axis > *ix) as usize)
156 }
157 }
158 Move(from, to) if from < to => {
159 if axis < *from || axis > *to {
160 Some(axis)
161 } else if axis == *from {
162 Some(*to)
163 } else {
164 Some(axis - 1)
165 }
166 }
167 Move(from, to) => {
168 if axis < *to || axis > *from {
169 Some(axis)
170 } else if axis == *from {
171 Some(*to)
172 } else {
173 Some(axis + 1)
174 }
175 }
176 Reshape(at, _, _) if axis < *at => Some(axis),
177 Reshape(at, from, to) if axis >= at + from.len() => Some(axis + to.len() - from.len()),
178 Reshape(_, _, _) => None,
179 }
180 }
181
182 pub fn merge_incoming_change(
187 &self,
188 change: &AxisOp,
189 ) -> Option<(Option<AxisOp>, Option<AxisOp>)> {
190 match (self.canonical().as_ref(), change.canonical().as_ref()) {
191 (Add(op), Add(c)) => {
192 Some((Some(Add(op + (c < op) as usize)), Some(Add(c + (c >= op) as usize))))
193 }
194 (Add(op), Rm(c)) => {
195 Some((Some(Add(op - (c < op) as usize)), Some(Rm(c + (c >= op) as usize))))
196 }
197 (Rm(op), Add(c)) => {
198 Some((Some(Rm(op + (c <= op) as usize)), Some(Add(c - (op < c) as usize))))
199 }
200 (Rm(op), Rm(c)) => {
201 Some((Some(Rm(op - (c < op) as usize)), Some(Rm(c - (op <= c) as usize))))
202 }
203
204 (Add(x), Move(from, to)) => {
205 if x <= from.min(to) {
206 Some((Some(self.clone()), Some(Move(from + 1, to + 1))))
207 } else if x > from.max(to) {
208 Some((Some(self.clone()), Some(change.clone())))
209 } else {
210 None
211 }
212 }
213
214 (Move(from, to), Add(x)) => {
215 if x <= from.min(to) {
216 Some((Some(Move(from + 1, to + 1)), Some(Add(*x))))
217 } else if x > from.max(to) {
218 Some((Some(Move(*from, *to)), Some(Add(*x))))
219 } else {
220 None
221 }
222 }
223
224 (Rm(x), Move(from, to)) => {
225 if x == from {
226 Some((Some(Rm(*to)), None))
227 } else if x < from.min(to) {
228 Some((Some(self.clone()), Some(Move(from - 1, to - 1))))
229 } else if x > from.max(to) {
230 Some((Some(self.clone()), Some(change.clone())))
231 } else if from + 1 == *to && x == to {
232 Some((Some(Rm(*from)), None))
233 } else if from < to && x <= to {
234 Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1))))
235 } else {
236 Some((Some(Rm(x + 1)), Some(Move(*from - 1, *to))))
237 }
238 }
239
240 (Move(from, to), Rm(x)) => {
241 if x < from.min(to) {
242 Some((Some(Move(from - 1, to - 1)), Some(Rm(*x))))
243 } else if x > from.max(to) {
244 Some((Some(Move(*from, *to)), Some(Rm(*x))))
245 } else {
246 None
247 }
248 }
249
250 (Add(op), Reshape(at, from, to)) => {
251 if op <= at {
252 Some((Some(Add(*op)), Some(Reshape(at + 1, from.clone(), to.clone()))))
253 } else if *op > at + from.len() {
254 Some((
255 Some(Add(*op + to.len() - from.len())),
256 Some(Reshape(*at, from.clone(), to.clone())),
257 ))
258 } else {
259 None
260 }
261 }
262 (Rm(op), Reshape(at, from, to)) => {
263 if op < at {
264 Some((Some(Rm(*op)), Some(Reshape(at - 1, from.clone(), to.clone()))))
265 } else if *op > at + from.len() {
266 Some((
267 Some(Rm(*op + to.len() - from.len())),
268 Some(Reshape(*at, from.clone(), to.clone())),
269 ))
270 } else {
271 None
272 }
273 }
274 (Reshape(at, from, to), Add(change)) => {
275 if change < at {
276 Some((Some(Reshape(at + 1, from.clone(), to.clone())), Some(Add(*change))))
277 } else if *change > *at + from.len() {
278 Some((
279 Some(Reshape(*at, from.clone(), to.clone())),
280 Some(Add(change + to.len() - from.len())),
281 ))
282 } else {
283 None
284 }
285 }
286 (Reshape(at, from, to), Rm(change)) => {
287 if change < at {
288 Some((Some(Reshape(at - 1, from.clone(), to.clone())), Some(Rm(*change))))
289 } else if *change > *at + from.len() {
290 Some((
291 Some(Reshape(*at, from.clone(), to.clone())),
292 Some(Rm(change + to.len() - from.len())),
293 ))
294 } else {
295 None
296 }
297 }
298 (Reshape(_, _, _), Move(_, _)) => None, (Move(_, _), Reshape(_, _, _)) => None, (Reshape(_, _, _), Reshape(_, _, _)) => None, _ => None,
302 }
303 }
304
305 pub fn change_shape_array<D: DimLike>(
306 &self,
307 shape: &mut TVec<D>,
308 broadcasting: bool,
309 ) -> TractResult<()> {
310 match self.canonical().as_ref() {
311 Add(ix) => {
312 ensure!(*ix <= shape.len());
313 shape.insert(*ix, D::one());
314 }
315 Rm(ix) => {
316 ensure!(*ix < shape.len());
317 shape.remove(*ix);
318 }
319 Move(from, to) => {
320 ensure!(*from < shape.len());
321 ensure!(*to < shape.len());
322 let axis = shape.remove(*from);
323 shape.insert(*to, axis);
324 }
325 Reshape(at, from, to) => {
326 let from_volume = from.iter().product::<TDim>();
327 let to_volume = to.iter().product::<TDim>();
328 ensure!(
335 from_volume.clone().expand_polynomial()
336 == to_volume.clone().expand_polynomial(),
337 "{from_volume} should be equal to {to_volume}"
338 );
339 ensure!(*at + from.len() <= shape.len());
340 if shape.len() >= from.len() + *at
341 && tract_itertools::izip!(shape.iter().skip(*at), from)
342 .all(|(shape, spec)| shape.to_dim() == *spec)
343 {
344 for _ in from {
345 shape.remove(*at);
346 }
347 for d in to.iter().rev() {
348 shape.insert(*at, d.try_into()?);
349 }
350 } else if broadcasting
351 && shape.iter().skip(*at).take(from.len()).all(|d| d.to_dim() == 1.to_dim())
352 {
353 for _ in from {
354 shape.remove(*at);
355 }
356 for _ in to.iter().rev() {
357 shape.insert(*at, 1.into());
358 }
359 } else {
360 bail!("Incompatible reshape for shape {:?} and {:?}", shape, self);
361 }
362 }
363 }
364 Ok(())
365 }
366
367 pub fn change_shape(&self, shape: &mut ShapeFact, broadcasting: bool) -> TractResult<()> {
368 match self.canonical().as_ref() {
369 Add(ix) => shape.insert_axis(*ix),
370 Rm(ix) => {
371 if shape.rank() <= *ix {
372 bail!("Attempt to remove axis #{} on shape {:?}", ix, shape);
373 }
374 if shape[*ix] != 1.to_dim() {
375 bail!("Removing non-trivial axis #{} of dim: {:?}", ix, shape);
376 }
377 shape.remove_axis(*ix)
378 }
379 _ => {
380 let mut array = shape.to_tvec();
381 self.change_shape_array(&mut array, broadcasting)?;
382 let mut new_shape = ShapeFact::from_dims(array);
383 std::mem::swap(shape, &mut new_shape);
384 Ok(())
385 }
386 }
387 }
388
389 pub fn change_tensor(&self, tensor: &mut Tensor, broadcasting: bool) -> TractResult<()> {
390 if tensor.storage_as::<BlockQuantStorage>().is_some() {
391 let bqs = tensor.try_storage_as::<BlockQuantStorage>()?.clone();
392 let mut new_shape: TVec<usize> = tensor.shape().into();
393 self.change_shape_array(&mut new_shape, false)?;
394 let mut new_tensor = bqs.into_tensor_with_shape(tensor.datum_type(), &new_shape);
395 std::mem::swap(tensor, &mut new_tensor);
396 return Ok(());
397 }
398 ensure!(self.required_rank() <= tensor.rank());
399 match self.canonical().as_ref() {
400 Add(ix) => tensor.insert_axis(*ix),
401 Rm(ix) => tensor.remove_axis(*ix),
402 Move(from, to) => {
403 let mut tmp = tensor.clone().move_axis(*from, *to)?;
404 std::mem::swap(tensor, &mut tmp);
405 Ok(())
406 }
407 Reshape(at, from, to) => {
408 let mut shape: TVec<usize> = tensor.shape().into();
409 self.change_shape_array(&mut shape, true)?;
410 if tensor.set_shape(&shape).is_ok() {
411 Ok(())
412 } else if broadcasting
413 && tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
414 {
415 if from.len() > to.len() {
416 for _ in to.len()..from.len() {
417 tensor.remove_axis(*at)?;
418 }
419 }
420 if to.len() > from.len() {
421 for _ in from.len()..to.len() {
422 tensor.insert_axis(*at)?;
423 }
424 }
425 Ok(())
426 } else {
427 bail!(
428 "Invalid reshaping: {:?} on tensor {:?} (broadcasting allowed: {:?})",
429 self,
430 tensor,
431 broadcasting
432 )
433 }
434 }
435 }
436 }
437
438 pub fn change_view<D>(&self, view: &mut ArrayViewD<D>) -> TractResult<()> {
439 use tract_ndarray::Axis;
440 match *self {
441 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
442 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
443 AxisOp::Move(from, to) if from < to => {
444 for left in from..to {
445 view.swap_axes(left, left + 1);
446 }
447 }
448 AxisOp::Move(from, to) => {
449 for left in (to..from).rev() {
450 view.swap_axes(left, left + 1);
451 }
452 }
453 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
454 }
455 Ok(())
456 }
457
458 pub fn change_view_mut<D>(&self, view: &mut ArrayViewMutD<D>) -> TractResult<()> {
459 use tract_ndarray::Axis;
460 match *self {
461 AxisOp::Rm(axis) => view.index_axis_inplace(Axis(axis), 0),
462 AxisOp::Add(axis) => view.insert_axis_inplace(Axis(axis)),
463 AxisOp::Move(from, to) if from < to => {
464 for left in from..to {
465 view.swap_axes(left, left + 1);
466 }
467 }
468 AxisOp::Move(from, to) => {
469 for left in (to..from).rev() {
470 view.swap_axes(left, left + 1);
471 }
472 }
473 AxisOp::Reshape(_, _, _) => bail!("Reshape can not change views in place"),
474 }
475 Ok(())
476 }
477
478 pub fn recip(&self) -> AxisOp {
479 match self.canonical().as_ref() {
480 Add(ix) => Rm(*ix),
481 Rm(ix) => Add(*ix),
482 Move(from, to) if from == to => self.clone(),
483 Move(from, to) if *from + 1 == *to => self.clone(),
484 Move(from, to) if *from == *to + 1 => {
485 unreachable!();
486 }
487 Move(from, to) => Move(*to, *from),
488 Reshape(at, from, to) => Reshape(*at, to.clone(), from.clone()),
489 }
490 }
491
492 pub fn is_noop(&self) -> bool {
493 match self {
494 Move(f, t) if f == t => true,
495 Reshape(_, f, t) if f == t => true,
496 _ => false,
497 }
498 }
499
500 pub fn only_shape(&self) -> bool {
501 if self.is_noop() {
502 return true;
503 }
504 !matches!(self, Move(_, _))
505 }
506
507 pub fn wire_split_axis(
508 model: &mut TypedModel,
509 name: impl ToString,
510 outlet: OutletId,
511 axis: usize,
512 outer_dim: usize,
513 ) -> TractResult<TVec<OutletId>> {
514 let fact = model.outlet_fact(outlet)?;
515 let dim: TDim = fact.shape[axis].clone();
516 let inner_dim = dim.clone() / outer_dim;
517 let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
518 model.wire_node(name.to_string(), op, &[outlet])
519 }
520
521 pub fn wire_collapse_axis(
522 model: &mut TypedModel,
523 name: impl ToString,
524 outlet: OutletId,
525 axis: usize,
526 ) -> TractResult<TVec<OutletId>> {
527 let fact = model.outlet_fact(outlet)?;
528 let dim: TDim = fact.shape[axis].clone();
529 let next_dim: TDim = fact.shape[axis + 1].clone();
530 let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
531 model.wire_node(name.to_string(), op, &[outlet])
532 }
533
534 #[inline]
535 pub fn required_rank(&self) -> usize {
536 match self {
537 Rm(r) => r + 1,
538 Add(a) => *a,
539 Reshape(at, from, _to) => at + from.len(),
540 Move(from, to) => *from.max(to),
541 }
542 }
543
544 pub fn trim_left(&self, prefix: usize) -> TractResult<AxisOp> {
545 Ok(match self {
546 Rm(r) if *r >= prefix => Rm(r - prefix),
547 Add(a) if *a >= prefix => Add(a - prefix),
548 Reshape(at, from, to) if *at >= prefix => {
549 Reshape(at - prefix, from.clone(), to.clone())
550 }
551 Move(from, to) if *from >= prefix && *to >= prefix => Move(from - prefix, to - prefix),
552 _ => bail!("Can no trim left {self:?} by {prefix}"),
553 })
554 }
555}
556
557pub fn wire_rank_broadcast(
558 prefix: impl AsRef<str>,
559 target: &mut TypedModel,
560 inputs: &[OutletId],
561) -> TractResult<TVec<OutletId>> {
562 let facts =
563 inputs.iter().map(|o| target.outlet_fact(*o).cloned()).collect::<TractResult<TVec<_>>>()?;
564 let max_rank = facts.iter().map(|f| f.rank()).max().unwrap();
565 let mut wires = tvec!();
566 for i in 0..inputs.len() {
567 let mut wire = inputs[i];
568 for _ in facts[i].rank()..max_rank {
569 let name = target.unique_name(prefix.as_ref().to_string() + ".fix-rank");
570 wire = target.wire_node(name, AxisOp::Add(0), &[wire])?[0];
571 }
572 wires.push(wire);
573 }
574 Ok(wires)
575}
576
577pub fn wire_with_rank_broadcast(
578 prefix: impl AsRef<str>,
579 target: &mut TypedModel,
580 op: impl Into<Box<dyn TypedOp>>,
581 inputs: &[OutletId],
582) -> TractResult<TVec<OutletId>> {
583 let prefix = prefix.as_ref();
584 let wires = wire_rank_broadcast(prefix, target, inputs)?;
585 target.wire_node(prefix, op.into(), &wires)
586}
587
588#[derive(Clone, Debug, PartialEq, Eq, Hash)]
589pub struct AxisChange {
590 pub outlet: OutletId,
591 pub op: AxisOp,
592}
593
594#[derive(Clone, Default, Debug)]
595pub struct AxisChangeConsequence {
596 pub substitute_op: Option<Box<dyn TypedOp>>,
597 pub wire_changes: TVec<(InOut, AxisOp)>,
598}
599
600impl AxisChangeConsequence {
601 pub fn new(
602 _model: &TypedModel,
603 node: &TypedNode,
604 op: Option<Box<dyn TypedOp>>,
605 axis_op: &AxisOp,
606 ) -> AxisChangeConsequence {
607 let mut wire_changes = tvec!();
608 for i in 0..node.inputs.len() {
609 wire_changes.push((InOut::In(i), axis_op.clone()));
610 }
611 for i in 0..node.outputs.len() {
612 wire_changes.push((InOut::Out(i), axis_op.clone()));
613 }
614 AxisChangeConsequence { wire_changes, substitute_op: op }
615 }
616}
617
618impl Op for AxisOp {
619 fn name(&self) -> StaticName {
620 match self {
621 Add(_) => "AddAxis".into(),
622 Rm(_) => "RmAxis".into(),
623 Move(_, _) => "MoveAxis".into(),
624 Reshape(_, _, _) => "Reshape".into(),
625 }
626 }
627
628 fn info(&self) -> TractResult<Vec<String>> {
629 match self {
630 Add(axis) | Rm(axis) => Ok(vec![format!("Axis: {axis}")]),
631 Move(from, to) => Ok(vec![format!("Axis {from} to {to}")]),
632 Reshape(at, from, to) => Ok(vec![format!(
633 "Axes starting at {}: {:?} to {:?}",
634 at,
635 from.iter().join(","),
636 to.iter().join(",")
637 )]),
638 }
639 }
640
641 op_as_typed_op!();
642}
643
644impl EvalOp for AxisOp {
645 fn is_stateless(&self) -> bool {
646 true
647 }
648
649 fn eval_with_session(
650 &self,
651 _node_id: usize,
652 session: &TurnState,
653 inputs: TVec<TValue>,
654 ) -> TractResult<TVec<TValue>> {
655 let mut input = args_1!(inputs).into_tensor();
656 match self {
657 AxisOp::Reshape(skip, from, to) => {
658 let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
659 let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
660 AxisOp::Reshape(*skip, from, to).change_tensor(&mut input, false)?
661 }
662 _ => self.change_tensor(&mut input, false)?,
663 }
664 Ok(tvec!(input.into_tvalue()))
665 }
666}
667
668fn remap_uniform_tdim(expr: &TDim, axis_op: &AxisOp) -> Option<TDim> {
672 let syms = expr.symbols();
673 let coord_syms: Vec<(usize, Symbol)> = syms
674 .into_iter()
675 .filter_map(|s| {
676 let name = format!("{s}");
677 name.strip_prefix("🎯").and_then(|rest| rest.parse::<usize>().ok()).map(|k| (k, s))
678 })
679 .collect();
680
681 if coord_syms.is_empty() {
682 return Some(expr.clone());
684 }
685
686 if let AxisOp::Reshape(at, from_dims, to_dims) = axis_op.canonical().as_ref() {
687 if from_dims.iter().all(|d| d.is_one()) && to_dims.iter().all(|d| d.is_one()) {
689 return Some(expr.clone());
690 }
691 if from_dims.len() == 1 {
697 let from_dim = from_dims[0].clone();
698 let to_product: TDim = to_dims.iter().fold(TDim::Val(1), |acc, d| acc * d.clone());
699 if to_product == from_dim {
700 let k_to = to_dims.len();
701 let mut map: HashMap<Symbol, TDim> = HashMap::default();
702 for (k, sym) in &coord_syms {
703 let scope = sym.scope()?;
704 let new_expr = if *k < *at {
705 TDim::Sym(sym.clone())
706 } else if *k == *at {
707 let mut sum = TDim::Val(0);
708 let mut stride = TDim::Val(1);
709 for i in (0..k_to).rev() {
710 let new_sym = scope.coord_sym(*at + i);
711 sum = sum + TDim::Sym(new_sym) * stride.clone();
712 stride = stride * to_dims[i].clone();
713 }
714 sum
715 } else {
716 TDim::Sym(scope.coord_sym(*k + k_to - 1))
717 };
718 map.insert(sym.clone(), new_expr);
719 }
720 return expr.substitute_all(&map).ok().map(|e| e.reduce());
721 }
722 }
723 if to_dims.len() == 1 {
729 let to_dim = to_dims[0].clone();
730 let from_product: TDim = from_dims.iter().fold(TDim::Val(1), |acc, d| acc * d.clone());
731 if from_product == to_dim {
732 let k_from = from_dims.len();
733 let mut map: HashMap<Symbol, TDim> = HashMap::default();
734 for (k, sym) in &coord_syms {
735 let scope = sym.scope()?;
736 let new_expr = if *k < *at {
737 TDim::Sym(sym.clone())
738 } else if *k < *at + k_from {
739 let i = *k - *at;
740 let only_nontrivial =
741 from_dims.iter().enumerate().all(|(j, d)| j == i || d.is_one());
742 if only_nontrivial {
743 TDim::Sym(scope.coord_sym(*at))
744 } else {
745 return None;
746 }
747 } else {
748 TDim::Sym(scope.coord_sym(*k - (k_from - 1)))
749 };
750 map.insert(sym.clone(), new_expr);
751 }
752 return expr.substitute_all(&map).ok().map(|e| e.reduce());
753 }
754 }
755 return None;
756 }
757
758 let map: HashMap<Symbol, TDim> = coord_syms
761 .into_iter()
762 .filter_map(|(k, sym)| {
763 let new_k = axis_op.transform_axis(k)?;
764 if new_k == k {
765 return None;
766 }
767 let scope = sym.scope()?;
768 Some((sym, TDim::Sym(scope.coord_sym(new_k))))
769 })
770 .collect();
771 if map.is_empty() {
772 return Some(expr.clone());
773 }
774 expr.substitute_all(&map).ok()
775}
776
777impl TypedOp for AxisOp {
778 as_op!();
779
780 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
781 if let Some(bqf) =
782 inputs[0].exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
783 {
784 let mut new_shape: TVec<usize> = bqf.shape().into();
785 self.change_shape_array(&mut new_shape, false)?;
786 let new_bqf = BlockQuantFact::new(bqf.format.clone(), new_shape.clone());
787 let shape: TVec<TDim> = new_shape.iter().map(|d| d.to_dim()).collect();
788 let mut new_fact = inputs[0].datum_type.fact(&*shape).with_exotic_fact(new_bqf);
789 if let Some(k) = &inputs[0].konst {
790 let mut new = k.clone().into_tensor();
791 self.change_tensor(&mut new, false)?;
792 new_fact.konst = Some(new.into());
793 }
794 return Ok(tvec!(new_fact));
795 }
796 let mut shape = inputs[0].shape.clone();
797 self.change_shape(&mut shape, false)?;
798 let mut fact = inputs[0].datum_type.fact(shape);
799 fact.exotic_fact.clone_from(&inputs[0].exotic_fact);
800 if let Some(tdim) = &inputs[0].uniform_tdim {
801 fact.uniform_tdim = remap_uniform_tdim(tdim, self);
802 }
803 Ok(tvec!(fact))
804 }
805
806 fn input_roi(
807 &self,
808 model: &TypedModel,
809 node: &TypedNode,
810 ) -> TractResult<Option<TVec<Option<TDim>>>> {
811 crate::optim::propagate_roi::bubble_roi(model, node)
812 }
813
814 fn axes_mapping(
815 &self,
816 inputs: &[&TypedFact],
817 outputs: &[&TypedFact],
818 ) -> TractResult<AxesMapping> {
819 let mut axes: Vec<Axis> = (0..inputs[0].rank())
820 .zip('a'..)
821 .map(|(axis_id, repr)| {
822 let mut axis = Axis::new(repr, inputs.len(), outputs.len()).input(0, axis_id);
823 if let Some(out) = self.transform_axis(axis_id) {
824 axis = axis.output(0, out);
825 }
826 axis
827 })
828 .collect();
829 for (axis, letter) in (0..outputs[0].rank()).zip('A'..) {
830 if self.recip().transform_axis(axis).is_none() {
831 axes.push(Axis::new(letter, inputs.len(), outputs.len()).output(0, axis));
832 }
833 }
834 AxesMapping::new(inputs.len(), outputs.len(), axes)
835 }
836
837 fn declutter(
838 &self,
839 model: &TypedModel,
840 node: &TypedNode,
841 ) -> TractResult<Option<TypedModelPatch>> {
842 if self.is_noop()
843 && let Some(p) = TypedModelPatch::shunt_one_op(model, node)?
844 {
845 return Ok(Some(p));
846 }
847 let simplified = self.simplify();
848 if simplified.len() != 1 || &simplified[0] != self {
849 let mut patch = TypedModelPatch::default();
850 let mut wire = patch.tap_model(model, node.inputs[0])?;
851 for (ix, op) in simplified.into_iter().enumerate() {
852 wire = patch.wire_node(format!("{}.{}", node.name, ix), op, &[wire])?[0];
853 }
854 patch.shunt_outside(model, node.id.into(), wire)?;
855 Ok(Some(patch))
856 } else {
857 Ok(None)
858 }
859 }
860
861 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
862 Ok(tvec!((InOut::Out(0), self.recip()), (InOut::In(0), self.clone())))
863 }
864
865 fn change_axes(
866 &self,
867 _model: &TypedModel,
868 _node: &TypedNode,
869 io: InOut,
870 change: &AxisOp,
871 ) -> TractResult<Option<AxisChangeConsequence>> {
872 let op = if let InOut::Out(0) = io {
873 rule_if_some!(more = self.recip().change_axes(_model, _node, InOut::In(0), change)?);
874 AxisChangeConsequence {
875 substitute_op: more.substitute_op.map(|op| {
876 if let Some(op) = op.as_op().downcast_ref::<AxisOp>() {
877 Box::new(op.recip())
878 } else {
879 op }
881 }),
882 wire_changes: more
883 .wire_changes
884 .into_iter()
885 .map(|wc| {
886 (if wc.0 == InOut::In(0) { InOut::Out(0) } else { InOut::In(0) }, wc.1)
887 })
888 .collect(),
889 }
890 } else if change == self {
891 AxisChangeConsequence { substitute_op: Some(Box::new(Identity)), wire_changes: tvec!() }
892 } else {
893 rule_if_some!((new_op, new_change) = self.merge_incoming_change(change));
894 trace!(" Change:{change:?} self:{self:?} -> change:{new_change:?} op:{new_op:?}");
895 let substitute_op: Box<dyn TypedOp> =
896 if let Some(o) = new_op { Box::new(o) as _ } else { Box::new(Identity) };
897 let mut wire_changes = tvec!();
898 if !change.is_noop() {
899 wire_changes.push((InOut::In(0), change.clone()))
900 }
901 if let Some(new_change) = new_change {
902 wire_changes.push((InOut::Out(0), new_change))
903 }
904 AxisChangeConsequence { substitute_op: Some(substitute_op), wire_changes }
905 };
906 Ok(Some(op))
907 }
908
909 fn set_symbols(
910 &self,
911 _source: &TypedModel,
912 node: &TypedNode,
913 target: &mut TypedModel,
914 mapping: &HashMap<OutletId, OutletId>,
915 subs: &HashMap<Symbol, TDim>,
916 ) -> TractResult<TVec<OutletId>> {
917 let op = if let AxisOp::Reshape(axis, from, to) = self {
918 AxisOp::Reshape(
919 *axis,
920 from.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<_>>()?,
921 to.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<_>>()?,
922 )
923 } else {
924 self.clone()
925 };
926 target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]])
927 }
928
929 fn slice(
930 &self,
931 patch: &mut TypedModelPatch,
932 _model: &TypedModel,
933 node: &TypedNode,
934 _prefix: &str,
935 inputs: &[OutletId],
936 output_axis: usize,
937 _start: &TDim,
938 _end: &TDim,
939 ) -> TractResult<Option<TVec<OutletId>>> {
940 if let Reshape(pos, _from, to) = self
942 && output_axis >= *pos
943 && output_axis < pos + to.len()
944 {
945 return Ok(None);
946 }
947 patch.wire_node(&node.name, &node.op, inputs).map(Some)
948 }
949
950 fn codegen(
951 &self,
952 model: &TypedModel,
953 node: &TypedNode,
954 ) -> TractResult<Option<TypedModelPatch>> {
955 rule_if!(node.outputs[0].fact.exotic_fact.is_none());
956 if let Some(shape) = node.outputs[0].fact.shape.as_concrete()
957 && !matches!(self, AxisOp::Move(_, _))
958 {
959 let (inputs, outputs) = model.node_facts(node.id)?;
960 let mapping = self.axes_mapping(&inputs, &outputs)?;
961 let op = IntoShape {
962 mapping,
963 len: shape.iter().product(),
964 strides: Tensor::natural_strides(shape),
965 dims: shape.into(),
966 };
967 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, op)?));
968 }
969 Ok(None)
970 }
971}
972
973fn perm_to_cycles(perm: &[usize]) -> TVec<TVec<usize>> {
975 let mut cycles: TVec<TVec<usize>> = tvec!();
976 let mut done = 0;
977 while done < perm.len() {
978 if perm[done] == done || cycles.iter().any(|c| c.contains(&done)) {
979 done += 1;
980 continue;
981 }
982 let mut cycle = tvec!();
983 let mut current = done;
984 loop {
985 cycle.push(current);
986 current = perm[current];
987 if current == done {
988 break;
989 }
990 }
991 cycles.push(cycle)
992 }
993 cycles
994}
995
996fn is_rotation_cycle(cycle: &[usize]) -> Option<(usize, usize)> {
997 if cycle.windows(2).all(|w| w[0] + 1 == w[1]) {
998 Some((cycle[0], cycle[cycle.len() - 1]))
999 } else if cycle[1..cycle.len()].windows(2).all(|w| w[0] - 1 == w[1])
1000 && cycle[cycle.len() - 1] - 1 == cycle[0]
1001 {
1002 Some((cycle[1], cycle[0]))
1003 } else {
1004 None
1005 }
1006}
1007
1008fn perm_to_atoms(input: &[usize]) -> TVec<(usize, usize)> {
1009 let mut changes: TVec<(usize, usize)> = tvec!();
1010 'top: loop {
1011 let mut reached: TVec<usize> = (0..input.len()).collect();
1012 changes.iter().for_each(|(f, t)| {
1013 let axis = reached.remove(*f);
1014 reached.insert(*t, axis);
1015 });
1016 if &*reached == input {
1017 return changes;
1018 }
1019 let remaining: TVec<usize> =
1020 input.iter().map(|x| reached.iter().position(|y| y == x).unwrap()).collect();
1021 let cycles = perm_to_cycles(&remaining);
1022 for cycle in &cycles {
1023 if let Some(rot) = is_rotation_cycle(cycle) {
1024 changes.push(rot);
1025 continue 'top;
1026 }
1027 }
1028 changes.push((cycles[0][1], cycles[0][0]));
1029 }
1030}
1031
1032pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {
1033 perm_to_atoms(input).into_iter().map(|pair| AxisOp::Move(pair.0, pair.1)).collect()
1034}
1035
1036pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
1037 let mut shape: TVec<TDim> = shape_spec.into();
1038 for (i, s) in shape.iter_mut().enumerate() {
1040 if *s == 0.into() {
1041 *s = input
1042 .get(i)
1043 .with_context(|| {
1044 format!("Reshape: 0 at position {i} but input only has {} dims", input.len())
1045 })?
1046 .clone();
1047 }
1048 }
1049 let input_vol: TDim = input.iter().product();
1050 if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
1051 let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
1052 let div = input_vol.maybe_div(&shape_vol)?;
1053 if div.1 != 1 {
1054 bail!("invalid")
1055 }
1056 shape[pos] = div.0;
1057 } else {
1058 let shape_vol: TDim = shape.iter().product();
1059 if input_vol != shape_vol {
1060 bail!(
1061 "Reshape volume mismatch: input {input:?} (vol={input_vol}) vs shape {shape:?} (vol={shape_vol})"
1062 );
1063 }
1064 }
1065 Ok(shape)
1066}
1067
1068pub fn to_axis_ops_with_tf_rules(
1069 input_orig: &[TDim],
1070 output_spec: &[TDim],
1071) -> TractResult<TVec<AxisOp>> {
1072 let final_output = compute_shape_with_tf_rules(input_orig, output_spec)?;
1073 let mut stack: TVec<AxisOp> = tvec!();
1074 'top: loop {
1075 let current_input =
1076 stack.iter().try_fold(TVec::from(input_orig), |mut shape, op| -> TractResult<_> {
1077 op.change_shape_array(&mut shape, false)?;
1078 Ok(shape)
1079 })?;
1080 if current_input == final_output {
1081 return Ok(stack);
1082 }
1083 if let Some(common) =
1084 current_input.iter().zip(final_output.iter()).position(|(a, b)| a != b)
1085 {
1086 if current_input[common].is_one() {
1087 stack.push(AxisOp::Rm(common));
1088 } else if final_output[common].is_one() {
1089 stack.push(AxisOp::Add(common));
1090 } else {
1091 for i in common..current_input.len() {
1094 let i_group = ¤t_input[common..i + 1];
1095 let i_volume: TDim = i_group.iter().product();
1096 for o in common..final_output.len() {
1097 let o_group = &final_output[common..o + 1];
1098 let o_volume: TDim = o_group.iter().product();
1099 if i_volume == o_volume {
1100 stack.push(AxisOp::Reshape(common, i_group.into(), o_group.into()));
1101 continue 'top;
1102 }
1103 }
1104 }
1105 bail!(
1106 "Could not find matching reshape grouping: current_input={current_input:?} final_output={final_output:?} common={common}"
1107 )
1108 }
1109 } else if final_output.len() > current_input.len() {
1110 stack.push(AxisOp::Add(current_input.len()));
1111 } else {
1112 stack.push(AxisOp::Rm(current_input.len() - 1));
1113 }
1114 }
1115}
1116
1117#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1118pub struct IntoShape {
1119 pub mapping: AxesMapping,
1120 pub len: usize,
1121 pub dims: TVec<usize>,
1122 pub strides: TVec<isize>,
1123}
1124
1125impl Op for IntoShape {
1126 fn name(&self) -> StaticName {
1127 "IntoShape".into()
1128 }
1129
1130 fn info(&self) -> TractResult<Vec<String>> {
1131 Ok(vec![format!("{}", self.mapping)])
1132 }
1133
1134 op_as_typed_op!();
1135}
1136
1137impl EvalOp for IntoShape {
1138 fn is_stateless(&self) -> bool {
1139 true
1140 }
1141
1142 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
1143 let mut input = args_1!(inputs).into_tensor();
1144 ensure!(input.len() == self.len);
1145 unsafe { input.set_geometry_unchecked(&self.dims, &self.strides) };
1146 Ok(tvec!(input.into_tvalue()))
1147 }
1148}
1149
1150impl TypedOp for IntoShape {
1151 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
1152 let mut fact = inputs[0].datum_type.fact(&self.dims);
1153 if let Some(of) = &inputs[0].exotic_fact {
1154 fact = fact.with_exotic_fact(of.clone());
1155 }
1156 Ok(tvec!(fact))
1157 }
1158
1159 fn declutter(
1160 &self,
1161 model: &TypedModel,
1162 node: &TypedNode,
1163 ) -> TractResult<Option<TypedModelPatch>> {
1164 let input = model.outlet_fact(node.inputs[0])?;
1165 if input.shape.as_concrete().is_some_and(|shape| shape == &*self.dims) {
1166 return TypedModelPatch::shunt_one_op(model, node);
1167 }
1168 if let Some(succ) = model.single_succ(node.id)?
1169 && let Some(into_shape) = succ.op_as::<IntoShape>()
1170 {
1171 let op =
1172 Self { mapping: self.mapping.compose(&into_shape.mapping)?, ..into_shape.clone() };
1173 return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
1174 }
1175 Ok(None)
1176 }
1177
1178 as_op!();
1179}
1180
1181#[cfg(test)]
1182mod test {
1183 use super::*;
1184
1185 #[test]
1186 fn test_perm_to_cycles() {
1187 assert_eq!(perm_to_cycles(&[1, 2, 0]), tvec!(tvec!(0, 1, 2)));
1188 assert_eq!(perm_to_cycles(&[2, 0, 1]), tvec!(tvec!(0, 2, 1)));
1189 assert_eq!(perm_to_cycles(&[1, 2, 3, 0]), tvec!(tvec!(0, 1, 2, 3)));
1190 assert_eq!(perm_to_cycles(&[3, 0, 1, 2]), tvec!(tvec!(0, 3, 2, 1)));
1191 assert_eq!(perm_to_cycles(&[3, 1, 2, 0, 4]), tvec!(tvec!(0, 3)));
1192 }
1193
1194 #[test]
1195 fn is_rotation() {
1196 assert_eq!(is_rotation_cycle(&[0, 1, 2]), Some((0, 2)));
1197 assert_eq!(is_rotation_cycle(&[0, 2, 1]), Some((2, 0)));
1198 }
1199
1200 #[test]
1201 fn test_perm_one_rotation() {
1202 assert_eq!(perm_to_atoms(&[1, 2, 0, 3, 4]), tvec!((0, 2)));
1203 }
1204
1205 #[test]
1206 fn test_perm_two_rotations() {
1207 assert_eq!(perm_to_atoms(&[1, 2, 0, 4, 3]), tvec!((0, 2), (3, 4)));
1208 }
1209
1210 #[test]
1211 fn test_perm_complex() {
1212 assert_eq!(perm_to_atoms(&[3, 1, 2, 0, 4]), tvec!((3, 0), (1, 3)));
1213 }
1214
1215 #[test]
1222 pub fn transform_op_add_0_add_0() {
1223 let change = Add(0);
1224 let op = Add(0);
1225 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(1)))));
1226 }
1227
1228 #[test]
1233 pub fn transform_op_add_0_add_1() {
1234 let change = Add(0);
1235 let op = Add(1);
1236 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(2)), Some(Add(0)))));
1237 }
1238
1239 #[test]
1244 pub fn transform_op_add_1_add_0() {
1245 let change = Add(1);
1246 let op = Add(0);
1247 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Add(2)))));
1248 }
1249
1250 #[test]
1255 pub fn transform_op_rm_0_rm_1() {
1256 let change = Rm(0);
1257 let op = Rm(1);
1258 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1259 }
1260
1261 #[test]
1266 pub fn transform_op_rm_1_rm_0() {
1267 let change = Rm(1);
1268 let op = Rm(0);
1269 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Rm(0)))));
1270 }
1271
1272 #[test]
1279 pub fn transform_op_add_0_rm_0() {
1280 let change = Add(0);
1281 let op = Rm(0);
1282 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Add(0)))));
1283 }
1284
1285 #[test]
1290 pub fn transform_op_add_0_rm_1() {
1291 let change = Add(0);
1292 let op = Rm(1);
1293 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(2)), Some(Add(0)))));
1294 }
1295
1296 #[test]
1301 pub fn transform_op_add_1_rm_0() {
1302 let change = Add(1);
1303 let op = Rm(0);
1304 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(0)), Some(Add(0)))));
1305 }
1306
1307 #[test]
1314 pub fn transform_op_rm_1_add_0() {
1315 let change = Rm(1);
1316 let op = Add(0);
1317 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(2)))));
1318 }
1319
1320 #[test]
1325 pub fn transform_op_rm_0_add_1() {
1326 let change = Rm(0);
1327 let op = Add(1);
1328 assert_eq!(op.merge_incoming_change(&change), Some((Some(Add(0)), Some(Rm(0)))));
1329 }
1330
1331 #[test]
1336 pub fn transform_op_mv_02_rm_2() {
1337 let change = Move(0, 2);
1338 let op = Rm(2);
1339 assert_eq!(op.merge_incoming_change(&change), Some((Some(Rm(1)), Some(Move(0, 1)))));
1340 }
1341}
1342
1343#[cfg(test)]
1344mod proptests {
1345 use super::*;
1346 use proptest::prelude::*;
1347
1348 #[derive(Debug)]
1349 struct ComposeProblem {
1350 input: TVec<usize>,
1351 ops: TVec<AxisOp>,
1352 }
1353
1354 impl Arbitrary for AxisOp {
1355 type Parameters = TVec<usize>;
1356 type Strategy = BoxedStrategy<AxisOp>;
1357 fn arbitrary_with(shape: TVec<usize>) -> Self::Strategy {
1358 let mut ops: BoxedStrategy<AxisOp> = (0usize..shape.len() + 1).prop_map(Add).boxed();
1359 if shape.len() > 1 {
1360 ops = ops
1361 .prop_union(
1362 (0..shape.len(), 0..shape.len() - 1)
1363 .prop_map(|(a, b)| Move(a, b + (b >= a) as usize))
1364 .boxed(),
1365 )
1366 .boxed()
1367 }
1368 let rms = (0..shape.len()).filter(|&ax| shape[ax] == 1).map(Rm).collect::<Vec<_>>();
1369 if rms.len() > 0 {
1370 ops = ops
1371 .prop_union((0..rms.len()).prop_map(move |rm| rms[rm].clone()).boxed())
1372 .boxed()
1373 }
1374 let mergeable: Vec<AxisOp> = shape
1375 .windows(2)
1376 .enumerate()
1377 .filter(|(_, w)| w[0] > 1 && w[1] > 1)
1378 .map(|(ix, w)| {
1379 Reshape(ix, tvec!(w[0].to_dim(), w[1].to_dim()), tvec!((w[0] * w[1]).to_dim()))
1380 })
1381 .collect();
1382 if mergeable.len() > 1 {
1383 ops = ops
1384 .prop_union(
1385 (0..mergeable.len()).prop_map(move |ix| mergeable[ix].clone()).boxed(),
1386 )
1387 .boxed()
1388 }
1389 ops
1390 }
1391 }
1392
1393 impl Arbitrary for ComposeProblem {
1394 type Parameters = ();
1395 type Strategy = BoxedStrategy<ComposeProblem>;
1396 fn arbitrary_with(_args: ()) -> Self::Strategy {
1397 let input = proptest::collection::vec(1usize..4, 1usize..4);
1398 fn tail(len: usize, shape: TVec<usize>) -> BoxedStrategy<TVec<AxisOp>> {
1399 if len == 0 {
1400 Just(tvec!()).boxed()
1401 } else {
1402 AxisOp::arbitrary_with(shape.clone())
1403 .prop_flat_map(move |op| {
1404 let mut shape = shape.clone();
1405 op.change_shape_array(&mut shape, false).unwrap();
1406 tail(len - 1, shape.clone()).prop_map(move |mut t| {
1407 t.insert(0, op.clone());
1408 t
1409 })
1410 })
1411 .boxed()
1412 }
1413 }
1414 (input, 1usize..=5)
1415 .prop_flat_map(|(input, len)| (Just(input.clone()), tail(len, input.into())))
1416 .prop_map(|(input, ops)| ComposeProblem { input: input.into(), ops })
1417 .boxed()
1418 }
1419 }
1420
1421 impl ComposeProblem {
1422 pub fn model(&self) -> TractResult<TypedModel> {
1423 let mut model = TypedModel::default();
1424 let mut wire = model.add_source("source", i64::fact(&self.input))?;
1425 for (ix, op) in self.ops.iter().enumerate() {
1426 wire = model.wire_node(format!("op_{ix}"), op.clone(), &[wire])?[0];
1427 }
1428 model.select_output_outlets(&[wire])?;
1429 Ok(model)
1430 }
1431
1432 fn input(&self) -> TractResult<Tensor> {
1433 unsafe {
1434 let mut t = Tensor::uninitialized::<i64>(&self.input)?;
1435 for i in 0..t.len() {
1436 t.try_as_plain_mut().unwrap().as_slice_mut().unwrap()[i] = i as i64;
1437 }
1438 Ok(t)
1439 }
1440 }
1441
1442 fn check(&self) -> TractResult<()> {
1443 crate::setup_test_logger();
1444 let input = self.input()?;
1445 let model = self.model()?;
1446 let raw = model.into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
1447 let optimized = self.model()?.into_decluttered()?;
1448 let opt = optimized.into_runnable()?.run(tvec!(input.into_tvalue()))?;
1449 opt[0].close_enough(&raw[0], false)
1450 }
1451 }
1452
1453 proptest! {
1454 #[test]
1455 fn recip(pb in any::<AxisOp>()) {
1456 assert_eq!(pb.recip().recip(), pb);
1457 }
1458
1459 #[test]
1460 fn axis_ops(pb in any::<ComposeProblem>()) {
1461 pb.check().unwrap()
1462 }
1463 }
1464
1465 #[test]
1466 fn add_0_rm_0() {
1467 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Rm(0)] };
1468 pb.check().unwrap();
1469 }
1470
1471 #[test]
1472 fn add_0_move_01() {
1473 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1)] };
1474 pb.check().unwrap();
1475 }
1476
1477 #[test]
1478 fn add_0_move_01_add_1() {
1479 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Move(0, 1), Add(1)] };
1480 pb.check().unwrap();
1481 }
1482
1483 #[test]
1484 fn recip_move_01() {
1485 let op = Move(1, 0);
1486 assert_eq!(op.recip().recip(), op);
1487 }
1488
1489 #[test]
1490 fn recip_move_20() {
1491 let op = Move(2, 0);
1492 assert_eq!(op.recip().recip(), op);
1493 }
1494
1495 #[test]
1496 fn recip_move_02() {
1497 let op = Move(0, 2);
1498 assert_eq!(op.recip().recip(), op);
1499 }
1500
1501 #[test]
1502 fn add_0_add_1_move_02() {
1503 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(1), Move(0, 2)] };
1504 pb.check().unwrap();
1505 }
1506
1507 #[test]
1508 fn add_0_add_0() {
1509 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0)] };
1510 pb.check().unwrap();
1511 }
1512
1513 #[test]
1514 fn add_0_add_0_move_02() {
1515 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(0, 2)] };
1516 pb.check().unwrap();
1517 }
1518
1519 #[test]
1520 fn add_0_add_2_move_12() {
1521 let pb = ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(2), Move(1, 2)] };
1522 pb.check().unwrap();
1523 }
1524
1525 #[test]
1526 fn add_0_add_0_move_02_rm_0() {
1527 let pb = ComposeProblem { input: tvec![1], ops: tvec![Add(0), Add(0), Move(0, 2), Rm(0)] };
1528 pb.check().unwrap();
1529 }
1530
1531 #[test]
1532 fn add_0_add_0_move_20_move_20() {
1533 let pb =
1534 ComposeProblem { input: tvec![2], ops: tvec![Add(0), Add(0), Move(2, 0), Move(2, 0)] };
1535 pb.check().unwrap();
1536 }
1537
1538 #[test]
1539 fn move_01_add_0() {
1540 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Move(0, 1), Add(0)] };
1541 pb.check().unwrap();
1542 }
1543
1544 #[test]
1545 fn add_0_move_02_move_02() {
1546 let pb = ComposeProblem { input: tvec![1, 1], ops: tvec![Add(0), Move(0, 2), Move(0, 2),] };
1547 pb.check().unwrap();
1548 }
1549
1550 #[test]
1551 fn add_0_add_2_move_20_move_12_rm_2() {
1552 let pb = ComposeProblem {
1553 input: tvec![3],
1554 ops: tvec![Add(0), Add(2), Move(2, 0), Move(1, 2), Rm(2)],
1555 };
1556 pb.check().unwrap();
1557 }
1558
1559 #[test]
1560 fn move_02_move_02() {
1561 let pb = ComposeProblem { input: tvec![2, 1, 1], ops: tvec![Move(0, 2), Move(0, 2)] };
1562 pb.check().unwrap();
1563 }
1564
1565 #[test]
1566 fn rm_1_perm_10_add_0() {
1567 let pb = ComposeProblem { input: tvec![1, 1, 2], ops: tvec![Rm(1), Move(0, 1), Add(0)] };
1568 pb.check().unwrap();
1569 }
1570
1571 #[test]
1572 fn add_2_move_02_move_02() {
1573 let pb = ComposeProblem { input: tvec![3, 2], ops: tvec![Add(2), Move(0, 2), Move(0, 2)] };
1574 pb.check().unwrap();
1575 }
1576
1577 #[test]
1578 fn move_01_move_20_move_20() {
1579 let pb = ComposeProblem {
1580 input: tvec![2, 3, 2],
1581 ops: tvec![Move(0, 1), Move(2, 0), Move(2, 0)],
1582 };
1583 pb.check().unwrap();
1584 }
1585
1586 #[test]
1587 fn reshape_axes_tracking() {
1588 let pb = ComposeProblem {
1589 input: tvec![2, 2, 2],
1590 ops: tvec![Reshape(0, tvec!(2.to_dim(), 2.to_dim()), tvec!(4.to_dim()))],
1591 };
1592 pb.check().unwrap();
1593 }
1594
1595 #[test]
1596 fn simplify_reshape() {
1597 macro_rules! d {
1598 ($($dim: expr),*) => { tvec!($($dim.to_dim()),*) }
1599 }
1600 assert_eq!(Reshape(3, d!(), d!()).simplify(), tvec!());
1601 assert_eq!(Reshape(3, d!(2, 3), d!(2, 3)).simplify(), tvec!());
1602 assert_eq!(Reshape(3, d!(1), d!()).simplify(), tvec!(Rm(3)));
1603 assert_eq!(Reshape(3, d!(), d!(1)).simplify(), tvec!(Add(3)));
1604 assert_eq!(
1605 Reshape(3, d!(2, 3, 4), d!(2, 4, 3)).simplify(),
1606 tvec!(Reshape(4, d!(3, 4), d!(4, 3)))
1607 );
1608 assert_eq!(
1609 Reshape(3, d!(3, 4, 2), d!(4, 3, 2)).simplify(),
1610 tvec!(Reshape(3, d!(3, 4), d!(4, 3)))
1611 );
1612 assert_eq!(
1613 Reshape(3, d!(1, 2, 3), d!(3, 2)).simplify(),
1614 tvec!(Rm(3), Reshape(3, d!(2, 3), d!(3, 2)))
1615 );
1616 assert_eq!(
1617 Reshape(3, d!(2, 3), d!(1, 3, 2)).simplify(),
1618 tvec!(Reshape(3, d!(2, 3), d!(3, 2)), Add(3))
1619 );
1620 assert_eq!(
1621 Reshape(3, d!(2, 3, 1), d!(3, 2)).simplify(),
1622 tvec!(Rm(5), Reshape(3, d!(2, 3), d!(3, 2)))
1623 );
1624 assert_eq!(
1625 Reshape(3, d!(2, 3), d!(3, 2, 1)).simplify(),
1626 tvec!(Add(5), Reshape(3, d!(2, 3), d!(3, 2)))
1627 );
1628 assert_eq!(
1629 Reshape(2, d!(2, 2, 1), d!(4)).simplify(),
1630 tvec!(Rm(4), Reshape(2, d!(2, 2), d!(4)))
1631 );
1632 assert_eq!(Reshape(1, d!(1, 2), d!(2)).simplify(), tvec!(Rm(1)));
1633 }
1634
1635 macro_rules! s {
1636 ($($a:expr),*) => {&[ $($a.clone().into()),* ]}
1637 }
1638
1639 macro_rules! r {
1640 ($at: expr ; $($from:expr),* => $($to:expr),*) => {
1641 AxisOp::Reshape($at, tvec!($($from.into()),*), tvec!($($to.into()),*))
1642 }
1643 }
1644
1645 #[test]
1646 fn compute_invalid() {
1647 assert!(compute_shape_with_tf_rules(s![3, 4, 5], s!(100)).is_err());
1648 }
1649
1650 #[test]
1651 fn compute_with_leading_zero() {
1652 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(0, 0, 5)).unwrap(), s![3, 4, 5])
1653 }
1654
1655 #[test]
1656 fn compute_with_leading_zero_with_flatten() {
1657 assert_eq!(
1658 &*compute_shape_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1659 s![2, 3, 35]
1660 )
1661 }
1662
1663 #[test]
1664 fn compute_with_trailing_zero() {
1665 assert_eq!(&*compute_shape_with_tf_rules(s![3, 4, 5], s!(3, -1, 0)).unwrap(), s![3, 4, 5])
1666 }
1667
1668 #[test]
1669 fn compute_bug_1() {
1670 let table = SymbolScope::default();
1671 let s = table.new_with_prefix("S");
1672 assert_eq!(
1673 &*compute_shape_with_tf_rules(s![s, 1, 2, 128], s!(0, 0, -1)).unwrap(),
1674 s![s, 1, 256]
1675 )
1676 }
1677
1678 #[test]
1679 fn compute_bug_2() {
1680 let table = SymbolScope::default();
1681 let b = table.new_with_prefix("B");
1682 let s = table.new_with_prefix("S");
1683 assert_eq!(
1684 &*compute_shape_with_tf_rules(s![s, b, 2, 128], s!(0, 0, -1)).unwrap(),
1685 s![s, b, 256]
1686 )
1687 }
1688
1689 #[test]
1690 fn compute_zero_with_rank_change() {
1691 assert_eq!(
1693 &*compute_shape_with_tf_rules(s![1, 52, 8, 32], s!(0, 0, 8, 16, 2)).unwrap(),
1694 s![1, 52, 8, 16, 2]
1695 )
1696 }
1697
1698 #[test]
1699 fn axis_op_rm_begin() {
1700 assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
1701 }
1702
1703 #[test]
1704 fn axis_op_rm_end() {
1705 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3, 1], s!(2, 3)).unwrap(), &[Rm(2)])
1706 }
1707
1708 #[test]
1709 fn axis_op_insert_begin() {
1710 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(1, 2, 3)).unwrap(), &[Add(0)])
1711 }
1712
1713 #[test]
1714 fn axis_op_insert_end() {
1715 assert_eq!(&*to_axis_ops_with_tf_rules(s![2, 3], s!(2, 3, 1)).unwrap(), &[Add(2)])
1716 }
1717
1718 #[test]
1719 fn axis_op_merge() {
1720 assert_eq!(
1721 &*to_axis_ops_with_tf_rules(s![2, 3, 5, 7], s!(2, 0, 35)).unwrap(),
1722 &[r!(2 ; 5,7 => 35 )]
1723 )
1724 }
1725
1726 #[test]
1727 fn axis_op_complex() {
1728 assert_eq!(
1729 &*to_axis_ops_with_tf_rules(s![1, 2, 3, 5, 7], s!(2, 1, 3, 35, 1)).unwrap(),
1730 &[Rm(0), Add(1), r!(3 ; 5,7 => 35 ), Add(4)]
1731 )
1732 }
1733}