1use rand::distributions::uniform::SampleUniform;
2use rand::prelude::Distribution;
3use rand::rngs::SmallRng;
4use rand::SeedableRng;
5use rand_distr::num_traits::Float;
6use rand_distr::StandardNormal;
7use tract_nnef::internal::*;
8use tract_nnef::ser::{array, tdims};
9use tract_nnef::tract_core::trivial_op_state_freeeze;
10
11pub fn register(registry: &mut Registry) {
12 registry.register_primitive(
13 "tract_onnx_random",
14 &[
15 TypeName::String.named("datum_type"),
16 TypeName::Integer.array().named("shape"),
17 TypeName::String.named("dist"),
18 TypeName::Scalar.array().named("parameters"),
19 TypeName::Integer.named("seed"),
20 ],
21 &[("output", TypeName::Scalar.tensor())],
22 load,
23 );
24 registry.register_dumper(dump);
25}
26
27fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
28 let dt: DatumType = invocation.named_arg_as::<String>(builder, "datum_type")?.parse()?;
29 let shape: TVec<TDim> = invocation.named_arg_as(builder, "shape")?;
30 let fact = dt.fact(&shape);
31 let dist: String = invocation.named_arg_as(builder, "dist")?;
32 let parameters: TVec<Arc<Tensor>> = invocation.named_arg_as(builder, "parameters")?;
33 let [p1, p2] = &*parameters else { bail!("Random expect two parameters") };
34 let dist = match &*dist {
35 "normal" => Dist::Normal { mean: p1.clone(), dev: p2.clone() },
36 "uniform" => Dist::Uniform { low: p1.clone(), high: p2.clone() },
37 _ => bail!("Unexpected distribution {}", dist),
38 };
39 let seed = invocation.get_named_arg_as(builder, "seed")?;
40 let op = Random { fact, dist, seed };
41 builder.wire(op, &[])
42}
43
44fn dump(_ast: &mut IntoAst, _node: &TypedNode, op: &Random) -> TractResult<Option<Arc<RValue>>> {
45 let mut named = vec![
46 ("datum_type", string(format!("{:?}", op.fact.datum_type))),
47 ("shape", tdims(&op.fact.shape)),
48 ];
49 if let Some(seed) = op.seed {
50 named.push(("seed", numeric(seed)));
51 }
52 match &op.dist {
53 Dist::Uniform { low, high } => {
54 named.push(("dist", string("uniform")));
55 named.push((
56 "parameters",
57 array(&[
58 numeric(low.cast_to_scalar::<f32>()?),
59 numeric(high.cast_to_scalar::<f32>()?),
60 ]),
61 ));
62 }
63 Dist::Normal { mean, dev } => {
64 named.push(("dist", string("normal")));
65 named.push((
66 "parameters",
67 array(&[
68 numeric(mean.cast_to_scalar::<f32>()?),
69 numeric(dev.cast_to_scalar::<f32>()?),
70 ]),
71 ));
72 }
73 }
74 Ok(Some(invocation("tract_onnx_random", &[], &named)))
75}
76
77#[derive(Debug, Clone, Hash)]
78pub enum Dist {
79 Uniform { low: Arc<Tensor>, high: Arc<Tensor> },
80 Normal { mean: Arc<Tensor>, dev: Arc<Tensor> },
81}
82
83#[derive(Debug, Clone, Hash)]
84pub struct Random {
85 pub fact: TypedFact,
86 pub dist: Dist,
87 pub seed: Option<u64>,
88}
89
90impl Op for Random {
91 fn name(&self) -> StaticName {
92 "Random".into()
93 }
94
95 op_as_typed_op!();
96}
97
98impl TypedOp for Random {
99 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
100 Ok(tvec!(self.fact.clone()))
101 }
102
103 as_op!();
104}
105
106impl EvalOp for Random {
107 fn is_stateless(&self) -> bool {
108 false
109 }
110
111 fn state(
112 &self,
113 _session: &mut SessionState,
114 _node_id: usize,
115 ) -> TractResult<Option<Box<dyn OpState>>> {
116 let rng = self.seed.map(SmallRng::seed_from_u64).unwrap_or_else(SmallRng::from_entropy);
117 Ok(Some(Box::new(RandomState(rng))))
118 }
119}
120
121#[derive(Clone, Debug)]
122struct RandomState(SmallRng);
123
124impl OpState for RandomState {
125 fn eval(
126 &mut self,
127 session: &mut SessionState,
128 op: &dyn Op,
129 _inputs: TVec<TValue>,
130 ) -> TractResult<TVec<TValue>> {
131 let op = op.downcast_ref::<Random>().context("op and state mismatch")?;
132 let mut tensor = unsafe {
133 Tensor::uninitialized_dt(
134 op.fact.datum_type,
135 &op.fact.shape.eval_to_usize(&session.resolved_symbols)?,
136 )?
137 };
138 match &op.dist {
139 Dist::Uniform { low, high } => match op.fact.datum_type {
140 DatumType::F32 => sample_uniform::<f32>(&mut tensor, &mut self.0, low, high)?,
141 DatumType::F64 => sample_uniform::<f64>(&mut tensor, &mut self.0, low, high)?,
142 DatumType::F16 => {
143 sample_uniform::<f32>(&mut tensor, &mut self.0, low, high)?;
144 tensor = tensor.cast_to::<f16>()?.into_owned();
145 }
146 _ => bail!("Random only support float types"),
147 },
148 Dist::Normal { mean, dev } => match op.fact.datum_type {
149 DatumType::F32 => sample_normal::<f32>(&mut tensor, &mut self.0, mean, dev)?,
150 DatumType::F64 => sample_normal::<f64>(&mut tensor, &mut self.0, mean, dev)?,
151 DatumType::F16 => {
152 sample_uniform::<f32>(&mut tensor, &mut self.0, mean, dev)?;
153 tensor = tensor.cast_to::<f16>()?.into_owned();
154 }
155 _ => bail!("Random only support float types"),
156 },
157 }
158 Ok(tvec!(tensor.into_tvalue()))
159 }
160}
161
162trivial_op_state_freeeze!(RandomState);
163
164fn sample_uniform<T: Datum + SampleUniform + Copy>(
165 t: &mut Tensor,
166 r: &mut SmallRng,
167 low: &Tensor,
168 high: &Tensor,
169) -> TractResult<()> {
170 let dist =
171 rand::distributions::Uniform::new(low.cast_to_scalar::<T>()?, high.cast_to_scalar::<T>()?);
172 t.as_slice_mut::<T>()?.iter_mut().zip(dist.sample_iter(r)).for_each(|(v, r)| *v = r);
173 Ok(())
174}
175
176fn sample_normal<T: Datum + Float + Copy>(
177 t: &mut Tensor,
178 r: &mut SmallRng,
179 mean: &Tensor,
180 dev: &Tensor,
181) -> TractResult<()>
182where
183 StandardNormal: Distribution<T>,
184{
185 let dist =
186 rand_distr::Normal::<T>::new(mean.cast_to_scalar::<T>()?, dev.cast_to_scalar::<T>()?)?;
187 t.as_slice_mut::<T>()?.iter_mut().zip(dist.sample_iter(r)).for_each(|(v, r)| *v = r);
188 Ok(())
189}