Skip to main content

stormdl_segment/
multi_source.rs

1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use stormdl_core::{ByteRange, MirrorSet, MirrorStats};
5
6pub struct MultiSourceManager {
7    mirrors: RwLock<MirrorSet>,
8    segment_assignments: RwLock<HashMap<usize, usize>>,
9    source_stats: RwLock<HashMap<usize, SourceStats>>,
10    #[allow(dead_code)]
11    total_size: u64,
12}
13
14struct SourceStats {
15    bytes_downloaded: AtomicU64,
16    errors: AtomicUsize,
17    active_segments: AtomicUsize,
18    speed_samples: RwLock<Vec<f64>>,
19}
20
21impl SourceStats {
22    fn new() -> Self {
23        Self {
24            bytes_downloaded: AtomicU64::new(0),
25            errors: AtomicUsize::new(0),
26            active_segments: AtomicUsize::new(0),
27            speed_samples: RwLock::new(Vec::with_capacity(10)),
28        }
29    }
30
31    fn avg_speed(&self) -> f64 {
32        let samples = self.speed_samples.read();
33        if samples.is_empty() {
34            return 0.0;
35        }
36        samples.iter().sum::<f64>() / samples.len() as f64
37    }
38}
39
40impl MultiSourceManager {
41    pub fn new(mirrors: MirrorSet, total_size: u64) -> Self {
42        Self {
43            mirrors: RwLock::new(mirrors),
44            segment_assignments: RwLock::new(HashMap::new()),
45            source_stats: RwLock::new(HashMap::new()),
46            total_size,
47        }
48    }
49
50    pub fn assign_segment(&self, segment_idx: usize, _range: ByteRange) -> usize {
51        let mirrors = self.mirrors.read();
52        let source_idx = mirrors.select_for_segment(segment_idx);
53        drop(mirrors);
54
55        self.segment_assignments
56            .write()
57            .insert(segment_idx, source_idx);
58
59        let mut stats = self.source_stats.write();
60        stats
61            .entry(source_idx)
62            .or_insert_with(SourceStats::new)
63            .active_segments
64            .fetch_add(1, Ordering::Relaxed);
65
66        source_idx
67    }
68
69    pub fn get_assignment(&self, segment_idx: usize) -> Option<usize> {
70        self.segment_assignments.read().get(&segment_idx).copied()
71    }
72
73    pub fn reassign_segment(&self, segment_idx: usize) -> Option<usize> {
74        let old_source = {
75            let assignments = self.segment_assignments.read();
76            assignments.get(&segment_idx).copied()
77        };
78
79        if let Some(old_idx) = old_source {
80            let stats = self.source_stats.read();
81            if let Some(source_stats) = stats.get(&old_idx) {
82                source_stats.active_segments.fetch_sub(1, Ordering::Relaxed);
83            }
84        }
85
86        let mirrors = self.mirrors.read();
87        let mirror_count = mirrors.len();
88
89        if mirror_count <= 1 {
90            return None;
91        }
92
93        let excluded = old_source.unwrap_or(usize::MAX);
94        let mut best_idx = None;
95        let mut best_score = f64::NEG_INFINITY;
96
97        for idx in 0..mirror_count {
98            if idx == excluded {
99                continue;
100            }
101
102            let stats_guard = self.source_stats.read();
103            let stats = stats_guard.get(&idx);
104
105            let speed = stats.map(|s| s.avg_speed()).unwrap_or(0.0);
106            let errors = stats.map(|s| s.errors.load(Ordering::Relaxed)).unwrap_or(0);
107            let active = stats
108                .map(|s| s.active_segments.load(Ordering::Relaxed))
109                .unwrap_or(0);
110
111            let error_penalty = 1.0 / (1.0 + errors as f64 * 0.5);
112            let load_factor = 1.0 / (1.0 + active as f64 * 0.1);
113            let score = (speed + 1.0) * error_penalty * load_factor;
114
115            if score > best_score {
116                best_score = score;
117                best_idx = Some(idx);
118            }
119        }
120
121        if let Some(new_idx) = best_idx {
122            self.segment_assignments
123                .write()
124                .insert(segment_idx, new_idx);
125
126            let mut stats = self.source_stats.write();
127            stats
128                .entry(new_idx)
129                .or_insert_with(SourceStats::new)
130                .active_segments
131                .fetch_add(1, Ordering::Relaxed);
132        }
133
134        best_idx
135    }
136
137    pub fn record_progress(&self, source_idx: usize, bytes: u64, speed: f64) {
138        let mut stats = self.source_stats.write();
139        let source_stats = stats.entry(source_idx).or_insert_with(SourceStats::new);
140
141        source_stats
142            .bytes_downloaded
143            .fetch_add(bytes, Ordering::Relaxed);
144
145        let mut samples = source_stats.speed_samples.write();
146        samples.push(speed);
147        if samples.len() > 10 {
148            samples.remove(0);
149        }
150    }
151
152    pub fn record_error(&self, source_idx: usize) {
153        let mut stats = self.source_stats.write();
154        stats
155            .entry(source_idx)
156            .or_insert_with(SourceStats::new)
157            .errors
158            .fetch_add(1, Ordering::Relaxed);
159    }
160
161    pub fn complete_segment(&self, segment_idx: usize) {
162        let source_idx = {
163            let assignments = self.segment_assignments.read();
164            assignments.get(&segment_idx).copied()
165        };
166
167        if let Some(idx) = source_idx {
168            let stats = self.source_stats.read();
169            if let Some(source_stats) = stats.get(&idx) {
170                source_stats.active_segments.fetch_sub(1, Ordering::Relaxed);
171            }
172        }
173    }
174
175    pub fn get_mirror_url(&self, source_idx: usize) -> Option<url::Url> {
176        self.mirrors.read().get(source_idx).map(|m| m.url.clone())
177    }
178
179    pub fn mirror_count(&self) -> usize {
180        self.mirrors.read().len()
181    }
182
183    pub fn sync_mirror_stats(&self) {
184        let stats_guard = self.source_stats.read();
185        let mut mirrors = self.mirrors.write();
186
187        for (idx, stats) in stats_guard.iter() {
188            let mirror_stats = MirrorStats {
189                bytes_downloaded: stats.bytes_downloaded.load(Ordering::Relaxed),
190                errors: stats.errors.load(Ordering::Relaxed),
191                avg_speed: stats.avg_speed(),
192                active_segments: stats.active_segments.load(Ordering::Relaxed),
193            };
194            mirrors.update_stats(*idx, mirror_stats);
195        }
196    }
197
198    pub fn get_source_summary(&self) -> Vec<(usize, u64, f64, usize)> {
199        let stats = self.source_stats.read();
200        stats
201            .iter()
202            .map(|(idx, s)| {
203                (
204                    *idx,
205                    s.bytes_downloaded.load(Ordering::Relaxed),
206                    s.avg_speed(),
207                    s.errors.load(Ordering::Relaxed),
208                )
209            })
210            .collect()
211    }
212}