tract_onnx_opl/
random.rs

1use rand::distributions::uniform::SampleUniform;
2use rand::prelude::Distribution;
3use rand::rngs::SmallRng;
4use rand::SeedableRng;
5use rand_distr::num_traits::Float;
6use rand_distr::StandardNormal;
7use tract_nnef::internal::*;
8use tract_nnef::ser::{array, tdims};
9use tract_nnef::tract_core::trivial_op_state_freeeze;
10
11pub fn register(registry: &mut Registry) {
12    registry.register_primitive(
13        "tract_onnx_random",
14        &[
15            TypeName::String.named("datum_type"),
16            TypeName::Integer.array().named("shape"),
17            TypeName::String.named("dist"),
18            TypeName::Scalar.array().named("parameters"),
19            TypeName::Integer.named("seed"),
20        ],
21        &[("output", TypeName::Scalar.tensor())],
22        load,
23    );
24    registry.register_dumper(dump);
25}
26
27fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
28    let dt: DatumType = invocation.named_arg_as::<String>(builder, "datum_type")?.parse()?;
29    let shape: TVec<TDim> = invocation.named_arg_as(builder, "shape")?;
30    let fact = dt.fact(&shape);
31    let dist: String = invocation.named_arg_as(builder, "dist")?;
32    let parameters: TVec<Arc<Tensor>> = invocation.named_arg_as(builder, "parameters")?;
33    let [p1, p2] = &*parameters else { bail!("Random expect two parameters") };
34    let dist = match &*dist {
35        "normal" => Dist::Normal { mean: p1.clone(), dev: p2.clone() },
36        "uniform" => Dist::Uniform { low: p1.clone(), high: p2.clone() },
37        _ => bail!("Unexpected distribution {}", dist),
38    };
39    let seed = invocation.get_named_arg_as(builder, "seed")?;
40    let op = Random { fact, dist, seed };
41    builder.wire(op, &[])
42}
43
44fn dump(_ast: &mut IntoAst, _node: &TypedNode, op: &Random) -> TractResult<Option<Arc<RValue>>> {
45    let mut named = vec![
46        ("datum_type", string(format!("{:?}", op.fact.datum_type))),
47        ("shape", tdims(&op.fact.shape)),
48    ];
49    if let Some(seed) = op.seed {
50        named.push(("seed", numeric(seed)));
51    }
52    match &op.dist {
53        Dist::Uniform { low, high } => {
54            named.push(("dist", string("uniform")));
55            named.push((
56                "parameters",
57                array(&[
58                    numeric(low.cast_to_scalar::<f32>()?),
59                    numeric(high.cast_to_scalar::<f32>()?),
60                ]),
61            ));
62        }
63        Dist::Normal { mean, dev } => {
64            named.push(("dist", string("normal")));
65            named.push((
66                "parameters",
67                array(&[
68                    numeric(mean.cast_to_scalar::<f32>()?),
69                    numeric(dev.cast_to_scalar::<f32>()?),
70                ]),
71            ));
72        }
73    }
74    Ok(Some(invocation("tract_onnx_random", &[], &named)))
75}
76
77#[derive(Debug, Clone, Hash)]
78pub enum Dist {
79    Uniform { low: Arc<Tensor>, high: Arc<Tensor> },
80    Normal { mean: Arc<Tensor>, dev: Arc<Tensor> },
81}
82
83#[derive(Debug, Clone, Hash)]
84pub struct Random {
85    pub fact: TypedFact,
86    pub dist: Dist,
87    pub seed: Option<u64>,
88}
89
90impl Op for Random {
91    fn name(&self) -> StaticName {
92        "Random".into()
93    }
94
95    op_as_typed_op!();
96}
97
98impl TypedOp for Random {
99    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
100        Ok(tvec!(self.fact.clone()))
101    }
102
103    as_op!();
104}
105
106impl EvalOp for Random {
107    fn is_stateless(&self) -> bool {
108        false
109    }
110
111    fn state(
112        &self,
113        _session: &mut SessionState,
114        _node_id: usize,
115    ) -> TractResult<Option<Box<dyn OpState>>> {
116        let rng = self.seed.map(SmallRng::seed_from_u64).unwrap_or_else(SmallRng::from_entropy);
117        Ok(Some(Box::new(RandomState(rng))))
118    }
119}
120
121#[derive(Clone, Debug)]
122struct RandomState(SmallRng);
123
124impl OpState for RandomState {
125    fn eval(
126        &mut self,
127        session: &mut SessionState,
128        op: &dyn Op,
129        _inputs: TVec<TValue>,
130    ) -> TractResult<TVec<TValue>> {
131        let op = op.downcast_ref::<Random>().context("op and state mismatch")?;
132        let mut tensor = unsafe {
133            Tensor::uninitialized_dt(
134                op.fact.datum_type,
135                &op.fact.shape.eval_to_usize(&session.resolved_symbols)?,
136            )?
137        };
138        match &op.dist {
139            Dist::Uniform { low, high } => match op.fact.datum_type {
140                DatumType::F32 => sample_uniform::<f32>(&mut tensor, &mut self.0, low, high)?,
141                DatumType::F64 => sample_uniform::<f64>(&mut tensor, &mut self.0, low, high)?,
142                DatumType::F16 => {
143                    sample_uniform::<f32>(&mut tensor, &mut self.0, low, high)?;
144                    tensor = tensor.cast_to::<f16>()?.into_owned();
145                }
146                _ => bail!("Random only support float types"),
147            },
148            Dist::Normal { mean, dev } => match op.fact.datum_type {
149                DatumType::F32 => sample_normal::<f32>(&mut tensor, &mut self.0, mean, dev)?,
150                DatumType::F64 => sample_normal::<f64>(&mut tensor, &mut self.0, mean, dev)?,
151                DatumType::F16 => {
152                    sample_uniform::<f32>(&mut tensor, &mut self.0, mean, dev)?;
153                    tensor = tensor.cast_to::<f16>()?.into_owned();
154                }
155                _ => bail!("Random only support float types"),
156            },
157        }
158        Ok(tvec!(tensor.into_tvalue()))
159    }
160}
161
162trivial_op_state_freeeze!(RandomState);
163
164fn sample_uniform<T: Datum + SampleUniform + Copy>(
165    t: &mut Tensor,
166    r: &mut SmallRng,
167    low: &Tensor,
168    high: &Tensor,
169) -> TractResult<()> {
170    let dist =
171        rand::distributions::Uniform::new(low.cast_to_scalar::<T>()?, high.cast_to_scalar::<T>()?);
172    t.as_slice_mut::<T>()?.iter_mut().zip(dist.sample_iter(r)).for_each(|(v, r)| *v = r);
173    Ok(())
174}
175
176fn sample_normal<T: Datum + Float + Copy>(
177    t: &mut Tensor,
178    r: &mut SmallRng,
179    mean: &Tensor,
180    dev: &Tensor,
181) -> TractResult<()>
182where
183    StandardNormal: Distribution<T>,
184{
185    let dist =
186        rand_distr::Normal::<T>::new(mean.cast_to_scalar::<T>()?, dev.cast_to_scalar::<T>()?)?;
187    t.as_slice_mut::<T>()?.iter_mut().zip(dist.sample_iter(r)).for_each(|(v, r)| *v = r);
188    Ok(())
189}