zenu_matrix/constructor/
rand.rs

1//! Constructors for creating random matrices.
2//!
3//! This module provides functions and builders for creating matrices filled with random values
4//! from various distributions such as normal distribution and uniform distribution.
5
6use std::marker::PhantomData;
7
8use crate::{
9    device::DeviceBase,
10    dim::DimTrait,
11    matrix::{Matrix, Owned, Repr},
12    num::Num,
13};
14use rand::prelude::*;
15use rand_distr::{num_traits::Float, uniform::SampleUniform, Normal, StandardNormal, Uniform};
16
17/// Creates a matrix filled with random values from a normal distribution.
18///
19/// # Arguments
20///
21/// * `mean` - The mean of the normal distribution.
22/// * `std_dev` - The standard deviation of the normal distribution.
23/// * `shape` - The shape of the matrix.
24/// * `seed` - An optional seed for the random number generator.
25/// # Panics
26/// Normal distribution may fail to create if the standard deviation is negative.
27pub fn normal<T: Num, S: DimTrait, D: DeviceBase>(
28    mean: T,
29    std_dev: T,
30    shape: S,
31    seed: Option<u64>,
32) -> Matrix<Owned<T>, S, D>
33where
34    StandardNormal: Distribution<T>,
35{
36    let mut rng: Box<dyn RngCore> = if let Some(seed) = seed {
37        Box::new(StdRng::seed_from_u64(seed))
38    } else {
39        Box::new(thread_rng())
40    };
41    let normal = Normal::new(mean, std_dev).unwrap();
42    let mut data = Vec::with_capacity(shape.num_elm());
43    for _ in 0..shape.num_elm() {
44        data.push(normal.sample(&mut *rng));
45    }
46    Matrix::from_vec(data, shape)
47}
48
49/// Creates a matrix filled with random values from a normal distribution with the same shape as another matrix.
50///
51/// # Arguments
52///
53/// * `mean` - The mean of the normal distribution.
54/// * `std_dev` - The standard deviation of the normal distribution.
55/// * `a` - The matrix whose shape is used.
56/// * `seed` - An optional seed for the random number generator.
57pub fn normal_like<T: Num, S: DimTrait, D: DeviceBase>(
58    mean: T,
59    std_dev: T,
60    a: &Matrix<Owned<T>, S, D>,
61    seed: Option<u64>,
62) -> Matrix<Owned<T>, S, D>
63where
64    StandardNormal: Distribution<T>,
65{
66    normal(mean, std_dev, a.shape(), seed)
67}
68
69/// Creates a matrix filled with random values from a uniform distribution.
70///
71/// # Arguments
72///
73/// * `low` - The lower bound of the uniform distribution.
74/// * `high` - The upper bound of the uniform distribution.
75/// * `shape` - The shape of the matrix.
76/// * `seed` - An optional seed for the random number generator.
77pub fn uniform<T, S: DimTrait, D: DeviceBase>(
78    low: T,
79    high: T,
80    shape: S,
81    seed: Option<u64>,
82) -> Matrix<Owned<T>, S, D>
83where
84    T: Num,
85    Uniform<T>: Distribution<T>,
86{
87    let mut rng: Box<dyn RngCore> = if let Some(seed) = seed {
88        Box::new(StdRng::seed_from_u64(seed))
89    } else {
90        Box::new(thread_rng())
91    };
92    let uniform = Uniform::new(low, high);
93    let mut data = Vec::with_capacity(shape.num_elm());
94    for _ in 0..shape.num_elm() {
95        data.push(uniform.sample(&mut *rng));
96    }
97    Matrix::from_vec(data, shape)
98}
99
100/// Creates a matrix filled with random values from a uniform distribution with the same shape as another matrix.
101///
102/// # Arguments
103///
104/// * `low` - The lower bound of the uniform distribution.
105/// * `high` - The upper bound of the uniform distribution.
106/// * `a` - The matrix whose shape is used.
107/// * `seed` - An optional seed for the random number generator.
108pub fn uniform_like<T, S: DimTrait, D: DeviceBase>(
109    low: T,
110    high: T,
111    a: &Matrix<Owned<T>, S, D>,
112    seed: Option<u64>,
113) -> Matrix<Owned<T>, S, D>
114where
115    T: Num,
116    Uniform<T>: Distribution<T>,
117{
118    uniform(low, high, a.shape(), seed)
119}
120
121/// A builder for creating matrices filled with random values from a normal distribution.
122#[derive(Debug, Clone, Default)]
123pub struct NormalBuilder<T: Num + Float, S: DimTrait, D: DeviceBase> {
124    mean: Option<T>,
125    std_dev: Option<T>,
126    shape: Option<S>,
127    seed: Option<u64>,
128    _marker: PhantomData<D>,
129}
130
131impl<T, S, D> NormalBuilder<T, S, D>
132where
133    T: Num,
134    S: DimTrait,
135    D: DeviceBase,
136{
137    /// Creates a new `NormalBuilder`.
138    #[must_use]
139    pub fn new() -> Self {
140        NormalBuilder {
141            mean: None,
142            std_dev: None,
143            shape: None,
144            seed: None,
145            _marker: PhantomData,
146        }
147    }
148
149    /// Sets the mean of the normal distribution.
150    #[must_use]
151    pub fn mean(mut self, mean: T) -> Self {
152        self.mean = Some(mean);
153        self
154    }
155
156    /// Sets the standard deviation of the normal distribution.
157    #[must_use]
158    pub fn std_dev(mut self, std_dev: T) -> Self {
159        self.std_dev = Some(std_dev);
160        self
161    }
162
163    /// Sets the shape of the matrix.
164    #[must_use]
165    pub fn shape(mut self, shape: S) -> Self {
166        self.shape = Some(shape);
167        self
168    }
169
170    /// Sets the seed for the random number generator.
171    #[must_use]
172    pub fn seed(mut self, seed: u64) -> Self {
173        self.seed = Some(seed);
174        self
175    }
176
177    /// Sets the shape of the matrix to be the same as another matrix.
178    #[must_use]
179    pub fn from_matrx<R2: Repr<Item = T>>(mut self, a: &Matrix<R2, S, D>) -> Self {
180        self.shape = Some(a.shape());
181        self
182    }
183
184    /// Builds the matrix.
185    /// # Panics
186    /// `mean` and `std_dev` and `shape` is not set.
187    #[must_use]
188    pub fn build(self) -> Matrix<Owned<T>, S, D>
189    where
190        StandardNormal: Distribution<T>,
191    {
192        assert!(self.mean.is_some(), "mean must be set");
193        assert!(self.std_dev.is_some(), "std_dev must be set");
194        assert!(self.shape.is_some(), "shape must be set");
195
196        normal(
197            self.mean.unwrap(),
198            self.std_dev.unwrap(),
199            self.shape.unwrap(),
200            self.seed,
201        )
202    }
203}
204
205/// A builder for creating matrices filled with random values from a uniform distribution.
206#[derive(Debug, Clone, Default)]
207pub struct UniformBuilder<T, S, D> {
208    low: Option<T>,
209    high: Option<T>,
210    shape: Option<S>,
211    seed: Option<u64>,
212    _marker: PhantomData<D>,
213}
214
215impl<T, S, D> UniformBuilder<T, S, D>
216where
217    T: Num + Float + SampleUniform,
218    Uniform<T>: Distribution<T>,
219    S: DimTrait,
220    D: DeviceBase,
221{
222    /// Creates a new `UniformBuilder`.
223    #[must_use]
224    pub fn new() -> Self {
225        UniformBuilder {
226            low: None,
227            high: None,
228            shape: None,
229            seed: None,
230            _marker: PhantomData,
231        }
232    }
233
234    /// Sets the lower bound of the uniform distribution.
235    #[must_use]
236    pub fn low(mut self, low: T) -> Self {
237        self.low = Some(low);
238        self
239    }
240
241    /// Sets the upper bound of the uniform distribution.
242    #[must_use]
243    pub fn high(mut self, high: T) -> Self {
244        self.high = Some(high);
245        self
246    }
247
248    /// Sets the shape of the matrix.
249    #[must_use]
250    pub fn shape(mut self, shape: S) -> Self {
251        self.shape = Some(shape);
252        self
253    }
254
255    /// Sets the seed for the random number generator.
256    #[must_use]
257    pub fn seed(mut self, seed: u64) -> Self {
258        self.seed = Some(seed);
259        self
260    }
261
262    /// Sets the shape of the matrix to be the same as another matrix.
263    #[must_use]
264    pub fn from_matrx<R2: Repr<Item = T>>(mut self, a: &Matrix<R2, S, D>) -> Self {
265        self.shape = Some(a.shape());
266        self
267    }
268
269    /// Builds the matrix.
270    /// # Panics
271    /// `low`, `high`, and `shape` is not set.
272    pub fn build(self) -> Matrix<Owned<T>, S, D> {
273        // if self.low.is_none() || self.high.is_none() || self.shape.is_none() {
274        //     panic!("low, high, and shape must be set");
275        // }
276        assert!(self.low.is_some(), "low must be set");
277        assert!(self.high.is_some(), "high must be set");
278        assert!(self.shape.is_some(), "shape must be set");
279
280        uniform(
281            self.low.unwrap(),
282            self.high.unwrap(),
283            self.shape.unwrap(),
284            self.seed,
285        )
286    }
287}