saorsa_node/upgrade/
rollout.rs1use sha2::{Digest, Sha256};
27use std::time::Duration;
28use tracing::debug;
29
30#[derive(Debug, Clone)]
32pub struct StagedRollout {
33 max_delay_hours: u64,
35 node_id_hash: [u8; 32],
37}
38
39impl StagedRollout {
40 #[must_use]
47 pub fn new(node_id: &[u8], max_delay_hours: u64) -> Self {
48 let mut hasher = Sha256::new();
49 hasher.update(node_id);
50 let hash_result = hasher.finalize();
51
52 let mut node_id_hash = [0u8; 32];
53 node_id_hash.copy_from_slice(&hash_result);
54
55 Self {
56 max_delay_hours,
57 node_id_hash,
58 }
59 }
60
61 #[must_use]
68 pub fn calculate_delay(&self) -> Duration {
69 if self.max_delay_hours == 0 {
70 return Duration::ZERO;
71 }
72
73 let hash_value = u64::from_le_bytes([
75 self.node_id_hash[0],
76 self.node_id_hash[1],
77 self.node_id_hash[2],
78 self.node_id_hash[3],
79 self.node_id_hash[4],
80 self.node_id_hash[5],
81 self.node_id_hash[6],
82 self.node_id_hash[7],
83 ]);
84
85 let max_delay_secs = self.max_delay_hours * 3600;
88
89 #[allow(clippy::cast_precision_loss)]
91 let delay_fraction = (hash_value as f64) / (u64::MAX as f64);
92
93 #[allow(
94 clippy::cast_possible_truncation,
95 clippy::cast_sign_loss,
96 clippy::cast_precision_loss
97 )]
98 let delay_secs = (delay_fraction * max_delay_secs as f64) as u64;
99
100 let delay = Duration::from_secs(delay_secs);
101
102 debug!(
103 "Calculated staged rollout delay: {}h {}m {}s",
104 delay.as_secs() / 3600,
105 (delay.as_secs() % 3600) / 60,
106 delay.as_secs() % 60
107 );
108
109 delay
110 }
111
112 #[must_use]
114 pub fn max_delay_hours(&self) -> u64 {
115 self.max_delay_hours
116 }
117
118 #[must_use]
120 pub fn is_enabled(&self) -> bool {
121 self.max_delay_hours > 0
122 }
123
124 #[must_use]
130 pub fn calculate_delay_for_version(&self, version: &semver::Version) -> Duration {
131 if self.max_delay_hours == 0 {
132 return Duration::ZERO;
133 }
134
135 let mut hasher = Sha256::new();
137 hasher.update(self.node_id_hash);
138 hasher.update(version.to_string().as_bytes());
139 let hash_result = hasher.finalize();
140
141 let hash_value = u64::from_le_bytes([
142 hash_result[0],
143 hash_result[1],
144 hash_result[2],
145 hash_result[3],
146 hash_result[4],
147 hash_result[5],
148 hash_result[6],
149 hash_result[7],
150 ]);
151
152 let max_delay_secs = self.max_delay_hours * 3600;
153
154 #[allow(clippy::cast_precision_loss)]
155 let delay_fraction = (hash_value as f64) / (u64::MAX as f64);
156
157 #[allow(
158 clippy::cast_possible_truncation,
159 clippy::cast_sign_loss,
160 clippy::cast_precision_loss
161 )]
162 let delay_secs = (delay_fraction * max_delay_secs as f64) as u64;
163
164 Duration::from_secs(delay_secs)
165 }
166}
167
168#[cfg(test)]
169#[allow(clippy::unwrap_used, clippy::expect_used)]
170mod tests {
171 use super::*;
172
173 #[test]
175 fn test_zero_delay_when_disabled() {
176 let rollout = StagedRollout::new(b"node-1", 0);
177 assert_eq!(rollout.calculate_delay(), Duration::ZERO);
178 assert!(!rollout.is_enabled());
179 }
180
181 #[test]
183 fn test_delay_within_range() {
184 let rollout = StagedRollout::new(b"node-1", 24);
185 let delay = rollout.calculate_delay();
186
187 assert!(delay <= Duration::from_secs(24 * 3600));
189 assert!(rollout.is_enabled());
190 }
191
192 #[test]
194 fn test_deterministic_delay() {
195 let rollout1 = StagedRollout::new(b"node-1", 24);
196 let rollout2 = StagedRollout::new(b"node-1", 24);
197
198 assert_eq!(rollout1.calculate_delay(), rollout2.calculate_delay());
199 }
200
201 #[test]
203 fn test_different_nodes_different_delays() {
204 let rollout1 = StagedRollout::new(b"node-1", 24);
205 let rollout2 = StagedRollout::new(b"node-2", 24);
206
207 assert_ne!(rollout1.calculate_delay(), rollout2.calculate_delay());
210 }
211
212 #[test]
214 fn test_delay_scales_with_max_hours() {
215 let node_id = b"consistent-node";
216 let rollout_12h = StagedRollout::new(node_id, 12);
217 let rollout_24h = StagedRollout::new(node_id, 24);
218
219 let delay_12h = rollout_12h.calculate_delay().as_secs();
222 let delay_24h = rollout_24h.calculate_delay().as_secs();
223
224 if delay_12h > 0 {
226 #[allow(clippy::cast_precision_loss)]
227 let ratio = delay_24h as f64 / delay_12h as f64;
228 assert!(
229 (ratio - 2.0).abs() < 0.1,
230 "Ratio should be ~2.0, got {ratio}"
231 );
232 }
233 }
234
235 #[test]
237 fn test_version_specific_delays() {
238 let rollout = StagedRollout::new(b"node-1", 24);
239 let v1 = semver::Version::new(1, 0, 0);
240 let v2 = semver::Version::new(2, 0, 0);
241
242 let delay_v1 = rollout.calculate_delay_for_version(&v1);
243 let delay_v2 = rollout.calculate_delay_for_version(&v2);
244
245 assert_ne!(delay_v1, delay_v2);
247 }
248
249 #[test]
251 fn test_max_delay_hours_getter() {
252 let rollout = StagedRollout::new(b"node", 48);
253 assert_eq!(rollout.max_delay_hours(), 48);
254 }
255
256 #[test]
258 fn test_large_node_id() {
259 let large_id = vec![0xABu8; 1000];
260 let rollout = StagedRollout::new(&large_id, 24);
261 let delay = rollout.calculate_delay();
262
263 assert!(delay <= Duration::from_secs(24 * 3600));
264 }
265
266 #[test]
268 fn test_empty_node_id() {
269 let rollout = StagedRollout::new(&[], 24);
270 let delay = rollout.calculate_delay();
271
272 assert!(delay <= Duration::from_secs(24 * 3600));
274 }
275
276 #[test]
278 fn test_delay_distribution() {
279 let max_hours = 24u64;
280 let max_secs = max_hours * 3600;
281 let mut delays = Vec::new();
282
283 for i in 0..100 {
285 let node_id = format!("node-{i}");
286 let rollout = StagedRollout::new(node_id.as_bytes(), max_hours);
287 delays.push(rollout.calculate_delay().as_secs());
288 }
289
290 let min = *delays.iter().min().unwrap();
292 let max = *delays.iter().max().unwrap();
293
294 assert!(min < max_secs / 4, "Should have some early delays");
297 assert!(max > 3 * max_secs / 4, "Should have some late delays");
298 }
299}