tract_tensorflow/ops/
vars.rs

1use tract_hir::internal::*;
2use tract_hir::tract_core::trivial_op_state_freeeze;
3
4use crate::model::{ParsingContext, TfOpRegister};
5use crate::tfpb::tensorflow::NodeDef;
6
7pub fn register_all_ops(reg: &mut TfOpRegister) {
8    reg.insert("Assign", |_, _| Ok(Box::<Assign>::default()));
9    reg.insert("VariableV2", variable_v2);
10}
11
12fn variable_v2(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
13    let shared_name = node.get_attr_str("shared_name")?;
14    let shared_name = if !shared_name.is_empty() { Some(shared_name) } else { None };
15    let container = node.get_attr_str("container")?;
16    let container = if !container.is_empty() { Some(container) } else { None };
17    let name = node.name.to_string();
18    let id = format!("{container:?}#{shared_name:?}#{name}");
19    let shape = node.get_attr_shape("shape")?;
20    let dt = node.get_attr_datum_type("dtype")?;
21    let shape = shape
22        .into_iter()
23        .map(|d| {
24            if d > 0 {
25                Ok(d as usize)
26            } else {
27                bail!("VariableV2 shape contains forbidden negative dim.")
28            }
29        })
30        .collect::<TractResult<TVec<usize>>>()?;
31    Ok(Box::new(VariableV2::new(container, shared_name, name, id, shape, dt, None)))
32}
33
34#[derive(Clone, Debug, new)]
35struct VariableV2State;
36trivial_op_state_freeeze!(VariableV2State);
37
38impl OpState for VariableV2State {
39    fn eval(
40        &mut self,
41        session: &mut SessionState,
42        op: &dyn Op,
43        _inputs: TVec<TValue>,
44    ) -> TractResult<TVec<TValue>> {
45        let op = op.downcast_ref::<VariableV2>().context("wrong op for variable state")?;
46        let tensor = session
47            .tensors
48            .get(&op.id)
49            .with_context(|| format!("Could not find state for variable {}", op.id))?;
50        Ok(tvec!(tensor.clone().into()))
51    }
52}
53
54#[derive(Clone, Debug, new, Hash)]
55pub struct VariableV2 {
56    container: Option<String>,
57    shared_name: Option<String>,
58    name: String,
59    pub id: String,
60    shape: TVec<usize>,
61    dt: DatumType,
62    pub initializer: Option<Arc<Tensor>>,
63}
64
65
66
67impl Op for VariableV2 {
68    fn name(&self) -> StaticName {
69        "VariableV2".into()
70    }
71
72    fn info(&self) -> TractResult<Vec<String>> {
73        if let Some(init) = &self.initializer {
74            Ok(vec![format!("Initialized to {init:?}")])
75        } else {
76            Ok(vec![format!("Uninitialized")])
77        }
78    }
79
80    op_as_typed_op!();
81}
82
83impl EvalOp for VariableV2 {
84    fn is_stateless(&self) -> bool {
85        false
86    }
87
88    fn state(
89        &self,
90        state: &mut SessionState,
91        _node_id: usize,
92    ) -> TractResult<Option<Box<dyn OpState>>> {
93        let tensor = if let Some(init) = &self.initializer {
94            init.clone().into_tensor()
95        } else {
96            unsafe { Tensor::uninitialized_dt(self.dt, &self.shape)? }
97        };
98        state.tensors.insert(self.id.clone(), tensor);
99        Ok(Some(Box::new(VariableV2State)))
100    }
101}
102
103impl InferenceRulesOp for VariableV2 {
104    fn rules<'r, 'p: 'r, 's: 'r>(
105        &'s self,
106        s: &mut Solver<'r>,
107        inputs: &'p [TensorProxy],
108        outputs: &'p [TensorProxy],
109    ) -> InferenceResult {
110        check_input_arity(inputs, 0)?;
111        check_output_arity(outputs, 1)?;
112        s.equals(&outputs[0].datum_type, self.dt)?;
113        s.equals(&outputs[0].shape, ShapeFactoid::from(&*self.shape))?;
114        Ok(())
115    }
116
117    as_op!();
118    to_typed!();
119}
120
121impl TypedOp for VariableV2 {
122    as_op!();
123
124    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
125        Ok(tvec!(self.dt.fact(&self.shape)))
126    }
127}
128
129// need some dummy state to make sure Assign is a EvalOp, and will not be
130// eval-ed() in Stateless context
131#[derive(Clone, Debug, new)]
132struct AssignState;
133trivial_op_state_freeeze!(AssignState);
134
135#[derive(Clone, Debug, new, Default, Hash)]
136pub struct Assign {
137    pub var_id: Option<String>,
138}
139
140
141
142impl Op for Assign {
143    fn name(&self) -> StaticName {
144        "Assign".into()
145    }
146
147    op_as_typed_op!();
148}
149
150impl OpState for AssignState {
151    fn eval(
152        &mut self,
153        session: &mut SessionState,
154        op: &dyn Op,
155        inputs: TVec<TValue>,
156    ) -> TractResult<TVec<TValue>> {
157        let (_current, new) = args_2!(inputs);
158        let op = op.downcast_ref::<Assign>().context("wrong op for variable state")?;
159        let var_id = if let Some(ref var_id) = op.var_id {
160            var_id
161        } else {
162            bail!("Assign has not been linked to var")
163        };
164        let store = session.tensors.get_mut(var_id).unwrap();
165        if cfg!(debug_assertions)
166            && (store.shape() != new.shape() && store.datum_type() != new.datum_type())
167        {
168            bail!(
169                "Invalid assignment to variable. Store is {:?}, assigned value is {:?}",
170                store,
171                new
172            );
173        }
174        *store = new.clone().into_tensor();
175        Ok(tvec!(new))
176    }
177}
178
179impl EvalOp for Assign {
180    fn is_stateless(&self) -> bool {
181        false
182    }
183
184    fn state(
185        &self,
186        _state: &mut SessionState,
187        _node_id: usize,
188    ) -> TractResult<Option<Box<dyn OpState>>> {
189        Ok(Some(Box::new(AssignState)))
190    }
191}
192
193impl InferenceRulesOp for Assign {
194    fn rules<'r, 'p: 'r, 's: 'r>(
195        &'s self,
196        s: &mut Solver<'r>,
197        inputs: &'p [TensorProxy],
198        outputs: &'p [TensorProxy],
199    ) -> InferenceResult {
200        check_input_arity(inputs, 2)?;
201        check_output_arity(outputs, 1)?;
202        s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
203        s.equals(&inputs[0].shape, &inputs[1].shape)?;
204        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
205        s.equals(&inputs[1].shape, &inputs[0].shape)?;
206        s.equals(&outputs[0].shape, &inputs[0].shape)?;
207        s.equals(&outputs[0].value, &inputs[1].value)?;
208        Ok(())
209    }
210
211    as_op!();
212    to_typed!();
213}
214
215impl TypedOp for Assign {
216    as_op!();
217
218    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
219        if inputs[0].datum_type != inputs[1].datum_type || inputs[0].shape != inputs[1].shape {
220            bail!("Invalid assignement {:?}", inputs);
221        }
222        Ok(tvec!(inputs[0].clone()))
223    }
224}