use hyper::{
http::{HeaderName, HeaderValue},
Uri,
};
use regex::Regex;
use crate::{utils::query_kv::QueryKvIter, BoxError, Request, SgBody};
#[derive(Debug, Clone)]
pub enum HttpPathMatchRewrite {
Exact(String, Option<String>),
Prefix(String, Option<String>),
RegExp(Regex, Option<String>),
}
impl HttpPathMatchRewrite {
pub fn prefix<S: Into<String>>(s: S) -> Self {
Self::Prefix(s.into(), None)
}
pub fn exact<S: Into<String>>(s: S) -> Self {
Self::Exact(s.into(), None)
}
pub fn regex(re: Regex) -> Self {
Self::RegExp(re, None)
}
pub fn replace_with(self, replace: impl Into<String>) -> Self {
match self {
HttpPathMatchRewrite::Exact(path, _) => HttpPathMatchRewrite::Exact(path, Some(replace.into())),
HttpPathMatchRewrite::Prefix(path, _) => HttpPathMatchRewrite::Prefix(path, Some(replace.into())),
HttpPathMatchRewrite::RegExp(re, _) => HttpPathMatchRewrite::RegExp(re, Some(replace.into())),
}
}
pub fn rewrite(&self, path: &str) -> Option<String> {
match self {
HttpPathMatchRewrite::Exact(_, Some(replace)) => {
if replace.eq_ignore_ascii_case(path) {
Some(replace.clone())
} else {
None
}
}
HttpPathMatchRewrite::Prefix(prefix, Some(replace)) => {
fn not_empty(s: &&str) -> bool {
!s.is_empty()
}
let mut path_segments = path.split('/').filter(not_empty);
let mut prefix_segments = prefix.split('/').filter(not_empty);
loop {
match (path_segments.next(), prefix_segments.next()) {
(Some(path_seg), Some(prefix_seg)) => {
if !path_seg.eq_ignore_ascii_case(prefix_seg) {
return None;
}
}
(None, None) => {
let mut new_path = String::from("/");
new_path.push_str(replace.trim_start_matches('/'));
return Some(new_path);
}
(Some(rest_path), None) => {
let mut new_path = String::from("/");
let replace_value = replace.trim_matches('/');
new_path.push_str(replace_value);
if !replace_value.is_empty() {
new_path.push('/');
}
new_path.push_str(rest_path);
for seg in path_segments {
new_path.push('/');
new_path.push_str(seg);
}
if path.ends_with('/') {
new_path.push('/')
}
return Some(new_path);
}
(None, Some(_)) => return None,
}
}
}
HttpPathMatchRewrite::RegExp(re, Some(replace)) => Some(re.replace(path, replace).to_string()),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum SgHttpHeaderMatchRewritePolicy {
Exact(HeaderValue, Option<HeaderValue>),
Regular(Regex, Option<String>),
}
#[derive(Debug, Clone)]
pub struct SgHttpHeaderMatchRewrite {
pub header_name: HeaderName,
pub policy: SgHttpHeaderMatchRewritePolicy,
}
impl SgHttpHeaderMatchRewrite {
pub fn regex(name: impl Into<HeaderName>, re: Regex) -> Self {
Self {
header_name: name.into(),
policy: SgHttpHeaderMatchRewritePolicy::Regular(re, None),
}
}
pub fn exact(name: impl Into<HeaderName>, value: impl Into<HeaderValue>) -> Self {
Self {
header_name: name.into(),
policy: SgHttpHeaderMatchRewritePolicy::Exact(value.into(), None),
}
}
pub fn rewrite(&self, req: &Request<SgBody>) -> Option<HeaderValue> {
let header_value = req.headers().get(&self.header_name)?;
let s = header_value.to_str().ok()?;
match &self.policy {
SgHttpHeaderMatchRewritePolicy::Exact(_, Some(replace)) => {
if s == replace {
Some(replace.clone())
} else {
None
}
}
SgHttpHeaderMatchRewritePolicy::Regular(re, Some(replace)) => {
if re.is_match(s) {
Some(HeaderValue::from_str(replace).ok()?)
} else {
None
}
}
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum SgHttpQueryMatchPolicy {
Exact(String),
Regular(Regex),
}
#[derive(Debug, Clone)]
pub struct HttpQueryMatch {
pub name: String,
pub policy: SgHttpQueryMatchPolicy,
}
#[derive(Default, Debug, Clone)]
pub struct HttpMethodMatch(pub String);
#[derive(Default, Debug, Clone)]
pub struct HttpRouteMatch {
pub path: Option<HttpPathMatchRewrite>,
pub header: Option<Vec<SgHttpHeaderMatchRewrite>>,
pub query: Option<Vec<HttpQueryMatch>>,
pub method: Option<Vec<HttpMethodMatch>>,
}
impl HttpRouteMatch {
pub fn rewrite(&self, req: &mut Request<SgBody>) -> Result<(), BoxError> {
if let Some(headers_match) = self.header.as_ref() {
for header_match in headers_match {
if let (Some(replace), Some(v)) = (header_match.rewrite(req), req.headers_mut().get_mut(&header_match.header_name)) {
*v = replace;
}
}
}
let path_match = self.path.as_ref();
if let (Some(pq), Some(path_match)) = (req.uri().path_and_query(), path_match) {
let old_path = pq.path();
if let Some(new_path) = path_match.rewrite(old_path) {
let mut uri_part = req.uri().clone().into_parts();
tracing::debug!("[Sg.Rewrite] rewrite path from {} to {}", old_path, new_path);
let mut new_pq = new_path;
if let Some(query) = pq.query() {
new_pq.push('?');
new_pq.push_str(query)
}
let new_pq = hyper::http::uri::PathAndQuery::from_maybe_shared(new_pq)?;
uri_part.path_and_query = Some(new_pq);
*req.uri_mut() = Uri::from_parts(uri_part)?;
}
}
Ok(())
}
}
pub trait MatchRequest {
fn match_request(&self, req: &Request<SgBody>) -> bool;
}
impl MatchRequest for HttpQueryMatch {
fn match_request(&self, req: &Request<SgBody>) -> bool {
let query = req.uri().query();
if let Some(query) = query {
let mut iter = QueryKvIter::new(query);
match &self.policy {
SgHttpQueryMatchPolicy::Exact(query) => iter.any(|(k, v)| k == self.name && v == Some(query)),
SgHttpQueryMatchPolicy::Regular(query) => iter.any(|(k, v)| k == self.name && v.map_or(false, |v| query.is_match(v))),
}
} else {
false
}
}
}
impl From<HttpPathMatchRewrite> for HttpRouteMatch {
fn from(val: HttpPathMatchRewrite) -> Self {
HttpRouteMatch {
path: Some(val),
header: None,
query: None,
method: None,
}
}
}
impl From<SgHttpHeaderMatchRewrite> for HttpRouteMatch {
fn from(value: SgHttpHeaderMatchRewrite) -> Self {
HttpRouteMatch {
path: None,
header: Some(vec![value]),
query: None,
method: None,
}
}
}
impl From<HttpQueryMatch> for HttpRouteMatch {
fn from(value: HttpQueryMatch) -> Self {
HttpRouteMatch {
path: None,
header: None,
query: Some(vec![value]),
method: None,
}
}
}
impl From<HttpMethodMatch> for HttpRouteMatch {
fn from(value: HttpMethodMatch) -> Self {
HttpRouteMatch {
path: None,
header: None,
query: None,
method: Some(vec![value]),
}
}
}
impl MatchRequest for HttpPathMatchRewrite {
fn match_request(&self, req: &Request<SgBody>) -> bool {
match self {
HttpPathMatchRewrite::Exact(path, _) => req.uri().path() == path,
HttpPathMatchRewrite::Prefix(path, _) => {
let mut path_segments = req.uri().path().split('/').filter(|s| !s.is_empty());
let mut prefix_segments = path.split('/').filter(|s| !s.is_empty());
loop {
match (path_segments.next(), prefix_segments.next()) {
(Some(path_seg), Some(prefix_seg)) => {
if !path_seg.eq_ignore_ascii_case(prefix_seg) {
return false;
}
}
(_, None) => return true,
(None, Some(_)) => return false,
}
}
}
HttpPathMatchRewrite::RegExp(path, _) => path.is_match(req.uri().path()),
}
}
}
impl MatchRequest for SgHttpHeaderMatchRewrite {
fn match_request(&self, req: &Request<SgBody>) -> bool {
match &self.policy {
SgHttpHeaderMatchRewritePolicy::Exact(header, _) => req.headers().get(&self.header_name).is_some_and(|v| v == header),
SgHttpHeaderMatchRewritePolicy::Regular(header, _) => {
req.headers().iter().any(|(k, v)| k.as_str() == self.header_name && v.to_str().map_or(false, |v| header.is_match(v)))
}
}
}
}
impl MatchRequest for HttpMethodMatch {
fn match_request(&self, req: &Request<SgBody>) -> bool {
req.method().as_str().eq_ignore_ascii_case(&self.0)
}
}
impl MatchRequest for HttpRouteMatch {
fn match_request(&self, req: &Request<SgBody>) -> bool {
self.path.match_request(req) && self.header.match_request(req) && self.query.match_request(req) && self.method.match_request(req)
}
}
impl<T> MatchRequest for Option<T>
where
T: MatchRequest,
{
fn match_request(&self, req: &Request<SgBody>) -> bool {
self.as_ref().map(|r| MatchRequest::match_request(r, req)).unwrap_or(true)
}
}
impl<T> MatchRequest for Vec<T>
where
T: MatchRequest,
{
fn match_request(&self, req: &Request<SgBody>) -> bool {
self.iter().any(|query| query.match_request(req))
}
}
#[test]
fn test_match_path() {
let req = Request::builder().uri("https://localhost:8080/child/subApp").body(SgBody::empty()).expect("invalid request");
assert!(HttpPathMatchRewrite::Prefix("/child/subApp".into(), None).match_request(&req));
}