tract_tensorflow/
model.rs

1use crate::tfpb::tensorflow::{GraphDef, NodeDef, SavedModel};
2use prost::Message;
3use std::{fs, path};
4use tract_hir::internal::*;
5
6#[derive(Default)]
7pub struct ParsingContext {
8    pub node_output_arities: HashMap<String, usize>,
9}
10
11type OpBuilder = fn(&ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>>;
12
13#[derive(Clone, Default)]
14pub struct TfOpRegister(pub HashMap<String, OpBuilder>);
15
16impl TfOpRegister {
17    pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
18        self.0.insert(s.into(), builder);
19    }
20}
21
22pub struct Tensorflow {
23    pub op_register: TfOpRegister,
24}
25
26pub struct TfModelExtensions {
27    pub control_inputs: Vec<(usize, usize)>,
28    pub initializing_nodes: Vec<usize>,
29}
30
31impl TfModelExtensions {
32    pub fn preproc(&self, mut original: InferenceModel) -> TractResult<InferenceModel> {
33        if self.initializing_nodes.len() > 0 {
34            let as_outlets =
35                self.initializing_nodes.iter().map(|n| OutletId::new(*n, 0)).collect::<Vec<_>>();
36            let plan = SimplePlan::build(
37                &original,
38                &as_outlets,
39                &self.control_inputs,
40                &PlanOptions::default(),
41            )?;
42            let mut state = SimpleState::new(plan)?;
43            state.exec()?;
44            let tensors = state.session_state.tensors;
45            for node in &mut original.nodes {
46                if let Some(var) = node.op_as_mut::<crate::ops::vars::VariableV2>() {
47                    if let Some(value) = tensors.get(&var.id) {
48                        var.initializer = Some(value.clone().into_arc_tensor());
49                    }
50                }
51            }
52        }
53        Ok(original)
54    }
55}
56
57pub struct TfModelAndExtensions(pub InferenceModel, pub TfModelExtensions);
58
59impl Tensorflow {
60    // From the node_def.proto documentation:
61    // Each input is "node:src_output" with "node" being a string name and
62    // "src_output" indicating which output tensor to use from "node". If
63    // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs may
64    // optionally be followed by control inputs that have the format "^node".
65    fn parse_input(i: &str) -> TractResult<(&str, usize)> {
66        let pair = if let Some(stripped) = i.strip_prefix('^') {
67            (stripped, 0)
68        } else {
69            let splits: Vec<_> = i.splitn(2, ':').collect();
70            (splits[0], if splits.len() > 1 { splits[1].parse::<usize>()? } else { 0 })
71        };
72        Ok(pair)
73    }
74
75    pub fn determinize(model: &mut GraphDef) -> TractResult<()> {
76        for pbnode in &mut model.node {
77            if pbnode.op == "RandomUniform"
78                && pbnode.get_attr_int::<i64>("seed")? == 0
79                && pbnode.get_attr_int::<i64>("seed2")? == 0
80            {
81                pbnode.attr.insert("seed".to_string(), 1.into());
82                pbnode.attr.insert("seed2".to_string(), 1.into());
83            }
84        }
85        Ok(())
86    }
87
88    #[cfg(target_family = "wasm")]
89    pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
90        use std::io::Read;
91        let mut file = fs::File::open(p)?;
92        let mut v = Vec::with_capacity(file.metadata()?.len() as usize);
93        file.read_to_end(&mut v)?;
94        let b = bytes::Bytes::from(v);
95        Ok(GraphDef::decode(b)?)
96    }
97
98    #[cfg(all(any(windows, unix), not(target_os = "emscripten")))]
99    pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
100        let map = unsafe { memmap2::Mmap::map(&fs::File::open(p)?)? };
101        Ok(GraphDef::decode(&*map)?)
102    }
103
104    pub fn read_frozen_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
105        let mut v = vec![];
106        r.read_to_end(&mut v)?;
107        let b = bytes::Bytes::from(v);
108        Ok(GraphDef::decode(b)?)
109    }
110
111    pub fn open_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<SavedModel> {
112        let mut v = vec![];
113        r.read_to_end(&mut v)?;
114        let b = bytes::Bytes::from(v);
115        Ok(SavedModel::decode(b)?)
116    }
117
118    /// Convenience method: will read the first model in the saved model
119    /// container. Use open_avec_model for more control.
120    pub fn read_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
121        let mut saved = self.open_saved_model(r)?;
122        Ok(saved.meta_graphs.remove(0).graph_def.unwrap())
123    }
124
125    pub fn parse_graph(&self, graph: &GraphDef) -> TractResult<TfModelAndExtensions> {
126        self.parse_graph_with_template(graph, Default::default())
127    }
128
129    pub fn parse_graph_with_template(
130        &self,
131        graph: &GraphDef,
132        mut model: InferenceModel
133    ) -> TractResult<TfModelAndExtensions> {
134        use crate::ops::control_flow as cf;
135
136        let mut inputs = tvec!();
137        let mut context = ParsingContext::default();
138        let mut control_inputs = vec![];
139
140        // compute min output arity for all nodes
141        for pbnode in &graph.node {
142            for i in &pbnode.input {
143                let (node, slot) = Self::parse_input(i)?;
144                let arity = context.node_output_arities.entry(node.to_string()).or_insert(1);
145                *arity = (*arity).max(slot + 1);
146            }
147        }
148
149        for pbnode in &graph.node {
150            let name = &pbnode.name;
151
152            if pbnode.op == "NextIteration" {
153                let source_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Source);
154                let sink_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Sink);
155                let _source =
156                    model.add_node(name.clone(), source_op, tvec!(InferenceFact::default()))?;
157                let _sink = model.add_node(format!("{name}-Sink"), sink_op, tvec!())?;
158                continue;
159            }
160
161            let op = match self.op_register.0.get(&pbnode.op) {
162                Some(builder) => (builder)(&context, pbnode)?,
163                None => tract_hir::ops::unimpl::UnimplementedOp::new(
164                    context.node_output_arities.get(name).cloned().unwrap_or(1),
165                    &pbnode.op,
166                    format!("{pbnode:?}"),
167                )
168                .into(),
169            };
170
171            let noutputs =
172                op.nboutputs()?.max(context.node_output_arities.get(name).cloned().unwrap_or(1));
173            let facts = tvec!(InferenceFact::default(); noutputs);
174
175            let node_id = model.add_node(name.clone(), op, facts)?;
176            if pbnode.op == "Placeholder" {
177                let dt = pbnode.get_attr_datum_type("dtype")?;
178                let mut fact = InferenceFact::dt(dt);
179                if let Some(shape) = pbnode.get_attr_opt_shape("shape")? {
180                    let shape_factoid = ShapeFactoid::closed(
181                        shape
182                            .iter()
183                            .map(|d| {
184                                if *d == -1 {
185                                    GenericFactoid::Any
186                                } else {
187                                    GenericFactoid::Only(d.to_dim())
188                                }
189                            })
190                            .collect(),
191                    );
192                    fact = fact.with_shape(shape_factoid);
193                }
194                inputs.push(OutletId::new(node_id, 0));
195                model.set_outlet_fact(OutletId::new(node_id, 0), fact)?;
196            }
197        }
198
199        for pbnode in &graph.node {
200            let node_id = if pbnode.op == "NextIteration" {
201                model.node_by_name(&*format!("{}-Sink", &pbnode.name))?.id
202            } else {
203                model.node_by_name(&pbnode.name)?.id
204            };
205            for (ix, i) in pbnode.input.iter().filter(|n| !n.starts_with('^')).enumerate() {
206                let input = Self::parse_input(i)?;
207                let prec = model.node_by_name(input.0)?.id;
208                let outlet = OutletId::new(prec, input.1);
209                let inlet = InletId::new(node_id, ix);
210                model.add_edge(outlet, inlet)?;
211                model.set_outlet_label(outlet, i.to_string())?;
212            }
213            for i in pbnode.input.iter().filter(|n| n.starts_with('^')) {
214                let input = Self::parse_input(i)?;
215                let prec = model.node_by_name(input.0)?.id;
216                control_inputs.push((model.node_id_by_name(&pbnode.name)?, prec));
217            }
218        }
219
220        // variable -> assign rewire
221        //  * Assign consumes this by_ref tensor on #0 and somehow performs
222        //      updates on it (it has a second input on #1 for the value to
223        //      assign)
224        //
225        // in tract:
226        //  * VariableV2 outputs a regular tensor stored in the session state
227        //  * Assign has the same inputs, but do not uses the #0, udating the
228        //      state session instead
229        for id in 0..model.nodes().len() {
230            use crate::ops::vars::*;
231            if model.node(id).op_is::<Assign>() {
232                let prec = model.node(id).inputs[0];
233                let var_id = model.node(prec.node).op_as::<VariableV2>().map(|v| v.id.clone());
234                if let (Some(var_id), Some(assign)) =
235                    (var_id, model.node_mut(id).op_as_mut::<Assign>())
236                {
237                    assign.var_id = Some(var_id);
238                } else {
239                    bail!("Model contains unlinked Assign/Variable2");
240                }
241            }
242        }
243        model.set_input_outlets(&inputs)?;
244        model.auto_outputs()?;
245        let extensions = TfModelExtensions { control_inputs, initializing_nodes: vec![] };
246        Ok(TfModelAndExtensions(model, extensions))
247    }
248}
249
250impl Framework<GraphDef, InferenceModel> for Tensorflow {
251    /// This method will try to read as frozen model, then as a saved model.
252    fn proto_model_for_path(&self, r: impl AsRef<path::Path>) -> TractResult<GraphDef> {
253        self.read_frozen_model(&mut fs::File::open(r.as_ref())?)
254            .or_else(|_| self.read_saved_model(&mut fs::File::open(r.as_ref())?))
255    }
256
257    /// This method expects a frozen model, use open_saved_model for TF2 saved
258    /// model format.
259    fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
260        self.read_frozen_model(r)
261    }
262
263    fn model_for_proto_model_with_model_template(
264            &self,
265            proto: &GraphDef,
266            template: InferenceModel,
267        ) -> TractResult<InferenceModel> {
268        Ok(self.parse_graph_with_template(proto, template)?.0)
269    }
270}