stormdl_segment/
multi_source.rs1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use stormdl_core::{ByteRange, MirrorSet, MirrorStats};
5
6pub struct MultiSourceManager {
7 mirrors: RwLock<MirrorSet>,
8 segment_assignments: RwLock<HashMap<usize, usize>>,
9 source_stats: RwLock<HashMap<usize, SourceStats>>,
10 #[allow(dead_code)]
11 total_size: u64,
12}
13
14struct SourceStats {
15 bytes_downloaded: AtomicU64,
16 errors: AtomicUsize,
17 active_segments: AtomicUsize,
18 speed_samples: RwLock<Vec<f64>>,
19}
20
21impl SourceStats {
22 fn new() -> Self {
23 Self {
24 bytes_downloaded: AtomicU64::new(0),
25 errors: AtomicUsize::new(0),
26 active_segments: AtomicUsize::new(0),
27 speed_samples: RwLock::new(Vec::with_capacity(10)),
28 }
29 }
30
31 fn avg_speed(&self) -> f64 {
32 let samples = self.speed_samples.read();
33 if samples.is_empty() {
34 return 0.0;
35 }
36 samples.iter().sum::<f64>() / samples.len() as f64
37 }
38}
39
40impl MultiSourceManager {
41 pub fn new(mirrors: MirrorSet, total_size: u64) -> Self {
42 Self {
43 mirrors: RwLock::new(mirrors),
44 segment_assignments: RwLock::new(HashMap::new()),
45 source_stats: RwLock::new(HashMap::new()),
46 total_size,
47 }
48 }
49
50 pub fn assign_segment(&self, segment_idx: usize, _range: ByteRange) -> usize {
51 let mirrors = self.mirrors.read();
52 let source_idx = mirrors.select_for_segment(segment_idx);
53 drop(mirrors);
54
55 self.segment_assignments
56 .write()
57 .insert(segment_idx, source_idx);
58
59 let mut stats = self.source_stats.write();
60 stats
61 .entry(source_idx)
62 .or_insert_with(SourceStats::new)
63 .active_segments
64 .fetch_add(1, Ordering::Relaxed);
65
66 source_idx
67 }
68
69 pub fn get_assignment(&self, segment_idx: usize) -> Option<usize> {
70 self.segment_assignments.read().get(&segment_idx).copied()
71 }
72
73 pub fn reassign_segment(&self, segment_idx: usize) -> Option<usize> {
74 let old_source = {
75 let assignments = self.segment_assignments.read();
76 assignments.get(&segment_idx).copied()
77 };
78
79 if let Some(old_idx) = old_source {
80 let stats = self.source_stats.read();
81 if let Some(source_stats) = stats.get(&old_idx) {
82 source_stats.active_segments.fetch_sub(1, Ordering::Relaxed);
83 }
84 }
85
86 let mirrors = self.mirrors.read();
87 let mirror_count = mirrors.len();
88
89 if mirror_count <= 1 {
90 return None;
91 }
92
93 let excluded = old_source.unwrap_or(usize::MAX);
94 let mut best_idx = None;
95 let mut best_score = f64::NEG_INFINITY;
96
97 for idx in 0..mirror_count {
98 if idx == excluded {
99 continue;
100 }
101
102 let stats_guard = self.source_stats.read();
103 let stats = stats_guard.get(&idx);
104
105 let speed = stats.map(|s| s.avg_speed()).unwrap_or(0.0);
106 let errors = stats.map(|s| s.errors.load(Ordering::Relaxed)).unwrap_or(0);
107 let active = stats
108 .map(|s| s.active_segments.load(Ordering::Relaxed))
109 .unwrap_or(0);
110
111 let error_penalty = 1.0 / (1.0 + errors as f64 * 0.5);
112 let load_factor = 1.0 / (1.0 + active as f64 * 0.1);
113 let score = (speed + 1.0) * error_penalty * load_factor;
114
115 if score > best_score {
116 best_score = score;
117 best_idx = Some(idx);
118 }
119 }
120
121 if let Some(new_idx) = best_idx {
122 self.segment_assignments
123 .write()
124 .insert(segment_idx, new_idx);
125
126 let mut stats = self.source_stats.write();
127 stats
128 .entry(new_idx)
129 .or_insert_with(SourceStats::new)
130 .active_segments
131 .fetch_add(1, Ordering::Relaxed);
132 }
133
134 best_idx
135 }
136
137 pub fn record_progress(&self, source_idx: usize, bytes: u64, speed: f64) {
138 let mut stats = self.source_stats.write();
139 let source_stats = stats.entry(source_idx).or_insert_with(SourceStats::new);
140
141 source_stats
142 .bytes_downloaded
143 .fetch_add(bytes, Ordering::Relaxed);
144
145 let mut samples = source_stats.speed_samples.write();
146 samples.push(speed);
147 if samples.len() > 10 {
148 samples.remove(0);
149 }
150 }
151
152 pub fn record_error(&self, source_idx: usize) {
153 let mut stats = self.source_stats.write();
154 stats
155 .entry(source_idx)
156 .or_insert_with(SourceStats::new)
157 .errors
158 .fetch_add(1, Ordering::Relaxed);
159 }
160
161 pub fn complete_segment(&self, segment_idx: usize) {
162 let source_idx = {
163 let assignments = self.segment_assignments.read();
164 assignments.get(&segment_idx).copied()
165 };
166
167 if let Some(idx) = source_idx {
168 let stats = self.source_stats.read();
169 if let Some(source_stats) = stats.get(&idx) {
170 source_stats.active_segments.fetch_sub(1, Ordering::Relaxed);
171 }
172 }
173 }
174
175 pub fn get_mirror_url(&self, source_idx: usize) -> Option<url::Url> {
176 self.mirrors.read().get(source_idx).map(|m| m.url.clone())
177 }
178
179 pub fn mirror_count(&self) -> usize {
180 self.mirrors.read().len()
181 }
182
183 pub fn sync_mirror_stats(&self) {
184 let stats_guard = self.source_stats.read();
185 let mut mirrors = self.mirrors.write();
186
187 for (idx, stats) in stats_guard.iter() {
188 let mirror_stats = MirrorStats {
189 bytes_downloaded: stats.bytes_downloaded.load(Ordering::Relaxed),
190 errors: stats.errors.load(Ordering::Relaxed),
191 avg_speed: stats.avg_speed(),
192 active_segments: stats.active_segments.load(Ordering::Relaxed),
193 };
194 mirrors.update_stats(*idx, mirror_stats);
195 }
196 }
197
198 pub fn get_source_summary(&self) -> Vec<(usize, u64, f64, usize)> {
199 let stats = self.source_stats.read();
200 stats
201 .iter()
202 .map(|(idx, s)| {
203 (
204 *idx,
205 s.bytes_downloaded.load(Ordering::Relaxed),
206 s.avg_speed(),
207 s.errors.load(Ordering::Relaxed),
208 )
209 })
210 .collect()
211 }
212}