1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use crate::internal::*;

#[derive(Debug, Clone, PartialEq, Eq, Hash, new)]
pub struct DynSlice {
    pub axis: usize,
    pub len: TDim,
}

impl DynSlice {
    pub fn suffix(&self) -> String {
        format!("axis{}", self.axis)
    }
}

impl Op for DynSlice {
    fn name(&self) -> Cow<str> {
        "DynSlice".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![format!("axis: {}", self.axis)])
    }

    op_as_typed_op!();

    fn same_as(&self, other: &dyn Op) -> bool {
        if let Some(other) = other.downcast_ref::<Self>() {
            other == self
        } else {
            false
        }
    }
}

impl EvalOp for DynSlice {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval_with_session(
        &self,
        session: &SessionState,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let start = inputs[1]
            .cast_to::<TDim>()?
            .to_scalar::<TDim>()?
            .eval(&session.resolved_symbols)
            .to_usize()?;
        let end = inputs[2]
            .cast_to::<TDim>()?
            .to_scalar::<TDim>()?
            .eval(&session.resolved_symbols)
            .to_usize()?;
        ensure!(start <= end);
        if let Ok(len) = self.len.eval(&session.resolved_symbols).to_usize() {
            ensure!(start + len == end);
        }
        let slice = inputs[0].slice(self.axis, start, end)?;
        Ok(tvec!(slice.into()))
    }
}

impl TypedOp for DynSlice {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        ensure!(inputs.len() == 3);
        let mut fact = inputs[0].without_value();
        fact.shape.set(self.axis, self.len.clone());
        Ok(tvec!(fact))
    }

    fn axes_mapping(
        &self,
        inputs: &[&TypedFact],
        _outputs: &[&TypedFact],
    ) -> TractResult<AxesMapping> {
        AxesMapping::natural_for_rank(1, 1, inputs[0].rank())?
            .with_extra_input(1)?
            .with_extra_input(2)
    }

    fn change_axes(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        io: InOut,
        change: &AxisOp,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        if io == InOut::In(1) || io == InOut::In(2) {
            return Ok(None);
        }
        if let Some(axis) = change.transform_axis(self.axis) {
            if axis != self.axis {
                Ok(Some(AxisChangeConsequence::new(
                    model,
                    node,
                    Some(Box::new(DynSlice { axis, ..self.clone() }) as _),
                    change,
                )))
            } else {
                Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
            }
        } else {
            Ok(None)
        }
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let inputs = model.node_input_facts(node.id)?;
        if let (Some(start), Some(end)) = (&inputs[1].konst, &inputs[2].konst) {
            let start = start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();
            let end = end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone();

            return Ok(Some(TypedModelPatch::replace_single_op(
                model,
                node,
                &[node.inputs[0]],
                crate::ops::array::Slice { axis: self.axis, start, end },
            )?));
        }
        Ok(None)
    }

    as_op!();
}