1use std::fmt::{self, Debug, Formatter};
3use std::sync::Arc;
4
5use tokio::sync::Notify;
6use tokio::time::Duration;
7use tokio_util::sync::CancellationToken;
8
9use super::{ArcFusewire, FuseEvent, FuseFactory, FuseInfo, Fusewire, async_trait};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum GuardAction {
14 Reject,
16 ToNext,
18 Permit,
20}
21pub trait Guard: Sync + Send + 'static {
23 fn check(&self, info: &FuseInfo, event: &FuseEvent) -> GuardAction;
25}
26impl<F> Guard for F
27where
28 F: Fn(&FuseInfo, &FuseEvent) -> GuardAction + Sync + Send + 'static,
29{
30 fn check(&self, info: &FuseInfo, event: &FuseEvent) -> GuardAction {
31 self(info, event)
32 }
33}
34
35#[must_use]
37pub fn skip_quic(info: &FuseInfo, _event: &FuseEvent) -> GuardAction {
38 if info.trans_proto.is_quic() {
39 GuardAction::Permit
40 } else {
41 GuardAction::ToNext
42 }
43}
44
45pub struct FlexFusewire {
47 info: FuseInfo,
48 guards: Arc<Vec<Box<dyn Guard>>>,
49
50 reject_token: CancellationToken,
51
52 tcp_idle_timeout: Duration,
53 tcp_idle_token: CancellationToken,
54 tcp_idle_notify: Arc<Notify>,
55
56 tcp_frame_timeout: Duration,
57 tcp_frame_token: CancellationToken,
58 tcp_frame_notify: Arc<Notify>,
59
60 tls_handshake_timeout: Duration,
61 tls_handshake_token: CancellationToken,
62 tls_handshake_notify: Arc<Notify>,
63}
64
65impl Debug for FlexFusewire {
66 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
67 f.debug_struct("FlexFusewire")
68 .field("info", &self.info)
69 .field("guards.len", &self.guards.len())
70 .field("tcp_idle_timeout", &self.tcp_idle_timeout)
71 .field("tcp_frame_timeout", &self.tcp_frame_timeout)
72 .field("tls_handshake_timeout", &self.tls_handshake_timeout)
73 .finish()
74 }
75}
76
77impl FlexFusewire {
78 #[must_use]
80 pub fn new(info: FuseInfo) -> Self {
81 Self::builder().build(info)
82 }
83
84 #[must_use]
86 pub fn builder() -> FlexFactory {
87 FlexFactory::new()
88 }
89 #[must_use]
91 pub fn tcp_idle_timeout(&self) -> Duration {
92 self.tcp_idle_timeout
93 }
94 #[must_use]
96 pub fn tcp_frame_timeout(&self) -> Duration {
97 self.tcp_frame_timeout
98 }
99 #[must_use]
101 pub fn tls_handshake_timeout(&self) -> Duration {
102 self.tls_handshake_timeout
103 }
104}
105#[async_trait]
106impl Fusewire for FlexFusewire {
107 fn event(&self, event: FuseEvent) {
108 for guard in self.guards.iter() {
109 match guard.check(&self.info, &event) {
110 GuardAction::Permit => {
111 return;
112 }
113 GuardAction::Reject => {
114 self.reject_token.cancel();
115 return;
116 }
117 GuardAction::ToNext => {}
118 }
119 }
120 self.tcp_idle_notify.notify_waiters();
121 match event {
122 FuseEvent::TlsHandshaking => {
123 let tls_handshake_notify = self.tls_handshake_notify.clone();
124 let tls_handshake_timeout = self.tls_handshake_timeout;
125 let tls_handshake_token = self.tls_handshake_token.clone();
126 tokio::spawn(async move {
127 loop {
128 if tokio::time::timeout(
129 tls_handshake_timeout,
130 tls_handshake_notify.notified(),
131 )
132 .await
133 .is_err()
134 {
135 tls_handshake_token.cancel();
136 break;
137 }
138 }
139 });
140 }
141 FuseEvent::TlsHandshaked => {
142 self.tls_handshake_notify.notify_waiters();
143 }
144 FuseEvent::WaitFrame => {
145 let tcp_frame_notify = self.tcp_frame_notify.clone();
146 let tcp_frame_timeout = self.tcp_frame_timeout;
147 let tcp_frame_token = self.tcp_frame_token.clone();
148 tokio::spawn(async move {
149 if tokio::time::timeout(tcp_frame_timeout, tcp_frame_notify.notified())
150 .await
151 .is_err()
152 {
153 tcp_frame_token.cancel();
154 }
155 });
156 }
157 FuseEvent::GainFrame => {
158 self.tcp_frame_notify.notify_waiters();
159 }
160 _ => {}
161 }
162 }
163 async fn fused(&self) {
164 tokio::select! {
165 _ = self.reject_token.cancelled() => {}
166 _ = self.tcp_idle_token.cancelled() => {}
167 _ = self.tcp_frame_token.cancelled() => {}
168 _ = self.tls_handshake_token.cancelled() => {}
169 }
170 }
171}
172
173#[derive(Clone)]
175pub struct FlexFactory {
176 tcp_idle_timeout: Duration,
177 tcp_frame_timeout: Duration,
178 tls_handshake_timeout: Duration,
179
180 guards: Arc<Vec<Box<dyn Guard>>>,
181}
182
183impl Debug for FlexFactory {
184 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
185 f.debug_struct("FlexFactory")
186 .field("tcp_idle_timeout", &self.tcp_idle_timeout)
187 .field("tcp_frame_timeout", &self.tcp_frame_timeout)
188 .field("tls_handshake_timeout", &self.tls_handshake_timeout)
189 .field("guards.len", &self.guards.len())
190 .finish()
191 }
192}
193
194impl Default for FlexFactory {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl FlexFactory {
201 pub fn new() -> Self {
203 Self {
204 tcp_idle_timeout: Duration::from_secs(30),
205 tcp_frame_timeout: Duration::from_secs(60),
206 tls_handshake_timeout: Duration::from_secs(10),
207 guards: Arc::new(vec![Box::new(skip_quic)]),
208 }
209 }
210
211 #[must_use]
213 pub fn tcp_idle_timeout(mut self, timeout: Duration) -> Self {
214 self.tcp_idle_timeout = timeout;
215 self
216 }
217 #[must_use]
219 pub fn tcp_frame_timeout(mut self, timeout: Duration) -> Self {
220 self.tcp_frame_timeout = timeout;
221 self
222 }
223
224 #[must_use]
226 pub fn guards(mut self, guards: Vec<Box<dyn Guard>>) -> Self {
227 self.guards = Arc::new(guards);
228 self
229 }
230 #[must_use]
232 pub fn add_guard(mut self, guard: impl Guard) -> Self {
233 Arc::get_mut(&mut self.guards)
234 .expect("guards get mut failed")
235 .push(Box::new(guard));
236 self
237 }
238
239 #[must_use]
241 pub fn build(&self, info: FuseInfo) -> FlexFusewire {
242 let Self {
243 tcp_idle_timeout,
244 tcp_frame_timeout,
245 tls_handshake_timeout,
246 guards,
247 } = self.clone();
248
249 let tcp_idle_token = CancellationToken::new();
250 let tcp_idle_notify = Arc::new(Notify::new());
251 tokio::spawn({
252 let tcp_idle_notify = tcp_idle_notify.clone();
253 let tcp_idle_token = tcp_idle_token.clone();
254 async move {
255 loop {
256 if tokio::time::timeout(tcp_idle_timeout, tcp_idle_notify.notified())
257 .await
258 .is_err()
259 {
260 tcp_idle_token.cancel();
261 break;
262 }
263 }
264 }
265 });
266 FlexFusewire {
267 info,
268 guards,
269
270 reject_token: CancellationToken::new(),
271
272 tcp_idle_timeout,
273 tcp_idle_token,
274 tcp_idle_notify,
275
276 tcp_frame_timeout,
277 tcp_frame_token: CancellationToken::new(),
278 tcp_frame_notify: Arc::new(Notify::new()),
279
280 tls_handshake_timeout,
281 tls_handshake_token: CancellationToken::new(),
282 tls_handshake_notify: Arc::new(Notify::new()),
283 }
284 }
285}
286
287impl FuseFactory for FlexFactory {
288 fn create(&self, info: FuseInfo) -> ArcFusewire {
289 Arc::new(self.build(info))
290 }
291}