1use std::path::PathBuf;
2use std::{fs, path};
3
4use std::collections::HashMap;
5
6use tract_hir::internal::*;
7use tract_hir::prelude::tract_itertools::Itertools;
8
9use crate::data_resolver::{self, ModelDataResolver};
10use crate::pb::type_proto::Value;
11use crate::pb::{self, TensorProto, TypeProto};
12use crate::tensor::{load_tensor, translate_inference_fact};
13use prost::Message;
14
15pub fn optional_inputs(pb: &pb::NodeProto) -> impl Iterator<Item = Option<usize>> + '_ {
16 let mut real_input = 0;
17 (0..).map(move |i| {
18 if pb.input.get(i).filter(|s| !s.is_empty()).is_some() {
19 real_input += 1;
20 Some(real_input - 1)
21 } else {
22 None
23 }
24 })
25}
26
27pub fn optional_outputs(pb: &pb::NodeProto) -> impl Iterator<Item = Option<usize>> + '_ {
28 let mut real_output = 0;
29 (0..).map(move |i| {
30 if pb.output.get(i).filter(|s| !s.is_empty()).is_some() {
31 real_output += 1;
32 Some(real_output - 1)
33 } else {
34 None
35 }
36 })
37}
38
39#[derive(Clone)]
40pub struct ParsingContext<'a> {
41 pub onnx_operator_set_version: i64,
42 pub framework: &'a Onnx,
43 pub model: &'a pb::ModelProto,
44 pub parent_graphs: Vec<&'a pb::GraphProto>,
45 pub model_dir: Option<&'a str>,
46 pub template: InferenceModel,
47}
48
49#[derive(Clone, Debug)]
50pub struct ParseResult {
51 pub model: InferenceModel,
52 pub unresolved_inputs: Vec<String>,
53 pub outlets_by_name: HashMap<String, OutletId>,
54}
55
56impl ParsingContext<'_> {
57 pub fn load_tensor(&self, proto: &TensorProto) -> TractResult<Tensor> {
58 load_tensor(&*self.framework.provider, proto, self.model_dir)
59 }
60
61 pub fn parse_graph(&self, graph: &pb::GraphProto) -> TractResult<ParseResult> {
62 let mut ctx = self.clone();
63 ctx.parent_graphs.push(graph);
64 let mut model = self.template.clone();
65 let mut unresolved_inputs = vec![];
66 let mut closures_to_wire = vec![];
67 trace!("trying to initialize initializers hashmap...");
68 #[allow(unused_assignments)]
69 let mut initializers: HashMap<&str, Tensor> = graph
70 .initializer
71 .iter()
72 .map(|name| {
73 let t = self.load_tensor(name)?;
74 Ok((&*name.name, t))
75 })
76 .collect::<TractResult<_>>()?;
77 for (k, v) in initializers.iter().sorted_by_key(|kv| kv.0) {
78 trace!("Initializer: {k} {v:?}");
79 }
80 let mut outlets_by_name = HashMap::<String, OutletId>::new();
81 for input in graph.input.iter() {
82 if let Some(init) = initializers.remove(&*input.name) {
83 trace!("Input: {} initialized by {:?}", input.name, init);
84 let id = model.add_const(input.name.to_owned(), init)?;
85 outlets_by_name.insert(input.name.to_owned(), id);
86 } else {
87 let fact = input.r#type.as_ref().unwrap().value.as_ref().unwrap();
88 #[allow(irrefutable_let_patterns)]
89 let fact: InferenceFact = if let pb::type_proto::Value::TensorType(fact) = fact {
90 translate_inference_fact(&ctx, fact, true)
91 .with_context(|| format!("translating to fact: {fact:?}"))?
92 } else {
93 bail!("Can not parse tensor type");
94 };
95 trace!("Input: {} is a source ({:?})", input.name, fact);
96 let id = model.add_source(&*input.name, fact)?;
97 outlets_by_name.insert(input.name.to_owned(), id);
98 }
99 }
100 for output in graph.output.iter() {
101 trace!("Model output: {output:?}");
102 }
103 for (name, t) in initializers.into_iter().sorted_by_key(|kv| kv.0) {
104 let id = model.add_const(name, t)?;
105 outlets_by_name.insert(name.to_string(), id);
106 }
107 let consts = model.nodes().len();
108 for pbnode in graph.node.iter() {
109 let name = if !pbnode.name.is_empty() {
110 pbnode.name.to_string()
111 } else if pbnode.output.len() > 0 && !pbnode.output[0].is_empty() {
112 pbnode.output[0].to_owned()
113 } else {
114 format!("{}-{}", model.nodes().len(), pbnode.op_type)
115 };
116 trace!("Creating node {name}");
117 let facts = pbnode
118 .output
119 .iter()
120 .filter(|s| !s.is_empty())
121 .map(|_| InferenceFact::default())
122 .collect();
123 trace!(" outputs {:?}", pbnode.output);
124 let (op, closures) = match self.framework.op_register.0.get(&pbnode.op_type) {
125 Some(builder) => (builder)(&ctx, pbnode).with_context(|| {
126 format!("Building node {} ({})", pbnode.name, pbnode.op_type)
127 })?,
128 None => (
129 tract_hir::ops::unimpl::UnimplementedOp::new(
130 pbnode.output.len(),
131 &*pbnode.op_type,
132 format!("{pbnode:?}"),
133 )
134 .into(),
135 vec![],
136 ),
137 };
138 let id = model.add_node(name, op, facts)?;
139 for (ix, output) in pbnode.output.iter().filter(|s| !s.is_empty()).enumerate() {
140 outlets_by_name.insert(output.to_owned(), OutletId::new(id, ix));
141 model.set_outlet_label(OutletId::new(id, ix), output.to_owned())?;
142 }
143 for closure in closures {
144 trace!("Node {} closes on {}", model.nodes()[id], closure);
145 closures_to_wire.push((id, closure))
146 }
147 }
148 for (id, pbnode) in graph.node.iter().enumerate() {
149 for (ix, input) in pbnode.input.iter().filter(|s| !s.is_empty()).enumerate() {
150 if !outlets_by_name.contains_key(input) {
151 let id = model.add_source(input.clone(), InferenceFact::default())?;
152 unresolved_inputs.push(input.to_string());
153 outlets_by_name.insert(input.to_string(), id);
154 }
155 let outlet = outlets_by_name[input];
156 model.add_edge(outlet, InletId::new(id + consts, ix))?;
157 }
158 }
159 for (id, closure) in closures_to_wire {
160 if !outlets_by_name.contains_key(&*closure) {
161 let id = model.add_source(closure.clone(), InferenceFact::default())?;
162 unresolved_inputs.push(closure.to_string());
163 outlets_by_name.insert(closure.to_string(), id);
164 }
165 let outlet = outlets_by_name[&*closure];
166 let ix = model.nodes()[id].inputs.len();
167 model.add_edge(outlet, InletId::new(id, ix))?;
168 }
169 let mut outputs = vec![];
170 for output in graph.output.iter() {
171 let mut fact = InferenceFact::default();
172 if self.framework.use_output_shapes {
173 if let Some(f) = output.r#type.as_ref().and_then(|t| t.value.as_ref()) {
174 let pb::type_proto::Value::TensorType(f) = f;
175 fact = translate_inference_fact(&ctx, f, false)?
176 };
177 }
178 if self.framework.ignore_output_types {
179 fact = fact.without_datum_type();
180 }
181 let outlet = outlets_by_name[&*output.name];
182 outputs.push(outlet);
183 model.set_outlet_label(outlet, output.name.clone())?;
184 model.set_outlet_fact(outlet, fact)?;
185 }
186 model.set_output_outlets(&outputs)?;
187 for info in &graph.value_info {
188 if let Some(TypeProto { value: Some(Value::TensorType(t)), .. }) = &info.r#type {
189 if let Some(outlet) = outlets_by_name.get(&info.name) {
190 let mut pbfact = translate_inference_fact(&ctx, t, false)?;
191 if pbfact.datum_type() == Some(i64::datum_type()) {
193 pbfact = pbfact.without_datum_type();
194 }
195 model.set_outlet_fact(*outlet, pbfact)?;
196 }
197 }
198 }
199 let result = ParseResult { model, unresolved_inputs, outlets_by_name };
200 Ok(result)
201 }
202}
203
204type OpBuilder =
205 fn(&ParsingContext, node: &pb::NodeProto) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)>;
206
207#[derive(Clone, Default)]
208pub struct OnnxOpRegister(pub HashMap<String, OpBuilder>);
209
210impl OnnxOpRegister {
211 pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
212 self.0.insert(s.into(), builder);
213 }
214}
215
216#[derive(Clone)]
217pub struct Onnx {
218 pub op_register: OnnxOpRegister,
219 pub use_output_shapes: bool,
220 pub ignore_output_types: bool,
221 pub provider: Arc<dyn ModelDataResolver + Send + Sync>,
222}
223
224impl Default for Onnx {
225 fn default() -> Self {
226 Onnx {
227 op_register: Default::default(),
228 use_output_shapes: Default::default(),
229 ignore_output_types: Default::default(),
230 provider: Arc::new(data_resolver::MmapDataResolver),
231 }
232 }
233}
234
235impl Onnx {
236 pub fn parse(&self, proto: &pb::ModelProto, path: Option<&str>) -> TractResult<ParseResult> {
237 self.parse_with_template(proto, path, Default::default())
238 }
239 pub fn parse_with_template(
240 &self,
241 proto: &pb::ModelProto,
242 model_dir: Option<&str>,
243 template: InferenceModel,
244 ) -> TractResult<ParseResult> {
245 let onnx_operator_set_version = proto
246 .opset_import
247 .iter()
248 .find(|import| import.domain.is_empty() || import.domain == "ai.onnx")
249 .map(|op| op.version)
250 .unwrap_or(0);
251 let graph =
252 proto.graph.as_ref().ok_or_else(|| anyhow!("model proto does not contain a graph"))?;
253 debug!("ONNX operator set version: {onnx_operator_set_version:?}");
254 if onnx_operator_set_version != 0 && !(9..19).contains(&onnx_operator_set_version) {
255 warn!("ONNX operator for your model is {onnx_operator_set_version}, tract is only tested against \
256 operator set 9 to 18 (included). Your model may still work so this is not a hard fail.");
257 }
258 let ctx = ParsingContext {
259 framework: self,
260 model: proto,
261 parent_graphs: vec![],
262 onnx_operator_set_version,
263 model_dir,
264 template,
265 };
266 trace!("created ParsingContext");
267 ctx.parse_graph(graph)
268 }
269
270 pub fn with_ignore_output_shapes(self, ignore: bool) -> Onnx {
271 Self { use_output_shapes: !ignore, ..self }
272 }
273
274 pub fn with_ignore_output_types(self, ignore: bool) -> Onnx {
275 Self { ignore_output_types: ignore, ..self }
276 }
277
278 pub fn determinize(model: &mut InferenceModel) -> TractResult<()> {
279 use crate::ops::multinomial::Multinomial;
280 for node in model.nodes_mut() {
281 if let Some(op) = node.op_as_mut::<Box<dyn Expansion>>() {
282 if let Some(op) = op.as_any_mut().downcast_mut::<Multinomial>() {
283 op.seed.get_or_insert(1.0);
284 }
285 }
286 }
287 Ok(())
288 }
289}
290
291impl Framework<pb::ModelProto, InferenceModel> for Onnx {
292 fn model_for_path(&self, p: impl AsRef<path::Path>) -> TractResult<InferenceModel> {
293 let mut path = PathBuf::new();
294 path.push(&p);
295 let mut dir: Option<&str> = None;
296 if let Some(dir_opt) = path.parent() {
297 dir = dir_opt.to_str();
298 }
299 let proto = self.proto_model_for_path(p)?;
300 let ParseResult { model, unresolved_inputs, .. } = self.parse(&proto, dir)?;
301 if unresolved_inputs.len() > 0 {
302 bail!("Could not resolve inputs at top-level: {:?}", unresolved_inputs)
303 }
304 Ok(model)
305 }
306
307 #[cfg(target_family = "wasm")]
308 fn proto_model_for_path(&self, p: impl AsRef<path::Path>) -> TractResult<pb::ModelProto> {
309 let p = p.as_ref();
310 let mut file = fs::File::open(p).with_context(|| format!("Opening {p:?}"))?;
311 Ok(self.proto_model_for_read(&mut file)?)
312 }
313
314 #[cfg(not(target_family = "wasm"))]
315 fn proto_model_for_path(&self, p: impl AsRef<path::Path>) -> TractResult<pb::ModelProto> {
316 let p = p.as_ref();
317 let map = unsafe {
318 memmap2::Mmap::map(&fs::File::open(p).with_context(|| format!("Opening {p:?}"))?)?
319 };
320 Ok(crate::pb::ModelProto::decode(&*map)?)
321 }
322
323 fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<pb::ModelProto> {
324 let mut v = vec![];
325 r.read_to_end(&mut v)?;
326 let b = bytes::Bytes::from(v);
327 Ok(crate::pb::ModelProto::decode(b)?)
328 }
329
330 fn model_for_proto_model_with_model_template(
331 &self,
332 proto: &pb::ModelProto,
333 template: InferenceModel,
334 ) -> TractResult<InferenceModel> {
335 let ParseResult { model, unresolved_inputs, .. } =
336 self.parse_with_template(proto, None, template)?;
337 if unresolved_inputs.len() > 0 {
338 bail!("Could not resolve inputs at top-level: {:?}", unresolved_inputs)
339 }
340 Ok(model)
341 }
342
343 fn model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<InferenceModel> {
344 let proto_model = self.proto_model_for_read(r).context("Reading proto model")?;
345 self.model_for_proto_model(&proto_model).context("Translating proto model to model")
346 }
347}