tract_core/ops/memory/
load.rs1use crate::internal::*;
2
3#[derive(Clone, Debug, PartialEq, Eq)]
4pub struct Load {
5 pub id: String,
6}
7
8impl Load {
9 pub fn new(id: &str) -> Load {
10 Load { id: id.to_string() }
11 }
12}
13
14impl Op for Load {
15 fn name(&self) -> StaticName {
16 "Load".into()
17 }
18
19 fn info(&self) -> TractResult<Vec<String>> {
20 Ok(vec![format!("id: {:?}", self.id)])
21 }
22
23 impl_op_same_as!();
24 op_as_typed_op!();
25}
26
27impl EvalOp for Load {
28 fn is_stateless(&self) -> bool {
29 false
30 }
31
32 fn state(
33 &self,
34 _session: &mut SessionState,
35 _node_id: usize,
36 ) -> TractResult<Option<Box<dyn OpState>>> {
37 Ok(Some(Box::new(self.clone())))
38 }
39}
40
41impl TypedOp for Load {
42 as_op!();
43
44 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
45 ensure!(inputs.len() == 1, "Expected one input (default value) for Load op");
46 let input_facts = inputs
48 .iter()
49 .map(|it| TypedFact::dt_shape(it.datum_type, it.shape.clone()))
50 .collect::<TVec<_>>();
51 Ok(input_facts)
52 }
53}
54
55impl OpState for Load {
56 fn eval(
57 &mut self,
58 session: &mut SessionState,
59 _op: &dyn Op,
60 inputs: TVec<TValue>,
61 ) -> TractResult<TVec<TValue>> {
62 let input = args_1!(inputs);
63 let tensor = session
64 .tensors
65 .get(&self.id)
66 .map_or_else(
67 || -> TractResult<TVec<TValue>> { Ok(tvec!(input.clone())) },
68 |it| {
69 ensure!(
71 it.datum_type() == input.datum_type(),
72 anyhow!(
73 "Expected datum {:?}, found {:?}",
74 input.datum_type(),
75 it.datum_type()
76 )
77 );
78 ensure!(
79 it.shape() == input.shape(),
80 anyhow!("Expected shape {:?}, found {:?}", input.shape(), it.shape())
81 );
82 Ok(tvec!(it.clone().into_tvalue()))
83 },
84 )
85 .with_context(|| anyhow!("While loading tensor from session"))?;
86
87 Ok(tensor)
88 }
89}
90
91trivial_op_state_freeeze!(Load);