1use crate::ConsensusRelays;
16use crate::params::NetParameters;
17use bitflags::bitflags;
18use tor_netdoc::doc::netstatus::{self, MdConsensus, MdRouterStatus, NetParams};
19
20fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23where
24 I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25{
26 let has_measured = weights.clone().any(|w| w.is_measured());
27 let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28 let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29
30 if !has_nonzero {
31 BandwidthFn::Uniform
34 } else if !has_measured {
35 BandwidthFn::IncludeUnmeasured
38 } else if has_nonzero_measured {
39 BandwidthFn::MeasuredOnly
42 } else {
43 BandwidthFn::Uniform
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54enum BandwidthFn {
55 Uniform,
58 IncludeUnmeasured,
61 MeasuredOnly,
63}
64
65impl BandwidthFn {
66 fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69 use BandwidthFn::*;
70 use netstatus::RelayWeight::*;
71 match (self, w) {
72 (Uniform, _) => 1,
73 (IncludeUnmeasured, Unmeasured(u)) => *u,
74 (IncludeUnmeasured, Measured(m)) => *m,
75 (MeasuredOnly, Unmeasured(_)) => 0,
76 (MeasuredOnly, Measured(m)) => *m,
77 (_, _) => 0,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Copy)]
89#[non_exhaustive]
90pub enum WeightRole {
91 Guard,
93 Middle,
95 Exit,
97 BeginDir,
99 Unweighted,
101 HsIntro,
103 HsRend,
105}
106
107#[derive(Clone, Debug, Copy)]
109struct RelayWeight {
110 as_guard: u32,
112 as_middle: u32,
114 as_exit: u32,
116 as_dir: u32,
118}
119
120impl std::ops::Mul<u32> for RelayWeight {
121 type Output = Self;
122 fn mul(self, rhs: u32) -> Self {
123 RelayWeight {
124 as_guard: self.as_guard * rhs,
125 as_middle: self.as_middle * rhs,
126 as_exit: self.as_exit * rhs,
127 as_dir: self.as_dir * rhs,
128 }
129 }
130}
131impl std::ops::Div<u32> for RelayWeight {
132 type Output = Self;
133 fn div(self, rhs: u32) -> Self {
134 RelayWeight {
135 as_guard: self.as_guard / rhs,
136 as_middle: self.as_middle / rhs,
137 as_exit: self.as_exit / rhs,
138 as_dir: self.as_dir / rhs,
139 }
140 }
141}
142
143impl RelayWeight {
144 #[allow(clippy::unwrap_used)]
147 fn max_weight(&self) -> u32 {
148 [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
149 .iter()
150 .max()
151 .copied()
152 .unwrap()
153 }
154 fn for_role(&self, role: WeightRole) -> u32 {
157 match role {
158 WeightRole::Guard => self.as_guard,
159 WeightRole::Middle => self.as_middle,
160 WeightRole::Exit => self.as_exit,
161 WeightRole::BeginDir => self.as_dir,
162 WeightRole::HsIntro => self.as_middle, WeightRole::HsRend => self.as_middle, WeightRole::Unweighted => 1,
165 }
166 }
167}
168
169bitflags! {
170 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
175 struct WeightKind: u8 {
176 const GUARD = 1 << 0;
178 const EXIT = 1 << 1;
180 const DIR = 1 << 2;
182 }
183}
184
185impl WeightKind {
186 fn for_rs(rs: &MdRouterStatus) -> Self {
188 let mut r = WeightKind::empty();
189 if rs.is_flagged_guard() {
190 r |= WeightKind::GUARD;
191 }
192 if rs.is_flagged_exit() {
193 r |= WeightKind::EXIT;
194 }
195 if rs.is_flagged_v2dir() {
196 r |= WeightKind::DIR;
197 }
198 r
199 }
200 fn idx(self) -> usize {
202 self.bits() as usize
203 }
204}
205
206#[derive(Debug, Clone)]
209pub(crate) struct WeightSet {
210 bandwidth_fn: BandwidthFn,
216 shift: u8,
227 w: [RelayWeight; 8],
230}
231
232impl WeightSet {
233 pub(crate) fn weight_rs_for_role(&self, rs: &MdRouterStatus, role: WeightRole) -> u64 {
240 self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
241 }
242
243 fn weight_bw_for_role(
246 &self,
247 kind: WeightKind,
248 relay_weight: &netstatus::RelayWeight,
249 role: WeightRole,
250 ) -> u64 {
251 let ws = &self.w[kind.idx()];
252
253 let router_bw = self.bandwidth_fn.apply(relay_weight);
254 let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
259 router_weight >> self.shift
260 }
261
262 pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
264 let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
265 let weight_scale = params.bw_weight_scale.into();
266
267 let total_bw = consensus
268 .c_relays()
269 .iter()
270 .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
271 .sum();
272 let p = consensus.bandwidth_weights();
273
274 Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
275 }
276
277 fn from_parts(
281 bandwidth_fn: BandwidthFn,
282 total_bw: u64,
283 weight_scale: u32,
284 p: &NetParams<i32>,
285 ) -> Self {
286 #[allow(clippy::many_single_char_names)]
292 fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
293 RelayWeight {
294 as_guard: w_param(p, g),
295 as_middle: w_param(p, m),
296 as_exit: w_param(p, e),
297 as_dir: w_param(p, d),
298 }
299 }
300
301 let weight_scale = weight_scale.max(1);
304
305 let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
311 let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
312 let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
313 let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
314
315 let w = [
318 w_none,
319 w_guard,
320 w_exit,
321 w_both,
322 (w_none * w_param(p, "Wmb")) / weight_scale,
328 (w_guard * w_param(p, "Wgb")) / weight_scale,
329 (w_exit * w_param(p, "Web")) / weight_scale,
330 (w_both * w_param(p, "Wdb")) / weight_scale,
331 ];
332
333 #[allow(clippy::unwrap_used)]
336 let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
337
338 let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
340
341 WeightSet {
342 bandwidth_fn,
343 shift,
344 w,
345 }
346 }
347
348 fn validate(self, consensus: &MdConsensus) -> Self {
351 use WeightRole::*;
352 for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
353 let _: u64 = consensus
354 .c_relays()
355 .iter()
356 .map(|rs| self.weight_rs_for_role(rs, role))
357 .fold(0_u64, |a, b| {
358 a.checked_add(b)
359 .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
360 });
361 }
362 self
363 }
364}
365
366const DFLT_WEIGHT: i32 = 1;
373
374fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
379 if kwd == "---" {
380 0
381 } else {
382 clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
383 }
384}
385
386fn clamp_to_pos(inp: i32) -> u32 {
388 if inp < 0 { 0 } else { inp as u32 }
391}
392
393fn calculate_shift(a: u64, b: u64) -> u32 {
396 let bits_for_product = log2_upper(a) + log2_upper(b);
397 bits_for_product.saturating_sub(64)
398}
399
400fn log2_upper(n: u64) -> u32 {
405 64 - n.leading_zeros()
406}
407
408#[cfg(test)]
409mod test {
410 #![allow(clippy::bool_assert_comparison)]
412 #![allow(clippy::clone_on_copy)]
413 #![allow(clippy::dbg_macro)]
414 #![allow(clippy::mixed_attributes_style)]
415 #![allow(clippy::print_stderr)]
416 #![allow(clippy::print_stdout)]
417 #![allow(clippy::single_char_pattern)]
418 #![allow(clippy::unwrap_used)]
419 #![allow(clippy::unchecked_time_subtraction)]
420 #![allow(clippy::useless_vec)]
421 #![allow(clippy::needless_pass_by_value)]
422 use super::*;
424 use netstatus::RelayWeight as RW;
425 use std::net::SocketAddr;
426 use std::time::{Duration, SystemTime};
427 use tor_basic_utils::test_rng::testing_rng;
428 use tor_netdoc::doc::netstatus::{Lifetime, MdRouterStatusBuilder};
429 use tor_netdoc::types::relay_flags::{RelayFlag, RelayFlags};
430
431 #[test]
432 fn t_clamp() {
433 assert_eq!(clamp_to_pos(32), 32);
434 assert_eq!(clamp_to_pos(i32::MAX), i32::MAX as u32);
435 assert_eq!(clamp_to_pos(0), 0);
436 assert_eq!(clamp_to_pos(-1), 0);
437 assert_eq!(clamp_to_pos(i32::MIN), 0);
438 }
439
440 #[test]
441 fn t_log2() {
442 assert_eq!(log2_upper(u64::MAX), 64);
443 assert_eq!(log2_upper(0), 0);
444 assert_eq!(log2_upper(1), 1);
445 assert_eq!(log2_upper(63), 6);
446 assert_eq!(log2_upper(64), 7); }
448
449 #[test]
450 fn t_calc_shift() {
451 assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
452 assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
453 assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
454 assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
455 assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
456 assert!(
457 ((432_u64 << 40) >> 38)
458 .checked_mul(7777_u64 << 40)
459 .is_some()
460 );
461 }
462
463 #[test]
464 fn t_pick_bwfunc() {
465 let empty = [];
466 assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
467
468 let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
469 assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
470
471 let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
472 assert_eq!(
473 pick_bandwidth_fn(all_unmeasured.iter()),
474 BandwidthFn::IncludeUnmeasured
475 );
476
477 let some_measured = [
478 RW::Unmeasured(10),
479 RW::Measured(7),
480 RW::Measured(4),
481 RW::Unmeasured(0),
482 ];
483 assert_eq!(
484 pick_bandwidth_fn(some_measured.iter()),
485 BandwidthFn::MeasuredOnly
486 );
487
488 let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
492 assert_eq!(
493 pick_bandwidth_fn(measured_all_zero.iter()),
494 BandwidthFn::Uniform
495 );
496 }
497
498 #[test]
499 fn t_apply_bwfn() {
500 use BandwidthFn::*;
501 use netstatus::RelayWeight::*;
502
503 assert_eq!(Uniform.apply(&Measured(7)), 1);
504 assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
505
506 assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
507 assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
508
509 assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
510 assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
511 }
512
513 const TESTVEC_PARAMS: &str = "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
515
516 #[test]
517 fn t_weightset_basic() {
518 let total_bandwidth = 1_000_000_000;
519 let params = TESTVEC_PARAMS.parse().unwrap();
520 let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, ¶ms);
521
522 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
523 assert_eq!(ws.shift, 0);
524
525 assert_eq!(ws.w[0].as_guard, 5904);
526 assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
527 assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
528 assert_eq!(
529 ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
530 0
531 );
532 assert_eq!(
533 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
534 4096
535 );
536 assert_eq!(
537 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
538 4096
539 );
540
541 assert_eq!(
542 ws.weight_bw_for_role(
543 WeightKind::GUARD | WeightKind::DIR,
544 &RW::Unmeasured(7777),
545 WeightRole::Guard
546 ),
547 0
548 );
549
550 assert_eq!(
551 ws.weight_bw_for_role(
552 WeightKind::GUARD | WeightKind::DIR,
553 &RW::Measured(7777),
554 WeightRole::Guard
555 ),
556 7777 * 5904
557 );
558
559 assert_eq!(
560 ws.weight_bw_for_role(
561 WeightKind::GUARD | WeightKind::DIR,
562 &RW::Measured(7777),
563 WeightRole::Middle
564 ),
565 7777 * 4096
566 );
567
568 assert_eq!(
569 ws.weight_bw_for_role(
570 WeightKind::GUARD | WeightKind::DIR,
571 &RW::Measured(7777),
572 WeightRole::Exit
573 ),
574 7777 * 10000
575 );
576
577 assert_eq!(
578 ws.weight_bw_for_role(
579 WeightKind::GUARD | WeightKind::DIR,
580 &RW::Measured(7777),
581 WeightRole::BeginDir
582 ),
583 7777 * 4096
584 );
585
586 assert_eq!(
587 ws.weight_bw_for_role(
588 WeightKind::GUARD | WeightKind::DIR,
589 &RW::Measured(7777),
590 WeightRole::Unweighted
591 ),
592 7777
593 );
594
595 let rs = rs_builder()
597 .set_flags(RelayFlag::Guard | RelayFlag::V2Dir)
598 .weight(RW::Measured(7777))
599 .build()
600 .unwrap();
601 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
602 assert_eq!(
603 ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
604 7777 * 4096
605 );
606 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
607 }
608
609 fn rs_builder() -> MdRouterStatusBuilder {
612 MdConsensus::builder()
613 .rs()
614 .identity([9; 20].into())
615 .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
616 .doc_digest([9; 32])
617 .protos("".parse().unwrap())
618 .clone()
619 }
620
621 #[test]
622 fn weight_flags() {
623 let rs1 = rs_builder().set_flags(RelayFlag::Exit).build().unwrap();
624 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
625
626 let rs1 = rs_builder().set_flags(RelayFlag::Guard).build().unwrap();
627 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
628
629 let rs1 = rs_builder().set_flags(RelayFlag::V2Dir).build().unwrap();
630 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
631
632 let rs1 = rs_builder().build().unwrap();
633 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
634
635 let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
636 assert_eq!(
637 WeightKind::for_rs(&rs1),
638 WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
639 );
640 }
641
642 #[test]
643 fn weightset_from_consensus() {
644 use rand::Rng;
645 let now = SystemTime::now();
646 let one_hour = Duration::new(3600, 0);
647 let mut rng = testing_rng();
648 let mut bld = MdConsensus::builder();
649 bld.consensus_method(34)
650 .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
651 .weights(TESTVEC_PARAMS.parse().unwrap());
652
653 for _ in 0..10 {
656 rs_builder()
657 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Unmeasured(1_000_000))
659 .set_flags(RelayFlag::Guard | RelayFlag::Exit)
660 .build_into(&mut bld)
661 .unwrap();
662 }
663 for n in 0..30 {
664 rs_builder()
665 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Measured(1_000 * n))
667 .set_flags(RelayFlag::Guard | RelayFlag::Exit)
668 .build_into(&mut bld)
669 .unwrap();
670 }
671
672 let consensus = bld.testing_consensus().unwrap();
673 let params = NetParameters::default();
674 let ws = WeightSet::from_consensus(&consensus, ¶ms);
675
676 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
677 assert_eq!(ws.shift, 0);
678 assert_eq!(ws.w[0].as_guard, 5904);
679 assert_eq!(ws.w[5].as_guard, 5904);
680 assert_eq!(ws.w[5].as_middle, 4096);
681 }
682}