sbrd_gen/generator/distribution/
normal_generator.rs

1use crate::builder::GeneratorBuilder;
2use crate::error::{BuildError, GenerateError};
3use crate::generator::{GeneratorBase, Randomizer};
4use crate::value::{DataValue, DataValueMap, SbrdReal};
5use crate::GeneratorType;
6use rand::distributions::Distribution;
7use rand_distr::Normal;
8
9/// The generator with generate [`DataValue::Real`] from normal distribution
10///
11/// [`DataValue::Real`]: ../../value/enum.DataValue.html#variant.Real
12#[derive(Debug, Clone, Copy)]
13pub struct NormalGenerator {
14    nullable: bool,
15    distribution: Normal<SbrdReal>,
16}
17
18impl<R: Randomizer + ?Sized> GeneratorBase<R> for NormalGenerator {
19    fn create(builder: GeneratorBuilder) -> Result<Self, BuildError>
20    where
21        Self: Sized,
22    {
23        let GeneratorBuilder {
24            generator_type,
25            nullable,
26            parameters,
27            ..
28        } = builder;
29
30        if generator_type != GeneratorType::DistNormal {
31            return Err(BuildError::InvalidType(generator_type));
32        }
33
34        let (mean, std_dev): (SbrdReal, SbrdReal) = match parameters {
35            None => Err(BuildError::NotExistValueOf("parameters".to_string())),
36            Some(parameters) => {
37                let _mean = parameters
38                    .get(Self::MEAN)
39                    .map(|v| {
40                        v.to_parse_string().parse::<SbrdReal>().map_err(|e| {
41                            BuildError::FailParseValue(
42                                v.to_parse_string(),
43                                "Real".to_string(),
44                                e.to_string(),
45                            )
46                        })
47                    })
48                    .unwrap_or_else(|| Ok(0.0))?;
49
50                let _std_dev = parameters
51                    .get(Self::STD_DEV)
52                    .map(|v| {
53                        v.to_parse_string().parse::<SbrdReal>().map_err(|e| {
54                            BuildError::FailParseValue(
55                                v.to_parse_string(),
56                                "Real".to_string(),
57                                e.to_string(),
58                            )
59                        })
60                    })
61                    .unwrap_or_else(|| Ok(1.0))?;
62                if _std_dev < 0.0 {
63                    return Err(BuildError::InvalidValue(format!(
64                        "std_dev {} is less than 0.0",
65                        _std_dev
66                    )));
67                }
68
69                Ok((_mean, _std_dev))
70            }
71        }?;
72
73        Ok(Self {
74            nullable,
75            distribution: Normal::new(mean, std_dev).map_err(|e| {
76                BuildError::FailBuildDistribution("Normal".to_string(), e.to_string())
77            })?,
78        })
79    }
80
81    fn is_nullable(&self) -> bool {
82        self.nullable
83    }
84
85    fn generate_without_null(
86        &self,
87        rng: &mut R,
88        _context: &DataValueMap<&str>,
89    ) -> Result<DataValue, GenerateError> {
90        Ok(DataValue::Real(self.distribution.sample(rng)))
91    }
92}
93
94impl NormalGenerator {
95    /// mean
96    pub const MEAN: &'static str = "mean";
97    /// standard deviation
98    pub const STD_DEV: &'static str = "std_dev";
99}