1use crate::ast::*;
2use crate::internal::*;
3use tract_core::ndarray::ArrayViewD;
4use tract_core::ndarray::Axis;
5use tract_itertools::Itertools;
6use tract_linalg::block_quant::BlockQuantValue;
7
8pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> {
9 model.prop_consts()?;
10 tract_core::ops::einsum::prefix_matmul::rewrite_einsum_to_prefix_matmul(model)?;
11 Rewriter::default()
12 .with_rule_for(
13 "rewrite_block_quant_const_to_scalar",
14 crate::ops::nnef::ser::rewrite_block_quant_const_to_scalar,
15 )
16 .with_rule_for(
17 "rewrite_matmul_to_same_rank",
18 crate::ops::nnef::ser::rewrite_matmul_to_same_rank,
19 )
20 .with_rule_for("rewrite_conv_with_n_axis", tract_core::ops::cnn::rewrite_conv_with_n_axis)
21 .with_rule_for(
22 "rewrite_deconv_with_n_axis",
23 tract_core::ops::cnn::rewrite_deconv_with_n_axis,
24 )
25 .with_rule_for(
26 "rewrite_kernel_conv_in_oihw",
27 crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw,
28 )
29 .with_rule_for(
30 "rewrite_kernel_deconv_in_oihw",
31 crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw,
32 )
33 .with_rule_for(
34 "rewrite_consistent_quantized_conv",
35 crate::ops::nnef::ser::rewrite_consistent_quantized_conv,
36 )
37 .with_rule_for("expand_mean_of_square", tract_core::ops::nn::expand_mean_of_squares)
38 .rewrite(&(), model)
39}
40
41pub fn to_proto_model(framework: &Nnef, model: &TypedModel) -> TractResult<ProtoModel> {
42 let mut fixed_model = model.clone();
43 rewrite_model(&mut fixed_model)?;
44 let mut into_ast = IntoAst::new(framework, &fixed_model);
45 into_ast.translate().context("Translating model to AST")?;
46 into_ast.into_proto_model().context("Translating AST to proto model")
47}
48
49pub fn to_fragment_def(
50 parent: &IntoAst,
51 model: &TypedModel,
52) -> TractResult<(FragmentDef, Vec<RequiredTensorParameter>)> {
53 let mut into_ast = IntoAst::new(parent.framework, model);
54 into_ast.parent = Some(parent);
55 into_ast.translate()?;
56 into_ast.into_fragment()
57}
58
59pub struct IntoAst<'a> {
60 pub framework: &'a Nnef,
61 pub parent: Option<&'a IntoAst<'a>>,
62 pub registries: Vec<Identifier>,
63 pub model: &'a TypedModel,
64 pub parameters: Vec<Identifier>,
65 pub results: Vec<Identifier>,
66 pub mapping: HashMap<OutletId, Arc<RValue>>,
67 pub tensors: HashMap<Identifier, Arc<Tensor>>,
68 pub quantization: HashMap<Identifier, QuantFormat>,
69 pub resources: HashMap<String, Arc<dyn Resource>>,
70 pub fragments: HashMap<Identifier, FragmentDef>,
71 pub body: Vec<Assignment>,
72}
73
74pub struct RequiredTensorParameter {
75 pub parameter_id: Identifier,
76 pub label: Identifier,
77 pub value: Arc<Tensor>,
78}
79
80impl<'a> IntoAst<'a> {
81 pub fn new(framework: &'a Nnef, model: &'a TypedModel) -> IntoAst<'a> {
82 IntoAst {
83 framework,
84 registries: Default::default(),
85 model,
86 parameters: Default::default(),
87 results: Default::default(),
88 mapping: Default::default(),
89 tensors: Default::default(),
90 quantization: Default::default(),
91 resources: Default::default(),
92 fragments: Default::default(),
93 body: Default::default(),
94 parent: None,
95 }
96 }
97
98 fn ensure_registry(&mut self, id: &Identifier) -> TractResult<()> {
99 if !self.framework.registries.iter().any(|r| &r.id == id) {
100 bail!("Registry {} required, consider allowing it on the NNEF framework.", id.0);
101 }
102 if !self.registries.iter().any(|r| r == id) {
103 self.registries.push(id.clone());
104 }
105 Ok(())
106 }
107
108 fn translate(&mut self) -> TractResult<()> {
109 for input in self.model.input_outlets()? {
110 let left = self.scoped_id(&self.model.node(input.node).name);
111 self.parameters.push(left.clone());
112 self.node(self.model.node(input.node))?;
113 self.mapping.insert(*input, RValue::Identifier(left).into());
114 }
115 for node in self.model.eval_order()? {
116 if self.model.input_outlets()?.iter().any(|io| io.node == node) {
117 continue;
118 }
119 self.node(self.model.node(node))
120 .with_context(|| format!("translating node {}", self.model.node(node)))?;
121 }
122 let outlets: Vec<OutletId> = self.model.output_outlets()?.to_vec();
123 for (ix, o) in outlets.into_iter().enumerate() {
124 let rv = if let Some(label) = self.model.outlet_label(o) {
125 self.force_variable_and_name(label, &self.mapping[&o].clone())
126 } else {
127 self.force_variable(format!("output_{ix}"), &self.mapping[&o].clone())
128 };
129 if let RValue::Identifier(name) = rv.as_ref() {
130 self.results.push(name.clone());
131 } else {
132 unreachable!()
133 };
134 }
135 Ok(())
136 }
137
138 pub fn into_fragment(self) -> TractResult<(FragmentDef, Vec<RequiredTensorParameter>)> {
139 let mut tensor_params = vec![];
140 for (name, t) in &self.tensors {
141 tensor_params.push(RequiredTensorParameter {
142 parameter_id: self.scoped_id(name),
143 label: name.clone(),
144 value: t.clone(),
145 })
146 }
147 let IntoAst { body, mut parameters, results, .. } = self;
148 parameters.extend(tensor_params.iter().map(|rtp| rtp.parameter_id.clone()).sorted());
149 let body = body
150 .into_iter()
151 .filter(|assign| match &assign.left {
152 LValue::Identifier(id) => !parameters.contains(id),
153 _ => true,
154 })
155 .collect();
156 Ok((
157 FragmentDef {
158 decl: FragmentDecl {
159 id: Identifier("network".into()),
160 generic_decl: None,
161 parameters: parameters
162 .into_iter()
163 .map(|s| TypeName::Scalar.tensor().named(s))
164 .collect(),
165 results: results
166 .into_iter()
167 .map(|s| Result_ { id: s, spec: TypeName::Scalar.tensor() })
168 .collect(),
169 },
170 body: Some(body),
171 },
172 tensor_params,
173 ))
174 }
175
176 pub fn into_proto_model(mut self) -> TractResult<ProtoModel> {
177 let mut properties = self
178 .model
179 .properties
180 .iter()
181 .sorted_by_key(|(k, _v)| k.to_owned())
182 .map(|(k, v)| Ok(tuple_2(string(k), self.konst(k, v)?.as_ref().clone())))
183 .collect::<TractResult<Vec<_>>>()?;
184 let version = env!("CARGO_PKG_VERSION");
185 properties.push(tuple_2(
186 string("tract_nnef_ser_version"),
187 self.konst("tract_nnef_ser_version", &rctensor0(version.to_string()))?.as_ref().clone(),
188 ));
189 properties.push(tuple_2(
190 string("tract_nnef_format_version"),
191 self.konst("tract_nnef_format_version", &rctensor0("beta1".to_string()))?
192 .as_ref()
193 .clone(),
194 ));
195 let properties: Assignment = assignment("properties", Arc::new(array(properties)));
196 let IntoAst { mut fragments, body, tensors, parameters, results, .. } = self;
197 let mut extension = vec![];
198 self.registries.sort();
199 for reg in self.registries {
200 if reg.0 != "tract_nnef" {
201 extension.push(("tract_registry".into(), reg.0));
202 }
203 }
204 for sym in self.model.symbols.all_symbols() {
205 extension.push(("tract_symbol".into(), sym.to_string()));
206 }
207 let locked = self.model.symbols.0.lock();
208 for assert in locked.borrow().all_assertions() {
209 extension.push(("tract_assert".into(), assert.to_string()));
210 }
211 for scenario in locked.borrow().scenarios() {
212 for assert in locked.borrow().scenario(scenario) {
213 extension.push(("tract_assert".into(), format!("{scenario}: {assert}")));
214 }
215 }
216 let properties = FragmentDef {
217 decl: FragmentDecl {
218 id: Identifier("tract_core_properties".to_string()),
219 generic_decl: None,
220 parameters: vec![],
221 results: vec![Result_ {
222 id: Identifier("properties".to_string()),
223 spec: TypeSpec::Tuple(vec![TypeName::String.spec(), TypeName::Scalar.tensor()])
224 .array(),
225 }],
226 },
227 body: Some(vec![properties]),
228 };
229 fragments.insert(properties.decl.id.clone(), properties);
230 let doc = Document {
231 version: "1.0".into(),
232 extension,
233 fragments: fragments.into_values().collect(),
234 graph_def: GraphDef { id: Identifier("network".into()), parameters, results, body },
235 };
236 let quantization = if self.quantization.len() > 0 { Some(self.quantization) } else { None };
237 Ok(ProtoModel { doc, tensors, quantization, resources: self.resources })
238 }
239
240 fn node(&mut self, node: &TypedNode) -> TractResult<TVec<Arc<RValue>>> {
241 let mut required_registries = Vec::new();
242 for reg in &self.framework.registries {
243 if let Some(outputs) = reg.serialize(self, node).context("Serializing op")? {
244 if self.ensure_registry(®.id).is_err() {
245 required_registries.push(®.id);
246 continue;
247 };
248 let scoped = self.scoped_id(&node.name);
249 let names: Vec<_> = (0..node.outputs.len())
250 .map(|ix| {
251 if ix > 0 {
252 Identifier(format!("{}_{}", scoped.0, ix))
253 } else {
254 scoped.clone()
255 }
256 })
257 .collect();
258 if node.outputs.len() > 1 {
259 self.body.push(Assignment {
260 left: LValue::Tuple(
261 names.iter().map(|n| LValue::Identifier(n.clone())).collect(),
262 ),
263 right: outputs.as_ref().clone(),
264 });
265 } else {
266 self.assignment(names[0].clone(), outputs);
267 };
268
269 for (outlet, name) in node.outputs.iter().zip(names.iter()) {
270 if let Some(qf) = QuantFormat::from_dt(outlet.fact.datum_type) {
271 self.quantization.insert(name.clone(), qf);
272 }
273 }
274
275 let mut outputs = tvec!();
276 for (ix, o) in names.into_iter().enumerate() {
277 let rv = Arc::new(ident(o));
278 self.mapping.insert((node.id, ix).into(), rv.clone());
279 outputs.push(rv);
280 }
281
282 return Ok(outputs);
283 }
284 }
285 if required_registries.is_empty() {
286 bail!("No serializer found for node {}", node);
287 } else if required_registries.len() == 1 {
288 bail!(
289 "Registry {} required, consider allowing it on the NNEF framework.",
290 required_registries[0].0
291 );
292 } else {
293 bail!("One of the following registries is required: {:?}, consider allowing one on the NNEF framework.", required_registries);
294 }
295 }
296
297 pub fn scoped_id(&self, name: impl AsRef<str>) -> Identifier {
298 let name = name.as_ref().to_string();
299 Identifier(name)
300 }
301
302 pub fn force_variable(&mut self, name: impl AsRef<str>, exp: &Arc<RValue>) -> Arc<RValue> {
303 if let RValue::Identifier(_) = exp.as_ref() {
304 exp.clone()
305 } else {
306 let name = self.scoped_id(name);
307 self.assignment(name.clone(), exp.clone());
308 ident(name).into()
309 }
310 }
311
312 pub fn force_variable_and_name(
313 &mut self,
314 name: impl Into<String>,
315 exp: &Arc<RValue>,
316 ) -> Arc<RValue> {
317 let name = name.into();
318 if let RValue::Identifier(id) = exp.as_ref() {
319 if name == id.0 {
320 return exp.clone();
321 }
322 }
323 let name = self.scoped_id(name);
324 self.assignment(name.clone(), exp.clone());
325 ident(name).into()
326 }
327
328 pub fn konst(
329 &mut self,
330 name: impl AsRef<str>,
331 tensor: &Arc<Tensor>,
332 ) -> TractResult<Arc<RValue>> {
333 self.do_konst(name, tensor, false)
334 }
335
336 pub fn konst_variable(
337 &mut self,
338 name: impl AsRef<str>,
339 tensor: &Arc<Tensor>,
340 ) -> TractResult<Arc<RValue>> {
341 self.do_konst(name, tensor, true)
342 }
343
344 fn dump_rec_tensor<T: Datum>(
345 t: &ArrayViewD<T>,
346 el: impl for<'t> Fn(&'t T) -> RValue + Copy,
347 ) -> RValue {
348 if t.ndim() == 0 {
349 el(&t.as_slice().unwrap()[0])
350 } else {
351 let values: TVec<RValue> = (0..t.shape()[0])
352 .map(|i| Self::dump_rec_tensor(&t.index_axis(Axis(0), i), el))
353 .collect();
354 array(values)
355 }
356 }
357
358 fn do_konst(
359 &mut self,
360 name: impl AsRef<str>,
361 tensor: &Arc<Tensor>,
362 force_variable: bool,
363 ) -> TractResult<Arc<RValue>> {
364 let mut name: Identifier = name.as_ref().into();
365 let have_tract_core = self.ensure_registry(&"tract_core".into()).is_ok();
366 if tensor.datum_type() == TDim::datum_type() {
367 return Ok(Self::dump_rec_tensor(&tensor.to_array_view::<TDim>()?, tdim).into());
368 }
369 if !force_variable && tensor.len() <= 8 {
370 if tensor.datum_type() == String::datum_type() {
371 return Ok(Self::dump_rec_tensor(&tensor.to_array_view::<String>()?, |f| {
372 string(f)
373 })
374 .into());
375 } else if tensor.datum_type() == DatumType::F32 {
376 return Ok(
377 Self::dump_rec_tensor(&tensor.to_array_view::<f32>()?, |f| numeric(f)).into()
378 );
379 } else if have_tract_core && tensor.datum_type() == DatumType::F16 {
380 let array =
381 Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
382 return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))]));
383 } else if have_tract_core && tensor.datum_type().is_integer() {
384 if let Ok(value) = tensor.cast_to::<i64>() {
385 let value =
386 Self::dump_rec_tensor(&value.to_array_view::<i64>().unwrap(), |i| {
387 numeric(i)
388 });
389 let to = string(format!("{:?}", tensor.datum_type()).to_lowercase());
390 return Ok(invocation("tract_core_cast", &[value.into()], &[("to", to)]));
391 }
392 };
393 }
394
395 if self.tensors.contains_key(&name) {
396 name = (0..)
397 .map(|it| Identifier::from(&*format!("{}_{}", name.0, it)))
398 .find(|it| !self.tensors.contains_key(it))
399 .unwrap();
400 }
401
402 self.tensors.insert(name.clone(), tensor.clone());
403 let id = self.scoped_id(&name);
404 let shape = if tensor.datum_type().is_opaque() {
405 if let Some(bqv) = tensor.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>() {
406 bqv.fact.shape()
407 } else {
408 bail!("Unexpected opaque tensor in serialization {tensor:?}");
409 }
410 } else {
411 tensor.shape()
412 };
413 self.assignment(
414 id.clone(),
415 RValue::Invocation(Invocation {
416 id: "variable".into(),
417 generic_type_name: Some(TypeName::Scalar),
418 arguments: vec![
419 named_arg("label", string(name.0)),
420 named_arg("shape", ints(shape)),
421 ],
422 })
423 .into(),
424 );
425 if let Some(qp) = QuantFormat::from_dt(tensor.datum_type()) {
426 self.quantization.insert(id.clone(), qp);
427 }
428 Ok(ident(id).into())
429 }
430
431 fn assignment(&mut self, name: impl AsRef<str>, right: Arc<RValue>) {
432 let name = name.as_ref();
433 if *right == ident(name) {
434 return;
435 }
436 self.body.push(assignment(name, right))
437 }
438}
439
440pub fn assignment(name: impl AsRef<str>, right: Arc<RValue>) -> Assignment {
441 Assignment { left: LValue::Identifier(name.as_ref().into()), right: right.as_ref().to_owned() }
442}
443
444pub fn ints(shape: &[usize]) -> RValue {
445 RValue::Array(shape.iter().map(|s| RValue::Literal(Literal::Numeric(s.to_string()))).collect())
446}
447
448pub fn tdims(shape: &[TDim]) -> RValue {
449 RValue::Array(shape.iter().map(tdim).collect())
450}
451
452pub fn tdim(dim: &TDim) -> RValue {
453 match dim {
454 TDim::Val(x) => numeric(x),
455 TDim::Sym(s) => ident(s.to_string()),
456 TDim::Add(terms) => terms
457 .iter()
458 .map(tdim)
459 .reduce(|x, y| RValue::Binary(x.boxed(), "+".to_string(), y.boxed()))
460 .unwrap(),
461 TDim::Mul(terms) => terms
462 .iter()
463 .map(tdim)
464 .reduce(|x, y| RValue::Binary(x.boxed(), "*".to_string(), y.boxed()))
465 .unwrap(),
466 TDim::MulInt(x, y) => RValue::Binary(numeric(x).boxed(), "*".to_string(), tdim(y).boxed()),
467 TDim::Div(x, y) => RValue::Binary(tdim(x).boxed(), "/".to_string(), numeric(y).boxed()),
468 TDim::Broadcast(_) => todo!(),
469 TDim::Min(_) | TDim::Max(_) => todo!(),
470 }
471}
472
473pub fn string(s: impl AsRef<str>) -> RValue {
474 RValue::Literal(Literal::String(s.as_ref().into()))
475}
476
477pub fn datum_type(dt: DatumType) -> RValue {
478 string(format!("{:?}", dt.unquantized()).to_lowercase())
479}
480
481pub fn logical(b: bool) -> RValue {
482 RValue::Literal(Literal::Logical(b))
483}
484
485pub fn lident(s: impl AsRef<str>) -> LValue {
486 LValue::Identifier(s.as_ref().into())
487}
488
489pub fn ident(s: impl AsRef<str>) -> RValue {
490 RValue::Identifier(s.as_ref().into())
491}
492
493pub fn array(items: impl AsRef<[RValue]>) -> RValue {
494 RValue::Array(items.as_ref().to_vec())
495}
496
497pub fn tuple_2(a: RValue, b: RValue) -> RValue {
498 RValue::Tuple(vec![a, b])
499}
500
501pub fn tuple_3(a: RValue, b: RValue, c: RValue) -> RValue {
502 RValue::Tuple(vec![a, b, c])
503}
504
505pub fn tuple_4(a: RValue, b: RValue, c: RValue, d: RValue) -> RValue {
506 RValue::Tuple(vec![a, b, c, d])
507}
508
509pub fn numeric<D: std::fmt::Debug>(num: D) -> RValue {
510 RValue::Literal(Literal::Numeric(format!("{num:?}")))
511}
512
513pub fn named_arg(id: &str, rv: RValue) -> Argument {
514 Argument { id: Some(id.into()), rvalue: rv }
515}
516
517pub fn invocation(
518 id: impl AsRef<str>,
519 positional: &[Arc<RValue>],
520 named: &[(&str, RValue)],
521) -> Arc<RValue> {
522 let arguments = positional
523 .iter()
524 .map(|rv| Argument { id: None, rvalue: rv.as_ref().clone() })
525 .chain(named.iter().map(|(n, v)| named_arg(n, v.clone())))
526 .collect();
527 RValue::Invocation(Invocation { id: id.as_ref().into(), generic_type_name: None, arguments })
528 .into()
529}