Skip to main content

tract_nnef/
deser.rs

1use std::ops::ControlFlow;
2
3use tract_core::num_traits::Zero;
4use tract_core::tract_data::itertools::Itertools;
5
6use crate::ast::*;
7use crate::internal::*;
8
9pub struct ModelBuilder<'a> {
10    pub framework: &'a Nnef,
11    pub registries: Vec<Identifier>,
12    pub model: TypedModel,
13    pub naming_scopes: Vec<Identifier>,
14    pub scopes: Vec<HashMap<Identifier, Value>>,
15    pub proto_model: &'a ProtoModel,
16    pub symbols: Vec<Symbol>,
17    allow_new_symbol: bool,
18}
19
20impl<'mb> ModelBuilder<'mb> {
21    pub fn new(
22        framework: &'mb Nnef,
23        proto_model: &'mb ProtoModel,
24        template: TypedModel,
25    ) -> ModelBuilder<'mb> {
26        ModelBuilder {
27            registries: vec!["tract_nnef".into()],
28            framework,
29            model: template,
30            naming_scopes: vec![],
31            scopes: vec![],
32            proto_model,
33            symbols: vec![],
34            allow_new_symbol: false,
35        }
36    }
37
38    pub fn allowing_new_symbols<R>(&mut self, closure: impl Fn(&mut Self) -> R) -> R {
39        self.allow_new_symbol = true;
40        let r = closure(self);
41        self.allow_new_symbol = false;
42        r
43    }
44
45    fn translate(&mut self) -> TractResult<()> {
46        let mut scenario_specs = vec![];
47        'ext: for ext in &self.proto_model.doc.extension {
48            match &*ext.0 .0 {
49                "tract_registry" => {
50                    let registry = Identifier(ext.1.trim().to_owned());
51                    if self.framework.registries.iter().any(|reg| reg.id == registry) {
52                        self.registries.push(registry.clone())
53                    } else if let Some(reg) =
54                        self.framework.registries.iter().find(|reg| reg.aliases.contains(&registry))
55                    {
56                        self.registries.push(reg.id.clone())
57                    } else {
58                        bail!("Registry not found {:?}", registry)
59                    }
60                }
61                "tract_symbol" => {
62                    let symbol = self.model.symbols.new_with_prefix(ext.1.trim());
63                    self.symbols.push(symbol);
64                }
65                "tract_assert" => {
66                    if let Some(pair) = ext.1.split_once(':') {
67                        scenario_specs.push(pair);
68                    } else {
69                        self.model.symbols.add_assertion(&ext.1)?;
70                    }
71                }
72                "KHR_enable_fragment_definitions" | "KHR_enable_operator_expressions" => (),
73                _ => {
74                    for reg in &self.framework.registries {
75                        for reg_ext in &reg.extensions {
76                            match reg_ext(self, &ext.0, &ext.1)? {
77                                ControlFlow::Continue(_) => (),
78                                ControlFlow::Break(_) => continue 'ext,
79                            }
80                        }
81                    }
82                    warn!("Ignore unknown extension {:?}", ext.0);
83                }
84            };
85        }
86        for (scen, rule) in scenario_specs {
87            self.model.symbols.add_scenario_assertion(scen, rule)?;
88        }
89        self.scopes.push(HashMap::new());
90        self.wire_body(&self.proto_model.doc.graph_def.body).context("Wiring root graph body")?;
91        let vars = self.scopes.pop().unwrap();
92
93        let outputs = self
94            .proto_model
95            .doc
96            .graph_def
97            .results
98            .iter()
99            .map(|s| {
100                vars.get(s)
101                    .with_context(|| format!("Could not find variable for output named {s:?}"))
102            })
103            .collect::<TractResult<TVec<&Value>>>()?;
104
105        let outputs = outputs
106            .into_iter()
107            .map(|s| s.to::<OutletId>(self))
108            .collect::<TractResult<TVec<OutletId>>>()?;
109        self.model.set_output_outlets(&outputs)?;
110
111        self.parse_properties().context("Parsing properties")?;
112
113        for (ix, name) in self.proto_model.doc.graph_def.results.iter().enumerate() {
114            self.model.set_outlet_label(outputs[ix], name.0.to_string())?;
115        }
116
117        Ok(())
118    }
119
120    #[allow(clippy::result_large_err)]
121    pub fn into_typed_model(mut self) -> Result<TypedModel, (TypedModel, TractError)> {
122        match self.translate().context("In ModelBuilder::translate") {
123            Ok(()) => Ok(self.model),
124            Err(e) => Err((self.model, e)),
125        }
126    }
127
128    fn parse_properties(&mut self) -> TractResult<()> {
129        if let Some(properties) = self
130            .proto_model
131            .doc
132            .fragments
133            .iter()
134            .find(|f| &f.decl.id.0 == "tract_core_properties")
135            .and_then(|f| f.body.as_ref())
136            .and_then(|body| body.first())
137        {
138            let properties: TVec<(String, Arc<Tensor>)> =
139                properties.right.resolve(self, &[])?.to(self)?;
140            self.model.properties = properties.into_iter().collect();
141        }
142        Ok(())
143    }
144
145    pub fn wire_body(&mut self, body: &[Assignment]) -> TractResult<()> {
146        // todo: can i relax the outlet id constraint ?
147        for assignment in body {
148            let identifiers = assignment.left.to_identifiers()?;
149            trace!("Wiring identifiers {identifiers:?}");
150            let datum_types = identifiers
151                .iter()
152                .map(|s| {
153                    self.proto_model
154                        .quantization
155                        .as_ref()
156                        .and_then(|qm| qm.get(*s).map(|q| q.datum_type()))
157                })
158                .collect::<Vec<_>>();
159            self.naming_scopes.push(identifiers[0].clone());
160            let mut values = if identifiers.len() == 1 {
161                let value: OutletId = assignment
162                    .right
163                    .resolve(self, &datum_types)
164                    .and_then(|v| v.to(self))
165                    .with_context(|| {
166                        format!(
167                            "Plugging in assignement for {:?}",
168                            identifiers.iter().map(|i| &i.0).join(", ")
169                        )
170                    })?;
171                tvec!(value)
172            } else {
173                let values: TVec<OutletId> = assignment
174                    .right
175                    .resolve(self, &datum_types)
176                    .and_then(|v| v.to(self))
177                    .with_context(|| {
178                        format!(
179                            "Plugging in assignement for {:?}",
180                            identifiers.iter().map(|i| &i.0).join(", ")
181                        )
182                    })?;
183                if values.len() != identifiers.len() {
184                    bail!(
185                        "Assignement for {} received {} value(s).",
186                        identifiers.iter().map(|i| &i.0).join(","),
187                        values.len()
188                    )
189                }
190                values
191            };
192            for (qparam, value) in datum_types.into_iter().zip(values.iter_mut()) {
193                if let Some(qparam) = qparam {
194                    if qparam != self.model.outlet_fact(*value)?.datum_type {
195                        self.model.node_mut(value.node).name =
196                            format!("{}_raw", self.naming_scopes.iter().map(|i| &i.0).join("_"));
197                        if self.model.outlet_fact(*value)?.datum_type == TDim::datum_type() {
198                            *value = self.model.wire_node(
199                                format!(
200                                    "{}_cast_to_f32",
201                                    self.naming_scopes.iter().map(|i| &i.0).join("_")
202                                ),
203                                tract_core::ops::cast::cast(f32::datum_type()),
204                                &[*value],
205                            )?[0];
206                        }
207                        *value = self.model.wire_node(
208                            format!(
209                                "{}_cast_to_q",
210                                self.naming_scopes.iter().map(|i| &i.0).join("_")
211                            ),
212                            tract_core::ops::cast::cast(qparam),
213                            &[*value],
214                        )?[0];
215                    }
216                }
217            }
218            for (id, outlet) in identifiers.iter().zip(values.iter()) {
219                self.scopes.last_mut().unwrap().insert((*id).clone(), Value::Wire(*outlet));
220            }
221            self.naming_scopes.pop();
222            for (value, identifier) in values.iter().zip(identifiers) {
223                if self.model.node_mut(value.node).name.is_empty() {
224                    self.naming_scopes.push(identifier.clone());
225                    self.model.node_mut(value.node).name = self.generate_node_name();
226                    self.naming_scopes.pop();
227                }
228            }
229        }
230        Ok(())
231    }
232
233    pub fn wire_invocation(
234        &mut self,
235        invocation: &Invocation,
236        dt: &[Option<DatumType>],
237    ) -> TractResult<Value> {
238        for frag in &self.proto_model.doc.fragments {
239            if frag.decl.id == invocation.id && frag.body.is_some() {
240                let resolved = ResolvedInvocation {
241                    invocation,
242                    dt_from_quant_file: dt,
243                    default_params: &frag.decl.parameters,
244                };
245                return self.wire_fragment_invocation(
246                    &resolved,
247                    &frag.decl,
248                    frag.body.as_deref().unwrap(),
249                );
250            }
251        }
252
253        // We start with the registry that has been added last
254        for registry in self.framework.registries.iter().rev() {
255            if self.registries.contains(&registry.id) {
256                if let Some(outputs) = registry
257                    .deserialize(self, invocation, dt)
258                    .with_context(|| format!("Interrogating registry {:?}", registry.id))?
259                {
260                    return Ok(outputs);
261                }
262            }
263        }
264        bail!("No definition for operator {:?}", invocation.id);
265    }
266
267    pub fn wire_fragment_invocation(
268        &mut self,
269        invocation: &ResolvedInvocation,
270        decl: &FragmentDecl,
271        body: &[Assignment],
272    ) -> TractResult<Value> {
273        let mut inner_scope = HashMap::new();
274        for par in invocation.default_params.iter() {
275            inner_scope.insert(par.id.clone(), invocation.named_arg_as::<Value>(self, &par.id.0)?);
276        }
277        self.scopes.push(inner_scope);
278        self.with_extra_naming_scope(invocation.invocation.id.clone(), |b| b.wire_body(body))?;
279        let inner_scope = self.scopes.pop().unwrap();
280        Ok(Value::Tuple(
281            decl.results.iter().map(|res| inner_scope.get(&res.id).unwrap()).cloned().collect(),
282        ))
283    }
284
285    fn with_extra_naming_scope<F: FnOnce(&mut Self) -> R, R>(
286        &mut self,
287        name: Identifier,
288        f: F,
289    ) -> R {
290        self.naming_scopes.push(name);
291        let r = f(self);
292        self.naming_scopes.pop();
293        r
294    }
295
296    pub fn generate_node_name(&self) -> String {
297        let name = self.naming_scopes.iter().map(|n| &n.0).join("_");
298        if self.model.nodes().iter().any(|n| n.name == name) {
299            for i in 0.. {
300                let candidate = format!("{name}_{i}");
301                if !self.model.nodes().iter().any(|n| n.name.starts_with(&candidate)) {
302                    return candidate;
303                }
304            }
305        }
306        name
307    }
308
309    pub fn wire_as_outlets(
310        &mut self,
311        op: impl Into<Box<dyn TypedOp>>,
312        inputs: &[OutletId],
313    ) -> TractResult<TVec<OutletId>> {
314        let op = op.into();
315        let name = self.generate_node_name();
316        self.model.wire_node(name, op, inputs).with_context(|| format!("inputs are {inputs:?}"))
317    }
318
319    pub fn add_const(&mut self, v: impl IntoArcTensor) -> TractResult<OutletId> {
320        self.model.add_const(self.generate_node_name(), v)
321    }
322
323    pub fn wire(
324        &mut self,
325        op: impl Into<Box<dyn TypedOp>>,
326        inputs: &[OutletId],
327    ) -> TractResult<Value> {
328        self.wire_as_outlets(op, inputs).map(Value::from)
329    }
330}
331
332#[derive(Clone, Debug)]
333pub struct ResolvedInvocation<'a> {
334    pub invocation: &'a Invocation,
335    pub dt_from_quant_file: &'a [Option<DatumType>],
336    pub default_params: &'a [Parameter],
337}
338
339impl ResolvedInvocation<'_> {
340    pub fn named_arg_as<T>(&self, builder: &mut ModelBuilder, name: &str) -> TractResult<T>
341    where
342        T: CoerceFrom<Value>,
343    {
344        let rv = self.named_arg(name)?;
345        builder.with_extra_naming_scope(Identifier(name.into()), |builder| {
346            let v = rv
347                .resolve(builder, &[])
348                .with_context(|| format!("Resolving argument `{name}' ({rv:?})"))?;
349            v.to::<T>(builder).with_context(|| format!("Converting argument `{name}' from {v:?}"))
350        })
351    }
352
353    pub fn optional_named_arg_as<T>(
354        &self,
355        builder: &mut ModelBuilder,
356        name: &str,
357    ) -> TractResult<Option<T>>
358    where
359        T: CoerceFrom<Value>,
360    {
361        let Some(rv) = self.get_named_arg(name) else { return Ok(None) };
362        let v = rv
363            .resolve(builder, &[])
364            .with_context(|| format!("Resolving argument `{name}' ({rv:?})"))?;
365        match v {
366            Value::Bool(b) => {
367                if !b {
368                    Ok(None)
369                } else {
370                    bail!("Bool(true) not expected for optional values, you might want to access a boolean direclty.")
371                }
372            }
373            _ => v
374                .to::<T>(builder)
375                .map(Option::Some)
376                .with_context(|| format!("Converting argument `{name}' from {v:?}")),
377        }
378    }
379
380    pub fn named_arg(&self, name: &str) -> TractResult<Cow<'_, RValue>> {
381        self.get_named_arg(name).ok_or_else(|| format_err!("expected argument {}", name))
382    }
383
384    pub fn get_named_arg(&self, name: &str) -> Option<Cow<'_, RValue>> {
385        // first look explicit name in invocation arguments
386        if let Some(arg) = self
387            .invocation
388            .arguments
389            .iter()
390            .find(|arg| arg.id.as_ref().map(|i| &*i.0) == Some(name))
391        {
392            return Some(Cow::Borrowed(&arg.rvalue));
393        }
394        // then use fragment prototype:
395        if let Some((ix, param)) =
396            self.default_params.iter().enumerate().find(|(_ix, param)| &*param.id.0 == name)
397        {
398            // check that all previous (and our) arguments are positional (todo:
399            // valid args when building augmented_invocation)
400            if self.invocation.arguments.len() > ix
401                && self.invocation.arguments.iter().take(ix + 1).all(|arg| arg.id.is_none())
402            {
403                return Some(Cow::Borrowed(&self.invocation.arguments[ix].rvalue));
404            }
405            if let Some(rv) = &param.lit {
406                return Some(Cow::Owned(RValue::Literal(rv.clone())));
407            }
408        }
409        None
410    }
411
412    pub fn get_named_arg_as<T>(
413        &self,
414        builder: &mut ModelBuilder,
415        name: &str,
416    ) -> TractResult<Option<T>>
417    where
418        T: CoerceFrom<Value>,
419    {
420        let Some(rv) = self.get_named_arg(name) else { return Ok(None) };
421        let v = rv
422            .resolve(builder, &[])
423            .with_context(|| format!("Resolving argument `{name}' ({rv:?})"))?;
424        v.to::<T>(builder)
425            .with_context(|| format!("Converting argument `{name}' from {v:?}"))
426            .map(Some)
427    }
428}
429
430impl ModelBuilder<'_> {}
431
432impl LValue {
433    fn to_identifier(&self) -> TractResult<&Identifier> {
434        match self {
435            LValue::Identifier(id) => Ok(id),
436            _ => bail!("Expected an identifier, found a tuple: {:?}", self),
437        }
438    }
439
440    #[allow(dead_code)]
441    fn to_identifiers(&self) -> TractResult<TVec<&Identifier>> {
442        match self {
443            LValue::Identifier(_) => Ok(tvec!(self.to_identifier()?)),
444            LValue::Tuple(ids) => ids.iter().map(|id| id.to_identifier()).collect(),
445            LValue::Array(ids) => ids.iter().map(|id| id.to_identifier()).collect(),
446        }
447    }
448}
449
450impl Invocation {}
451
452impl RValue {
453    pub fn resolve(
454        &self,
455        builder: &mut ModelBuilder,
456        dt: &[Option<DatumType>],
457    ) -> TractResult<Value> {
458        match self {
459            RValue::Identifier(id) => {
460                if let Some(mut outlet) = builder.scopes.last().unwrap().get(id).cloned() {
461                    if let Value::Wire(outlet_id) = outlet {
462                        let out_dt = builder.model.node(outlet_id.node).outputs[outlet_id.slot]
463                            .fact
464                            .datum_type;
465                        if let Some(Some(dt)) = dt.first() {
466                            if out_dt.unquantized() != dt.unquantized() {
467                                return Err(format_err!(
468                                    "Mismatched types expected {:?}, got {:?}",
469                                    dt,
470                                    out_dt
471                                ));
472                            }
473                            if out_dt != *dt {
474                                outlet =
475                                    builder.wire(tract_core::ops::cast::cast(*dt), &[outlet_id])?;
476                            }
477                        }
478                    }
479                    Ok(outlet)
480                } else if let Some(sym) = builder.model.symbols.get(&id.0) {
481                    Ok(Value::Dim(sym.into()))
482                } else if builder.allow_new_symbol {
483                    warn!("Introducing symbol {id:?} without forward declaration (\"extension tract_symbol ...\"). May be deprecated soon.");
484                    let sym = builder.model.symbols.sym(&id.0);
485                    Ok(Value::Dim(sym.into()))
486                } else {
487                    bail!("Can not resolve {:?}. Not a known identifier, and symbol introduction is forbidden out of \"external\" shape field", id);
488                }
489            }
490            RValue::Invocation(inv) => builder
491                .wire_invocation(inv, dt)
492                .with_context(|| format!("Resolving invocation {:?}", inv.id)),
493            RValue::Binary(left, op, right) => {
494                let op = match &**op {
495                    "+" => "add",
496                    "-" => "sub",
497                    "*" => "mul",
498                    "/" => "div",
499                    "^" => "pow",
500                    ">" => "gt",
501                    "<" => "lt",
502                    "==" => "eq",
503                    "!=" => "ne",
504                    ">=" => "ge",
505                    "<=" => "le",
506                    op => bail!("Unknown binary operator: {}", op),
507                };
508                let inv = Invocation {
509                    id: op.into(),
510                    generic_type_name: None,
511                    arguments: vec![
512                        Argument { id: None, rvalue: left.as_ref().clone() },
513                        Argument { id: None, rvalue: right.as_ref().clone() },
514                    ],
515                };
516                builder
517                    .wire_invocation(&inv, dt)
518                    .with_context(|| format!("Resolving invocation {:?}", &inv.id))
519            }
520            RValue::Array(array) => Ok(Value::Array(
521                array
522                    .iter()
523                    .zip(std::iter::repeat(&dt.first().copied().flatten()))
524                    .map(|(i, dt)| i.resolve(builder, &[*dt]))
525                    .collect::<TractResult<_>>()?,
526            )),
527            RValue::Tuple(array) => {
528                let dt_iter: Box<dyn Iterator<Item = &Option<DatumType>>> =
529                    if dt.len() == 0 || dt.len() == 1 && dt[0].is_none() {
530                        Box::new(std::iter::repeat(&None))
531                    } else if dt.len() == array.len() {
532                        Box::new(dt.iter())
533                    } else {
534                        bail!("Wrong number of types for a tuple, got {:?} for {:?}", dt, array)
535                    };
536                Ok(Value::Tuple(
537                    array
538                        .iter()
539                        .zip(dt_iter)
540                        .map(|(i, dt)| {
541                            if dt.is_none() {
542                                i.resolve(builder, &[])
543                            } else {
544                                i.resolve(builder, &[*dt])
545                            }
546                        })
547                        .collect::<TractResult<_>>()?,
548                ))
549            }
550            RValue::Literal(Literal::Numeric(f)) => {
551                if f.contains('.') || f.contains('e') || f == "inf" || f == "-inf" {
552                    f.parse::<f32>()
553                        .map(Value::Scalar)
554                        .with_context(|| format!("Can not parse {f} as f32"))
555                } else if let Ok(i) = f.parse::<i64>() {
556                    Ok(Value::Dim(i.into()))
557                } else if let Some(s) = builder.model.symbols.get(f) {
558                    Ok(Value::Dim(s.into()))
559                } else {
560                    bail!("Can not parse {}", f)
561                }
562            }
563            RValue::Literal(Literal::String(s)) => Ok(Value::String(s.clone())),
564            RValue::Literal(Literal::Logical(s)) => Ok(Value::Bool(*s)),
565            RValue::Literal(Literal::Array(array)) => Ok(Value::Array(
566                array
567                    .iter()
568                    .zip(std::iter::repeat(&dt.first().copied().flatten()))
569                    .map(|(i, dt)| RValue::Literal(i.clone()).resolve(builder, &[*dt]))
570                    .collect::<TractResult<_>>()?,
571            )),
572            _ => panic!("{self:?}"),
573        }
574    }
575}
576
577#[derive(Clone, Debug, PartialEq)]
578pub enum Value {
579    Tensor(Arc<Tensor>),
580    Wire(OutletId),
581    Array(Vec<Value>),
582    Tuple(Vec<Value>),
583    String(String),
584    Bool(bool),
585    Scalar(f32),
586    Dim(TDim),
587}
588
589impl Value {
590    pub fn to<T>(&self, builder: &mut ModelBuilder) -> TractResult<T>
591    where
592        T: CoerceFrom<Value>,
593    {
594        T::coerce(builder, self)
595    }
596}
597
598impl From<TVec<OutletId>> for Value {
599    fn from(outled_ids: TVec<OutletId>) -> Self {
600        Self::Tuple(outled_ids.into_iter().map(Self::Wire).collect())
601    }
602}
603
604pub trait CoerceFrom<F> {
605    fn coerce(builder: &mut ModelBuilder, from: &F) -> TractResult<Self>
606    where
607        Self: Sized;
608}
609
610impl CoerceFrom<Value> for Value {
611    fn coerce(_builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
612        Ok(from.clone())
613    }
614}
615
616impl CoerceFrom<Value> for Arc<Tensor> {
617    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
618        match from {
619            Value::Dim(t) => Ok(rctensor0(t.to_i32()?)),
620            Value::Tensor(t) => Ok(t.clone()),
621            Value::Tuple(t) if t.len() == 1 => t[0].to(builder),
622            Value::Scalar(f) => Ok(rctensor0(*f)),
623            Value::String(f) => Ok(rctensor0(f.clone())),
624            Value::Bool(b) => Ok(rctensor0(*b)),
625            Value::Wire(o) => builder
626                .model
627                .outlet_fact(*o)?
628                .konst
629                .clone()
630                .ok_or_else(|| format_err!("Not a const")),
631            Value::Array(items) => {
632                let mut tensors = vec![];
633                for item in items {
634                    let tensor = Arc::<Tensor>::coerce(builder, item)?;
635                    let mut tensor = tensor.into_tensor();
636                    tensor.insert_axis(0)?;
637                    tensors.push(tensor);
638                }
639                let tensor = Tensor::stack_tensors(0, &tensors)?;
640                Ok(tensor.into_arc_tensor())
641            }
642            _ => bail!("Can not build a tensor from {:?}", from),
643        }
644    }
645}
646
647impl CoerceFrom<Value> for (Arc<Tensor>, DatumType) {
648    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
649        match from {
650            Value::Tensor(t) => Ok((t.clone(), t.datum_type())),
651            Value::Scalar(f) => Ok((rctensor0(*f), DatumType::F32)),
652            Value::String(f) => Ok((rctensor0(f.clone()), DatumType::String)),
653            Value::Bool(b) => Ok((rctensor0(*b), DatumType::Bool)),
654            Value::Wire(o) => {
655                let outlet_fact = builder.model.outlet_fact(*o)?;
656                Ok((
657                    outlet_fact.konst.clone().ok_or_else(|| format_err!("Not a const"))?,
658                    outlet_fact.datum_type,
659                ))
660            }
661            _ => bail!("Can not build a tensor from {:?}", from),
662        }
663    }
664}
665
666impl CoerceFrom<Value> for OutletId {
667    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
668        match from {
669            Value::Tensor(t) => builder.add_const(t.clone()),
670            Value::Scalar(f) => builder.add_const(rctensor0(*f)),
671            Value::Dim(i) => builder.add_const(rctensor0(i.clone())),
672            Value::Wire(outlet) => Ok(*outlet),
673            Value::Tuple(tuple) if tuple.len() == 1 => OutletId::coerce(builder, &tuple[0]),
674            Value::Array(inputs) => {
675                if let Ok(c) = from.to::<Arc<Tensor>>(builder) {
676                    return builder.add_const(c);
677                }
678                let mut outlets = tvec!();
679                for i in inputs {
680                    let outlet = OutletId::coerce(builder, i)?;
681                    outlets.push(builder.wire_as_outlets(AxisOp::Add(0), &[outlet])?[0]);
682                }
683                builder
684                    .wire_as_outlets(tract_core::ops::array::TypedConcat::new(0), &outlets)
685                    .map(|o| o[0])
686            }
687            Value::String(s) => builder.add_const(rctensor0(s.clone())),
688            Value::Bool(b) => builder.add_const(rctensor0(*b)),
689            _ => bail!("Can not build an outletid from {:?}", from),
690        }
691    }
692}
693
694impl CoerceFrom<Value> for u64 {
695    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
696        match from {
697            Value::Dim(d) => Ok(d.to_i64()? as u64),
698            Value::Tensor(t) => Ok(t.cast_to_scalar::<u64>()?),
699            Value::Wire(_) => Ok(from.to::<Arc<Tensor>>(builder)?.cast_to_scalar::<u64>()?),
700            _ => bail!("Can not build a u64 from {:?}", from),
701        }
702    }
703}
704
705impl CoerceFrom<Value> for i64 {
706    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
707        match from {
708            Value::Dim(d) => d.to_i64(),
709            Value::Tensor(t) => Ok(*t.to_scalar::<i64>()?),
710            Value::Wire(_) => Ok(from.to::<Arc<Tensor>>(builder)?.cast_to_scalar::<i64>()?),
711            _ => bail!("Can not build a i64 from {:?}", from),
712        }
713    }
714}
715
716impl CoerceFrom<Value> for TDim {
717    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
718        match from {
719            Value::Dim(d) => Ok(d.clone()),
720            Value::Tensor(t) => Ok(t.to_scalar::<TDim>()?.clone()),
721            Value::Wire(_) => {
722                Ok(from.to::<Arc<Tensor>>(builder)?.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone())
723            }
724            _ => bail!("Can not build a TDim from {:?}", from),
725        }
726    }
727}
728
729impl CoerceFrom<Value> for String {
730    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
731        match from {
732            Value::String(s) => Ok(s.to_string()),
733            Value::Tensor(t) => Ok(t.to_scalar::<String>()?.clone()),
734            Value::Wire(_) => Ok(from
735                .to::<Arc<Tensor>>(builder)?
736                .cast_to::<String>()?
737                .to_scalar::<String>()?
738                .clone()),
739            _ => bail!("Can not build a String from {:?}", from),
740        }
741    }
742}
743
744impl CoerceFrom<Value> for bool {
745    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
746        match from {
747            Value::Bool(b) => Ok(*b),
748            Value::Tensor(t) => Ok(*t.to_scalar::<bool>()?),
749            Value::Wire(_) => {
750                Ok(*from.to::<Arc<Tensor>>(builder)?.cast_to::<bool>()?.to_scalar::<bool>()?)
751            }
752            Value::Dim(n) => Ok(!n.is_zero()),
753            _ => bail!("Can not build a boolean from {:?}", from),
754        }
755    }
756}
757
758impl CoerceFrom<Value> for usize {
759    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
760        Ok(i64::coerce(builder, from)? as usize)
761    }
762}
763
764impl CoerceFrom<Value> for isize {
765    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
766        Ok(i64::coerce(builder, from)? as isize)
767    }
768}
769
770impl CoerceFrom<Value> for f32 {
771    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
772        match from {
773            Value::Scalar(f) => Ok(*f),
774            Value::Dim(d) => Ok(d.to_i64()? as f32),
775            Value::Tensor(t) => Ok(*t.to_scalar::<f32>()?),
776            Value::Wire(_) => {
777                Ok(*from.to::<Arc<Tensor>>(builder)?.cast_to::<f32>()?.to_scalar::<f32>()?)
778            }
779            _ => bail!("Can not build a f32 from {:?}", from),
780        }
781    }
782}
783
784impl<D: CoerceFrom<Value>> CoerceFrom<Value> for TVec<D> {
785    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
786        match from {
787            Value::Array(vec) => vec.iter().map(|item| D::coerce(builder, item)).collect(),
788            Value::Tuple(vec) => vec.iter().map(|item| D::coerce(builder, item)).collect(),
789            any => Ok(tvec!(D::coerce(builder, any)?)),
790        }
791    }
792}
793
794impl CoerceFrom<Value> for ShapeFact {
795    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
796        match from {
797            Value::Array(vec) => vec.iter().map(|item| TDim::coerce(builder, item)).collect(),
798            Value::Tuple(vec) => vec.iter().map(|item| TDim::coerce(builder, item)).collect(),
799            _ => {
800                let t = from.to::<Arc<Tensor>>(builder)?;
801                Ok(t.cast_to::<TDim>()?.as_slice::<TDim>()?.into())
802            }
803        }
804    }
805}
806
807macro_rules! tuple {
808    ($($d: ident),*) => {
809        impl<$($d),*> CoerceFrom<Value> for ($($d),*)
810            where
811                $($d: CoerceFrom<Value>),*
812                {
813                    fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
814                        match from {
815                            Value::Tuple(vec) => {
816                                let mut vec = vec.iter();
817                                Ok((
818                                        $($d::coerce(builder, vec.next().context("Too small a tuple")?)?),*
819                                   ))
820                            }
821                            _ => bail!("Can not build a tuple from {:?}", from),
822                        }
823                    }
824                }
825    }
826}
827
828tuple!(D1, D2);
829tuple!(D1, D2, D3);
830tuple!(D1, D2, D3, D4);