salvo_core/fuse/
flex.rs

1//! A flexible fusewire.
2use std::fmt::{self, Debug, Formatter};
3use std::sync::Arc;
4
5use tokio::sync::Notify;
6use tokio::time::Duration;
7use tokio_util::sync::CancellationToken;
8
9use super::{ArcFusewire, FuseEvent, FuseFactory, FuseInfo, Fusewire, async_trait};
10
11/// A guard action.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum GuardAction {
14    /// Reject the connection.
15    Reject,
16    /// Allow the event to next guards.
17    ToNext,
18    /// Permit the event and skip next guards.
19    Permit,
20}
21/// A guard.
22pub trait Guard: Sync + Send + 'static {
23    /// Check the event.
24    fn check(&self, info: &FuseInfo, event: &FuseEvent) -> GuardAction;
25}
26impl<F> Guard for F
27where
28    F: Fn(&FuseInfo, &FuseEvent) -> GuardAction + Sync + Send + 'static,
29{
30    fn check(&self, info: &FuseInfo, event: &FuseEvent) -> GuardAction {
31        self(info, event)
32    }
33}
34
35/// Skip the quic connection.
36#[must_use]
37pub fn skip_quic(info: &FuseInfo, _event: &FuseEvent) -> GuardAction {
38    if info.trans_proto.is_quic() {
39        GuardAction::Permit
40    } else {
41        GuardAction::ToNext
42    }
43}
44
45/// A simple fusewire.
46pub struct FlexFusewire {
47    info: FuseInfo,
48    guards: Arc<Vec<Box<dyn Guard>>>,
49
50    reject_token: CancellationToken,
51
52    tcp_idle_timeout: Duration,
53    tcp_idle_token: CancellationToken,
54    tcp_idle_notify: Arc<Notify>,
55
56    tcp_frame_timeout: Duration,
57    tcp_frame_token: CancellationToken,
58    tcp_frame_notify: Arc<Notify>,
59
60    tls_handshake_timeout: Duration,
61    tls_handshake_token: CancellationToken,
62    tls_handshake_notify: Arc<Notify>,
63}
64
65impl Debug for FlexFusewire {
66    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
67        f.debug_struct("FlexFusewire")
68            .field("info", &self.info)
69            .field("guards.len", &self.guards.len())
70            .field("tcp_idle_timeout", &self.tcp_idle_timeout)
71            .field("tcp_frame_timeout", &self.tcp_frame_timeout)
72            .field("tls_handshake_timeout", &self.tls_handshake_timeout)
73            .finish()
74    }
75}
76
77impl FlexFusewire {
78    /// Create a new `FlexFusewire`.
79    #[must_use]
80    pub fn new(info: FuseInfo) -> Self {
81        Self::builder().build(info)
82    }
83
84    /// Create a new `FlexFactory`.
85    #[must_use]
86    pub fn builder() -> FlexFactory {
87        FlexFactory::new()
88    }
89    /// Get the timeout for close the idle tcp connection.
90    #[must_use]
91    pub fn tcp_idle_timeout(&self) -> Duration {
92        self.tcp_idle_timeout
93    }
94    /// Get the timeout for close the connection if frame can not be received.
95    #[must_use]
96    pub fn tcp_frame_timeout(&self) -> Duration {
97        self.tcp_frame_timeout
98    }
99    /// Set the timeout for close the connection if handshake not finished.
100    #[must_use]
101    pub fn tls_handshake_timeout(&self) -> Duration {
102        self.tls_handshake_timeout
103    }
104}
105#[async_trait]
106impl Fusewire for FlexFusewire {
107    fn event(&self, event: FuseEvent) {
108        for guard in self.guards.iter() {
109            match guard.check(&self.info, &event) {
110                GuardAction::Permit => {
111                    return;
112                }
113                GuardAction::Reject => {
114                    self.reject_token.cancel();
115                    return;
116                }
117                GuardAction::ToNext => {}
118            }
119        }
120        self.tcp_idle_notify.notify_waiters();
121        match event {
122            FuseEvent::TlsHandshaking => {
123                let tls_handshake_notify = self.tls_handshake_notify.clone();
124                let tls_handshake_timeout = self.tls_handshake_timeout;
125                let tls_handshake_token = self.tls_handshake_token.clone();
126                tokio::spawn(async move {
127                    loop {
128                        if tokio::time::timeout(
129                            tls_handshake_timeout,
130                            tls_handshake_notify.notified(),
131                        )
132                        .await
133                        .is_err()
134                        {
135                            tls_handshake_token.cancel();
136                            break;
137                        }
138                    }
139                });
140            }
141            FuseEvent::TlsHandshaked => {
142                self.tls_handshake_notify.notify_waiters();
143            }
144            FuseEvent::WaitFrame => {
145                let tcp_frame_notify = self.tcp_frame_notify.clone();
146                let tcp_frame_timeout = self.tcp_frame_timeout;
147                let tcp_frame_token = self.tcp_frame_token.clone();
148                tokio::spawn(async move {
149                    if tokio::time::timeout(tcp_frame_timeout, tcp_frame_notify.notified())
150                        .await
151                        .is_err()
152                    {
153                        tcp_frame_token.cancel();
154                    }
155                });
156            }
157            FuseEvent::GainFrame => {
158                self.tcp_frame_notify.notify_waiters();
159            }
160            _ => {}
161        }
162    }
163    async fn fused(&self) {
164        tokio::select! {
165            _ = self.reject_token.cancelled() => {}
166            _ = self.tcp_idle_token.cancelled() => {}
167            _ = self.tcp_frame_token.cancelled() => {}
168            _ = self.tls_handshake_token.cancelled() => {}
169        }
170    }
171}
172
173/// A [`FlexFusewire`] builder.
174#[derive(Clone)]
175pub struct FlexFactory {
176    tcp_idle_timeout: Duration,
177    tcp_frame_timeout: Duration,
178    tls_handshake_timeout: Duration,
179
180    guards: Arc<Vec<Box<dyn Guard>>>,
181}
182
183impl Debug for FlexFactory {
184    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
185        f.debug_struct("FlexFactory")
186            .field("tcp_idle_timeout", &self.tcp_idle_timeout)
187            .field("tcp_frame_timeout", &self.tcp_frame_timeout)
188            .field("tls_handshake_timeout", &self.tls_handshake_timeout)
189            .field("guards.len", &self.guards.len())
190            .finish()
191    }
192}
193
194impl Default for FlexFactory {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl FlexFactory {
201    /// Create a new `FlexFactory`.
202    pub fn new() -> Self {
203        Self {
204            tcp_idle_timeout: Duration::from_secs(30),
205            tcp_frame_timeout: Duration::from_secs(60),
206            tls_handshake_timeout: Duration::from_secs(10),
207            guards: Arc::new(vec![Box::new(skip_quic)]),
208        }
209    }
210
211    /// Set the timeout for close the idle tcp connection.
212    #[must_use]
213    pub fn tcp_idle_timeout(mut self, timeout: Duration) -> Self {
214        self.tcp_idle_timeout = timeout;
215        self
216    }
217    /// Set the timeout for close the connection if frame can not be received.
218    #[must_use]
219    pub fn tcp_frame_timeout(mut self, timeout: Duration) -> Self {
220        self.tcp_frame_timeout = timeout;
221        self
222    }
223
224    /// Set guards to new value.
225    #[must_use]
226    pub fn guards(mut self, guards: Vec<Box<dyn Guard>>) -> Self {
227        self.guards = Arc::new(guards);
228        self
229    }
230    /// Add a guard.
231    #[must_use]
232    pub fn add_guard(mut self, guard: impl Guard) -> Self {
233        Arc::get_mut(&mut self.guards)
234            .expect("guards get mut failed")
235            .push(Box::new(guard));
236        self
237    }
238
239    /// Build a `FlexFusewire`.
240    #[must_use]
241    pub fn build(&self, info: FuseInfo) -> FlexFusewire {
242        let Self {
243            tcp_idle_timeout,
244            tcp_frame_timeout,
245            tls_handshake_timeout,
246            guards,
247        } = self.clone();
248
249        let tcp_idle_token = CancellationToken::new();
250        let tcp_idle_notify = Arc::new(Notify::new());
251        tokio::spawn({
252            let tcp_idle_notify = tcp_idle_notify.clone();
253            let tcp_idle_token = tcp_idle_token.clone();
254            async move {
255                loop {
256                    if tokio::time::timeout(tcp_idle_timeout, tcp_idle_notify.notified())
257                        .await
258                        .is_err()
259                    {
260                        tcp_idle_token.cancel();
261                        break;
262                    }
263                }
264            }
265        });
266        FlexFusewire {
267            info,
268            guards,
269
270            reject_token: CancellationToken::new(),
271
272            tcp_idle_timeout,
273            tcp_idle_token,
274            tcp_idle_notify,
275
276            tcp_frame_timeout,
277            tcp_frame_token: CancellationToken::new(),
278            tcp_frame_notify: Arc::new(Notify::new()),
279
280            tls_handshake_timeout,
281            tls_handshake_token: CancellationToken::new(),
282            tls_handshake_notify: Arc::new(Notify::new()),
283        }
284    }
285}
286
287impl FuseFactory for FlexFactory {
288    fn create(&self, info: FuseInfo) -> ArcFusewire {
289        Arc::new(self.build(info))
290    }
291}