1pub mod pattern_trie;
2
3use std::{
4 fmt::{self, Debug, Write},
5 rc::Rc,
6 str::from_utf8,
7 time::Instant,
8};
9
10use regex::bytes::Regex;
11use sozu_command::{
12 logging::CachedTags,
13 proto::command::{
14 HeaderPosition, HstsConfig, PathRule as CommandPathRule, PathRuleKind, RedirectPolicy,
15 RedirectScheme, RulePosition,
16 },
17 response::HttpFrontend,
18 state::ClusterId,
19};
20
21use crate::metrics::names;
22use crate::{
23 protocol::{http::editor::HeaderEditMode, http::parser::Method},
24 router::pattern_trie::{TrieMatches, TrieNode, TrieSubMatch},
25 sozu_command::logging::ansi_palette,
26};
27
28macro_rules! log_module_context {
35 () => {{
36 let (open, reset, _, _, _) = ansi_palette();
37 format!("{open}ROUTER{reset}\t >>>", open = open, reset = reset)
38 }};
39}
40
41#[derive(thiserror::Error, Debug, PartialEq)]
42pub enum RouterError {
43 #[error("Could not parse rule from frontend path {0:?}")]
44 InvalidPathRule(String),
45 #[error("parsing hostname {hostname} failed")]
46 InvalidDomain { hostname: String },
47 #[error("Could not parse host rewrite {0:?}")]
48 InvalidHostRewrite(String),
49 #[error("Could not parse path rewrite {0:?}")]
50 InvalidPathRewrite(String),
51 #[error("Could not add route {0}")]
52 AddRoute(String),
53 #[error("Could not remove route {0}")]
54 RemoveRoute(String),
55 #[error("no route for {method} {host} {path}")]
56 RouteNotFound {
57 host: String,
58 path: String,
59 method: Method,
60 },
61}
62
63pub struct Router {
64 pre: Vec<(DomainRule, PathRule, MethodRule, Route)>,
65 pub tree: TrieNode<Vec<(PathRule, MethodRule, Route)>>,
66 post: Vec<(DomainRule, PathRule, MethodRule, Route)>,
67}
68
69impl Default for Router {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Router {
76 pub fn new() -> Router {
77 Router {
78 pre: Vec::new(),
79 tree: TrieNode::root(),
80 post: Vec::new(),
81 }
82 }
83
84 pub fn lookup(
93 &self,
94 hostname: &str,
95 path: &str,
96 method: &Method,
97 ) -> Result<RouteResult, RouterError> {
98 let hostname_b = hostname.as_bytes();
99 let path_b = path.as_bytes();
100 for (domain_rule, path_rule, method_rule, route) in &self.pre {
101 if domain_rule.matches(hostname_b)
102 && path_rule.matches(path_b) != PathRuleResult::None
103 && method_rule.matches(method) != MethodRuleResult::None
104 {
105 return Ok(RouteResult::new_no_trie(
106 hostname_b,
107 domain_rule,
108 path_b,
109 path_rule,
110 route,
111 ));
112 }
113 }
114
115 let trie_path: TrieMatches<'_, '_> = Vec::with_capacity(16);
116 if let Some(((_, path_rules), trie_matches)) =
117 self.tree.lookup_with_path(hostname_b, true, trie_path)
118 {
119 let mut prefix_length = 0;
120 let mut matched: Option<(&PathRule, &Route)> = None;
121
122 for (rule, method_rule, route) in path_rules {
123 match rule.matches(path_b) {
124 PathRuleResult::Regex | PathRuleResult::Equals => {
125 match method_rule.matches(method) {
126 MethodRuleResult::Equals => {
127 return Ok(RouteResult::new_with_trie(
128 hostname_b,
129 trie_matches,
130 path_b,
131 rule,
132 route,
133 ));
134 }
135 MethodRuleResult::All => {
136 prefix_length = path_b.len();
137 matched = Some((rule, route));
138 }
139 MethodRuleResult::None => {}
140 }
141 }
142 PathRuleResult::Prefix(size) => {
143 if size >= prefix_length {
144 match method_rule.matches(method) {
145 MethodRuleResult::Equals => {
147 prefix_length = size;
148 matched = Some((rule, route));
149 }
150 MethodRuleResult::All => {
151 prefix_length = size;
152 matched = Some((rule, route));
153 }
154 MethodRuleResult::None => {}
155 }
156 }
157 }
158 PathRuleResult::None => {}
159 }
160 }
161
162 if let Some((path_rule, route)) = matched {
163 return Ok(RouteResult::new_with_trie(
164 hostname_b,
165 trie_matches,
166 path_b,
167 path_rule,
168 route,
169 ));
170 }
171 }
172
173 for (domain_rule, path_rule, method_rule, route) in self.post.iter() {
174 if domain_rule.matches(hostname_b)
175 && path_rule.matches(path_b) != PathRuleResult::None
176 && method_rule.matches(method) != MethodRuleResult::None
177 {
178 return Ok(RouteResult::new_no_trie(
179 hostname_b,
180 domain_rule,
181 path_b,
182 path_rule,
183 route,
184 ));
185 }
186 }
187
188 Err(RouterError::RouteNotFound {
189 host: hostname.to_owned(),
190 path: path.to_owned(),
191 method: method.to_owned(),
192 })
193 }
194
195 pub fn add_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> {
202 self.add_http_front_with_hsts_origin(front, HstsOrigin::Explicit)
203 }
204
205 pub fn add_http_front_with_hsts_origin(
212 &mut self,
213 front: &HttpFrontend,
214 hsts_origin: HstsOrigin,
215 ) -> Result<(), RouterError> {
216 let path_rule = PathRule::from_config(front.path.clone())
217 .ok_or(RouterError::InvalidPathRule(front.path.to_string()))?;
218
219 let method_rule = MethodRule::new(front.method.clone());
220
221 let has_policy = front.redirect.is_some()
226 || front.redirect_scheme.is_some()
227 || front.redirect_template.is_some()
228 || front.rewrite_host.is_some()
229 || front.rewrite_path.is_some()
230 || front.rewrite_port.is_some()
231 || front.required_auth.unwrap_or(false)
232 || !front.headers.is_empty()
233 || front.hsts.is_some();
234
235 let domain =
236 front
237 .hostname
238 .parse::<DomainRule>()
239 .map_err(|_| RouterError::InvalidDomain {
240 hostname: front.hostname.clone(),
241 })?;
242
243 let route = if has_policy {
244 let redirect = front
245 .redirect
246 .and_then(|r| RedirectPolicy::try_from(r).ok())
247 .unwrap_or(RedirectPolicy::Forward);
248 let redirect_scheme = front
249 .redirect_scheme
250 .and_then(|s| RedirectScheme::try_from(s).ok())
251 .unwrap_or(RedirectScheme::UseSame);
252 let frontend = Frontend::new(
253 &domain,
254 &path_rule,
255 front,
256 redirect,
257 redirect_scheme,
258 front.redirect_template.clone(),
259 front.rewrite_host.clone(),
260 front.rewrite_path.clone(),
261 front.rewrite_port.and_then(|p| u16::try_from(p).ok()),
262 &front.headers,
263 front.required_auth.unwrap_or(false),
264 hsts_origin,
265 )?;
266 Route::Frontend(Rc::new(frontend))
267 } else {
268 match &front.cluster_id {
269 Some(cluster_id) => Route::ClusterId(cluster_id.clone()),
270 None => Route::Deny,
271 }
272 };
273
274 let success = match front.position {
275 RulePosition::Pre => self.add_pre_rule(&domain, &path_rule, &method_rule, &route),
276 RulePosition::Post => self.add_post_rule(&domain, &path_rule, &method_rule, &route),
277 RulePosition::Tree => {
278 self.add_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule, &route)
279 }
280 };
281 if !success {
282 return Err(RouterError::AddRoute(format!("{front:?}")));
283 }
284 Ok(())
285 }
286
287 pub fn remove_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> {
288 let path_rule = PathRule::from_config(front.path.clone())
289 .ok_or(RouterError::InvalidPathRule(front.path.to_string()))?;
290
291 let method_rule = MethodRule::new(front.method.clone());
292
293 let remove_success = match front.position {
294 RulePosition::Pre => {
295 let domain = front.hostname.parse::<DomainRule>().map_err(|_| {
296 RouterError::InvalidDomain {
297 hostname: front.hostname.clone(),
298 }
299 })?;
300
301 self.remove_pre_rule(&domain, &path_rule, &method_rule)
302 }
303 RulePosition::Post => {
304 let domain = front.hostname.parse::<DomainRule>().map_err(|_| {
305 RouterError::InvalidDomain {
306 hostname: front.hostname.clone(),
307 }
308 })?;
309
310 self.remove_post_rule(&domain, &path_rule, &method_rule)
311 }
312 RulePosition::Tree => {
313 self.remove_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule)
314 }
315 };
316 if !remove_success {
317 return Err(RouterError::RemoveRoute(format!("{front:?}")));
318 }
319 Ok(())
320 }
321
322 pub fn add_tree_rule(
323 &mut self,
324 hostname: &[u8],
325 path: &PathRule,
326 method: &MethodRule,
327 cluster: &Route,
328 ) -> bool {
329 let hostname = match from_utf8(hostname) {
330 Err(_) => return false,
331 Ok(h) => h,
332 };
333
334 match ::idna::domain_to_ascii(hostname) {
335 Ok(hostname) => {
336 let mut empty = true;
338 if let Some((_, paths)) = self.tree.domain_lookup_mut(hostname.as_bytes(), false) {
339 empty = false;
340 if !paths.iter().any(|(p, m, _)| p == path && m == method) {
341 paths.push((path.to_owned(), method.to_owned(), cluster.to_owned()));
342 return true;
343 }
344 }
345
346 if empty {
347 self.tree.domain_insert(
348 hostname.into_bytes(),
349 vec![(path.to_owned(), method.to_owned(), cluster.to_owned())],
350 );
351 return true;
352 }
353
354 false
355 }
356 Err(_) => false,
357 }
358 }
359
360 pub fn remove_tree_rule(
361 &mut self,
362 hostname: &[u8],
363 path: &PathRule,
364 method: &MethodRule,
365 ) -> bool {
367 let hostname = match from_utf8(hostname) {
368 Err(_) => return false,
369 Ok(h) => h,
370 };
371
372 match ::idna::domain_to_ascii(hostname) {
373 Ok(hostname) => {
374 let should_delete = {
375 let paths_opt = self.tree.domain_lookup_mut(hostname.as_bytes(), false);
376
377 if let Some((_, paths)) = paths_opt {
378 paths.retain(|(p, m, _)| p != path || m != method);
379 }
380
381 paths_opt
382 .as_ref()
383 .map(|(_, paths)| paths.is_empty())
384 .unwrap_or(false)
385 };
386
387 if should_delete {
388 self.tree.domain_remove(&hostname.into_bytes());
389 }
390
391 true
392 }
393 Err(_) => false,
394 }
395 }
396
397 pub fn refresh_inheriting_hsts(&mut self, new_hsts: Option<&HstsConfig>) -> usize {
441 let mut refreshed = 0usize;
442 let new_edit = build_listener_hsts_edit(new_hsts);
450 let new_edit_ref = new_edit.as_ref();
451 let promote_lightweight = new_edit_ref.is_some();
452 let mut visit = |route: &mut Route| match route {
453 Route::Frontend(rc) => {
454 if rc.inherits_listener_hsts {
455 let new_frontend = rebuild_with_listener_hsts(rc, new_edit_ref);
456 *rc = Rc::new(new_frontend);
457 refreshed += 1;
458 }
459 }
460 Route::ClusterId(id) => {
461 if promote_lightweight {
462 let promoted = rebuild_with_listener_hsts(
463 &Frontend::minimal_forward(id.clone()),
464 new_edit_ref,
465 );
466 *route = Route::Frontend(Rc::new(promoted));
467 refreshed += 1;
468 }
469 }
470 Route::Deny => {
471 if promote_lightweight {
472 let promoted =
473 rebuild_with_listener_hsts(&Frontend::minimal_deny(), new_edit_ref);
474 *route = Route::Frontend(Rc::new(promoted));
475 refreshed += 1;
476 }
477 }
478 };
479
480 for (_, _, _, route) in self.pre.iter_mut() {
481 visit(route);
482 }
483 self.tree.for_each_value_mut(&mut |paths| {
484 for (_, _, route) in paths.iter_mut() {
485 visit(route);
486 }
487 });
488 for (_, _, _, route) in self.post.iter_mut() {
489 visit(route);
490 }
491 refreshed
492 }
493
494 pub fn add_pre_rule(
495 &mut self,
496 domain: &DomainRule,
497 path: &PathRule,
498 method: &MethodRule,
499 cluster_id: &Route,
500 ) -> bool {
501 if !self
502 .pre
503 .iter()
504 .any(|(d, p, m, _)| d == domain && p == path && m == method)
505 {
506 self.pre.push((
507 domain.to_owned(),
508 path.to_owned(),
509 method.to_owned(),
510 cluster_id.to_owned(),
511 ));
512 true
513 } else {
514 false
515 }
516 }
517
518 pub fn add_post_rule(
519 &mut self,
520 domain: &DomainRule,
521 path: &PathRule,
522 method: &MethodRule,
523 cluster_id: &Route,
524 ) -> bool {
525 if !self
526 .post
527 .iter()
528 .any(|(d, p, m, _)| d == domain && p == path && m == method)
529 {
530 self.post.push((
531 domain.to_owned(),
532 path.to_owned(),
533 method.to_owned(),
534 cluster_id.to_owned(),
535 ));
536 true
537 } else {
538 false
539 }
540 }
541
542 pub fn remove_pre_rule(
543 &mut self,
544 domain: &DomainRule,
545 path: &PathRule,
546 method: &MethodRule,
547 ) -> bool {
548 match self
549 .pre
550 .iter()
551 .position(|(d, p, m, _)| d == domain && p == path && m == method)
552 {
553 None => false,
554 Some(index) => {
555 self.pre.remove(index);
556 true
557 }
558 }
559 }
560
561 pub fn remove_post_rule(
562 &mut self,
563 domain: &DomainRule,
564 path: &PathRule,
565 method: &MethodRule,
566 ) -> bool {
567 match self
568 .post
569 .iter()
570 .position(|(d, p, m, _)| d == domain && p == path && m == method)
571 {
572 None => false,
573 Some(index) => {
574 self.post.remove(index);
575 true
576 }
577 }
578 }
579
580 pub fn has_hostname(&self, hostname: &str) -> bool {
585 let hostname_b = hostname.as_bytes();
586
587 for (domain_rule, _, _, _) in &self.pre {
589 if domain_rule.matches(hostname_b) {
590 return true;
591 }
592 }
593
594 if let Ok(ascii_hostname) = ::idna::domain_to_ascii(hostname) {
596 if self
597 .tree
598 .domain_lookup(ascii_hostname.as_bytes(), false)
599 .is_some()
600 {
601 return true;
602 }
603 }
604
605 for (domain_rule, _, _, _) in &self.post {
607 if domain_rule.matches(hostname_b) {
608 return true;
609 }
610 }
611
612 false
613 }
614}
615
616#[derive(Clone, Debug)]
617pub enum DomainRule {
618 Any,
619 Exact(String),
620 Wildcard(String),
626 Regex(Regex),
627}
628
629fn convert_regex_domain_rule(hostname: &str) -> Option<String> {
630 let mut result = String::from("\\A");
636
637 let s = hostname.as_bytes();
638 let mut index = 0;
639 loop {
640 if s[index] == b'/' {
641 let mut found = false;
642 for i in index + 1..s.len() {
643 if s[i] == b'/' {
644 match std::str::from_utf8(&s[index + 1..i]) {
645 Ok(r) => result.push_str(r),
646 Err(_) => return None,
647 }
648 index = i + 1;
649 found = true;
650 break;
651 }
652 }
653
654 if !found {
655 return None;
656 }
657 } else {
658 let start = index;
659 for i in start..s.len() + 1 {
660 index = i;
661 if i < s.len() && s[i] == b'.' {
662 match std::str::from_utf8(&s[start..i]) {
663 Ok(r) => result.push_str(r),
664 Err(_) => return None,
665 }
666 break;
667 }
668 }
669 if index == s.len() {
670 match std::str::from_utf8(&s[start..]) {
671 Ok(r) => result.push_str(r),
672 Err(_) => return None,
673 }
674 }
675 }
676
677 if index == s.len() {
678 result.push_str("\\z");
679 return Some(result);
680 } else if s[index] == b'.' {
681 result.push_str("\\.");
682 index += 1;
683 } else {
684 return None;
685 }
686 }
687}
688
689impl DomainRule {
690 pub fn matches(&self, hostname: &[u8]) -> bool {
691 match self {
692 DomainRule::Any => true,
693 DomainRule::Wildcard(s) => {
694 let suffix = &s.as_bytes()[1..];
695 hostname
696 .strip_suffix(suffix)
697 .is_some_and(|prefix| !prefix.is_empty() && !prefix.contains(&b'.'))
698 }
699 DomainRule::Exact(s) => s.as_bytes() == hostname,
700 DomainRule::Regex(r) => {
701 let start = Instant::now();
702 let is_a_match = r.is_match(hostname);
703 let now = Instant::now();
704 time!(
705 names::event_loop::REGEX_MATCHING_TIME,
706 (now - start).as_millis()
707 );
708 is_a_match
709 }
710 }
711 }
712}
713
714impl std::cmp::PartialEq for DomainRule {
715 fn eq(&self, other: &Self) -> bool {
716 match (self, other) {
717 (DomainRule::Any, DomainRule::Any) => true,
718 (DomainRule::Wildcard(s1), DomainRule::Wildcard(s2)) => s1 == s2,
719 (DomainRule::Exact(s1), DomainRule::Exact(s2)) => s1 == s2,
720 (DomainRule::Regex(r1), DomainRule::Regex(r2)) => r1.as_str() == r2.as_str(),
721 _ => false,
722 }
723 }
724}
725
726impl std::str::FromStr for DomainRule {
727 type Err = ();
728
729 fn from_str(s: &str) -> Result<Self, Self::Err> {
730 Ok(if s == "*" {
731 DomainRule::Any
732 } else if s.contains('/') {
733 match convert_regex_domain_rule(s) {
734 Some(s) => match regex::bytes::Regex::new(&s) {
735 Ok(r) => DomainRule::Regex(r),
736 Err(_) => return Err(()),
737 },
738 None => return Err(()),
739 }
740 } else if s.contains('*') {
741 if s.starts_with('*') {
742 match ::idna::domain_to_ascii(s) {
743 Ok(r) => DomainRule::Wildcard(r),
744 Err(_) => return Err(()),
745 }
746 } else {
747 return Err(());
748 }
749 } else {
750 match ::idna::domain_to_ascii(s) {
751 Ok(r) => DomainRule::Exact(r),
752 Err(_) => return Err(()),
753 }
754 })
755 }
756}
757
758#[derive(Clone, Debug)]
759pub enum PathRule {
760 Prefix(String),
761 Regex(Regex),
762 Equals(String),
763}
764
765#[derive(PartialEq, Eq)]
766pub enum PathRuleResult {
767 Regex,
768 Prefix(usize),
769 Equals,
770 None,
771}
772
773impl PathRule {
774 pub fn matches(&self, path: &[u8]) -> PathRuleResult {
775 match self {
776 PathRule::Prefix(prefix) => {
777 if path.starts_with(prefix.as_bytes()) {
778 PathRuleResult::Prefix(prefix.len())
779 } else {
780 PathRuleResult::None
781 }
782 }
783 PathRule::Regex(regex) => {
784 let start = Instant::now();
785 let is_a_match = regex.is_match(path);
786 let now = Instant::now();
787 time!(
788 names::event_loop::REGEX_MATCHING_TIME,
789 (now - start).as_millis()
790 );
791
792 if is_a_match {
793 PathRuleResult::Regex
794 } else {
795 PathRuleResult::None
796 }
797 }
798 PathRule::Equals(pattern) => {
799 if path == pattern.as_bytes() {
800 PathRuleResult::Equals
801 } else {
802 PathRuleResult::None
803 }
804 }
805 }
806 }
807
808 pub fn from_config(rule: CommandPathRule) -> Option<Self> {
809 match PathRuleKind::try_from(rule.kind) {
810 Ok(PathRuleKind::Prefix) => Some(PathRule::Prefix(rule.value)),
811 Ok(PathRuleKind::Regex) => Regex::new(&rule.value).ok().map(PathRule::Regex),
812 Ok(PathRuleKind::Equals) => Some(PathRule::Equals(rule.value)),
813 Err(_) => None,
814 }
815 }
816}
817
818impl std::cmp::PartialEq for PathRule {
819 fn eq(&self, other: &Self) -> bool {
820 match (self, other) {
821 (PathRule::Prefix(s1), PathRule::Prefix(s2)) => s1 == s2,
822 (PathRule::Regex(r1), PathRule::Regex(r2)) => r1.as_str() == r2.as_str(),
823 _ => false,
824 }
825 }
826}
827
828#[derive(Clone, Debug, PartialEq, Eq)]
829pub struct MethodRule {
830 pub inner: Option<Method>,
831}
832
833#[derive(PartialEq, Eq)]
834pub enum MethodRuleResult {
835 All,
836 Equals,
837 None,
838}
839
840impl MethodRule {
841 pub fn new(method: Option<String>) -> Self {
842 MethodRule {
843 inner: method.map(|s| Method::new(s.as_bytes())),
844 }
845 }
846
847 pub fn matches(&self, method: &Method) -> MethodRuleResult {
848 match self.inner {
849 None => MethodRuleResult::All,
850 Some(ref m) => {
851 if method == m {
852 MethodRuleResult::Equals
853 } else {
854 MethodRuleResult::None
855 }
856 }
857 }
858 }
859}
860
861#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
883pub enum Route {
884 Deny,
886 ClusterId(ClusterId),
888 Frontend(Rc<Frontend>),
893}
894
895fn build_listener_hsts_edit(new_hsts: Option<&HstsConfig>) -> Option<HeaderEdit> {
922 let cfg = new_hsts?;
923 if !matches!(cfg.enabled, Some(true)) {
924 return None;
925 }
926 let rendered = render_hsts(cfg)?;
927 let mode = if matches!(cfg.force_replace_backend, Some(true)) {
928 HeaderEditMode::Set
929 } else {
930 HeaderEditMode::SetIfAbsent
931 };
932 Some(HeaderEdit {
933 key: Rc::from(&b"strict-transport-security"[..]),
934 val: rendered.into_bytes().into(),
935 mode,
936 })
937}
938
939fn rebuild_with_listener_hsts(frontend: &Frontend, new_edit: Option<&HeaderEdit>) -> Frontend {
955 let mut headers_response: Vec<HeaderEdit> = frontend
957 .headers_response
958 .iter()
959 .filter(|edit| !edit.key.eq_ignore_ascii_case(b"strict-transport-security"))
960 .cloned()
961 .collect();
962
963 if let Some(edit) = new_edit {
966 headers_response.push(edit.clone());
967 }
968
969 Frontend {
970 headers_response: headers_response.into(),
971 ..frontend.clone()
973 }
974}
975
976pub fn render_hsts(cfg: &HstsConfig) -> Option<String> {
989 let max_age = cfg.max_age?;
990 let mut s = format!("max-age={max_age}");
991 if matches!(cfg.include_subdomains, Some(true)) {
992 s.push_str("; includeSubDomains");
993 }
994 if matches!(cfg.preload, Some(true)) {
995 s.push_str("; preload");
996 }
997 Some(s)
998}
999
1000#[derive(Clone, PartialEq, Eq)]
1012pub struct HeaderEdit {
1013 pub key: Rc<[u8]>,
1014 pub val: Rc<[u8]>,
1015 pub mode: HeaderEditMode,
1016}
1017
1018impl Debug for HeaderEdit {
1019 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1020 f.write_fmt(format_args!(
1021 "({:?}, {:?}, {:?})",
1022 String::from_utf8_lossy(&self.key),
1023 String::from_utf8_lossy(&self.val),
1024 self.mode,
1025 ))
1026 }
1027}
1028
1029#[derive(Debug, Clone, PartialEq, Eq)]
1037enum RewritePart {
1038 String(String),
1039 Host(usize),
1040 Path(usize),
1041}
1042
1043#[derive(Debug, Clone, PartialEq, Eq)]
1058pub struct RewriteParts(Vec<RewritePart>);
1059
1060impl RewriteParts {
1061 pub fn parse(
1073 template: &str,
1074 host_cap_cap: usize,
1075 path_cap_cap: usize,
1076 used_index_host: &mut usize,
1077 used_index_path: &mut usize,
1078 ) -> Option<Self> {
1079 let mut result = Vec::new();
1080 let mut i = 0;
1081 let pattern = template.as_bytes();
1082 while i < pattern.len() {
1083 if pattern[i] == b'$' {
1084 let is_host = if pattern[i..].starts_with(b"$HOST[") {
1085 i += 6;
1086 true
1087 } else if pattern[i..].starts_with(b"$PATH[") {
1088 i += 6;
1089 false
1090 } else {
1091 return None;
1092 };
1093 let mut index = 0usize;
1094 let digits_start = i;
1095 while i < pattern.len() && pattern[i].is_ascii_digit() {
1096 index = index
1097 .checked_mul(10)?
1098 .checked_add((pattern[i] - b'0') as usize)?;
1099 i += 1;
1100 }
1101 if i == digits_start {
1102 return None;
1104 }
1105 if i >= pattern.len() || pattern[i] != b']' {
1106 return None;
1107 }
1108 if is_host {
1109 if index >= host_cap_cap {
1110 return None;
1111 }
1112 if index >= *used_index_host {
1113 *used_index_host = index + 1;
1114 }
1115 result.push(RewritePart::Host(index));
1116 } else {
1117 if index >= path_cap_cap {
1118 return None;
1119 }
1120 if index >= *used_index_path {
1121 *used_index_path = index + 1;
1122 }
1123 result.push(RewritePart::Path(index));
1124 }
1125 i += 1; } else {
1127 let start = i;
1128 while i < pattern.len() && pattern[i] != b'$' {
1129 i += 1;
1130 }
1131 result.push(RewritePart::String(template[start..i].to_owned()));
1136 }
1137 }
1138 Some(Self(result))
1139 }
1140
1141 pub fn run(&self, host_captures: &[&str], path_captures: &[&str]) -> String {
1146 let mut cap = 0usize;
1147 for part in &self.0 {
1148 cap += match part {
1149 RewritePart::String(s) => s.len(),
1150 RewritePart::Host(i) => host_captures.get(*i).map(|s| s.len()).unwrap_or(0),
1151 RewritePart::Path(i) => path_captures.get(*i).map(|s| s.len()).unwrap_or(0),
1152 };
1153 }
1154 let mut result = String::with_capacity(cap);
1155 for part in &self.0 {
1156 let _ = match part {
1158 RewritePart::String(s) => result.write_str(s),
1159 RewritePart::Host(i) => result.write_str(host_captures.get(*i).unwrap_or(&"")),
1160 RewritePart::Path(i) => result.write_str(path_captures.get(*i).unwrap_or(&"")),
1161 };
1162 }
1163 result
1164 }
1165}
1166
1167#[derive(Debug, Clone)]
1180pub struct Frontend {
1181 pub cluster_id: Option<ClusterId>,
1182 pub redirect: RedirectPolicy,
1183 pub redirect_scheme: RedirectScheme,
1184 pub redirect_template: Option<String>,
1185 pub capture_cap_host: usize,
1189 pub capture_cap_path: usize,
1193 pub rewrite_host: Option<RewriteParts>,
1194 pub rewrite_path: Option<RewriteParts>,
1195 pub rewrite_port: Option<u16>,
1196 pub headers_request: Rc<[HeaderEdit]>,
1197 pub headers_response: Rc<[HeaderEdit]>,
1198 pub required_auth: bool,
1199 pub tags: Option<Rc<CachedTags>>,
1200 pub inherits_listener_hsts: bool,
1209}
1210
1211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1219pub enum HstsOrigin {
1220 Explicit,
1225 InheritedFromListenerDefault,
1230}
1231
1232impl PartialEq for Frontend {
1233 fn eq(&self, other: &Self) -> bool {
1234 self.cluster_id == other.cluster_id
1238 && self.redirect == other.redirect
1239 && self.redirect_scheme == other.redirect_scheme
1240 && self.redirect_template == other.redirect_template
1241 && self.rewrite_host == other.rewrite_host
1242 && self.rewrite_path == other.rewrite_path
1243 && self.rewrite_port == other.rewrite_port
1244 && self.headers_request == other.headers_request
1245 && self.headers_response == other.headers_response
1246 && self.required_auth == other.required_auth
1247 }
1248}
1249
1250impl Eq for Frontend {}
1251
1252impl std::hash::Hash for Frontend {
1253 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1254 self.cluster_id.hash(state);
1255 (self.redirect as i32).hash(state);
1258 (self.redirect_scheme as i32).hash(state);
1259 self.redirect_template.hash(state);
1260 self.required_auth.hash(state);
1261 }
1262}
1263
1264impl PartialOrd for Frontend {
1265 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1266 Some(self.cmp(other))
1267 }
1268}
1269
1270impl Ord for Frontend {
1271 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1272 self.cluster_id
1273 .cmp(&other.cluster_id)
1274 .then_with(|| (self.redirect as i32).cmp(&(other.redirect as i32)))
1275 .then_with(|| (self.redirect_scheme as i32).cmp(&(other.redirect_scheme as i32)))
1276 .then_with(|| self.redirect_template.cmp(&other.redirect_template))
1277 .then_with(|| self.required_auth.cmp(&other.required_auth))
1278 }
1279}
1280
1281impl Frontend {
1282 #[allow(clippy::too_many_arguments)]
1303 pub fn new(
1304 domain_rule: &DomainRule,
1305 path_rule: &PathRule,
1306 front: &HttpFrontend,
1307 redirect: RedirectPolicy,
1308 redirect_scheme: RedirectScheme,
1309 redirect_template: Option<String>,
1310 rewrite_host: Option<String>,
1311 rewrite_path: Option<String>,
1312 rewrite_port: Option<u16>,
1313 headers: &[sozu_command::proto::command::Header],
1314 required_auth: bool,
1315 hsts_origin: HstsOrigin,
1316 ) -> Result<Self, RouterError> {
1317 let hsts = front.hsts.as_ref();
1324 let inherits_listener_hsts =
1325 matches!(hsts_origin, HstsOrigin::InheritedFromListenerDefault) && hsts.is_some();
1326 let cluster_id = front.cluster_id.clone();
1327 let tags = front
1328 .tags
1329 .clone()
1330 .map(|tags| Rc::new(CachedTags::new(tags)));
1331
1332 let redirect_template = redirect_template.filter(|s| !s.is_empty());
1337 let rewrite_host = rewrite_host.filter(|s| !s.is_empty());
1338 let rewrite_path = rewrite_path.filter(|s| !s.is_empty());
1339
1340 let deny = match (&cluster_id, redirect) {
1341 (_, RedirectPolicy::Unauthorized) => true,
1342 (None, RedirectPolicy::Forward) => {
1343 warn!(
1344 "{} Frontend[domain: {:?}, path: {:?}]: forward on clusterless frontends are unauthorized",
1345 log_module_context!(),
1346 domain_rule,
1347 path_rule,
1348 );
1349 true
1350 }
1351 _ => false,
1352 };
1353 if deny {
1354 let mut deny_headers_response: Vec<HeaderEdit> = Vec::new();
1362 if let Some(cfg) = hsts
1363 && matches!(cfg.enabled, Some(true))
1364 && let Some(rendered) = render_hsts(cfg)
1365 {
1366 let mode = if matches!(cfg.force_replace_backend, Some(true)) {
1367 HeaderEditMode::Set
1368 } else {
1369 HeaderEditMode::SetIfAbsent
1370 };
1371 deny_headers_response.push(HeaderEdit {
1372 key: Rc::from(&b"strict-transport-security"[..]),
1373 val: rendered.into_bytes().into(),
1374 mode,
1375 });
1376 crate::incr!(names::http::HSTS_FRONTEND_ADDED);
1377 }
1378
1379 return Ok(Self {
1380 cluster_id,
1381 redirect: RedirectPolicy::Unauthorized,
1382 redirect_scheme,
1383 redirect_template: None,
1384 capture_cap_host: 0,
1385 capture_cap_path: 0,
1386 rewrite_host: None,
1387 rewrite_path: None,
1388 rewrite_port: None,
1389 headers_request: Rc::new([]),
1390 headers_response: deny_headers_response.into(),
1391 required_auth,
1392 tags,
1393 inherits_listener_hsts,
1394 });
1395 }
1396
1397 let mut capture_cap_host = match domain_rule {
1402 DomainRule::Any => 1,
1403 DomainRule::Exact(_) => 1,
1404 DomainRule::Wildcard(_) => 2,
1405 DomainRule::Regex(regex) => regex.captures_len(),
1406 };
1407 let mut capture_cap_path = match path_rule {
1408 PathRule::Equals(_) => 1,
1409 PathRule::Prefix(_) => 2,
1410 PathRule::Regex(regex) => regex.captures_len(),
1411 };
1412 let mut used_capture_host = 0usize;
1413 let mut used_capture_path = 0usize;
1414 let rewrite_host_parts = if let Some(p) = rewrite_host {
1415 Some(
1416 RewriteParts::parse(
1417 &p,
1418 capture_cap_host,
1419 capture_cap_path,
1420 &mut used_capture_host,
1421 &mut used_capture_path,
1422 )
1423 .ok_or(RouterError::InvalidHostRewrite(p))?,
1424 )
1425 } else {
1426 None
1427 };
1428 let rewrite_path_parts = if let Some(p) = rewrite_path {
1429 Some(
1430 RewriteParts::parse(
1431 &p,
1432 capture_cap_host,
1433 capture_cap_path,
1434 &mut used_capture_host,
1435 &mut used_capture_path,
1436 )
1437 .ok_or(RouterError::InvalidPathRewrite(p))?,
1438 )
1439 } else {
1440 None
1441 };
1442 if used_capture_host == 0 {
1445 capture_cap_host = 0;
1446 }
1447 if used_capture_path == 0 {
1448 capture_cap_path = 0;
1449 }
1450
1451 let mut headers_request = Vec::new();
1452 let mut headers_response = Vec::new();
1453 for header in headers {
1454 let edit = HeaderEdit {
1455 key: header.key.as_bytes().into(),
1456 val: header.val.as_bytes().into(),
1457 mode: HeaderEditMode::Append,
1458 };
1459 match header.position() {
1460 HeaderPosition::Request => headers_request.push(edit),
1461 HeaderPosition::Response => headers_response.push(edit),
1462 HeaderPosition::Both => {
1463 headers_request.push(edit.clone());
1464 headers_response.push(edit);
1465 }
1466 HeaderPosition::Unspecified => {
1472 warn!(
1473 "{} dropping Header {{ key: {:?}, val: {:?} }} with HEADER_POSITION_UNSPECIFIED",
1474 log_module_context!(),
1475 header.key,
1476 header.val,
1477 );
1478 }
1479 }
1480 }
1481
1482 if let Some(cfg) = hsts
1490 && matches!(cfg.enabled, Some(true))
1491 {
1492 if let Some(rendered) = render_hsts(cfg) {
1493 let mode = if matches!(cfg.force_replace_backend, Some(true)) {
1499 HeaderEditMode::Set
1500 } else {
1501 HeaderEditMode::SetIfAbsent
1502 };
1503 headers_response.push(HeaderEdit {
1504 key: Rc::from(&b"strict-transport-security"[..]),
1505 val: rendered.into_bytes().into(),
1506 mode,
1507 });
1508 crate::incr!(names::http::HSTS_FRONTEND_ADDED);
1509 } else {
1510 warn!(
1518 "{} HSTS enabled = true on frontend {:?} but render_hsts \
1519 returned None (max_age missing). Frontend will not emit \
1520 Strict-Transport-Security; the config layer that built \
1521 this HstsConfig must substitute DEFAULT_HSTS_MAX_AGE.",
1522 log_module_context!(),
1523 cluster_id,
1524 );
1525 crate::incr!(names::http::HSTS_UNRENDERED);
1526 }
1527 }
1528
1529 Ok(Frontend {
1530 cluster_id,
1531 redirect,
1532 redirect_scheme,
1533 redirect_template,
1534 capture_cap_host,
1535 capture_cap_path,
1536 rewrite_host: rewrite_host_parts,
1537 rewrite_path: rewrite_path_parts,
1538 rewrite_port,
1539 headers_request: headers_request.into(),
1540 headers_response: headers_response.into(),
1541 required_auth,
1542 tags,
1543 inherits_listener_hsts,
1544 })
1545 }
1546
1547 pub(crate) fn minimal_forward(cluster_id: ClusterId) -> Self {
1558 Self {
1559 cluster_id: Some(cluster_id),
1560 redirect: RedirectPolicy::Forward,
1561 redirect_scheme: RedirectScheme::UseSame,
1562 redirect_template: None,
1563 capture_cap_host: 0,
1564 capture_cap_path: 0,
1565 rewrite_host: None,
1566 rewrite_path: None,
1567 rewrite_port: None,
1568 headers_request: Rc::new([]),
1569 headers_response: Rc::new([]),
1570 required_auth: false,
1571 tags: None,
1572 inherits_listener_hsts: true,
1573 }
1574 }
1575
1576 pub(crate) fn minimal_deny() -> Self {
1587 Self {
1588 cluster_id: None,
1589 redirect: RedirectPolicy::Unauthorized,
1590 redirect_scheme: RedirectScheme::UseSame,
1591 redirect_template: None,
1592 capture_cap_host: 0,
1593 capture_cap_path: 0,
1594 rewrite_host: None,
1595 rewrite_path: None,
1596 rewrite_port: None,
1597 headers_request: Rc::new([]),
1598 headers_response: Rc::new([]),
1599 required_auth: false,
1600 tags: None,
1601 inherits_listener_hsts: true,
1602 }
1603 }
1604}
1605
1606#[derive(Debug, Clone, PartialEq)]
1621pub struct RouteResult {
1622 pub cluster_id: Option<ClusterId>,
1623 pub redirect: RedirectPolicy,
1624 pub redirect_scheme: RedirectScheme,
1625 pub redirect_template: Option<String>,
1626 pub rewritten_host: Option<String>,
1627 pub rewritten_path: Option<String>,
1628 pub rewritten_port: Option<u16>,
1629 pub headers_request: Rc<[HeaderEdit]>,
1630 pub headers_response: Rc<[HeaderEdit]>,
1631 pub required_auth: bool,
1632 pub tags: Option<Rc<CachedTags>>,
1633}
1634
1635impl RouteResult {
1636 pub fn deny(cluster_id: Option<ClusterId>) -> Self {
1638 Self {
1639 cluster_id,
1640 redirect: RedirectPolicy::Unauthorized,
1641 redirect_scheme: RedirectScheme::UseSame,
1642 redirect_template: None,
1643 rewritten_host: None,
1644 rewritten_path: None,
1645 rewritten_port: None,
1646 headers_request: Rc::new([]),
1647 headers_response: Rc::new([]),
1648 required_auth: false,
1649 tags: None,
1650 }
1651 }
1652
1653 pub fn forward(cluster_id: ClusterId) -> Self {
1656 Self {
1657 cluster_id: Some(cluster_id),
1658 redirect: RedirectPolicy::Forward,
1659 redirect_scheme: RedirectScheme::UseSame,
1660 redirect_template: None,
1661 rewritten_host: None,
1662 rewritten_path: None,
1663 rewritten_port: None,
1664 headers_request: Rc::new([]),
1665 headers_response: Rc::new([]),
1666 required_auth: false,
1667 tags: None,
1668 }
1669 }
1670
1671 fn from_frontend(
1674 frontend: &Frontend,
1675 captures_host: Vec<&str>,
1676 path: &[u8],
1677 path_rule: &PathRule,
1678 ) -> Self {
1679 if frontend.redirect == RedirectPolicy::Unauthorized {
1687 return Self {
1688 cluster_id: frontend.cluster_id.clone(),
1689 redirect: RedirectPolicy::Unauthorized,
1690 redirect_scheme: frontend.redirect_scheme,
1691 redirect_template: frontend.redirect_template.clone(),
1692 rewritten_host: None,
1693 rewritten_path: None,
1694 rewritten_port: None,
1695 headers_request: Rc::new([]),
1696 headers_response: frontend.headers_response.clone(),
1697 required_auth: frontend.required_auth,
1698 tags: frontend.tags.clone(),
1699 };
1700 }
1701
1702 let mut captures_path: Vec<&str> = Vec::with_capacity(frontend.capture_cap_path);
1703 if frontend.capture_cap_path > 0 {
1704 captures_path.push(from_utf8(path).unwrap_or_default());
1705 match path_rule {
1706 PathRule::Prefix(prefix) => {
1707 let tail_start = prefix.len().min(path.len());
1708 captures_path.push(from_utf8(&path[tail_start..]).unwrap_or_default());
1709 }
1710 PathRule::Regex(regex) => {
1711 if let Some(caps) = regex.captures(path) {
1712 captures_path.extend(caps.iter().skip(1).map(|c| {
1713 c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
1714 .unwrap_or("")
1715 }));
1716 }
1717 }
1718 PathRule::Equals(_) => {}
1719 }
1720 }
1721
1722 Self {
1723 cluster_id: frontend.cluster_id.clone(),
1724 redirect: frontend.redirect,
1725 redirect_scheme: frontend.redirect_scheme,
1726 redirect_template: frontend.redirect_template.clone(),
1727 rewritten_host: frontend
1728 .rewrite_host
1729 .as_ref()
1730 .map(|rewrite| rewrite.run(&captures_host, &captures_path)),
1731 rewritten_path: frontend
1732 .rewrite_path
1733 .as_ref()
1734 .map(|rewrite| rewrite.run(&captures_host, &captures_path)),
1735 rewritten_port: frontend.rewrite_port,
1736 headers_request: frontend.headers_request.clone(),
1737 headers_response: frontend.headers_response.clone(),
1738 required_auth: frontend.required_auth,
1739 tags: frontend.tags.clone(),
1740 }
1741 }
1742
1743 fn new_no_trie<'a>(
1748 domain: &'a [u8],
1749 domain_rule: &DomainRule,
1750 path: &'a [u8],
1751 path_rule: &PathRule,
1752 route: &Route,
1753 ) -> Self {
1754 let frontend = match route {
1755 Route::Frontend(f) => f.clone(),
1756 Route::ClusterId(id) => return Self::forward(id.clone()),
1757 Route::Deny => return Self::deny(None),
1758 };
1759 let mut captures_host: Vec<&str> = Vec::with_capacity(frontend.capture_cap_host);
1760 if frontend.capture_cap_host > 0 {
1761 captures_host.push(from_utf8(domain).unwrap_or_default());
1762 match domain_rule {
1763 DomainRule::Wildcard(suffix) => {
1764 let head_end = domain.len().saturating_sub(suffix.len().saturating_sub(1));
1765 captures_host.push(from_utf8(&domain[..head_end]).unwrap_or_default());
1766 }
1767 DomainRule::Regex(regex) => {
1768 if let Some(caps) = regex.captures(domain) {
1769 captures_host.extend(caps.iter().skip(1).map(|c| {
1770 c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
1771 .unwrap_or("")
1772 }));
1773 }
1774 }
1775 DomainRule::Any | DomainRule::Exact(_) => {}
1776 }
1777 }
1778 Self::from_frontend(&frontend, captures_host, path, path_rule)
1779 }
1780
1781 fn new_with_trie<'a, 'b>(
1786 domain: &'a [u8],
1787 domain_submatches: TrieMatches<'a, 'b>,
1788 path: &'a [u8],
1789 path_rule: &PathRule,
1790 route: &Route,
1791 ) -> Self {
1792 let frontend = match route {
1793 Route::Frontend(f) => f.clone(),
1794 Route::ClusterId(id) => return Self::forward(id.clone()),
1795 Route::Deny => return Self::deny(None),
1796 };
1797 let mut captures_host: Vec<&str> = Vec::with_capacity(frontend.capture_cap_host);
1798 if frontend.capture_cap_host > 0 {
1799 captures_host.push(from_utf8(domain).unwrap_or_default());
1800 for submatch in &domain_submatches {
1801 match submatch {
1802 TrieSubMatch::Wildcard(part) => {
1803 captures_host.push(from_utf8(part).unwrap_or_default());
1804 }
1805 TrieSubMatch::Regexp(part, regex) => {
1806 if let Some(caps) = regex.captures(part) {
1807 captures_host.extend(caps.iter().skip(1).map(|c| {
1808 c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
1809 .unwrap_or("")
1810 }));
1811 }
1812 }
1813 }
1814 }
1815 }
1816 Self::from_frontend(&frontend, captures_host, path, path_rule)
1817 }
1818}
1819
1820#[cfg(test)]
1821mod tests {
1822 use super::*;
1823
1824 #[test]
1825 fn render_hsts_max_age_only() {
1826 let cfg = HstsConfig {
1827 enabled: Some(true),
1828 max_age: Some(31_536_000),
1829 include_subdomains: None,
1830 preload: None,
1831 force_replace_backend: None,
1832 };
1833 assert_eq!(render_hsts(&cfg), Some("max-age=31536000".to_owned()));
1834 }
1835
1836 #[test]
1837 fn render_hsts_with_include_subdomains() {
1838 let cfg = HstsConfig {
1839 enabled: Some(true),
1840 max_age: Some(31_536_000),
1841 include_subdomains: Some(true),
1842 preload: None,
1843 force_replace_backend: None,
1844 };
1845 assert_eq!(
1846 render_hsts(&cfg),
1847 Some("max-age=31536000; includeSubDomains".to_owned())
1848 );
1849 }
1850
1851 #[test]
1852 fn render_hsts_with_preload_only() {
1853 let cfg = HstsConfig {
1854 enabled: Some(true),
1855 max_age: Some(63_072_000),
1856 include_subdomains: None,
1857 preload: Some(true),
1858 force_replace_backend: None,
1859 };
1860 assert_eq!(
1861 render_hsts(&cfg),
1862 Some("max-age=63072000; preload".to_owned())
1863 );
1864 }
1865
1866 #[test]
1867 fn render_hsts_full() {
1868 let cfg = HstsConfig {
1869 enabled: Some(true),
1870 max_age: Some(31_536_000),
1871 include_subdomains: Some(true),
1872 preload: Some(true),
1873 force_replace_backend: None,
1874 };
1875 assert_eq!(
1876 render_hsts(&cfg),
1877 Some("max-age=31536000; includeSubDomains; preload".to_owned())
1878 );
1879 }
1880
1881 #[test]
1882 fn render_hsts_kill_switch_max_age_zero() {
1883 let cfg = HstsConfig {
1884 enabled: Some(true),
1885 max_age: Some(0),
1886 include_subdomains: Some(true),
1887 preload: None,
1888 force_replace_backend: None,
1889 };
1890 assert_eq!(
1894 render_hsts(&cfg),
1895 Some("max-age=0; includeSubDomains".to_owned())
1896 );
1897 }
1898
1899 #[test]
1900 fn render_hsts_omitted_when_max_age_missing() {
1901 let cfg = HstsConfig {
1902 enabled: Some(true),
1903 max_age: None,
1904 include_subdomains: Some(true),
1905 preload: None,
1906 force_replace_backend: None,
1907 };
1908 assert_eq!(render_hsts(&cfg), None);
1912 }
1913
1914 #[test]
1915 fn rebuild_with_listener_hsts_replaces_existing_entry() {
1916 let frontend = Frontend {
1920 cluster_id: Some("api".to_owned()),
1921 redirect: RedirectPolicy::Forward,
1922 redirect_scheme: RedirectScheme::UseSame,
1923 redirect_template: None,
1924 capture_cap_host: 0,
1925 capture_cap_path: 0,
1926 rewrite_host: None,
1927 rewrite_path: None,
1928 rewrite_port: None,
1929 headers_request: Rc::new([]),
1930 headers_response: Rc::from(vec![
1931 HeaderEdit {
1932 key: Rc::from(&b"x-cache"[..]),
1933 val: Rc::from(&b"hit"[..]),
1934 mode: HeaderEditMode::Append,
1935 },
1936 HeaderEdit {
1937 key: Rc::from(&b"strict-transport-security"[..]),
1938 val: Rc::from(&b"max-age=31536000"[..]),
1939 mode: HeaderEditMode::SetIfAbsent,
1940 },
1941 ]),
1942 required_auth: false,
1943 tags: None,
1944 inherits_listener_hsts: true,
1945 };
1946 let new_hsts = HstsConfig {
1947 enabled: Some(true),
1948 max_age: Some(63_072_000),
1949 include_subdomains: Some(true),
1950 preload: None,
1951 force_replace_backend: None,
1952 };
1953 let new_edit = build_listener_hsts_edit(Some(&new_hsts));
1954 let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
1955
1956 let response: Vec<_> = rebuilt.headers_response.iter().collect();
1957 assert_eq!(response.len(), 2, "x-cache + new STS, no leftover STS");
1958 assert_eq!(&*response[0].key, b"x-cache");
1959 assert_eq!(&*response[1].key, b"strict-transport-security");
1960 assert_eq!(
1961 &*response[1].val,
1962 b"max-age=63072000; includeSubDomains".as_slice()
1963 );
1964 assert!(rebuilt.inherits_listener_hsts);
1965 }
1966
1967 #[test]
1968 fn rebuild_with_listener_hsts_strips_when_none() {
1969 let frontend = Frontend {
1972 cluster_id: Some("api".to_owned()),
1973 redirect: RedirectPolicy::Forward,
1974 redirect_scheme: RedirectScheme::UseSame,
1975 redirect_template: None,
1976 capture_cap_host: 0,
1977 capture_cap_path: 0,
1978 rewrite_host: None,
1979 rewrite_path: None,
1980 rewrite_port: None,
1981 headers_request: Rc::new([]),
1982 headers_response: Rc::from(vec![
1983 HeaderEdit {
1984 key: Rc::from(&b"x-cache"[..]),
1985 val: Rc::from(&b"hit"[..]),
1986 mode: HeaderEditMode::Append,
1987 },
1988 HeaderEdit {
1989 key: Rc::from(&b"strict-transport-security"[..]),
1990 val: Rc::from(&b"max-age=31536000"[..]),
1991 mode: HeaderEditMode::SetIfAbsent,
1992 },
1993 ]),
1994 required_auth: false,
1995 tags: None,
1996 inherits_listener_hsts: true,
1997 };
1998 let new_edit = build_listener_hsts_edit(None);
1999 let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
2000 let response: Vec<_> = rebuilt.headers_response.iter().collect();
2001 assert_eq!(response.len(), 1);
2002 assert_eq!(&*response[0].key, b"x-cache");
2003 }
2004
2005 #[test]
2006 fn rebuild_with_listener_hsts_disabled_strips() {
2007 let frontend = Frontend {
2010 cluster_id: Some("api".to_owned()),
2011 redirect: RedirectPolicy::Forward,
2012 redirect_scheme: RedirectScheme::UseSame,
2013 redirect_template: None,
2014 capture_cap_host: 0,
2015 capture_cap_path: 0,
2016 rewrite_host: None,
2017 rewrite_path: None,
2018 rewrite_port: None,
2019 headers_request: Rc::new([]),
2020 headers_response: Rc::from(vec![HeaderEdit {
2021 key: Rc::from(&b"strict-transport-security"[..]),
2022 val: Rc::from(&b"max-age=31536000"[..]),
2023 mode: HeaderEditMode::SetIfAbsent,
2024 }]),
2025 required_auth: false,
2026 tags: None,
2027 inherits_listener_hsts: true,
2028 };
2029 let new_hsts = HstsConfig {
2030 enabled: Some(false),
2031 max_age: None,
2032 include_subdomains: None,
2033 preload: None,
2034 force_replace_backend: None,
2035 };
2036 let new_edit = build_listener_hsts_edit(Some(&new_hsts));
2037 let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
2038 assert_eq!(rebuilt.headers_response.len(), 0);
2039 }
2040
2041 #[test]
2042 fn refresh_inheriting_hsts_skips_explicit_overrides() {
2043 use crate::router::pattern_trie::TrieNode;
2047 let mut router = Router {
2048 pre: Vec::new(),
2049 tree: TrieNode::root(),
2050 post: Vec::new(),
2051 };
2052 let inheriting = Frontend {
2053 cluster_id: Some("api".to_owned()),
2054 redirect: RedirectPolicy::Forward,
2055 redirect_scheme: RedirectScheme::UseSame,
2056 redirect_template: None,
2057 capture_cap_host: 0,
2058 capture_cap_path: 0,
2059 rewrite_host: None,
2060 rewrite_path: None,
2061 rewrite_port: None,
2062 headers_request: Rc::new([]),
2063 headers_response: Rc::from(vec![HeaderEdit {
2064 key: Rc::from(&b"strict-transport-security"[..]),
2065 val: Rc::from(&b"max-age=31536000"[..]),
2066 mode: HeaderEditMode::SetIfAbsent,
2067 }]),
2068 required_auth: false,
2069 tags: None,
2070 inherits_listener_hsts: true,
2071 };
2072 let explicit = Frontend {
2073 cluster_id: Some("legacy".to_owned()),
2074 redirect: RedirectPolicy::Forward,
2075 redirect_scheme: RedirectScheme::UseSame,
2076 redirect_template: None,
2077 capture_cap_host: 0,
2078 capture_cap_path: 0,
2079 rewrite_host: None,
2080 rewrite_path: None,
2081 rewrite_port: None,
2082 headers_request: Rc::new([]),
2083 headers_response: Rc::from(vec![HeaderEdit {
2084 key: Rc::from(&b"strict-transport-security"[..]),
2085 val: Rc::from(&b"max-age=300"[..]),
2086 mode: HeaderEditMode::SetIfAbsent,
2087 }]),
2088 required_auth: false,
2089 tags: None,
2090 inherits_listener_hsts: false,
2091 };
2092 router.pre.push((
2093 DomainRule::Any,
2094 PathRule::Prefix("/api".to_owned()),
2095 MethodRule::new(None),
2096 Route::Frontend(Rc::new(inheriting)),
2097 ));
2098 router.post.push((
2099 DomainRule::Any,
2100 PathRule::Prefix("/legacy".to_owned()),
2101 MethodRule::new(None),
2102 Route::Frontend(Rc::new(explicit)),
2103 ));
2104
2105 let new_hsts = HstsConfig {
2106 enabled: Some(true),
2107 max_age: Some(63_072_000),
2108 include_subdomains: Some(true),
2109 preload: None,
2110 force_replace_backend: None,
2111 };
2112 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2113 assert_eq!(count, 1, "only the inheriting frontend should refresh");
2114
2115 if let Route::Frontend(rc) = &router.pre[0].3 {
2116 let response: Vec<_> = rc.headers_response.iter().collect();
2117 assert_eq!(
2118 &*response.last().unwrap().val,
2119 b"max-age=63072000; includeSubDomains".as_slice(),
2120 "inheriting frontend's STS must reflect the new listener default"
2121 );
2122 } else {
2123 panic!("pre[0] should be Route::Frontend");
2124 }
2125 if let Route::Frontend(rc) = &router.post[0].3 {
2126 let response: Vec<_> = rc.headers_response.iter().collect();
2127 assert_eq!(
2128 &*response.last().unwrap().val,
2129 b"max-age=300".as_slice(),
2130 "explicit override must be preserved unchanged"
2131 );
2132 } else {
2133 panic!("post[0] should be Route::Frontend");
2134 }
2135 }
2136
2137 #[test]
2138 fn refresh_inheriting_hsts_promotes_clusterid_on_enable() {
2139 use crate::router::pattern_trie::TrieNode;
2149 let mut router = Router {
2150 pre: Vec::new(),
2151 tree: TrieNode::root(),
2152 post: vec![(
2153 DomainRule::Any,
2154 PathRule::Prefix("/".to_owned()),
2155 MethodRule::new(None),
2156 Route::ClusterId("api".to_owned()),
2157 )],
2158 };
2159
2160 let new_hsts = HstsConfig {
2161 enabled: Some(true),
2162 max_age: Some(31_536_000),
2163 include_subdomains: Some(true),
2164 preload: None,
2165 force_replace_backend: None,
2166 };
2167 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2168 assert_eq!(count, 1, "the ClusterId entry must be promoted + counted");
2169
2170 let Route::Frontend(rc) = &router.post[0].3 else {
2171 panic!("post[0] should now be Route::Frontend, not the original Route::ClusterId");
2172 };
2173 assert_eq!(rc.cluster_id.as_deref(), Some("api"));
2174 assert_eq!(
2175 rc.redirect,
2176 RedirectPolicy::Forward,
2177 "promoted entry must keep Forward semantics so lookup yields the same backend"
2178 );
2179 assert!(
2180 rc.inherits_listener_hsts,
2181 "promoted entry must mark itself inheriting so the next patch refreshes it"
2182 );
2183 let response: Vec<_> = rc.headers_response.iter().collect();
2184 assert_eq!(
2185 response.len(),
2186 1,
2187 "promoted entry carries exactly one STS edit, no operator headers"
2188 );
2189 assert_eq!(&*response[0].key, b"strict-transport-security");
2190 assert_eq!(
2191 &*response[0].val,
2192 b"max-age=31536000; includeSubDomains".as_slice()
2193 );
2194 }
2195
2196 #[test]
2197 fn refresh_inheriting_hsts_promotes_deny_on_enable() {
2198 use crate::router::pattern_trie::TrieNode;
2204 let mut router = Router {
2205 pre: Vec::new(),
2206 tree: TrieNode::root(),
2207 post: vec![(
2208 DomainRule::Any,
2209 PathRule::Prefix("/forbidden".to_owned()),
2210 MethodRule::new(None),
2211 Route::Deny,
2212 )],
2213 };
2214
2215 let new_hsts = HstsConfig {
2216 enabled: Some(true),
2217 max_age: Some(31_536_000),
2218 include_subdomains: None,
2219 preload: None,
2220 force_replace_backend: None,
2221 };
2222 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2223 assert_eq!(count, 1);
2224
2225 let Route::Frontend(rc) = &router.post[0].3 else {
2226 panic!("post[0] should now be Route::Frontend, not the original Route::Deny");
2227 };
2228 assert_eq!(rc.cluster_id, None, "promoted Deny stays clusterless");
2229 assert_eq!(
2230 rc.redirect,
2231 RedirectPolicy::Unauthorized,
2232 "promoted Deny must keep Unauthorized so lookup yields a 401"
2233 );
2234 assert!(rc.inherits_listener_hsts);
2235 let response: Vec<_> = rc.headers_response.iter().collect();
2236 assert_eq!(response.len(), 1);
2237 assert_eq!(&*response[0].key, b"strict-transport-security");
2238 assert_eq!(&*response[0].val, b"max-age=31536000".as_slice());
2239 }
2240
2241 #[test]
2242 fn refresh_inheriting_hsts_skips_lightweight_on_disable() {
2243 use crate::router::pattern_trie::TrieNode;
2253 let make_router = || Router {
2254 pre: vec![(
2255 DomainRule::Any,
2256 PathRule::Prefix("/".to_owned()),
2257 MethodRule::new(None),
2258 Route::ClusterId("api".to_owned()),
2259 )],
2260 tree: TrieNode::root(),
2261 post: vec![(
2262 DomainRule::Any,
2263 PathRule::Prefix("/forbidden".to_owned()),
2264 MethodRule::new(None),
2265 Route::Deny,
2266 )],
2267 };
2268
2269 for (label, hsts) in [
2270 ("none", None),
2271 (
2272 "disabled",
2273 Some(HstsConfig {
2274 enabled: Some(false),
2275 max_age: None,
2276 include_subdomains: None,
2277 preload: None,
2278 force_replace_backend: None,
2279 }),
2280 ),
2281 (
2282 "enabled-without-max-age",
2283 Some(HstsConfig {
2284 enabled: Some(true),
2285 max_age: None,
2286 include_subdomains: None,
2287 preload: None,
2288 force_replace_backend: None,
2289 }),
2290 ),
2291 ] {
2292 let mut router = make_router();
2293 let count = router.refresh_inheriting_hsts(hsts.as_ref());
2294 assert_eq!(count, 0, "no promotion expected for {label}");
2295 assert!(
2296 matches!(router.pre[0].3, Route::ClusterId(_)),
2297 "{label}: ClusterId must stay lightweight"
2298 );
2299 assert!(
2300 matches!(router.post[0].3, Route::Deny),
2301 "{label}: Deny must stay lightweight"
2302 );
2303 }
2304 }
2305
2306 #[test]
2307 fn refresh_inheriting_hsts_promoted_entry_refreshes_on_subsequent_patches() {
2308 use crate::router::pattern_trie::TrieNode;
2313 let mut router = Router {
2314 pre: Vec::new(),
2315 tree: TrieNode::root(),
2316 post: vec![(
2317 DomainRule::Any,
2318 PathRule::Prefix("/".to_owned()),
2319 MethodRule::new(None),
2320 Route::ClusterId("api".to_owned()),
2321 )],
2322 };
2323
2324 let first_patch = HstsConfig {
2325 enabled: Some(true),
2326 max_age: Some(31_536_000),
2327 include_subdomains: None,
2328 preload: None,
2329 force_replace_backend: None,
2330 };
2331 assert_eq!(router.refresh_inheriting_hsts(Some(&first_patch)), 1);
2332
2333 let second_patch = HstsConfig {
2334 enabled: Some(true),
2335 max_age: Some(63_072_000),
2336 include_subdomains: Some(true),
2337 preload: None,
2338 force_replace_backend: None,
2339 };
2340 assert_eq!(
2341 router.refresh_inheriting_hsts(Some(&second_patch)),
2342 1,
2343 "the previously promoted entry must be re-counted via the path-1 branch"
2344 );
2345
2346 let Route::Frontend(rc) = &router.post[0].3 else {
2347 panic!("post[0] should still be Route::Frontend after the second patch");
2348 };
2349 let response: Vec<_> = rc.headers_response.iter().collect();
2350 assert_eq!(
2351 response.len(),
2352 1,
2353 "second patch must REPLACE the existing STS edit, not append a duplicate"
2354 );
2355 assert_eq!(
2356 &*response[0].val,
2357 b"max-age=63072000; includeSubDomains".as_slice()
2358 );
2359 }
2360
2361 #[test]
2362 fn refresh_inheriting_hsts_promoted_entry_loses_hsts_on_disable_patch() {
2363 use crate::router::pattern_trie::TrieNode;
2370 let mut router = Router {
2371 pre: vec![(
2372 DomainRule::Any,
2373 PathRule::Prefix("/".to_owned()),
2374 MethodRule::new(None),
2375 Route::ClusterId("api".to_owned()),
2376 )],
2377 tree: TrieNode::root(),
2378 post: Vec::new(),
2379 };
2380
2381 let enable = HstsConfig {
2382 enabled: Some(true),
2383 max_age: Some(31_536_000),
2384 include_subdomains: None,
2385 preload: None,
2386 force_replace_backend: None,
2387 };
2388 assert_eq!(router.refresh_inheriting_hsts(Some(&enable)), 1);
2389
2390 let disable = HstsConfig {
2391 enabled: Some(false),
2392 max_age: None,
2393 include_subdomains: None,
2394 preload: None,
2395 force_replace_backend: None,
2396 };
2397 assert_eq!(
2398 router.refresh_inheriting_hsts(Some(&disable)),
2399 1,
2400 "the promoted entry must still be touched on disable to strip its STS edit"
2401 );
2402
2403 let Route::Frontend(rc) = &router.pre[0].3 else {
2404 panic!("pre[0] should still be Route::Frontend (no demotion)");
2405 };
2406 assert_eq!(rc.cluster_id.as_deref(), Some("api"));
2407 assert_eq!(
2408 rc.headers_response.len(),
2409 0,
2410 "disable patch must strip the STS edit from the promoted entry"
2411 );
2412 }
2413
2414 #[test]
2415 fn refresh_inheriting_hsts_promotes_clusterid_in_trie_on_enable() {
2416 use crate::router::pattern_trie::TrieNode;
2423 let mut router = Router {
2424 pre: Vec::new(),
2425 tree: TrieNode::root(),
2426 post: Vec::new(),
2427 };
2428 let path_rule = PathRule::Prefix("/".to_owned());
2429 let method_rule = MethodRule::new(None);
2430 assert!(router.add_tree_rule(
2431 b"example.com",
2432 &path_rule,
2433 &method_rule,
2434 &Route::ClusterId("api".to_owned()),
2435 ));
2436
2437 let new_hsts = HstsConfig {
2438 enabled: Some(true),
2439 max_age: Some(31_536_000),
2440 include_subdomains: Some(true),
2441 preload: None,
2442 force_replace_backend: None,
2443 };
2444 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2445 assert_eq!(
2446 count, 1,
2447 "trie-resident ClusterId must be promoted + counted"
2448 );
2449
2450 let (_, paths) = router
2451 .tree
2452 .domain_lookup_mut(b"example.com", false)
2453 .expect("trie leaf still present after refresh");
2454 assert_eq!(paths.len(), 1);
2455 let Route::Frontend(rc) = &paths[0].2 else {
2456 panic!("trie leaf should now be Route::Frontend, not Route::ClusterId");
2457 };
2458 assert_eq!(rc.cluster_id.as_deref(), Some("api"));
2459 assert_eq!(rc.redirect, RedirectPolicy::Forward);
2460 assert!(rc.inherits_listener_hsts);
2461 let response: Vec<_> = rc.headers_response.iter().collect();
2462 assert_eq!(response.len(), 1);
2463 assert_eq!(&*response[0].key, b"strict-transport-security");
2464 assert_eq!(
2465 &*response[0].val,
2466 b"max-age=31536000; includeSubDomains".as_slice()
2467 );
2468 }
2469
2470 #[test]
2471 fn convert_regex() {
2472 assert_eq!(
2475 convert_regex_domain_rule("www.example.com")
2476 .unwrap()
2477 .as_str(),
2478 "\\Awww\\.example\\.com\\z"
2479 );
2480 assert_eq!(
2481 convert_regex_domain_rule("*.example.com").unwrap().as_str(),
2482 "\\A*\\.example\\.com\\z"
2483 );
2484 assert_eq!(
2485 convert_regex_domain_rule("test.*.example.com")
2486 .unwrap()
2487 .as_str(),
2488 "\\Atest\\.*\\.example\\.com\\z"
2489 );
2490 assert_eq!(
2491 convert_regex_domain_rule("css./cdn[a-z0-9]+/.example.com")
2492 .unwrap()
2493 .as_str(),
2494 "\\Acss\\.cdn[a-z0-9]+\\.example\\.com\\z"
2495 );
2496
2497 assert_eq!(
2498 convert_regex_domain_rule("css./cdn[a-z0-9]+.example.com"),
2499 None
2500 );
2501 assert_eq!(
2502 convert_regex_domain_rule("css./cdn[a-z0-9]+/a.example.com"),
2503 None
2504 );
2505 }
2506
2507 #[test]
2512 fn regex_domain_rule_rejects_suffix_and_prefix() {
2513 let rule: DomainRule = "/example\\.com/".parse().unwrap();
2514 assert!(rule.matches(b"example.com"));
2515 assert!(!rule.matches(b"attacker.example.com"));
2516 assert!(!rule.matches(b"example.com.evil.org"));
2517 assert!(!rule.matches(b"prefixexample.com"));
2518 assert!(!rule.matches(b"example.commercial"));
2519 }
2520
2521 #[test]
2527 fn regex_domain_rule_multi_segment_segments_are_isolated() {
2528 let pattern = convert_regex_domain_rule("/seg1/.foo./seg2/.com")
2529 .expect("multi-segment regex hostname must compile");
2530 assert_eq!(pattern.as_str(), "\\Aseg1\\.foo\\.seg2\\.com\\z");
2531 }
2532
2533 #[test]
2534 fn parse_domain_rule() {
2535 assert_eq!("*".parse::<DomainRule>().unwrap(), DomainRule::Any);
2536 assert_eq!(
2537 "www.example.com".parse::<DomainRule>().unwrap(),
2538 DomainRule::Exact("www.example.com".to_string())
2539 );
2540 assert_eq!(
2541 "*.example.com".parse::<DomainRule>().unwrap(),
2542 DomainRule::Wildcard("*.example.com".to_string())
2543 );
2544 assert_eq!("test.*.example.com".parse::<DomainRule>(), Err(()));
2545 assert_eq!(
2546 "/cdn[0-9]+/.example.com".parse::<DomainRule>().unwrap(),
2547 DomainRule::Regex(Regex::new("\\Acdn[0-9]+\\.example\\.com\\z").unwrap())
2548 );
2549 }
2550
2551 #[test]
2552 fn match_domain_rule() {
2553 assert!(DomainRule::Any.matches("www.example.com".as_bytes()));
2554 assert!(
2555 DomainRule::Exact("www.example.com".to_string()).matches("www.example.com".as_bytes())
2556 );
2557 assert!(
2558 DomainRule::Wildcard("*.example.com".to_string()).matches("www.example.com".as_bytes())
2559 );
2560 assert!(
2561 !DomainRule::Wildcard("*.example.com".to_string())
2562 .matches("test.www.example.com".as_bytes())
2563 );
2564 assert!(
2565 "/cdn[0-9]+/.example.com"
2566 .parse::<DomainRule>()
2567 .unwrap()
2568 .matches("cdn1.example.com".as_bytes())
2569 );
2570 assert!(
2571 !"/cdn[0-9]+/.example.com"
2572 .parse::<DomainRule>()
2573 .unwrap()
2574 .matches("www.example.com".as_bytes())
2575 );
2576 assert!(
2577 !"/cdn[0-9]+/.example.com"
2578 .parse::<DomainRule>()
2579 .unwrap()
2580 .matches("cdn10.exampleAcom".as_bytes())
2581 );
2582 }
2583
2584 #[test]
2585 fn match_domain_rule_wildcard_short_hostname_does_not_panic() {
2586 let rule = DomainRule::Wildcard("*.foo.example.com".to_string());
2587
2588 assert!(!rule.matches(b""));
2590
2591 assert!(!rule.matches(b"a.b"));
2593 assert!(!rule.matches(b"x"));
2594
2595 assert!(!rule.matches(b".foo.example.com"));
2598
2599 assert!(!rule.matches(b"y.x.foo.example.com"));
2601
2602 assert!(rule.matches(b"x.foo.example.com"));
2605 }
2606
2607 #[test]
2608 fn router_lookup_wildcard_pre_rule_short_hostname_does_not_panic() {
2609 let mut router = Router::new();
2610
2611 assert!(router.add_pre_rule(
2614 &"*.foo.example.com".parse::<DomainRule>().unwrap(),
2615 &PathRule::Prefix("/".to_string()),
2616 &MethodRule::new(Some("GET".to_string())),
2617 &Route::ClusterId("wildcard".to_string()),
2618 ));
2619
2620 let method = Method::new(&b"GET"[..]);
2621
2622 assert!(router.lookup("", "/", &method).is_err());
2624 assert!(router.lookup("x", "/", &method).is_err());
2625 assert!(router.lookup("a.b", "/", &method).is_err());
2626
2627 assert!(router.lookup(".foo.example.com", "/", &method).is_err());
2629
2630 assert_eq!(
2632 router.lookup("x.foo.example.com", "/", &method),
2633 Ok(RouteResult::forward("wildcard".to_string()))
2634 );
2635 }
2636
2637 #[test]
2638 fn match_path_rule() {
2639 assert!(PathRule::Prefix("".to_string()).matches("/".as_bytes()) != PathRuleResult::None);
2640 assert!(
2641 PathRule::Prefix("".to_string()).matches("/hello".as_bytes()) != PathRuleResult::None
2642 );
2643 assert!(
2644 PathRule::Prefix("/hello".to_string()).matches("/hello".as_bytes())
2645 != PathRuleResult::None
2646 );
2647 assert!(
2648 PathRule::Prefix("/hello".to_string()).matches("/hello/world".as_bytes())
2649 != PathRuleResult::None
2650 );
2651 assert!(
2652 PathRule::Prefix("/hello".to_string()).matches("/".as_bytes()) == PathRuleResult::None
2653 );
2654 }
2655
2656 #[test]
2664 fn multiple_children_on_a_wildcard() {
2665 let mut router = Router::new();
2666
2667 assert!(router.add_tree_rule(
2668 b"*.sozu.io",
2669 &PathRule::Prefix("".to_string()),
2670 &MethodRule::new(Some("GET".to_string())),
2671 &Route::ClusterId("base".to_string())
2672 ));
2673 println!("{:#?}", router.tree);
2674 assert_eq!(
2675 router.lookup("www.sozu.io", "/api", &Method::Get),
2676 Ok(RouteResult::forward("base".to_string()))
2677 );
2678 assert!(router.add_tree_rule(
2679 b"*.sozu.io",
2680 &PathRule::Prefix("/api".to_string()),
2681 &MethodRule::new(Some("GET".to_string())),
2682 &Route::ClusterId("api".to_string())
2683 ));
2684 println!("{:#?}", router.tree);
2685 assert_eq!(
2686 router.lookup("www.sozu.io", "/ap", &Method::Get),
2687 Ok(RouteResult::forward("base".to_string()))
2688 );
2689 assert_eq!(
2690 router.lookup("www.sozu.io", "/api", &Method::Get),
2691 Ok(RouteResult::forward("api".to_string()))
2692 );
2693 }
2694
2695 #[test]
2703 fn multiple_children_including_one_with_wildcard() {
2704 let mut router = Router::new();
2705
2706 assert!(router.add_tree_rule(
2707 b"*.sozu.io",
2708 &PathRule::Prefix("".to_string()),
2709 &MethodRule::new(Some("GET".to_string())),
2710 &Route::ClusterId("base".to_string())
2711 ));
2712 println!("{:#?}", router.tree);
2713 assert_eq!(
2714 router.lookup("www.sozu.io", "/api", &Method::Get),
2715 Ok(RouteResult::forward("base".to_string()))
2716 );
2717 assert!(router.add_tree_rule(
2718 b"api.sozu.io",
2719 &PathRule::Prefix("".to_string()),
2720 &MethodRule::new(Some("GET".to_string())),
2721 &Route::ClusterId("api".to_string())
2722 ));
2723 println!("{:#?}", router.tree);
2724 assert_eq!(
2725 router.lookup("www.sozu.io", "/api", &Method::Get),
2726 Ok(RouteResult::forward("base".to_string()))
2727 );
2728 assert_eq!(
2729 router.lookup("api.sozu.io", "/api", &Method::Get),
2730 Ok(RouteResult::forward("api".to_string()))
2731 );
2732 }
2733
2734 #[test]
2735 fn router_insert_remove_through_regex() {
2736 let mut router = Router::new();
2737
2738 assert!(router.add_tree_rule(
2739 b"www./.*/.io",
2740 &PathRule::Prefix("".to_string()),
2741 &MethodRule::new(Some("GET".to_string())),
2742 &Route::ClusterId("base".to_string())
2743 ));
2744 println!("{:#?}", router.tree);
2745 assert!(router.add_tree_rule(
2746 b"www.doc./.*/.io",
2747 &PathRule::Prefix("".to_string()),
2748 &MethodRule::new(Some("GET".to_string())),
2749 &Route::ClusterId("doc".to_string())
2750 ));
2751 println!("{:#?}", router.tree);
2752 assert_eq!(
2753 router.lookup("www.sozu.io", "/", &Method::Get),
2754 Ok(RouteResult::forward("base".to_string()))
2755 );
2756 assert_eq!(
2757 router.lookup("www.doc.sozu.io", "/", &Method::Get),
2758 Ok(RouteResult::forward("doc".to_string()))
2759 );
2760 assert!(router.remove_tree_rule(
2761 b"www./.*/.io",
2762 &PathRule::Prefix("".to_string()),
2763 &MethodRule::new(Some("GET".to_string()))
2764 ));
2765 println!("{:#?}", router.tree);
2766 assert!(router.lookup("www.sozu.io", "/", &Method::Get).is_err());
2767 assert_eq!(
2768 router.lookup("www.doc.sozu.io", "/", &Method::Get),
2769 Ok(RouteResult::forward("doc".to_string()))
2770 );
2771 }
2772
2773 #[test]
2774 fn match_router() {
2775 let mut router = Router::new();
2776
2777 assert!(router.add_pre_rule(
2778 &"*".parse::<DomainRule>().unwrap(),
2779 &PathRule::Prefix("/.well-known/acme-challenge".to_string()),
2780 &MethodRule::new(Some("GET".to_string())),
2781 &Route::ClusterId("acme".to_string())
2782 ));
2783 assert!(router.add_tree_rule(
2784 "www.example.com".as_bytes(),
2785 &PathRule::Prefix("/".to_string()),
2786 &MethodRule::new(Some("GET".to_string())),
2787 &Route::ClusterId("example".to_string())
2788 ));
2789 assert!(router.add_tree_rule(
2790 "*.test.example.com".as_bytes(),
2791 &PathRule::Regex(Regex::new("/hello[A-Z]+/").unwrap()),
2792 &MethodRule::new(Some("GET".to_string())),
2793 &Route::ClusterId("examplewildcard".to_string())
2794 ));
2795 assert!(router.add_tree_rule(
2796 "/test[0-9]/.example.com".as_bytes(),
2797 &PathRule::Prefix("/".to_string()),
2798 &MethodRule::new(Some("GET".to_string())),
2799 &Route::ClusterId("exampleregex".to_string())
2800 ));
2801
2802 assert_eq!(
2803 router.lookup("www.example.com", "/helloA", &Method::new(&b"GET"[..])),
2804 Ok(RouteResult::forward("example".to_string()))
2805 );
2806 assert_eq!(
2807 router.lookup(
2808 "www.example.com",
2809 "/.well-known/acme-challenge",
2810 &Method::new(&b"GET"[..])
2811 ),
2812 Ok(RouteResult::forward("acme".to_string()))
2813 );
2814 assert!(
2815 router
2816 .lookup("www.test.example.com", "/", &Method::new(&b"GET"[..]))
2817 .is_err()
2818 );
2819 assert_eq!(
2820 router.lookup(
2821 "www.test.example.com",
2822 "/helloAB/",
2823 &Method::new(&b"GET"[..])
2824 ),
2825 Ok(RouteResult::forward("examplewildcard".to_string()))
2826 );
2827 assert_eq!(
2828 router.lookup("test1.example.com", "/helloAB/", &Method::new(&b"GET"[..])),
2829 Ok(RouteResult::forward("exampleregex".to_string()))
2830 );
2831 }
2832
2833 #[test]
2834 fn has_hostname_checks_tree_pre_and_post() {
2835 let mut router = Router::new();
2836
2837 assert!(!router.has_hostname("www.example.com"));
2839
2840 assert!(router.add_tree_rule(
2842 b"www.example.com",
2843 &PathRule::Prefix("/".to_string()),
2844 &MethodRule::new(Some("GET".to_string())),
2845 &Route::ClusterId("cluster1".to_string())
2846 ));
2847 assert!(router.has_hostname("www.example.com"));
2848 assert!(!router.has_hostname("api.example.com"));
2849
2850 assert!(router.remove_tree_rule(
2852 b"www.example.com",
2853 &PathRule::Prefix("/".to_string()),
2854 &MethodRule::new(Some("GET".to_string()))
2855 ));
2856 assert!(!router.has_hostname("www.example.com"));
2857
2858 assert!(router.add_pre_rule(
2860 &DomainRule::Exact("api.example.com".to_string()),
2861 &PathRule::Prefix("/".to_string()),
2862 &MethodRule::new(None),
2863 &Route::ClusterId("cluster2".to_string())
2864 ));
2865 assert!(router.has_hostname("api.example.com"));
2866 assert!(!router.has_hostname("www.example.com"));
2867
2868 assert!(router.add_post_rule(
2870 &DomainRule::Exact("cdn.example.com".to_string()),
2871 &PathRule::Prefix("/".to_string()),
2872 &MethodRule::new(None),
2873 &Route::ClusterId("cluster3".to_string())
2874 ));
2875 assert!(router.has_hostname("cdn.example.com"));
2876
2877 assert!(router.remove_pre_rule(
2879 &DomainRule::Exact("api.example.com".to_string()),
2880 &PathRule::Prefix("/".to_string()),
2881 &MethodRule::new(None),
2882 ));
2883 assert!(!router.has_hostname("api.example.com"));
2884 assert!(router.has_hostname("cdn.example.com"));
2885 }
2886
2887 #[test]
2888 fn has_hostname_false_after_last_route_removed() {
2889 let mut router = Router::new();
2890
2891 assert!(router.add_tree_rule(
2893 b"www.example.com",
2894 &PathRule::Prefix("/".to_string()),
2895 &MethodRule::new(Some("GET".to_string())),
2896 &Route::ClusterId("cluster1".to_string())
2897 ));
2898 assert!(router.add_tree_rule(
2899 b"www.example.com",
2900 &PathRule::Prefix("/api".to_string()),
2901 &MethodRule::new(Some("GET".to_string())),
2902 &Route::ClusterId("cluster2".to_string())
2903 ));
2904 assert!(router.has_hostname("www.example.com"));
2905
2906 assert!(router.remove_tree_rule(
2908 b"www.example.com",
2909 &PathRule::Prefix("/".to_string()),
2910 &MethodRule::new(Some("GET".to_string()))
2911 ));
2912 assert!(router.has_hostname("www.example.com"));
2913
2914 assert!(router.remove_tree_rule(
2916 b"www.example.com",
2917 &PathRule::Prefix("/api".to_string()),
2918 &MethodRule::new(Some("GET".to_string()))
2919 ));
2920 assert!(!router.has_hostname("www.example.com"));
2921 }
2922}