1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum HaBackend {
7 Redis,
8 Etcd,
9}
10
11impl std::fmt::Display for HaBackend {
12 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13 match self {
14 Self::Redis => write!(f, "redis"),
15 Self::Etcd => write!(f, "etcd"),
16 }
17 }
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct HaConfig {
23 pub backend: HaBackend,
25 pub endpoints: Vec<String>,
29 pub password: Option<String>,
31 #[serde(default)]
33 pub tls: bool,
34 #[serde(default = "default_prefix")]
36 pub prefix: String,
37 #[serde(default)]
39 pub leader: LeaderConfig,
40 #[serde(default)]
42 pub rate_limit: DistributedRateLimitConfig,
43 #[serde(default)]
45 pub audit_replication: AuditReplicationConfig,
46}
47
48fn default_prefix() -> String {
49 "audex:".to_string()
50}
51
52impl HaConfig {
53 pub fn resolve_password(&self) -> Option<String> {
63 std::env::var("AUDEX_HA_PASSWORD")
64 .ok()
65 .filter(|v| !v.is_empty())
66 .or_else(|| self.password.clone())
67 }
68}
69
70impl Default for HaConfig {
71 fn default() -> Self {
72 Self {
73 backend: HaBackend::Redis,
74 endpoints: vec!["redis://localhost:6379".to_string()],
75 password: None,
76 tls: false,
77 prefix: default_prefix(),
78 leader: LeaderConfig::default(),
79 rate_limit: DistributedRateLimitConfig::default(),
80 audit_replication: AuditReplicationConfig::default(),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct LeaderConfig {
88 pub node_id: Option<String>,
90 #[serde(default = "default_lease_ttl")]
92 pub lease_ttl: u64,
93 #[serde(default = "default_renewal_interval")]
95 pub renewal_interval: u64,
96}
97
98fn default_lease_ttl() -> u64 {
99 15
100}
101
102fn default_renewal_interval() -> u64 {
103 5
104}
105
106impl LeaderConfig {
107 pub fn validate(&self) -> crate::error::Result<()> {
113 if self.lease_ttl == 0 {
114 return Err(crate::error::AvError::InvalidPolicy(
115 "HA leader lease_ttl must be > 0".to_string(),
116 ));
117 }
118 if self.renewal_interval == 0 {
119 return Err(crate::error::AvError::InvalidPolicy(
120 "HA leader renewal_interval must be > 0".to_string(),
121 ));
122 }
123 let triple = self.renewal_interval.checked_mul(3).ok_or_else(|| {
131 crate::error::AvError::InvalidPolicy(format!(
132 "HA leader renewal_interval ({}) overflows when multiplied by 3",
133 self.renewal_interval
134 ))
135 })?;
136 if triple >= self.lease_ttl {
137 return Err(crate::error::AvError::InvalidPolicy(format!(
138 "HA leader renewal_interval ({}) must be < lease_ttl/3 ({}/3 = {}). \
139 Otherwise network hiccups can cause lease loss and split-brain.",
140 self.renewal_interval,
141 self.lease_ttl,
142 self.lease_ttl / 3
143 )));
144 }
145 Ok(())
146 }
147}
148
149impl Default for LeaderConfig {
150 fn default() -> Self {
151 Self {
152 node_id: None,
153 lease_ttl: default_lease_ttl(),
154 renewal_interval: default_renewal_interval(),
155 }
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct DistributedRateLimitConfig {
162 #[serde(default = "default_true")]
164 pub enabled: bool,
165 #[serde(default = "default_window")]
167 pub window_seconds: u64,
168 #[serde(default = "default_expiry")]
170 pub expiry_seconds: u64,
171}
172
173fn default_true() -> bool {
174 true
175}
176
177fn default_window() -> u64 {
178 3600
179}
180
181fn default_expiry() -> u64 {
182 7200
183}
184
185impl Default for DistributedRateLimitConfig {
186 fn default() -> Self {
187 Self {
188 enabled: true,
189 window_seconds: default_window(),
190 expiry_seconds: default_expiry(),
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct AuditReplicationConfig {
198 #[serde(default = "default_true")]
200 pub enabled: bool,
201 pub stream: Option<String>,
203 #[serde(default = "default_max_stream")]
205 pub max_stream_length: u64,
206 pub consumer_group: Option<String>,
208}
209
210fn default_max_stream() -> u64 {
211 10000
212}
213
214impl Default for AuditReplicationConfig {
215 fn default() -> Self {
216 Self {
217 enabled: true,
218 stream: None,
219 max_stream_length: default_max_stream(),
220 consumer_group: None,
221 }
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct LeaderState {
228 pub node_id: String,
229 pub is_leader: bool,
230 pub leader_node: Option<String>,
231 pub lease_expires_at: Option<String>,
232 pub last_heartbeat: Option<String>,
233}
234
235pub struct HaKeys {
237 prefix: String,
238}
239
240impl HaKeys {
241 pub fn new(prefix: &str) -> Self {
242 let sanitized: String = prefix
246 .chars()
247 .filter(|c| {
248 c.is_ascii_alphanumeric() || *c == ':' || *c == '-' || *c == '_' || *c == '.'
249 })
250 .collect();
251 if sanitized != prefix {
252 tracing::warn!(
253 original = prefix,
254 sanitized = %sanitized,
255 "HA key prefix contained invalid characters and was sanitized"
256 );
257 }
258 Self { prefix: sanitized }
259 }
260
261 fn sanitize_component(s: &str) -> String {
264 s.chars()
265 .filter(|c| {
266 c.is_ascii_alphanumeric()
267 || *c == ':'
268 || *c == '-'
269 || *c == '_'
270 || *c == '.'
271 || *c == '@'
272 })
273 .collect()
274 }
275
276 pub fn leader_lock(&self) -> String {
278 format!("{}leader:lock", self.prefix)
279 }
280
281 pub fn leader_heartbeat(&self) -> String {
283 format!("{}leader:heartbeat", self.prefix)
284 }
285
286 pub fn rate_limit(&self, identity: &str) -> String {
288 format!(
289 "{}ratelimit:{}",
290 self.prefix,
291 Self::sanitize_component(identity)
292 )
293 }
294
295 pub fn rate_limit_window(&self, identity: &str) -> String {
297 format!(
298 "{}ratelimit:window:{}",
299 self.prefix,
300 Self::sanitize_component(identity)
301 )
302 }
303
304 pub fn audit_stream(&self) -> String {
306 format!("{}audit:stream", self.prefix)
307 }
308
309 pub fn session_registry(&self) -> String {
311 format!("{}sessions:active", self.prefix)
312 }
313
314 pub fn node_registry(&self) -> String {
316 format!("{}nodes:active", self.prefix)
317 }
318
319 pub fn node_sessions(&self, node_id: &str) -> String {
321 format!(
322 "{}nodes:{}:sessions",
323 self.prefix,
324 Self::sanitize_component(node_id)
325 )
326 }
327}
328
329pub struct LeaderElectionCommands;
331
332impl LeaderElectionCommands {
333 pub fn try_acquire(key: &str, node_id: &str, ttl_secs: u64) -> Vec<String> {
336 vec![
337 "SET".to_string(),
338 key.to_string(),
339 node_id.to_string(),
340 "NX".to_string(),
341 "EX".to_string(),
342 ttl_secs.to_string(),
343 ]
344 }
345
346 pub fn renew_lua() -> &'static str {
350 r#"
351if redis.call("GET", KEYS[1]) == ARGV[1] then
352 return redis.call("EXPIRE", KEYS[1], ARGV[2])
353else
354 return 0
355end
356"#
357 }
358
359 pub fn release_lua() -> &'static str {
361 r#"
362if redis.call("GET", KEYS[1]) == ARGV[1] then
363 return redis.call("DEL", KEYS[1])
364else
365 return 0
366end
367"#
368 }
369
370 pub fn get_leader(key: &str) -> Vec<String> {
372 vec!["GET".to_string(), key.to_string()]
373 }
374}
375
376pub struct DistributedRateLimitCommands;
378
379impl DistributedRateLimitCommands {
380 pub fn check_and_record_lua() -> &'static str {
383 r#"
384local key = KEYS[1]
385local now = tonumber(ARGV[1])
386local window = tonumber(ARGV[2])
387local limit = tonumber(ARGV[3])
388local expiry = tonumber(ARGV[4])
389local member = ARGV[5]
390
391-- Remove expired entries
392redis.call("ZREMRANGEBYSCORE", key, 0, now - window)
393
394-- Count current entries in window
395local count = redis.call("ZCARD", key)
396
397if count >= limit then
398 return {0, count}
399end
400
401-- Add new entry
402redis.call("ZADD", key, now, member)
403redis.call("EXPIRE", key, expiry)
404
405return {1, count + 1}
406"#
407 }
408
409 pub fn get_count(key: &str, window_start: u64) -> Vec<String> {
411 vec![
412 "ZCOUNT".to_string(),
413 key.to_string(),
414 window_start.to_string(),
415 "+inf".to_string(),
416 ]
417 }
418}
419
420pub struct AuditReplicationCommands;
422
423impl AuditReplicationCommands {
424 pub fn publish_fields(entry_json: &str) -> Vec<(&str, &str)> {
427 vec![("data", entry_json)]
428 }
429
430 pub fn create_group(stream_key: &str, group: &str) -> Vec<String> {
432 vec![
433 "XGROUP".to_string(),
434 "CREATE".to_string(),
435 stream_key.to_string(),
436 group.to_string(),
437 "0".to_string(),
438 "MKSTREAM".to_string(),
439 ]
440 }
441
442 pub fn read_new(stream_key: &str, group: &str, consumer: &str, count: u64) -> Vec<String> {
444 vec![
445 "XREADGROUP".to_string(),
446 "GROUP".to_string(),
447 group.to_string(),
448 consumer.to_string(),
449 "COUNT".to_string(),
450 count.to_string(),
451 "STREAMS".to_string(),
452 stream_key.to_string(),
453 ">".to_string(),
454 ]
455 }
456
457 pub fn ack(stream_key: &str, group: &str, ids: &[&str]) -> Vec<String> {
459 let mut cmd = vec![
460 "XACK".to_string(),
461 stream_key.to_string(),
462 group.to_string(),
463 ];
464 for id in ids {
465 cmd.push(id.to_string());
466 }
467 cmd
468 }
469}
470
471pub struct EtcdOperations;
473
474impl EtcdOperations {
475 pub fn acquire_lock_endpoint() -> &'static str {
478 "/v3/kv/put"
479 }
480
481 pub fn grant_lease_endpoint() -> &'static str {
483 "/v3/lease/grant"
484 }
485
486 pub fn keepalive_endpoint() -> &'static str {
488 "/v3/lease/keepalive"
489 }
490
491 pub fn watch_endpoint() -> &'static str {
493 "/v3/watch"
494 }
495
496 pub fn lease_grant_body(ttl: u64) -> String {
498 format!("{{\"TTL\":{}}}", ttl)
499 }
500
501 pub fn txn_create_body(key: &str, value: &str, lease_id: u64) -> String {
503 let key_b64 = base64_encode(key);
504 let value_b64 = base64_encode(value);
505 format!(
506 r#"{{"compare":[{{"result":"EQUAL","target":"CREATE","key":"{}"}}],"success":[{{"request_put":{{"key":"{}","value":"{}","lease":{}}}}}],"failure":[]}}"#,
507 key_b64, key_b64, value_b64, lease_id
508 )
509 }
510}
511
512fn base64_encode(input: &str) -> String {
514 use std::io::Write;
515 let mut buf = Vec::new();
516 {
517 let mut encoder = Base64Writer::new(&mut buf);
518 encoder.write_all(input.as_bytes()).ok();
519 }
520 String::from_utf8(buf).unwrap_or_default()
521}
522
523struct Base64Writer<'a> {
525 out: &'a mut Vec<u8>,
526 buf: [u8; 3],
527 len: usize,
528}
529
530const B64_CHARS: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
531
532impl<'a> Base64Writer<'a> {
533 fn new(out: &'a mut Vec<u8>) -> Self {
534 Self {
535 out,
536 buf: [0; 3],
537 len: 0,
538 }
539 }
540
541 fn flush_buf(&mut self) {
542 if self.len == 0 {
543 return;
544 }
545 let b = self.buf;
546 self.out.push(B64_CHARS[(b[0] >> 2) as usize]);
547 self.out
548 .push(B64_CHARS[((b[0] & 0x03) << 4 | b[1] >> 4) as usize]);
549 if self.len > 1 {
550 self.out
551 .push(B64_CHARS[((b[1] & 0x0f) << 2 | b[2] >> 6) as usize]);
552 } else {
553 self.out.push(b'=');
554 }
555 if self.len > 2 {
556 self.out.push(B64_CHARS[(b[2] & 0x3f) as usize]);
557 } else {
558 self.out.push(b'=');
559 }
560 self.buf = [0; 3];
561 self.len = 0;
562 }
563}
564
565impl<'a> std::io::Write for Base64Writer<'a> {
566 fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
567 for &byte in data {
568 self.buf[self.len] = byte;
569 self.len += 1;
570 if self.len == 3 {
571 self.flush_buf();
572 }
573 }
574 Ok(data.len())
575 }
576
577 fn flush(&mut self) -> std::io::Result<()> {
578 self.flush_buf();
579 Ok(())
580 }
581}
582
583impl<'a> Drop for Base64Writer<'a> {
584 fn drop(&mut self) {
585 self.flush_buf();
586 }
587}
588
589pub fn generate_node_id() -> String {
596 let raw_hostname = std::fs::read_to_string("/etc/hostname")
597 .map(|h| h.trim().to_string())
598 .unwrap_or_else(|_| "unknown".to_string());
599 let hostname = HaKeys::sanitize_component(&raw_hostname);
600 let hostname = if hostname.is_empty() {
602 "unknown".to_string()
603 } else {
604 hostname
605 };
606 let id = uuid::Uuid::new_v4();
607 format!("{}-{}", hostname, id)
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize)]
612pub struct ClusterStatus {
613 pub node_id: String,
614 pub is_leader: bool,
615 pub leader_node: Option<String>,
616 pub active_nodes: Vec<NodeInfo>,
617 pub total_sessions: u64,
618 pub backend: String,
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize)]
623pub struct NodeInfo {
624 pub node_id: String,
625 pub address: String,
626 pub sessions: u64,
627 pub last_seen: String,
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_ha_config_default() {
636 let config = HaConfig::default();
637 assert_eq!(config.backend, HaBackend::Redis);
638 assert_eq!(config.endpoints.len(), 1);
639 assert_eq!(config.prefix, "audex:");
640 assert!(!config.tls);
641 }
642
643 #[test]
644 fn test_ha_config_deserialize_redis() {
645 let toml_str = r#"
646backend = "redis"
647endpoints = ["redis://redis-1:6379", "redis://redis-2:6379"]
648password = "secret"
649tls = true
650prefix = "myapp:audex:"
651
652[leader]
653node_id = "server-1"
654lease_ttl = 30
655renewal_interval = 10
656
657[rate_limit]
658enabled = true
659window_seconds = 1800
660
661[audit_replication]
662enabled = true
663max_stream_length = 50000
664"#;
665 let config: HaConfig = toml::from_str(toml_str).unwrap();
666 assert_eq!(config.backend, HaBackend::Redis);
667 assert_eq!(config.endpoints.len(), 2);
668 assert_eq!(config.password.as_deref(), Some("secret"));
669 assert!(config.tls);
670 assert_eq!(config.leader.lease_ttl, 30);
671 assert_eq!(config.rate_limit.window_seconds, 1800);
672 assert_eq!(config.audit_replication.max_stream_length, 50000);
673 }
674
675 #[test]
676 fn test_ha_config_deserialize_etcd() {
677 let toml_str = r#"
678backend = "etcd"
679endpoints = ["http://etcd-1:2379", "http://etcd-2:2379", "http://etcd-3:2379"]
680
681[leader]
682lease_ttl = 10
683"#;
684 let config: HaConfig = toml::from_str(toml_str).unwrap();
685 assert_eq!(config.backend, HaBackend::Etcd);
686 assert_eq!(config.endpoints.len(), 3);
687 assert_eq!(config.leader.lease_ttl, 10);
688 }
689
690 #[test]
691 fn test_ha_keys() {
692 let keys = HaKeys::new("audex:");
693 assert_eq!(keys.leader_lock(), "audex:leader:lock");
694 assert_eq!(keys.leader_heartbeat(), "audex:leader:heartbeat");
695 assert_eq!(
696 keys.rate_limit("user@corp.com"),
697 "audex:ratelimit:user@corp.com"
698 );
699 assert_eq!(keys.audit_stream(), "audex:audit:stream");
700 assert_eq!(keys.session_registry(), "audex:sessions:active");
701 assert_eq!(keys.node_registry(), "audex:nodes:active");
702 }
703
704 #[test]
705 fn test_ha_keys_custom_prefix() {
706 let keys = HaKeys::new("prod:audex:");
707 assert_eq!(keys.leader_lock(), "prod:audex:leader:lock");
708 assert_eq!(keys.rate_limit("alice"), "prod:audex:ratelimit:alice");
709 }
710
711 #[test]
712 fn test_leader_election_commands() {
713 let cmd = LeaderElectionCommands::try_acquire("audex:leader:lock", "node-1", 15);
714 assert_eq!(cmd[0], "SET");
715 assert_eq!(cmd[2], "node-1");
716 assert_eq!(cmd[3], "NX");
717 assert_eq!(cmd[4], "EX");
718 assert_eq!(cmd[5], "15");
719 }
720
721 #[test]
722 fn test_leader_election_lua_scripts() {
723 let renew = LeaderElectionCommands::renew_lua();
724 assert!(renew.contains("EXPIRE"));
725 assert!(renew.contains("GET"));
726
727 let release = LeaderElectionCommands::release_lua();
728 assert!(release.contains("DEL"));
729 }
730
731 #[test]
732 fn test_rate_limit_lua_script() {
733 let lua = DistributedRateLimitCommands::check_and_record_lua();
734 assert!(lua.contains("ZREMRANGEBYSCORE"));
735 assert!(lua.contains("ZADD"));
736 assert!(lua.contains("ZCARD"));
737 assert!(lua.contains("EXPIRE"));
738 }
739
740 #[test]
741 fn test_audit_replication_commands() {
742 let fields = AuditReplicationCommands::publish_fields("{\"event\":\"test\"}");
743 assert_eq!(fields[0].0, "data");
744
745 let group = AuditReplicationCommands::create_group("audex:audit:stream", "node-1-group");
746 assert_eq!(group[0], "XGROUP");
747 assert_eq!(group[5], "MKSTREAM");
748
749 let ack = AuditReplicationCommands::ack("stream", "group", &["1-1", "1-2"]);
750 assert_eq!(ack[0], "XACK");
751 assert_eq!(ack.len(), 5);
752 }
753
754 #[test]
755 fn test_etcd_operations() {
756 let body = EtcdOperations::lease_grant_body(15);
757 assert!(body.contains("\"TTL\":15"));
758
759 let txn = EtcdOperations::txn_create_body("mykey", "myval", 12345);
760 assert!(txn.contains("request_put"));
761 assert!(txn.contains("12345"));
762 }
763
764 #[test]
765 fn test_base64_encode() {
766 assert_eq!(base64_encode("hello"), "aGVsbG8=");
767 assert_eq!(base64_encode("ab"), "YWI=");
768 assert_eq!(base64_encode("abc"), "YWJj");
769 }
770
771 #[test]
772 fn test_generate_node_id() {
773 let id = generate_node_id();
774 assert!(!id.is_empty());
775 assert!(id.contains('-'));
777 }
778
779 #[test]
780 fn test_cluster_status_serialization() {
781 let status = ClusterStatus {
782 node_id: "node-1".to_string(),
783 is_leader: true,
784 leader_node: Some("node-1".to_string()),
785 active_nodes: vec![NodeInfo {
786 node_id: "node-1".to_string(),
787 address: "10.0.0.1:8080".to_string(),
788 sessions: 5,
789 last_seen: "2026-01-01T00:00:00Z".to_string(),
790 }],
791 total_sessions: 5,
792 backend: "redis".to_string(),
793 };
794 let json = serde_json::to_string(&status).unwrap();
795 assert!(json.contains("node-1"));
796 assert!(json.contains("is_leader"));
797 }
798
799 #[test]
800 fn test_ha_backend_display() {
801 assert_eq!(HaBackend::Redis.to_string(), "redis");
802 assert_eq!(HaBackend::Etcd.to_string(), "etcd");
803 }
804
805 #[test]
806 fn test_sanitize_component_strips_unsafe_chars() {
807 assert_eq!(HaKeys::sanitize_component("host\nname"), "hostname");
809 assert_eq!(HaKeys::sanitize_component("host\rname"), "hostname");
810 assert_eq!(HaKeys::sanitize_component("host name"), "hostname");
811 assert_eq!(HaKeys::sanitize_component("host\0name"), "hostname");
812 assert_eq!(
813 HaKeys::sanitize_component("good-host.local"),
814 "good-host.local"
815 );
816 assert_eq!(HaKeys::sanitize_component(""), "");
817 }
818
819 #[test]
820 fn test_leader_config_validate_overflow() {
821 let cfg = LeaderConfig {
823 node_id: None,
824 lease_ttl: 15,
825 renewal_interval: u64::MAX / 2,
826 };
827 let err = cfg.validate().unwrap_err();
828 let msg = format!("{}", err);
829 assert!(
830 msg.contains("overflows"),
831 "expected overflow error, got: {}",
832 msg
833 );
834 }
835
836 #[test]
837 fn test_leader_config_validate_normal() {
838 let cfg = LeaderConfig {
840 node_id: None,
841 lease_ttl: 15,
842 renewal_interval: 4,
843 };
844 assert!(cfg.validate().is_ok());
845
846 let cfg = LeaderConfig {
848 node_id: None,
849 lease_ttl: 15,
850 renewal_interval: 5,
851 };
852 assert!(cfg.validate().is_err());
853 }
854}