tract_onnx/ops/
multinomial.rs1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4
5pub fn multinomial(
6 _ctx: &ParsingContext,
7 node: &NodeProto,
8) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
9 let dtype = match node.get_attr_opt("dtype")?.unwrap_or(6) {
10 6 => DatumType::I32,
11 7 => DatumType::I64,
12 i => bail!("Unsupported datum type {} for ONNX Multinomial", i),
13 };
14 let sample_size = node.get_attr_opt("sample_size")?.unwrap_or(1);
15 let seed = node.get_attr::<f32>("seed").ok();
16
17 Ok((expand(Multinomial { dtype, sample_size, seed }), vec![]))
18}
19
20#[derive(Clone, Debug)]
21pub struct Multinomial {
22 dtype: DatumType,
23 sample_size: i32,
24 pub seed: Option<f32>,
25}
26
27impl Expansion for Multinomial {
28 fn name(&self) -> StaticName {
29 "Multinomial".into()
30 }
31
32
33 fn rules<'r, 'p: 'r, 's: 'r>(
34 &'s self,
35 s: &mut Solver<'r>,
36 inputs: &'p [TensorProxy],
37 outputs: &'p [TensorProxy],
38 ) -> InferenceResult {
39 check_output_arity(outputs, 1)?;
40 check_input_arity(inputs, 1)?;
41
42 s.equals(&inputs[0].rank, 2)?;
46 s.equals(&outputs[0].rank, 2)?;
47 s.equals(&outputs[0].datum_type, self.dtype)?;
48 s.equals(&inputs[0].shape[0], &outputs[0].shape[0])?; s.equals(&outputs[0].shape[1], self.sample_size.to_dim())?; Ok(())
52 }
53
54 fn wire(
55 &self,
56 name: &str,
57 model: &mut TypedModel,
58 inputs: &[OutletId],
59 ) -> TractResult<TVec<OutletId>> {
60 model.wire_node(
61 name,
62 tract_onnx_opl::multinomial::Multinomial {
63 dtype: self.dtype,
64 sample_size: self.sample_size,
65 seed: self.seed,
66 },
67 &[inputs[0]],
68 )
69 }
70}