tract_nnef/ops/core/
topk.rs1use crate::internal::*;
2use crate::ser::*;
3use tract_core::ops::array::Topk;
4
5pub fn register(registry: &mut Registry) {
6 registry.register_dumper(ser_topk);
7 registry.register_primitive(
8 "tract_core_topk",
9 &[
10 TypeName::Scalar.tensor().named("input"),
11 TypeName::Integer.tensor().named("k"),
12 TypeName::Integer.named("axis"),
13 TypeName::Logical.named("largest"),
14 ],
15 &[("values", TypeName::Scalar.tensor()), ("indices", TypeName::Integer.tensor())],
16 de_topk,
17 );
18}
19
20fn ser_topk(ast: &mut IntoAst, node: &TypedNode, op: &Topk) -> TractResult<Option<Arc<RValue>>> {
21 let input = ast.mapping[&node.inputs[0]].clone();
22 let k = ast.mapping[&node.inputs[1]].clone();
23 Ok(Some(invocation(
24 "tract_core_topk",
25 &[input, k],
26 &[("largest", logical(op.largest)), ("axis", numeric(op.axis))],
27 )))
28}
29
30fn de_topk(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
31 let input = invocation.named_arg_as(builder, "input")?;
32 let k = invocation.named_arg_as(builder, "k")?;
33 let axis = invocation.named_arg_as(builder, "axis")?;
34 let largest = invocation.named_arg_as(builder, "largest")?;
35 let fallback_k = builder.model.symbols.new_with_prefix("k").into();
36 builder.wire(Topk { largest, fallback_k, axis }, &[input, k])
37}