swan_common/interceptor/
cache.rs1use std::any::TypeId;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5pub struct InterceptorCache {
10 method_interceptors: HashMap<TypeId, Arc<dyn std::any::Any + Send + Sync>>,
13}
14
15impl InterceptorCache {
16 pub fn new() -> Self {
18 Self {
19 method_interceptors: HashMap::new(),
20 }
21 }
22
23 pub fn get_or_create<T>(&mut self) -> Arc<T>
33 where
34 T: Default + Send + Sync + 'static,
35 {
36 let type_id = TypeId::of::<T>();
37
38 let any_interceptor = self.method_interceptors
39 .entry(type_id)
40 .or_insert_with(|| Arc::new(T::default()) as Arc<dyn std::any::Any + Send + Sync>)
41 .clone();
42
43 any_interceptor.downcast::<T>().unwrap()
45 }
46
47 pub fn warmup<T>(&mut self)
51 where
52 T: Default + Send + Sync + 'static,
53 {
54 let _ = self.get_or_create::<T>();
55 }
56
57 #[cfg(test)]
59 pub fn clear(&mut self) {
60 self.method_interceptors.clear();
61 }
62
63 pub fn size(&self) -> usize {
65 self.method_interceptors.len()
66 }
67}
68
69impl Default for InterceptorCache {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::interceptor::traits::{NoOpInterceptor, SwanInterceptor};
79 use std::sync::atomic::{AtomicUsize, Ordering};
80 use async_trait::async_trait;
81 use std::borrow::Cow;
82
83 static CREATION_COUNT: AtomicUsize = AtomicUsize::new(0);
84
85 #[derive(Default)]
86 struct TestInterceptor;
87
88 #[async_trait]
89 impl SwanInterceptor for TestInterceptor {
90 async fn before_request<'a>(
91 &self,
92 request: reqwest::RequestBuilder,
93 request_body: &'a [u8],
94 ) -> anyhow::Result<(reqwest::RequestBuilder, Cow<'a, [u8]>)> {
95 CREATION_COUNT.fetch_add(1, Ordering::SeqCst);
96 Ok((request, Cow::Borrowed(request_body)))
97 }
98
99 async fn after_response(
100 &self,
101 response: reqwest::Response,
102 ) -> anyhow::Result<reqwest::Response> {
103 Ok(response)
104 }
105 }
106
107 #[test]
108 fn test_cache_reuses_interceptors() {
109 let mut cache = InterceptorCache::new();
110
111 CREATION_COUNT.store(0, Ordering::SeqCst);
113
114 let interceptor1 = cache.get_or_create::<TestInterceptor>();
116 let interceptor2 = cache.get_or_create::<TestInterceptor>();
117 let interceptor3 = cache.get_or_create::<TestInterceptor>();
118
119 assert!(Arc::ptr_eq(&interceptor1, &interceptor2));
121 assert!(Arc::ptr_eq(&interceptor2, &interceptor3));
122
123 assert_eq!(cache.size(), 1);
125 }
126
127 #[test]
128 fn test_cache_different_types() {
129 let mut cache = InterceptorCache::new();
130
131 let _interceptor1 = cache.get_or_create::<TestInterceptor>();
132 let _interceptor2 = cache.get_or_create::<NoOpInterceptor>();
133
134 assert_eq!(cache.size(), 2);
136 }
137
138 #[test]
139 fn test_warmup() {
140 let mut cache = InterceptorCache::new();
141
142 cache.warmup::<TestInterceptor>();
144 cache.warmup::<NoOpInterceptor>();
145
146 assert_eq!(cache.size(), 2);
147
148 let _interceptor = cache.get_or_create::<TestInterceptor>();
150 assert_eq!(cache.size(), 2); }
152}