1use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use parking_lot::RwLock;
17use serde::{Deserialize, Serialize};
18
19use super::throughput::{EffectiveThroughput, ThroughputConfig};
20use super::Uplink;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum SchedulingStrategy {
26 WeightedRoundRobin,
28 LowestLatency,
30 LowestLoss,
32 #[default]
34 Adaptive,
35 Redundant,
37 PrimaryBackup,
39 BandwidthProportional,
41 EcmpAware,
44 EffectiveThroughput,
47 LatencyAware,
50 SizeBased,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct SchedulerConfig {
58 #[serde(default)]
60 pub strategy: SchedulingStrategy,
61
62 #[serde(default = "default_rtt_threshold")]
64 pub rtt_threshold_ms: u32,
65
66 #[serde(default = "default_loss_threshold")]
68 pub loss_threshold_percent: f32,
69
70 #[serde(default = "default_rtt_weight")]
72 pub rtt_weight: f32,
73
74 #[serde(default = "default_loss_weight")]
76 pub loss_weight: f32,
77
78 #[serde(default = "default_bw_weight")]
80 pub bandwidth_weight: f32,
81
82 #[serde(default = "default_nat_weight")]
85 pub nat_penalty_weight: f32,
86
87 #[serde(default = "default_sticky")]
89 pub sticky_paths: bool,
90
91 #[serde(default = "default_sticky_timeout", with = "humantime_serde")]
93 pub sticky_timeout: Duration,
94
95 #[serde(default = "default_probe")]
97 pub probe_backup_paths: bool,
98
99 #[serde(default = "default_probe_interval", with = "humantime_serde")]
101 pub probe_interval: Duration,
102
103 #[serde(default)]
105 pub throughput: ThroughputConfig,
106
107 #[serde(default = "default_max_latency")]
110 pub max_acceptable_latency_ms: u32,
111
112 #[serde(default = "default_size_threshold")]
115 pub size_threshold_bytes: u64,
116
117 #[serde(default = "default_throughput_aware")]
120 pub throughput_aware: bool,
121
122 #[serde(default = "default_effective_throughput_weight")]
125 pub effective_throughput_weight: f32,
126
127 #[serde(default = "default_prevent_latency_blocking")]
130 pub prevent_latency_blocking: bool,
131
132 #[serde(default = "default_latency_blocking_ratio")]
135 pub latency_blocking_ratio: f32,
136}
137
138fn default_rtt_threshold() -> u32 {
139 10
140}
141fn default_loss_threshold() -> f32 {
142 2.0
143}
144fn default_rtt_weight() -> f32 {
145 0.25
146}
147fn default_loss_weight() -> f32 {
148 0.25
149}
150fn default_bw_weight() -> f32 {
151 0.35
152}
153fn default_nat_weight() -> f32 {
154 0.05
155}
156fn default_sticky() -> bool {
157 true
158}
159fn default_sticky_timeout() -> Duration {
160 Duration::from_secs(5)
161}
162fn default_probe() -> bool {
163 true
164}
165fn default_probe_interval() -> Duration {
166 Duration::from_secs(1)
167}
168fn default_max_latency() -> u32 {
169 500
170}
171fn default_size_threshold() -> u64 {
172 64 * 1024
173} fn default_throughput_aware() -> bool {
175 true
176}
177fn default_effective_throughput_weight() -> f32 {
178 0.10
179}
180fn default_prevent_latency_blocking() -> bool {
181 true
182}
183fn default_latency_blocking_ratio() -> f32 {
184 10.0
185}
186
187impl Default for SchedulerConfig {
188 fn default() -> Self {
189 Self {
190 strategy: SchedulingStrategy::default(),
191 rtt_threshold_ms: default_rtt_threshold(),
192 loss_threshold_percent: default_loss_threshold(),
193 rtt_weight: default_rtt_weight(),
194 loss_weight: default_loss_weight(),
195 bandwidth_weight: default_bw_weight(),
196 nat_penalty_weight: default_nat_weight(),
197 sticky_paths: default_sticky(),
198 sticky_timeout: default_sticky_timeout(),
199 probe_backup_paths: default_probe(),
200 probe_interval: default_probe_interval(),
201 throughput: ThroughputConfig::default(),
202 max_acceptable_latency_ms: default_max_latency(),
203 size_threshold_bytes: default_size_threshold(),
204 throughput_aware: default_throughput_aware(),
205 effective_throughput_weight: default_effective_throughput_weight(),
206 prevent_latency_blocking: default_prevent_latency_blocking(),
207 latency_blocking_ratio: default_latency_blocking_ratio(),
208 }
209 }
210}
211
212#[derive(Debug, Default)]
214struct WrrState {
215 current_index: usize,
216 current_weight: u32,
217}
218
219#[derive(Debug)]
221struct PathStickiness {
222 flows: HashMap<u64, (u16, Instant)>,
224}
225
226impl PathStickiness {
227 fn new() -> Self {
228 Self {
229 flows: HashMap::new(),
230 }
231 }
232
233 fn get(&self, flow_id: u64, timeout: Duration) -> Option<u16> {
234 self.flows.get(&flow_id).and_then(|(uplink, last)| {
235 if last.elapsed() < timeout {
236 Some(*uplink)
237 } else {
238 None
239 }
240 })
241 }
242
243 fn set(&mut self, flow_id: u64, uplink_id: u16) {
244 self.flows.insert(flow_id, (uplink_id, Instant::now()));
245 }
246
247 fn cleanup(&mut self, timeout: Duration) {
248 self.flows.retain(|_, (_, last)| last.elapsed() < timeout);
249 }
250}
251
252pub struct Scheduler {
254 config: SchedulerConfig,
255 wrr_state: RwLock<WrrState>,
256 stickiness: RwLock<PathStickiness>,
257 last_probe: RwLock<HashMap<u16, Instant>>,
258 throughput_cache: RwLock<HashMap<u16, (Instant, EffectiveThroughput)>>,
260 cache_ttl: Duration,
262}
263
264impl Scheduler {
265 pub fn new(config: SchedulerConfig) -> Self {
267 Self {
268 config,
269 wrr_state: RwLock::new(WrrState::default()),
270 stickiness: RwLock::new(PathStickiness::new()),
271 last_probe: RwLock::new(HashMap::new()),
272 throughput_cache: RwLock::new(HashMap::new()),
273 cache_ttl: Duration::from_millis(100),
274 }
275 }
276
277 pub fn config(&self) -> &SchedulerConfig {
279 &self.config
280 }
281
282 pub fn select(&self, uplinks: &[Arc<Uplink>], flow_id: Option<u64>) -> Vec<u16> {
288 self.select_for_size(uplinks, flow_id, None)
289 }
290
291 pub fn select_for_size(
294 &self,
295 uplinks: &[Arc<Uplink>],
296 flow_id: Option<u64>,
297 size_bytes: Option<u64>,
298 ) -> Vec<u16> {
299 let usable: Vec<_> = if self.config.prevent_latency_blocking {
301 self.filter_latency_blocked(uplinks)
302 } else {
303 uplinks.iter().filter(|u| u.is_usable()).collect()
304 };
305
306 if usable.is_empty() {
307 return vec![];
308 }
309
310 if self.config.sticky_paths {
312 if let Some(flow) = flow_id {
313 let sticky = self.stickiness.read().get(flow, self.config.sticky_timeout);
314 if let Some(sticky_uplink) = sticky {
315 if usable.iter().any(|u| u.numeric_id() == sticky_uplink) {
316 return vec![sticky_uplink];
317 }
318 }
319 }
320 }
321
322 let selected = match self.config.strategy {
323 SchedulingStrategy::WeightedRoundRobin => self.select_wrr(&usable),
324 SchedulingStrategy::LowestLatency => Self::select_lowest_latency(&usable),
325 SchedulingStrategy::LowestLoss => Self::select_lowest_loss(&usable),
326 SchedulingStrategy::Adaptive => self.select_adaptive(&usable),
327 SchedulingStrategy::Redundant => Self::select_redundant(&usable),
328 SchedulingStrategy::PrimaryBackup => Self::select_primary_backup(&usable),
329 SchedulingStrategy::BandwidthProportional => {
330 self.select_bandwidth_proportional(&usable)
331 }
332 SchedulingStrategy::EcmpAware => Self::select_ecmp_aware(&usable, flow_id),
333 SchedulingStrategy::EffectiveThroughput => self.select_effective_throughput(&usable),
334 SchedulingStrategy::LatencyAware => self.select_latency_aware(&usable, size_bytes),
335 SchedulingStrategy::SizeBased => self.select_size_based(&usable, size_bytes),
336 };
337
338 if self.config.sticky_paths && !selected.is_empty() {
340 if let Some(flow) = flow_id {
341 self.stickiness.write().set(flow, selected[0]);
342 }
343 }
344
345 selected
346 }
347
348 fn filter_latency_blocked<'a>(&self, uplinks: &'a [Arc<Uplink>]) -> Vec<&'a Arc<Uplink>> {
351 let usable: Vec<_> = uplinks.iter().filter(|u| u.is_usable()).collect();
352
353 if usable.is_empty() {
354 return usable;
355 }
356
357 let min_rtt = usable
359 .iter()
360 .map(|u| u.rtt())
361 .min()
362 .unwrap_or(Duration::ZERO);
363
364 if min_rtt == Duration::ZERO {
365 return usable;
366 }
367
368 let threshold = Duration::from_secs_f64(
369 min_rtt.as_secs_f64() * self.config.latency_blocking_ratio as f64,
370 );
371
372 usable
373 .into_iter()
374 .filter(|u| u.rtt() <= threshold)
375 .collect()
376 }
377
378 fn select_wrr(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
380 if uplinks.is_empty() {
381 return vec![];
382 }
383
384 let mut state = self.wrr_state.write();
385
386 let max_weight: u32 = uplinks.iter().map(|u| u.config().weight).max().unwrap_or(1);
388
389 loop {
390 state.current_index = (state.current_index + 1) % uplinks.len();
391
392 if state.current_index == 0 {
393 if state.current_weight == 0 {
394 state.current_weight = max_weight;
395 } else {
396 state.current_weight -= 1;
397 }
398 }
399
400 let uplink = &uplinks[state.current_index];
401 if uplink.config().weight >= state.current_weight && uplink.can_send() {
402 return vec![uplink.numeric_id()];
403 }
404
405 if state.current_weight == 0 && state.current_index == 0 {
407 break;
408 }
409 }
410
411 uplinks
413 .first()
414 .map(|u| vec![u.numeric_id()])
415 .unwrap_or_default()
416 }
417
418 fn select_lowest_latency(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
420 uplinks
421 .iter()
422 .filter(|u| u.can_send())
423 .min_by_key(|u| u.rtt())
424 .map(|u| vec![u.numeric_id()])
425 .unwrap_or_default()
426 }
427
428 fn select_lowest_loss(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
430 uplinks
431 .iter()
432 .filter(|u| u.can_send())
433 .min_by(|a, b| {
434 a.loss_ratio()
435 .partial_cmp(&b.loss_ratio())
436 .unwrap_or(std::cmp::Ordering::Equal)
437 })
438 .map(|u| vec![u.numeric_id()])
439 .unwrap_or_default()
440 }
441
442 fn select_adaptive(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
445 let mut scored: Vec<_> = uplinks
447 .iter()
448 .filter(|u| u.can_send())
449 .map(|u| {
450 let rtt_score = Self::rtt_score(u);
451 let loss_score = Self::loss_score(u);
452 let bw_score = Self::bandwidth_score(u, uplinks);
453 let nat_score = Self::nat_score(u);
454
455 let mut total_score = rtt_score * self.config.rtt_weight
456 + loss_score * self.config.loss_weight
457 + bw_score * self.config.bandwidth_weight
458 + nat_score * self.config.nat_penalty_weight;
459
460 if self.config.throughput_aware {
462 let eff_throughput = self.calculate_effective_throughput(u);
463 total_score +=
464 eff_throughput.score as f32 * self.config.effective_throughput_weight;
465 }
466
467 (u.numeric_id(), total_score)
468 })
469 .collect();
470
471 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
473
474 scored.into_iter().map(|(id, _)| id).take(1).collect()
475 }
476
477 fn select_effective_throughput(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
479 let mut scored: Vec<_> = uplinks
480 .iter()
481 .filter(|u| u.can_send())
482 .map(|u| {
483 let throughput = self.calculate_effective_throughput(u);
484 (u.numeric_id(), throughput.score)
485 })
486 .collect();
487
488 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489
490 scored.into_iter().map(|(id, _)| id).take(1).collect()
491 }
492
493 fn select_latency_aware(&self, uplinks: &[&Arc<Uplink>], size_bytes: Option<u64>) -> Vec<u16> {
495 let size = size_bytes.unwrap_or(self.config.size_threshold_bytes);
496
497 let mut scored: Vec<_> = uplinks
498 .iter()
499 .filter(|u| u.can_send())
500 .map(|u| {
501 let throughput = self.calculate_effective_throughput(u);
502 let transfer_time = throughput.transfer_time(size);
503 (u.numeric_id(), transfer_time)
504 })
505 .collect();
506
507 scored.sort_by(|a, b| a.1.cmp(&b.1));
509
510 scored.into_iter().map(|(id, _)| id).take(1).collect()
511 }
512
513 fn select_size_based(&self, uplinks: &[&Arc<Uplink>], size_bytes: Option<u64>) -> Vec<u16> {
515 let size = size_bytes.unwrap_or(self.config.size_threshold_bytes);
516
517 if size < self.config.size_threshold_bytes {
518 Self::select_lowest_latency(uplinks)
520 } else {
521 self.select_effective_throughput(uplinks)
523 }
524 }
525
526 fn calculate_effective_throughput(&self, uplink: &Uplink) -> EffectiveThroughput {
528 let uplink_id = uplink.numeric_id();
529
530 {
532 let cache = self.throughput_cache.read();
533 if let Some((cached_at, throughput)) = cache.get(&uplink_id) {
534 if cached_at.elapsed() < self.cache_ttl {
535 return *throughput;
536 }
537 }
538 }
539
540 let metrics = uplink.quality_metrics();
542 let throughput = EffectiveThroughput::calculate(
543 uplink.bandwidth().bytes_per_sec,
544 uplink.rtt(),
545 uplink.loss_ratio(),
546 metrics.jitter,
547 &self.config.throughput,
548 );
549
550 {
552 let mut cache = self.throughput_cache.write();
553 cache.insert(uplink_id, (Instant::now(), throughput));
554 }
555
556 throughput
557 }
558
559 fn select_redundant(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
561 uplinks
562 .iter()
563 .filter(|u| u.can_send())
564 .map(|u| u.numeric_id())
565 .collect()
566 }
567
568 fn select_primary_backup(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
570 let mut sorted: Vec<_> = uplinks.iter().collect();
572 sorted.sort_by_key(|u| std::cmp::Reverse(u.priority_score()));
573
574 let mut result = Vec::new();
576 for uplink in sorted {
577 if uplink.can_send() {
578 result.push(uplink.numeric_id());
579 if result.len() >= 2 {
580 break;
581 }
582 }
583 }
584 result
585 }
586
587 fn select_bandwidth_proportional(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
589 let total_bw: f64 = uplinks
591 .iter()
592 .filter(|u| u.can_send())
593 .map(|u| u.bandwidth().bytes_per_sec)
594 .sum();
595
596 if total_bw == 0.0 {
597 return self.select_wrr(uplinks);
598 }
599
600 let r: f64 = rand::random();
602 let mut cumulative = 0.0;
603
604 for uplink in uplinks.iter().filter(|u| u.can_send()) {
605 cumulative += uplink.bandwidth().bytes_per_sec / total_bw;
606 if r <= cumulative {
607 return vec![uplink.numeric_id()];
608 }
609 }
610
611 uplinks
612 .first()
613 .map(|u| vec![u.numeric_id()])
614 .unwrap_or_default()
615 }
616
617 fn select_ecmp_aware(uplinks: &[&Arc<Uplink>], flow_id: Option<u64>) -> Vec<u16> {
624 let sendable: Vec<_> = uplinks.iter().filter(|u| u.can_send()).collect();
625
626 if sendable.is_empty() {
627 return vec![];
628 }
629
630 if let Some(flow) = flow_id {
632 let index = (flow as usize) % sendable.len();
635 return vec![sendable[index].numeric_id()];
636 }
637
638 sendable
640 .iter()
641 .max_by(|a, b| {
642 let score_a = a.priority_score();
643 let score_b = b.priority_score();
644 score_a.cmp(&score_b)
645 })
646 .map(|u| vec![u.numeric_id()])
647 .unwrap_or_default()
648 }
649
650 fn rtt_score(uplink: &Uplink) -> f32 {
652 let rtt = uplink.rtt().as_secs_f32() * 1000.0; 1.0 / (1.0 + rtt / 50.0)
655 }
656
657 fn loss_score(uplink: &Uplink) -> f32 {
659 let loss = uplink.loss_ratio() as f32;
660 1.0 - loss.min(1.0)
661 }
662
663 fn bandwidth_score(uplink: &Uplink, all: &[&Arc<Uplink>]) -> f32 {
665 let bw = uplink.bandwidth().bytes_per_sec;
666 let max_bw: f64 = all
667 .iter()
668 .map(|u| u.bandwidth().bytes_per_sec)
669 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
670 .unwrap_or(1.0);
671
672 if max_bw == 0.0 {
673 0.5
674 } else {
675 (bw / max_bw) as f32
676 }
677 }
678
679 fn nat_score(uplink: &Uplink) -> f32 {
682 if !uplink.is_natted() {
683 return 1.0;
684 }
685
686 match uplink.nat_type() {
688 super::nat::NatType::None => 1.0,
689 super::nat::NatType::Unknown => 0.7, super::nat::NatType::FullCone => 0.8, super::nat::NatType::RestrictedCone => 0.6,
692 super::nat::NatType::PortRestrictedCone => 0.4,
693 super::nat::NatType::Symmetric => 0.2, }
695 }
696
697 pub fn needs_probe(&self, uplink: &Uplink) -> bool {
699 if !self.config.probe_backup_paths {
700 return false;
701 }
702
703 let probes = self.last_probe.read();
704 match probes.get(&uplink.numeric_id()) {
705 Some(last) => last.elapsed() >= self.config.probe_interval,
706 None => true,
707 }
708 }
709
710 pub fn record_probe(&self, uplink_id: u16) {
712 self.last_probe.write().insert(uplink_id, Instant::now());
713 }
714
715 pub fn cleanup(&self) {
717 self.stickiness.write().cleanup(self.config.sticky_timeout);
718
719 let timeout = self.config.probe_interval * 10;
721 self.last_probe
722 .write()
723 .retain(|_, last| last.elapsed() < timeout);
724
725 self.throughput_cache
727 .write()
728 .retain(|_, (cached_at, _)| cached_at.elapsed() < self.cache_ttl * 10);
729 }
730
731 pub fn uplinks_to_probe(&self, uplinks: &[Arc<Uplink>]) -> Vec<u16> {
733 uplinks
734 .iter()
735 .filter(|u| u.is_usable() && self.needs_probe(u))
736 .map(|u| u.numeric_id())
737 .collect()
738 }
739}
740
741#[allow(clippy::missing_fields_in_debug)]
743impl std::fmt::Debug for Scheduler {
744 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
745 f.debug_struct("Scheduler")
746 .field("strategy", &self.config.strategy)
747 .finish()
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[test]
756 fn test_scheduler_creation() {
757 let scheduler = Scheduler::new(SchedulerConfig::default());
758 assert_eq!(scheduler.config.strategy, SchedulingStrategy::Adaptive);
759 }
760
761 #[test]
762 fn test_empty_uplinks() {
763 let scheduler = Scheduler::new(SchedulerConfig::default());
764 let result = scheduler.select(&[], None);
765 assert!(result.is_empty());
766 }
767}