tract_tensorflow/ops/
vars.rs1use tract_hir::internal::*;
2use tract_hir::tract_core::trivial_op_state_freeeze;
3
4use crate::model::{ParsingContext, TfOpRegister};
5use crate::tfpb::tensorflow::NodeDef;
6
7pub fn register_all_ops(reg: &mut TfOpRegister) {
8 reg.insert("Assign", |_, _| Ok(Box::<Assign>::default()));
9 reg.insert("VariableV2", variable_v2);
10}
11
12fn variable_v2(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
13 let shared_name = node.get_attr_str("shared_name")?;
14 let shared_name = if !shared_name.is_empty() { Some(shared_name) } else { None };
15 let container = node.get_attr_str("container")?;
16 let container = if !container.is_empty() { Some(container) } else { None };
17 let name = node.name.to_string();
18 let id = format!("{container:?}#{shared_name:?}#{name}");
19 let shape = node.get_attr_shape("shape")?;
20 let dt = node.get_attr_datum_type("dtype")?;
21 let shape = shape
22 .into_iter()
23 .map(|d| {
24 if d > 0 {
25 Ok(d as usize)
26 } else {
27 bail!("VariableV2 shape contains forbidden negative dim.")
28 }
29 })
30 .collect::<TractResult<TVec<usize>>>()?;
31 Ok(Box::new(VariableV2::new(container, shared_name, name, id, shape, dt, None)))
32}
33
34#[derive(Clone, Debug, new)]
35struct VariableV2State;
36trivial_op_state_freeeze!(VariableV2State);
37
38impl OpState for VariableV2State {
39 fn eval(
40 &mut self,
41 session: &mut SessionState,
42 op: &dyn Op,
43 _inputs: TVec<TValue>,
44 ) -> TractResult<TVec<TValue>> {
45 let op = op.downcast_ref::<VariableV2>().context("wrong op for variable state")?;
46 let tensor = session
47 .tensors
48 .get(&op.id)
49 .with_context(|| format!("Could not find state for variable {}", op.id))?;
50 Ok(tvec!(tensor.clone().into()))
51 }
52}
53
54#[derive(Clone, Debug, new, Hash)]
55pub struct VariableV2 {
56 container: Option<String>,
57 shared_name: Option<String>,
58 name: String,
59 pub id: String,
60 shape: TVec<usize>,
61 dt: DatumType,
62 pub initializer: Option<Arc<Tensor>>,
63}
64
65
66
67impl Op for VariableV2 {
68 fn name(&self) -> StaticName {
69 "VariableV2".into()
70 }
71
72 fn info(&self) -> TractResult<Vec<String>> {
73 if let Some(init) = &self.initializer {
74 Ok(vec![format!("Initialized to {init:?}")])
75 } else {
76 Ok(vec![format!("Uninitialized")])
77 }
78 }
79
80 op_as_typed_op!();
81}
82
83impl EvalOp for VariableV2 {
84 fn is_stateless(&self) -> bool {
85 false
86 }
87
88 fn state(
89 &self,
90 state: &mut SessionState,
91 _node_id: usize,
92 ) -> TractResult<Option<Box<dyn OpState>>> {
93 let tensor = if let Some(init) = &self.initializer {
94 init.clone().into_tensor()
95 } else {
96 unsafe { Tensor::uninitialized_dt(self.dt, &self.shape)? }
97 };
98 state.tensors.insert(self.id.clone(), tensor);
99 Ok(Some(Box::new(VariableV2State)))
100 }
101}
102
103impl InferenceRulesOp for VariableV2 {
104 fn rules<'r, 'p: 'r, 's: 'r>(
105 &'s self,
106 s: &mut Solver<'r>,
107 inputs: &'p [TensorProxy],
108 outputs: &'p [TensorProxy],
109 ) -> InferenceResult {
110 check_input_arity(inputs, 0)?;
111 check_output_arity(outputs, 1)?;
112 s.equals(&outputs[0].datum_type, self.dt)?;
113 s.equals(&outputs[0].shape, ShapeFactoid::from(&*self.shape))?;
114 Ok(())
115 }
116
117 as_op!();
118 to_typed!();
119}
120
121impl TypedOp for VariableV2 {
122 as_op!();
123
124 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
125 Ok(tvec!(self.dt.fact(&self.shape)))
126 }
127}
128
129#[derive(Clone, Debug, new)]
132struct AssignState;
133trivial_op_state_freeeze!(AssignState);
134
135#[derive(Clone, Debug, new, Default, Hash)]
136pub struct Assign {
137 pub var_id: Option<String>,
138}
139
140
141
142impl Op for Assign {
143 fn name(&self) -> StaticName {
144 "Assign".into()
145 }
146
147 op_as_typed_op!();
148}
149
150impl OpState for AssignState {
151 fn eval(
152 &mut self,
153 session: &mut SessionState,
154 op: &dyn Op,
155 inputs: TVec<TValue>,
156 ) -> TractResult<TVec<TValue>> {
157 let (_current, new) = args_2!(inputs);
158 let op = op.downcast_ref::<Assign>().context("wrong op for variable state")?;
159 let var_id = if let Some(ref var_id) = op.var_id {
160 var_id
161 } else {
162 bail!("Assign has not been linked to var")
163 };
164 let store = session.tensors.get_mut(var_id).unwrap();
165 if cfg!(debug_assertions)
166 && (store.shape() != new.shape() && store.datum_type() != new.datum_type())
167 {
168 bail!(
169 "Invalid assignment to variable. Store is {:?}, assigned value is {:?}",
170 store,
171 new
172 );
173 }
174 *store = new.clone().into_tensor();
175 Ok(tvec!(new))
176 }
177}
178
179impl EvalOp for Assign {
180 fn is_stateless(&self) -> bool {
181 false
182 }
183
184 fn state(
185 &self,
186 _state: &mut SessionState,
187 _node_id: usize,
188 ) -> TractResult<Option<Box<dyn OpState>>> {
189 Ok(Some(Box::new(AssignState)))
190 }
191}
192
193impl InferenceRulesOp for Assign {
194 fn rules<'r, 'p: 'r, 's: 'r>(
195 &'s self,
196 s: &mut Solver<'r>,
197 inputs: &'p [TensorProxy],
198 outputs: &'p [TensorProxy],
199 ) -> InferenceResult {
200 check_input_arity(inputs, 2)?;
201 check_output_arity(outputs, 1)?;
202 s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
203 s.equals(&inputs[0].shape, &inputs[1].shape)?;
204 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
205 s.equals(&inputs[1].shape, &inputs[0].shape)?;
206 s.equals(&outputs[0].shape, &inputs[0].shape)?;
207 s.equals(&outputs[0].value, &inputs[1].value)?;
208 Ok(())
209 }
210
211 as_op!();
212 to_typed!();
213}
214
215impl TypedOp for Assign {
216 as_op!();
217
218 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
219 if inputs[0].datum_type != inputs[1].datum_type || inputs[0].shape != inputs[1].shape {
220 bail!("Invalid assignement {:?}", inputs);
221 }
222 Ok(tvec!(inputs[0].clone()))
223 }
224}