Skip to main content

tract_core/ops/
identity.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)]
4pub struct Identity;
5
6impl Op for Identity {
7    fn name(&self) -> StaticName {
8        "Identity".into()
9    }
10
11    op_as_typed_op!();
12}
13
14impl EvalOp for Identity {
15    fn is_stateless(&self) -> bool {
16        true
17    }
18
19    /// Evaluates the operation given the input tensors.
20    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
21        Ok(inputs)
22    }
23}
24
25impl TypedOp for Identity {
26    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
27        Ok(tvec!(inputs[0].clone()))
28    }
29
30    fn input_roi(
31        &self,
32        model: &TypedModel,
33        node: &TypedNode,
34    ) -> TractResult<Option<TVec<Option<TDim>>>> {
35        crate::optim::propagate_roi::bubble_roi(model, node)
36    }
37
38    fn declutter(
39        &self,
40        model: &TypedModel,
41        node: &TypedNode,
42    ) -> TractResult<Option<TypedModelPatch>> {
43        TypedModelPatch::shunt_one_op(model, node)
44    }
45
46    fn fuse(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
47        TypedModelPatch::shunt_one_op(model, node)
48    }
49
50    fn axes_mapping(
51        &self,
52        inputs: &[&TypedFact],
53        outputs: &[&TypedFact],
54    ) -> TractResult<AxesMapping> {
55        AxesMapping::natural(inputs, outputs)
56    }
57
58    as_op!();
59}
60
61#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)]
62pub struct PinConst;
63
64impl Op for PinConst {
65    fn name(&self) -> StaticName {
66        "PinConst".into()
67    }
68
69    op_as_typed_op!();
70}
71
72impl EvalOp for PinConst {
73    fn is_stateless(&self) -> bool {
74        false
75    }
76
77    /// Evaluates the operation given the input tensors.
78    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
79        Ok(inputs)
80    }
81
82    fn state(
83        &self,
84        _session: &TurnState,
85        _node_id: usize,
86    ) -> TractResult<Option<Box<dyn OpState>>> {
87        Ok(Some(Box::new(self.clone())))
88    }
89}
90
91impl OpState for PinConst {
92    fn eval(
93        &mut self,
94        _session: &mut TurnState,
95        _op: &dyn Op,
96        inputs: TVec<TValue>,
97    ) -> TractResult<TVec<TValue>> {
98        Ok(inputs)
99    }
100}
101
102impl TypedOp for PinConst {
103    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
104        Ok(tvec!(inputs[0].without_value()))
105    }
106
107    as_op!();
108}
109
110trivial_op_state_freeze!(PinConst);