use crate::context::Push;
use futures::future::FutureExt;
use hyper::header::AUTHORIZATION;
use hyper::service::Service;
use hyper::{HeaderMap, Request};
pub use hyper_old_types::header::Authorization as Header;
use hyper_old_types::header::Header as HeaderTrait;
pub use hyper_old_types::header::{Basic, Bearer};
use hyper_old_types::header::{Raw, Scheme};
use std::collections::BTreeSet;
use std::marker::PhantomData;
use std::string::ToString;
use std::task::Context;
use std::task::Poll;
#[derive(Clone, Debug, PartialEq)]
pub enum Scopes {
Some(BTreeSet<String>),
All,
}
#[derive(Clone, Debug, PartialEq)]
pub struct Authorization {
pub subject: String,
pub scopes: Scopes,
pub issuer: Option<String>,
}
#[derive(Clone, Debug, PartialEq)]
pub enum AuthData {
Basic(Basic),
Bearer(Bearer),
ApiKey(String),
}
impl AuthData {
pub fn basic(username: &str, password: &str) -> Self {
AuthData::Basic(Basic {
username: username.to_owned(),
password: Some(password.to_owned()),
})
}
pub fn bearer(token: &str) -> Self {
AuthData::Bearer(Bearer {
token: token.to_owned(),
})
}
pub fn apikey(apikey: &str) -> Self {
AuthData::ApiKey(apikey.to_owned())
}
}
pub trait RcBound: Push<Option<Authorization>> + Send + 'static {}
impl<T> RcBound for T where T: Push<Option<Authorization>> + Send + 'static {}
#[derive(Debug)]
pub struct MakeAllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
inner: T,
subject: String,
marker: PhantomData<RC>,
}
impl<T, RC> MakeAllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
pub fn new<U: Into<String>>(inner: T, subject: U) -> Self {
MakeAllowAllAuthenticator {
inner,
subject: subject.into(),
marker: PhantomData,
}
}
}
impl<Inner, RC, Target> Service<Target> for MakeAllowAllAuthenticator<Inner, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
Inner: Service<Target>,
Inner::Future: Send + 'static,
{
type Error = Inner::Error;
type Response = AllowAllAuthenticator<Inner::Response, RC>;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, target: Target) -> Self::Future {
let subject = self.subject.clone();
Box::pin(
self.inner
.call(target)
.map(|s| Ok(AllowAllAuthenticator::new(s?, subject))),
)
}
}
#[derive(Debug)]
pub struct AllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
inner: T,
subject: String,
marker: PhantomData<RC>,
}
impl<T, RC> AllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
pub fn new<U: Into<String>>(inner: T, subject: U) -> Self {
AllowAllAuthenticator {
inner,
subject: subject.into(),
marker: PhantomData,
}
}
}
impl<T, RC> Clone for AllowAllAuthenticator<T, RC>
where
T: Clone,
RC: RcBound,
RC::Result: Send + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
subject: self.subject.clone(),
marker: PhantomData,
}
}
}
impl<T, B, RC> Service<(Request<B>, RC)> for AllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
T: Service<(Request<B>, RC::Result)>,
{
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: (Request<B>, RC)) -> Self::Future {
let (request, context) = req;
let context = context.push(Some(Authorization {
subject: self.subject.clone(),
scopes: Scopes::All,
issuer: None,
}));
self.inner.call((request, context))
}
}
pub fn from_headers<S: Scheme>(headers: &HeaderMap) -> Option<S>
where
S: std::str::FromStr + 'static,
S::Err: 'static,
{
headers
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| Header::<S>::parse_header(&Raw::from(s)).ok())
.map(|a| a.0)
}
pub fn api_key_from_header(headers: &HeaderMap, header: &str) -> Option<String> {
headers
.get(header)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{ContextBuilder, Has};
use crate::EmptyContext;
use hyper::service::Service;
use hyper::{Body, Response};
struct MakeTestService;
type ReqWithAuth = (
Request<Body>,
ContextBuilder<Option<Authorization>, EmptyContext>,
);
impl<Target> Service<Target> for MakeTestService {
type Response = TestService;
type Error = ();
type Future = futures::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _target: Target) -> Self::Future {
futures::future::ok(TestService)
}
}
struct TestService;
impl Service<ReqWithAuth> for TestService {
type Response = Response<Body>;
type Error = String;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: ReqWithAuth) -> Self::Future {
Box::pin(async move {
let auth: &Option<Authorization> = req.1.get();
let expected = Some(Authorization {
subject: "foo".to_string(),
scopes: Scopes::All,
issuer: None,
});
if *auth == expected {
Ok(Response::new(Body::empty()))
} else {
Err(format!("{:?} != {:?}", auth, expected))
}
})
}
}
#[tokio::test]
async fn test_make_service() {
let make_svc = MakeTestService;
let mut a: MakeAllowAllAuthenticator<_, EmptyContext> =
MakeAllowAllAuthenticator::new(make_svc, "foo");
let mut service = a.call(&()).await.unwrap();
let response = service
.call((
Request::get("http://localhost")
.body(Body::empty())
.unwrap(),
EmptyContext::default(),
))
.await;
response.unwrap();
}
}