tract_nnef/ops/core/
submodel.rs

1use tract_core::ops::submodel::SubmodelOp;
2
3use crate::internal::*;
4
5pub fn register(registry: &mut Registry) {
6    registry.register_dumper(ser_submodel);
7    registry.register_primitive(
8        "tract_core_submodel",
9        &[TypeName::Scalar.tensor().array().named("input"), TypeName::String.named("label")],
10        &[("outputs", TypeName::Any.tensor().array())],
11        de_submodel,
12    );
13}
14
15fn de_submodel(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
16    let wires: TVec<OutletId> = invocation.named_arg_as(builder, "input")?;
17    let label: String = invocation.named_arg_as(builder, "label")?;
18    let model: TypedModel = builder
19        .proto_model
20        .resources
21        .get(label.as_str())
22        .with_context(|| anyhow!("{} not found in model builder loaded resources", label.as_str()))?
23        .clone()
24        .downcast_arc::<TypedModelResource>()
25        .map_err(|_| anyhow!("Error while downcasting typed model resource"))
26        .map(|r| r.0.clone())
27        .with_context(|| anyhow!("Error while loading typed model resource"))?;
28
29    let op: Box<dyn TypedOp> = Box::new(SubmodelOp::new(Box::new(model), &label)?);
30
31    builder.model.wire_node(label, op, &wires).map(Value::from)
32}
33
34fn ser_submodel(
35    ast: &mut IntoAst,
36    node: &TypedNode,
37    op: &SubmodelOp,
38) -> TractResult<Option<Arc<RValue>>> {
39    let input = tvec![ast.mapping[&node.inputs[0]].clone()];
40    let invoke = invocation("tract_core_submodel", &input, &[("label", string(op.label()))]);
41    ast.resources.insert(op.label().to_string(), Arc::new(TypedModelResource(op.model().clone())));
42    Ok(Some(invoke))
43}