standard_dist/lib.rs
1/*!
2`standard-dist` is a library for automatically deriving a `rand` standard
3distribution for your types via a derive macro.
4
5# Usage examples
6
7```
8use rand::distributions::Uniform;
9use standard_dist::StandardDist;
10
11// Select heads or tails with equal probability
12#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
13enum Coin {
14 Heads,
15 Tails,
16}
17
18// Flip 3 coins, independently
19#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
20struct Coins {
21 first: Coin,
22 second: Coin,
23 third: Coin,
24}
25
26// Use the `#[distribution]` attribute to customize the distribution used on
27// a field
28#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
29struct Die {
30 #[distribution(Uniform::from(1..=6))]
31 value: u8
32}
33
34// Use the `#[weight]` attribute to customize the relative probabilities of
35// enum variants
36#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
37enum D20 {
38 #[weight(18)]
39 Normal,
40
41 Critical,
42 CriticalFail,
43}
44```
45
46[`rand`] generates typed random values via the [`Distribution`] trait, which
47uses a [source of randomness] to produce values of the given type. Of particular
48note is the [`Standard`] distribution, which is the stateless "default" way to
49produce random values of a particular type. For instance:
50- For ints, this randomly chooses from all possible values for that int type
51- For bools, it chooses true or false with 50/50 probability
52- For `Option<T>`, it chooses `None` or `Some` with 50/50 probability, and uses
53 [`Standard`] to randomly populate the inner `Some` value.
54
55# Structs
56
57When you derive `StandardDist` for one of your own structs, it creates an
58`impl Distribution<YourStruct> for Standard` implementation, allowing you to
59create randomized instances of the struct via [`Rng::gen`]. This implementation
60will in turn use the `Standard` distribution to populate all the fields of
61your type.
62
63```rust
64use standard_dist::StandardDist;
65
66#[derive(StandardDist)]
67struct SimpleStruct {
68 coin: bool,
69 percent: f64,
70}
71
72let mut heads = 0;
73
74for _ in 0..2000 {
75 let s: SimpleStruct = rand::random();
76 assert!(0.0 <= s.percent);
77 assert!(s.percent < 1.0);
78 if s.coin {
79 heads += 1;
80 }
81}
82
83assert!(900 < heads, "heads: {}", heads);
84assert!(heads < 1100, "heads: {}", heads);
85```
86
87## Custom Distributions
88
89You can customize the distribution used for any field with the `#[distribution]`
90attribute:
91
92```rust
93use std::collections::HashMap;
94use standard_dist::StandardDist;
95use rand::distributions::Uniform;
96
97#[derive(StandardDist)]
98struct Die {
99 #[distribution(Uniform::from(1..=6))]
100 value: u8
101}
102
103let mut counter: HashMap<u8, u32> = HashMap::new();
104
105for _ in 0..6000 {
106 let die: Die = rand::random();
107 *counter.entry(die.value).or_insert(0) += 1;
108}
109
110assert_eq!(counter.len(), 6);
111
112for i in 1..=6 {
113 let count = counter[&i];
114 assert!(900 < count, "{}: {}", i, count);
115 assert!(count < 1100, "{}: {}", i, count);
116}
117```
118
119# Enums
120
121When applied to an enum type, the implementation will randomly select a variant
122(where each variant has an equal probability) and then populate all the fields
123of that variant in the same manner as with a struct. Enum variant fields may
124have custom distributions applied via `#[distribution]`, just like struct
125fields.
126
127```rust
128use standard_dist::StandardDist;
129
130#[derive(PartialEq, Eq, StandardDist)]
131enum Coin {
132 Heads,
133 Tails,
134}
135
136let mut heads = 0;
137
138for _ in 0..2000 {
139 let coin: Coin = rand::random();
140 if coin == Coin::Heads {
141 heads += 1;
142 }
143}
144
145assert!(900 < heads, "heads: {}", heads);
146assert!(heads < 1100, "heads: {}", heads);
147```
148
149## Weights
150
151Enum variants may be weighted with the `#[weight]` attribute to make them
152relatively more or less likely to be randomly selected. A weight of 0 means
153that the variant will never be selected. Any untagged variants will have a
154weight of 1.
155
156```rust
157use standard_dist::StandardDist;
158
159#[derive(StandardDist)]
160enum D20 {
161 #[weight(18)]
162 Normal,
163
164 CriticalHit,
165 CriticalMiss,
166}
167
168let mut crits = 0;
169
170for _ in 0..20000 {
171 let roll: D20 = rand::random();
172 if matches!(roll, D20::CriticalHit) {
173 crits += 1;
174 }
175}
176
177assert!(900 < crits, "crits: {}", crits);
178assert!(crits < 1100, "crits: {}", crits);
179```
180
181# Advanced custom distributions
182
183## Distribution types
184
185You may optionally explicitly specify a type for your distributions; this can
186sometimes be necessary when using generic types.
187
188```rust
189use std::collections::HashMap;
190use standard_dist::StandardDist;
191use rand::distributions::Uniform;
192
193#[derive(StandardDist)]
194struct Die {
195 #[distribution(Uniform<u8> = Uniform::from(1..=6))]
196 value: u8
197}
198
199let mut counter: HashMap<u8, u32> = HashMap::new();
200
201for _ in 0..6000 {
202 let die: Die = rand::random();
203 *counter.entry(die.value).or_insert(0) += 1;
204}
205
206assert_eq!(counter.len(), 6);
207
208for i in 1..=6 {
209 let count = counter[&i];
210 assert!(900 < count, "{}: {}", i, count);
211 assert!(count < 1100, "{}: {}", i, count);
212}
213```
214
215## Distribution caching
216
217In some cases, you may wish to cache a `Distribution` instance for reuse. Many
218distributions perform some initial calculations when constructed, and it can
219help performance to reuse existing distributions rather than recreate them
220every time a value is generated. `standard-dist` provides two ways to cache
221distributions: `static` and `once`. A `static` distribution is stored as a
222global static variable; this is the preferable option, but it requires the
223initializer to be usable in a `const` context. A `once` distribution is stored
224in a `once_cell::sync::OnceCell`; it is initialized the first time it's used,
225and then reused on subsequent invocations.
226
227In either case, a cache policy is specified by prefixing the type with `once` or
228`static`. The type must be specified in order to use a cache policy.
229
230```rust
231use std::collections::HashMap;
232use std::time::{Instant, Duration};
233use standard_dist::StandardDist;
234use rand::prelude::*;
235use rand::distributions::Uniform;
236
237#[derive(StandardDist)]
238struct Die {
239 #[distribution(Uniform::from(1..=6))]
240 value: u8
241}
242
243#[derive(StandardDist)]
244struct CachedDie {
245 #[distribution(once Uniform<u8> = Uniform::from(1..=6))]
246 value: u8
247}
248
249fn timed<T>(task: impl FnOnce() -> T) -> (T, Duration) {
250 let start = Instant::now();
251 (task(), start.elapsed())
252}
253
254// Count the 6s
255let mut rng = StdRng::from_entropy();
256
257let (count, plain_die_duration) = timed(|| (0..600000)
258 .map(|_| rng.gen())
259 .filter(|&Die{ value }| value == 6)
260 .count()
261);
262
263assert!(90000 < count);
264assert!(count < 110000);
265
266let (count, cache_die_duration) = timed(|| (0..600000)
267 .map(|_| rng.gen())
268 .filter(|&CachedDie{ value }| value == 6)
269 .count()
270);
271
272assert!(90000 < count);
273assert!(count < 110000);
274
275assert!(
276 cache_die_duration < plain_die_duration,
277 "cache: {:?}, plain: {:?}",
278 cache_die_duration,
279 plain_die_duration,
280);
281```
282
283Note that, unless you're generating a huge quantity of random objects, using
284`cell` is likely a pessimization because of the upfront cost to initializing
285the cell. Make sure to benchmark your specific use case if performance is a
286concern.
287
288
289[`rand`]: https://docs.rs/rand/
290[`Distribution`]: https://docs.rs/rand/latest/rand/distributions/trait.Distribution.html
291[`Standard`]: https://docs.rs/rand/latest/rand/distributions/struct.Standard.html
292[source of randomness]: https://docs.rs/rand/latest/rand/trait.Rng.html
293[`Rng::gen`]: https://docs.rs/rand/latest/rand/trait.Rng.html#method.gen
294*/
295use std::{collections::HashSet, iter};
296
297use itertools::Itertools;
298use parse::ParseStream;
299use proc_macro::TokenStream;
300use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
301use quote::{quote, ToTokens};
302use syn::{
303 parse,
304 parse::{discouraged::Speculative, Parse},
305 parse_quote,
306 spanned::Spanned,
307 DeriveInput, Error, Expr, Field, Fields, LitInt, Token, Type, Variant,
308};
309
310/// A particular field type, paired with the type of the distribution used
311/// to produce it. Used to create `where` bindings.
312#[derive(Debug, Clone, PartialEq, Eq, Hash)]
313struct FieldDistributionBinding<'a> {
314 field_type: &'a Type,
315 distribution_type: Type,
316}
317
318/// Given a list of fields (as from a struct or enum variant), return a list
319/// of all the types of those fields, paired with the associated distribution
320/// types.
321fn fields_types(fields: &Fields) -> impl Iterator<Item = syn::Result<FieldDistributionBinding>> {
322 fields.iter().filter_map(|field| {
323 field_distribution(field)
324 .map(|spec| {
325 spec.container.map(|container| FieldDistributionBinding {
326 field_type: &field.ty,
327 distribution_type: container.ty,
328 })
329 })
330 .transpose()
331 })
332}
333
334/// Given a type definition- a struct or enum- return an iterator over
335/// all the types of all the fields in that type, paired with the associated
336/// distribution types.
337fn item_subtypes(
338 input: &DeriveInput,
339) -> Box<dyn Iterator<Item = syn::Result<FieldDistributionBinding<'_>>> + '_> {
340 match &input.data {
341 syn::Data::Struct(data) => Box::new(fields_types(&data.fields)),
342 syn::Data::Enum(data) => Box::new(
343 data.variants
344 .iter()
345 .flat_map(|variant| fields_types(&variant.fields)),
346 ),
347 syn::Data::Union(_) => Box::new(iter::empty()),
348 }
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq)]
352enum FieldDistributionStorage {
353 Local,
354 Once,
355 Static,
356}
357
358impl Parse for FieldDistributionStorage {
359 fn parse(input: ParseStream) -> syn::Result<Self> {
360 use FieldDistributionStorage::*;
361
362 input.step(|cursor| match cursor.ident() {
363 Some((ident, tail)) if ident == "static" => Ok((Static, tail)),
364 Some((ident, tail)) if ident == "once" => Ok((Once, tail)),
365 _ => Ok((Local, *cursor)),
366 })
367 }
368}
369
370#[derive(Debug, Clone)]
371struct FieldDistributionContainer {
372 ty: Type,
373 storage: FieldDistributionStorage,
374}
375
376#[derive(Debug, Clone)]
377struct FieldDistributionSpec {
378 init: Expr,
379 container: Option<FieldDistributionContainer>,
380}
381
382impl Parse for FieldDistributionSpec {
383 fn parse(input: ParseStream) -> syn::Result<Self> {
384 let storage: FieldDistributionStorage = input.parse()?;
385
386 if storage == FieldDistributionStorage::Local {
387 // There was no storage specifier. Try to parse `type =`, but
388 // fall back to just an expression.
389 let input_with_type = input.fork();
390
391 if let Ok(ty) = input_with_type.parse() {
392 if let Ok(_eq) = input_with_type.parse::<Token![=]>() {
393 // We got "type =", so proceed unconditionally this way
394 input.advance_to(&input_with_type);
395 let original = input.fork();
396 let init = input.parse().map_err(|_| {
397 Error::new(original.span(), "expected a distribution expression")
398 })?;
399 return Ok(FieldDistributionSpec {
400 init,
401 container: Some(FieldDistributionContainer { ty, storage }),
402 });
403 }
404 }
405
406 let original = input.fork();
407
408 // Failed to parse "type =". Attempt to just parse the expression.
409 input
410 .parse()
411 .map(|init| FieldDistributionSpec {
412 init,
413 container: None,
414 })
415 .map_err(|_| Error::new(original.span(), "expected a distribution expression"))
416 } else {
417 // If we had a storage specifier, we now must have a type
418 let ty = input
419 .parse()
420 .map_err(|_| Error::new(input.span(), "expected a distribution type"))?;
421 let _equals: Token![=] = input.parse()?;
422 let init = input
423 .parse()
424 .map_err(|_| Error::new(input.span(), "expected a distribution expression"))?;
425 Ok(FieldDistributionSpec {
426 init,
427 container: Some(FieldDistributionContainer { ty, storage }),
428 })
429 }
430 }
431}
432
433/// Given a field, look at the #[distribution] attribute of the field to
434/// determine what distribution should be used. Returns the Standard
435/// distribution if there is no such attribute. The returned token stream
436/// should be an expression which can be passed to rng.sample.
437fn field_distribution(field: &Field) -> syn::Result<FieldDistributionSpec> {
438 match field
439 .attrs
440 .iter()
441 .find(|attr| attr.path.is_ident("distribution"))
442 {
443 None => Ok(FieldDistributionSpec {
444 init: parse_quote! {::rand::distributions::Standard},
445 container: Some(FieldDistributionContainer {
446 ty: parse_quote! {::rand::distributions::Standard},
447 storage: FieldDistributionStorage::Local,
448 }),
449 }),
450 Some(attr) => attr.parse_args(),
451 }
452}
453
454/// Given a list of fields, create a comma-separated series of initializers
455/// suited for initializing a type containing those fields. Return something
456/// resembling "field1: value1, field2: value2," for fields with names, and
457/// "value1, value2," for fields without names.
458///
459/// The initializers are specifically the invocations of
460/// `rng.sample(distribution)`.
461fn field_inits<'a>(
462 rng: &Ident,
463 fields: impl Iterator<Item = &'a Field>,
464) -> syn::Result<TokenStream2> {
465 fields
466 .map(|field| {
467 let field_type = &field.ty;
468 let distribution = field_distribution(&field)?;
469 let (dist_ty, dist_init) = match distribution.container {
470 None => (parse_quote! {_}, distribution.init),
471 Some(container) => {
472 let ty = container.ty;
473 let init = distribution.init;
474
475 match container.storage {
476 FieldDistributionStorage::Local => (ty, init),
477 FieldDistributionStorage::Once => (
478 parse_quote! {&'static #ty},
479 parse_quote! {{
480 static DISTRIBUTION: ::once_cell::sync::OnceCell<#ty> =
481 ::once_cell::sync::OnceCell::new();
482
483 DISTRIBUTION.get_or_init(move || #init)
484 }},
485 ),
486 FieldDistributionStorage::Static => (
487 parse_quote! {&'static #ty},
488 parse_quote! {{
489 static DISTRIBUTION: #ty = #init;
490
491 &DISTRIBUTION
492 }},
493 ),
494 }
495 }
496 };
497
498 let init = quote! { ::rand::Rng::sample::<#field_type, #dist_ty>(#rng, #dist_init), };
499 Ok(match &field.ident {
500 Some(field_ident) => quote! { #field_ident: #init },
501 None => init,
502 })
503 })
504 .collect()
505}
506
507/// Create a literal expression initializing a value of the given `type`
508/// consisting of the given fields. Used to create expressions to initialize
509/// structs and enum variants.
510fn init_value_of_type(
511 type_path: TokenStream2,
512 rng: &Ident,
513 fields: &Fields,
514) -> syn::Result<TokenStream2> {
515 match fields {
516 Fields::Named(fields) => {
517 let field_inits = field_inits(rng, fields.named.iter())?;
518
519 Ok(quote! {
520 #type_path {
521 #field_inits
522 }
523 })
524 }
525 Fields::Unnamed(fields) => {
526 let field_inits = field_inits(rng, fields.unnamed.iter())?;
527
528 Ok(quote! {
529 #type_path (
530 #field_inits
531 )
532 })
533 }
534 Fields::Unit => Ok(type_path),
535 }
536}
537
538/// Look at the #[weight] attribute of an enum variant to determine what weight
539/// it should be given in random generation. Returns 1 if there is no such
540/// attribute, or an error if the attribute is malformed.
541fn enum_variant_weight(variant: &Variant) -> syn::Result<u64> {
542 match variant
543 .attrs
544 .iter()
545 .find(|attr| attr.path.is_ident("weight"))
546 {
547 None => Ok(1),
548 Some(attr) => attr.parse_args::<LitInt>()?.base10_parse(),
549 }
550}
551
552/// Similar to `try!`, this macro wraps a `syn::Result`, and converts the
553/// error to a compile error and returns it in the event of an error.
554macro_rules! syn_unwrap {
555 ($input:expr) => {
556 match ($input) {
557 Ok(value) => value,
558 Err(err @ syn::Error { .. }) => return err.into_compile_error().into(),
559 }
560 };
561}
562
563#[proc_macro_derive(StandardDist, attributes(weight, distribution))]
564pub fn standard_dist(item: TokenStream) -> TokenStream {
565 let input: DeriveInput = match parse(item) {
566 Ok(input) => input,
567 Err(err) => return err.into_compile_error().into(),
568 };
569
570 let type_ident = &input.ident;
571 let rng = Ident::new("rng", Span::mixed_site());
572
573 let sample_body = match &input.data {
574 syn::Data::Struct(data) => syn_unwrap!(init_value_of_type(
575 type_ident.to_token_stream(),
576 &rng,
577 &data.fields
578 )),
579 syn::Data::Enum(data) => {
580 // The total weights that have been accumulated for all variants.
581 let mut cumulative_weight = Some(0u64);
582
583 // TODO: There's enough weird control flow and statefulness here
584 // that it should probably be a plain for loop. The problem,
585 // ironically, is that it's actually easier to use an iterator
586 // chain, because we can use `?`. This should all be refactored
587 // into a function returning a syn::Result.
588 let match_arms = data
589 .variants
590 .iter()
591 // For each variant, compute the weight. The weight is given
592 // via a #[weight(10)] annotation, defaulting to 1. May return
593 // an error for a malformed annotation.
594 .map(|variant| enum_variant_weight(variant).map(|weight| (variant, weight)))
595 // Skip variants with a weight of 0.
596 .filter_ok(|&(_, weight)| weight != 0)
597 // Create a match arm for each variant
598 .map(|state| {
599 let (variant, weight) = state?;
600
601 // Process the cumulative weights. Compute the inclusive lower
602 // and upper bounds for this variant, and update the cumulative
603 // weight.
604 let lower_bound = cumulative_weight.ok_or_else(|| {
605 Error::new(variant.span(), "enum variant weight overflow")
606 })?;
607 let upper_bound = lower_bound.checked_add(weight - 1).ok_or_else(|| {
608 Error::new(variant.span(), "enum variant weight overflow")
609 })?;
610 cumulative_weight = upper_bound.checked_add(1);
611
612 // Create a match arm for each variant
613 let variant_ident = &variant.ident;
614 let variant_path = quote! {#type_ident::#variant_ident};
615 let gen_variant = init_value_of_type(variant_path, &rng, &variant.fields)?;
616 let pattern = quote! {#lower_bound ..= #upper_bound};
617 Ok(quote! {#pattern => #gen_variant,})
618 })
619 .collect();
620
621 let match_arms: TokenStream2 = syn_unwrap!(match_arms);
622
623 // In the likely event that we didn't use an entire u64's worth of
624 // weights, create a trailing catch-all arm with an `unreachable`
625 let trailing_arm = cumulative_weight.map(|cumulative_weight| {
626 quote! {
627 n => ::std::unreachable!(
628 "The enum {} only has {} total weight, but the rng returned {}",
629 ::std::stringify!(#type_ident),
630 #cumulative_weight,
631 n
632 ),
633 }
634 });
635
636 // Create the expression that actually produces a random integer
637 // which is used to randomly select a variant.
638 let gen_variant_selector = match cumulative_weight {
639 None => quote! { ::rand::Rng::gen(#rng) },
640 Some(0) => {
641 return Error::new(
642 input.span(),
643 match data.variants.len() {
644 0 => "cannot derive StandardDist for empty enums",
645 _ => "must have at least one variant with a nonzero weight",
646 },
647 )
648 .into_compile_error()
649 .into()
650 }
651 Some(upper_bound) => quote! { ::rand::Rng::gen_range(#rng, 0u64..#upper_bound) },
652 };
653
654 quote! {
655 match #gen_variant_selector {
656 #match_arms
657 #trailing_arm
658 }
659 }
660 }
661 syn::Data::Union(..) => {
662 return Error::new(input.span(), "cannot derive `StandardDist` on a union")
663 .into_compile_error()
664 .into()
665 }
666 };
667
668 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
669
670 let where_clause = if !input.generics.params.is_empty() {
671 let type_bindings: HashSet<FieldDistributionBinding> =
672 syn_unwrap!(item_subtypes(&input).collect());
673
674 let type_bindings = type_bindings.iter().map(
675 |FieldDistributionBinding {
676 field_type,
677 distribution_type,
678 }| quote!( #distribution_type: ::rand::distributions::Distribution<#field_type> ),
679 );
680
681 let type_bindings = type_bindings.chain(
682 where_clause
683 .into_iter()
684 .flat_map(|clause| clause.predicates.iter().map(|pred| pred.to_token_stream())),
685 );
686
687 quote! {where #(#type_bindings),*}
688 } else {
689 quote! {#where_clause}
690 };
691
692 let distribution_impl = quote! {
693 impl #impl_generics ::rand::distributions::Distribution<#type_ident #ty_generics> for ::rand::distributions::Standard
694 #where_clause
695 {
696 fn sample<R: ::rand::Rng + ?::std::marker::Sized>(&self, #rng: &mut R) -> #type_ident #ty_generics {
697 #sample_body
698 }
699 }
700 };
701
702 distribution_impl.into()
703}