virtual_hosts_module/
handler.rs

1// Copyright 2024 Wladimir Palant
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use http::header;
17use http::uri::Uri;
18use log::warn;
19use module_utils::{RequestFilter, RequestFilterResult};
20use pingora_core::Error;
21use pingora_proxy::Session;
22use std::collections::HashMap;
23use trie_rs::map::{Trie, TrieBuilder};
24
25use crate::configuration::VirtualHostsConf;
26
27struct Path {
28    segments: Vec<Vec<u8>>,
29    trailing_slash: bool,
30}
31
32impl Path {
33    fn new<T: AsRef<[u8]>>(path: T) -> Self {
34        let path = path.as_ref();
35        let trailing_slash = path.last().is_some_and(|c| *c == b'/');
36
37        let segments = path
38            .split(|c| *c == b'/')
39            .filter(|s| !s.is_empty())
40            .map(|s| s.to_owned())
41            .collect();
42        Self {
43            segments,
44            trailing_slash,
45        }
46    }
47
48    fn len(&self) -> usize {
49        self.segments.len()
50    }
51
52    fn to_key<T: AsRef<[u8]>>(&self, host: T) -> Vec<Vec<u8>> {
53        let mut key = vec![host.as_ref().to_owned()];
54        key.extend_from_slice(&self.segments);
55        key
56    }
57
58    fn to_vec(&self, strip_segments: usize) -> Vec<u8> {
59        let mut path =
60            self.segments[strip_segments..]
61                .iter()
62                .fold(Vec::new(), |mut path, segment| {
63                    path.push(b'/');
64                    path.extend_from_slice(segment);
65                    path
66                });
67        if self.trailing_slash || path.is_empty() {
68            path.push(b'/');
69        }
70        path
71    }
72}
73
74fn host_from_uri(uri: &Uri) -> Option<String> {
75    let mut host = uri.host()?.to_owned();
76    if let Some(port) = uri.port() {
77        host.push(':');
78        host.push_str(port.as_str());
79    }
80    Some(host)
81}
82
83fn set_uri_path(uri: &Uri, path: &[u8]) -> Uri {
84    let mut parts = uri.clone().into_parts();
85    let mut path_and_query = String::from_utf8_lossy(path).to_string();
86    let query = parts
87        .path_and_query
88        .as_ref()
89        .and_then(|path_and_query| path_and_query.query());
90    if let Some(query) = query {
91        path_and_query.push('?');
92        path_and_query.push_str(query);
93    }
94    parts.path_and_query = path_and_query.parse().ok();
95    parts.try_into().unwrap_or_else(|_| uri.clone())
96}
97
98/// Handler for Pingora’s `request_filter` phase
99#[derive(Debug)]
100pub struct VirtualHostsHandler<H> {
101    handlers: Trie<Vec<u8>, (bool, H)>,
102    aliases: HashMap<String, String>,
103    default: Option<String>,
104}
105
106impl<H> VirtualHostsHandler<H> {
107    fn best_match<T: AsRef<[u8]>>(&self, host: T, path: &Path) -> Option<(Option<Vec<u8>>, &H)> {
108        self.handlers
109            .common_prefix_search(path.to_key(host))
110            .last()
111            .map(
112                |(prefix, (strip_prefix, handler)): (Vec<Vec<u8>>, &(bool, H))| {
113                    if *strip_prefix && prefix.len() > 1 {
114                        (Some(path.to_vec(prefix.len() - 1)), handler)
115                    } else {
116                        (None, handler)
117                    }
118                },
119            )
120    }
121}
122
123#[async_trait]
124impl<H> RequestFilter for VirtualHostsHandler<H>
125where
126    H: RequestFilter + Sync,
127    H::Conf: Default,
128    H::CTX: Send,
129{
130    type Conf = VirtualHostsConf<H::Conf>;
131
132    type CTX = H::CTX;
133
134    fn new_ctx() -> Self::CTX {
135        H::new_ctx()
136    }
137
138    async fn request_filter(
139        &self,
140        session: &mut Session,
141        ctx: &mut Self::CTX,
142    ) -> Result<RequestFilterResult, Box<Error>> {
143        let host = session
144            .get_header(header::HOST)
145            .and_then(|host| host.to_str().ok())
146            .map(|host| host.to_owned())
147            .or_else(|| host_from_uri(&session.req_header().uri));
148
149        let path = Path::new(session.req_header().uri.path());
150        let handler = host
151            .and_then(|host| {
152                if let Some(handler) = self.best_match(&host, &path) {
153                    Some(handler)
154                } else if let Some(alias) = self.aliases.get(&host) {
155                    self.best_match(alias, &path)
156                } else {
157                    None
158                }
159            })
160            .or_else(|| {
161                self.default
162                    .as_ref()
163                    .and_then(|default| self.best_match(default, &path))
164            });
165
166        if let Some((new_path, handler)) = handler {
167            if let Some(new_path) = new_path {
168                let header = session.req_header_mut();
169                header.set_uri(set_uri_path(&header.uri, &new_path));
170            }
171            handler.request_filter(session, ctx).await
172        } else {
173            Ok(RequestFilterResult::Unhandled)
174        }
175    }
176}
177
178impl<C, H> TryFrom<VirtualHostsConf<C>> for VirtualHostsHandler<H>
179where
180    C: TryInto<H, Error = Box<Error>> + Default,
181{
182    type Error = Box<Error>;
183
184    fn try_from(conf: VirtualHostsConf<C>) -> Result<Self, Box<Error>> {
185        let mut handlers = TrieBuilder::new();
186        let mut aliases = HashMap::new();
187        let mut default = None;
188        for (host, host_conf) in conf.vhosts.into_iter() {
189            for alias in host_conf.host.aliases.into_iter() {
190                aliases.insert(alias, host.clone());
191            }
192            if host_conf.host.default {
193                if let Some(previous) = &default {
194                    warn!("both {previous} and {host} are marked as default virtual host, ignoring the latter");
195                } else {
196                    default = Some(host.clone());
197                }
198            }
199            handlers.push(
200                Path::new(b"").to_key(&host),
201                (false, host_conf.config.try_into()?),
202            );
203
204            // Work-around for https://github.com/laysakura/trie-rs/issues/32, insert shorter paths
205            // first.
206            let mut subdirs = host_conf
207                .host
208                .subdirs
209                .into_iter()
210                .map(|(path, conf)| (Path::new(path), conf))
211                .collect::<Vec<_>>();
212            subdirs.sort_by_key(|(path, _)| path.len());
213            for (path, conf) in subdirs {
214                handlers.push(
215                    path.to_key(&host),
216                    (conf.subdir.strip_prefix, conf.config.try_into()?),
217                );
218            }
219        }
220        let handlers = handlers.build();
221
222        Ok(Self {
223            handlers,
224            aliases,
225            default,
226        })
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::configuration::{SubDirCombined, SubDirConf, VirtualHostCombined, VirtualHostConf};
234
235    use async_trait::async_trait;
236    use test_log::test;
237    use tokio_test::io::Builder;
238
239    #[derive(Debug)]
240    struct Handler {
241        result: RequestFilterResult,
242    }
243
244    #[async_trait]
245    impl RequestFilter for Handler {
246        type Conf = RequestFilterResult;
247        type CTX = ();
248        fn new_ctx() -> Self::CTX {}
249        async fn request_filter(
250            &self,
251            _session: &mut Session,
252            _ctx: &mut Self::CTX,
253        ) -> Result<RequestFilterResult, Box<Error>> {
254            Ok(self.result)
255        }
256    }
257
258    impl TryFrom<RequestFilterResult> for Handler {
259        type Error = Box<Error>;
260
261        fn try_from(result: RequestFilterResult) -> Result<Self, Self::Error> {
262            Ok(Self { result })
263        }
264    }
265
266    fn handler(add_default: bool) -> VirtualHostsHandler<Handler> {
267        let mut vhosts = HashMap::new();
268
269        let mut subdirs = HashMap::new();
270        subdirs.insert(
271            "/subdir/".to_owned(),
272            SubDirCombined::<RequestFilterResult> {
273                subdir: SubDirConf { strip_prefix: true },
274                config: RequestFilterResult::Unhandled,
275            },
276        );
277        subdirs.insert(
278            "/subdir/subsub".to_owned(),
279            SubDirCombined::<RequestFilterResult> {
280                subdir: SubDirConf {
281                    strip_prefix: false,
282                },
283                config: RequestFilterResult::Handled,
284            },
285        );
286
287        vhosts.insert(
288            "localhost:8080".to_owned(),
289            VirtualHostCombined::<RequestFilterResult> {
290                host: VirtualHostConf {
291                    aliases: vec!["127.0.0.1:8080".to_owned(), "[::1]:8080".to_owned()],
292                    default: add_default,
293                    subdirs,
294                },
295                config: RequestFilterResult::ResponseSent,
296            },
297        );
298
299        vhosts.insert(
300            "example.com".to_owned(),
301            VirtualHostCombined::<RequestFilterResult> {
302                host: VirtualHostConf {
303                    aliases: vec!["example.com:8080".to_owned()],
304                    default: false,
305                    subdirs: HashMap::new(),
306                },
307                config: RequestFilterResult::Handled,
308            },
309        );
310
311        VirtualHostsConf::<RequestFilterResult> { vhosts }
312            .try_into()
313            .unwrap()
314    }
315
316    async fn make_session(uri: &str, host: Option<&str>) -> Session {
317        let mut mock = Builder::new();
318
319        mock.read(format!("GET {uri} HTTP/1.1\r\n").as_bytes());
320        if let Some(host) = host {
321            mock.read(format!("Host: {host}\r\n").as_bytes());
322        }
323        mock.read(b"Connection: close\r\n");
324        mock.read(b"\r\n");
325
326        let mut session = Session::new_h1(Box::new(mock.build()));
327        assert!(session.read_request().await.unwrap());
328
329        // Set URI explicitly, otherwise with a H1 session it will all end up in the path.
330        session.req_header_mut().set_uri(uri.try_into().unwrap());
331
332        session
333    }
334
335    #[test(tokio::test)]
336    async fn host_match() -> Result<(), Box<Error>> {
337        let handler = handler(true);
338        let mut session = make_session("/", Some("example.com")).await;
339        assert_eq!(
340            handler.request_filter(&mut session, &mut ()).await?,
341            RequestFilterResult::Handled
342        );
343        Ok(())
344    }
345
346    #[test(tokio::test)]
347    async fn host_alias_match() -> Result<(), Box<Error>> {
348        let handler = handler(false);
349        let mut session = make_session("/", Some("[::1]:8080")).await;
350        assert_eq!(
351            handler.request_filter(&mut session, &mut ()).await?,
352            RequestFilterResult::ResponseSent
353        );
354        Ok(())
355    }
356
357    #[test(tokio::test)]
358    async fn uri_match() -> Result<(), Box<Error>> {
359        let handler = handler(false);
360        let mut session = make_session("https://example.com/", None).await;
361        assert_eq!(
362            handler.request_filter(&mut session, &mut ()).await?,
363            RequestFilterResult::Handled
364        );
365        Ok(())
366    }
367
368    #[test(tokio::test)]
369    async fn uri_alias_match() -> Result<(), Box<Error>> {
370        let handler = handler(false);
371        let mut session = make_session("http://[::1]:8080/", None).await;
372        assert_eq!(
373            handler.request_filter(&mut session, &mut ()).await?,
374            RequestFilterResult::ResponseSent
375        );
376        Ok(())
377    }
378
379    #[test(tokio::test)]
380    async fn host_precedence() -> Result<(), Box<Error>> {
381        let handler = handler(false);
382        let mut session = make_session("https://localhost:8080/", Some("example.com")).await;
383        assert_eq!(
384            handler.request_filter(&mut session, &mut ()).await?,
385            RequestFilterResult::Handled
386        );
387        Ok(())
388    }
389
390    #[test(tokio::test)]
391    async fn default_fallback() -> Result<(), Box<Error>> {
392        let handler = handler(true);
393        let mut session = make_session("/", Some("example.net")).await;
394        assert_eq!(
395            handler.request_filter(&mut session, &mut ()).await?,
396            RequestFilterResult::ResponseSent
397        );
398        Ok(())
399    }
400
401    #[test(tokio::test)]
402    async fn no_default_fallback() -> Result<(), Box<Error>> {
403        let handler = handler(false);
404        let mut session = make_session("/", Some("example.net")).await;
405        assert_eq!(
406            handler.request_filter(&mut session, &mut ()).await?,
407            RequestFilterResult::Unhandled
408        );
409        Ok(())
410    }
411
412    #[test(tokio::test)]
413    async fn subdir_match() -> Result<(), Box<Error>> {
414        let handler = handler(true);
415        let mut session = make_session("/subdir/", Some("localhost:8080")).await;
416        assert_eq!(
417            handler.request_filter(&mut session, &mut ()).await?,
418            RequestFilterResult::Unhandled
419        );
420        assert_eq!(session.req_header().uri, "/");
421        Ok(())
422    }
423
424    #[test(tokio::test)]
425    async fn subdir_match_without_slash() -> Result<(), Box<Error>> {
426        let handler = handler(true);
427        let mut session = make_session("/subdir", Some("localhost:8080")).await;
428        assert_eq!(
429            handler.request_filter(&mut session, &mut ()).await?,
430            RequestFilterResult::Unhandled
431        );
432        assert_eq!(session.req_header().uri, "/");
433        Ok(())
434    }
435
436    #[test(tokio::test)]
437    async fn subdir_match_with_suffix() -> Result<(), Box<Error>> {
438        let handler = handler(true);
439        let mut session = make_session("/subdir/xyz?abc", Some("localhost:8080")).await;
440        assert_eq!(
441            handler.request_filter(&mut session, &mut ()).await?,
442            RequestFilterResult::Unhandled
443        );
444        assert_eq!(session.req_header().uri, "/xyz?abc");
445        Ok(())
446    }
447
448    #[test(tokio::test)]
449    async fn subdir_match_extra_slashes() -> Result<(), Box<Error>> {
450        let handler = handler(true);
451        let mut session = make_session("//subdir///xyz//", Some("localhost:8080")).await;
452        assert_eq!(
453            handler.request_filter(&mut session, &mut ()).await?,
454            RequestFilterResult::Unhandled
455        );
456        assert_eq!(session.req_header().uri, "/xyz/");
457        Ok(())
458    }
459
460    #[test(tokio::test)]
461    async fn subdir_no_match() -> Result<(), Box<Error>> {
462        let handler = handler(true);
463        let mut session = make_session("/subdir_xyz", Some("localhost:8080")).await;
464        assert_eq!(
465            handler.request_filter(&mut session, &mut ()).await?,
466            RequestFilterResult::ResponseSent
467        );
468        assert_eq!(session.req_header().uri, "/subdir_xyz");
469        Ok(())
470    }
471
472    #[test(tokio::test)]
473    async fn subdir_longer_match() -> Result<(), Box<Error>> {
474        let handler = handler(true);
475        let mut session = make_session("/subdir/subsub/xyz", Some("localhost:8080")).await;
476        assert_eq!(
477            handler.request_filter(&mut session, &mut ()).await?,
478            RequestFilterResult::Handled
479        );
480        assert_eq!(session.req_header().uri, "/subdir/subsub/xyz");
481        Ok(())
482    }
483}