tract_onnx_opl/
multinomial.rs

1use rand::distributions::Standard;
2use rand::prelude::Distribution;
3use rand::rngs::SmallRng;
4use rand::{Rng, SeedableRng};
5
6use tract_nnef::internal::*;
7use tract_nnef::tract_ndarray::s;
8use tract_nnef::tract_num_traits::{AsPrimitive, Float, Zero};
9
10pub fn register(registry: &mut Registry) {
11    registry.register_primitive(
12        "tract_onnx_multinomial",
13        &parameters(),
14        &[("output", TypeName::Scalar.tensor())],
15        load,
16    );
17    registry.register_dumper(dump);
18}
19
20/// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Multinomial
21#[derive(Clone, Debug)]
22pub struct Multinomial {
23    pub dtype: DatumType,
24    pub sample_size: i32,
25    pub seed: Option<f32>,
26}
27
28impl Multinomial {
29    fn eval_t0<T1>(&self, input: TValue) -> TractResult<TValue>
30    where
31        T1: Datum + std::ops::SubAssign + Float + std::iter::Sum,
32        Standard: Distribution<T1>,
33    {
34        match self.dtype {
35            DatumType::I32 => self.eval_t::<T1, i32>(input),
36            DatumType::I64 => self.eval_t::<T1, i64>(input),
37            dt => bail!("Unsupported output datum type for Multinomial: {:?}", dt),
38        }
39    }
40    fn eval_t<T1, T2>(&self, input: TValue) -> TractResult<TValue>
41    where
42        T1: Datum + std::ops::SubAssign + Float + std::iter::Sum,
43        Standard: Distribution<T1>,
44        T2: Datum + Zero + Copy,
45        usize: AsPrimitive<T2>,
46    {
47        let batch_size = input.shape()[0];
48        let class_size = input.shape()[1];
49
50        let mut rng = self.seed.map_or_else(SmallRng::from_entropy, |seed| {
51            SmallRng::seed_from_u64(seed.to_bits() as _)
52        });
53
54        // shape: [batch_size, class_size]
55        let input = input.to_array_view::<T1>()?;
56
57        // ONNX Multinomial inputs are "unnormalized log probabilities".
58        // This means that we need to compute the maximum for each batch beforehand,
59        //  and we also need to exp every value.
60
61        let maximums: TVec<_> =
62            input.rows().into_iter().map(|r| r.iter().map(|e| e.exp()).sum::<T1>()).collect();
63
64        // shape: [batch_size, sample_size]
65        let out_shape: &[_] = &[batch_size, self.sample_size as usize];
66        let output = tract_ndarray::ArrayD::from_shape_fn(out_shape, |co_o| -> T2 {
67            let batch = co_o[0];
68
69            let mut rand = rng.r#gen::<T1>() * maximums[batch];
70            let mut ret: T2 = usize::as_(class_size - 1);
71
72            for (i, prob) in input.slice(s![batch, ..]).iter().enumerate() {
73                let prob = prob.exp();
74                if rand < prob {
75                    ret = usize::as_(i);
76                    break;
77                }
78                rand -= prob;
79            }
80
81            ret
82        });
83
84        Ok(output.into_tvalue())
85    }
86}
87
88impl Op for Multinomial {
89    fn name(&self) -> StaticName {
90        "Multinomial".into()
91    }
92
93    op_as_typed_op!();
94}
95
96impl EvalOp for Multinomial {
97    fn is_stateless(&self) -> bool {
98        true
99    }
100
101    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
102        let input = args_1!(inputs);
103
104        let output = match input.datum_type() {
105            // DatumType::F16 => self.eval_t0::<f16>(input), // TODO: implement random for f16
106            DatumType::F32 => self.eval_t0::<f32>(input),
107            DatumType::F64 => self.eval_t0::<f64>(input),
108            dt => bail!("Unsupported input datum type for Multinomial: {:?}", dt),
109        }?;
110
111        Ok(tvec![output])
112    }
113}
114
115impl TypedOp for Multinomial {
116    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
117        let input_shape = if let Some(s) = inputs[0].shape.as_concrete() {
118            s
119        } else {
120            bail!("Only constant input shape are supported in Multinomial")
121        };
122
123        let batch_size = input_shape[0];
124        Ok(tvec!(self.dtype.fact([batch_size, self.sample_size as usize])))
125    }
126
127    as_op!();
128}
129
130fn parameters() -> Vec<Parameter> {
131    vec![
132        TypeName::Integer.tensor().named("input"),
133        TypeName::Integer.named("dtype").default(6),
134        TypeName::Integer.named("sample_size").default(1),
135        TypeName::Integer.named("seed"),
136    ]
137}
138
139fn dump(ast: &mut IntoAst, node: &TypedNode, op: &Multinomial) -> TractResult<Option<Arc<RValue>>> {
140    let input = ast.mapping[&node.inputs[0]].clone();
141
142    let dtype = match op.dtype {
143        DatumType::I32 => 6,
144        DatumType::I64 => 7,
145        dt => bail!("Unsupported datum type {:?} for ONNX Multinomial", dt),
146    };
147
148    let inv = if let Some(seed) = op.seed {
149        invocation(
150            "tract_onnx_multinomial",
151            &[input],
152            &[
153                ("dtype", numeric(dtype)),
154                ("sample_size", numeric(op.sample_size)),
155                ("seed", numeric(seed)),
156            ],
157        )
158    } else {
159        invocation(
160            "tract_onnx_multinomial",
161            &[input],
162            &[("dtype", numeric(dtype)), ("sample_size", numeric(op.sample_size))],
163        )
164    };
165
166    Ok(Some(inv))
167}
168
169fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
170    let input = invocation.named_arg_as(builder, "input")?;
171    let dtype = match invocation.named_arg_as::<i64>(builder, "dtype")? {
172        6 => DatumType::I32,
173        7 => DatumType::I64,
174        i => bail!("Unsupported datum type {} for ONNX Multinomial", i),
175    };
176    let sample_size = invocation.named_arg_as::<i64>(builder, "sample_size")? as _;
177    let seed = invocation.named_arg_as(builder, "seed").ok();
178
179    let op = Multinomial { dtype, sample_size, seed };
180    builder.wire(op, &[input])
181}