1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum HaBackend {
7 Redis,
8 Etcd,
9}
10
11impl std::fmt::Display for HaBackend {
12 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13 match self {
14 Self::Redis => write!(f, "redis"),
15 Self::Etcd => write!(f, "etcd"),
16 }
17 }
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct HaConfig {
23 pub backend: HaBackend,
25 pub endpoints: Vec<String>,
29 pub password: Option<String>,
31 #[serde(default)]
33 pub tls: bool,
34 #[serde(default = "default_prefix")]
36 pub prefix: String,
37 #[serde(default)]
39 pub leader: LeaderConfig,
40 #[serde(default)]
42 pub rate_limit: DistributedRateLimitConfig,
43 #[serde(default)]
45 pub audit_replication: AuditReplicationConfig,
46}
47
48fn default_prefix() -> String {
49 "audex:".to_string()
50}
51
52impl Default for HaConfig {
53 fn default() -> Self {
54 Self {
55 backend: HaBackend::Redis,
56 endpoints: vec!["redis://localhost:6379".to_string()],
57 password: None,
58 tls: false,
59 prefix: default_prefix(),
60 leader: LeaderConfig::default(),
61 rate_limit: DistributedRateLimitConfig::default(),
62 audit_replication: AuditReplicationConfig::default(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct LeaderConfig {
70 pub node_id: Option<String>,
72 #[serde(default = "default_lease_ttl")]
74 pub lease_ttl: u64,
75 #[serde(default = "default_renewal_interval")]
77 pub renewal_interval: u64,
78}
79
80fn default_lease_ttl() -> u64 {
81 15
82}
83
84fn default_renewal_interval() -> u64 {
85 5
86}
87
88impl Default for LeaderConfig {
89 fn default() -> Self {
90 Self {
91 node_id: None,
92 lease_ttl: default_lease_ttl(),
93 renewal_interval: default_renewal_interval(),
94 }
95 }
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct DistributedRateLimitConfig {
101 #[serde(default = "default_true")]
103 pub enabled: bool,
104 #[serde(default = "default_window")]
106 pub window_seconds: u64,
107 #[serde(default = "default_expiry")]
109 pub expiry_seconds: u64,
110}
111
112fn default_true() -> bool {
113 true
114}
115
116fn default_window() -> u64 {
117 3600
118}
119
120fn default_expiry() -> u64 {
121 7200
122}
123
124impl Default for DistributedRateLimitConfig {
125 fn default() -> Self {
126 Self {
127 enabled: true,
128 window_seconds: default_window(),
129 expiry_seconds: default_expiry(),
130 }
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct AuditReplicationConfig {
137 #[serde(default = "default_true")]
139 pub enabled: bool,
140 pub stream: Option<String>,
142 #[serde(default = "default_max_stream")]
144 pub max_stream_length: u64,
145 pub consumer_group: Option<String>,
147}
148
149fn default_max_stream() -> u64 {
150 10000
151}
152
153impl Default for AuditReplicationConfig {
154 fn default() -> Self {
155 Self {
156 enabled: true,
157 stream: None,
158 max_stream_length: default_max_stream(),
159 consumer_group: None,
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct LeaderState {
167 pub node_id: String,
168 pub is_leader: bool,
169 pub leader_node: Option<String>,
170 pub lease_expires_at: Option<String>,
171 pub last_heartbeat: Option<String>,
172}
173
174pub struct HaKeys {
176 prefix: String,
177}
178
179impl HaKeys {
180 pub fn new(prefix: &str) -> Self {
181 Self {
182 prefix: prefix.to_string(),
183 }
184 }
185
186 pub fn leader_lock(&self) -> String {
188 format!("{}leader:lock", self.prefix)
189 }
190
191 pub fn leader_heartbeat(&self) -> String {
193 format!("{}leader:heartbeat", self.prefix)
194 }
195
196 pub fn rate_limit(&self, identity: &str) -> String {
198 format!("{}ratelimit:{}", self.prefix, identity)
199 }
200
201 pub fn rate_limit_window(&self, identity: &str) -> String {
203 format!("{}ratelimit:window:{}", self.prefix, identity)
204 }
205
206 pub fn audit_stream(&self) -> String {
208 format!("{}audit:stream", self.prefix)
209 }
210
211 pub fn session_registry(&self) -> String {
213 format!("{}sessions:active", self.prefix)
214 }
215
216 pub fn node_registry(&self) -> String {
218 format!("{}nodes:active", self.prefix)
219 }
220
221 pub fn node_sessions(&self, node_id: &str) -> String {
223 format!("{}nodes:{}:sessions", self.prefix, node_id)
224 }
225}
226
227pub struct LeaderElectionCommands;
229
230impl LeaderElectionCommands {
231 pub fn try_acquire(key: &str, node_id: &str, ttl_secs: u64) -> Vec<String> {
234 vec![
235 "SET".to_string(),
236 key.to_string(),
237 node_id.to_string(),
238 "NX".to_string(),
239 "EX".to_string(),
240 ttl_secs.to_string(),
241 ]
242 }
243
244 pub fn renew_lua() -> &'static str {
247 r#"
248if redis.call("GET", KEYS[1]) == ARGV[1] then
249 return redis.call("PEXPIRE", KEYS[1], ARGV[2])
250else
251 return 0
252end
253"#
254 }
255
256 pub fn release_lua() -> &'static str {
258 r#"
259if redis.call("GET", KEYS[1]) == ARGV[1] then
260 return redis.call("DEL", KEYS[1])
261else
262 return 0
263end
264"#
265 }
266
267 pub fn get_leader(key: &str) -> Vec<String> {
269 vec!["GET".to_string(), key.to_string()]
270 }
271}
272
273pub struct DistributedRateLimitCommands;
275
276impl DistributedRateLimitCommands {
277 pub fn check_and_record_lua() -> &'static str {
280 r#"
281local key = KEYS[1]
282local now = tonumber(ARGV[1])
283local window = tonumber(ARGV[2])
284local limit = tonumber(ARGV[3])
285local expiry = tonumber(ARGV[4])
286local member = ARGV[5]
287
288-- Remove expired entries
289redis.call("ZREMRANGEBYSCORE", key, 0, now - window)
290
291-- Count current entries in window
292local count = redis.call("ZCARD", key)
293
294if count >= limit then
295 return {0, count}
296end
297
298-- Add new entry
299redis.call("ZADD", key, now, member)
300redis.call("EXPIRE", key, expiry)
301
302return {1, count + 1}
303"#
304 }
305
306 pub fn get_count(key: &str, window_start: u64) -> Vec<String> {
308 vec![
309 "ZCOUNT".to_string(),
310 key.to_string(),
311 window_start.to_string(),
312 "+inf".to_string(),
313 ]
314 }
315}
316
317pub struct AuditReplicationCommands;
319
320impl AuditReplicationCommands {
321 pub fn publish_fields(entry_json: &str) -> Vec<(&str, &str)> {
324 vec![("data", entry_json)]
325 }
326
327 pub fn create_group(stream_key: &str, group: &str) -> Vec<String> {
329 vec![
330 "XGROUP".to_string(),
331 "CREATE".to_string(),
332 stream_key.to_string(),
333 group.to_string(),
334 "0".to_string(),
335 "MKSTREAM".to_string(),
336 ]
337 }
338
339 pub fn read_new(stream_key: &str, group: &str, consumer: &str, count: u64) -> Vec<String> {
341 vec![
342 "XREADGROUP".to_string(),
343 "GROUP".to_string(),
344 group.to_string(),
345 consumer.to_string(),
346 "COUNT".to_string(),
347 count.to_string(),
348 "STREAMS".to_string(),
349 stream_key.to_string(),
350 ">".to_string(),
351 ]
352 }
353
354 pub fn ack(stream_key: &str, group: &str, ids: &[&str]) -> Vec<String> {
356 let mut cmd = vec![
357 "XACK".to_string(),
358 stream_key.to_string(),
359 group.to_string(),
360 ];
361 for id in ids {
362 cmd.push(id.to_string());
363 }
364 cmd
365 }
366}
367
368pub struct EtcdOperations;
370
371impl EtcdOperations {
372 pub fn acquire_lock_endpoint() -> &'static str {
375 "/v3/kv/put"
376 }
377
378 pub fn grant_lease_endpoint() -> &'static str {
380 "/v3/lease/grant"
381 }
382
383 pub fn keepalive_endpoint() -> &'static str {
385 "/v3/lease/keepalive"
386 }
387
388 pub fn watch_endpoint() -> &'static str {
390 "/v3/watch"
391 }
392
393 pub fn lease_grant_body(ttl: u64) -> String {
395 format!("{{\"TTL\":{}}}", ttl)
396 }
397
398 pub fn txn_create_body(key: &str, value: &str, lease_id: u64) -> String {
400 let key_b64 = base64_encode(key);
401 let value_b64 = base64_encode(value);
402 format!(
403 r#"{{"compare":[{{"result":"EQUAL","target":"CREATE","key":"{}"}}],"success":[{{"request_put":{{"key":"{}","value":"{}","lease":{}}}}}],"failure":[]}}"#,
404 key_b64, key_b64, value_b64, lease_id
405 )
406 }
407}
408
409fn base64_encode(input: &str) -> String {
411 use std::io::Write;
412 let mut buf = Vec::new();
413 {
414 let mut encoder = Base64Writer::new(&mut buf);
415 encoder.write_all(input.as_bytes()).ok();
416 }
417 String::from_utf8(buf).unwrap_or_default()
418}
419
420struct Base64Writer<'a> {
422 out: &'a mut Vec<u8>,
423 buf: [u8; 3],
424 len: usize,
425}
426
427const B64_CHARS: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
428
429impl<'a> Base64Writer<'a> {
430 fn new(out: &'a mut Vec<u8>) -> Self {
431 Self {
432 out,
433 buf: [0; 3],
434 len: 0,
435 }
436 }
437
438 fn flush_buf(&mut self) {
439 if self.len == 0 {
440 return;
441 }
442 let b = self.buf;
443 self.out.push(B64_CHARS[(b[0] >> 2) as usize]);
444 self.out
445 .push(B64_CHARS[((b[0] & 0x03) << 4 | b[1] >> 4) as usize]);
446 if self.len > 1 {
447 self.out
448 .push(B64_CHARS[((b[1] & 0x0f) << 2 | b[2] >> 6) as usize]);
449 } else {
450 self.out.push(b'=');
451 }
452 if self.len > 2 {
453 self.out.push(B64_CHARS[(b[2] & 0x3f) as usize]);
454 } else {
455 self.out.push(b'=');
456 }
457 self.buf = [0; 3];
458 self.len = 0;
459 }
460}
461
462impl<'a> std::io::Write for Base64Writer<'a> {
463 fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
464 for &byte in data {
465 self.buf[self.len] = byte;
466 self.len += 1;
467 if self.len == 3 {
468 self.flush_buf();
469 }
470 }
471 Ok(data.len())
472 }
473
474 fn flush(&mut self) -> std::io::Result<()> {
475 self.flush_buf();
476 Ok(())
477 }
478}
479
480impl<'a> Drop for Base64Writer<'a> {
481 fn drop(&mut self) {
482 self.flush_buf();
483 }
484}
485
486pub fn generate_node_id() -> String {
488 let hostname = std::fs::read_to_string("/etc/hostname")
489 .map(|h| h.trim().to_string())
490 .unwrap_or_else(|_| "unknown".to_string());
491 let pid = std::process::id();
492 let ts = std::time::SystemTime::now()
493 .duration_since(std::time::UNIX_EPOCH)
494 .map(|d| d.as_millis())
495 .unwrap_or(0);
496 format!("{}-{}-{}", hostname, pid, ts % 100000)
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct ClusterStatus {
502 pub node_id: String,
503 pub is_leader: bool,
504 pub leader_node: Option<String>,
505 pub active_nodes: Vec<NodeInfo>,
506 pub total_sessions: u64,
507 pub backend: String,
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct NodeInfo {
513 pub node_id: String,
514 pub address: String,
515 pub sessions: u64,
516 pub last_seen: String,
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_ha_config_default() {
525 let config = HaConfig::default();
526 assert_eq!(config.backend, HaBackend::Redis);
527 assert_eq!(config.endpoints.len(), 1);
528 assert_eq!(config.prefix, "audex:");
529 assert!(!config.tls);
530 }
531
532 #[test]
533 fn test_ha_config_deserialize_redis() {
534 let toml_str = r#"
535backend = "redis"
536endpoints = ["redis://redis-1:6379", "redis://redis-2:6379"]
537password = "secret"
538tls = true
539prefix = "myapp:audex:"
540
541[leader]
542node_id = "server-1"
543lease_ttl = 30
544renewal_interval = 10
545
546[rate_limit]
547enabled = true
548window_seconds = 1800
549
550[audit_replication]
551enabled = true
552max_stream_length = 50000
553"#;
554 let config: HaConfig = toml::from_str(toml_str).unwrap();
555 assert_eq!(config.backend, HaBackend::Redis);
556 assert_eq!(config.endpoints.len(), 2);
557 assert_eq!(config.password.as_deref(), Some("secret"));
558 assert!(config.tls);
559 assert_eq!(config.leader.lease_ttl, 30);
560 assert_eq!(config.rate_limit.window_seconds, 1800);
561 assert_eq!(config.audit_replication.max_stream_length, 50000);
562 }
563
564 #[test]
565 fn test_ha_config_deserialize_etcd() {
566 let toml_str = r#"
567backend = "etcd"
568endpoints = ["http://etcd-1:2379", "http://etcd-2:2379", "http://etcd-3:2379"]
569
570[leader]
571lease_ttl = 10
572"#;
573 let config: HaConfig = toml::from_str(toml_str).unwrap();
574 assert_eq!(config.backend, HaBackend::Etcd);
575 assert_eq!(config.endpoints.len(), 3);
576 assert_eq!(config.leader.lease_ttl, 10);
577 }
578
579 #[test]
580 fn test_ha_keys() {
581 let keys = HaKeys::new("audex:");
582 assert_eq!(keys.leader_lock(), "audex:leader:lock");
583 assert_eq!(keys.leader_heartbeat(), "audex:leader:heartbeat");
584 assert_eq!(
585 keys.rate_limit("user@corp.com"),
586 "audex:ratelimit:user@corp.com"
587 );
588 assert_eq!(keys.audit_stream(), "audex:audit:stream");
589 assert_eq!(keys.session_registry(), "audex:sessions:active");
590 assert_eq!(keys.node_registry(), "audex:nodes:active");
591 }
592
593 #[test]
594 fn test_ha_keys_custom_prefix() {
595 let keys = HaKeys::new("prod:audex:");
596 assert_eq!(keys.leader_lock(), "prod:audex:leader:lock");
597 assert_eq!(keys.rate_limit("alice"), "prod:audex:ratelimit:alice");
598 }
599
600 #[test]
601 fn test_leader_election_commands() {
602 let cmd = LeaderElectionCommands::try_acquire("audex:leader:lock", "node-1", 15);
603 assert_eq!(cmd[0], "SET");
604 assert_eq!(cmd[2], "node-1");
605 assert_eq!(cmd[3], "NX");
606 assert_eq!(cmd[4], "EX");
607 assert_eq!(cmd[5], "15");
608 }
609
610 #[test]
611 fn test_leader_election_lua_scripts() {
612 let renew = LeaderElectionCommands::renew_lua();
613 assert!(renew.contains("PEXPIRE"));
614 assert!(renew.contains("GET"));
615
616 let release = LeaderElectionCommands::release_lua();
617 assert!(release.contains("DEL"));
618 }
619
620 #[test]
621 fn test_rate_limit_lua_script() {
622 let lua = DistributedRateLimitCommands::check_and_record_lua();
623 assert!(lua.contains("ZREMRANGEBYSCORE"));
624 assert!(lua.contains("ZADD"));
625 assert!(lua.contains("ZCARD"));
626 assert!(lua.contains("EXPIRE"));
627 }
628
629 #[test]
630 fn test_audit_replication_commands() {
631 let fields = AuditReplicationCommands::publish_fields("{\"event\":\"test\"}");
632 assert_eq!(fields[0].0, "data");
633
634 let group = AuditReplicationCommands::create_group("audex:audit:stream", "node-1-group");
635 assert_eq!(group[0], "XGROUP");
636 assert_eq!(group[5], "MKSTREAM");
637
638 let ack = AuditReplicationCommands::ack("stream", "group", &["1-1", "1-2"]);
639 assert_eq!(ack[0], "XACK");
640 assert_eq!(ack.len(), 5);
641 }
642
643 #[test]
644 fn test_etcd_operations() {
645 let body = EtcdOperations::lease_grant_body(15);
646 assert!(body.contains("\"TTL\":15"));
647
648 let txn = EtcdOperations::txn_create_body("mykey", "myval", 12345);
649 assert!(txn.contains("request_put"));
650 assert!(txn.contains("12345"));
651 }
652
653 #[test]
654 fn test_base64_encode() {
655 assert_eq!(base64_encode("hello"), "aGVsbG8=");
656 assert_eq!(base64_encode("ab"), "YWI=");
657 assert_eq!(base64_encode("abc"), "YWJj");
658 }
659
660 #[test]
661 fn test_generate_node_id() {
662 let id = generate_node_id();
663 assert!(!id.is_empty());
664 assert!(id.contains('-'));
666 }
667
668 #[test]
669 fn test_cluster_status_serialization() {
670 let status = ClusterStatus {
671 node_id: "node-1".to_string(),
672 is_leader: true,
673 leader_node: Some("node-1".to_string()),
674 active_nodes: vec![NodeInfo {
675 node_id: "node-1".to_string(),
676 address: "10.0.0.1:8080".to_string(),
677 sessions: 5,
678 last_seen: "2026-01-01T00:00:00Z".to_string(),
679 }],
680 total_sessions: 5,
681 backend: "redis".to_string(),
682 };
683 let json = serde_json::to_string(&status).unwrap();
684 assert!(json.contains("node-1"));
685 assert!(json.contains("is_leader"));
686 }
687
688 #[test]
689 fn test_ha_backend_display() {
690 assert_eq!(HaBackend::Redis.to_string(), "redis");
691 assert_eq!(HaBackend::Etcd.to_string(), "etcd");
692 }
693}