tract_nnef/ops/core/
topk.rs

1use crate::internal::*;
2use crate::ser::*;
3use tract_core::ops::array::Topk;
4
5pub fn register(registry: &mut Registry) {
6    registry.register_dumper(ser_topk);
7    registry.register_primitive(
8        "tract_core_topk",
9        &[
10            TypeName::Scalar.tensor().named("input"),
11            TypeName::Integer.tensor().named("k"),
12            TypeName::Integer.named("axis"),
13            TypeName::Logical.named("largest"),
14        ],
15        &[("values", TypeName::Scalar.tensor()), ("indices", TypeName::Integer.tensor())],
16        de_topk,
17    );
18}
19
20fn ser_topk(ast: &mut IntoAst, node: &TypedNode, op: &Topk) -> TractResult<Option<Arc<RValue>>> {
21    let input = ast.mapping[&node.inputs[0]].clone();
22    let k = ast.mapping[&node.inputs[1]].clone();
23    Ok(Some(invocation(
24        "tract_core_topk",
25        &[input, k],
26        &[("largest", logical(op.largest)), ("axis", numeric(op.axis))],
27    )))
28}
29
30fn de_topk(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
31    let input = invocation.named_arg_as(builder, "input")?;
32    let k = invocation.named_arg_as(builder, "k")?;
33    let axis = invocation.named_arg_as(builder, "axis")?;
34    let largest = invocation.named_arg_as(builder, "largest")?;
35    let fallback_k = builder.model.symbols.new_with_prefix("k").into();
36    builder.wire(Topk { largest, fallback_k, axis }, &[input, k])
37}