tract_onnx/ops/
multinomial.rs

1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4
5pub fn multinomial(
6    _ctx: &ParsingContext,
7    node: &NodeProto,
8) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
9    let dtype = match node.get_attr_opt("dtype")?.unwrap_or(6) {
10        6 => DatumType::I32,
11        7 => DatumType::I64,
12        i => bail!("Unsupported datum type {} for ONNX Multinomial", i),
13    };
14    let sample_size = node.get_attr_opt("sample_size")?.unwrap_or(1);
15    let seed = node.get_attr::<f32>("seed").ok();
16
17    Ok((expand(Multinomial { dtype, sample_size, seed }), vec![]))
18}
19
20#[derive(Clone, Debug)]
21pub struct Multinomial {
22    dtype: DatumType,
23    sample_size: i32,
24    pub seed: Option<f32>,
25}
26
27impl Expansion for Multinomial {
28    fn name(&self) -> StaticName {
29        "Multinomial".into()
30    }
31
32
33    fn rules<'r, 'p: 'r, 's: 'r>(
34        &'s self,
35        s: &mut Solver<'r>,
36        inputs: &'p [TensorProxy],
37        outputs: &'p [TensorProxy],
38    ) -> InferenceResult {
39        check_output_arity(outputs, 1)?;
40        check_input_arity(inputs, 1)?;
41
42        // inputs[0]: tensor(float16), tensor(float), tensor(double) ; [batch_size, class_size]
43        // outputs[0]: tensor(int32), tensor(int64) {depending on self.datum_type} ; [batch_size, sample_size]
44
45        s.equals(&inputs[0].rank, 2)?;
46        s.equals(&outputs[0].rank, 2)?;
47        s.equals(&outputs[0].datum_type, self.dtype)?;
48        s.equals(&inputs[0].shape[0], &outputs[0].shape[0])?; // batch_size
49        s.equals(&outputs[0].shape[1], self.sample_size.to_dim())?; // sample_size
50
51        Ok(())
52    }
53
54    fn wire(
55        &self,
56        name: &str,
57        model: &mut TypedModel,
58        inputs: &[OutletId],
59    ) -> TractResult<TVec<OutletId>> {
60        model.wire_node(
61            name,
62            tract_onnx_opl::multinomial::Multinomial {
63                dtype: self.dtype,
64                sample_size: self.sample_size,
65                seed: self.seed,
66            },
67            &[inputs[0]],
68        )
69    }
70}