1use std::ops::ControlFlow;
2
3use tract_core::num_traits::Zero;
4use tract_core::tract_data::itertools::Itertools;
5
6use crate::ast::*;
7use crate::internal::*;
8
9pub struct ModelBuilder<'a> {
10 pub framework: &'a Nnef,
11 pub registries: Vec<Identifier>,
12 pub model: TypedModel,
13 pub naming_scopes: Vec<Identifier>,
14 pub scopes: Vec<HashMap<Identifier, Value>>,
15 pub proto_model: &'a ProtoModel,
16 pub symbols: Vec<Symbol>,
17 allow_new_symbol: bool,
18}
19
20impl<'mb> ModelBuilder<'mb> {
21 pub fn new(
22 framework: &'mb Nnef,
23 proto_model: &'mb ProtoModel,
24 template: TypedModel,
25 ) -> ModelBuilder<'mb> {
26 ModelBuilder {
27 registries: vec!["tract_nnef".into()],
28 framework,
29 model: template,
30 naming_scopes: vec![],
31 scopes: vec![],
32 proto_model,
33 symbols: vec![],
34 allow_new_symbol: false,
35 }
36 }
37
38 pub fn allowing_new_symbols<R>(&mut self, closure: impl Fn(&mut Self) -> R) -> R {
39 self.allow_new_symbol = true;
40 let r = closure(self);
41 self.allow_new_symbol = false;
42 r
43 }
44
45 fn translate(&mut self) -> TractResult<()> {
46 let mut scenario_specs = vec![];
47 'ext: for ext in &self.proto_model.doc.extension {
48 match &*ext.0 .0 {
49 "tract_registry" => {
50 let registry = Identifier(ext.1.trim().to_owned());
51 if self.framework.registries.iter().any(|reg| reg.id == registry) {
52 self.registries.push(registry.clone())
53 } else if let Some(reg) =
54 self.framework.registries.iter().find(|reg| reg.aliases.contains(®istry))
55 {
56 self.registries.push(reg.id.clone())
57 } else {
58 bail!("Registry not found {:?}", registry)
59 }
60 }
61 "tract_symbol" => {
62 let symbol = self.model.symbols.new_with_prefix(ext.1.trim());
63 self.symbols.push(symbol);
64 }
65 "tract_assert" => {
66 if let Some(pair) = ext.1.split_once(':') {
67 scenario_specs.push(pair);
68 } else {
69 self.model.symbols.add_assertion(&ext.1)?;
70 }
71 }
72 "KHR_enable_fragment_definitions" | "KHR_enable_operator_expressions" => (),
73 _ => {
74 for reg in &self.framework.registries {
75 for reg_ext in ®.extensions {
76 match reg_ext(self, &ext.0, &ext.1)? {
77 ControlFlow::Continue(_) => (),
78 ControlFlow::Break(_) => continue 'ext,
79 }
80 }
81 }
82 warn!("Ignore unknown extension {:?}", ext.0);
83 }
84 };
85 }
86 for (scen, rule) in scenario_specs {
87 self.model.symbols.add_scenario_assertion(scen, rule)?;
88 }
89 self.scopes.push(HashMap::new());
90 self.wire_body(&self.proto_model.doc.graph_def.body).context("Wiring root graph body")?;
91 let vars = self.scopes.pop().unwrap();
92
93 let outputs = self
94 .proto_model
95 .doc
96 .graph_def
97 .results
98 .iter()
99 .map(|s| {
100 vars.get(s)
101 .with_context(|| format!("Could not find variable for output named {s:?}"))
102 })
103 .collect::<TractResult<TVec<&Value>>>()?;
104
105 let outputs = outputs
106 .into_iter()
107 .map(|s| s.to::<OutletId>(self))
108 .collect::<TractResult<TVec<OutletId>>>()?;
109 self.model.set_output_outlets(&outputs)?;
110
111 self.parse_properties().context("Parsing properties")?;
112
113 for (ix, name) in self.proto_model.doc.graph_def.results.iter().enumerate() {
114 self.model.set_outlet_label(outputs[ix], name.0.to_string())?;
115 }
116
117 Ok(())
118 }
119
120 #[allow(clippy::result_large_err)]
121 pub fn into_typed_model(mut self) -> Result<TypedModel, (TypedModel, TractError)> {
122 match self.translate().context("In ModelBuilder::translate") {
123 Ok(()) => Ok(self.model),
124 Err(e) => Err((self.model, e)),
125 }
126 }
127
128 fn parse_properties(&mut self) -> TractResult<()> {
129 if let Some(properties) = self
130 .proto_model
131 .doc
132 .fragments
133 .iter()
134 .find(|f| &f.decl.id.0 == "tract_core_properties")
135 .and_then(|f| f.body.as_ref())
136 .and_then(|body| body.first())
137 {
138 let properties: TVec<(String, Arc<Tensor>)> =
139 properties.right.resolve(self, &[])?.to(self)?;
140 self.model.properties = properties.into_iter().collect();
141 }
142 Ok(())
143 }
144
145 pub fn wire_body(&mut self, body: &[Assignment]) -> TractResult<()> {
146 for assignment in body {
148 let identifiers = assignment.left.to_identifiers()?;
149 trace!("Wiring identifiers {identifiers:?}");
150 let datum_types = identifiers
151 .iter()
152 .map(|s| {
153 self.proto_model
154 .quantization
155 .as_ref()
156 .and_then(|qm| qm.get(*s).map(|q| q.datum_type()))
157 })
158 .collect::<Vec<_>>();
159 self.naming_scopes.push(identifiers[0].clone());
160 let mut values = if identifiers.len() == 1 {
161 let value: OutletId = assignment
162 .right
163 .resolve(self, &datum_types)
164 .and_then(|v| v.to(self))
165 .with_context(|| {
166 format!(
167 "Plugging in assignement for {:?}",
168 identifiers.iter().map(|i| &i.0).join(", ")
169 )
170 })?;
171 tvec!(value)
172 } else {
173 let values: TVec<OutletId> = assignment
174 .right
175 .resolve(self, &datum_types)
176 .and_then(|v| v.to(self))
177 .with_context(|| {
178 format!(
179 "Plugging in assignement for {:?}",
180 identifiers.iter().map(|i| &i.0).join(", ")
181 )
182 })?;
183 if values.len() != identifiers.len() {
184 bail!(
185 "Assignement for {} received {} value(s).",
186 identifiers.iter().map(|i| &i.0).join(","),
187 values.len()
188 )
189 }
190 values
191 };
192 for (qparam, value) in datum_types.into_iter().zip(values.iter_mut()) {
193 if let Some(qparam) = qparam {
194 if qparam != self.model.outlet_fact(*value)?.datum_type {
195 self.model.node_mut(value.node).name =
196 format!("{}_raw", self.naming_scopes.iter().map(|i| &i.0).join("_"));
197 if self.model.outlet_fact(*value)?.datum_type == TDim::datum_type() {
198 *value = self.model.wire_node(
199 format!(
200 "{}_cast_to_f32",
201 self.naming_scopes.iter().map(|i| &i.0).join("_")
202 ),
203 tract_core::ops::cast::cast(f32::datum_type()),
204 &[*value],
205 )?[0];
206 }
207 *value = self.model.wire_node(
208 format!(
209 "{}_cast_to_q",
210 self.naming_scopes.iter().map(|i| &i.0).join("_")
211 ),
212 tract_core::ops::cast::cast(qparam),
213 &[*value],
214 )?[0];
215 }
216 }
217 }
218 for (id, outlet) in identifiers.iter().zip(values.iter()) {
219 self.scopes.last_mut().unwrap().insert((*id).clone(), Value::Wire(*outlet));
220 }
221 self.naming_scopes.pop();
222 for (value, identifier) in values.iter().zip(identifiers) {
223 if self.model.node_mut(value.node).name.is_empty() {
224 self.naming_scopes.push(identifier.clone());
225 self.model.node_mut(value.node).name = self.generate_node_name();
226 self.naming_scopes.pop();
227 }
228 }
229 }
230 Ok(())
231 }
232
233 pub fn wire_invocation(
234 &mut self,
235 invocation: &Invocation,
236 dt: &[Option<DatumType>],
237 ) -> TractResult<Value> {
238 for frag in &self.proto_model.doc.fragments {
239 if frag.decl.id == invocation.id && frag.body.is_some() {
240 let resolved = ResolvedInvocation {
241 invocation,
242 dt_from_quant_file: dt,
243 default_params: &frag.decl.parameters,
244 };
245 return self.wire_fragment_invocation(
246 &resolved,
247 &frag.decl,
248 frag.body.as_deref().unwrap(),
249 );
250 }
251 }
252
253 for registry in self.framework.registries.iter().rev() {
255 if self.registries.contains(®istry.id) {
256 if let Some(outputs) = registry
257 .deserialize(self, invocation, dt)
258 .with_context(|| format!("Interrogating registry {:?}", registry.id))?
259 {
260 return Ok(outputs);
261 }
262 }
263 }
264 bail!("No definition for operator {:?}", invocation.id);
265 }
266
267 pub fn wire_fragment_invocation(
268 &mut self,
269 invocation: &ResolvedInvocation,
270 decl: &FragmentDecl,
271 body: &[Assignment],
272 ) -> TractResult<Value> {
273 let mut inner_scope = HashMap::new();
274 for par in invocation.default_params.iter() {
275 inner_scope.insert(par.id.clone(), invocation.named_arg_as::<Value>(self, &par.id.0)?);
276 }
277 self.scopes.push(inner_scope);
278 self.with_extra_naming_scope(invocation.invocation.id.clone(), |b| b.wire_body(body))?;
279 let inner_scope = self.scopes.pop().unwrap();
280 Ok(Value::Tuple(
281 decl.results.iter().map(|res| inner_scope.get(&res.id).unwrap()).cloned().collect(),
282 ))
283 }
284
285 fn with_extra_naming_scope<F: FnOnce(&mut Self) -> R, R>(
286 &mut self,
287 name: Identifier,
288 f: F,
289 ) -> R {
290 self.naming_scopes.push(name);
291 let r = f(self);
292 self.naming_scopes.pop();
293 r
294 }
295
296 pub fn generate_node_name(&self) -> String {
297 let name = self.naming_scopes.iter().map(|n| &n.0).join("_");
298 if self.model.nodes().iter().any(|n| n.name == name) {
299 for i in 0.. {
300 let candidate = format!("{name}_{i}");
301 if !self.model.nodes().iter().any(|n| n.name.starts_with(&candidate)) {
302 return candidate;
303 }
304 }
305 }
306 name
307 }
308
309 pub fn wire_as_outlets(
310 &mut self,
311 op: impl Into<Box<dyn TypedOp>>,
312 inputs: &[OutletId],
313 ) -> TractResult<TVec<OutletId>> {
314 let op = op.into();
315 let name = self.generate_node_name();
316 self.model.wire_node(name, op, inputs).with_context(|| format!("inputs are {inputs:?}"))
317 }
318
319 pub fn add_const(&mut self, v: impl IntoArcTensor) -> TractResult<OutletId> {
320 self.model.add_const(self.generate_node_name(), v)
321 }
322
323 pub fn wire(
324 &mut self,
325 op: impl Into<Box<dyn TypedOp>>,
326 inputs: &[OutletId],
327 ) -> TractResult<Value> {
328 self.wire_as_outlets(op, inputs).map(Value::from)
329 }
330}
331
332#[derive(Clone, Debug)]
333pub struct ResolvedInvocation<'a> {
334 pub invocation: &'a Invocation,
335 pub dt_from_quant_file: &'a [Option<DatumType>],
336 pub default_params: &'a [Parameter],
337}
338
339impl ResolvedInvocation<'_> {
340 pub fn named_arg_as<T>(&self, builder: &mut ModelBuilder, name: &str) -> TractResult<T>
341 where
342 T: CoerceFrom<Value>,
343 {
344 let rv = self.named_arg(name)?;
345 builder.with_extra_naming_scope(Identifier(name.into()), |builder| {
346 let v = rv
347 .resolve(builder, &[])
348 .with_context(|| format!("Resolving argument `{name}' ({rv:?})"))?;
349 v.to::<T>(builder).with_context(|| format!("Converting argument `{name}' from {v:?}"))
350 })
351 }
352
353 pub fn optional_named_arg_as<T>(
354 &self,
355 builder: &mut ModelBuilder,
356 name: &str,
357 ) -> TractResult<Option<T>>
358 where
359 T: CoerceFrom<Value>,
360 {
361 let Some(rv) = self.get_named_arg(name) else { return Ok(None) };
362 let v = rv
363 .resolve(builder, &[])
364 .with_context(|| format!("Resolving argument `{name}' ({rv:?})"))?;
365 match v {
366 Value::Bool(b) => {
367 if !b {
368 Ok(None)
369 } else {
370 bail!("Bool(true) not expected for optional values, you might want to access a boolean direclty.")
371 }
372 }
373 _ => v
374 .to::<T>(builder)
375 .map(Option::Some)
376 .with_context(|| format!("Converting argument `{name}' from {v:?}")),
377 }
378 }
379
380 pub fn named_arg(&self, name: &str) -> TractResult<Cow<'_, RValue>> {
381 self.get_named_arg(name).ok_or_else(|| format_err!("expected argument {}", name))
382 }
383
384 pub fn get_named_arg(&self, name: &str) -> Option<Cow<'_, RValue>> {
385 if let Some(arg) = self
387 .invocation
388 .arguments
389 .iter()
390 .find(|arg| arg.id.as_ref().map(|i| &*i.0) == Some(name))
391 {
392 return Some(Cow::Borrowed(&arg.rvalue));
393 }
394 if let Some((ix, param)) =
396 self.default_params.iter().enumerate().find(|(_ix, param)| &*param.id.0 == name)
397 {
398 if self.invocation.arguments.len() > ix
401 && self.invocation.arguments.iter().take(ix + 1).all(|arg| arg.id.is_none())
402 {
403 return Some(Cow::Borrowed(&self.invocation.arguments[ix].rvalue));
404 }
405 if let Some(rv) = ¶m.lit {
406 return Some(Cow::Owned(RValue::Literal(rv.clone())));
407 }
408 }
409 None
410 }
411
412 pub fn get_named_arg_as<T>(
413 &self,
414 builder: &mut ModelBuilder,
415 name: &str,
416 ) -> TractResult<Option<T>>
417 where
418 T: CoerceFrom<Value>,
419 {
420 let Some(rv) = self.get_named_arg(name) else { return Ok(None) };
421 let v = rv
422 .resolve(builder, &[])
423 .with_context(|| format!("Resolving argument `{name}' ({rv:?})"))?;
424 v.to::<T>(builder)
425 .with_context(|| format!("Converting argument `{name}' from {v:?}"))
426 .map(Some)
427 }
428}
429
430impl ModelBuilder<'_> {}
431
432impl LValue {
433 fn to_identifier(&self) -> TractResult<&Identifier> {
434 match self {
435 LValue::Identifier(id) => Ok(id),
436 _ => bail!("Expected an identifier, found a tuple: {:?}", self),
437 }
438 }
439
440 #[allow(dead_code)]
441 fn to_identifiers(&self) -> TractResult<TVec<&Identifier>> {
442 match self {
443 LValue::Identifier(_) => Ok(tvec!(self.to_identifier()?)),
444 LValue::Tuple(ids) => ids.iter().map(|id| id.to_identifier()).collect(),
445 LValue::Array(ids) => ids.iter().map(|id| id.to_identifier()).collect(),
446 }
447 }
448}
449
450impl Invocation {}
451
452impl RValue {
453 pub fn resolve(
454 &self,
455 builder: &mut ModelBuilder,
456 dt: &[Option<DatumType>],
457 ) -> TractResult<Value> {
458 match self {
459 RValue::Identifier(id) => {
460 if let Some(mut outlet) = builder.scopes.last().unwrap().get(id).cloned() {
461 if let Value::Wire(outlet_id) = outlet {
462 let out_dt = builder.model.node(outlet_id.node).outputs[outlet_id.slot]
463 .fact
464 .datum_type;
465 if let Some(Some(dt)) = dt.first() {
466 if out_dt.unquantized() != dt.unquantized() {
467 return Err(format_err!(
468 "Mismatched types expected {:?}, got {:?}",
469 dt,
470 out_dt
471 ));
472 }
473 if out_dt != *dt {
474 outlet =
475 builder.wire(tract_core::ops::cast::cast(*dt), &[outlet_id])?;
476 }
477 }
478 }
479 Ok(outlet)
480 } else if let Some(sym) = builder.model.symbols.get(&id.0) {
481 Ok(Value::Dim(sym.into()))
482 } else if builder.allow_new_symbol {
483 warn!("Introducing symbol {id:?} without forward declaration (\"extension tract_symbol ...\"). May be deprecated soon.");
484 let sym = builder.model.symbols.sym(&id.0);
485 Ok(Value::Dim(sym.into()))
486 } else {
487 bail!("Can not resolve {:?}. Not a known identifier, and symbol introduction is forbidden out of \"external\" shape field", id);
488 }
489 }
490 RValue::Invocation(inv) => builder
491 .wire_invocation(inv, dt)
492 .with_context(|| format!("Resolving invocation {:?}", inv.id)),
493 RValue::Binary(left, op, right) => {
494 let op = match &**op {
495 "+" => "add",
496 "-" => "sub",
497 "*" => "mul",
498 "/" => "div",
499 "^" => "pow",
500 ">" => "gt",
501 "<" => "lt",
502 "==" => "eq",
503 "!=" => "ne",
504 ">=" => "ge",
505 "<=" => "le",
506 op => bail!("Unknown binary operator: {}", op),
507 };
508 let inv = Invocation {
509 id: op.into(),
510 generic_type_name: None,
511 arguments: vec![
512 Argument { id: None, rvalue: left.as_ref().clone() },
513 Argument { id: None, rvalue: right.as_ref().clone() },
514 ],
515 };
516 builder
517 .wire_invocation(&inv, dt)
518 .with_context(|| format!("Resolving invocation {:?}", &inv.id))
519 }
520 RValue::Array(array) => Ok(Value::Array(
521 array
522 .iter()
523 .zip(std::iter::repeat(&dt.first().copied().flatten()))
524 .map(|(i, dt)| i.resolve(builder, &[*dt]))
525 .collect::<TractResult<_>>()?,
526 )),
527 RValue::Tuple(array) => {
528 let dt_iter: Box<dyn Iterator<Item = &Option<DatumType>>> =
529 if dt.len() == 0 || dt.len() == 1 && dt[0].is_none() {
530 Box::new(std::iter::repeat(&None))
531 } else if dt.len() == array.len() {
532 Box::new(dt.iter())
533 } else {
534 bail!("Wrong number of types for a tuple, got {:?} for {:?}", dt, array)
535 };
536 Ok(Value::Tuple(
537 array
538 .iter()
539 .zip(dt_iter)
540 .map(|(i, dt)| {
541 if dt.is_none() {
542 i.resolve(builder, &[])
543 } else {
544 i.resolve(builder, &[*dt])
545 }
546 })
547 .collect::<TractResult<_>>()?,
548 ))
549 }
550 RValue::Literal(Literal::Numeric(f)) => {
551 if f.contains('.') || f.contains('e') || f == "inf" || f == "-inf" {
552 f.parse::<f32>()
553 .map(Value::Scalar)
554 .with_context(|| format!("Can not parse {f} as f32"))
555 } else if let Ok(i) = f.parse::<i64>() {
556 Ok(Value::Dim(i.into()))
557 } else if let Some(s) = builder.model.symbols.get(f) {
558 Ok(Value::Dim(s.into()))
559 } else {
560 bail!("Can not parse {}", f)
561 }
562 }
563 RValue::Literal(Literal::String(s)) => Ok(Value::String(s.clone())),
564 RValue::Literal(Literal::Logical(s)) => Ok(Value::Bool(*s)),
565 RValue::Literal(Literal::Array(array)) => Ok(Value::Array(
566 array
567 .iter()
568 .zip(std::iter::repeat(&dt.first().copied().flatten()))
569 .map(|(i, dt)| RValue::Literal(i.clone()).resolve(builder, &[*dt]))
570 .collect::<TractResult<_>>()?,
571 )),
572 _ => panic!("{self:?}"),
573 }
574 }
575}
576
577#[derive(Clone, Debug, PartialEq)]
578pub enum Value {
579 Tensor(Arc<Tensor>),
580 Wire(OutletId),
581 Array(Vec<Value>),
582 Tuple(Vec<Value>),
583 String(String),
584 Bool(bool),
585 Scalar(f32),
586 Dim(TDim),
587}
588
589impl Value {
590 pub fn to<T>(&self, builder: &mut ModelBuilder) -> TractResult<T>
591 where
592 T: CoerceFrom<Value>,
593 {
594 T::coerce(builder, self)
595 }
596}
597
598impl From<TVec<OutletId>> for Value {
599 fn from(outled_ids: TVec<OutletId>) -> Self {
600 Self::Tuple(outled_ids.into_iter().map(Self::Wire).collect())
601 }
602}
603
604pub trait CoerceFrom<F> {
605 fn coerce(builder: &mut ModelBuilder, from: &F) -> TractResult<Self>
606 where
607 Self: Sized;
608}
609
610impl CoerceFrom<Value> for Value {
611 fn coerce(_builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
612 Ok(from.clone())
613 }
614}
615
616impl CoerceFrom<Value> for Arc<Tensor> {
617 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
618 match from {
619 Value::Dim(t) => Ok(rctensor0(t.to_i32()?)),
620 Value::Tensor(t) => Ok(t.clone()),
621 Value::Tuple(t) if t.len() == 1 => t[0].to(builder),
622 Value::Scalar(f) => Ok(rctensor0(*f)),
623 Value::String(f) => Ok(rctensor0(f.clone())),
624 Value::Bool(b) => Ok(rctensor0(*b)),
625 Value::Wire(o) => builder
626 .model
627 .outlet_fact(*o)?
628 .konst
629 .clone()
630 .ok_or_else(|| format_err!("Not a const")),
631 Value::Array(items) => {
632 let mut tensors = vec![];
633 for item in items {
634 let tensor = Arc::<Tensor>::coerce(builder, item)?;
635 let mut tensor = tensor.into_tensor();
636 tensor.insert_axis(0)?;
637 tensors.push(tensor);
638 }
639 let tensor = Tensor::stack_tensors(0, &tensors)?;
640 Ok(tensor.into_arc_tensor())
641 }
642 _ => bail!("Can not build a tensor from {:?}", from),
643 }
644 }
645}
646
647impl CoerceFrom<Value> for (Arc<Tensor>, DatumType) {
648 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
649 match from {
650 Value::Tensor(t) => Ok((t.clone(), t.datum_type())),
651 Value::Scalar(f) => Ok((rctensor0(*f), DatumType::F32)),
652 Value::String(f) => Ok((rctensor0(f.clone()), DatumType::String)),
653 Value::Bool(b) => Ok((rctensor0(*b), DatumType::Bool)),
654 Value::Wire(o) => {
655 let outlet_fact = builder.model.outlet_fact(*o)?;
656 Ok((
657 outlet_fact.konst.clone().ok_or_else(|| format_err!("Not a const"))?,
658 outlet_fact.datum_type,
659 ))
660 }
661 _ => bail!("Can not build a tensor from {:?}", from),
662 }
663 }
664}
665
666impl CoerceFrom<Value> for OutletId {
667 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
668 match from {
669 Value::Tensor(t) => builder.add_const(t.clone()),
670 Value::Scalar(f) => builder.add_const(rctensor0(*f)),
671 Value::Dim(i) => builder.add_const(rctensor0(i.clone())),
672 Value::Wire(outlet) => Ok(*outlet),
673 Value::Tuple(tuple) if tuple.len() == 1 => OutletId::coerce(builder, &tuple[0]),
674 Value::Array(inputs) => {
675 if let Ok(c) = from.to::<Arc<Tensor>>(builder) {
676 return builder.add_const(c);
677 }
678 let mut outlets = tvec!();
679 for i in inputs {
680 let outlet = OutletId::coerce(builder, i)?;
681 outlets.push(builder.wire_as_outlets(AxisOp::Add(0), &[outlet])?[0]);
682 }
683 builder
684 .wire_as_outlets(tract_core::ops::array::TypedConcat::new(0), &outlets)
685 .map(|o| o[0])
686 }
687 Value::String(s) => builder.add_const(rctensor0(s.clone())),
688 Value::Bool(b) => builder.add_const(rctensor0(*b)),
689 _ => bail!("Can not build an outletid from {:?}", from),
690 }
691 }
692}
693
694impl CoerceFrom<Value> for u64 {
695 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
696 match from {
697 Value::Dim(d) => Ok(d.to_i64()? as u64),
698 Value::Tensor(t) => Ok(t.cast_to_scalar::<u64>()?),
699 Value::Wire(_) => Ok(from.to::<Arc<Tensor>>(builder)?.cast_to_scalar::<u64>()?),
700 _ => bail!("Can not build a u64 from {:?}", from),
701 }
702 }
703}
704
705impl CoerceFrom<Value> for i64 {
706 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
707 match from {
708 Value::Dim(d) => d.to_i64(),
709 Value::Tensor(t) => Ok(*t.to_scalar::<i64>()?),
710 Value::Wire(_) => Ok(from.to::<Arc<Tensor>>(builder)?.cast_to_scalar::<i64>()?),
711 _ => bail!("Can not build a i64 from {:?}", from),
712 }
713 }
714}
715
716impl CoerceFrom<Value> for TDim {
717 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
718 match from {
719 Value::Dim(d) => Ok(d.clone()),
720 Value::Tensor(t) => Ok(t.to_scalar::<TDim>()?.clone()),
721 Value::Wire(_) => {
722 Ok(from.to::<Arc<Tensor>>(builder)?.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone())
723 }
724 _ => bail!("Can not build a TDim from {:?}", from),
725 }
726 }
727}
728
729impl CoerceFrom<Value> for String {
730 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
731 match from {
732 Value::String(s) => Ok(s.to_string()),
733 Value::Tensor(t) => Ok(t.to_scalar::<String>()?.clone()),
734 Value::Wire(_) => Ok(from
735 .to::<Arc<Tensor>>(builder)?
736 .cast_to::<String>()?
737 .to_scalar::<String>()?
738 .clone()),
739 _ => bail!("Can not build a String from {:?}", from),
740 }
741 }
742}
743
744impl CoerceFrom<Value> for bool {
745 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
746 match from {
747 Value::Bool(b) => Ok(*b),
748 Value::Tensor(t) => Ok(*t.to_scalar::<bool>()?),
749 Value::Wire(_) => {
750 Ok(*from.to::<Arc<Tensor>>(builder)?.cast_to::<bool>()?.to_scalar::<bool>()?)
751 }
752 Value::Dim(n) => Ok(!n.is_zero()),
753 _ => bail!("Can not build a boolean from {:?}", from),
754 }
755 }
756}
757
758impl CoerceFrom<Value> for usize {
759 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
760 Ok(i64::coerce(builder, from)? as usize)
761 }
762}
763
764impl CoerceFrom<Value> for isize {
765 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
766 Ok(i64::coerce(builder, from)? as isize)
767 }
768}
769
770impl CoerceFrom<Value> for f32 {
771 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
772 match from {
773 Value::Scalar(f) => Ok(*f),
774 Value::Dim(d) => Ok(d.to_i64()? as f32),
775 Value::Tensor(t) => Ok(*t.to_scalar::<f32>()?),
776 Value::Wire(_) => {
777 Ok(*from.to::<Arc<Tensor>>(builder)?.cast_to::<f32>()?.to_scalar::<f32>()?)
778 }
779 _ => bail!("Can not build a f32 from {:?}", from),
780 }
781 }
782}
783
784impl<D: CoerceFrom<Value>> CoerceFrom<Value> for TVec<D> {
785 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
786 match from {
787 Value::Array(vec) => vec.iter().map(|item| D::coerce(builder, item)).collect(),
788 Value::Tuple(vec) => vec.iter().map(|item| D::coerce(builder, item)).collect(),
789 any => Ok(tvec!(D::coerce(builder, any)?)),
790 }
791 }
792}
793
794impl CoerceFrom<Value> for ShapeFact {
795 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
796 match from {
797 Value::Array(vec) => vec.iter().map(|item| TDim::coerce(builder, item)).collect(),
798 Value::Tuple(vec) => vec.iter().map(|item| TDim::coerce(builder, item)).collect(),
799 _ => {
800 let t = from.to::<Arc<Tensor>>(builder)?;
801 Ok(t.cast_to::<TDim>()?.as_slice::<TDim>()?.into())
802 }
803 }
804 }
805}
806
807macro_rules! tuple {
808 ($($d: ident),*) => {
809 impl<$($d),*> CoerceFrom<Value> for ($($d),*)
810 where
811 $($d: CoerceFrom<Value>),*
812 {
813 fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
814 match from {
815 Value::Tuple(vec) => {
816 let mut vec = vec.iter();
817 Ok((
818 $($d::coerce(builder, vec.next().context("Too small a tuple")?)?),*
819 ))
820 }
821 _ => bail!("Can not build a tuple from {:?}", from),
822 }
823 }
824 }
825 }
826}
827
828tuple!(D1, D2);
829tuple!(D1, D2, D3);
830tuple!(D1, D2, D3, D4);