Skip to main content

tract_tensorflow/
model.rs

1use crate::tfpb::tensorflow::{GraphDef, NodeDef, SavedModel};
2use prost::Message;
3use std::{fs, path};
4use tract_hir::internal::*;
5
6#[derive(Default)]
7pub struct ParsingContext {
8    pub node_output_arities: HashMap<String, usize>,
9}
10
11type OpBuilder = fn(&ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>>;
12
13#[derive(Clone, Default)]
14pub struct TfOpRegister(pub HashMap<String, OpBuilder>);
15
16impl TfOpRegister {
17    pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
18        self.0.insert(s.into(), builder);
19    }
20}
21
22pub struct Tensorflow {
23    pub op_register: TfOpRegister,
24}
25
26pub struct TfModelExtensions {
27    pub control_inputs: Vec<(usize, usize)>,
28}
29
30impl TfModelExtensions {
31    pub fn preproc(&self, original: InferenceModel) -> TractResult<InferenceModel> {
32        Ok(original)
33    }
34}
35
36pub struct TfModelAndExtensions(pub InferenceModel, pub TfModelExtensions);
37
38impl Tensorflow {
39    // From the node_def.proto documentation:
40    // Each input is "node:src_output" with "node" being a string name and
41    // "src_output" indicating which output tensor to use from "node". If
42    // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs may
43    // optionally be followed by control inputs that have the format "^node".
44    fn parse_input(i: &str) -> TractResult<(&str, usize)> {
45        let pair = if let Some(stripped) = i.strip_prefix('^') {
46            (stripped, 0)
47        } else {
48            let splits: Vec<_> = i.splitn(2, ':').collect();
49            (splits[0], if splits.len() > 1 { splits[1].parse::<usize>()? } else { 0 })
50        };
51        Ok(pair)
52    }
53
54    pub fn determinize(model: &mut GraphDef) -> TractResult<()> {
55        for pbnode in &mut model.node {
56            if pbnode.op == "RandomUniform"
57                && pbnode.get_attr_int::<i64>("seed")? == 0
58                && pbnode.get_attr_int::<i64>("seed2")? == 0
59            {
60                pbnode.attr.insert("seed".to_string(), 1.into());
61                pbnode.attr.insert("seed2".to_string(), 1.into());
62            }
63        }
64        Ok(())
65    }
66
67    #[cfg(target_family = "wasm")]
68    pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
69        use std::io::Read;
70        let mut file = fs::File::open(p)?;
71        let mut v = Vec::with_capacity(file.metadata()?.len() as usize);
72        file.read_to_end(&mut v)?;
73        let b = bytes::Bytes::from(v);
74        Ok(GraphDef::decode(b)?)
75    }
76
77    #[cfg(all(any(windows, unix), not(target_os = "emscripten")))]
78    pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
79        let map = unsafe { memmap2::Mmap::map(&fs::File::open(p)?)? };
80        Ok(GraphDef::decode(&*map)?)
81    }
82
83    pub fn read_frozen_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
84        let mut v = vec![];
85        r.read_to_end(&mut v)?;
86        let b = bytes::Bytes::from(v);
87        Ok(GraphDef::decode(b)?)
88    }
89
90    pub fn open_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<SavedModel> {
91        let mut v = vec![];
92        r.read_to_end(&mut v)?;
93        let b = bytes::Bytes::from(v);
94        Ok(SavedModel::decode(b)?)
95    }
96
97    /// Convenience method: will read the first model in the saved model
98    /// container. Use open_avec_model for more control.
99    pub fn read_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
100        let mut saved = self.open_saved_model(r)?;
101        Ok(saved.meta_graphs.remove(0).graph_def.unwrap())
102    }
103
104    pub fn parse_graph(&self, graph: &GraphDef) -> TractResult<TfModelAndExtensions> {
105        self.parse_graph_with_template(graph, Default::default())
106    }
107
108    pub fn parse_graph_with_template(
109        &self,
110        graph: &GraphDef,
111        mut model: InferenceModel,
112    ) -> TractResult<TfModelAndExtensions> {
113        use crate::ops::control_flow as cf;
114
115        let mut inputs = tvec!();
116        let mut context = ParsingContext::default();
117        let mut control_inputs = vec![];
118
119        // compute min output arity for all nodes
120        for pbnode in &graph.node {
121            for i in &pbnode.input {
122                let (node, slot) = Self::parse_input(i)?;
123                let arity = context.node_output_arities.entry(node.to_string()).or_insert(1);
124                *arity = (*arity).max(slot + 1);
125            }
126        }
127
128        for pbnode in &graph.node {
129            let name = &pbnode.name;
130
131            if pbnode.op == "NextIteration" {
132                let source_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Source);
133                let sink_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Sink);
134                let _source =
135                    model.add_node(name.clone(), source_op, tvec!(InferenceFact::default()))?;
136                let _sink = model.add_node(format!("{name}-Sink"), sink_op, tvec!())?;
137                continue;
138            }
139
140            let op = match self.op_register.0.get(&pbnode.op) {
141                Some(builder) => (builder)(&context, pbnode)?,
142                None => tract_hir::ops::unimpl::UnimplementedOp::new(
143                    context.node_output_arities.get(name).cloned().unwrap_or(1),
144                    &pbnode.op,
145                    format!("{pbnode:?}"),
146                )
147                .into(),
148            };
149
150            let noutputs =
151                op.nboutputs()?.max(context.node_output_arities.get(name).cloned().unwrap_or(1));
152            let facts = tvec!(InferenceFact::default(); noutputs);
153
154            let node_id = model.add_node(name.clone(), op, facts)?;
155            if pbnode.op == "Placeholder" {
156                let dt = pbnode.get_attr_datum_type("dtype")?;
157                let mut fact = InferenceFact::dt(dt);
158                if let Some(shape) = pbnode.get_attr_opt_shape("shape")? {
159                    let shape_factoid = ShapeFactoid::closed(
160                        shape
161                            .iter()
162                            .map(|d| {
163                                if *d == -1 {
164                                    GenericFactoid::Any
165                                } else {
166                                    GenericFactoid::Only(d.to_dim())
167                                }
168                            })
169                            .collect(),
170                    );
171                    fact = fact.with_shape(shape_factoid);
172                }
173                inputs.push(OutletId::new(node_id, 0));
174                model.set_outlet_fact(OutletId::new(node_id, 0), fact)?;
175            }
176        }
177
178        for pbnode in &graph.node {
179            let node_id = if pbnode.op == "NextIteration" {
180                model.node_by_name(&*format!("{}-Sink", &pbnode.name))?.id
181            } else {
182                model.node_by_name(&pbnode.name)?.id
183            };
184            for (ix, i) in pbnode.input.iter().filter(|n| !n.starts_with('^')).enumerate() {
185                let input = Self::parse_input(i)?;
186                let prec = model.node_by_name(input.0)?.id;
187                let outlet = OutletId::new(prec, input.1);
188                let inlet = InletId::new(node_id, ix);
189                model.add_edge(outlet, inlet)?;
190                model.set_outlet_label(outlet, i.to_string())?;
191            }
192            for i in pbnode.input.iter().filter(|n| n.starts_with('^')) {
193                let input = Self::parse_input(i)?;
194                let prec = model.node_by_name(input.0)?.id;
195                control_inputs.push((model.node_id_by_name(&pbnode.name)?, prec));
196            }
197        }
198
199        // variable -> assign rewire
200        //  * Assign consumes this by_ref tensor on #0 and somehow performs
201        //      updates on it (it has a second input on #1 for the value to
202        //      assign)
203        //
204        // in tract:
205        //  * VariableV2 outputs a regular tensor stored in the session state
206        //  * Assign has the same inputs, but do not uses the #0, udating the
207        //      state session instead
208        //
209        // 2026-01: remove vars support in tract core
210        //
211        // for id in 0..model.nodes().len() {
212        //     use crate::ops::vars::*;
213        //     if model.node(id).op_is::<Assign>() {
214        //         let prec = model.node(id).inputs[0];
215        //         let var_id = model.node(prec.node).op_as::<VariableV2>().map(|v| v.id.clone());
216        //         if let (Some(var_id), Some(assign)) =
217        //             (var_id, model.node_mut(id).op_as_mut::<Assign>())
218        //         {
219        //             assign.var_id = Some(var_id);
220        //         } else {
221        //             bail!("Model contains unlinked Assign/Variable2");
222        //         }
223        //     }
224        // }
225        model.set_input_outlets(&inputs)?;
226        model.auto_outputs()?;
227        let extensions = TfModelExtensions { control_inputs };
228        Ok(TfModelAndExtensions(model, extensions))
229    }
230}
231
232impl Framework<GraphDef, InferenceModel> for Tensorflow {
233    /// This method will try to read as frozen model, then as a saved model.
234    fn proto_model_for_path(&self, r: impl AsRef<path::Path>) -> TractResult<GraphDef> {
235        self.read_frozen_model(&mut fs::File::open(r.as_ref())?)
236            .or_else(|_| self.read_saved_model(&mut fs::File::open(r.as_ref())?))
237    }
238
239    /// This method expects a frozen model, use open_saved_model for TF2 saved
240    /// model format.
241    fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
242        self.read_frozen_model(r)
243    }
244
245    fn model_for_proto_model_with_model_template(
246        &self,
247        proto: &GraphDef,
248        template: InferenceModel,
249    ) -> TractResult<InferenceModel> {
250        Ok(self.parse_graph_with_template(proto, template)?.0)
251    }
252}