use std::{collections::HashMap, fmt::Debug, iter, str};
use regex::bytes::Regex;
pub type Key = Vec<u8>;
pub type KeyValue<K, V> = (K, V);
#[derive(Debug, PartialEq, Eq)]
pub enum InsertResult {
Ok,
Existing,
Failed,
}
#[derive(Debug, PartialEq, Eq)]
pub enum RemoveResult {
Ok,
NotFound,
}
fn find_last_dot(input: &[u8]) -> Option<usize> {
(0..input.len()).rev().find(|&i| input[i] == b'.')
}
fn find_last_slash(input: &[u8]) -> Option<usize> {
(0..input.len()).rev().find(|&i| input[i] == b'/')
}
#[derive(Debug)]
pub struct TrieNode<V> {
key_value: Option<KeyValue<Key, V>>,
wildcard: Option<KeyValue<Key, V>>,
children: HashMap<Key, TrieNode<V>>,
regexps: Vec<(Regex, TrieNode<V>)>,
}
impl<V: PartialEq> std::cmp::PartialEq for TrieNode<V> {
fn eq(&self, other: &Self) -> bool {
self.key_value == other.key_value
&& self.wildcard == other.wildcard
&& self.children == other.children
&& self.regexps.len() == other.regexps.len()
&& self
.regexps
.iter()
.zip(other.regexps.iter())
.fold(true, |b, (left, right)| {
b && left.0.as_str() == right.0.as_str() && left.1 == right.1
})
}
}
impl<V: Debug + Clone> TrieNode<V> {
pub fn new(key: Key, value: V) -> TrieNode<V> {
TrieNode {
key_value: Some((key, value)),
wildcard: None,
children: HashMap::new(),
regexps: Vec::new(),
}
}
pub fn wildcard(key: Key, value: V) -> TrieNode<V> {
TrieNode {
key_value: None,
wildcard: Some((key, value)),
children: HashMap::new(),
regexps: Vec::new(),
}
}
pub fn root() -> TrieNode<V> {
TrieNode {
key_value: None,
wildcard: None,
children: HashMap::new(),
regexps: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.key_value.is_none()
&& self.wildcard.is_none()
&& self.regexps.is_empty()
&& self.children.is_empty()
}
pub fn insert(&mut self, key: Key, value: V) -> InsertResult {
if key.is_empty() {
return InsertResult::Failed;
}
if key[..] == b"."[..] {
return InsertResult::Failed;
}
let insert_result = self.insert_recursive(&key, &key, value);
assert_ne!(insert_result, InsertResult::Failed);
insert_result
}
pub fn insert_recursive(&mut self, partial_key: &[u8], key: &Key, value: V) -> InsertResult {
assert_ne!(partial_key, &b""[..]);
if partial_key[partial_key.len() - 1] == b'/' {
let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
if let Some(pos) = pos {
if pos > 0 && partial_key[pos - 1] != b'.' {
return InsertResult::Failed;
}
if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
for t in self.regexps.iter_mut() {
if t.0.as_str() == s {
return t.1.insert_recursive(&partial_key[..pos - 1], key, value);
}
}
if let Ok(r) = Regex::new(s) {
if pos > 0 {
let mut node = TrieNode::root();
let pos = pos - 1;
let res = node.insert_recursive(&partial_key[..pos], key, value);
if res == InsertResult::Ok {
self.regexps.push((r, node));
}
return res;
} else {
let node = TrieNode::new(key.to_vec(), value);
self.regexps.push((r, node));
return InsertResult::Ok;
}
}
}
}
return InsertResult::Failed;
}
let pos = find_last_dot(partial_key);
match pos {
None => {
if self.children.contains_key(partial_key) {
InsertResult::Existing
} else if partial_key == &b"*"[..] {
if self.wildcard.is_some() {
InsertResult::Existing
} else {
self.wildcard = Some((key.to_vec(), value));
InsertResult::Ok
}
} else {
let node = TrieNode::new(key.to_vec(), value);
self.children.insert(partial_key.to_vec(), node);
InsertResult::Ok
}
}
Some(pos) => {
if let Some(child) = self.children.get_mut(&partial_key[pos..]) {
return child.insert_recursive(&partial_key[..pos], key, value);
}
let mut node = TrieNode::root();
let res = node.insert_recursive(&partial_key[..pos], key, value);
if res == InsertResult::Ok {
self.children.insert(partial_key[pos..].to_vec(), node);
}
res
}
}
}
pub fn remove(&mut self, key: &Key) -> RemoveResult {
self.remove_recursive(key)
}
pub fn remove_recursive(&mut self, partial_key: &[u8]) -> RemoveResult {
if partial_key.is_empty() {
if self.key_value.is_some() {
self.key_value = None;
return RemoveResult::Ok;
} else {
return RemoveResult::NotFound;
}
}
if partial_key == &b"*"[..] {
if self.wildcard.is_some() {
self.wildcard = None;
return RemoveResult::Ok;
} else {
return RemoveResult::NotFound;
}
}
if partial_key[partial_key.len() - 1] == b'/' {
let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
if let Some(pos) = pos {
if pos > 0 && partial_key[pos - 1] != b'.' {
return RemoveResult::NotFound;
}
if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
if pos > 0 {
let mut remove_result = RemoveResult::NotFound;
for t in self.regexps.iter_mut() {
if t.0.as_str() == s
&& t.1.remove_recursive(&partial_key[..pos - 1]) == RemoveResult::Ok
{
remove_result = RemoveResult::Ok;
}
}
return remove_result;
} else {
let len = self.regexps.len();
self.regexps.retain(|(r, _)| r.as_str() != s);
if len > self.regexps.len() {
return RemoveResult::Ok;
}
}
}
}
return RemoveResult::NotFound;
}
let pos = find_last_dot(partial_key);
let (prefix, suffix) = match pos {
None => (&b""[..], partial_key),
Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
};
match self.children.get_mut(suffix) {
Some(child) => match child.remove_recursive(prefix) {
RemoveResult::NotFound => RemoveResult::NotFound,
RemoveResult::Ok => {
if child.is_empty() {
self.children.remove(suffix);
}
RemoveResult::Ok
}
},
None => RemoveResult::NotFound,
}
}
pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
if partial_key.is_empty() {
return self.key_value.as_ref();
}
let pos = find_last_dot(partial_key);
let (prefix, suffix) = match pos {
None => (&b""[..], partial_key),
Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
};
match self.children.get(suffix) {
Some(child) => child.lookup(prefix, accept_wildcard),
None => {
if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
self.wildcard.as_ref()
} else {
for (ref regexp, ref child) in self.regexps.iter() {
let suffix = if suffix[0] == b'.' {
&suffix[1..]
} else {
suffix
};
if regexp.is_match(suffix) {
return child.lookup(prefix, accept_wildcard);
}
}
None
}
}
}
}
pub fn lookup_mut(
&mut self,
partial_key: &[u8],
accept_wildcard: bool,
) -> Option<&mut KeyValue<Key, V>> {
if partial_key.is_empty() {
return self.key_value.as_mut();
}
if partial_key == &b"*"[..] {
return self.wildcard.as_mut();
}
if partial_key[partial_key.len() - 1] == b'/' {
let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
if let Some(pos) = pos {
if pos > 0 && partial_key[pos - 1] != b'.' {
return None;
}
if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
for t in self.regexps.iter_mut() {
if t.0.as_str() == s {
return t.1.lookup_mut(&partial_key[..pos - 1], accept_wildcard);
}
}
}
}
return None;
}
let pos = find_last_dot(partial_key);
let (prefix, suffix) = match pos {
None => (&b""[..], partial_key),
Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
};
match self.children.get_mut(suffix) {
Some(child) => child.lookup_mut(prefix, accept_wildcard),
None => {
if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
self.wildcard.as_mut()
} else {
for (ref regexp, ref mut child) in self.regexps.iter_mut() {
let suffix = if suffix[0] == b'.' {
&suffix[1..]
} else {
suffix
};
if regexp.is_match(suffix) {
return child.lookup_mut(prefix, accept_wildcard);
}
}
None
}
}
}
}
pub fn print(&self) {
self.print_recursive(b"", 0)
}
pub fn print_recursive(&self, partial_key: &[u8], indent: u8) {
let raw_prefix: Vec<u8> = iter::repeat(b' ').take(2 * indent as usize).collect();
let prefix = str::from_utf8(&raw_prefix).unwrap();
print!("{}{}: ", prefix, str::from_utf8(partial_key).unwrap());
if let Some((ref key, ref value)) = self.key_value {
print!("({}, {:?}) | ", str::from_utf8(key).unwrap(), value);
} else {
print!("None | ");
}
if let Some((key, value)) = &self.wildcard {
println!("({}, {:?})", str::from_utf8(key).unwrap(), value);
} else {
println!("None");
}
for (child_key, child) in self.children.iter() {
child.print_recursive(child_key, indent + 1);
}
for (regexp, child) in self.regexps.iter() {
child.print_recursive(regexp.as_str().as_bytes(), indent + 1);
}
}
pub fn domain_insert(&mut self, key: Key, value: V) -> InsertResult {
self.insert(key, value)
}
pub fn domain_remove(&mut self, key: &Key) -> RemoveResult {
self.remove(key)
}
pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
self.lookup(key, accept_wildcard)
}
pub fn domain_lookup_mut(
&mut self,
key: &[u8],
accept_wildcard: bool,
) -> Option<&mut KeyValue<Key, V>> {
self.lookup_mut(key, accept_wildcard)
}
pub fn size(&self) -> usize {
::std::mem::size_of::<TrieNode<V>>()
+ ::std::mem::size_of::<Option<KeyValue<Key, V>>>() * 2
+ self
.children
.iter()
.fold(0, |acc, c| acc + c.0.len() + c.1.size())
}
pub fn to_hashmap(&self) -> HashMap<Key, V> {
let mut h = HashMap::new();
self.to_hashmap_recursive(&mut h);
h
}
pub fn to_hashmap_recursive(&self, h: &mut HashMap<Key, V>) {
if let Some((key, value)) = &self.key_value {
h.insert(key.clone(), value.clone());
}
if let Some((key, value)) = &self.wildcard {
h.insert(key.clone(), value.clone());
}
for child in self.children.values() {
child.to_hashmap_recursive(h);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert() {
let mut root: TrieNode<u8> = TrieNode::root();
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"abcd"[..]), 1),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"abce"[..]), 2),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"abgh"[..]), 3),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_lookup(&b"abce"[..], true),
Some(&(b"abce"[..].to_vec(), 2))
);
}
#[test]
fn remove() {
let mut root: TrieNode<u8> = TrieNode::root();
println!("creating root:");
root.print();
println!("adding (abcd, 1)");
assert_eq!(root.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
root.print();
println!("adding (abce, 2)");
assert_eq!(root.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
root.print();
println!("adding (abgh, 3)");
assert_eq!(root.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
root.print();
let mut root2: TrieNode<u8> = TrieNode::root();
assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
assert_eq!(root2.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
println!("before remove");
root.print();
assert_eq!(root.remove(&Vec::from(&b"abce"[..])), RemoveResult::Ok);
println!("after remove");
root.print();
println!("expected");
root2.print();
assert_eq!(root, root2);
assert_eq!(root.remove(&Vec::from(&b"abgh"[..])), RemoveResult::Ok);
println!("after remove");
root.print();
println!("expected");
let mut root3: TrieNode<u8> = TrieNode::root();
assert_eq!(root3.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
root3.print();
assert_eq!(root, root3);
}
#[test]
fn insert_remove_through_regex() {
let mut root: TrieNode<u8> = TrieNode::root();
println!("creating root:");
root.print();
println!("adding (www./.*/.com, 1)");
assert_eq!(
root.insert(Vec::from(&b"www./.*/.com"[..]), 1),
InsertResult::Ok
);
root.print();
println!("adding (www.doc./.*/.com, 2)");
assert_eq!(
root.insert(Vec::from(&b"www.doc./.*/.com"[..]), 2),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_lookup(b"www.sozu.com".as_ref(), false),
Some(&(b"www./.*/.com".to_vec(), 1))
);
assert_eq!(
root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
Some(&(b"www.doc./.*/.com".to_vec(), 2))
);
assert_eq!(
root.domain_remove(&b"www./.*/.com".to_vec()),
RemoveResult::Ok
);
root.print();
assert_eq!(root.domain_lookup(b"www.sozu.com".as_ref(), false), None);
assert_eq!(
root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
Some(&(b"www.doc./.*/.com".to_vec(), 2))
);
}
#[test]
fn add_child_to_leaf() {
let mut root1: TrieNode<u8> = TrieNode::root();
println!("creating root1:");
root1.print();
println!("adding (abcd, 1)");
assert_eq!(root1.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
root1.print();
println!("adding (abce, 2)");
assert_eq!(root1.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
root1.print();
println!("adding (abc, 3)");
assert_eq!(root1.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
println!("root1:");
root1.print();
let mut root2: TrieNode<u8> = TrieNode::root();
assert_eq!(root2.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
assert_eq!(root2.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
println!("root2:");
root2.print();
assert_eq!(root2.remove(&Vec::from(&b"abc"[..])), RemoveResult::Ok);
println!("root2 after,remove:");
root2.print();
let mut expected: TrieNode<u8> = TrieNode::root();
assert_eq!(
expected.insert(Vec::from(&b"abcd"[..]), 1),
InsertResult::Ok
);
assert_eq!(
expected.insert(Vec::from(&b"abce"[..]), 2),
InsertResult::Ok
);
println!("root2 after insert");
root2.print();
println!("expected");
expected.print();
assert_eq!(root2, expected);
}
#[test]
fn domains() {
let mut root: TrieNode<u8> = TrieNode::root();
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"www.example.com"[..]), 1),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"test.example.com"[..]), 2),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"*.alldomains.org"[..]), 3),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"alldomains.org"[..]), 4),
InsertResult::Ok
);
assert_eq!(
root.domain_insert(Vec::from(&b"pouet.alldomains.org"[..]), 5),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"hello.com"[..]), 6),
InsertResult::Ok
);
assert_eq!(
root.domain_insert(Vec::from(&b"*.hello.com"[..]), 7),
InsertResult::Ok
);
assert_eq!(
root.domain_insert(Vec::from(&b"images./cdn[0-9]+/.hello.com"[..]), 8),
InsertResult::Ok
);
root.print();
assert_eq!(
root.domain_insert(Vec::from(&b"/test[0-9]+/.www.hello.com"[..]), 9),
InsertResult::Ok
);
root.print();
assert_eq!(root.domain_lookup(&b"example.com"[..], true), None);
assert_eq!(
root.domain_lookup(&b"blah.test.example.com"[..], true),
None
);
assert_eq!(
root.domain_lookup(&b"www.example.com"[..], true),
Some(&(b"www.example.com"[..].to_vec(), 1))
);
assert_eq!(
root.domain_lookup(&b"alldomains.org"[..], true),
Some(&(b"alldomains.org"[..].to_vec(), 4))
);
assert_eq!(
root.domain_lookup(&b"test.hello.com"[..], true),
Some(&(b"*.hello.com"[..].to_vec(), 7))
);
assert_eq!(
root.domain_lookup(&b"images.cdn10.hello.com"[..], true),
Some(&(b"images./cdn[0-9]+/.hello.com"[..].to_vec(), 8))
);
assert_eq!(
root.domain_lookup(&b"test42.www.hello.com"[..], true),
Some(&(b"/test[0-9]+/.www.hello.com"[..].to_vec(), 9))
);
assert_eq!(
root.domain_lookup(&b"test.alldomains.org"[..], true),
Some(&(b"*.alldomains.org"[..].to_vec(), 3))
);
assert_eq!(
root.domain_lookup(&b"hello.alldomains.org"[..], true),
Some(&(b"*.alldomains.org"[..].to_vec(), 3))
);
assert_eq!(
root.domain_lookup(&b"pouet.alldomains.org"[..], true),
Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
);
assert_eq!(
root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
None
);
assert_eq!(
root.domain_remove(&Vec::from(&b"alldomains.org"[..])),
RemoveResult::Ok
);
println!("after remove");
root.print();
assert_eq!(root.domain_lookup(&b"alldomains.org"[..], true), None);
assert_eq!(
root.domain_lookup(&b"test.alldomains.org"[..], true),
Some(&(b"*.alldomains.org"[..].to_vec(), 3))
);
assert_eq!(
root.domain_lookup(&b"hello.alldomains.org"[..], true),
Some(&(b"*.alldomains.org"[..].to_vec(), 3))
);
assert_eq!(
root.domain_lookup(&b"pouet.alldomains.org"[..], true),
Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
);
assert_eq!(
root.domain_lookup(&b"test.hello.com"[..], true),
Some(&(b"*.hello.com"[..].to_vec(), 7))
);
assert_eq!(
root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
None
);
}
#[test]
fn wildcard() {
let mut root: TrieNode<u8> = TrieNode::root();
root.print();
root.domain_insert("*.clever-cloud.com".as_bytes().to_vec(), 2u8);
root.domain_insert("services.clever-cloud.com".as_bytes().to_vec(), 0u8);
root.domain_insert("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8);
let res = root.domain_lookup(b"test.services.clever-cloud.com", true);
println!("query result: {res:?}");
assert_eq!(
root.domain_lookup(b"pgstudio.services.clever-cloud.com", true),
Some(&("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8))
);
}
fn hm_insert(h: std::collections::HashMap<String, u32>) -> bool {
let mut root: TrieNode<u32> = TrieNode::root();
for (k, v) in h.iter() {
if k.is_empty() {
continue;
}
if k.as_bytes()[0] == b'.' {
continue;
}
if k.contains('/') {
continue;
}
if k == "*" {
continue;
}
assert_eq!(
root.insert(Vec::from(k.as_bytes()), *v),
InsertResult::Ok,
"could not insert ({k}, {v})"
);
}
for (k, v) in h.iter() {
if k.is_empty() {
continue;
}
if k.as_bytes()[0] == b'.' {
continue;
}
if k.contains('/') {
continue;
}
if k == "*" {
continue;
}
match root.lookup(k.as_bytes(), false) {
None => {
println!("did not find key '{k}'");
return false;
}
Some(&(ref k1, v1)) => {
if k.as_bytes() != &k1[..] || *v != v1 {
println!(
"request ({}, {}), got ({}, {})",
k,
v,
str::from_utf8(&k1[..]).unwrap(),
v1
);
return false;
}
}
}
}
true
}
#[test]
fn insert_disappearing_tree() {
let h: std::collections::HashMap<String, u32> = [
(String::from("\n\u{3}"), 0),
(String::from("\n\u{0}"), 1),
(String::from("\n"), 2),
]
.iter()
.cloned()
.collect();
assert!(hm_insert(h));
}
#[test]
fn size() {
assert_size!(TrieNode<u32>, 136);
}
}