1use std::collections::HashMap;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26
27use serde::{Deserialize, Serialize};
28use tokio::sync::RwLock;
29use uuid::Uuid;
30
31const DEFAULT_TTL_SECS: u64 = 300;
33
34#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case", tag = "mode")]
49#[non_exhaustive]
50pub enum StickyPolicy {
51 #[default]
53 Disabled,
54 Domain {
56 #[serde(with = "serde_duration_secs")]
58 ttl: Duration,
59 },
60}
61
62impl StickyPolicy {
63 pub const fn domain(ttl: Duration) -> Self {
65 Self::Domain { ttl }
66 }
67
68 pub fn domain_default() -> Self {
70 Self::Domain {
71 ttl: Duration::from_secs(DEFAULT_TTL_SECS),
72 }
73 }
74
75 pub const fn is_disabled(&self) -> bool {
77 matches!(self, Self::Disabled)
78 }
79}
80
81#[derive(Debug, Clone)]
85struct ProxySession {
86 proxy_id: Uuid,
88 bound_at: Instant,
90 ttl: Duration,
92}
93
94impl ProxySession {
95 fn is_expired(&self) -> bool {
97 self.bound_at.elapsed() >= self.ttl
98 }
99}
100
101#[derive(Debug, Clone)]
120pub struct SessionMap {
121 inner: Arc<RwLock<HashMap<String, ProxySession>>>,
122}
123
124impl Default for SessionMap {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl SessionMap {
131 pub fn new() -> Self {
133 Self {
134 inner: Arc::new(RwLock::new(HashMap::new())),
135 }
136 }
137
138 pub fn lookup(&self, key: &str) -> Option<Uuid> {
144 let guard = self.inner.try_read().ok()?;
146 guard
147 .get(key)
148 .filter(|s| !s.is_expired())
149 .map(|s| s.proxy_id)
150 }
151
152 pub fn bind(&self, key: &str, proxy_id: Uuid, ttl: Duration) {
155 let session = ProxySession {
156 proxy_id,
157 bound_at: Instant::now(),
158 ttl,
159 };
160 if let Ok(mut guard) = self.inner.try_write() {
161 guard.insert(key.to_string(), session);
162 }
163 }
164
165 pub fn purge_expired(&self) -> usize {
167 let Ok(mut guard) = self.inner.try_write() else {
168 return 0;
169 };
170 let before = guard.len();
171 guard.retain(|_, s| !s.is_expired());
172 before - guard.len()
173 }
174
175 pub fn unbind(&self, key: &str) {
177 if let Ok(mut guard) = self.inner.try_write() {
178 guard.remove(key);
179 }
180 }
181
182 pub fn active_count(&self) -> usize {
184 let Ok(guard) = self.inner.try_read() else {
185 return 0;
186 };
187 guard.values().filter(|s| !s.is_expired()).count()
188 }
189}
190
191mod serde_duration_secs {
194 use serde::{Deserialize, Deserializer, Serialize, Serializer};
195 use std::time::Duration;
196
197 pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
198 d.as_secs().serialize(s)
199 }
200
201 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
202 Ok(Duration::from_secs(u64::deserialize(d)?))
203 }
204}
205
206#[cfg(test)]
209#[allow(clippy::unwrap_used)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn same_domain_returns_same_proxy() {
215 let map = SessionMap::new();
216 let id = Uuid::new_v4();
217 map.bind("example.com", id, Duration::from_secs(60));
218 assert_eq!(map.lookup("example.com"), Some(id));
219 assert_eq!(map.lookup("example.com"), Some(id));
220 }
221
222 #[test]
223 fn different_domains_independent() {
224 let map = SessionMap::new();
225 let id_a = Uuid::new_v4();
226 let id_b = Uuid::new_v4();
227 map.bind("a.com", id_a, Duration::from_secs(60));
228 map.bind("b.com", id_b, Duration::from_secs(60));
229 assert_eq!(map.lookup("a.com"), Some(id_a));
230 assert_eq!(map.lookup("b.com"), Some(id_b));
231 }
232
233 #[test]
234 fn expired_session_returns_none() {
235 let map = SessionMap::new();
236 let id = Uuid::new_v4();
237 map.bind("example.com", id, Duration::ZERO);
239 std::thread::sleep(Duration::from_millis(1));
241 assert_eq!(map.lookup("example.com"), None);
242 }
243
244 #[test]
245 fn purge_removes_expired() {
246 let map = SessionMap::new();
247 map.bind("expired.com", Uuid::new_v4(), Duration::ZERO);
248 map.bind("active.com", Uuid::new_v4(), Duration::from_secs(300));
249 std::thread::sleep(Duration::from_millis(1));
250
251 let removed = map.purge_expired();
252 assert_eq!(removed, 1);
253 assert_eq!(map.active_count(), 1);
254 }
255
256 #[test]
257 fn unbind_removes_session() {
258 let map = SessionMap::new();
259 map.bind("example.com", Uuid::new_v4(), Duration::from_secs(60));
260 map.unbind("example.com");
261 assert_eq!(map.lookup("example.com"), None);
262 }
263
264 #[test]
265 fn rebind_overwrites_previous() {
266 let map = SessionMap::new();
267 let old_id = Uuid::new_v4();
268 let new_id = Uuid::new_v4();
269 map.bind("example.com", old_id, Duration::from_secs(60));
270 map.bind("example.com", new_id, Duration::from_secs(60));
271 assert_eq!(map.lookup("example.com"), Some(new_id));
272 }
273
274 #[test]
275 fn policy_domain_default_ttl() {
276 let policy = StickyPolicy::domain_default();
277 match policy {
278 StickyPolicy::Domain { ttl } => {
279 assert_eq!(ttl, Duration::from_secs(300));
280 }
281 _ => panic!("expected Domain variant"),
282 }
283 }
284
285 #[test]
286 fn policy_disabled_by_default() {
287 let policy = StickyPolicy::default();
288 assert!(policy.is_disabled());
289 }
290
291 #[test]
292 fn policy_serde_roundtrip() {
293 let policy = StickyPolicy::domain(Duration::from_secs(120));
294 let json = serde_json::to_string(&policy).unwrap();
295 let back: StickyPolicy = serde_json::from_str(&json).unwrap();
296 match back {
297 StickyPolicy::Domain { ttl } => assert_eq!(ttl, Duration::from_secs(120)),
298 _ => panic!("expected Domain variant"),
299 }
300 }
301}