solana_epoch_rewards_hasher/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2use {siphasher::sip::SipHasher13, solana_address::Address, solana_hash::Hash, std::hash::Hasher};
3
4#[derive(Debug, Clone)]
5pub struct EpochRewardsHasher {
6    hasher: SipHasher13,
7    partitions: usize,
8}
9
10impl EpochRewardsHasher {
11    /// Use SipHasher13 keyed on the `seed` for calculating epoch reward partition
12    pub fn new(partitions: usize, seed: &Hash) -> Self {
13        let mut hasher = SipHasher13::new();
14        hasher.write(seed.as_ref());
15        Self { hasher, partitions }
16    }
17
18    /// Return partition index (0..partitions) by hashing `address` with the `hasher`
19    pub fn hash_address_to_partition(self, address: &Address) -> usize {
20        let Self {
21            mut hasher,
22            partitions,
23        } = self;
24        hasher.write(address.as_ref());
25        let hash64 = hasher.finish();
26
27        hash_to_partition(hash64, partitions)
28    }
29}
30
31/// Compute the partition index by modulo the address hash to number of partitions w.o bias.
32/// (rand_int * DESIRED_RANGE_MAX) / (RAND_MAX + 1)
33// Clippy objects to `u128::from(u64::MAX).saturating_add(1)`, even though it
34// can never overflow
35#[allow(clippy::arithmetic_side_effects)]
36fn hash_to_partition(hash: u64, partitions: usize) -> usize {
37    ((partitions as u128)
38        .saturating_mul(u128::from(hash))
39        .saturating_div(u128::from(u64::MAX).saturating_add(1))) as usize
40}
41
42#[cfg(test)]
43mod tests {
44    #![allow(clippy::arithmetic_side_effects)]
45    use {super::*, std::ops::RangeInclusive};
46
47    #[test]
48    fn test_get_equal_partition_range() {
49        // show how 2 equal partition ranges are 0..=(max/2), (max/2+1)..=max
50        // the inclusive is tricky to think about
51        let range = get_equal_partition_range(0, 2);
52        assert_eq!(*range.start(), 0);
53        assert_eq!(*range.end(), u64::MAX / 2);
54        let range = get_equal_partition_range(1, 2);
55        assert_eq!(*range.start(), u64::MAX / 2 + 1);
56        assert_eq!(*range.end(), u64::MAX);
57    }
58
59    #[test]
60    fn test_hash_to_partitions() {
61        let partitions = 16;
62        assert_eq!(hash_to_partition(0, partitions), 0);
63        assert_eq!(hash_to_partition(u64::MAX / 16, partitions), 0);
64        assert_eq!(hash_to_partition(u64::MAX / 16 + 1, partitions), 1);
65        assert_eq!(hash_to_partition(u64::MAX / 16 * 2, partitions), 1);
66        assert_eq!(hash_to_partition(u64::MAX / 16 * 2 + 1, partitions), 1);
67        assert_eq!(hash_to_partition(u64::MAX - 1, partitions), partitions - 1);
68        assert_eq!(hash_to_partition(u64::MAX, partitions), partitions - 1);
69    }
70
71    fn test_partitions(partition: usize, partitions: usize) {
72        let partition = partition.min(partitions - 1);
73        let range = get_equal_partition_range(partition, partitions);
74        // beginning and end of this partition
75        assert_eq!(hash_to_partition(*range.start(), partitions), partition);
76        assert_eq!(hash_to_partition(*range.end(), partitions), partition);
77        if partition < partitions - 1 {
78            // first index in next partition
79            assert_eq!(
80                hash_to_partition(*range.end() + 1, partitions),
81                partition + 1
82            );
83        } else {
84            assert_eq!(*range.end(), u64::MAX);
85        }
86        if partition > 0 {
87            // last index in previous partition
88            assert_eq!(
89                hash_to_partition(*range.start() - 1, partitions),
90                partition - 1
91            );
92        } else {
93            assert_eq!(*range.start(), 0);
94        }
95    }
96
97    #[test]
98    fn test_hash_to_partitions_equal_ranges() {
99        for partitions in [2, 4, 8, 16, 4096] {
100            assert_eq!(hash_to_partition(0, partitions), 0);
101            for partition in [0, 1, 2, partitions - 1] {
102                test_partitions(partition, partitions);
103            }
104
105            let range = get_equal_partition_range(0, partitions);
106            for partition in 1..partitions {
107                let this_range = get_equal_partition_range(partition, partitions);
108                assert_eq!(
109                    this_range.end() - this_range.start(),
110                    range.end() - range.start()
111                );
112            }
113        }
114        // verify non-evenly divisible partitions (partitions will be different sizes by at most 1 from any other partition)
115        for partitions in [3, 19, 1019, 4095] {
116            for partition in [0, 1, 2, partitions - 1] {
117                test_partitions(partition, partitions);
118            }
119            let expected_len_of_partition =
120                ((u128::from(u64::MAX) + 1) / partitions as u128) as u64;
121            for partition in 0..partitions {
122                let this_range = get_equal_partition_range(partition, partitions);
123                let len = this_range.end() - this_range.start();
124                // size is same or 1 less
125                assert!(
126                    len == expected_len_of_partition || len + 1 == expected_len_of_partition,
127                    "{expected_len_of_partition}, {len}, {partition}, {partitions}",
128                );
129            }
130        }
131    }
132
133    /// return start and end_inclusive of `partition` indexes out of from u64::MAX+1 elements in equal `partitions`
134    /// These will be equal as long as (u64::MAX + 1) divides by `partitions` evenly
135    fn get_equal_partition_range(partition: usize, partitions: usize) -> RangeInclusive<u64> {
136        let max_inclusive = u128::from(u64::MAX);
137        let max_plus_1 = max_inclusive + 1;
138        let partition = partition as u128;
139        let partitions = partitions as u128;
140        let mut start = max_plus_1 * partition / partitions;
141        if partition > 0 && start * partitions / max_plus_1 == partition - 1 {
142            // partitions don't evenly divide and the start of this partition needs to be 1 greater
143            start += 1;
144        }
145
146        let mut end_inclusive = start + max_plus_1 / partitions - 1;
147        if partition < partitions.saturating_sub(1) {
148            let next = end_inclusive + 1;
149            if next * partitions / max_plus_1 == partition {
150                // this partition is far enough into partitions such that the len of this partition is 1 larger than expected
151                end_inclusive += 1;
152            }
153        } else {
154            end_inclusive = max_inclusive;
155        }
156        RangeInclusive::new(start as u64, end_inclusive as u64)
157    }
158
159    /// Make sure that each time hash_address_to_partition is called, it uses the initial seed state and that clone correctly copies the initial hasher state.
160    #[test]
161    fn test_hasher_copy() {
162        let seed = Hash::new_unique();
163        let partitions = 10;
164        let hasher = EpochRewardsHasher::new(partitions, &seed);
165
166        let pk = Address::new_unique();
167
168        let b1 = hasher.clone().hash_address_to_partition(&pk);
169        let b2 = hasher.hash_address_to_partition(&pk);
170        assert_eq!(b1, b2);
171
172        // make sure b1 includes the seed's hash
173        let mut hasher = SipHasher13::new();
174        hasher.write(seed.as_ref());
175        hasher.write(pk.as_ref());
176        let partition = hash_to_partition(hasher.finish(), partitions);
177        assert_eq!(partition, b1);
178    }
179}