1use crate::traits::Parameterized;
2use crate::traits::{
3 Cdf, ContinuousDistr, Entropy, HasDensity, HasSuffStat, InverseCdf,
4 Kurtosis, Mean, Median, Mode, Sampleable, Scalable, Skewness, Support,
5 Variance,
6};
7use rand::Rng;
8#[cfg(feature = "serde1")]
9use serde::{Deserialize, Serialize};
10use std::fmt;
11
12#[derive(Debug, Clone, PartialEq)]
25#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
27pub struct Scaled<D> {
28 parent: D,
29 scale: f64,
30 rate: f64,
31 logjac: f64,
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum ScaledError {
36 NonNormalScale(f64),
38 NegativeScale(f64),
40}
41
42impl std::error::Error for ScaledError {}
43
44#[cfg_attr(coverage_nightly, coverage(off))]
45impl fmt::Display for ScaledError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::NonNormalScale(scale) => {
49 write!(f, "non-normal scale: {scale}")
50 }
51 Self::NegativeScale(scale) => {
52 write!(f, "negative scale: {scale}")
53 }
54 }
55 }
56}
57
58impl<D> Scaled<D> {
59 pub fn new(parent: D, scale: f64) -> Result<Self, ScaledError> {
86 if !scale.is_normal() {
87 Err(ScaledError::NonNormalScale(scale))
88 } else if scale <= 0.0 {
89 Err(ScaledError::NegativeScale(scale))
90 } else {
91 Ok(Scaled {
92 parent,
93 scale,
94 rate: scale.recip(),
95 logjac: scale.abs().ln(),
96 })
97 }
98 }
99
100 pub fn new_unchecked(parent: D, scale: f64) -> Self {
107 Scaled {
108 parent,
109 scale,
110 rate: scale.recip(),
111 logjac: scale.abs().ln(),
112 }
113 }
114
115 pub fn from_parts_unchecked(
129 parent: D,
130 scale: f64,
131 rate: f64,
132 logjac: f64,
133 ) -> Self {
134 Scaled {
135 parent,
136 scale,
137 rate,
138 logjac,
139 }
140 }
141
142 pub fn parent(&self) -> &D {
154 &self.parent
155 }
156
157 pub fn parent_mut(&mut self) -> &mut D {
169 &mut self.parent
170 }
171
172 pub fn scale(&self) -> f64 {
182 self.scale
183 }
184
185 pub fn rate(&self) -> f64 {
195 self.rate
196 }
197
198 pub fn logjac(&self) -> f64 {
208 self.logjac
209 }
210
211 pub fn map_parent_params(
231 &self,
232 f: impl Fn(D::Parameters) -> D::Parameters,
233 ) -> Self
234 where
235 D: Parameterized,
236 {
237 let parent = self.parent.map_params(f);
238 Self::from_parts_unchecked(parent, self.scale, self.rate, self.logjac)
239 }
240}
241
242pub struct ScaledParameters<D: Parameterized> {
243 parent: D::Parameters,
244 scale: f64,
245}
246
247impl<D> Parameterized for Scaled<D>
248where
249 D: Parameterized,
250{
251 type Parameters = ScaledParameters<D>;
252
253 fn emit_params(&self) -> Self::Parameters {
254 ScaledParameters {
255 parent: self.parent.emit_params(),
256 scale: self.scale,
257 }
258 }
259
260 fn from_params(params: Self::Parameters) -> Self {
261 let parent = D::from_params(params.parent);
262 Self::new_unchecked(parent, params.scale)
263 }
264}
265
266use crate::data::ScaledSuffStat;
267
268impl<D> HasSuffStat<f64> for Scaled<D>
269where
270 D: HasSuffStat<f64>,
271{
272 type Stat = ScaledSuffStat<D::Stat>;
273
274 fn empty_suffstat(&self) -> Self::Stat {
275 ScaledSuffStat::new(self.parent.empty_suffstat(), self.scale)
276 }
277
278 fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
279 self.parent.ln_f_stat(stat.parent())
280 }
281}
282
283impl<D> Sampleable<f64> for Scaled<D>
284where
285 D: Sampleable<f64>,
286{
287 fn draw<R: Rng>(&self, rng: &mut R) -> f64 {
288 self.parent.draw(rng) * self.scale
289 }
290}
291
292impl<D> HasDensity<f64> for Scaled<D>
293where
294 D: HasDensity<f64>,
295{
296 fn ln_f(&self, x: &f64) -> f64 {
297 self.parent.ln_f(&(x * self.rate)) - self.logjac()
298 }
299}
300
301impl<D> Support<f64> for Scaled<D>
302where
303 D: Support<f64>,
304{
305 fn supports(&self, x: &f64) -> bool {
306 self.parent.supports(&(x * self.rate))
307 }
308}
309
310impl<D> ContinuousDistr<f64> for Scaled<D> where D: ContinuousDistr<f64> {}
311
312impl<D> Cdf<f64> for Scaled<D>
313where
314 D: Cdf<f64>,
315{
316 fn cdf(&self, x: &f64) -> f64 {
317 self.parent.cdf(&(x * self.rate))
318 }
319
320 fn sf(&self, x: &f64) -> f64 {
321 self.parent.sf(&(x * self.rate))
322 }
323}
324
325impl<D> InverseCdf<f64> for Scaled<D>
326where
327 D: InverseCdf<f64>,
328{
329 fn invcdf(&self, p: f64) -> f64 {
330 self.parent.invcdf(p) * self.scale
331 }
332
333 fn interval(&self, p: f64) -> (f64, f64) {
334 let (l, r) = self.parent.interval(p);
335 (l * self.scale, r * self.scale)
336 }
337}
338
339impl<D> Skewness for Scaled<D>
340where
341 D: Skewness,
342{
343 fn skewness(&self) -> Option<f64> {
344 self.parent.skewness()
345 }
346}
347
348impl<D> Kurtosis for Scaled<D>
349where
350 D: Kurtosis,
351{
352 fn kurtosis(&self) -> Option<f64> {
353 self.parent.kurtosis()
354 }
355}
356
357impl<D> Mean<f64> for Scaled<D>
358where
359 D: Mean<f64>,
360{
361 fn mean(&self) -> Option<f64> {
362 self.parent.mean().map(|m| m * self.scale)
363 }
364}
365
366impl<D> Median<f64> for Scaled<D>
367where
368 D: Median<f64>,
369{
370 fn median(&self) -> Option<f64> {
371 self.parent.median().map(|m| m * self.scale)
372 }
373}
374
375impl<D> Mode<f64> for Scaled<D>
376where
377 D: Mode<f64>,
378{
379 fn mode(&self) -> Option<f64> {
380 self.parent.mode().map(|m| m * self.scale)
381 }
382}
383
384impl<D> Variance<f64> for Scaled<D>
385where
386 D: Variance<f64>,
387{
388 fn variance(&self) -> Option<f64> {
389 self.parent.variance().map(|v| v * self.scale * self.scale)
390 }
391}
392
393impl<D> Entropy for Scaled<D>
394where
395 D: Entropy,
396{
397 fn entropy(&self) -> f64 {
398 self.parent.entropy() + self.logjac()
399 }
400}
401
402impl<D> Scalable for Scaled<D>
403where
404 D: Scalable,
405{
406 type Output = Self;
407 type Error = ScaledError;
408
409 fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error>
410 where
411 Self: Sized,
412 {
413 Scaled::new(self.parent, self.scale * scale)
414 }
415
416 fn scaled_unchecked(self, scale: f64) -> Self::Output
417 where
418 Self: Sized,
419 {
420 let new_scale = self.scale * scale;
421 Scaled {
422 parent: self.parent,
423 scale: new_scale,
424 rate: new_scale.recip(),
425 logjac: new_scale.ln(),
426 }
427 }
428}
429
430#[cfg(test)]
431mod tests {
432
433 use rand::SeedableRng;
434 use rand::rngs::SmallRng;
435
436 use crate::prelude::*;
437 use crate::test_scalable_cdf;
438 use crate::test_scalable_density;
439 use crate::test_scalable_entropy;
440 use crate::test_scalable_invcdf;
441 use crate::test_scalable_method;
442
443 #[test]
444 fn symmetric_parameters() {
445 let a = Scaled::new_unchecked(Gaussian::standard(), 3.0);
446 let b = Scaled::from_params(a.emit_params());
447
448 assert_eq!(a, b);
449 }
450
451 #[test]
452 fn support_is_scaled() {
453 let a = Scaled::new_unchecked(Uniform::new_unchecked(0.0, 1.0), 3.0);
454
455 assert!(a.supports(&2.5));
456 assert!(!a.supports(&3.5));
457 }
458
459 #[test]
460 fn draws_are_scaled() {
461 let a = Scaled::new_unchecked(Shifted::new_unchecked(Delta, 1.0), 3.0);
462 let mut rng = SmallRng::seed_from_u64(0x1234);
463 let x: f64 = a.draw(&mut rng);
464
465 assert_eq!(x, 3.0);
466 }
467
468 test_scalable_method!(
469 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap(),
470 mean
471 );
472 test_scalable_method!(
473 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap(),
474 median
475 );
476 test_scalable_method!(
477 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap(),
478 variance
479 );
480 test_scalable_method!(
481 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap(),
482 skewness
483 );
484 test_scalable_method!(
485 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap(),
486 kurtosis
487 );
488 test_scalable_density!(
489 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap()
490 );
491 test_scalable_entropy!(
492 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap()
493 );
494 test_scalable_cdf!(
495 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap()
496 );
497 test_scalable_invcdf!(
498 Scaled::new(Gaussian::new(2.0, 4.0).unwrap(), 3.0).unwrap()
499 );
500}