tract_hir/ops/array/
split.rs1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Default, Hash)]
5pub struct Split {
6 axis: isize,
7 outputs: usize,
8 split: Option<Vec<usize>>,
9}
10
11
12
13impl Split {
14 fn split_dims<D: DimLike>(&self, input: &D) -> TractResult<TVec<D>> {
15 if let Some(split) = self.split.as_ref() {
16 Ok(split.iter().map(|&d| D::from(d)).collect())
17 } else {
18 let bigs = input.clone().divceil(self.outputs);
19 let last = input.clone() - (bigs.clone() * (self.outputs - 1));
20 let mut splits = tvec!(bigs ; self.outputs - 1);
21 splits.push(last);
22 Ok(splits)
23 }
24 }
25}
26
27impl Expansion for Split {
28 fn name(&self) -> StaticName {
29 "Split".into()
30 }
31
32 fn rules<'r, 'p: 'r, 's: 'r>(
33 &'s self,
34 s: &mut Solver<'r>,
35 inputs: &'p [TensorProxy],
36 outputs: &'p [TensorProxy],
37 ) -> InferenceResult {
38 check_input_arity(inputs, 1)?;
39 check_output_arity(outputs, self.outputs)?;
40 (0..self.outputs).try_for_each(|i| {
41 s.equals(&inputs[0].datum_type, &outputs[i].datum_type)?;
42 s.equals(&inputs[0].rank, &outputs[i].rank)
43 })?;
44 s.given(&inputs[0].shape, move |s, shape| {
45 let axis =
46 if self.axis < 0 { self.axis + shape.len() as isize } else { self.axis } as usize;
47 let dims = self.split_dims(&shape[axis])?;
48 for i in 0..self.outputs {
49 let mut shape = shape.clone();
50 shape[axis] = dims[i].clone();
51 s.equals(&outputs[i].shape, shape)?;
52 }
53 Ok(())
54 })?;
55 Ok(())
56 }
57
58 fn nboutputs(&self) -> TractResult<usize> {
59 Ok(self.outputs)
60 }
61
62 fn wire(
63 &self,
64 prefix: &str,
65 target: &mut TypedModel,
66 inputs: &[OutletId],
67 ) -> TractResult<TVec<OutletId>> {
68 let input = target.outlet_fact(inputs[0])?.clone();
69 let mut outputs = tvec!();
70 let mut current = 0.to_dim();
71 let axis =
72 if self.axis < 0 { self.axis + input.rank() as isize } else { self.axis } as usize;
73 for (ix, len) in self.split_dims(&input.shape[axis])?.into_iter().enumerate() {
74 let end = current.clone() + len;
75 outputs.push(
76 target.wire_node(
77 format!("{prefix}.axis{axis}_slice{ix}_{current}..{end}"),
78 crate::ops::array::Slice::new(axis, current, end.clone()),
79 inputs,
80 )?[0],
81 );
82 current = end;
83 }
84 Ok(outputs)
85 }
86}