tract_onnx_opl/ml/
category_mapper.rs

1use std::hash::*;
2use tract_itertools::Itertools;
3use tract_nnef::internal::*;
4use tract_smallvec::SmallVec;
5
6pub fn register(registry: &mut Registry) {
7    registry.register_primitive(
8        "tract_onnx_ml_direct_lookup",
9        &parameters_direct_lookup(),
10        &[("output", TypeName::Scalar.tensor())],
11        load_direct_lookup,
12    );
13    registry.register_primitive(
14        "tract_onnx_ml_reverse_lookup",
15        &parameters_reverse_lookup(),
16        &[("output", TypeName::Scalar.tensor())],
17        load_reverse_lookup,
18    );
19    registry.register_dumper(dump_direct_lookup);
20    registry.register_dumper(dump_reverse_lookup);
21}
22
23#[derive(Clone, Debug, Hash)]
24pub struct DirectLookup {
25    values: Arc<Tensor>,
26    fallback_value: Arc<Tensor>,
27}
28
29impl DirectLookup {
30    pub fn new(values: Arc<Tensor>, fallback_value: Arc<Tensor>) -> TractResult<DirectLookup> {
31        Ok(DirectLookup { values, fallback_value })
32    }
33
34    fn eval_t<T: Datum>(&self, input: &Tensor) -> TractResult<Tensor> {
35        let values = self.values.as_slice::<T>()?;
36        let fallback_value = self.fallback_value.to_scalar::<T>()?;
37        Ok(input
38            .to_array_view::<i32>()?
39            .mapv(|ix| values.get(ix as usize).unwrap_or(fallback_value).clone())
40            .into_tensor())
41    }
42}
43
44impl Op for DirectLookup {
45    fn name(&self) -> StaticName {
46        "DirectLookup".into()
47    }
48
49    op_as_typed_op!();
50}
51
52impl EvalOp for DirectLookup {
53    fn is_stateless(&self) -> bool {
54        true
55    }
56
57    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
58        let input = args_1!(inputs);
59        let output = dispatch_hash!(Self::eval_t(self.values.datum_type())(self, &input))?;
60        Ok(tvec!(output.into_tvalue()))
61    }
62}
63
64impl TypedOp for DirectLookup {
65    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
66        if self.values.datum_type() != self.fallback_value.datum_type() {
67            bail!(
68                "values and fallback value should be of the same type, got {:?}, {:?}",
69                self.values,
70                self.fallback_value
71            )
72        }
73        Ok(tvec!(self.values.datum_type().fact(inputs[0].shape.iter())))
74    }
75
76    fn axes_mapping(
77        &self,
78        inputs: &[&TypedFact],
79        outputs: &[&TypedFact],
80    ) -> TractResult<AxesMapping> {
81        AxesMapping::natural(inputs, outputs)
82    }
83
84    fn change_axes(
85        &self,
86        model: &TypedModel,
87        node: &TypedNode,
88        _io: InOut,
89        change: &AxisOp,
90    ) -> TractResult<Option<AxisChangeConsequence>> {
91        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
92    }
93
94    as_op!();
95}
96
97#[derive(Clone, Debug)]
98pub struct ReverseLookup {
99    keys: Arc<Tensor>,
100    index: HashMap<u64, SmallVec<[i32; 1]>>,
101    fallback_value: i32,
102}
103
104#[allow(clippy::manual_hash_one)]
105impl ReverseLookup {
106    pub fn new(keys: Arc<Tensor>, fallback_value: i32) -> TractResult<ReverseLookup> {
107        unsafe fn new_t<T: Datum + Hash>(keys: &Tensor) -> HashMap<u64, SmallVec<[i32; 1]>> {
108            let keys = unsafe { keys.as_slice_unchecked::<T>() };
109            let mut hashmap = HashMap::<u64, SmallVec<[i32; 1]>>::default();
110            for (ix, k) in keys.iter().enumerate() {
111                let mut hasher = hashmap.hasher().build_hasher();
112                k.hash(&mut hasher);
113                let u = hasher.finish();
114                hashmap.entry(u).or_default().push(ix as i32);
115            }
116            hashmap
117        }
118        let index = unsafe { dispatch_hash!(new_t(keys.datum_type())(&keys)) };
119        Ok(ReverseLookup { index, keys, fallback_value })
120    }
121
122    unsafe fn search_t<T: Datum + Hash>(&self, needle: &T) -> Option<i32> {
123        let keys = unsafe { self.keys.as_slice_unchecked::<T>() };
124        let mut hasher = self.index.hasher().build_hasher();
125        needle.hash(&mut hasher);
126        let u = hasher.finish();
127        if let Some(candidates) = self.index.get(&u) {
128            for candidate in candidates {
129                if &keys[*candidate as usize] == needle {
130                    return Some(*candidate);
131                }
132            }
133        }
134        None
135    }
136
137    fn eval_t<T: Datum + Hash>(&self, input: &Tensor) -> TractResult<Tensor> {
138        unsafe {
139            let mut output = Tensor::uninitialized_dt(i32::datum_type(), input.shape())?;
140            for (i, o) in
141                input.as_slice::<T>()?.iter().zip(output.as_slice_mut_unchecked::<i32>().iter_mut())
142            {
143                *o = self.search_t(i).unwrap_or(self.fallback_value);
144            }
145            Ok(output)
146        }
147    }
148}
149
150impl Hash for ReverseLookup {
151    fn hash<H: Hasher>(&self, state: &mut H) {
152        self.keys.hash(state);
153        self.fallback_value.hash(state);
154        self.index.iter().sorted().for_each(|v| Hash::hash(&v, state));
155    }
156}
157
158impl Op for ReverseLookup {
159    fn name(&self) -> StaticName {
160        "ReverseLookup".into()
161    }
162
163    op_as_typed_op!();
164}
165
166impl EvalOp for ReverseLookup {
167    fn is_stateless(&self) -> bool {
168        true
169    }
170
171    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
172        let input = args_1!(inputs);
173        let output = dispatch_hash!(Self::eval_t(self.keys.datum_type())(self, &input))?;
174        Ok(tvec!(output.into_tvalue()))
175    }
176}
177
178impl TypedOp for ReverseLookup {
179    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
180        Ok(tvec!(i32::fact(inputs[0].shape.iter())))
181    }
182
183    fn axes_mapping(
184        &self,
185        inputs: &[&TypedFact],
186        outputs: &[&TypedFact],
187    ) -> TractResult<AxesMapping> {
188        AxesMapping::natural(inputs, outputs)
189    }
190
191    fn change_axes(
192        &self,
193        model: &TypedModel,
194        node: &TypedNode,
195        _io: InOut,
196        change: &AxisOp,
197    ) -> TractResult<Option<AxisChangeConsequence>> {
198        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
199    }
200
201    as_op!();
202}
203
204fn parameters_direct_lookup() -> Vec<Parameter> {
205    vec![
206        TypeName::String.tensor().named("input"),
207        TypeName::Scalar.tensor().named("values"),
208        TypeName::Scalar.tensor().named("fallback"),
209    ]
210}
211
212fn parameters_reverse_lookup() -> Vec<Parameter> {
213    vec![
214        TypeName::Scalar.tensor().named("input"),
215        TypeName::Scalar.tensor().named("keys"),
216        TypeName::Scalar.named("fallback"),
217    ]
218}
219
220fn dump_direct_lookup(
221    ast: &mut IntoAst,
222    node: &TypedNode,
223    op: &DirectLookup,
224) -> TractResult<Option<Arc<RValue>>> {
225    let input = ast.mapping[&node.inputs[0]].clone();
226    let keys = ast.konst_variable(format!("{}.values", node.name), &op.values)?;
227    let fallback = ast.konst_variable(format!("{}.fallback", node.name), &op.fallback_value)?;
228    Ok(Some(invocation("tract_onnx_ml_direct_lookup", &[input, keys, fallback], &[])))
229}
230
231fn dump_reverse_lookup(
232    ast: &mut IntoAst,
233    node: &TypedNode,
234    op: &ReverseLookup,
235) -> TractResult<Option<Arc<RValue>>> {
236    let input = ast.mapping[&node.inputs[0]].clone();
237    let values = ast.konst_variable(format!("{}.keys", node.name), &op.keys)?;
238    Ok(Some(invocation(
239        "tract_onnx_ml_reverse_lookup",
240        &[input, values],
241        &[("fallback", numeric(op.fallback_value))],
242    )))
243}
244
245fn load_direct_lookup(
246    builder: &mut ModelBuilder,
247    invocation: &ResolvedInvocation,
248) -> TractResult<Value> {
249    let input = invocation.named_arg_as(builder, "input")?;
250    let values: Arc<Tensor> = invocation.named_arg_as(builder, "values")?;
251    let fallback_value = invocation.named_arg_as(builder, "fallback")?;
252    let op = DirectLookup { fallback_value, values };
253    builder.wire(op, &[input])
254}
255
256fn load_reverse_lookup(
257    builder: &mut ModelBuilder,
258    invocation: &ResolvedInvocation,
259) -> TractResult<Value> {
260    let input = invocation.named_arg_as(builder, "input")?;
261    let keys: isize = invocation.named_arg_as(builder, "keys")?;
262    let fallback_value = invocation.named_arg_as(builder, "fallback")?;
263    let op = ReverseLookup::new(fallback_value, keys as i32)?;
264    builder.wire(op, &[input])
265}