1use crate::fetch::{FetchKind, Terminator};
2use crate::middleware::MiddlewareKind;
3
4#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
5pub enum Phase {
6 L4Raw,
7 L4Peeked,
8 L7Request,
9 L7Response,
10 Tunnel,
11}
12
13#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
14pub enum PhaseNodeKind {
15 Check,
16 Middleware(MiddlewareKind),
17 Upgrade,
18 Fetch(FetchKind),
19 Terminate(Terminator),
20}
21
22#[derive(Copy, Clone, Eq, PartialEq, Debug)]
23pub enum Transition {
24 PassThrough,
25 Into(Phase),
26 BiOutcome { response: Phase, tunnel: Phase },
27 Terminal,
28}
29
30#[derive(Copy, Clone, Eq, PartialEq, Debug)]
31pub struct PhaseError {
32 pub expected: &'static [Phase],
33 pub got: Phase,
34}
35
36const L4_ANY: &[Phase] = &[Phase::L4Raw, Phase::L4Peeked];
37const L7_REQ: &[Phase] = &[Phase::L7Request];
38const L7_RESP: &[Phase] = &[Phase::L7Response];
39const TUNNEL: &[Phase] = &[Phase::Tunnel];
40const ANY_PHASE: &[Phase] =
41 &[Phase::L4Raw, Phase::L4Peeked, Phase::L7Request, Phase::L7Response, Phase::Tunnel];
42
43#[must_use]
47#[allow(clippy::match_same_arms)]
48pub const fn accepted_in_phases(kind: PhaseNodeKind) -> &'static [Phase] {
49 match kind {
50 PhaseNodeKind::Check => ANY_PHASE,
51 PhaseNodeKind::Middleware(MiddlewareKind::L4Peek) => L4_ANY,
52 PhaseNodeKind::Middleware(MiddlewareKind::L4Bytes) => L4_ANY,
53 PhaseNodeKind::Middleware(MiddlewareKind::L7Request) => L7_REQ,
54 PhaseNodeKind::Middleware(MiddlewareKind::L7Response) => L7_RESP,
55 PhaseNodeKind::Upgrade => L4_ANY,
60 PhaseNodeKind::Fetch(FetchKind::L4Forward) => L4_ANY,
61 PhaseNodeKind::Fetch(FetchKind::HttpProxy) => L7_REQ,
62 PhaseNodeKind::Fetch(FetchKind::HttpSynthesize) => L7_REQ,
63 PhaseNodeKind::Fetch(FetchKind::WebSocketUpgrade) => L7_REQ,
64 PhaseNodeKind::Terminate(Terminator::WriteHttpResponse) => L7_RESP,
65 PhaseNodeKind::Terminate(Terminator::ByteTunnel) => TUNNEL,
66 PhaseNodeKind::Terminate(Terminator::Close) => ANY_PHASE,
69 }
70}
71
72#[allow(clippy::match_same_arms)]
78pub fn transition(kind: PhaseNodeKind, cur: Phase) -> Result<Transition, PhaseError> {
79 let accepted = accepted_in_phases(kind);
80 if !accepted.contains(&cur) {
81 return Err(PhaseError { expected: accepted, got: cur });
82 }
83 Ok(match kind {
84 PhaseNodeKind::Check => Transition::PassThrough,
85 PhaseNodeKind::Middleware(MiddlewareKind::L4Peek) => Transition::Into(Phase::L4Peeked),
86 PhaseNodeKind::Middleware(MiddlewareKind::L4Bytes) => Transition::PassThrough,
87 PhaseNodeKind::Middleware(MiddlewareKind::L7Request) => Transition::Into(Phase::L7Request),
88 PhaseNodeKind::Middleware(MiddlewareKind::L7Response) => Transition::Into(Phase::L7Response),
89 PhaseNodeKind::Upgrade => Transition::Into(Phase::L7Request),
90 PhaseNodeKind::Fetch(FetchKind::L4Forward) => Transition::Into(Phase::Tunnel),
91 PhaseNodeKind::Fetch(FetchKind::HttpProxy) => Transition::Into(Phase::L7Response),
92 PhaseNodeKind::Fetch(FetchKind::HttpSynthesize) => Transition::Into(Phase::L7Response),
93 PhaseNodeKind::Fetch(FetchKind::WebSocketUpgrade) => {
94 Transition::BiOutcome { response: Phase::L7Response, tunnel: Phase::Tunnel }
95 }
96 PhaseNodeKind::Terminate(_) => Transition::Terminal,
97 })
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 const ALL_PHASES: [Phase; 5] =
105 [Phase::L4Raw, Phase::L4Peeked, Phase::L7Request, Phase::L7Response, Phase::Tunnel];
106
107 #[test]
108 fn phase_serde_round_trip_per_variant() {
109 for p in ALL_PHASES {
110 let encoded = serde_json::to_string(&p).expect("serialize");
111 let decoded: Phase = serde_json::from_str(&encoded).expect("deserialize");
112 assert_eq!(decoded, p);
113 }
114 }
115
116 #[test]
117 fn check_accepts_any_phase() {
118 assert_eq!(accepted_in_phases(PhaseNodeKind::Check), ANY_PHASE);
119 }
120
121 #[test]
122 fn l4_peek_accepts_l4_phases_only() {
123 assert_eq!(
124 accepted_in_phases(PhaseNodeKind::Middleware(MiddlewareKind::L4Peek)),
125 &[Phase::L4Raw, Phase::L4Peeked] as &[Phase],
126 );
127 }
128
129 #[test]
130 fn l4_bytes_accepts_l4_phases_only() {
131 assert_eq!(
132 accepted_in_phases(PhaseNodeKind::Middleware(MiddlewareKind::L4Bytes)),
133 &[Phase::L4Raw, Phase::L4Peeked] as &[Phase],
134 );
135 }
136
137 #[test]
138 fn l7_request_middleware_accepts_only_l7_request() {
139 assert_eq!(
140 accepted_in_phases(PhaseNodeKind::Middleware(MiddlewareKind::L7Request)),
141 &[Phase::L7Request] as &[Phase],
142 );
143 }
144
145 #[test]
146 fn l7_response_middleware_accepts_only_l7_response() {
147 assert_eq!(
148 accepted_in_phases(PhaseNodeKind::Middleware(MiddlewareKind::L7Response)),
149 &[Phase::L7Response] as &[Phase],
150 );
151 }
152
153 #[test]
154 fn upgrade_accepts_both_l4_phases() {
155 assert_eq!(
158 accepted_in_phases(PhaseNodeKind::Upgrade),
159 &[Phase::L4Raw, Phase::L4Peeked] as &[Phase],
160 );
161 }
162
163 #[test]
164 fn l4_forward_fetch_accepts_l4_phases() {
165 assert_eq!(
166 accepted_in_phases(PhaseNodeKind::Fetch(FetchKind::L4Forward)),
167 &[Phase::L4Raw, Phase::L4Peeked] as &[Phase],
168 );
169 }
170
171 #[test]
172 fn http_fetches_accept_only_l7_request() {
173 for f in [FetchKind::HttpProxy, FetchKind::HttpSynthesize, FetchKind::WebSocketUpgrade] {
174 assert_eq!(accepted_in_phases(PhaseNodeKind::Fetch(f)), &[Phase::L7Request] as &[Phase],);
175 }
176 }
177
178 #[test]
179 fn write_http_response_accepts_only_l7_response() {
180 assert_eq!(
181 accepted_in_phases(PhaseNodeKind::Terminate(Terminator::WriteHttpResponse)),
182 &[Phase::L7Response] as &[Phase],
183 );
184 }
185
186 #[test]
187 fn byte_tunnel_accepts_only_tunnel() {
188 assert_eq!(
189 accepted_in_phases(PhaseNodeKind::Terminate(Terminator::ByteTunnel)),
190 &[Phase::Tunnel] as &[Phase],
191 );
192 }
193
194 #[test]
195 fn check_is_pass_through_at_every_phase() {
196 for cur in ALL_PHASES {
197 assert_eq!(transition(PhaseNodeKind::Check, cur), Ok(Transition::PassThrough));
198 }
199 }
200
201 #[test]
202 fn l4_peek_forces_out_to_l4_peeked() {
203 for cur in [Phase::L4Raw, Phase::L4Peeked] {
204 assert_eq!(
205 transition(PhaseNodeKind::Middleware(MiddlewareKind::L4Peek), cur),
206 Ok(Transition::Into(Phase::L4Peeked)),
207 );
208 }
209 }
210
211 #[test]
212 fn l4_bytes_is_pass_through_on_l4_phases() {
213 for cur in [Phase::L4Raw, Phase::L4Peeked] {
214 assert_eq!(
215 transition(PhaseNodeKind::Middleware(MiddlewareKind::L4Bytes), cur),
216 Ok(Transition::PassThrough),
217 );
218 }
219 }
220
221 #[test]
222 fn upgrade_transitions_to_l7_request_from_any_l4_phase() {
223 for cur in [Phase::L4Raw, Phase::L4Peeked] {
224 assert_eq!(transition(PhaseNodeKind::Upgrade, cur), Ok(Transition::Into(Phase::L7Request)),);
225 }
226 }
227
228 #[test]
229 fn l7_request_middleware_stays_in_l7_request() {
230 assert_eq!(
231 transition(PhaseNodeKind::Middleware(MiddlewareKind::L7Request), Phase::L7Request),
232 Ok(Transition::Into(Phase::L7Request)),
233 );
234 }
235
236 #[test]
237 fn l7_response_middleware_stays_in_l7_response() {
238 assert_eq!(
239 transition(PhaseNodeKind::Middleware(MiddlewareKind::L7Response), Phase::L7Response),
240 Ok(Transition::Into(Phase::L7Response)),
241 );
242 }
243
244 #[test]
245 fn l4_forward_fetch_goes_to_tunnel_from_any_l4_phase() {
246 for cur in [Phase::L4Raw, Phase::L4Peeked] {
247 assert_eq!(
248 transition(PhaseNodeKind::Fetch(FetchKind::L4Forward), cur),
249 Ok(Transition::Into(Phase::Tunnel)),
250 );
251 }
252 }
253
254 #[test]
255 fn http_fetch_variants_go_to_l7_response() {
256 for f in [FetchKind::HttpProxy, FetchKind::HttpSynthesize] {
257 assert_eq!(
258 transition(PhaseNodeKind::Fetch(f), Phase::L7Request),
259 Ok(Transition::Into(Phase::L7Response)),
260 );
261 }
262 }
263
264 #[test]
265 fn websocket_fetch_is_bi_outcome() {
266 assert_eq!(
267 transition(PhaseNodeKind::Fetch(FetchKind::WebSocketUpgrade), Phase::L7Request),
268 Ok(Transition::BiOutcome { response: Phase::L7Response, tunnel: Phase::Tunnel }),
269 );
270 }
271
272 #[test]
273 fn write_http_response_is_terminal() {
274 assert_eq!(
275 transition(PhaseNodeKind::Terminate(Terminator::WriteHttpResponse), Phase::L7Response),
276 Ok(Transition::Terminal),
277 );
278 }
279
280 #[test]
281 fn byte_tunnel_is_terminal() {
282 assert_eq!(
283 transition(PhaseNodeKind::Terminate(Terminator::ByteTunnel), Phase::Tunnel),
284 Ok(Transition::Terminal),
285 );
286 }
287
288 #[test]
289 fn close_is_terminal_at_every_phase() {
290 for p in ALL_PHASES {
291 assert_eq!(
292 transition(PhaseNodeKind::Terminate(Terminator::Close), p),
293 Ok(Transition::Terminal),
294 );
295 }
296 }
297
298 #[test]
299 fn close_accepts_any_phase() {
300 assert_eq!(accepted_in_phases(PhaseNodeKind::Terminate(Terminator::Close)), ANY_PHASE,);
301 }
302
303 #[test]
304 fn rejects_out_of_phase_attempts() {
305 let cases: &[(PhaseNodeKind, Phase)] = &[
306 (PhaseNodeKind::Upgrade, Phase::L7Request),
307 (PhaseNodeKind::Upgrade, Phase::L7Response),
308 (PhaseNodeKind::Middleware(MiddlewareKind::L7Request), Phase::L4Raw),
309 (PhaseNodeKind::Middleware(MiddlewareKind::L7Response), Phase::L7Request),
310 (PhaseNodeKind::Fetch(FetchKind::HttpProxy), Phase::L7Response),
311 (PhaseNodeKind::Fetch(FetchKind::L4Forward), Phase::L7Request),
312 (PhaseNodeKind::Terminate(Terminator::WriteHttpResponse), Phase::Tunnel),
313 (PhaseNodeKind::Terminate(Terminator::ByteTunnel), Phase::L7Response),
314 ];
315 for (kind, cur) in cases.iter().copied() {
316 let err = transition(kind, cur).expect_err("out-of-phase must error");
317 assert_eq!(err.got, cur);
318 assert_eq!(err.expected, accepted_in_phases(kind));
319 }
320 }
321}