swan_common/interceptor/
cache.rs

1use std::any::TypeId;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5/// 拦截器缓存管理器
6/// 
7/// 为每个客户端实例管理拦截器的生命周期,避免重复创建。
8/// 使用 Arc 来实现零成本的拦截器共享。
9pub struct InterceptorCache {
10    /// 方法级拦截器缓存,按类型ID索引
11    /// 使用类型擦除存储以支持不同的状态类型
12    method_interceptors: HashMap<TypeId, Arc<dyn std::any::Any + Send + Sync>>,
13}
14
15impl InterceptorCache {
16    /// 创建新的拦截器缓存
17    pub fn new() -> Self {
18        Self {
19            method_interceptors: HashMap::new(),
20        }
21    }
22
23    /// 获取或创建方法级拦截器
24    /// 
25    /// # 类型参数
26    /// 
27    /// * `T` - 拦截器类型,必须实现 Default + Send + Sync
28    /// 
29    /// # 返回值
30    /// 
31    /// 返回拦截器的 Arc 引用
32    pub fn get_or_create<T>(&mut self) -> Arc<T>
33    where
34        T: Default + Send + Sync + 'static,
35    {
36        let type_id = TypeId::of::<T>();
37        
38        let any_interceptor = self.method_interceptors
39            .entry(type_id)
40            .or_insert_with(|| Arc::new(T::default()) as Arc<dyn std::any::Any + Send + Sync>)
41            .clone();
42            
43        // 安全的向下转型,因为我们知道确切的类型
44        any_interceptor.downcast::<T>().unwrap()
45    }
46
47    /// 预热拦截器缓存
48    /// 
49    /// 在客户端初始化时调用,预先创建常用的拦截器实例
50    pub fn warmup<T>(&mut self)
51    where
52        T: Default + Send + Sync + 'static,
53    {
54        let _ = self.get_or_create::<T>();
55    }
56
57    /// 清空缓存(主要用于测试)
58    #[cfg(test)]
59    pub fn clear(&mut self) {
60        self.method_interceptors.clear();
61    }
62
63    /// 获取缓存大小(用于监控)
64    pub fn size(&self) -> usize {
65        self.method_interceptors.len()
66    }
67}
68
69impl Default for InterceptorCache {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::interceptor::traits::{NoOpInterceptor, SwanInterceptor};
79    use std::sync::atomic::{AtomicUsize, Ordering};
80    use async_trait::async_trait;
81    use std::borrow::Cow;
82
83    static CREATION_COUNT: AtomicUsize = AtomicUsize::new(0);
84
85    #[derive(Default)]
86    struct TestInterceptor;
87
88    #[async_trait]
89    impl SwanInterceptor for TestInterceptor {
90        async fn before_request<'a>(
91            &self,
92            request: reqwest::RequestBuilder,
93            request_body: &'a [u8],
94        ) -> anyhow::Result<(reqwest::RequestBuilder, Cow<'a, [u8]>)> {
95            CREATION_COUNT.fetch_add(1, Ordering::SeqCst);
96            Ok((request, Cow::Borrowed(request_body)))
97        }
98
99        async fn after_response(
100            &self,
101            response: reqwest::Response,
102        ) -> anyhow::Result<reqwest::Response> {
103            Ok(response)
104        }
105    }
106
107    #[test]
108    fn test_cache_reuses_interceptors() {
109        let mut cache = InterceptorCache::new();
110        
111        // 重置计数器
112        CREATION_COUNT.store(0, Ordering::SeqCst);
113        
114        // 多次获取同一类型的拦截器
115        let interceptor1 = cache.get_or_create::<TestInterceptor>();
116        let interceptor2 = cache.get_or_create::<TestInterceptor>();
117        let interceptor3 = cache.get_or_create::<TestInterceptor>();
118        
119        // 验证是同一个实例(Arc 指针相同)
120        assert!(Arc::ptr_eq(&interceptor1, &interceptor2));
121        assert!(Arc::ptr_eq(&interceptor2, &interceptor3));
122        
123        // 验证缓存大小
124        assert_eq!(cache.size(), 1);
125    }
126
127    #[test]
128    fn test_cache_different_types() {
129        let mut cache = InterceptorCache::new();
130        
131        let _interceptor1 = cache.get_or_create::<TestInterceptor>();
132        let _interceptor2 = cache.get_or_create::<NoOpInterceptor>();
133        
134        // 不同类型应该创建不同的实例
135        assert_eq!(cache.size(), 2);
136    }
137
138    #[test]
139    fn test_warmup() {
140        let mut cache = InterceptorCache::new();
141        
142        // 预热缓存
143        cache.warmup::<TestInterceptor>();
144        cache.warmup::<NoOpInterceptor>();
145        
146        assert_eq!(cache.size(), 2);
147        
148        // 后续获取应该直接返回缓存的实例
149        let _interceptor = cache.get_or_create::<TestInterceptor>();
150        assert_eq!(cache.size(), 2); // 大小不变
151    }
152}