1use crate::ConsensusRelays;
16use crate::params::NetParameters;
17use bitflags::bitflags;
18use tor_netdoc::doc::netstatus::{self, MdConsensus, MdRouterStatus, NetParams};
19
20fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23where
24 I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25{
26 let has_measured = weights.clone().any(|w| w.is_measured());
27 let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28 let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29
30 if !has_nonzero {
31 BandwidthFn::Uniform
34 } else if !has_measured {
35 BandwidthFn::IncludeUnmeasured
38 } else if has_nonzero_measured {
39 BandwidthFn::MeasuredOnly
42 } else {
43 BandwidthFn::Uniform
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54enum BandwidthFn {
55 Uniform,
58 IncludeUnmeasured,
61 MeasuredOnly,
63}
64
65impl BandwidthFn {
66 fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69 use BandwidthFn::*;
70 use netstatus::RelayWeight::*;
71 match (self, w) {
72 (Uniform, _) => 1,
73 (IncludeUnmeasured, Unmeasured(u)) => *u,
74 (IncludeUnmeasured, Measured(m)) => *m,
75 (MeasuredOnly, Unmeasured(_)) => 0,
76 (MeasuredOnly, Measured(m)) => *m,
77 (_, _) => 0,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Copy)]
89#[non_exhaustive]
90pub enum WeightRole {
91 Guard,
93 Middle,
95 Exit,
97 BeginDir,
99 Unweighted,
101 HsIntro,
103 HsRend,
105}
106
107#[derive(Clone, Debug, Copy)]
109struct RelayWeight {
110 as_guard: u32,
112 as_middle: u32,
114 as_exit: u32,
116 as_dir: u32,
118}
119
120impl std::ops::Mul<u32> for RelayWeight {
121 type Output = Self;
122 fn mul(self, rhs: u32) -> Self {
123 RelayWeight {
124 as_guard: self.as_guard * rhs,
125 as_middle: self.as_middle * rhs,
126 as_exit: self.as_exit * rhs,
127 as_dir: self.as_dir * rhs,
128 }
129 }
130}
131impl std::ops::Div<u32> for RelayWeight {
132 type Output = Self;
133 fn div(self, rhs: u32) -> Self {
134 RelayWeight {
135 as_guard: self.as_guard / rhs,
136 as_middle: self.as_middle / rhs,
137 as_exit: self.as_exit / rhs,
138 as_dir: self.as_dir / rhs,
139 }
140 }
141}
142
143impl RelayWeight {
144 #[allow(clippy::unwrap_used)]
147 fn max_weight(&self) -> u32 {
148 [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
149 .iter()
150 .max()
151 .copied()
152 .unwrap()
153 }
154 fn for_role(&self, role: WeightRole) -> u32 {
157 match role {
158 WeightRole::Guard => self.as_guard,
159 WeightRole::Middle => self.as_middle,
160 WeightRole::Exit => self.as_exit,
161 WeightRole::BeginDir => self.as_dir,
162 WeightRole::HsIntro => self.as_middle, WeightRole::HsRend => self.as_middle, WeightRole::Unweighted => 1,
165 }
166 }
167}
168
169bitflags! {
170 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
175 struct WeightKind: u8 {
176 const GUARD = 1 << 0;
178 const EXIT = 1 << 1;
180 const DIR = 1 << 2;
182 }
183}
184
185impl WeightKind {
186 fn for_rs(rs: &MdRouterStatus) -> Self {
188 let mut r = WeightKind::empty();
189 if rs.is_flagged_guard() {
190 r |= WeightKind::GUARD;
191 }
192 if rs.is_flagged_exit() {
193 r |= WeightKind::EXIT;
194 }
195 if rs.is_flagged_v2dir() {
196 r |= WeightKind::DIR;
197 }
198 r
199 }
200 fn idx(self) -> usize {
202 self.bits() as usize
203 }
204}
205
206#[derive(Debug, Clone)]
209pub(crate) struct WeightSet {
210 bandwidth_fn: BandwidthFn,
216 shift: u8,
227 w: [RelayWeight; 8],
230}
231
232impl WeightSet {
233 pub(crate) fn weight_rs_for_role(&self, rs: &MdRouterStatus, role: WeightRole) -> u64 {
240 self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
241 }
242
243 fn weight_bw_for_role(
246 &self,
247 kind: WeightKind,
248 relay_weight: &netstatus::RelayWeight,
249 role: WeightRole,
250 ) -> u64 {
251 let ws = &self.w[kind.idx()];
252
253 let router_bw = self.bandwidth_fn.apply(relay_weight);
254 let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
259 router_weight >> self.shift
260 }
261
262 pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
264 let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
265 let weight_scale = params.bw_weight_scale.into();
266
267 let total_bw = consensus
268 .c_relays()
269 .iter()
270 .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
271 .sum();
272 let p = consensus.bandwidth_weights();
273
274 Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
275 }
276
277 fn from_parts(
281 bandwidth_fn: BandwidthFn,
282 total_bw: u64,
283 weight_scale: u32,
284 p: &NetParams<i32>,
285 ) -> Self {
286 #[allow(clippy::many_single_char_names)]
292 fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
293 RelayWeight {
294 as_guard: w_param(p, g),
295 as_middle: w_param(p, m),
296 as_exit: w_param(p, e),
297 as_dir: w_param(p, d),
298 }
299 }
300
301 let weight_scale = weight_scale.max(1);
304
305 let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
311 let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
312 let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
313 let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
314
315 let w = [
318 w_none,
319 w_guard,
320 w_exit,
321 w_both,
322 (w_none * w_param(p, "Wmb")) / weight_scale,
328 (w_guard * w_param(p, "Wgb")) / weight_scale,
329 (w_exit * w_param(p, "Web")) / weight_scale,
330 (w_both * w_param(p, "Wdb")) / weight_scale,
331 ];
332
333 #[allow(clippy::unwrap_used)]
336 let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
337
338 let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
340
341 WeightSet {
342 bandwidth_fn,
343 shift,
344 w,
345 }
346 }
347
348 fn validate(self, consensus: &MdConsensus) -> Self {
351 use WeightRole::*;
352 for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
353 let _: u64 = consensus
354 .c_relays()
355 .iter()
356 .map(|rs| self.weight_rs_for_role(rs, role))
357 .fold(0_u64, |a, b| {
358 a.checked_add(b)
359 .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
360 });
361 }
362 self
363 }
364}
365
366const DFLT_WEIGHT: i32 = 1;
373
374fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
379 if kwd == "---" {
380 0
381 } else {
382 clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
383 }
384}
385
386fn clamp_to_pos(inp: i32) -> u32 {
388 if inp < 0 { 0 } else { inp as u32 }
391}
392
393fn calculate_shift(a: u64, b: u64) -> u32 {
396 let bits_for_product = log2_upper(a) + log2_upper(b);
397 bits_for_product.saturating_sub(64)
398}
399
400fn log2_upper(n: u64) -> u32 {
405 64 - n.leading_zeros()
406}
407
408#[cfg(test)]
409mod test {
410 #![allow(clippy::bool_assert_comparison)]
412 #![allow(clippy::clone_on_copy)]
413 #![allow(clippy::dbg_macro)]
414 #![allow(clippy::mixed_attributes_style)]
415 #![allow(clippy::print_stderr)]
416 #![allow(clippy::print_stdout)]
417 #![allow(clippy::single_char_pattern)]
418 #![allow(clippy::unwrap_used)]
419 #![allow(clippy::unchecked_duration_subtraction)]
420 #![allow(clippy::useless_vec)]
421 #![allow(clippy::needless_pass_by_value)]
422 use super::*;
424 use netstatus::RelayWeight as RW;
425 use std::net::SocketAddr;
426 use std::time::{Duration, SystemTime};
427 use tor_basic_utils::test_rng::testing_rng;
428 use tor_netdoc::doc::netstatus::{Lifetime, MdRouterStatusBuilder, RelayFlags};
429
430 #[test]
431 fn t_clamp() {
432 assert_eq!(clamp_to_pos(32), 32);
433 assert_eq!(clamp_to_pos(i32::MAX), i32::MAX as u32);
434 assert_eq!(clamp_to_pos(0), 0);
435 assert_eq!(clamp_to_pos(-1), 0);
436 assert_eq!(clamp_to_pos(i32::MIN), 0);
437 }
438
439 #[test]
440 fn t_log2() {
441 assert_eq!(log2_upper(u64::MAX), 64);
442 assert_eq!(log2_upper(0), 0);
443 assert_eq!(log2_upper(1), 1);
444 assert_eq!(log2_upper(63), 6);
445 assert_eq!(log2_upper(64), 7); }
447
448 #[test]
449 fn t_calc_shift() {
450 assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
451 assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
452 assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
453 assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
454 assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
455 assert!(
456 ((432_u64 << 40) >> 38)
457 .checked_mul(7777_u64 << 40)
458 .is_some()
459 );
460 }
461
462 #[test]
463 fn t_pick_bwfunc() {
464 let empty = [];
465 assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
466
467 let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
468 assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
469
470 let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
471 assert_eq!(
472 pick_bandwidth_fn(all_unmeasured.iter()),
473 BandwidthFn::IncludeUnmeasured
474 );
475
476 let some_measured = [
477 RW::Unmeasured(10),
478 RW::Measured(7),
479 RW::Measured(4),
480 RW::Unmeasured(0),
481 ];
482 assert_eq!(
483 pick_bandwidth_fn(some_measured.iter()),
484 BandwidthFn::MeasuredOnly
485 );
486
487 let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
491 assert_eq!(
492 pick_bandwidth_fn(measured_all_zero.iter()),
493 BandwidthFn::Uniform
494 );
495 }
496
497 #[test]
498 fn t_apply_bwfn() {
499 use BandwidthFn::*;
500 use netstatus::RelayWeight::*;
501
502 assert_eq!(Uniform.apply(&Measured(7)), 1);
503 assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
504
505 assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
506 assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
507
508 assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
509 assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
510 }
511
512 const TESTVEC_PARAMS: &str = "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
514
515 #[test]
516 fn t_weightset_basic() {
517 let total_bandwidth = 1_000_000_000;
518 let params = TESTVEC_PARAMS.parse().unwrap();
519 let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, ¶ms);
520
521 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
522 assert_eq!(ws.shift, 0);
523
524 assert_eq!(ws.w[0].as_guard, 5904);
525 assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
526 assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
527 assert_eq!(
528 ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
529 0
530 );
531 assert_eq!(
532 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
533 4096
534 );
535 assert_eq!(
536 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
537 4096
538 );
539
540 assert_eq!(
541 ws.weight_bw_for_role(
542 WeightKind::GUARD | WeightKind::DIR,
543 &RW::Unmeasured(7777),
544 WeightRole::Guard
545 ),
546 0
547 );
548
549 assert_eq!(
550 ws.weight_bw_for_role(
551 WeightKind::GUARD | WeightKind::DIR,
552 &RW::Measured(7777),
553 WeightRole::Guard
554 ),
555 7777 * 5904
556 );
557
558 assert_eq!(
559 ws.weight_bw_for_role(
560 WeightKind::GUARD | WeightKind::DIR,
561 &RW::Measured(7777),
562 WeightRole::Middle
563 ),
564 7777 * 4096
565 );
566
567 assert_eq!(
568 ws.weight_bw_for_role(
569 WeightKind::GUARD | WeightKind::DIR,
570 &RW::Measured(7777),
571 WeightRole::Exit
572 ),
573 7777 * 10000
574 );
575
576 assert_eq!(
577 ws.weight_bw_for_role(
578 WeightKind::GUARD | WeightKind::DIR,
579 &RW::Measured(7777),
580 WeightRole::BeginDir
581 ),
582 7777 * 4096
583 );
584
585 assert_eq!(
586 ws.weight_bw_for_role(
587 WeightKind::GUARD | WeightKind::DIR,
588 &RW::Measured(7777),
589 WeightRole::Unweighted
590 ),
591 7777
592 );
593
594 let rs = rs_builder()
596 .set_flags(RelayFlags::GUARD | RelayFlags::V2DIR)
597 .weight(RW::Measured(7777))
598 .build()
599 .unwrap();
600 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
601 assert_eq!(
602 ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
603 7777 * 4096
604 );
605 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
606 }
607
608 fn rs_builder() -> MdRouterStatusBuilder {
611 MdConsensus::builder()
612 .rs()
613 .identity([9; 20].into())
614 .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
615 .doc_digest([9; 32])
616 .protos("".parse().unwrap())
617 .clone()
618 }
619
620 #[test]
621 fn weight_flags() {
622 let rs1 = rs_builder().set_flags(RelayFlags::EXIT).build().unwrap();
623 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
624
625 let rs1 = rs_builder().set_flags(RelayFlags::GUARD).build().unwrap();
626 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
627
628 let rs1 = rs_builder().set_flags(RelayFlags::V2DIR).build().unwrap();
629 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
630
631 let rs1 = rs_builder().build().unwrap();
632 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
633
634 let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
635 assert_eq!(
636 WeightKind::for_rs(&rs1),
637 WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
638 );
639 }
640
641 #[test]
642 fn weightset_from_consensus() {
643 use rand::Rng;
644 let now = SystemTime::now();
645 let one_hour = Duration::new(3600, 0);
646 let mut rng = testing_rng();
647 let mut bld = MdConsensus::builder();
648 bld.consensus_method(34)
649 .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
650 .weights(TESTVEC_PARAMS.parse().unwrap());
651
652 for _ in 0..10 {
655 rs_builder()
656 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Unmeasured(1_000_000))
658 .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
659 .build_into(&mut bld)
660 .unwrap();
661 }
662 for n in 0..30 {
663 rs_builder()
664 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Measured(1_000 * n))
666 .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
667 .build_into(&mut bld)
668 .unwrap();
669 }
670
671 let consensus = bld.testing_consensus().unwrap();
672 let params = NetParameters::default();
673 let ws = WeightSet::from_consensus(&consensus, ¶ms);
674
675 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
676 assert_eq!(ws.shift, 0);
677 assert_eq!(ws.w[0].as_guard, 5904);
678 assert_eq!(ws.w[5].as_guard, 5904);
679 assert_eq!(ws.w[5].as_middle, 4096);
680 }
681}