1use async_trait::async_trait;
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15use xxhash_rust::xxh3::xxh3_64;
16
17use sentinel_common::errors::{SentinelError, SentinelResult};
18
19use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
20
21#[derive(Debug, Clone)]
23pub struct MaglevConfig {
24 pub table_size: usize,
26 pub key_source: MaglevKeySource,
28}
29
30impl Default for MaglevConfig {
31 fn default() -> Self {
32 Self {
33 table_size: 65537,
35 key_source: MaglevKeySource::ClientIp,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub enum MaglevKeySource {
43 ClientIp,
45 Header(String),
47 Cookie(String),
49 Path,
51}
52
53pub struct MaglevBalancer {
55 targets: Vec<UpstreamTarget>,
57 lookup_table: Arc<RwLock<Vec<Option<usize>>>>,
59 health_status: Arc<RwLock<HashMap<String, bool>>>,
61 config: MaglevConfig,
63 generation: Arc<RwLock<u64>>,
65}
66
67impl MaglevBalancer {
68 pub fn new(targets: Vec<UpstreamTarget>, config: MaglevConfig) -> Self {
70 let mut health_status = HashMap::new();
71 for target in &targets {
72 health_status.insert(target.full_address(), true);
73 }
74
75 let table_size = config.table_size;
76 let balancer = Self {
77 targets,
78 lookup_table: Arc::new(RwLock::new(vec![None; table_size])),
79 health_status: Arc::new(RwLock::new(health_status)),
80 config,
81 generation: Arc::new(RwLock::new(0)),
82 };
83
84 let targets_clone = balancer.targets.clone();
87 let table_size = balancer.config.table_size;
88 let table = Self::build_lookup_table(&targets_clone, table_size);
89
90 if let Ok(mut lookup) = balancer.lookup_table.try_write() {
92 *lookup = table;
93 }
94
95 balancer
96 }
97
98 fn build_lookup_table(targets: &[UpstreamTarget], table_size: usize) -> Vec<Option<usize>> {
100 if targets.is_empty() {
101 return vec![None; table_size];
102 }
103
104 let n = targets.len();
105 let m = table_size;
106
107 let permutations: Vec<Vec<usize>> = targets
109 .iter()
110 .map(|target| Self::generate_permutation(&target.full_address(), m))
111 .collect();
112
113 let mut table = vec![None; m];
115 let mut next = vec![0usize; n]; let mut filled = 0;
117
118 while filled < m {
119 for i in 0..n {
120 loop {
122 let c = permutations[i][next[i]];
123 next[i] += 1;
124
125 if table[c].is_none() {
126 table[c] = Some(i);
127 filled += 1;
128 break;
129 }
130
131 if next[i] >= m {
133 next[i] = 0;
134 break;
135 }
136 }
137
138 if filled >= m {
139 break;
140 }
141 }
142 }
143
144 table
145 }
146
147 fn generate_permutation(name: &str, table_size: usize) -> Vec<usize> {
149 let m = table_size;
150
151 let h1 = xxh3_64(name.as_bytes()) as usize;
153 let h2 = {
154 let mut hasher = std::collections::hash_map::DefaultHasher::new();
155 name.hash(&mut hasher);
156 hasher.finish() as usize
157 };
158
159 let offset = h1 % m;
161 let skip = (h2 % (m - 1)) + 1; (0..m).map(|i| (offset + i * skip) % m).collect()
165 }
166
167 async fn rebuild_table_for_healthy(&self) {
169 let health = self.health_status.read().await;
170 let healthy_targets: Vec<_> = self
171 .targets
172 .iter()
173 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
174 .cloned()
175 .collect();
176 drop(health);
177
178 if healthy_targets.is_empty() {
179 return;
181 }
182
183 let table = Self::build_lookup_table(&healthy_targets, self.config.table_size);
184
185 let mut lookup = self.lookup_table.write().await;
186 *lookup = table;
187
188 let mut gen = self.generation.write().await;
189 *gen += 1;
190
191 debug!(
192 healthy_count = healthy_targets.len(),
193 total_count = self.targets.len(),
194 generation = *gen,
195 "Maglev lookup table rebuilt"
196 );
197 }
198
199 fn extract_key(&self, context: Option<&RequestContext>) -> String {
201 match &self.config.key_source {
202 MaglevKeySource::ClientIp => context
203 .and_then(|c| c.client_ip.map(|ip| ip.ip().to_string()))
204 .unwrap_or_else(|| "default".to_string()),
205 MaglevKeySource::Header(name) => context
206 .and_then(|c| c.headers.get(name).cloned())
207 .unwrap_or_else(|| "default".to_string()),
208 MaglevKeySource::Cookie(name) => context
209 .and_then(|c| {
210 c.headers.get("cookie").and_then(|cookies| {
211 cookies.split(';').find_map(|cookie| {
212 let mut parts = cookie.trim().splitn(2, '=');
213 let key = parts.next()?;
214 let value = parts.next()?;
215 if key == name {
216 Some(value.to_string())
217 } else {
218 None
219 }
220 })
221 })
222 })
223 .unwrap_or_else(|| "default".to_string()),
224 MaglevKeySource::Path => context
225 .map(|c| c.path.clone())
226 .unwrap_or_else(|| "/".to_string()),
227 }
228 }
229
230 async fn get_healthy_targets(&self) -> Vec<&UpstreamTarget> {
232 let health = self.health_status.read().await;
233 self.targets
234 .iter()
235 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
236 .collect()
237 }
238}
239
240#[async_trait]
241impl LoadBalancer for MaglevBalancer {
242 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
243 trace!(
244 total_targets = self.targets.len(),
245 algorithm = "maglev",
246 "Selecting upstream target"
247 );
248
249 let health = self.health_status.read().await;
251 let healthy_targets: Vec<_> = self
252 .targets
253 .iter()
254 .enumerate()
255 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
256 .collect();
257 drop(health);
258
259 if healthy_targets.is_empty() {
260 warn!(
261 total_targets = self.targets.len(),
262 algorithm = "maglev",
263 "No healthy upstream targets available"
264 );
265 return Err(SentinelError::NoHealthyUpstream);
266 }
267
268 let key = self.extract_key(context);
270 let hash = xxh3_64(key.as_bytes()) as usize;
271 let table_index = hash % self.config.table_size;
272
273 let lookup = self.lookup_table.read().await;
275 let target_index = lookup[table_index];
276 drop(lookup);
277
278 let target = if let Some(idx) = target_index {
280 if idx < self.targets.len() {
282 let t = &self.targets[idx];
283 let health = self.health_status.read().await;
284 if *health.get(&t.full_address()).unwrap_or(&true) {
285 t
286 } else {
287 healthy_targets
289 .first()
290 .map(|(_, t)| *t)
291 .ok_or(SentinelError::NoHealthyUpstream)?
292 }
293 } else {
294 healthy_targets
296 .first()
297 .map(|(_, t)| *t)
298 .ok_or(SentinelError::NoHealthyUpstream)?
299 }
300 } else {
301 healthy_targets
303 .first()
304 .map(|(_, t)| *t)
305 .ok_or(SentinelError::NoHealthyUpstream)?
306 };
307
308 trace!(
309 selected_target = %target.full_address(),
310 hash_key = %key,
311 table_index = table_index,
312 healthy_count = healthy_targets.len(),
313 algorithm = "maglev",
314 "Selected target via Maglev consistent hashing"
315 );
316
317 Ok(TargetSelection {
318 address: target.full_address(),
319 weight: target.weight,
320 metadata: HashMap::new(),
321 })
322 }
323
324 async fn report_health(&self, address: &str, healthy: bool) {
325 let prev_health = {
326 let health = self.health_status.read().await;
327 *health.get(address).unwrap_or(&true)
328 };
329
330 if prev_health != healthy {
331 trace!(
332 target = %address,
333 healthy = healthy,
334 algorithm = "maglev",
335 "Target health changed, rebuilding lookup table"
336 );
337
338 self.health_status
339 .write()
340 .await
341 .insert(address.to_string(), healthy);
342
343 self.rebuild_table_for_healthy().await;
345 } else {
346 self.health_status
347 .write()
348 .await
349 .insert(address.to_string(), healthy);
350 }
351 }
352
353 async fn healthy_targets(&self) -> Vec<String> {
354 self.health_status
355 .read()
356 .await
357 .iter()
358 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
359 .collect()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn make_targets(count: usize) -> Vec<UpstreamTarget> {
368 (0..count)
369 .map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
370 .collect()
371 }
372
373 #[test]
374 fn test_build_lookup_table() {
375 let targets = make_targets(3);
376 let table = MaglevBalancer::build_lookup_table(&targets, 65537);
377
378 assert!(table.iter().all(|entry| entry.is_some()));
380
381 let mut counts = vec![0usize; 3];
383 for entry in &table {
384 if let Some(idx) = entry {
385 counts[*idx] += 1;
386 }
387 }
388
389 let expected = 65537 / 3;
391 for count in counts {
392 assert!(
393 (count as i64 - expected as i64).abs() < (expected as i64 / 10),
394 "Uneven distribution: {} vs expected ~{}",
395 count,
396 expected
397 );
398 }
399 }
400
401 #[test]
402 fn test_permutation_generation() {
403 let perm1 = MaglevBalancer::generate_permutation("backend-1", 65537);
404 let perm2 = MaglevBalancer::generate_permutation("backend-2", 65537);
405
406 assert_ne!(perm1[0..100], perm2[0..100]);
408
409 let mut seen = vec![false; 65537];
411 for &idx in &perm1 {
412 seen[idx] = true;
413 }
414 assert!(seen.iter().all(|&s| s));
415 }
416
417 #[tokio::test]
418 async fn test_consistent_selection() {
419 let targets = make_targets(5);
420 let balancer = MaglevBalancer::new(targets, MaglevConfig::default());
421
422 let context = RequestContext {
423 client_ip: Some("192.168.1.100:12345".parse().unwrap()),
424 headers: HashMap::new(),
425 path: "/api/test".to_string(),
426 method: "GET".to_string(),
427 };
428
429 let selection1 = balancer.select(Some(&context)).await.unwrap();
431 let selection2 = balancer.select(Some(&context)).await.unwrap();
432 let selection3 = balancer.select(Some(&context)).await.unwrap();
433
434 assert_eq!(selection1.address, selection2.address);
435 assert_eq!(selection2.address, selection3.address);
436 }
437
438 #[tokio::test]
439 async fn test_minimal_disruption() {
440 let targets = make_targets(5);
442 let balancer = MaglevBalancer::new(targets.clone(), MaglevConfig::default());
443
444 let mut original_selections = HashMap::new();
446 for i in 0..1000 {
447 let context = RequestContext {
448 client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
449 headers: HashMap::new(),
450 path: format!("/api/test/{}", i),
451 method: "GET".to_string(),
452 };
453 let selection = balancer.select(Some(&context)).await.unwrap();
454 original_selections.insert(i, selection.address);
455 }
456
457 balancer.report_health("backend-2:8080", false).await;
459
460 let mut changed = 0;
462 for i in 0..1000 {
463 let context = RequestContext {
464 client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
465 headers: HashMap::new(),
466 path: format!("/api/test/{}", i),
467 method: "GET".to_string(),
468 };
469 let selection = balancer.select(Some(&context)).await.unwrap();
470 if selection.address != original_selections[&i] {
471 changed += 1;
472 }
473 }
474
475 assert!(
481 changed < 800,
482 "Too many selections changed: {} (expected less than 800 for 1/5 backend removal)",
483 changed
484 );
485
486 assert!(
488 changed < 1000 - 100,
489 "Too few stable selections: only {} unchanged",
490 1000 - changed
491 );
492 }
493}