Skip to main content

tryaudex_core/
ha.rs

1use serde::{Deserialize, Serialize};
2
3/// High availability backend type.
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum HaBackend {
7    Redis,
8    Etcd,
9}
10
11impl std::fmt::Display for HaBackend {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        match self {
14            Self::Redis => write!(f, "redis"),
15            Self::Etcd => write!(f, "etcd"),
16        }
17    }
18}
19
20/// Configuration for high availability mode.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct HaConfig {
23    /// Backend: "redis" or "etcd"
24    pub backend: HaBackend,
25    /// Connection URL(s).
26    /// - Redis: "redis://localhost:6379" or "redis+sentinel://host:26379/mymaster"
27    /// - Etcd: "http://localhost:2379" (comma-separated for cluster)
28    pub endpoints: Vec<String>,
29    /// Password for backend authentication
30    pub password: Option<String>,
31    /// TLS enabled
32    #[serde(default)]
33    pub tls: bool,
34    /// Key prefix for all HA keys (default: "audex:")
35    #[serde(default = "default_prefix")]
36    pub prefix: String,
37    /// Leader election configuration
38    #[serde(default)]
39    pub leader: LeaderConfig,
40    /// Distributed rate limiting configuration
41    #[serde(default)]
42    pub rate_limit: DistributedRateLimitConfig,
43    /// Audit replication configuration
44    #[serde(default)]
45    pub audit_replication: AuditReplicationConfig,
46}
47
48fn default_prefix() -> String {
49    "audex:".to_string()
50}
51
52impl Default for HaConfig {
53    fn default() -> Self {
54        Self {
55            backend: HaBackend::Redis,
56            endpoints: vec!["redis://localhost:6379".to_string()],
57            password: None,
58            tls: false,
59            prefix: default_prefix(),
60            leader: LeaderConfig::default(),
61            rate_limit: DistributedRateLimitConfig::default(),
62            audit_replication: AuditReplicationConfig::default(),
63        }
64    }
65}
66
67/// Leader election configuration.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct LeaderConfig {
70    /// Unique node ID for this instance (auto-generated if not set)
71    pub node_id: Option<String>,
72    /// Lease TTL in seconds (default: 15)
73    #[serde(default = "default_lease_ttl")]
74    pub lease_ttl: u64,
75    /// Renewal interval in seconds (default: 5, should be < lease_ttl/3)
76    #[serde(default = "default_renewal_interval")]
77    pub renewal_interval: u64,
78}
79
80fn default_lease_ttl() -> u64 {
81    15
82}
83
84fn default_renewal_interval() -> u64 {
85    5
86}
87
88impl Default for LeaderConfig {
89    fn default() -> Self {
90        Self {
91            node_id: None,
92            lease_ttl: default_lease_ttl(),
93            renewal_interval: default_renewal_interval(),
94        }
95    }
96}
97
98/// Distributed rate limiting configuration.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct DistributedRateLimitConfig {
101    /// Enable distributed rate limiting (default: true)
102    #[serde(default = "default_true")]
103    pub enabled: bool,
104    /// Sliding window size in seconds (default: 3600 = 1 hour)
105    #[serde(default = "default_window")]
106    pub window_seconds: u64,
107    /// Key expiry in seconds (default: 7200 = 2 hours)
108    #[serde(default = "default_expiry")]
109    pub expiry_seconds: u64,
110}
111
112fn default_true() -> bool {
113    true
114}
115
116fn default_window() -> u64 {
117    3600
118}
119
120fn default_expiry() -> u64 {
121    7200
122}
123
124impl Default for DistributedRateLimitConfig {
125    fn default() -> Self {
126        Self {
127            enabled: true,
128            window_seconds: default_window(),
129            expiry_seconds: default_expiry(),
130        }
131    }
132}
133
134/// Audit replication configuration.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct AuditReplicationConfig {
137    /// Enable audit replication across nodes (default: true)
138    #[serde(default = "default_true")]
139    pub enabled: bool,
140    /// Stream/channel name for audit events (default: "audex:audit:stream")
141    pub stream: Option<String>,
142    /// Max stream length before trimming (default: 10000)
143    #[serde(default = "default_max_stream")]
144    pub max_stream_length: u64,
145    /// Consumer group name (auto-generated from node_id if not set)
146    pub consumer_group: Option<String>,
147}
148
149fn default_max_stream() -> u64 {
150    10000
151}
152
153impl Default for AuditReplicationConfig {
154    fn default() -> Self {
155        Self {
156            enabled: true,
157            stream: None,
158            max_stream_length: default_max_stream(),
159            consumer_group: None,
160        }
161    }
162}
163
164/// Leader election state.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct LeaderState {
167    pub node_id: String,
168    pub is_leader: bool,
169    pub leader_node: Option<String>,
170    pub lease_expires_at: Option<String>,
171    pub last_heartbeat: Option<String>,
172}
173
174/// Redis key generators for HA operations.
175pub struct HaKeys {
176    prefix: String,
177}
178
179impl HaKeys {
180    pub fn new(prefix: &str) -> Self {
181        Self {
182            prefix: prefix.to_string(),
183        }
184    }
185
186    /// Leader election lock key.
187    pub fn leader_lock(&self) -> String {
188        format!("{}leader:lock", self.prefix)
189    }
190
191    /// Leader heartbeat key.
192    pub fn leader_heartbeat(&self) -> String {
193        format!("{}leader:heartbeat", self.prefix)
194    }
195
196    /// Rate limit counter key for an identity.
197    pub fn rate_limit(&self, identity: &str) -> String {
198        format!("{}ratelimit:{}", self.prefix, identity)
199    }
200
201    /// Rate limit sorted set key for sliding window.
202    pub fn rate_limit_window(&self, identity: &str) -> String {
203        format!("{}ratelimit:window:{}", self.prefix, identity)
204    }
205
206    /// Audit replication stream key.
207    pub fn audit_stream(&self) -> String {
208        format!("{}audit:stream", self.prefix)
209    }
210
211    /// Session registry key (tracks active sessions across nodes).
212    pub fn session_registry(&self) -> String {
213        format!("{}sessions:active", self.prefix)
214    }
215
216    /// Node registry key.
217    pub fn node_registry(&self) -> String {
218        format!("{}nodes:active", self.prefix)
219    }
220
221    /// Per-node session count key.
222    pub fn node_sessions(&self, node_id: &str) -> String {
223        format!("{}nodes:{}:sessions", self.prefix, node_id)
224    }
225}
226
227/// Redis commands for leader election using SET NX EX (atomic lock).
228pub struct LeaderElectionCommands;
229
230impl LeaderElectionCommands {
231    /// Attempt to acquire leadership.
232    /// Uses SET key value NX EX ttl for atomic acquire.
233    pub fn try_acquire(key: &str, node_id: &str, ttl_secs: u64) -> Vec<String> {
234        vec![
235            "SET".to_string(),
236            key.to_string(),
237            node_id.to_string(),
238            "NX".to_string(),
239            "EX".to_string(),
240            ttl_secs.to_string(),
241        ]
242    }
243
244    /// Renew leadership lease (only if we still hold it).
245    /// Uses a Lua script for atomic check-and-renew.
246    pub fn renew_lua() -> &'static str {
247        r#"
248if redis.call("GET", KEYS[1]) == ARGV[1] then
249    return redis.call("PEXPIRE", KEYS[1], ARGV[2])
250else
251    return 0
252end
253"#
254    }
255
256    /// Release leadership (only if we hold it).
257    pub fn release_lua() -> &'static str {
258        r#"
259if redis.call("GET", KEYS[1]) == ARGV[1] then
260    return redis.call("DEL", KEYS[1])
261else
262    return 0
263end
264"#
265    }
266
267    /// Get current leader.
268    pub fn get_leader(key: &str) -> Vec<String> {
269        vec!["GET".to_string(), key.to_string()]
270    }
271}
272
273/// Redis commands for distributed rate limiting using sliding window.
274pub struct DistributedRateLimitCommands;
275
276impl DistributedRateLimitCommands {
277    /// Record a session and check rate limit.
278    /// Uses ZADD + ZREMRANGEBYSCORE + ZCARD in a pipeline for atomic sliding window.
279    pub fn check_and_record_lua() -> &'static str {
280        r#"
281local key = KEYS[1]
282local now = tonumber(ARGV[1])
283local window = tonumber(ARGV[2])
284local limit = tonumber(ARGV[3])
285local expiry = tonumber(ARGV[4])
286local member = ARGV[5]
287
288-- Remove expired entries
289redis.call("ZREMRANGEBYSCORE", key, 0, now - window)
290
291-- Count current entries in window
292local count = redis.call("ZCARD", key)
293
294if count >= limit then
295    return {0, count}
296end
297
298-- Add new entry
299redis.call("ZADD", key, now, member)
300redis.call("EXPIRE", key, expiry)
301
302return {1, count + 1}
303"#
304    }
305
306    /// Get current count for an identity.
307    pub fn get_count(key: &str, window_start: u64) -> Vec<String> {
308        vec![
309            "ZCOUNT".to_string(),
310            key.to_string(),
311            window_start.to_string(),
312            "+inf".to_string(),
313        ]
314    }
315}
316
317/// Redis commands for audit stream replication.
318pub struct AuditReplicationCommands;
319
320impl AuditReplicationCommands {
321    /// Publish an audit entry to the stream.
322    /// XADD key MAXLEN ~ max_length * data
323    pub fn publish_fields(entry_json: &str) -> Vec<(&str, &str)> {
324        vec![("data", entry_json)]
325    }
326
327    /// Create consumer group (idempotent with MKSTREAM).
328    pub fn create_group(stream_key: &str, group: &str) -> Vec<String> {
329        vec![
330            "XGROUP".to_string(),
331            "CREATE".to_string(),
332            stream_key.to_string(),
333            group.to_string(),
334            "0".to_string(),
335            "MKSTREAM".to_string(),
336        ]
337    }
338
339    /// Read new entries from the stream for a consumer.
340    pub fn read_new(stream_key: &str, group: &str, consumer: &str, count: u64) -> Vec<String> {
341        vec![
342            "XREADGROUP".to_string(),
343            "GROUP".to_string(),
344            group.to_string(),
345            consumer.to_string(),
346            "COUNT".to_string(),
347            count.to_string(),
348            "STREAMS".to_string(),
349            stream_key.to_string(),
350            ">".to_string(),
351        ]
352    }
353
354    /// Acknowledge processed entries.
355    pub fn ack(stream_key: &str, group: &str, ids: &[&str]) -> Vec<String> {
356        let mut cmd = vec![
357            "XACK".to_string(),
358            stream_key.to_string(),
359            group.to_string(),
360        ];
361        for id in ids {
362            cmd.push(id.to_string());
363        }
364        cmd
365    }
366}
367
368/// Etcd key-value operations for HA (alternative to Redis).
369pub struct EtcdOperations;
370
371impl EtcdOperations {
372    /// Leader election uses etcd leases with keep-alive.
373    /// PUT key value --lease=LEASE_ID
374    pub fn acquire_lock_endpoint() -> &'static str {
375        "/v3/kv/put"
376    }
377
378    /// Create a lease.
379    pub fn grant_lease_endpoint() -> &'static str {
380        "/v3/lease/grant"
381    }
382
383    /// Keep lease alive.
384    pub fn keepalive_endpoint() -> &'static str {
385        "/v3/lease/keepalive"
386    }
387
388    /// Watch for leader changes.
389    pub fn watch_endpoint() -> &'static str {
390        "/v3/watch"
391    }
392
393    /// Format a lease grant request body.
394    pub fn lease_grant_body(ttl: u64) -> String {
395        format!("{{\"TTL\":{}}}", ttl)
396    }
397
398    /// Format a put-if-not-exists request (using transactions).
399    pub fn txn_create_body(key: &str, value: &str, lease_id: u64) -> String {
400        let key_b64 = base64_encode(key);
401        let value_b64 = base64_encode(value);
402        format!(
403            r#"{{"compare":[{{"result":"EQUAL","target":"CREATE","key":"{}"}}],"success":[{{"request_put":{{"key":"{}","value":"{}","lease":{}}}}}],"failure":[]}}"#,
404            key_b64, key_b64, value_b64, lease_id
405        )
406    }
407}
408
409/// Simple base64 encoding for etcd API (no padding).
410fn base64_encode(input: &str) -> String {
411    use std::io::Write;
412    let mut buf = Vec::new();
413    {
414        let mut encoder = Base64Writer::new(&mut buf);
415        encoder.write_all(input.as_bytes()).ok();
416    }
417    String::from_utf8(buf).unwrap_or_default()
418}
419
420/// Minimal base64 encoder (avoid external dependency).
421struct Base64Writer<'a> {
422    out: &'a mut Vec<u8>,
423    buf: [u8; 3],
424    len: usize,
425}
426
427const B64_CHARS: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
428
429impl<'a> Base64Writer<'a> {
430    fn new(out: &'a mut Vec<u8>) -> Self {
431        Self {
432            out,
433            buf: [0; 3],
434            len: 0,
435        }
436    }
437
438    fn flush_buf(&mut self) {
439        if self.len == 0 {
440            return;
441        }
442        let b = self.buf;
443        self.out.push(B64_CHARS[(b[0] >> 2) as usize]);
444        self.out
445            .push(B64_CHARS[((b[0] & 0x03) << 4 | b[1] >> 4) as usize]);
446        if self.len > 1 {
447            self.out
448                .push(B64_CHARS[((b[1] & 0x0f) << 2 | b[2] >> 6) as usize]);
449        } else {
450            self.out.push(b'=');
451        }
452        if self.len > 2 {
453            self.out.push(B64_CHARS[(b[2] & 0x3f) as usize]);
454        } else {
455            self.out.push(b'=');
456        }
457        self.buf = [0; 3];
458        self.len = 0;
459    }
460}
461
462impl<'a> std::io::Write for Base64Writer<'a> {
463    fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
464        for &byte in data {
465            self.buf[self.len] = byte;
466            self.len += 1;
467            if self.len == 3 {
468                self.flush_buf();
469            }
470        }
471        Ok(data.len())
472    }
473
474    fn flush(&mut self) -> std::io::Result<()> {
475        self.flush_buf();
476        Ok(())
477    }
478}
479
480impl<'a> Drop for Base64Writer<'a> {
481    fn drop(&mut self) {
482        self.flush_buf();
483    }
484}
485
486/// Generate a unique node ID for this instance.
487pub fn generate_node_id() -> String {
488    let hostname = std::fs::read_to_string("/etc/hostname")
489        .map(|h| h.trim().to_string())
490        .unwrap_or_else(|_| "unknown".to_string());
491    let pid = std::process::id();
492    let ts = std::time::SystemTime::now()
493        .duration_since(std::time::UNIX_EPOCH)
494        .map(|d| d.as_millis())
495        .unwrap_or(0);
496    format!("{}-{}-{}", hostname, pid, ts % 100000)
497}
498
499/// Cluster status information.
500#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct ClusterStatus {
502    pub node_id: String,
503    pub is_leader: bool,
504    pub leader_node: Option<String>,
505    pub active_nodes: Vec<NodeInfo>,
506    pub total_sessions: u64,
507    pub backend: String,
508}
509
510/// Information about an active node in the cluster.
511#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct NodeInfo {
513    pub node_id: String,
514    pub address: String,
515    pub sessions: u64,
516    pub last_seen: String,
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_ha_config_default() {
525        let config = HaConfig::default();
526        assert_eq!(config.backend, HaBackend::Redis);
527        assert_eq!(config.endpoints.len(), 1);
528        assert_eq!(config.prefix, "audex:");
529        assert!(!config.tls);
530    }
531
532    #[test]
533    fn test_ha_config_deserialize_redis() {
534        let toml_str = r#"
535backend = "redis"
536endpoints = ["redis://redis-1:6379", "redis://redis-2:6379"]
537password = "secret"
538tls = true
539prefix = "myapp:audex:"
540
541[leader]
542node_id = "server-1"
543lease_ttl = 30
544renewal_interval = 10
545
546[rate_limit]
547enabled = true
548window_seconds = 1800
549
550[audit_replication]
551enabled = true
552max_stream_length = 50000
553"#;
554        let config: HaConfig = toml::from_str(toml_str).unwrap();
555        assert_eq!(config.backend, HaBackend::Redis);
556        assert_eq!(config.endpoints.len(), 2);
557        assert_eq!(config.password.as_deref(), Some("secret"));
558        assert!(config.tls);
559        assert_eq!(config.leader.lease_ttl, 30);
560        assert_eq!(config.rate_limit.window_seconds, 1800);
561        assert_eq!(config.audit_replication.max_stream_length, 50000);
562    }
563
564    #[test]
565    fn test_ha_config_deserialize_etcd() {
566        let toml_str = r#"
567backend = "etcd"
568endpoints = ["http://etcd-1:2379", "http://etcd-2:2379", "http://etcd-3:2379"]
569
570[leader]
571lease_ttl = 10
572"#;
573        let config: HaConfig = toml::from_str(toml_str).unwrap();
574        assert_eq!(config.backend, HaBackend::Etcd);
575        assert_eq!(config.endpoints.len(), 3);
576        assert_eq!(config.leader.lease_ttl, 10);
577    }
578
579    #[test]
580    fn test_ha_keys() {
581        let keys = HaKeys::new("audex:");
582        assert_eq!(keys.leader_lock(), "audex:leader:lock");
583        assert_eq!(keys.leader_heartbeat(), "audex:leader:heartbeat");
584        assert_eq!(
585            keys.rate_limit("user@corp.com"),
586            "audex:ratelimit:user@corp.com"
587        );
588        assert_eq!(keys.audit_stream(), "audex:audit:stream");
589        assert_eq!(keys.session_registry(), "audex:sessions:active");
590        assert_eq!(keys.node_registry(), "audex:nodes:active");
591    }
592
593    #[test]
594    fn test_ha_keys_custom_prefix() {
595        let keys = HaKeys::new("prod:audex:");
596        assert_eq!(keys.leader_lock(), "prod:audex:leader:lock");
597        assert_eq!(keys.rate_limit("alice"), "prod:audex:ratelimit:alice");
598    }
599
600    #[test]
601    fn test_leader_election_commands() {
602        let cmd = LeaderElectionCommands::try_acquire("audex:leader:lock", "node-1", 15);
603        assert_eq!(cmd[0], "SET");
604        assert_eq!(cmd[2], "node-1");
605        assert_eq!(cmd[3], "NX");
606        assert_eq!(cmd[4], "EX");
607        assert_eq!(cmd[5], "15");
608    }
609
610    #[test]
611    fn test_leader_election_lua_scripts() {
612        let renew = LeaderElectionCommands::renew_lua();
613        assert!(renew.contains("PEXPIRE"));
614        assert!(renew.contains("GET"));
615
616        let release = LeaderElectionCommands::release_lua();
617        assert!(release.contains("DEL"));
618    }
619
620    #[test]
621    fn test_rate_limit_lua_script() {
622        let lua = DistributedRateLimitCommands::check_and_record_lua();
623        assert!(lua.contains("ZREMRANGEBYSCORE"));
624        assert!(lua.contains("ZADD"));
625        assert!(lua.contains("ZCARD"));
626        assert!(lua.contains("EXPIRE"));
627    }
628
629    #[test]
630    fn test_audit_replication_commands() {
631        let fields = AuditReplicationCommands::publish_fields("{\"event\":\"test\"}");
632        assert_eq!(fields[0].0, "data");
633
634        let group = AuditReplicationCommands::create_group("audex:audit:stream", "node-1-group");
635        assert_eq!(group[0], "XGROUP");
636        assert_eq!(group[5], "MKSTREAM");
637
638        let ack = AuditReplicationCommands::ack("stream", "group", &["1-1", "1-2"]);
639        assert_eq!(ack[0], "XACK");
640        assert_eq!(ack.len(), 5);
641    }
642
643    #[test]
644    fn test_etcd_operations() {
645        let body = EtcdOperations::lease_grant_body(15);
646        assert!(body.contains("\"TTL\":15"));
647
648        let txn = EtcdOperations::txn_create_body("mykey", "myval", 12345);
649        assert!(txn.contains("request_put"));
650        assert!(txn.contains("12345"));
651    }
652
653    #[test]
654    fn test_base64_encode() {
655        assert_eq!(base64_encode("hello"), "aGVsbG8=");
656        assert_eq!(base64_encode("ab"), "YWI=");
657        assert_eq!(base64_encode("abc"), "YWJj");
658    }
659
660    #[test]
661    fn test_generate_node_id() {
662        let id = generate_node_id();
663        assert!(!id.is_empty());
664        // Should contain a hyphen-separated structure
665        assert!(id.contains('-'));
666    }
667
668    #[test]
669    fn test_cluster_status_serialization() {
670        let status = ClusterStatus {
671            node_id: "node-1".to_string(),
672            is_leader: true,
673            leader_node: Some("node-1".to_string()),
674            active_nodes: vec![NodeInfo {
675                node_id: "node-1".to_string(),
676                address: "10.0.0.1:8080".to_string(),
677                sessions: 5,
678                last_seen: "2026-01-01T00:00:00Z".to_string(),
679            }],
680            total_sessions: 5,
681            backend: "redis".to_string(),
682        };
683        let json = serde_json::to_string(&status).unwrap();
684        assert!(json.contains("node-1"));
685        assert!(json.contains("is_leader"));
686    }
687
688    #[test]
689    fn test_ha_backend_display() {
690        assert_eq!(HaBackend::Redis.to_string(), "redis");
691        assert_eq!(HaBackend::Etcd.to_string(), "etcd");
692    }
693}