Skip to main content

wae_distributed/
lib.rs

1//! WAE Distributed - 分布式能力抽象层
2//!
3//! 提供统一的分布式系统能力抽象,包括分布式锁、功能开关、分布式ID生成等。
4//!
5//! 深度融合 tokio 运行时,所有 API 都是异步优先设计。
6//! 微服务架构友好,支持高可用、高并发场景。
7
8#![warn(missing_docs)]
9
10use serde::{Deserialize, Serialize};
11use std::{collections::HashSet, fmt, sync::Arc, time::Duration};
12
13pub use feature_flag::{FeatureFlag, FeatureFlagManager, FlagDefinition, Strategy, evaluate};
14pub use id_generator::{IdGenerator, SnowflakeGenerator, UuidGenerator};
15pub use lock::{DistributedLock, InMemoryLock, InMemoryLockManager, LockError, LockOptions};
16
17mod lock {
18    use super::*;
19
20    /// 分布式锁错误类型
21    #[derive(Debug)]
22    pub enum LockError {
23        /// 获取锁失败
24        AcquireFailed(String),
25
26        /// 锁已过期
27        Expired(String),
28
29        /// 锁不存在
30        NotFound(String),
31
32        /// 释放锁失败
33        ReleaseFailed(String),
34
35        /// 等待超时
36        WaitTimeout(String),
37
38        /// 内部错误
39        Internal(String),
40    }
41
42    impl fmt::Display for LockError {
43        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44            match self {
45                LockError::AcquireFailed(msg) => write!(f, "Failed to acquire lock: {}", msg),
46                LockError::Expired(msg) => write!(f, "Lock expired: {}", msg),
47                LockError::NotFound(msg) => write!(f, "Lock not found: {}", msg),
48                LockError::ReleaseFailed(msg) => write!(f, "Failed to release lock: {}", msg),
49                LockError::WaitTimeout(msg) => write!(f, "Wait timeout: {}", msg),
50                LockError::Internal(msg) => write!(f, "Lock internal error: {}", msg),
51            }
52        }
53    }
54
55    impl std::error::Error for LockError {}
56
57    /// 锁操作结果类型
58    pub type LockResult<T> = Result<T, LockError>;
59
60    /// 锁选项
61    #[derive(Debug, Clone)]
62    pub struct LockOptions {
63        /// 锁的生存时间
64        pub ttl: Duration,
65        /// 等待获取锁的超时时间
66        pub wait_timeout: Duration,
67    }
68
69    impl Default for LockOptions {
70        fn default() -> Self {
71            Self { ttl: Duration::from_secs(30), wait_timeout: Duration::from_secs(10) }
72        }
73    }
74
75    impl LockOptions {
76        /// 创建新的锁选项
77        pub fn new() -> Self {
78            Self::default()
79        }
80
81        /// 设置锁的生存时间
82        pub fn with_ttl(mut self, ttl: Duration) -> Self {
83            self.ttl = ttl;
84            self
85        }
86
87        /// 设置等待获取锁的超时时间
88        pub fn with_wait_timeout(mut self, timeout: Duration) -> Self {
89            self.wait_timeout = timeout;
90            self
91        }
92    }
93
94    /// 分布式锁 trait (dyn 兼容)
95    #[allow(async_fn_in_trait)]
96    pub trait DistributedLock: Send + Sync {
97        /// 获取锁 (阻塞直到获取成功或超时)
98        async fn lock(&self) -> LockResult<()>;
99
100        /// 尝试获取锁 (非阻塞)
101        async fn try_lock(&self) -> LockResult<bool>;
102
103        /// 释放锁
104        async fn unlock(&self) -> LockResult<()>;
105
106        /// 带超时的获取锁
107        async fn lock_with_timeout(&self, timeout: Duration) -> LockResult<()>;
108
109        /// 获取锁的键名
110        fn key(&self) -> &str;
111
112        /// 检查锁是否被持有
113        async fn is_locked(&self) -> bool;
114    }
115
116    /// 内存锁实现 (用于单机测试)
117    pub struct InMemoryLock {
118        key: String,
119        manager: Arc<InMemoryLockManager>,
120    }
121
122    impl InMemoryLock {
123        /// 创建新的内存锁
124        pub fn new(key: impl Into<String>, manager: Arc<InMemoryLockManager>) -> Self {
125            Self { key: key.into(), manager }
126        }
127    }
128
129    impl DistributedLock for InMemoryLock {
130        async fn lock(&self) -> LockResult<()> {
131            self.lock_with_timeout(Duration::from_secs(30)).await
132        }
133
134        async fn try_lock(&self) -> LockResult<bool> {
135            self.manager.acquire_lock(&self.key, Duration::from_secs(30)).await
136        }
137
138        async fn unlock(&self) -> LockResult<()> {
139            self.manager.release_lock(&self.key).await
140        }
141
142        async fn lock_with_timeout(&self, timeout: Duration) -> LockResult<()> {
143            let start = std::time::Instant::now();
144            loop {
145                if self.manager.acquire_lock(&self.key, Duration::from_secs(30)).await? {
146                    return Ok(());
147                }
148                if start.elapsed() >= timeout {
149                    return Err(LockError::WaitTimeout(format!("Lock key: {}", self.key)));
150                }
151                tokio::time::sleep(Duration::from_millis(50)).await;
152            }
153        }
154
155        fn key(&self) -> &str {
156            &self.key
157        }
158
159        async fn is_locked(&self) -> bool {
160            self.manager.is_locked(&self.key).await
161        }
162    }
163
164    /// 内存锁管理器
165    pub struct InMemoryLockManager {
166        locks: parking_lot::RwLock<HashSet<String>>,
167    }
168
169    impl InMemoryLockManager {
170        /// 创建新的内存锁管理器
171        pub fn new() -> Self {
172            Self { locks: parking_lot::RwLock::new(HashSet::new()) }
173        }
174
175        /// 创建锁实例
176        pub fn create_lock(&self, key: impl Into<String>) -> InMemoryLock {
177            InMemoryLock::new(key, Arc::new(self.clone()))
178        }
179
180        async fn acquire_lock(&self, key: &str, _ttl: Duration) -> LockResult<bool> {
181            let mut locks = self.locks.write();
182            if locks.contains(key) {
183                return Ok(false);
184            }
185            locks.insert(key.to_string());
186            Ok(true)
187        }
188
189        async fn release_lock(&self, key: &str) -> LockResult<()> {
190            let mut locks = self.locks.write();
191            if locks.remove(key) { Ok(()) } else { Err(LockError::NotFound(key.to_string())) }
192        }
193
194        async fn is_locked(&self, key: &str) -> bool {
195            self.locks.read().contains(key)
196        }
197    }
198
199    impl Default for InMemoryLockManager {
200        fn default() -> Self {
201            Self::new()
202        }
203    }
204
205    impl Clone for InMemoryLockManager {
206        fn clone(&self) -> Self {
207            Self { locks: parking_lot::RwLock::new(self.locks.read().clone()) }
208        }
209    }
210}
211
212mod feature_flag {
213    use super::*;
214
215    /// 功能开关策略
216    #[derive(Debug, Clone, Default, Serialize, Deserialize)]
217    pub enum Strategy {
218        /// 始终开启
219        On,
220        /// 始终关闭
221        #[default]
222        Off,
223        /// 百分比灰度
224        Percentage(u32),
225        /// 用户白名单
226        UserList(Vec<String>),
227    }
228
229    /// 功能开关定义
230    #[derive(Debug, Clone, Serialize, Deserialize)]
231    pub struct FlagDefinition {
232        /// 开关名称
233        pub name: String,
234        /// 开关描述
235        pub description: String,
236        /// 开关策略
237        pub strategy: Strategy,
238        /// 是否启用
239        pub enabled: bool,
240    }
241
242    impl FlagDefinition {
243        /// 创建新的功能开关定义
244        pub fn new(name: impl Into<String>) -> Self {
245            Self { name: name.into(), description: String::new(), strategy: Strategy::default(), enabled: false }
246        }
247
248        /// 设置描述
249        pub fn with_description(mut self, description: impl Into<String>) -> Self {
250            self.description = description.into();
251            self
252        }
253
254        /// 设置策略
255        pub fn with_strategy(mut self, strategy: Strategy) -> Self {
256            self.strategy = strategy;
257            self
258        }
259
260        /// 设置启用状态
261        pub fn with_enabled(mut self, enabled: bool) -> Self {
262            self.enabled = enabled;
263            self
264        }
265    }
266
267    /// 功能开关 trait (dyn 兼容)
268    #[allow(async_fn_in_trait)]
269    pub trait FeatureFlag: Send + Sync {
270        /// 检查开关是否启用
271        async fn is_enabled(&self, key: &str) -> bool;
272
273        /// 检查开关是否启用 (带用户上下文)
274        async fn is_enabled_for_user(&self, key: &str, user_id: &str) -> bool;
275
276        /// 获取开关变体
277        async fn get_variant(&self, key: &str) -> Option<String>;
278    }
279
280    /// 功能开关管理器
281    pub struct FeatureFlagManager {
282        flags: parking_lot::RwLock<std::collections::HashMap<String, FlagDefinition>>,
283    }
284
285    impl FeatureFlagManager {
286        /// 创建新的功能开关管理器
287        pub fn new() -> Self {
288            Self { flags: parking_lot::RwLock::new(std::collections::HashMap::new()) }
289        }
290
291        /// 注册功能开关
292        pub fn register(&self, flag: FlagDefinition) {
293            let mut flags = self.flags.write();
294            flags.insert(flag.name.clone(), flag);
295        }
296
297        /// 注销功能开关
298        pub fn unregister(&self, name: &str) -> bool {
299            let mut flags = self.flags.write();
300            flags.remove(name).is_some()
301        }
302
303        /// 获取功能开关定义
304        pub fn get(&self, name: &str) -> Option<FlagDefinition> {
305            let flags = self.flags.read();
306            flags.get(name).cloned()
307        }
308
309        /// 更新功能开关
310        pub fn update(&self, name: &str, enabled: bool) -> bool {
311            let mut flags = self.flags.write();
312            if let Some(flag) = flags.get_mut(name) {
313                flag.enabled = enabled;
314                return true;
315            }
316            false
317        }
318
319        /// 列出所有功能开关
320        pub fn list(&self) -> Vec<FlagDefinition> {
321            let flags = self.flags.read();
322            flags.values().cloned().collect()
323        }
324    }
325
326    impl Default for FeatureFlagManager {
327        fn default() -> Self {
328            Self::new()
329        }
330    }
331
332    impl FeatureFlag for FeatureFlagManager {
333        async fn is_enabled(&self, key: &str) -> bool {
334            let flags = self.flags.read();
335            if let Some(flag) = flags.get(key) {
336                return flag.enabled && matches!(flag.strategy, Strategy::On);
337            }
338            false
339        }
340
341        async fn is_enabled_for_user(&self, key: &str, user_id: &str) -> bool {
342            let flags = self.flags.read();
343            if let Some(flag) = flags.get(key) {
344                if !flag.enabled {
345                    return false;
346                }
347                return evaluate(&flag.strategy, user_id);
348            }
349            false
350        }
351
352        async fn get_variant(&self, key: &str) -> Option<String> {
353            let flags = self.flags.read();
354            flags.get(key).and_then(|f| if f.enabled { Some(f.name.clone()) } else { None })
355        }
356    }
357
358    /// 评估开关状态
359    pub fn evaluate(strategy: &Strategy, user_id: &str) -> bool {
360        match strategy {
361            Strategy::On => true,
362            Strategy::Off => false,
363            Strategy::Percentage(pct) => {
364                let hash = calculate_hash(user_id);
365                let bucket = hash % 100;
366                bucket < *pct as u64
367            }
368            Strategy::UserList(users) => users.contains(&user_id.to_string()),
369        }
370    }
371
372    fn calculate_hash(s: &str) -> u64 {
373        let mut hash: u64 = 0;
374        for c in s.chars() {
375            hash = hash.wrapping_mul(31).wrapping_add(c as u64);
376        }
377        hash
378    }
379}
380
381mod id_generator {
382    use parking_lot::Mutex;
383    use std::time::{SystemTime, UNIX_EPOCH};
384
385    /// ID 生成器 trait (dyn 兼容)
386    #[allow(async_fn_in_trait)]
387    pub trait IdGenerator: Send + Sync {
388        /// 生成单个 ID
389        async fn generate(&self) -> String;
390
391        /// 批量生成 ID
392        async fn generate_batch(&self, count: usize) -> Vec<String>;
393    }
394
395    /// 雪花算法 ID 生成器
396    pub struct SnowflakeGenerator {
397        worker_id: u64,
398        datacenter_id: u64,
399        sequence: Mutex<u64>,
400        last_timestamp: Mutex<u64>,
401    }
402
403    impl SnowflakeGenerator {
404        const EPOCH: u64 = 1704067200000;
405        const WORKER_ID_BITS: u64 = 5;
406        const DATACENTER_ID_BITS: u64 = 5;
407        const SEQUENCE_BITS: u64 = 12;
408
409        const MAX_WORKER_ID: u64 = (1 << Self::WORKER_ID_BITS) - 1;
410        const MAX_DATACENTER_ID: u64 = (1 << Self::DATACENTER_ID_BITS) - 1;
411        const SEQUENCE_MASK: u64 = (1 << Self::SEQUENCE_BITS) - 1;
412
413        const WORKER_ID_SHIFT: u64 = Self::SEQUENCE_BITS;
414        const DATACENTER_ID_SHIFT: u64 = Self::SEQUENCE_BITS + Self::WORKER_ID_BITS;
415        const TIMESTAMP_SHIFT: u64 = Self::SEQUENCE_BITS + Self::WORKER_ID_BITS + Self::DATACENTER_ID_BITS;
416
417        /// 创建新的雪花算法 ID 生成器
418        pub fn new(worker_id: u64, datacenter_id: u64) -> Result<Self, String> {
419            if worker_id > Self::MAX_WORKER_ID {
420                return Err(format!("Worker ID must be between 0 and {}", Self::MAX_WORKER_ID));
421            }
422            if datacenter_id > Self::MAX_DATACENTER_ID {
423                return Err(format!("Datacenter ID must be between 0 and {}", Self::MAX_DATACENTER_ID));
424            }
425            Ok(Self { worker_id, datacenter_id, sequence: Mutex::new(0), last_timestamp: Mutex::new(0) })
426        }
427
428        fn current_timestamp() -> u64 {
429            SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64
430        }
431
432        fn til_next_millis(last_timestamp: u64) -> u64 {
433            let mut timestamp = Self::current_timestamp();
434            while timestamp <= last_timestamp {
435                timestamp = Self::current_timestamp();
436            }
437            timestamp
438        }
439
440        fn generate_id(&self) -> u64 {
441            let mut sequence = self.sequence.lock();
442            let mut last_timestamp = self.last_timestamp.lock();
443
444            let timestamp = Self::current_timestamp();
445
446            if timestamp < *last_timestamp {
447                panic!("Clock moved backwards!");
448            }
449
450            if timestamp == *last_timestamp {
451                *sequence = (*sequence + 1) & Self::SEQUENCE_MASK;
452                if *sequence == 0 {
453                    *last_timestamp = Self::til_next_millis(*last_timestamp);
454                }
455            }
456            else {
457                *sequence = 0;
458            }
459
460            *last_timestamp = timestamp;
461
462            ((timestamp - Self::EPOCH) << Self::TIMESTAMP_SHIFT)
463                | (self.datacenter_id << Self::DATACENTER_ID_SHIFT)
464                | (self.worker_id << Self::WORKER_ID_SHIFT)
465                | *sequence
466        }
467    }
468
469    impl IdGenerator for SnowflakeGenerator {
470        async fn generate(&self) -> String {
471            self.generate_id().to_string()
472        }
473
474        async fn generate_batch(&self, count: usize) -> Vec<String> {
475            (0..count).map(|_| self.generate_id().to_string()).collect()
476        }
477    }
478
479    /// UUID 生成器
480    pub struct UuidGenerator {
481        version: UuidVersion,
482    }
483
484    /// UUID 版本
485    #[derive(Debug, Clone, Copy, Default)]
486    pub enum UuidVersion {
487        /// V4 随机 UUID
488        #[default]
489        V4,
490        /// V7 时间排序 UUID
491        V7,
492    }
493
494    impl UuidGenerator {
495        /// 创建新的 UUID 生成器
496        pub fn new(version: UuidVersion) -> Self {
497            Self { version }
498        }
499
500        /// 创建 V4 UUID 生成器
501        pub fn v4() -> Self {
502            Self::new(UuidVersion::V4)
503        }
504
505        /// 创建 V7 UUID 生成器
506        pub fn v7() -> Self {
507            Self::new(UuidVersion::V7)
508        }
509    }
510
511    impl Default for UuidGenerator {
512        fn default() -> Self {
513            Self::v4()
514        }
515    }
516
517    impl IdGenerator for UuidGenerator {
518        async fn generate(&self) -> String {
519            match self.version {
520                UuidVersion::V4 => uuid::Uuid::new_v4().to_string(),
521                UuidVersion::V7 => {
522                    let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64;
523                    let random_bytes: [u8; 10] = {
524                        let mut bytes = [0u8; 10];
525                        for byte in &mut bytes {
526                            *byte = rand_byte();
527                        }
528                        bytes
529                    };
530
531                    let mut uuid_bytes = [0u8; 16];
532                    uuid_bytes[0..6].copy_from_slice(&now.to_be_bytes()[2..8]);
533                    uuid_bytes[6..16].copy_from_slice(&random_bytes);
534
535                    uuid_bytes[6] = (uuid_bytes[6] & 0x0F) | 0x70;
536                    uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80;
537
538                    uuid::Uuid::from_bytes(uuid_bytes).to_string()
539                }
540            }
541        }
542
543        async fn generate_batch(&self, count: usize) -> Vec<String> {
544            let mut result = Vec::with_capacity(count);
545            for _ in 0..count {
546                result.push(self.generate().await);
547            }
548            result
549        }
550    }
551
552    fn rand_byte() -> u8 {
553        use std::{
554            collections::hash_map::RandomState,
555            hash::{BuildHasher, Hasher},
556        };
557        let state = RandomState::new();
558        let mut hasher = state.build_hasher();
559        hasher.write_u64(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64);
560        (hasher.finish() & 0xFF) as u8
561    }
562}