1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
use rand::distributions::Standard;
use rand::prelude::Distribution;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use tract_nnef::internal::*;
use tract_nnef::tract_ndarray::s;
use tract_nnef::tract_num_traits::{AsPrimitive, Float, Zero};

pub fn register(registry: &mut Registry) {
    registry.register_primitive("tract_onnx_multinomial", &parameters(), load);
    registry.register_dumper(TypeId::of::<Multinomial>(), dump);
}

/// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Multinomial
#[derive(Clone, Debug, Educe)]
#[educe(Hash)]
pub struct Multinomial {
    pub dtype: DatumType,
    pub sample_size: i32,
    #[educe(Hash(method = "hash_opt_f32"))]
    pub seed: Option<f32>,
}

impl_dyn_hash!(Multinomial);

impl Multinomial {
    fn eval_t0<T1>(&self, input: Arc<Tensor>) -> TractResult<Arc<Tensor>>
    where
        T1: Datum + std::ops::SubAssign + Float + std::iter::Sum,
        Standard: Distribution<T1>,
    {
        match self.dtype {
            DatumType::I32 => self.eval_t::<T1, i32>(input),
            DatumType::I64 => self.eval_t::<T1, i64>(input),
            dt => bail!("Unsupported output datum type for Multinomial: {:?}", dt),
        }
    }
    fn eval_t<T1, T2>(&self, input: Arc<Tensor>) -> TractResult<Arc<Tensor>>
    where
        T1: Datum + std::ops::SubAssign + Float + std::iter::Sum,
        Standard: Distribution<T1>,
        T2: Datum + Zero + Copy,
        usize: AsPrimitive<T2>,
    {
        let batch_size = input.shape()[0];
        let class_size = input.shape()[1];

        let mut rng = self.seed.map_or_else(SmallRng::from_entropy, |seed| {
            SmallRng::seed_from_u64(seed.to_bits() as _)
        });

        // shape: [batch_size, class_size]
        let input = input.to_array_view::<T1>()?;

        // ONNX Multinomial inputs are "unnormalized log probabilities".
        // This means that we need to compute the maximum for each batch beforehand,
        //  and we also need to exp every value.

        let maximums: TVec<_> =
            input.rows().into_iter().map(|r| r.iter().map(|e| e.exp()).sum::<T1>()).collect();

        // shape: [batch_size, sample_size]
        let out_shape: &[_] = &[batch_size, self.sample_size as usize];
        let output = tract_ndarray::ArrayD::from_shape_fn(out_shape, |co_o| -> T2 {
            let batch = co_o[0];

            let mut rand = rng.gen::<T1>() * maximums[batch];
            let mut ret: T2 = usize::as_(class_size - 1);

            for (i, prob) in input.slice(s![batch, ..]).iter().enumerate() {
                let prob = prob.exp();
                if rand < prob {
                    ret = usize::as_(i);
                    break;
                }
                rand -= prob;
            }

            ret
        });

        Ok(output.into_arc_tensor())
    }
}

impl Op for Multinomial {
    fn name(&self) -> Cow<str> {
        "Multinomial".into()
    }

    fn op_families(&self) -> &'static [&'static str] {
        &["onnx"]
    }

    op_as_typed_op!();
}

impl EvalOp for Multinomial {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let input = args_1!(inputs);

        let output = match input.datum_type() {
            // DatumType::F16 => self.eval_t0::<f16>(input), // TODO: implement random for f16
            DatumType::F32 => self.eval_t0::<f32>(input),
            DatumType::F64 => self.eval_t0::<f64>(input),
            dt => bail!("Unsupported input datum type for Multinomial: {:?}", dt),
        }?;

        Ok(tvec![output])
    }
}

impl TypedOp for Multinomial {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let input_shape = if let Some(s) = inputs[0].shape.as_concrete() {
            s
        } else {
            bail!("Only constant input shape are supported in Multinomial")
        };

        let batch_size = input_shape[0];
        Ok(tvec!(self.dtype.fact(&[batch_size, self.sample_size as usize])))
    }

    as_op!();
}

fn parameters() -> Vec<Parameter> {
    vec![
        TypeName::Integer.tensor().named("input"),
        TypeName::Integer.named("dtype").default(6),
        TypeName::Integer.named("sample_size").default(1),
        TypeName::Integer.named("seed"),
    ]
}

fn dump(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
    let op = node.op_as::<Multinomial>().context("wrong op")?;
    let input = ast.mapping[&node.inputs[0]].clone();

    let dtype = match op.dtype {
        DatumType::I32 => 6,
        DatumType::I64 => 7,
        dt => bail!("Unsupported datum type {:?} for ONNX Multinomial", dt),
    };

    let inv = if let Some(seed) = op.seed {
        invocation(
            "tract_onnx_multinomial",
            &[input],
            &[
                ("dtype", numeric(dtype)),
                ("sample_size", numeric(op.sample_size)),
                ("seed", numeric(seed)),
            ],
        )
    } else {
        invocation(
            "tract_onnx_multinomial",
            &[input],
            &[("dtype", numeric(dtype)), ("sample_size", numeric(op.sample_size))],
        )
    };

    Ok(Some(inv))
}

fn load(
    builder: &mut ModelBuilder,
    invocation: &ResolvedInvocation,
) -> TractResult<Value> {
    let input = invocation.named_arg_as(builder, "input")?;
    let dtype = match invocation.named_arg_as::<i64>(builder, "dtype")? {
        6 => DatumType::I32,
        7 => DatumType::I64,
        i => bail!("Unsupported datum type {} for ONNX Multinomial", i),
    };
    let sample_size = invocation.named_arg_as::<i64>(builder, "sample_size")? as _;
    let seed = invocation.named_arg_as(builder, "seed").ok();

    let op = Multinomial { dtype, sample_size, seed };
    builder.wire(op, &[input])
}