sentinel_proxy/upstream/
sticky_session.rs

1//! Cookie-based sticky session load balancer
2//!
3//! Routes requests to the same backend based on an affinity cookie.
4//! Falls back to a configurable algorithm when no cookie is present
5//! or the target is unavailable.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use hmac::{Hmac, Mac};
12use sha2::Sha256;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15
16use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
17use sentinel_common::errors::{SentinelError, SentinelResult};
18use sentinel_config::upstreams::StickySessionConfig;
19
20type HmacSha256 = Hmac<Sha256>;
21
22/// Runtime configuration for sticky sessions
23#[derive(Debug, Clone)]
24pub struct StickySessionRuntimeConfig {
25    /// Cookie name for session affinity
26    pub cookie_name: String,
27    /// Cookie TTL in seconds
28    pub cookie_ttl_secs: u64,
29    /// Cookie path
30    pub cookie_path: String,
31    /// Whether to set Secure and HttpOnly flags
32    pub cookie_secure: bool,
33    /// SameSite policy
34    pub cookie_same_site: sentinel_config::upstreams::SameSitePolicy,
35    /// HMAC key for signing cookie values
36    pub hmac_key: [u8; 32],
37}
38
39impl StickySessionRuntimeConfig {
40    /// Create runtime config from parsed config, generating HMAC key
41    pub fn from_config(config: &StickySessionConfig) -> Self {
42        use rand::RngCore;
43
44        // Generate random HMAC key
45        let mut hmac_key = [0u8; 32];
46        rand::thread_rng().fill_bytes(&mut hmac_key);
47
48        Self {
49            cookie_name: config.cookie_name.clone(),
50            cookie_ttl_secs: config.cookie_ttl_secs,
51            cookie_path: config.cookie_path.clone(),
52            cookie_secure: config.cookie_secure,
53            cookie_same_site: config.cookie_same_site,
54            hmac_key,
55        }
56    }
57}
58
59/// Cookie-based sticky session load balancer
60///
61/// This balancer wraps a fallback load balancer and adds session affinity
62/// based on cookies. When a client has a valid affinity cookie, requests
63/// are routed to the same backend. Otherwise, the fallback balancer is used
64/// and a new cookie is set.
65pub struct StickySessionBalancer {
66    /// Runtime configuration
67    config: StickySessionRuntimeConfig,
68    /// All upstream targets
69    targets: Vec<UpstreamTarget>,
70    /// Fallback load balancer
71    fallback: Arc<dyn LoadBalancer>,
72    /// Target health status
73    health_status: Arc<RwLock<HashMap<String, bool>>>,
74}
75
76impl StickySessionBalancer {
77    /// Create a new sticky session balancer
78    pub fn new(
79        targets: Vec<UpstreamTarget>,
80        config: StickySessionRuntimeConfig,
81        fallback: Arc<dyn LoadBalancer>,
82    ) -> Self {
83        trace!(
84            target_count = targets.len(),
85            cookie_name = %config.cookie_name,
86            cookie_ttl_secs = config.cookie_ttl_secs,
87            "Creating sticky session balancer"
88        );
89
90        let mut health_status = HashMap::new();
91        for target in &targets {
92            health_status.insert(target.full_address(), true);
93        }
94
95        Self {
96            config,
97            targets,
98            fallback,
99            health_status: Arc::new(RwLock::new(health_status)),
100        }
101    }
102
103    /// Extract and validate sticky cookie from request
104    ///
105    /// Returns the target index if the cookie is valid and properly signed.
106    fn extract_affinity(&self, context: &RequestContext) -> Option<usize> {
107        // Get cookie header
108        let cookie_header = context.headers.get("cookie")?;
109
110        // Parse cookies and find our sticky session cookie
111        let cookie_value = cookie_header.split(';').find_map(|cookie| {
112            let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
113            if parts.len() == 2 && parts[0] == self.config.cookie_name {
114                Some(parts[1].to_string())
115            } else {
116                None
117            }
118        })?;
119
120        // Validate cookie format: "{index}.{signature}"
121        let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
122        if parts.len() != 2 {
123            trace!(
124                cookie_value = %cookie_value,
125                "Invalid sticky cookie format (missing signature)"
126            );
127            return None;
128        }
129
130        let index: usize = parts[0].parse().ok()?;
131        let signature = parts[1];
132
133        // Verify HMAC signature
134        if !self.verify_signature(index, signature) {
135            warn!(
136                cookie_value = %cookie_value,
137                "Invalid sticky cookie signature (possible tampering)"
138            );
139            return None;
140        }
141
142        // Verify index is valid
143        if index >= self.targets.len() {
144            trace!(
145                index = index,
146                target_count = self.targets.len(),
147                "Sticky cookie index out of bounds"
148            );
149            return None;
150        }
151
152        trace!(
153            cookie_name = %self.config.cookie_name,
154            target_index = index,
155            "Extracted valid sticky session affinity"
156        );
157
158        Some(index)
159    }
160
161    /// Generate signed cookie value for target
162    pub fn generate_cookie_value(&self, target_index: usize) -> String {
163        let signature = self.sign_index(target_index);
164        format!("{}.{}", target_index, signature)
165    }
166
167    /// Generate full Set-Cookie header value
168    pub fn generate_set_cookie_header(&self, target_index: usize) -> String {
169        let cookie_value = self.generate_cookie_value(target_index);
170
171        let mut header = format!(
172            "{}={}; Path={}; Max-Age={}",
173            self.config.cookie_name,
174            cookie_value,
175            self.config.cookie_path,
176            self.config.cookie_ttl_secs
177        );
178
179        if self.config.cookie_secure {
180            header.push_str("; HttpOnly; Secure");
181        }
182
183        header.push_str(&format!("; SameSite={}", self.config.cookie_same_site));
184
185        header
186    }
187
188    /// Sign target index with HMAC-SHA256
189    fn sign_index(&self, index: usize) -> String {
190        let mut mac =
191            HmacSha256::new_from_slice(&self.config.hmac_key).expect("HMAC key length is valid");
192        mac.update(index.to_string().as_bytes());
193        let result = mac.finalize();
194        // Use first 8 bytes of signature (16 hex chars) for compactness
195        hex::encode(&result.into_bytes()[..8])
196    }
197
198    /// Verify HMAC signature for target index
199    fn verify_signature(&self, index: usize, signature: &str) -> bool {
200        let expected = self.sign_index(index);
201        // Constant-time comparison
202        expected == signature
203    }
204
205    /// Check if target at index is healthy
206    async fn is_target_healthy(&self, index: usize) -> bool {
207        if index >= self.targets.len() {
208            return false;
209        }
210
211        let target = &self.targets[index];
212        let health = self.health_status.read().await;
213        *health.get(&target.full_address()).unwrap_or(&true)
214    }
215
216    /// Find target index by address
217    fn find_target_index(&self, address: &str) -> Option<usize> {
218        self.targets
219            .iter()
220            .position(|t| t.full_address() == address)
221    }
222
223    /// Get the cookie name
224    pub fn cookie_name(&self) -> &str {
225        &self.config.cookie_name
226    }
227
228    /// Get the config for Set-Cookie header generation
229    pub fn config(&self) -> &StickySessionRuntimeConfig {
230        &self.config
231    }
232}
233
234#[async_trait]
235impl LoadBalancer for StickySessionBalancer {
236    async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
237        trace!(
238            has_context = context.is_some(),
239            cookie_name = %self.config.cookie_name,
240            "Sticky session select called"
241        );
242
243        // Try to extract affinity from cookie
244        if let Some(ctx) = context {
245            if let Some(target_index) = self.extract_affinity(ctx) {
246                // Check if target is healthy
247                if self.is_target_healthy(target_index).await {
248                    let target = &self.targets[target_index];
249
250                    debug!(
251                        target = %target.full_address(),
252                        target_index = target_index,
253                        cookie_name = %self.config.cookie_name,
254                        "Sticky session hit - routing to affinity target"
255                    );
256
257                    return Ok(TargetSelection {
258                        address: target.full_address(),
259                        weight: target.weight,
260                        metadata: {
261                            let mut meta = HashMap::new();
262                            meta.insert("sticky_session_hit".to_string(), "true".to_string());
263                            meta.insert("sticky_target_index".to_string(), target_index.to_string());
264                            meta.insert("algorithm".to_string(), "sticky_session".to_string());
265                            meta
266                        },
267                    });
268                }
269
270                debug!(
271                    target_index = target_index,
272                    cookie_name = %self.config.cookie_name,
273                    "Sticky target unhealthy, falling back to load balancer"
274                );
275            }
276        }
277
278        // No valid cookie or target unavailable - use fallback
279        let mut selection = self.fallback.select(context).await?;
280
281        // Find target index for the selected address
282        let target_index = self.find_target_index(&selection.address);
283
284        if let Some(index) = target_index {
285            // Mark that we need to set a new cookie
286            selection
287                .metadata
288                .insert("sticky_session_new".to_string(), "true".to_string());
289            selection
290                .metadata
291                .insert("sticky_target_index".to_string(), index.to_string());
292            selection.metadata.insert(
293                "sticky_cookie_value".to_string(),
294                self.generate_cookie_value(index),
295            );
296            selection.metadata.insert(
297                "sticky_set_cookie_header".to_string(),
298                self.generate_set_cookie_header(index),
299            );
300
301            debug!(
302                target = %selection.address,
303                target_index = index,
304                cookie_name = %self.config.cookie_name,
305                "New sticky session assignment, will set cookie"
306            );
307        }
308
309        selection
310            .metadata
311            .insert("algorithm".to_string(), "sticky_session".to_string());
312
313        Ok(selection)
314    }
315
316    async fn report_health(&self, address: &str, healthy: bool) {
317        trace!(
318            target = %address,
319            healthy = healthy,
320            algorithm = "sticky_session",
321            "Updating target health status"
322        );
323
324        // Update local health status
325        self.health_status
326            .write()
327            .await
328            .insert(address.to_string(), healthy);
329
330        // Propagate to fallback balancer
331        self.fallback.report_health(address, healthy).await;
332    }
333
334    async fn healthy_targets(&self) -> Vec<String> {
335        // Delegate to fallback balancer for consistency
336        self.fallback.healthy_targets().await
337    }
338
339    async fn release(&self, selection: &TargetSelection) {
340        // Delegate to fallback balancer
341        self.fallback.release(selection).await;
342    }
343
344    async fn report_result(
345        &self,
346        selection: &TargetSelection,
347        success: bool,
348        latency: Option<std::time::Duration>,
349    ) {
350        // Delegate to fallback balancer
351        self.fallback
352            .report_result(selection, success, latency)
353            .await;
354    }
355
356    async fn report_result_with_latency(
357        &self,
358        address: &str,
359        success: bool,
360        latency: Option<std::time::Duration>,
361    ) {
362        // Delegate to fallback balancer
363        self.fallback
364            .report_result_with_latency(address, success, latency)
365            .await;
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
374        (0..count)
375            .map(|i| UpstreamTarget {
376                address: format!("10.0.0.{}", i + 1),
377                port: 8080,
378                weight: 100,
379            })
380            .collect()
381    }
382
383    fn create_test_config() -> StickySessionRuntimeConfig {
384        StickySessionRuntimeConfig {
385            cookie_name: "SERVERID".to_string(),
386            cookie_ttl_secs: 3600,
387            cookie_path: "/".to_string(),
388            cookie_secure: true,
389            cookie_same_site: sentinel_config::upstreams::SameSitePolicy::Lax,
390            hmac_key: [42u8; 32], // Fixed key for testing
391        }
392    }
393
394    #[test]
395    fn test_cookie_generation_and_validation() {
396        let targets = create_test_targets(3);
397        let config = create_test_config();
398
399        // Create a mock fallback balancer
400        struct MockBalancer;
401
402        #[async_trait]
403        impl LoadBalancer for MockBalancer {
404            async fn select(
405                &self,
406                _context: Option<&RequestContext>,
407            ) -> SentinelResult<TargetSelection> {
408                Ok(TargetSelection {
409                    address: "10.0.0.1:8080".to_string(),
410                    weight: 100,
411                    metadata: HashMap::new(),
412                })
413            }
414            async fn report_health(&self, _address: &str, _healthy: bool) {}
415            async fn healthy_targets(&self) -> Vec<String> {
416                vec![]
417            }
418        }
419
420        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
421
422        // Test cookie value generation
423        let cookie_value = balancer.generate_cookie_value(1);
424        assert!(cookie_value.starts_with("1."));
425        assert_eq!(cookie_value.len(), 2 + 16); // "1." + 16 hex chars
426
427        // Test signature verification
428        let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
429        assert!(balancer.verify_signature(1, parts[1]));
430
431        // Test invalid signature
432        assert!(!balancer.verify_signature(1, "invalid"));
433        assert!(!balancer.verify_signature(2, parts[1])); // Wrong index
434    }
435
436    #[test]
437    fn test_set_cookie_header_generation() {
438        let targets = create_test_targets(3);
439        let config = create_test_config();
440
441        struct MockBalancer;
442
443        #[async_trait]
444        impl LoadBalancer for MockBalancer {
445            async fn select(
446                &self,
447                _context: Option<&RequestContext>,
448            ) -> SentinelResult<TargetSelection> {
449                unreachable!()
450            }
451            async fn report_health(&self, _address: &str, _healthy: bool) {}
452            async fn healthy_targets(&self) -> Vec<String> {
453                vec![]
454            }
455        }
456
457        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
458
459        let header = balancer.generate_set_cookie_header(0);
460        assert!(header.starts_with("SERVERID=0."));
461        assert!(header.contains("Path=/"));
462        assert!(header.contains("Max-Age=3600"));
463        assert!(header.contains("HttpOnly"));
464        assert!(header.contains("Secure"));
465        assert!(header.contains("SameSite=Lax"));
466    }
467
468    #[tokio::test]
469    async fn test_sticky_session_hit() {
470        let targets = create_test_targets(3);
471        let config = create_test_config();
472
473        struct MockBalancer;
474
475        #[async_trait]
476        impl LoadBalancer for MockBalancer {
477            async fn select(
478                &self,
479                _context: Option<&RequestContext>,
480            ) -> SentinelResult<TargetSelection> {
481                // Should not be called when we have valid cookie
482                panic!("Fallback should not be called for sticky hit");
483            }
484            async fn report_health(&self, _address: &str, _healthy: bool) {}
485            async fn healthy_targets(&self) -> Vec<String> {
486                vec![
487                    "10.0.0.1:8080".to_string(),
488                    "10.0.0.2:8080".to_string(),
489                    "10.0.0.3:8080".to_string(),
490                ]
491            }
492        }
493
494        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
495
496        // Generate a valid cookie for target 1
497        let cookie_value = balancer.generate_cookie_value(1);
498
499        // Create context with sticky cookie
500        let mut headers = HashMap::new();
501        headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
502
503        let context = RequestContext {
504            client_ip: None,
505            headers,
506            path: "/".to_string(),
507            method: "GET".to_string(),
508        };
509
510        let selection = balancer.select(Some(&context)).await.unwrap();
511
512        // Should route to target 1 (10.0.0.2:8080)
513        assert_eq!(selection.address, "10.0.0.2:8080");
514        assert_eq!(
515            selection.metadata.get("sticky_session_hit"),
516            Some(&"true".to_string())
517        );
518        assert_eq!(
519            selection.metadata.get("sticky_target_index"),
520            Some(&"1".to_string())
521        );
522    }
523
524    #[tokio::test]
525    async fn test_sticky_session_miss_sets_cookie() {
526        let targets = create_test_targets(3);
527        let config = create_test_config();
528
529        struct MockBalancer;
530
531        #[async_trait]
532        impl LoadBalancer for MockBalancer {
533            async fn select(
534                &self,
535                _context: Option<&RequestContext>,
536            ) -> SentinelResult<TargetSelection> {
537                Ok(TargetSelection {
538                    address: "10.0.0.2:8080".to_string(),
539                    weight: 100,
540                    metadata: HashMap::new(),
541                })
542            }
543            async fn report_health(&self, _address: &str, _healthy: bool) {}
544            async fn healthy_targets(&self) -> Vec<String> {
545                vec!["10.0.0.2:8080".to_string()]
546            }
547        }
548
549        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
550
551        // Create context without sticky cookie
552        let context = RequestContext {
553            client_ip: None,
554            headers: HashMap::new(),
555            path: "/".to_string(),
556            method: "GET".to_string(),
557        };
558
559        let selection = balancer.select(Some(&context)).await.unwrap();
560
561        // Should use fallback and mark for cookie setting
562        assert_eq!(selection.address, "10.0.0.2:8080");
563        assert_eq!(
564            selection.metadata.get("sticky_session_new"),
565            Some(&"true".to_string())
566        );
567        assert!(selection.metadata.get("sticky_cookie_value").is_some());
568        assert!(selection.metadata.get("sticky_set_cookie_header").is_some());
569    }
570
571    #[tokio::test]
572    async fn test_unhealthy_target_falls_back() {
573        let targets = create_test_targets(3);
574        let config = create_test_config();
575
576        struct MockBalancer;
577
578        #[async_trait]
579        impl LoadBalancer for MockBalancer {
580            async fn select(
581                &self,
582                _context: Option<&RequestContext>,
583            ) -> SentinelResult<TargetSelection> {
584                Ok(TargetSelection {
585                    address: "10.0.0.3:8080".to_string(), // Different target
586                    weight: 100,
587                    metadata: HashMap::new(),
588                })
589            }
590            async fn report_health(&self, _address: &str, _healthy: bool) {}
591            async fn healthy_targets(&self) -> Vec<String> {
592                vec!["10.0.0.3:8080".to_string()]
593            }
594        }
595
596        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
597
598        // Mark target 1 as unhealthy
599        balancer.report_health("10.0.0.2:8080", false).await;
600
601        // Generate cookie for unhealthy target 1
602        let cookie_value = balancer.generate_cookie_value(1);
603
604        let mut headers = HashMap::new();
605        headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
606
607        let context = RequestContext {
608            client_ip: None,
609            headers,
610            path: "/".to_string(),
611            method: "GET".to_string(),
612        };
613
614        let selection = balancer.select(Some(&context)).await.unwrap();
615
616        // Should fall back to another target and set new cookie
617        assert_eq!(selection.address, "10.0.0.3:8080");
618        assert_eq!(
619            selection.metadata.get("sticky_session_new"),
620            Some(&"true".to_string())
621        );
622    }
623}