1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use hmac::{Hmac, Mac};
12use sha2::Sha256;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15
16use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
17use sentinel_common::errors::{SentinelError, SentinelResult};
18use sentinel_config::upstreams::StickySessionConfig;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(Debug, Clone)]
24pub struct StickySessionRuntimeConfig {
25 pub cookie_name: String,
27 pub cookie_ttl_secs: u64,
29 pub cookie_path: String,
31 pub cookie_secure: bool,
33 pub cookie_same_site: sentinel_config::upstreams::SameSitePolicy,
35 pub hmac_key: [u8; 32],
37}
38
39impl StickySessionRuntimeConfig {
40 pub fn from_config(config: &StickySessionConfig) -> Self {
42 use rand::RngCore;
43
44 let mut hmac_key = [0u8; 32];
46 rand::thread_rng().fill_bytes(&mut hmac_key);
47
48 Self {
49 cookie_name: config.cookie_name.clone(),
50 cookie_ttl_secs: config.cookie_ttl_secs,
51 cookie_path: config.cookie_path.clone(),
52 cookie_secure: config.cookie_secure,
53 cookie_same_site: config.cookie_same_site,
54 hmac_key,
55 }
56 }
57}
58
59pub struct StickySessionBalancer {
66 config: StickySessionRuntimeConfig,
68 targets: Vec<UpstreamTarget>,
70 fallback: Arc<dyn LoadBalancer>,
72 health_status: Arc<RwLock<HashMap<String, bool>>>,
74}
75
76impl StickySessionBalancer {
77 pub fn new(
79 targets: Vec<UpstreamTarget>,
80 config: StickySessionRuntimeConfig,
81 fallback: Arc<dyn LoadBalancer>,
82 ) -> Self {
83 trace!(
84 target_count = targets.len(),
85 cookie_name = %config.cookie_name,
86 cookie_ttl_secs = config.cookie_ttl_secs,
87 "Creating sticky session balancer"
88 );
89
90 let mut health_status = HashMap::new();
91 for target in &targets {
92 health_status.insert(target.full_address(), true);
93 }
94
95 Self {
96 config,
97 targets,
98 fallback,
99 health_status: Arc::new(RwLock::new(health_status)),
100 }
101 }
102
103 fn extract_affinity(&self, context: &RequestContext) -> Option<usize> {
107 let cookie_header = context.headers.get("cookie")?;
109
110 let cookie_value = cookie_header.split(';').find_map(|cookie| {
112 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
113 if parts.len() == 2 && parts[0] == self.config.cookie_name {
114 Some(parts[1].to_string())
115 } else {
116 None
117 }
118 })?;
119
120 let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
122 if parts.len() != 2 {
123 trace!(
124 cookie_value = %cookie_value,
125 "Invalid sticky cookie format (missing signature)"
126 );
127 return None;
128 }
129
130 let index: usize = parts[0].parse().ok()?;
131 let signature = parts[1];
132
133 if !self.verify_signature(index, signature) {
135 warn!(
136 cookie_value = %cookie_value,
137 "Invalid sticky cookie signature (possible tampering)"
138 );
139 return None;
140 }
141
142 if index >= self.targets.len() {
144 trace!(
145 index = index,
146 target_count = self.targets.len(),
147 "Sticky cookie index out of bounds"
148 );
149 return None;
150 }
151
152 trace!(
153 cookie_name = %self.config.cookie_name,
154 target_index = index,
155 "Extracted valid sticky session affinity"
156 );
157
158 Some(index)
159 }
160
161 pub fn generate_cookie_value(&self, target_index: usize) -> String {
163 let signature = self.sign_index(target_index);
164 format!("{}.{}", target_index, signature)
165 }
166
167 pub fn generate_set_cookie_header(&self, target_index: usize) -> String {
169 let cookie_value = self.generate_cookie_value(target_index);
170
171 let mut header = format!(
172 "{}={}; Path={}; Max-Age={}",
173 self.config.cookie_name,
174 cookie_value,
175 self.config.cookie_path,
176 self.config.cookie_ttl_secs
177 );
178
179 if self.config.cookie_secure {
180 header.push_str("; HttpOnly; Secure");
181 }
182
183 header.push_str(&format!("; SameSite={}", self.config.cookie_same_site));
184
185 header
186 }
187
188 fn sign_index(&self, index: usize) -> String {
190 let mut mac =
191 HmacSha256::new_from_slice(&self.config.hmac_key).expect("HMAC key length is valid");
192 mac.update(index.to_string().as_bytes());
193 let result = mac.finalize();
194 hex::encode(&result.into_bytes()[..8])
196 }
197
198 fn verify_signature(&self, index: usize, signature: &str) -> bool {
200 let expected = self.sign_index(index);
201 expected == signature
203 }
204
205 async fn is_target_healthy(&self, index: usize) -> bool {
207 if index >= self.targets.len() {
208 return false;
209 }
210
211 let target = &self.targets[index];
212 let health = self.health_status.read().await;
213 *health.get(&target.full_address()).unwrap_or(&true)
214 }
215
216 fn find_target_index(&self, address: &str) -> Option<usize> {
218 self.targets
219 .iter()
220 .position(|t| t.full_address() == address)
221 }
222
223 pub fn cookie_name(&self) -> &str {
225 &self.config.cookie_name
226 }
227
228 pub fn config(&self) -> &StickySessionRuntimeConfig {
230 &self.config
231 }
232}
233
234#[async_trait]
235impl LoadBalancer for StickySessionBalancer {
236 async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
237 trace!(
238 has_context = context.is_some(),
239 cookie_name = %self.config.cookie_name,
240 "Sticky session select called"
241 );
242
243 if let Some(ctx) = context {
245 if let Some(target_index) = self.extract_affinity(ctx) {
246 if self.is_target_healthy(target_index).await {
248 let target = &self.targets[target_index];
249
250 debug!(
251 target = %target.full_address(),
252 target_index = target_index,
253 cookie_name = %self.config.cookie_name,
254 "Sticky session hit - routing to affinity target"
255 );
256
257 return Ok(TargetSelection {
258 address: target.full_address(),
259 weight: target.weight,
260 metadata: {
261 let mut meta = HashMap::new();
262 meta.insert("sticky_session_hit".to_string(), "true".to_string());
263 meta.insert("sticky_target_index".to_string(), target_index.to_string());
264 meta.insert("algorithm".to_string(), "sticky_session".to_string());
265 meta
266 },
267 });
268 }
269
270 debug!(
271 target_index = target_index,
272 cookie_name = %self.config.cookie_name,
273 "Sticky target unhealthy, falling back to load balancer"
274 );
275 }
276 }
277
278 let mut selection = self.fallback.select(context).await?;
280
281 let target_index = self.find_target_index(&selection.address);
283
284 if let Some(index) = target_index {
285 selection
287 .metadata
288 .insert("sticky_session_new".to_string(), "true".to_string());
289 selection
290 .metadata
291 .insert("sticky_target_index".to_string(), index.to_string());
292 selection.metadata.insert(
293 "sticky_cookie_value".to_string(),
294 self.generate_cookie_value(index),
295 );
296 selection.metadata.insert(
297 "sticky_set_cookie_header".to_string(),
298 self.generate_set_cookie_header(index),
299 );
300
301 debug!(
302 target = %selection.address,
303 target_index = index,
304 cookie_name = %self.config.cookie_name,
305 "New sticky session assignment, will set cookie"
306 );
307 }
308
309 selection
310 .metadata
311 .insert("algorithm".to_string(), "sticky_session".to_string());
312
313 Ok(selection)
314 }
315
316 async fn report_health(&self, address: &str, healthy: bool) {
317 trace!(
318 target = %address,
319 healthy = healthy,
320 algorithm = "sticky_session",
321 "Updating target health status"
322 );
323
324 self.health_status
326 .write()
327 .await
328 .insert(address.to_string(), healthy);
329
330 self.fallback.report_health(address, healthy).await;
332 }
333
334 async fn healthy_targets(&self) -> Vec<String> {
335 self.fallback.healthy_targets().await
337 }
338
339 async fn release(&self, selection: &TargetSelection) {
340 self.fallback.release(selection).await;
342 }
343
344 async fn report_result(
345 &self,
346 selection: &TargetSelection,
347 success: bool,
348 latency: Option<std::time::Duration>,
349 ) {
350 self.fallback
352 .report_result(selection, success, latency)
353 .await;
354 }
355
356 async fn report_result_with_latency(
357 &self,
358 address: &str,
359 success: bool,
360 latency: Option<std::time::Duration>,
361 ) {
362 self.fallback
364 .report_result_with_latency(address, success, latency)
365 .await;
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
374 (0..count)
375 .map(|i| UpstreamTarget {
376 address: format!("10.0.0.{}", i + 1),
377 port: 8080,
378 weight: 100,
379 })
380 .collect()
381 }
382
383 fn create_test_config() -> StickySessionRuntimeConfig {
384 StickySessionRuntimeConfig {
385 cookie_name: "SERVERID".to_string(),
386 cookie_ttl_secs: 3600,
387 cookie_path: "/".to_string(),
388 cookie_secure: true,
389 cookie_same_site: sentinel_config::upstreams::SameSitePolicy::Lax,
390 hmac_key: [42u8; 32], }
392 }
393
394 #[test]
395 fn test_cookie_generation_and_validation() {
396 let targets = create_test_targets(3);
397 let config = create_test_config();
398
399 struct MockBalancer;
401
402 #[async_trait]
403 impl LoadBalancer for MockBalancer {
404 async fn select(
405 &self,
406 _context: Option<&RequestContext>,
407 ) -> SentinelResult<TargetSelection> {
408 Ok(TargetSelection {
409 address: "10.0.0.1:8080".to_string(),
410 weight: 100,
411 metadata: HashMap::new(),
412 })
413 }
414 async fn report_health(&self, _address: &str, _healthy: bool) {}
415 async fn healthy_targets(&self) -> Vec<String> {
416 vec![]
417 }
418 }
419
420 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
421
422 let cookie_value = balancer.generate_cookie_value(1);
424 assert!(cookie_value.starts_with("1."));
425 assert_eq!(cookie_value.len(), 2 + 16); let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
429 assert!(balancer.verify_signature(1, parts[1]));
430
431 assert!(!balancer.verify_signature(1, "invalid"));
433 assert!(!balancer.verify_signature(2, parts[1])); }
435
436 #[test]
437 fn test_set_cookie_header_generation() {
438 let targets = create_test_targets(3);
439 let config = create_test_config();
440
441 struct MockBalancer;
442
443 #[async_trait]
444 impl LoadBalancer for MockBalancer {
445 async fn select(
446 &self,
447 _context: Option<&RequestContext>,
448 ) -> SentinelResult<TargetSelection> {
449 unreachable!()
450 }
451 async fn report_health(&self, _address: &str, _healthy: bool) {}
452 async fn healthy_targets(&self) -> Vec<String> {
453 vec![]
454 }
455 }
456
457 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
458
459 let header = balancer.generate_set_cookie_header(0);
460 assert!(header.starts_with("SERVERID=0."));
461 assert!(header.contains("Path=/"));
462 assert!(header.contains("Max-Age=3600"));
463 assert!(header.contains("HttpOnly"));
464 assert!(header.contains("Secure"));
465 assert!(header.contains("SameSite=Lax"));
466 }
467
468 #[tokio::test]
469 async fn test_sticky_session_hit() {
470 let targets = create_test_targets(3);
471 let config = create_test_config();
472
473 struct MockBalancer;
474
475 #[async_trait]
476 impl LoadBalancer for MockBalancer {
477 async fn select(
478 &self,
479 _context: Option<&RequestContext>,
480 ) -> SentinelResult<TargetSelection> {
481 panic!("Fallback should not be called for sticky hit");
483 }
484 async fn report_health(&self, _address: &str, _healthy: bool) {}
485 async fn healthy_targets(&self) -> Vec<String> {
486 vec![
487 "10.0.0.1:8080".to_string(),
488 "10.0.0.2:8080".to_string(),
489 "10.0.0.3:8080".to_string(),
490 ]
491 }
492 }
493
494 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
495
496 let cookie_value = balancer.generate_cookie_value(1);
498
499 let mut headers = HashMap::new();
501 headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
502
503 let context = RequestContext {
504 client_ip: None,
505 headers,
506 path: "/".to_string(),
507 method: "GET".to_string(),
508 };
509
510 let selection = balancer.select(Some(&context)).await.unwrap();
511
512 assert_eq!(selection.address, "10.0.0.2:8080");
514 assert_eq!(
515 selection.metadata.get("sticky_session_hit"),
516 Some(&"true".to_string())
517 );
518 assert_eq!(
519 selection.metadata.get("sticky_target_index"),
520 Some(&"1".to_string())
521 );
522 }
523
524 #[tokio::test]
525 async fn test_sticky_session_miss_sets_cookie() {
526 let targets = create_test_targets(3);
527 let config = create_test_config();
528
529 struct MockBalancer;
530
531 #[async_trait]
532 impl LoadBalancer for MockBalancer {
533 async fn select(
534 &self,
535 _context: Option<&RequestContext>,
536 ) -> SentinelResult<TargetSelection> {
537 Ok(TargetSelection {
538 address: "10.0.0.2:8080".to_string(),
539 weight: 100,
540 metadata: HashMap::new(),
541 })
542 }
543 async fn report_health(&self, _address: &str, _healthy: bool) {}
544 async fn healthy_targets(&self) -> Vec<String> {
545 vec!["10.0.0.2:8080".to_string()]
546 }
547 }
548
549 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
550
551 let context = RequestContext {
553 client_ip: None,
554 headers: HashMap::new(),
555 path: "/".to_string(),
556 method: "GET".to_string(),
557 };
558
559 let selection = balancer.select(Some(&context)).await.unwrap();
560
561 assert_eq!(selection.address, "10.0.0.2:8080");
563 assert_eq!(
564 selection.metadata.get("sticky_session_new"),
565 Some(&"true".to_string())
566 );
567 assert!(selection.metadata.get("sticky_cookie_value").is_some());
568 assert!(selection.metadata.get("sticky_set_cookie_header").is_some());
569 }
570
571 #[tokio::test]
572 async fn test_unhealthy_target_falls_back() {
573 let targets = create_test_targets(3);
574 let config = create_test_config();
575
576 struct MockBalancer;
577
578 #[async_trait]
579 impl LoadBalancer for MockBalancer {
580 async fn select(
581 &self,
582 _context: Option<&RequestContext>,
583 ) -> SentinelResult<TargetSelection> {
584 Ok(TargetSelection {
585 address: "10.0.0.3:8080".to_string(), weight: 100,
587 metadata: HashMap::new(),
588 })
589 }
590 async fn report_health(&self, _address: &str, _healthy: bool) {}
591 async fn healthy_targets(&self) -> Vec<String> {
592 vec!["10.0.0.3:8080".to_string()]
593 }
594 }
595
596 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
597
598 balancer.report_health("10.0.0.2:8080", false).await;
600
601 let cookie_value = balancer.generate_cookie_value(1);
603
604 let mut headers = HashMap::new();
605 headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
606
607 let context = RequestContext {
608 client_ip: None,
609 headers,
610 path: "/".to_string(),
611 method: "GET".to_string(),
612 };
613
614 let selection = balancer.select(Some(&context)).await.unwrap();
615
616 assert_eq!(selection.address, "10.0.0.3:8080");
618 assert_eq!(
619 selection.metadata.get("sticky_session_new"),
620 Some(&"true".to_string())
621 );
622 }
623}