1use std::path::PathBuf;
2use std::{fs, path};
3
4use std::collections::HashMap;
5
6use tract_hir::internal::*;
7use tract_hir::prelude::tract_itertools::Itertools;
8
9use crate::data_resolver::{self, ModelDataResolver};
10use crate::pb::type_proto::Value;
11use crate::pb::{self, TensorProto, TypeProto};
12use crate::tensor::{load_tensor, translate_inference_fact};
13use prost::Message;
14
15pub fn optional_inputs(pb: &pb::NodeProto) -> impl Iterator<Item = Option<usize>> + '_ {
16    let mut real_input = 0;
17    (0..).map(move |i| {
18        if pb.input.get(i).filter(|s| !s.is_empty()).is_some() {
19            real_input += 1;
20            Some(real_input - 1)
21        } else {
22            None
23        }
24    })
25}
26
27pub fn optional_outputs(pb: &pb::NodeProto) -> impl Iterator<Item = Option<usize>> + '_ {
28    let mut real_output = 0;
29    (0..).map(move |i| {
30        if pb.output.get(i).filter(|s| !s.is_empty()).is_some() {
31            real_output += 1;
32            Some(real_output - 1)
33        } else {
34            None
35        }
36    })
37}
38
39#[derive(Clone)]
40pub struct ParsingContext<'a> {
41    pub onnx_operator_set_version: i64,
42    pub framework: &'a Onnx,
43    pub model: &'a pb::ModelProto,
44    pub parent_graphs: Vec<&'a pb::GraphProto>,
45    pub model_dir: Option<&'a str>,
46    pub template: InferenceModel,
47}
48
49#[derive(Clone, Debug)]
50pub struct ParseResult {
51    pub model: InferenceModel,
52    pub unresolved_inputs: Vec<String>,
53    pub outlets_by_name: HashMap<String, OutletId>,
54}
55
56impl ParsingContext<'_> {
57    pub fn load_tensor(&self, proto: &TensorProto) -> TractResult<Tensor> {
58        load_tensor(&*self.framework.provider, proto, self.model_dir)
59    }
60
61    pub fn parse_graph(&self, graph: &pb::GraphProto) -> TractResult<ParseResult> {
62        let mut ctx = self.clone();
63        ctx.parent_graphs.push(graph);
64        let mut model = self.template.clone();
65        let mut unresolved_inputs = vec![];
66        let mut closures_to_wire = vec![];
67        trace!("trying to initialize initializers hashmap...");
68        #[allow(unused_assignments)]
69        let mut initializers: HashMap<&str, Tensor> = graph
70            .initializer
71            .iter()
72            .map(|name| {
73                let t = self.load_tensor(name)?;
74                Ok((&*name.name, t))
75            })
76            .collect::<TractResult<_>>()?;
77        for (k, v) in initializers.iter().sorted_by_key(|kv| kv.0) {
78            trace!("Initializer: {k} {v:?}");
79        }
80        let mut outlets_by_name = HashMap::<String, OutletId>::new();
81        for input in graph.input.iter() {
82            if let Some(init) = initializers.remove(&*input.name) {
83                trace!("Input: {} initialized by {:?}", input.name, init);
84                let id = model.add_const(input.name.to_owned(), init)?;
85                outlets_by_name.insert(input.name.to_owned(), id);
86            } else {
87                let fact = input.r#type.as_ref().unwrap().value.as_ref().unwrap();
88                #[allow(irrefutable_let_patterns)]
89                let fact: InferenceFact = if let pb::type_proto::Value::TensorType(fact) = fact {
90                    translate_inference_fact(&ctx, fact, true)
91                        .with_context(|| format!("translating to fact: {fact:?}"))?
92                } else {
93                    bail!("Can not parse tensor type");
94                };
95                trace!("Input: {} is a source ({:?})", input.name, fact);
96                let id = model.add_source(&*input.name, fact)?;
97                outlets_by_name.insert(input.name.to_owned(), id);
98            }
99        }
100        for output in graph.output.iter() {
101            trace!("Model output: {output:?}");
102        }
103        for (name, t) in initializers.into_iter().sorted_by_key(|kv| kv.0) {
104            let id = model.add_const(name, t)?;
105            outlets_by_name.insert(name.to_string(), id);
106        }
107        let consts = model.nodes().len();
108        for pbnode in graph.node.iter() {
109            let name = if !pbnode.name.is_empty() {
110                pbnode.name.to_string()
111            } else if pbnode.output.len() > 0 && !pbnode.output[0].is_empty() {
112                pbnode.output[0].to_owned()
113            } else {
114                format!("{}-{}", model.nodes().len(), pbnode.op_type)
115            };
116            trace!("Creating node {name}");
117            let facts = pbnode
118                .output
119                .iter()
120                .filter(|s| !s.is_empty())
121                .map(|_| InferenceFact::default())
122                .collect();
123            trace!("  outputs {:?}", pbnode.output);
124            let (op, closures) = match self.framework.op_register.0.get(&pbnode.op_type) {
125                Some(builder) => (builder)(&ctx, pbnode).with_context(|| {
126                    format!("Building node {} ({})", pbnode.name, pbnode.op_type)
127                })?,
128                None => (
129                    tract_hir::ops::unimpl::UnimplementedOp::new(
130                        pbnode.output.len(),
131                        &*pbnode.op_type,
132                        format!("{pbnode:?}"),
133                    )
134                    .into(),
135                    vec![],
136                ),
137            };
138            let id = model.add_node(name, op, facts)?;
139            for (ix, output) in pbnode.output.iter().filter(|s| !s.is_empty()).enumerate() {
140                outlets_by_name.insert(output.to_owned(), OutletId::new(id, ix));
141                model.set_outlet_label(OutletId::new(id, ix), output.to_owned())?;
142            }
143            for closure in closures {
144                trace!("Node {} closes on {}", model.nodes()[id], closure);
145                closures_to_wire.push((id, closure))
146            }
147        }
148        for (id, pbnode) in graph.node.iter().enumerate() {
149            for (ix, input) in pbnode.input.iter().filter(|s| !s.is_empty()).enumerate() {
150                if !outlets_by_name.contains_key(input) {
151                    let id = model.add_source(input.clone(), InferenceFact::default())?;
152                    unresolved_inputs.push(input.to_string());
153                    outlets_by_name.insert(input.to_string(), id);
154                }
155                let outlet = outlets_by_name[input];
156                model.add_edge(outlet, InletId::new(id + consts, ix))?;
157            }
158        }
159        for (id, closure) in closures_to_wire {
160            if !outlets_by_name.contains_key(&*closure) {
161                let id = model.add_source(closure.clone(), InferenceFact::default())?;
162                unresolved_inputs.push(closure.to_string());
163                outlets_by_name.insert(closure.to_string(), id);
164            }
165            let outlet = outlets_by_name[&*closure];
166            let ix = model.nodes()[id].inputs.len();
167            model.add_edge(outlet, InletId::new(id, ix))?;
168        }
169        let mut outputs = vec![];
170        for output in graph.output.iter() {
171            let mut fact = InferenceFact::default();
172            if self.framework.use_output_shapes {
173                if let Some(f) = output.r#type.as_ref().and_then(|t| t.value.as_ref()) {
174                    let pb::type_proto::Value::TensorType(f) = f;
175                    fact = translate_inference_fact(&ctx, f, false)?
176                };
177            }
178            if self.framework.ignore_output_types {
179                fact = fact.without_datum_type();
180            }
181            let outlet = outlets_by_name[&*output.name];
182            outputs.push(outlet);
183            model.set_outlet_label(outlet, output.name.clone())?;
184            model.set_outlet_fact(outlet, fact)?;
185        }
186        model.set_output_outlets(&outputs)?;
187        for info in &graph.value_info {
188            if let Some(TypeProto { value: Some(Value::TensorType(t)), .. }) = &info.r#type {
189                if let Some(outlet) = outlets_by_name.get(&info.name) {
190                    let mut pbfact = translate_inference_fact(&ctx, t, false)?;
191                    if pbfact.datum_type() == Some(i64::datum_type()) {
193                        pbfact = pbfact.without_datum_type();
194                    }
195                    model.set_outlet_fact(*outlet, pbfact)?;
196                }
197            }
198        }
199        let result = ParseResult { model, unresolved_inputs, outlets_by_name };
200        Ok(result)
201    }
202}
203
204type OpBuilder =
205    fn(&ParsingContext, node: &pb::NodeProto) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)>;
206
207#[derive(Clone, Default)]
208pub struct OnnxOpRegister(pub HashMap<String, OpBuilder>);
209
210impl OnnxOpRegister {
211    pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
212        self.0.insert(s.into(), builder);
213    }
214}
215
216#[derive(Clone)]
217pub struct Onnx {
218    pub op_register: OnnxOpRegister,
219    pub use_output_shapes: bool,
220    pub ignore_output_types: bool,
221    pub provider: Arc<dyn ModelDataResolver + Send + Sync>,
222}
223
224impl Default for Onnx {
225    fn default() -> Self {
226        Onnx {
227            op_register: Default::default(),
228            use_output_shapes: Default::default(),
229            ignore_output_types: Default::default(),
230            provider: Arc::new(data_resolver::MmapDataResolver),
231        }
232    }
233}
234
235impl Onnx {
236    pub fn parse(&self, proto: &pb::ModelProto, path: Option<&str>) -> TractResult<ParseResult> {
237        self.parse_with_template(proto, path, Default::default())
238    }
239    pub fn parse_with_template(
240        &self,
241        proto: &pb::ModelProto,
242        model_dir: Option<&str>,
243        template: InferenceModel,
244    ) -> TractResult<ParseResult> {
245        let onnx_operator_set_version = proto
246            .opset_import
247            .iter()
248            .find(|import| import.domain.is_empty() || import.domain == "ai.onnx")
249            .map(|op| op.version)
250            .unwrap_or(0);
251        let graph =
252            proto.graph.as_ref().ok_or_else(|| anyhow!("model proto does not contain a graph"))?;
253        debug!("ONNX operator set version: {onnx_operator_set_version:?}");
254        if onnx_operator_set_version != 0 && !(9..19).contains(&onnx_operator_set_version) {
255            warn!("ONNX operator for your model is {onnx_operator_set_version}, tract is only tested against \
256                  operator set 9 to 18 (included). Your model may still work so this is not a hard fail.");
257        }
258        let ctx = ParsingContext {
259            framework: self,
260            model: proto,
261            parent_graphs: vec![],
262            onnx_operator_set_version,
263            model_dir,
264            template,
265        };
266        trace!("created ParsingContext");
267        ctx.parse_graph(graph)
268    }
269
270    pub fn with_ignore_output_shapes(self, ignore: bool) -> Onnx {
271        Self { use_output_shapes: !ignore, ..self }
272    }
273
274    pub fn with_ignore_output_types(self, ignore: bool) -> Onnx {
275        Self { ignore_output_types: ignore, ..self }
276    }
277
278    pub fn determinize(model: &mut InferenceModel) -> TractResult<()> {
279        use crate::ops::multinomial::Multinomial;
280        for node in model.nodes_mut() {
281            if let Some(op) = node.op_as_mut::<Box<dyn Expansion>>() {
282                if let Some(op) = op.as_any_mut().downcast_mut::<Multinomial>() {
283                    op.seed.get_or_insert(1.0);
284                }
285            }
286        }
287        Ok(())
288    }
289}
290
291impl Framework<pb::ModelProto, InferenceModel> for Onnx {
292    fn model_for_path(&self, p: impl AsRef<path::Path>) -> TractResult<InferenceModel> {
293        let mut path = PathBuf::new();
294        path.push(&p);
295        let mut dir: Option<&str> = None;
296        if let Some(dir_opt) = path.parent() {
297            dir = dir_opt.to_str();
298        }
299        let proto = self.proto_model_for_path(p)?;
300        let ParseResult { model, unresolved_inputs, .. } = self.parse(&proto, dir)?;
301        if unresolved_inputs.len() > 0 {
302            bail!("Could not resolve inputs at top-level: {:?}", unresolved_inputs)
303        }
304        Ok(model)
305    }
306
307    #[cfg(target_family = "wasm")]
308    fn proto_model_for_path(&self, p: impl AsRef<path::Path>) -> TractResult<pb::ModelProto> {
309        let p = p.as_ref();
310        let mut file = fs::File::open(p).with_context(|| format!("Opening {p:?}"))?;
311        Ok(self.proto_model_for_read(&mut file)?)
312    }
313
314    #[cfg(not(target_family = "wasm"))]
315    fn proto_model_for_path(&self, p: impl AsRef<path::Path>) -> TractResult<pb::ModelProto> {
316        let p = p.as_ref();
317        let map = unsafe {
318            memmap2::Mmap::map(&fs::File::open(p).with_context(|| format!("Opening {p:?}"))?)?
319        };
320        Ok(crate::pb::ModelProto::decode(&*map)?)
321    }
322
323    fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<pb::ModelProto> {
324        let mut v = vec![];
325        r.read_to_end(&mut v)?;
326        let b = bytes::Bytes::from(v);
327        Ok(crate::pb::ModelProto::decode(b)?)
328    }
329
330    fn model_for_proto_model_with_model_template(
331        &self,
332        proto: &pb::ModelProto,
333        template: InferenceModel,
334    ) -> TractResult<InferenceModel> {
335        let ParseResult { model, unresolved_inputs, .. } =
336            self.parse_with_template(proto, None, template)?;
337        if unresolved_inputs.len() > 0 {
338            bail!("Could not resolve inputs at top-level: {:?}", unresolved_inputs)
339        }
340        Ok(model)
341    }
342
343    fn model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<InferenceModel> {
344        let proto_model = self.proto_model_for_read(r).context("Reading proto model")?;
345        self.model_for_proto_model(&proto_model).context("Translating proto model to model")
346    }
347}