standard_dist/
lib.rs

1/*!
2`standard-dist` is a library for automatically deriving a `rand` standard
3distribution for your types via a derive macro.
4
5# Usage examples
6
7```
8use rand::distributions::Uniform;
9use standard_dist::StandardDist;
10
11// Select heads or tails with equal probability
12#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
13enum Coin {
14    Heads,
15    Tails,
16}
17
18// Flip 3 coins, independently
19#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
20struct Coins {
21    first: Coin,
22    second: Coin,
23    third: Coin,
24}
25
26// Use the `#[distribution]` attribute to customize the distribution used on
27// a field
28#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
29struct Die {
30    #[distribution(Uniform::from(1..=6))]
31    value: u8
32}
33
34// Use the `#[weight]` attribute to customize the relative probabilities of
35// enum variants
36#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
37enum D20 {
38    #[weight(18)]
39    Normal,
40
41    Critical,
42    CriticalFail,
43}
44```
45
46[`rand`] generates typed random values via the [`Distribution`] trait, which
47uses a [source of randomness] to produce values of the given type. Of particular
48note is the [`Standard`] distribution, which is the stateless "default" way to
49produce random values of a particular type. For instance:
50- For ints, this randomly chooses from all possible values for that int type
51- For bools, it chooses true or false with 50/50 probability
52- For `Option<T>`, it chooses `None` or `Some` with 50/50 probability, and uses
53  [`Standard`] to randomly populate the inner `Some` value.
54
55# Structs
56
57When you derive `StandardDist` for one of your own structs, it creates an
58`impl Distribution<YourStruct> for Standard` implementation, allowing you to
59create randomized instances of the struct via [`Rng::gen`]. This implementation
60will in turn use the `Standard` distribution to populate all the fields of
61your type.
62
63```rust
64use standard_dist::StandardDist;
65
66#[derive(StandardDist)]
67struct SimpleStruct {
68    coin: bool,
69    percent: f64,
70}
71
72let mut heads = 0;
73
74for _ in 0..2000 {
75    let s: SimpleStruct = rand::random();
76    assert!(0.0 <= s.percent);
77    assert!(s.percent < 1.0);
78    if s.coin {
79        heads += 1;
80    }
81}
82
83assert!(900 < heads, "heads: {}", heads);
84assert!(heads < 1100, "heads: {}", heads);
85```
86
87## Custom Distributions
88
89You can customize the distribution used for any field with the `#[distribution]`
90attribute:
91
92```rust
93use std::collections::HashMap;
94use standard_dist::StandardDist;
95use rand::distributions::Uniform;
96
97#[derive(StandardDist)]
98struct Die {
99    #[distribution(Uniform::from(1..=6))]
100    value: u8
101}
102
103let mut counter: HashMap<u8, u32> = HashMap::new();
104
105for _ in 0..6000 {
106    let die: Die = rand::random();
107    *counter.entry(die.value).or_insert(0) += 1;
108}
109
110assert_eq!(counter.len(), 6);
111
112for i in 1..=6 {
113    let count = counter[&i];
114    assert!(900 < count, "{}: {}", i, count);
115    assert!(count < 1100, "{}: {}", i, count);
116}
117```
118
119# Enums
120
121When applied to an enum type, the implementation will randomly select a variant
122(where each variant has an equal probability) and then populate all the fields
123of that variant in the same manner as with a struct. Enum variant fields may
124have custom distributions applied via `#[distribution]`, just like struct
125fields.
126
127```rust
128use standard_dist::StandardDist;
129
130#[derive(PartialEq, Eq, StandardDist)]
131enum Coin {
132    Heads,
133    Tails,
134}
135
136let mut heads = 0;
137
138for _ in 0..2000 {
139    let coin: Coin = rand::random();
140    if coin == Coin::Heads {
141        heads += 1;
142    }
143}
144
145assert!(900 < heads, "heads: {}", heads);
146assert!(heads < 1100, "heads: {}", heads);
147```
148
149## Weights
150
151Enum variants may be weighted with the `#[weight]` attribute to make them
152relatively more or less likely to be randomly selected. A weight of 0 means
153that the variant will never be selected. Any untagged variants will have a
154weight of 1.
155
156```rust
157use standard_dist::StandardDist;
158
159#[derive(StandardDist)]
160enum D20 {
161    #[weight(18)]
162    Normal,
163
164    CriticalHit,
165    CriticalMiss,
166}
167
168let mut crits = 0;
169
170for _ in 0..20000 {
171    let roll: D20 = rand::random();
172    if matches!(roll, D20::CriticalHit) {
173        crits += 1;
174    }
175}
176
177assert!(900 < crits, "crits: {}", crits);
178assert!(crits < 1100, "crits: {}", crits);
179```
180
181# Advanced custom distributions
182
183## Distribution types
184
185You may optionally explicitly specify a type for your distributions; this can
186sometimes be necessary when using generic types.
187
188```rust
189use std::collections::HashMap;
190use standard_dist::StandardDist;
191use rand::distributions::Uniform;
192
193#[derive(StandardDist)]
194struct Die {
195    #[distribution(Uniform<u8> = Uniform::from(1..=6))]
196    value: u8
197}
198
199let mut counter: HashMap<u8, u32> = HashMap::new();
200
201for _ in 0..6000 {
202    let die: Die = rand::random();
203    *counter.entry(die.value).or_insert(0) += 1;
204}
205
206assert_eq!(counter.len(), 6);
207
208for i in 1..=6 {
209    let count = counter[&i];
210    assert!(900 < count, "{}: {}", i, count);
211    assert!(count < 1100, "{}: {}", i, count);
212}
213```
214
215## Distribution caching
216
217In some cases, you may wish to cache a `Distribution` instance for reuse. Many
218distributions perform some initial calculations when constructed, and it can
219help performance to reuse existing distributions rather than recreate them
220every time a value is generated. `standard-dist` provides two ways to cache
221distributions: `static` and `once`. A `static` distribution is stored as a
222global static variable; this is the preferable option, but it requires the
223initializer to be usable in a `const` context. A `once` distribution is stored
224in a `once_cell::sync::OnceCell`; it is initialized the first time it's used,
225and then reused on subsequent invocations.
226
227In either case, a cache policy is specified by prefixing the type with `once` or
228`static`. The type must be specified in order to use a cache policy.
229
230```rust
231use std::collections::HashMap;
232use std::time::{Instant, Duration};
233use standard_dist::StandardDist;
234use rand::prelude::*;
235use rand::distributions::Uniform;
236
237#[derive(StandardDist)]
238struct Die {
239    #[distribution(Uniform::from(1..=6))]
240    value: u8
241}
242
243#[derive(StandardDist)]
244struct CachedDie {
245    #[distribution(once Uniform<u8> = Uniform::from(1..=6))]
246    value: u8
247}
248
249fn timed<T>(task: impl FnOnce() -> T) -> (T, Duration) {
250    let start = Instant::now();
251    (task(), start.elapsed())
252}
253
254// Count the 6s
255let mut rng = StdRng::from_entropy();
256
257let (count, plain_die_duration) = timed(|| (0..600000)
258    .map(|_| rng.gen())
259    .filter(|&Die{ value }| value == 6)
260    .count()
261);
262
263assert!(90000 < count);
264assert!(count < 110000);
265
266let (count, cache_die_duration) = timed(|| (0..600000)
267    .map(|_| rng.gen())
268    .filter(|&CachedDie{ value }| value == 6)
269    .count()
270);
271
272assert!(90000 < count);
273assert!(count < 110000);
274
275assert!(
276    cache_die_duration < plain_die_duration,
277    "cache: {:?}, plain: {:?}",
278    cache_die_duration,
279    plain_die_duration,
280);
281```
282
283Note that, unless you're generating a huge quantity of random objects, using
284`cell` is likely a pessimization because of the upfront cost to initializing
285the cell. Make sure to benchmark your specific use case if performance is a
286concern.
287
288
289[`rand`]: https://docs.rs/rand/
290[`Distribution`]: https://docs.rs/rand/latest/rand/distributions/trait.Distribution.html
291[`Standard`]: https://docs.rs/rand/latest/rand/distributions/struct.Standard.html
292[source of randomness]: https://docs.rs/rand/latest/rand/trait.Rng.html
293[`Rng::gen`]: https://docs.rs/rand/latest/rand/trait.Rng.html#method.gen
294*/
295use std::{collections::HashSet, iter};
296
297use itertools::Itertools;
298use parse::ParseStream;
299use proc_macro::TokenStream;
300use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
301use quote::{quote, ToTokens};
302use syn::{
303    parse,
304    parse::{discouraged::Speculative, Parse},
305    parse_quote,
306    spanned::Spanned,
307    DeriveInput, Error, Expr, Field, Fields, LitInt, Token, Type, Variant,
308};
309
310/// A particular field type, paired with the type of the distribution used
311/// to produce it. Used to create `where` bindings.
312#[derive(Debug, Clone, PartialEq, Eq, Hash)]
313struct FieldDistributionBinding<'a> {
314    field_type: &'a Type,
315    distribution_type: Type,
316}
317
318/// Given a list of fields (as from a struct or enum variant), return a list
319/// of all the types of those fields, paired with the associated distribution
320/// types.
321fn fields_types(fields: &Fields) -> impl Iterator<Item = syn::Result<FieldDistributionBinding>> {
322    fields.iter().filter_map(|field| {
323        field_distribution(field)
324            .map(|spec| {
325                spec.container.map(|container| FieldDistributionBinding {
326                    field_type: &field.ty,
327                    distribution_type: container.ty,
328                })
329            })
330            .transpose()
331    })
332}
333
334/// Given a type definition- a struct or enum- return an iterator over
335/// all the types of all the fields in that type, paired with the associated
336/// distribution types.
337fn item_subtypes(
338    input: &DeriveInput,
339) -> Box<dyn Iterator<Item = syn::Result<FieldDistributionBinding<'_>>> + '_> {
340    match &input.data {
341        syn::Data::Struct(data) => Box::new(fields_types(&data.fields)),
342        syn::Data::Enum(data) => Box::new(
343            data.variants
344                .iter()
345                .flat_map(|variant| fields_types(&variant.fields)),
346        ),
347        syn::Data::Union(_) => Box::new(iter::empty()),
348    }
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq)]
352enum FieldDistributionStorage {
353    Local,
354    Once,
355    Static,
356}
357
358impl Parse for FieldDistributionStorage {
359    fn parse(input: ParseStream) -> syn::Result<Self> {
360        use FieldDistributionStorage::*;
361
362        input.step(|cursor| match cursor.ident() {
363            Some((ident, tail)) if ident == "static" => Ok((Static, tail)),
364            Some((ident, tail)) if ident == "once" => Ok((Once, tail)),
365            _ => Ok((Local, *cursor)),
366        })
367    }
368}
369
370#[derive(Debug, Clone)]
371struct FieldDistributionContainer {
372    ty: Type,
373    storage: FieldDistributionStorage,
374}
375
376#[derive(Debug, Clone)]
377struct FieldDistributionSpec {
378    init: Expr,
379    container: Option<FieldDistributionContainer>,
380}
381
382impl Parse for FieldDistributionSpec {
383    fn parse(input: ParseStream) -> syn::Result<Self> {
384        let storage: FieldDistributionStorage = input.parse()?;
385
386        if storage == FieldDistributionStorage::Local {
387            // There was no storage specifier. Try to parse `type =`, but
388            // fall back to just an expression.
389            let input_with_type = input.fork();
390
391            if let Ok(ty) = input_with_type.parse() {
392                if let Ok(_eq) = input_with_type.parse::<Token![=]>() {
393                    // We got "type =", so proceed unconditionally this way
394                    input.advance_to(&input_with_type);
395                    let original = input.fork();
396                    let init = input.parse().map_err(|_| {
397                        Error::new(original.span(), "expected a distribution expression")
398                    })?;
399                    return Ok(FieldDistributionSpec {
400                        init,
401                        container: Some(FieldDistributionContainer { ty, storage }),
402                    });
403                }
404            }
405
406            let original = input.fork();
407
408            // Failed to parse "type =". Attempt to just parse the expression.
409            input
410                .parse()
411                .map(|init| FieldDistributionSpec {
412                    init,
413                    container: None,
414                })
415                .map_err(|_| Error::new(original.span(), "expected a distribution expression"))
416        } else {
417            // If we had a storage specifier, we now must have a type
418            let ty = input
419                .parse()
420                .map_err(|_| Error::new(input.span(), "expected a distribution type"))?;
421            let _equals: Token![=] = input.parse()?;
422            let init = input
423                .parse()
424                .map_err(|_| Error::new(input.span(), "expected a distribution expression"))?;
425            Ok(FieldDistributionSpec {
426                init,
427                container: Some(FieldDistributionContainer { ty, storage }),
428            })
429        }
430    }
431}
432
433/// Given a field, look at the #[distribution] attribute of the field to
434/// determine what distribution should be used. Returns the Standard
435/// distribution if there is no such attribute. The returned token stream
436/// should be an expression which can be passed to rng.sample.
437fn field_distribution(field: &Field) -> syn::Result<FieldDistributionSpec> {
438    match field
439        .attrs
440        .iter()
441        .find(|attr| attr.path.is_ident("distribution"))
442    {
443        None => Ok(FieldDistributionSpec {
444            init: parse_quote! {::rand::distributions::Standard},
445            container: Some(FieldDistributionContainer {
446                ty: parse_quote! {::rand::distributions::Standard},
447                storage: FieldDistributionStorage::Local,
448            }),
449        }),
450        Some(attr) => attr.parse_args(),
451    }
452}
453
454/// Given a list of fields, create a comma-separated series of initializers
455/// suited for initializing a type containing those fields. Return something
456/// resembling "field1: value1, field2: value2," for fields with names, and
457/// "value1, value2," for fields without names.
458///
459/// The initializers are specifically the invocations of
460/// `rng.sample(distribution)`.
461fn field_inits<'a>(
462    rng: &Ident,
463    fields: impl Iterator<Item = &'a Field>,
464) -> syn::Result<TokenStream2> {
465    fields
466        .map(|field| {
467            let field_type = &field.ty;
468            let distribution = field_distribution(&field)?;
469            let (dist_ty, dist_init) = match distribution.container {
470                None => (parse_quote! {_}, distribution.init),
471                Some(container) => {
472                    let ty = container.ty;
473                    let init = distribution.init;
474
475                    match container.storage {
476                        FieldDistributionStorage::Local => (ty, init),
477                        FieldDistributionStorage::Once => (
478                            parse_quote! {&'static #ty},
479                            parse_quote! {{
480                                static DISTRIBUTION: ::once_cell::sync::OnceCell<#ty> =
481                                    ::once_cell::sync::OnceCell::new();
482
483                                DISTRIBUTION.get_or_init(move || #init)
484                            }},
485                        ),
486                        FieldDistributionStorage::Static => (
487                            parse_quote! {&'static #ty},
488                            parse_quote! {{
489                                static DISTRIBUTION: #ty = #init;
490
491                                &DISTRIBUTION
492                            }},
493                        ),
494                    }
495                }
496            };
497
498            let init = quote! { ::rand::Rng::sample::<#field_type, #dist_ty>(#rng, #dist_init), };
499            Ok(match &field.ident {
500                Some(field_ident) => quote! { #field_ident: #init },
501                None => init,
502            })
503        })
504        .collect()
505}
506
507/// Create a literal expression initializing a value of the given `type`
508/// consisting of the given fields. Used to create expressions to initialize
509/// structs and enum variants.
510fn init_value_of_type(
511    type_path: TokenStream2,
512    rng: &Ident,
513    fields: &Fields,
514) -> syn::Result<TokenStream2> {
515    match fields {
516        Fields::Named(fields) => {
517            let field_inits = field_inits(rng, fields.named.iter())?;
518
519            Ok(quote! {
520                #type_path {
521                    #field_inits
522                }
523            })
524        }
525        Fields::Unnamed(fields) => {
526            let field_inits = field_inits(rng, fields.unnamed.iter())?;
527
528            Ok(quote! {
529                #type_path (
530                    #field_inits
531                )
532            })
533        }
534        Fields::Unit => Ok(type_path),
535    }
536}
537
538/// Look at the #[weight] attribute of an enum variant to determine what weight
539/// it should be given in random generation. Returns 1 if there is no such
540/// attribute, or an error if the attribute is malformed.
541fn enum_variant_weight(variant: &Variant) -> syn::Result<u64> {
542    match variant
543        .attrs
544        .iter()
545        .find(|attr| attr.path.is_ident("weight"))
546    {
547        None => Ok(1),
548        Some(attr) => attr.parse_args::<LitInt>()?.base10_parse(),
549    }
550}
551
552/// Similar to `try!`, this macro wraps a `syn::Result`, and converts the
553/// error to a compile error and returns it in the event of an error.
554macro_rules! syn_unwrap {
555    ($input:expr) => {
556        match ($input) {
557            Ok(value) => value,
558            Err(err @ syn::Error { .. }) => return err.into_compile_error().into(),
559        }
560    };
561}
562
563#[proc_macro_derive(StandardDist, attributes(weight, distribution))]
564pub fn standard_dist(item: TokenStream) -> TokenStream {
565    let input: DeriveInput = match parse(item) {
566        Ok(input) => input,
567        Err(err) => return err.into_compile_error().into(),
568    };
569
570    let type_ident = &input.ident;
571    let rng = Ident::new("rng", Span::mixed_site());
572
573    let sample_body = match &input.data {
574        syn::Data::Struct(data) => syn_unwrap!(init_value_of_type(
575            type_ident.to_token_stream(),
576            &rng,
577            &data.fields
578        )),
579        syn::Data::Enum(data) => {
580            // The total weights that have been accumulated for all variants.
581            let mut cumulative_weight = Some(0u64);
582
583            // TODO: There's enough weird control flow and statefulness here
584            // that it should probably be a plain for loop. The problem,
585            // ironically, is that it's actually easier to use an iterator
586            // chain, because we can use `?`. This should all be refactored
587            // into a function returning a syn::Result.
588            let match_arms = data
589                .variants
590                .iter()
591                // For each variant, compute the weight. The weight is given
592                // via a #[weight(10)] annotation, defaulting to 1. May return
593                // an error for a malformed annotation.
594                .map(|variant| enum_variant_weight(variant).map(|weight| (variant, weight)))
595                // Skip variants with a weight of 0.
596                .filter_ok(|&(_, weight)| weight != 0)
597                // Create a match arm for each variant
598                .map(|state| {
599                    let (variant, weight) = state?;
600
601                    // Process the cumulative weights. Compute the inclusive lower
602                    // and upper bounds for this variant, and update the cumulative
603                    // weight.
604                    let lower_bound = cumulative_weight.ok_or_else(|| {
605                        Error::new(variant.span(), "enum variant weight overflow")
606                    })?;
607                    let upper_bound = lower_bound.checked_add(weight - 1).ok_or_else(|| {
608                        Error::new(variant.span(), "enum variant weight overflow")
609                    })?;
610                    cumulative_weight = upper_bound.checked_add(1);
611
612                    // Create a match arm for each variant
613                    let variant_ident = &variant.ident;
614                    let variant_path = quote! {#type_ident::#variant_ident};
615                    let gen_variant = init_value_of_type(variant_path, &rng, &variant.fields)?;
616                    let pattern = quote! {#lower_bound ..= #upper_bound};
617                    Ok(quote! {#pattern => #gen_variant,})
618                })
619                .collect();
620
621            let match_arms: TokenStream2 = syn_unwrap!(match_arms);
622
623            // In the likely event that we didn't use an entire u64's worth of
624            // weights, create a trailing catch-all arm with an `unreachable`
625            let trailing_arm = cumulative_weight.map(|cumulative_weight| {
626                quote! {
627                    n => ::std::unreachable!(
628                        "The enum {} only has {} total weight, but the rng returned {}",
629                        ::std::stringify!(#type_ident),
630                        #cumulative_weight,
631                        n
632                    ),
633                }
634            });
635
636            // Create the expression that actually produces a random integer
637            // which is used to randomly select a variant.
638            let gen_variant_selector = match cumulative_weight {
639                None => quote! { ::rand::Rng::gen(#rng) },
640                Some(0) => {
641                    return Error::new(
642                        input.span(),
643                        match data.variants.len() {
644                            0 => "cannot derive StandardDist for empty enums",
645                            _ => "must have at least one variant with a nonzero weight",
646                        },
647                    )
648                    .into_compile_error()
649                    .into()
650                }
651                Some(upper_bound) => quote! { ::rand::Rng::gen_range(#rng, 0u64..#upper_bound) },
652            };
653
654            quote! {
655                match #gen_variant_selector {
656                    #match_arms
657                    #trailing_arm
658                }
659            }
660        }
661        syn::Data::Union(..) => {
662            return Error::new(input.span(), "cannot derive `StandardDist` on a union")
663                .into_compile_error()
664                .into()
665        }
666    };
667
668    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
669
670    let where_clause = if !input.generics.params.is_empty() {
671        let type_bindings: HashSet<FieldDistributionBinding> =
672            syn_unwrap!(item_subtypes(&input).collect());
673
674        let type_bindings = type_bindings.iter().map(
675            |FieldDistributionBinding {
676                 field_type,
677                 distribution_type,
678             }| quote!( #distribution_type: ::rand::distributions::Distribution<#field_type> ),
679        );
680
681        let type_bindings = type_bindings.chain(
682            where_clause
683                .into_iter()
684                .flat_map(|clause| clause.predicates.iter().map(|pred| pred.to_token_stream())),
685        );
686
687        quote! {where #(#type_bindings),*}
688    } else {
689        quote! {#where_clause}
690    };
691
692    let distribution_impl = quote! {
693        impl #impl_generics ::rand::distributions::Distribution<#type_ident #ty_generics> for ::rand::distributions::Standard
694            #where_clause
695        {
696            fn sample<R: ::rand::Rng + ?::std::marker::Sized>(&self, #rng: &mut R) -> #type_ident #ty_generics {
697                #sample_body
698            }
699        }
700    };
701
702    distribution_impl.into()
703}