tract_tensorflow/
model.rs1use crate::tfpb::tensorflow::{GraphDef, NodeDef, SavedModel};
2use prost::Message;
3use std::{fs, path};
4use tract_hir::internal::*;
5
6#[derive(Default)]
7pub struct ParsingContext {
8 pub node_output_arities: HashMap<String, usize>,
9}
10
11type OpBuilder = fn(&ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>>;
12
13#[derive(Clone, Default)]
14pub struct TfOpRegister(pub HashMap<String, OpBuilder>);
15
16impl TfOpRegister {
17 pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
18 self.0.insert(s.into(), builder);
19 }
20}
21
22pub struct Tensorflow {
23 pub op_register: TfOpRegister,
24}
25
26pub struct TfModelExtensions {
27 pub control_inputs: Vec<(usize, usize)>,
28 pub initializing_nodes: Vec<usize>,
29}
30
31impl TfModelExtensions {
32 pub fn preproc(&self, mut original: InferenceModel) -> TractResult<InferenceModel> {
33 if self.initializing_nodes.len() > 0 {
34 let as_outlets =
35 self.initializing_nodes.iter().map(|n| OutletId::new(*n, 0)).collect::<Vec<_>>();
36 let plan = SimplePlan::build(
37 &original,
38 &as_outlets,
39 &self.control_inputs,
40 &PlanOptions::default(),
41 )?;
42 let mut state = SimpleState::new(plan)?;
43 state.exec()?;
44 let tensors = state.session_state.tensors;
45 for node in &mut original.nodes {
46 if let Some(var) = node.op_as_mut::<crate::ops::vars::VariableV2>() {
47 if let Some(value) = tensors.get(&var.id) {
48 var.initializer = Some(value.clone().into_arc_tensor());
49 }
50 }
51 }
52 }
53 Ok(original)
54 }
55}
56
57pub struct TfModelAndExtensions(pub InferenceModel, pub TfModelExtensions);
58
59impl Tensorflow {
60 fn parse_input(i: &str) -> TractResult<(&str, usize)> {
66 let pair = if let Some(stripped) = i.strip_prefix('^') {
67 (stripped, 0)
68 } else {
69 let splits: Vec<_> = i.splitn(2, ':').collect();
70 (splits[0], if splits.len() > 1 { splits[1].parse::<usize>()? } else { 0 })
71 };
72 Ok(pair)
73 }
74
75 pub fn determinize(model: &mut GraphDef) -> TractResult<()> {
76 for pbnode in &mut model.node {
77 if pbnode.op == "RandomUniform"
78 && pbnode.get_attr_int::<i64>("seed")? == 0
79 && pbnode.get_attr_int::<i64>("seed2")? == 0
80 {
81 pbnode.attr.insert("seed".to_string(), 1.into());
82 pbnode.attr.insert("seed2".to_string(), 1.into());
83 }
84 }
85 Ok(())
86 }
87
88 #[cfg(target_family = "wasm")]
89 pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
90 use std::io::Read;
91 let mut file = fs::File::open(p)?;
92 let mut v = Vec::with_capacity(file.metadata()?.len() as usize);
93 file.read_to_end(&mut v)?;
94 let b = bytes::Bytes::from(v);
95 Ok(GraphDef::decode(b)?)
96 }
97
98 #[cfg(all(any(windows, unix), not(target_os = "emscripten")))]
99 pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
100 let map = unsafe { memmap2::Mmap::map(&fs::File::open(p)?)? };
101 Ok(GraphDef::decode(&*map)?)
102 }
103
104 pub fn read_frozen_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
105 let mut v = vec![];
106 r.read_to_end(&mut v)?;
107 let b = bytes::Bytes::from(v);
108 Ok(GraphDef::decode(b)?)
109 }
110
111 pub fn open_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<SavedModel> {
112 let mut v = vec![];
113 r.read_to_end(&mut v)?;
114 let b = bytes::Bytes::from(v);
115 Ok(SavedModel::decode(b)?)
116 }
117
118 pub fn read_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
121 let mut saved = self.open_saved_model(r)?;
122 Ok(saved.meta_graphs.remove(0).graph_def.unwrap())
123 }
124
125 pub fn parse_graph(&self, graph: &GraphDef) -> TractResult<TfModelAndExtensions> {
126 self.parse_graph_with_template(graph, Default::default())
127 }
128
129 pub fn parse_graph_with_template(
130 &self,
131 graph: &GraphDef,
132 mut model: InferenceModel
133 ) -> TractResult<TfModelAndExtensions> {
134 use crate::ops::control_flow as cf;
135
136 let mut inputs = tvec!();
137 let mut context = ParsingContext::default();
138 let mut control_inputs = vec![];
139
140 for pbnode in &graph.node {
142 for i in &pbnode.input {
143 let (node, slot) = Self::parse_input(i)?;
144 let arity = context.node_output_arities.entry(node.to_string()).or_insert(1);
145 *arity = (*arity).max(slot + 1);
146 }
147 }
148
149 for pbnode in &graph.node {
150 let name = &pbnode.name;
151
152 if pbnode.op == "NextIteration" {
153 let source_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Source);
154 let sink_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Sink);
155 let _source =
156 model.add_node(name.clone(), source_op, tvec!(InferenceFact::default()))?;
157 let _sink = model.add_node(format!("{name}-Sink"), sink_op, tvec!())?;
158 continue;
159 }
160
161 let op = match self.op_register.0.get(&pbnode.op) {
162 Some(builder) => (builder)(&context, pbnode)?,
163 None => tract_hir::ops::unimpl::UnimplementedOp::new(
164 context.node_output_arities.get(name).cloned().unwrap_or(1),
165 &pbnode.op,
166 format!("{pbnode:?}"),
167 )
168 .into(),
169 };
170
171 let noutputs =
172 op.nboutputs()?.max(context.node_output_arities.get(name).cloned().unwrap_or(1));
173 let facts = tvec!(InferenceFact::default(); noutputs);
174
175 let node_id = model.add_node(name.clone(), op, facts)?;
176 if pbnode.op == "Placeholder" {
177 let dt = pbnode.get_attr_datum_type("dtype")?;
178 let mut fact = InferenceFact::dt(dt);
179 if let Some(shape) = pbnode.get_attr_opt_shape("shape")? {
180 let shape_factoid = ShapeFactoid::closed(
181 shape
182 .iter()
183 .map(|d| {
184 if *d == -1 {
185 GenericFactoid::Any
186 } else {
187 GenericFactoid::Only(d.to_dim())
188 }
189 })
190 .collect(),
191 );
192 fact = fact.with_shape(shape_factoid);
193 }
194 inputs.push(OutletId::new(node_id, 0));
195 model.set_outlet_fact(OutletId::new(node_id, 0), fact)?;
196 }
197 }
198
199 for pbnode in &graph.node {
200 let node_id = if pbnode.op == "NextIteration" {
201 model.node_by_name(&*format!("{}-Sink", &pbnode.name))?.id
202 } else {
203 model.node_by_name(&pbnode.name)?.id
204 };
205 for (ix, i) in pbnode.input.iter().filter(|n| !n.starts_with('^')).enumerate() {
206 let input = Self::parse_input(i)?;
207 let prec = model.node_by_name(input.0)?.id;
208 let outlet = OutletId::new(prec, input.1);
209 let inlet = InletId::new(node_id, ix);
210 model.add_edge(outlet, inlet)?;
211 model.set_outlet_label(outlet, i.to_string())?;
212 }
213 for i in pbnode.input.iter().filter(|n| n.starts_with('^')) {
214 let input = Self::parse_input(i)?;
215 let prec = model.node_by_name(input.0)?.id;
216 control_inputs.push((model.node_id_by_name(&pbnode.name)?, prec));
217 }
218 }
219
220 for id in 0..model.nodes().len() {
230 use crate::ops::vars::*;
231 if model.node(id).op_is::<Assign>() {
232 let prec = model.node(id).inputs[0];
233 let var_id = model.node(prec.node).op_as::<VariableV2>().map(|v| v.id.clone());
234 if let (Some(var_id), Some(assign)) =
235 (var_id, model.node_mut(id).op_as_mut::<Assign>())
236 {
237 assign.var_id = Some(var_id);
238 } else {
239 bail!("Model contains unlinked Assign/Variable2");
240 }
241 }
242 }
243 model.set_input_outlets(&inputs)?;
244 model.auto_outputs()?;
245 let extensions = TfModelExtensions { control_inputs, initializing_nodes: vec![] };
246 Ok(TfModelAndExtensions(model, extensions))
247 }
248}
249
250impl Framework<GraphDef, InferenceModel> for Tensorflow {
251 fn proto_model_for_path(&self, r: impl AsRef<path::Path>) -> TractResult<GraphDef> {
253 self.read_frozen_model(&mut fs::File::open(r.as_ref())?)
254 .or_else(|_| self.read_saved_model(&mut fs::File::open(r.as_ref())?))
255 }
256
257 fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
260 self.read_frozen_model(r)
261 }
262
263 fn model_for_proto_model_with_model_template(
264 &self,
265 proto: &GraphDef,
266 template: InferenceModel,
267 ) -> TractResult<InferenceModel> {
268 Ok(self.parse_graph_with_template(proto, template)?.0)
269 }
270}