1use std::marker::PhantomData;
7
8use crate::{
9 device::DeviceBase,
10 dim::DimTrait,
11 matrix::{Matrix, Owned, Repr},
12 num::Num,
13};
14use rand::prelude::*;
15use rand_distr::{num_traits::Float, uniform::SampleUniform, Normal, StandardNormal, Uniform};
16
17pub fn normal<T: Num, S: DimTrait, D: DeviceBase>(
28 mean: T,
29 std_dev: T,
30 shape: S,
31 seed: Option<u64>,
32) -> Matrix<Owned<T>, S, D>
33where
34 StandardNormal: Distribution<T>,
35{
36 let mut rng: Box<dyn RngCore> = if let Some(seed) = seed {
37 Box::new(StdRng::seed_from_u64(seed))
38 } else {
39 Box::new(thread_rng())
40 };
41 let normal = Normal::new(mean, std_dev).unwrap();
42 let mut data = Vec::with_capacity(shape.num_elm());
43 for _ in 0..shape.num_elm() {
44 data.push(normal.sample(&mut *rng));
45 }
46 Matrix::from_vec(data, shape)
47}
48
49pub fn normal_like<T: Num, S: DimTrait, D: DeviceBase>(
58 mean: T,
59 std_dev: T,
60 a: &Matrix<Owned<T>, S, D>,
61 seed: Option<u64>,
62) -> Matrix<Owned<T>, S, D>
63where
64 StandardNormal: Distribution<T>,
65{
66 normal(mean, std_dev, a.shape(), seed)
67}
68
69pub fn uniform<T, S: DimTrait, D: DeviceBase>(
78 low: T,
79 high: T,
80 shape: S,
81 seed: Option<u64>,
82) -> Matrix<Owned<T>, S, D>
83where
84 T: Num,
85 Uniform<T>: Distribution<T>,
86{
87 let mut rng: Box<dyn RngCore> = if let Some(seed) = seed {
88 Box::new(StdRng::seed_from_u64(seed))
89 } else {
90 Box::new(thread_rng())
91 };
92 let uniform = Uniform::new(low, high);
93 let mut data = Vec::with_capacity(shape.num_elm());
94 for _ in 0..shape.num_elm() {
95 data.push(uniform.sample(&mut *rng));
96 }
97 Matrix::from_vec(data, shape)
98}
99
100pub fn uniform_like<T, S: DimTrait, D: DeviceBase>(
109 low: T,
110 high: T,
111 a: &Matrix<Owned<T>, S, D>,
112 seed: Option<u64>,
113) -> Matrix<Owned<T>, S, D>
114where
115 T: Num,
116 Uniform<T>: Distribution<T>,
117{
118 uniform(low, high, a.shape(), seed)
119}
120
121#[derive(Debug, Clone, Default)]
123pub struct NormalBuilder<T: Num + Float, S: DimTrait, D: DeviceBase> {
124 mean: Option<T>,
125 std_dev: Option<T>,
126 shape: Option<S>,
127 seed: Option<u64>,
128 _marker: PhantomData<D>,
129}
130
131impl<T, S, D> NormalBuilder<T, S, D>
132where
133 T: Num,
134 S: DimTrait,
135 D: DeviceBase,
136{
137 #[must_use]
139 pub fn new() -> Self {
140 NormalBuilder {
141 mean: None,
142 std_dev: None,
143 shape: None,
144 seed: None,
145 _marker: PhantomData,
146 }
147 }
148
149 #[must_use]
151 pub fn mean(mut self, mean: T) -> Self {
152 self.mean = Some(mean);
153 self
154 }
155
156 #[must_use]
158 pub fn std_dev(mut self, std_dev: T) -> Self {
159 self.std_dev = Some(std_dev);
160 self
161 }
162
163 #[must_use]
165 pub fn shape(mut self, shape: S) -> Self {
166 self.shape = Some(shape);
167 self
168 }
169
170 #[must_use]
172 pub fn seed(mut self, seed: u64) -> Self {
173 self.seed = Some(seed);
174 self
175 }
176
177 #[must_use]
179 pub fn from_matrx<R2: Repr<Item = T>>(mut self, a: &Matrix<R2, S, D>) -> Self {
180 self.shape = Some(a.shape());
181 self
182 }
183
184 #[must_use]
188 pub fn build(self) -> Matrix<Owned<T>, S, D>
189 where
190 StandardNormal: Distribution<T>,
191 {
192 assert!(self.mean.is_some(), "mean must be set");
193 assert!(self.std_dev.is_some(), "std_dev must be set");
194 assert!(self.shape.is_some(), "shape must be set");
195
196 normal(
197 self.mean.unwrap(),
198 self.std_dev.unwrap(),
199 self.shape.unwrap(),
200 self.seed,
201 )
202 }
203}
204
205#[derive(Debug, Clone, Default)]
207pub struct UniformBuilder<T, S, D> {
208 low: Option<T>,
209 high: Option<T>,
210 shape: Option<S>,
211 seed: Option<u64>,
212 _marker: PhantomData<D>,
213}
214
215impl<T, S, D> UniformBuilder<T, S, D>
216where
217 T: Num + Float + SampleUniform,
218 Uniform<T>: Distribution<T>,
219 S: DimTrait,
220 D: DeviceBase,
221{
222 #[must_use]
224 pub fn new() -> Self {
225 UniformBuilder {
226 low: None,
227 high: None,
228 shape: None,
229 seed: None,
230 _marker: PhantomData,
231 }
232 }
233
234 #[must_use]
236 pub fn low(mut self, low: T) -> Self {
237 self.low = Some(low);
238 self
239 }
240
241 #[must_use]
243 pub fn high(mut self, high: T) -> Self {
244 self.high = Some(high);
245 self
246 }
247
248 #[must_use]
250 pub fn shape(mut self, shape: S) -> Self {
251 self.shape = Some(shape);
252 self
253 }
254
255 #[must_use]
257 pub fn seed(mut self, seed: u64) -> Self {
258 self.seed = Some(seed);
259 self
260 }
261
262 #[must_use]
264 pub fn from_matrx<R2: Repr<Item = T>>(mut self, a: &Matrix<R2, S, D>) -> Self {
265 self.shape = Some(a.shape());
266 self
267 }
268
269 pub fn build(self) -> Matrix<Owned<T>, S, D> {
273 assert!(self.low.is_some(), "low must be set");
277 assert!(self.high.is_some(), "high must be set");
278 assert!(self.shape.is_some(), "shape must be set");
279
280 uniform(
281 self.low.unwrap(),
282 self.high.unwrap(),
283 self.shape.unwrap(),
284 self.seed,
285 )
286 }
287}