tract_tensorflow/
model.rs1use crate::tfpb::tensorflow::{GraphDef, NodeDef, SavedModel};
2use prost::Message;
3use std::{fs, path};
4use tract_hir::internal::*;
5
6#[derive(Default)]
7pub struct ParsingContext {
8 pub node_output_arities: HashMap<String, usize>,
9}
10
11type OpBuilder = fn(&ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>>;
12
13#[derive(Clone, Default)]
14pub struct TfOpRegister(pub HashMap<String, OpBuilder>);
15
16impl TfOpRegister {
17 pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
18 self.0.insert(s.into(), builder);
19 }
20}
21
22pub struct Tensorflow {
23 pub op_register: TfOpRegister,
24}
25
26pub struct TfModelExtensions {
27 pub control_inputs: Vec<(usize, usize)>,
28}
29
30impl TfModelExtensions {
31 pub fn preproc(&self, original: InferenceModel) -> TractResult<InferenceModel> {
32 Ok(original)
33 }
34}
35
36pub struct TfModelAndExtensions(pub InferenceModel, pub TfModelExtensions);
37
38impl Tensorflow {
39 fn parse_input(i: &str) -> TractResult<(&str, usize)> {
45 let pair = if let Some(stripped) = i.strip_prefix('^') {
46 (stripped, 0)
47 } else {
48 let splits: Vec<_> = i.splitn(2, ':').collect();
49 (splits[0], if splits.len() > 1 { splits[1].parse::<usize>()? } else { 0 })
50 };
51 Ok(pair)
52 }
53
54 pub fn determinize(model: &mut GraphDef) -> TractResult<()> {
55 for pbnode in &mut model.node {
56 if pbnode.op == "RandomUniform"
57 && pbnode.get_attr_int::<i64>("seed")? == 0
58 && pbnode.get_attr_int::<i64>("seed2")? == 0
59 {
60 pbnode.attr.insert("seed".to_string(), 1.into());
61 pbnode.attr.insert("seed2".to_string(), 1.into());
62 }
63 }
64 Ok(())
65 }
66
67 #[cfg(target_family = "wasm")]
68 pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
69 use std::io::Read;
70 let mut file = fs::File::open(p)?;
71 let mut v = Vec::with_capacity(file.metadata()?.len() as usize);
72 file.read_to_end(&mut v)?;
73 let b = bytes::Bytes::from(v);
74 Ok(GraphDef::decode(b)?)
75 }
76
77 #[cfg(all(any(windows, unix), not(target_os = "emscripten")))]
78 pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
79 let map = unsafe { memmap2::Mmap::map(&fs::File::open(p)?)? };
80 Ok(GraphDef::decode(&*map)?)
81 }
82
83 pub fn read_frozen_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
84 let mut v = vec![];
85 r.read_to_end(&mut v)?;
86 let b = bytes::Bytes::from(v);
87 Ok(GraphDef::decode(b)?)
88 }
89
90 pub fn open_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<SavedModel> {
91 let mut v = vec![];
92 r.read_to_end(&mut v)?;
93 let b = bytes::Bytes::from(v);
94 Ok(SavedModel::decode(b)?)
95 }
96
97 pub fn read_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
100 let mut saved = self.open_saved_model(r)?;
101 Ok(saved.meta_graphs.remove(0).graph_def.unwrap())
102 }
103
104 pub fn parse_graph(&self, graph: &GraphDef) -> TractResult<TfModelAndExtensions> {
105 self.parse_graph_with_template(graph, Default::default())
106 }
107
108 pub fn parse_graph_with_template(
109 &self,
110 graph: &GraphDef,
111 mut model: InferenceModel,
112 ) -> TractResult<TfModelAndExtensions> {
113 use crate::ops::control_flow as cf;
114
115 let mut inputs = tvec!();
116 let mut context = ParsingContext::default();
117 let mut control_inputs = vec![];
118
119 for pbnode in &graph.node {
121 for i in &pbnode.input {
122 let (node, slot) = Self::parse_input(i)?;
123 let arity = context.node_output_arities.entry(node.to_string()).or_insert(1);
124 *arity = (*arity).max(slot + 1);
125 }
126 }
127
128 for pbnode in &graph.node {
129 let name = &pbnode.name;
130
131 if pbnode.op == "NextIteration" {
132 let source_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Source);
133 let sink_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Sink);
134 let _source =
135 model.add_node(name.clone(), source_op, tvec!(InferenceFact::default()))?;
136 let _sink = model.add_node(format!("{name}-Sink"), sink_op, tvec!())?;
137 continue;
138 }
139
140 let op = match self.op_register.0.get(&pbnode.op) {
141 Some(builder) => (builder)(&context, pbnode)?,
142 None => tract_hir::ops::unimpl::UnimplementedOp::new(
143 context.node_output_arities.get(name).cloned().unwrap_or(1),
144 &pbnode.op,
145 format!("{pbnode:?}"),
146 )
147 .into(),
148 };
149
150 let noutputs =
151 op.nboutputs()?.max(context.node_output_arities.get(name).cloned().unwrap_or(1));
152 let facts = tvec!(InferenceFact::default(); noutputs);
153
154 let node_id = model.add_node(name.clone(), op, facts)?;
155 if pbnode.op == "Placeholder" {
156 let dt = pbnode.get_attr_datum_type("dtype")?;
157 let mut fact = InferenceFact::dt(dt);
158 if let Some(shape) = pbnode.get_attr_opt_shape("shape")? {
159 let shape_factoid = ShapeFactoid::closed(
160 shape
161 .iter()
162 .map(|d| {
163 if *d == -1 {
164 GenericFactoid::Any
165 } else {
166 GenericFactoid::Only(d.to_dim())
167 }
168 })
169 .collect(),
170 );
171 fact = fact.with_shape(shape_factoid);
172 }
173 inputs.push(OutletId::new(node_id, 0));
174 model.set_outlet_fact(OutletId::new(node_id, 0), fact)?;
175 }
176 }
177
178 for pbnode in &graph.node {
179 let node_id = if pbnode.op == "NextIteration" {
180 model.node_by_name(&*format!("{}-Sink", &pbnode.name))?.id
181 } else {
182 model.node_by_name(&pbnode.name)?.id
183 };
184 for (ix, i) in pbnode.input.iter().filter(|n| !n.starts_with('^')).enumerate() {
185 let input = Self::parse_input(i)?;
186 let prec = model.node_by_name(input.0)?.id;
187 let outlet = OutletId::new(prec, input.1);
188 let inlet = InletId::new(node_id, ix);
189 model.add_edge(outlet, inlet)?;
190 model.set_outlet_label(outlet, i.to_string())?;
191 }
192 for i in pbnode.input.iter().filter(|n| n.starts_with('^')) {
193 let input = Self::parse_input(i)?;
194 let prec = model.node_by_name(input.0)?.id;
195 control_inputs.push((model.node_id_by_name(&pbnode.name)?, prec));
196 }
197 }
198
199 model.set_input_outlets(&inputs)?;
226 model.auto_outputs()?;
227 let extensions = TfModelExtensions { control_inputs };
228 Ok(TfModelAndExtensions(model, extensions))
229 }
230}
231
232impl Framework<GraphDef, InferenceModel> for Tensorflow {
233 fn proto_model_for_path(&self, r: impl AsRef<path::Path>) -> TractResult<GraphDef> {
235 self.read_frozen_model(&mut fs::File::open(r.as_ref())?)
236 .or_else(|_| self.read_saved_model(&mut fs::File::open(r.as_ref())?))
237 }
238
239 fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
242 self.read_frozen_model(r)
243 }
244
245 fn model_for_proto_model_with_model_template(
246 &self,
247 proto: &GraphDef,
248 template: InferenceModel,
249 ) -> TractResult<InferenceModel> {
250 Ok(self.parse_graph_with_template(proto, template)?.0)
251 }
252}