solana_epoch_rewards_hasher/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg))]
2use {siphasher::sip::SipHasher13, solana_address::Address, solana_hash::Hash, std::hash::Hasher};
3
4#[derive(Debug, Clone)]
5pub struct EpochRewardsHasher {
6 hasher: SipHasher13,
7 partitions: usize,
8}
9
10impl EpochRewardsHasher {
11 pub fn new(partitions: usize, seed: &Hash) -> Self {
13 let mut hasher = SipHasher13::new();
14 hasher.write(seed.as_ref());
15 Self { hasher, partitions }
16 }
17
18 pub fn hash_address_to_partition(self, address: &Address) -> usize {
20 let Self {
21 mut hasher,
22 partitions,
23 } = self;
24 hasher.write(address.as_ref());
25 let hash64 = hasher.finish();
26
27 hash_to_partition(hash64, partitions)
28 }
29}
30
31#[allow(clippy::arithmetic_side_effects)]
36fn hash_to_partition(hash: u64, partitions: usize) -> usize {
37 ((partitions as u128)
38 .saturating_mul(u128::from(hash))
39 .saturating_div(u128::from(u64::MAX).saturating_add(1))) as usize
40}
41
42#[cfg(test)]
43mod tests {
44 #![allow(clippy::arithmetic_side_effects)]
45 use {super::*, std::ops::RangeInclusive};
46
47 #[test]
48 fn test_get_equal_partition_range() {
49 let range = get_equal_partition_range(0, 2);
52 assert_eq!(*range.start(), 0);
53 assert_eq!(*range.end(), u64::MAX / 2);
54 let range = get_equal_partition_range(1, 2);
55 assert_eq!(*range.start(), u64::MAX / 2 + 1);
56 assert_eq!(*range.end(), u64::MAX);
57 }
58
59 #[test]
60 fn test_hash_to_partitions() {
61 let partitions = 16;
62 assert_eq!(hash_to_partition(0, partitions), 0);
63 assert_eq!(hash_to_partition(u64::MAX / 16, partitions), 0);
64 assert_eq!(hash_to_partition(u64::MAX / 16 + 1, partitions), 1);
65 assert_eq!(hash_to_partition(u64::MAX / 16 * 2, partitions), 1);
66 assert_eq!(hash_to_partition(u64::MAX / 16 * 2 + 1, partitions), 1);
67 assert_eq!(hash_to_partition(u64::MAX - 1, partitions), partitions - 1);
68 assert_eq!(hash_to_partition(u64::MAX, partitions), partitions - 1);
69 }
70
71 fn test_partitions(partition: usize, partitions: usize) {
72 let partition = partition.min(partitions - 1);
73 let range = get_equal_partition_range(partition, partitions);
74 assert_eq!(hash_to_partition(*range.start(), partitions), partition);
76 assert_eq!(hash_to_partition(*range.end(), partitions), partition);
77 if partition < partitions - 1 {
78 assert_eq!(
80 hash_to_partition(*range.end() + 1, partitions),
81 partition + 1
82 );
83 } else {
84 assert_eq!(*range.end(), u64::MAX);
85 }
86 if partition > 0 {
87 assert_eq!(
89 hash_to_partition(*range.start() - 1, partitions),
90 partition - 1
91 );
92 } else {
93 assert_eq!(*range.start(), 0);
94 }
95 }
96
97 #[test]
98 fn test_hash_to_partitions_equal_ranges() {
99 for partitions in [2, 4, 8, 16, 4096] {
100 assert_eq!(hash_to_partition(0, partitions), 0);
101 for partition in [0, 1, 2, partitions - 1] {
102 test_partitions(partition, partitions);
103 }
104
105 let range = get_equal_partition_range(0, partitions);
106 for partition in 1..partitions {
107 let this_range = get_equal_partition_range(partition, partitions);
108 assert_eq!(
109 this_range.end() - this_range.start(),
110 range.end() - range.start()
111 );
112 }
113 }
114 for partitions in [3, 19, 1019, 4095] {
116 for partition in [0, 1, 2, partitions - 1] {
117 test_partitions(partition, partitions);
118 }
119 let expected_len_of_partition =
120 ((u128::from(u64::MAX) + 1) / partitions as u128) as u64;
121 for partition in 0..partitions {
122 let this_range = get_equal_partition_range(partition, partitions);
123 let len = this_range.end() - this_range.start();
124 assert!(
126 len == expected_len_of_partition || len + 1 == expected_len_of_partition,
127 "{expected_len_of_partition}, {len}, {partition}, {partitions}",
128 );
129 }
130 }
131 }
132
133 fn get_equal_partition_range(partition: usize, partitions: usize) -> RangeInclusive<u64> {
136 let max_inclusive = u128::from(u64::MAX);
137 let max_plus_1 = max_inclusive + 1;
138 let partition = partition as u128;
139 let partitions = partitions as u128;
140 let mut start = max_plus_1 * partition / partitions;
141 if partition > 0 && start * partitions / max_plus_1 == partition - 1 {
142 start += 1;
144 }
145
146 let mut end_inclusive = start + max_plus_1 / partitions - 1;
147 if partition < partitions.saturating_sub(1) {
148 let next = end_inclusive + 1;
149 if next * partitions / max_plus_1 == partition {
150 end_inclusive += 1;
152 }
153 } else {
154 end_inclusive = max_inclusive;
155 }
156 RangeInclusive::new(start as u64, end_inclusive as u64)
157 }
158
159 #[test]
161 fn test_hasher_copy() {
162 let seed = Hash::new_unique();
163 let partitions = 10;
164 let hasher = EpochRewardsHasher::new(partitions, &seed);
165
166 let pk = Address::new_unique();
167
168 let b1 = hasher.clone().hash_address_to_partition(&pk);
169 let b2 = hasher.hash_address_to_partition(&pk);
170 assert_eq!(b1, b2);
171
172 let mut hasher = SipHasher13::new();
174 hasher.write(seed.as_ref());
175 hasher.write(pk.as_ref());
176 let partition = hash_to_partition(hasher.finish(), partitions);
177 assert_eq!(partition, b1);
178 }
179}