1use std::hash::*;
2use tract_itertools::Itertools;
3use tract_nnef::internal::*;
4use tract_smallvec::SmallVec;
5
6pub fn register(registry: &mut Registry) {
7 registry.register_primitive(
8 "tract_onnx_ml_direct_lookup",
9 ¶meters_direct_lookup(),
10 &[("output", TypeName::Scalar.tensor())],
11 load_direct_lookup,
12 );
13 registry.register_primitive(
14 "tract_onnx_ml_reverse_lookup",
15 ¶meters_reverse_lookup(),
16 &[("output", TypeName::Scalar.tensor())],
17 load_reverse_lookup,
18 );
19 registry.register_dumper(dump_direct_lookup);
20 registry.register_dumper(dump_reverse_lookup);
21}
22
23#[derive(Clone, Debug, Hash)]
24pub struct DirectLookup {
25 values: Arc<Tensor>,
26 fallback_value: Arc<Tensor>,
27}
28
29impl DirectLookup {
30 pub fn new(values: Arc<Tensor>, fallback_value: Arc<Tensor>) -> TractResult<DirectLookup> {
31 Ok(DirectLookup { values, fallback_value })
32 }
33
34 fn eval_t<T: Datum>(&self, input: &Tensor) -> TractResult<Tensor> {
35 let values = self.values.as_slice::<T>()?;
36 let fallback_value = self.fallback_value.to_scalar::<T>()?;
37 Ok(input
38 .to_array_view::<i32>()?
39 .mapv(|ix| values.get(ix as usize).unwrap_or(fallback_value).clone())
40 .into_tensor())
41 }
42}
43
44impl Op for DirectLookup {
45 fn name(&self) -> StaticName {
46 "DirectLookup".into()
47 }
48
49 op_as_typed_op!();
50}
51
52impl EvalOp for DirectLookup {
53 fn is_stateless(&self) -> bool {
54 true
55 }
56
57 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
58 let input = args_1!(inputs);
59 let output = dispatch_hash!(Self::eval_t(self.values.datum_type())(self, &input))?;
60 Ok(tvec!(output.into_tvalue()))
61 }
62}
63
64impl TypedOp for DirectLookup {
65 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
66 if self.values.datum_type() != self.fallback_value.datum_type() {
67 bail!(
68 "values and fallback value should be of the same type, got {:?}, {:?}",
69 self.values,
70 self.fallback_value
71 )
72 }
73 Ok(tvec!(self.values.datum_type().fact(inputs[0].shape.iter())))
74 }
75
76 fn axes_mapping(
77 &self,
78 inputs: &[&TypedFact],
79 outputs: &[&TypedFact],
80 ) -> TractResult<AxesMapping> {
81 AxesMapping::natural(inputs, outputs)
82 }
83
84 fn change_axes(
85 &self,
86 model: &TypedModel,
87 node: &TypedNode,
88 _io: InOut,
89 change: &AxisOp,
90 ) -> TractResult<Option<AxisChangeConsequence>> {
91 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
92 }
93
94 as_op!();
95}
96
97#[derive(Clone, Debug)]
98pub struct ReverseLookup {
99 keys: Arc<Tensor>,
100 index: HashMap<u64, SmallVec<[i32; 1]>>,
101 fallback_value: i32,
102}
103
104#[allow(clippy::manual_hash_one)]
105impl ReverseLookup {
106 pub fn new(keys: Arc<Tensor>, fallback_value: i32) -> TractResult<ReverseLookup> {
107 unsafe fn new_t<T: Datum + Hash>(keys: &Tensor) -> HashMap<u64, SmallVec<[i32; 1]>> {
108 let keys = unsafe { keys.as_slice_unchecked::<T>() };
109 let mut hashmap = HashMap::<u64, SmallVec<[i32; 1]>>::default();
110 for (ix, k) in keys.iter().enumerate() {
111 let mut hasher = hashmap.hasher().build_hasher();
112 k.hash(&mut hasher);
113 let u = hasher.finish();
114 hashmap.entry(u).or_default().push(ix as i32);
115 }
116 hashmap
117 }
118 let index = unsafe { dispatch_hash!(new_t(keys.datum_type())(&keys)) };
119 Ok(ReverseLookup { index, keys, fallback_value })
120 }
121
122 unsafe fn search_t<T: Datum + Hash>(&self, needle: &T) -> Option<i32> {
123 let keys = unsafe { self.keys.as_slice_unchecked::<T>() };
124 let mut hasher = self.index.hasher().build_hasher();
125 needle.hash(&mut hasher);
126 let u = hasher.finish();
127 if let Some(candidates) = self.index.get(&u) {
128 for candidate in candidates {
129 if &keys[*candidate as usize] == needle {
130 return Some(*candidate);
131 }
132 }
133 }
134 None
135 }
136
137 fn eval_t<T: Datum + Hash>(&self, input: &Tensor) -> TractResult<Tensor> {
138 unsafe {
139 let mut output = Tensor::uninitialized_dt(i32::datum_type(), input.shape())?;
140 for (i, o) in
141 input.as_slice::<T>()?.iter().zip(output.as_slice_mut_unchecked::<i32>().iter_mut())
142 {
143 *o = self.search_t(i).unwrap_or(self.fallback_value);
144 }
145 Ok(output)
146 }
147 }
148}
149
150impl Hash for ReverseLookup {
151 fn hash<H: Hasher>(&self, state: &mut H) {
152 self.keys.hash(state);
153 self.fallback_value.hash(state);
154 self.index.iter().sorted().for_each(|v| Hash::hash(&v, state));
155 }
156}
157
158impl Op for ReverseLookup {
159 fn name(&self) -> StaticName {
160 "ReverseLookup".into()
161 }
162
163 op_as_typed_op!();
164}
165
166impl EvalOp for ReverseLookup {
167 fn is_stateless(&self) -> bool {
168 true
169 }
170
171 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
172 let input = args_1!(inputs);
173 let output = dispatch_hash!(Self::eval_t(self.keys.datum_type())(self, &input))?;
174 Ok(tvec!(output.into_tvalue()))
175 }
176}
177
178impl TypedOp for ReverseLookup {
179 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
180 Ok(tvec!(i32::fact(inputs[0].shape.iter())))
181 }
182
183 fn axes_mapping(
184 &self,
185 inputs: &[&TypedFact],
186 outputs: &[&TypedFact],
187 ) -> TractResult<AxesMapping> {
188 AxesMapping::natural(inputs, outputs)
189 }
190
191 fn change_axes(
192 &self,
193 model: &TypedModel,
194 node: &TypedNode,
195 _io: InOut,
196 change: &AxisOp,
197 ) -> TractResult<Option<AxisChangeConsequence>> {
198 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
199 }
200
201 as_op!();
202}
203
204fn parameters_direct_lookup() -> Vec<Parameter> {
205 vec![
206 TypeName::String.tensor().named("input"),
207 TypeName::Scalar.tensor().named("values"),
208 TypeName::Scalar.tensor().named("fallback"),
209 ]
210}
211
212fn parameters_reverse_lookup() -> Vec<Parameter> {
213 vec![
214 TypeName::Scalar.tensor().named("input"),
215 TypeName::Scalar.tensor().named("keys"),
216 TypeName::Scalar.named("fallback"),
217 ]
218}
219
220fn dump_direct_lookup(
221 ast: &mut IntoAst,
222 node: &TypedNode,
223 op: &DirectLookup,
224) -> TractResult<Option<Arc<RValue>>> {
225 let input = ast.mapping[&node.inputs[0]].clone();
226 let keys = ast.konst_variable(format!("{}.values", node.name), &op.values)?;
227 let fallback = ast.konst_variable(format!("{}.fallback", node.name), &op.fallback_value)?;
228 Ok(Some(invocation("tract_onnx_ml_direct_lookup", &[input, keys, fallback], &[])))
229}
230
231fn dump_reverse_lookup(
232 ast: &mut IntoAst,
233 node: &TypedNode,
234 op: &ReverseLookup,
235) -> TractResult<Option<Arc<RValue>>> {
236 let input = ast.mapping[&node.inputs[0]].clone();
237 let values = ast.konst_variable(format!("{}.keys", node.name), &op.keys)?;
238 Ok(Some(invocation(
239 "tract_onnx_ml_reverse_lookup",
240 &[input, values],
241 &[("fallback", numeric(op.fallback_value))],
242 )))
243}
244
245fn load_direct_lookup(
246 builder: &mut ModelBuilder,
247 invocation: &ResolvedInvocation,
248) -> TractResult<Value> {
249 let input = invocation.named_arg_as(builder, "input")?;
250 let values: Arc<Tensor> = invocation.named_arg_as(builder, "values")?;
251 let fallback_value = invocation.named_arg_as(builder, "fallback")?;
252 let op = DirectLookup { fallback_value, values };
253 builder.wire(op, &[input])
254}
255
256fn load_reverse_lookup(
257 builder: &mut ModelBuilder,
258 invocation: &ResolvedInvocation,
259) -> TractResult<Value> {
260 let input = invocation.named_arg_as(builder, "input")?;
261 let keys: isize = invocation.named_arg_as(builder, "keys")?;
262 let fallback_value = invocation.named_arg_as(builder, "fallback")?;
263 let op = ReverseLookup::new(fallback_value, keys as i32)?;
264 builder.wire(op, &[input])
265}