tract_onnx_opl/
multinomial.rs1use rand::distributions::Standard;
2use rand::prelude::Distribution;
3use rand::rngs::SmallRng;
4use rand::{Rng, SeedableRng};
5
6use tract_nnef::internal::*;
7use tract_nnef::tract_ndarray::s;
8use tract_nnef::tract_num_traits::{AsPrimitive, Float, Zero};
9
10pub fn register(registry: &mut Registry) {
11 registry.register_primitive(
12 "tract_onnx_multinomial",
13 ¶meters(),
14 &[("output", TypeName::Scalar.tensor())],
15 load,
16 );
17 registry.register_dumper(dump);
18}
19
20#[derive(Clone, Debug)]
22pub struct Multinomial {
23 pub dtype: DatumType,
24 pub sample_size: i32,
25 pub seed: Option<f32>,
26}
27
28impl Multinomial {
29 fn eval_t0<T1>(&self, input: TValue) -> TractResult<TValue>
30 where
31 T1: Datum + std::ops::SubAssign + Float + std::iter::Sum,
32 Standard: Distribution<T1>,
33 {
34 match self.dtype {
35 DatumType::I32 => self.eval_t::<T1, i32>(input),
36 DatumType::I64 => self.eval_t::<T1, i64>(input),
37 dt => bail!("Unsupported output datum type for Multinomial: {:?}", dt),
38 }
39 }
40 fn eval_t<T1, T2>(&self, input: TValue) -> TractResult<TValue>
41 where
42 T1: Datum + std::ops::SubAssign + Float + std::iter::Sum,
43 Standard: Distribution<T1>,
44 T2: Datum + Zero + Copy,
45 usize: AsPrimitive<T2>,
46 {
47 let batch_size = input.shape()[0];
48 let class_size = input.shape()[1];
49
50 let mut rng = self.seed.map_or_else(SmallRng::from_entropy, |seed| {
51 SmallRng::seed_from_u64(seed.to_bits() as _)
52 });
53
54 let input = input.to_array_view::<T1>()?;
56
57 let maximums: TVec<_> =
62 input.rows().into_iter().map(|r| r.iter().map(|e| e.exp()).sum::<T1>()).collect();
63
64 let out_shape: &[_] = &[batch_size, self.sample_size as usize];
66 let output = tract_ndarray::ArrayD::from_shape_fn(out_shape, |co_o| -> T2 {
67 let batch = co_o[0];
68
69 let mut rand = rng.r#gen::<T1>() * maximums[batch];
70 let mut ret: T2 = usize::as_(class_size - 1);
71
72 for (i, prob) in input.slice(s![batch, ..]).iter().enumerate() {
73 let prob = prob.exp();
74 if rand < prob {
75 ret = usize::as_(i);
76 break;
77 }
78 rand -= prob;
79 }
80
81 ret
82 });
83
84 Ok(output.into_tvalue())
85 }
86}
87
88impl Op for Multinomial {
89 fn name(&self) -> StaticName {
90 "Multinomial".into()
91 }
92
93 op_as_typed_op!();
94}
95
96impl EvalOp for Multinomial {
97 fn is_stateless(&self) -> bool {
98 true
99 }
100
101 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
102 let input = args_1!(inputs);
103
104 let output = match input.datum_type() {
105 DatumType::F32 => self.eval_t0::<f32>(input),
107 DatumType::F64 => self.eval_t0::<f64>(input),
108 dt => bail!("Unsupported input datum type for Multinomial: {:?}", dt),
109 }?;
110
111 Ok(tvec![output])
112 }
113}
114
115impl TypedOp for Multinomial {
116 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
117 let input_shape = if let Some(s) = inputs[0].shape.as_concrete() {
118 s
119 } else {
120 bail!("Only constant input shape are supported in Multinomial")
121 };
122
123 let batch_size = input_shape[0];
124 Ok(tvec!(self.dtype.fact([batch_size, self.sample_size as usize])))
125 }
126
127 as_op!();
128}
129
130fn parameters() -> Vec<Parameter> {
131 vec![
132 TypeName::Integer.tensor().named("input"),
133 TypeName::Integer.named("dtype").default(6),
134 TypeName::Integer.named("sample_size").default(1),
135 TypeName::Integer.named("seed"),
136 ]
137}
138
139fn dump(ast: &mut IntoAst, node: &TypedNode, op: &Multinomial) -> TractResult<Option<Arc<RValue>>> {
140 let input = ast.mapping[&node.inputs[0]].clone();
141
142 let dtype = match op.dtype {
143 DatumType::I32 => 6,
144 DatumType::I64 => 7,
145 dt => bail!("Unsupported datum type {:?} for ONNX Multinomial", dt),
146 };
147
148 let inv = if let Some(seed) = op.seed {
149 invocation(
150 "tract_onnx_multinomial",
151 &[input],
152 &[
153 ("dtype", numeric(dtype)),
154 ("sample_size", numeric(op.sample_size)),
155 ("seed", numeric(seed)),
156 ],
157 )
158 } else {
159 invocation(
160 "tract_onnx_multinomial",
161 &[input],
162 &[("dtype", numeric(dtype)), ("sample_size", numeric(op.sample_size))],
163 )
164 };
165
166 Ok(Some(inv))
167}
168
169fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
170 let input = invocation.named_arg_as(builder, "input")?;
171 let dtype = match invocation.named_arg_as::<i64>(builder, "dtype")? {
172 6 => DatumType::I32,
173 7 => DatumType::I64,
174 i => bail!("Unsupported datum type {} for ONNX Multinomial", i),
175 };
176 let sample_size = invocation.named_arg_as::<i64>(builder, "sample_size")? as _;
177 let seed = invocation.named_arg_as(builder, "seed").ok();
178
179 let op = Multinomial { dtype, sample_size, seed };
180 builder.wire(op, &[input])
181}