tract_core/ops/memory/
store.rs

1use crate::internal::*;
2
3#[derive(Clone, Debug, PartialEq, Eq)]
4pub struct Store {
5    pub id: String,
6}
7
8impl Store {
9    pub fn new(id: &str) -> Store {
10        Store { id: id.to_string() }
11    }
12}
13
14impl Op for Store {
15    fn name(&self) -> StaticName {
16        "Store".into()
17    }
18
19    fn info(&self) -> TractResult<Vec<String>> {
20        Ok(vec![format!("id: {:?}", self.id)])
21    }
22
23    impl_op_same_as!();
24    op_as_typed_op!();
25}
26
27impl EvalOp for Store {
28    fn is_stateless(&self) -> bool {
29        false
30    }
31
32    fn state(
33        &self,
34        _session: &mut SessionState,
35        _node_id: usize,
36    ) -> TractResult<Option<Box<dyn OpState>>> {
37        Ok(Some(Box::new(self.clone())))
38    }
39}
40
41impl TypedOp for Store {
42    as_op!();
43
44    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
45        ensure!(
46            inputs.len() == 2,
47            "Expected two inputs (input to propagate and state to store) for Store op"
48        );
49        Ok(tvec![inputs[0].clone()])
50    }
51}
52
53impl OpState for Store {
54    fn eval(
55        &mut self,
56        session: &mut SessionState,
57        _op: &dyn Op,
58        inputs: TVec<TValue>,
59    ) -> TractResult<TVec<TValue>> {
60        let (input, state) = args_2!(inputs);
61        session.tensors.insert(self.id.clone(), state.into_tensor());
62        Ok(tvec![input])
63    }
64}
65
66trivial_op_state_freeeze!(Store);