use std::{
future::Future,
ops::Add,
pin::Pin,
task::{Context, Poll},
};
use axum::extract::FromRequestParts;
use bytes::Bytes;
use chrono::{Duration, Utc};
use headers::{Authorization, HeaderMapExt};
use http::{request::Parts, Request, StatusCode};
use http_body::combinators::UnsyncBoxBody;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use strum::EnumMessage;
use tower::{Layer, Service};
use tracing::{error, trace, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::{limits::Limits, models::user::AccountTier};
pub const EXP_MINUTES: i64 = 15;
const ISS: &str = "shuttle";
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq, EnumMessage)]
#[serde(rename_all = "snake_case")]
pub enum Scope {
Deployment,
DeploymentPush,
Logs,
Service,
ServiceCreate,
Project,
#[serde(rename = "project_create")] ProjectWrite,
ExtraProjects,
Resources,
ResourcesWrite,
Secret,
SecretWrite,
User,
UserCreate,
AcmeCreate,
CustomDomainCreate,
CustomDomainCertificateRenew,
GatewayCertificateRenew,
Admin,
}
#[derive(Default)]
pub struct ScopeBuilder(Vec<Scope>);
impl ScopeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_admin(mut self) -> Self {
self.0.extend(vec![
Scope::User,
Scope::UserCreate,
Scope::AcmeCreate,
Scope::CustomDomainCreate,
Scope::CustomDomainCertificateRenew,
Scope::GatewayCertificateRenew,
Scope::Admin,
]);
self
}
pub fn with_pro(mut self) -> Self {
self.0.push(Scope::ExtraProjects);
self
}
pub fn with_basic(mut self) -> Self {
self.0.extend(vec![
Scope::Deployment,
Scope::DeploymentPush,
Scope::Logs,
Scope::Service,
Scope::ServiceCreate,
Scope::Project,
Scope::ProjectWrite,
Scope::Resources,
Scope::ResourcesWrite,
Scope::Secret,
Scope::SecretWrite,
]);
self
}
pub fn with_deploy_rights(mut self) -> Self {
self.0.extend(vec![
Scope::DeploymentPush, Scope::Resources, Scope::Service, Scope::ResourcesWrite,
]);
self
}
pub fn build(self) -> Vec<Scope> {
self.0
}
}
impl AccountTier {
pub fn as_permit_account_tier(&self) -> Self {
match self {
Self::Basic
| Self::PendingPaymentPro
| Self::CancelledPro
| Self::Team
| Self::Admin
| Self::Deployer
| Self::Employee => Self::Basic,
Self::Pro => Self::Pro,
}
}
}
impl From<AccountTier> for Vec<Scope> {
fn from(tier: AccountTier) -> Self {
let mut builder = ScopeBuilder::new();
if tier == AccountTier::Deployer {
builder = builder.with_deploy_rights();
} else {
builder = builder.with_basic();
if tier == AccountTier::Admin {
builder = builder.with_admin();
} else if tier == AccountTier::Pro {
builder = builder.with_pro();
}
}
builder.build()
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, Eq, PartialEq)]
pub struct Claim {
pub exp: usize,
iat: usize,
iss: String,
nbf: usize,
pub sub: String,
pub scopes: Vec<Scope>,
pub token: Option<String>,
pub limits: Limits,
pub tier: AccountTier,
}
impl Claim {
pub fn new(
sub: String,
scopes: Vec<Scope>,
tier: AccountTier,
limits: impl Into<Limits>,
) -> Self {
let iat = Utc::now();
let exp = iat.add(Duration::minutes(EXP_MINUTES));
Self {
exp: exp.timestamp() as usize,
iat: iat.timestamp() as usize,
iss: ISS.to_string(),
nbf: iat.timestamp() as usize,
sub,
scopes,
token: None,
limits: limits.into(),
tier,
}
}
pub fn into_token(self, encoding_key: &EncodingKey) -> Result<String, StatusCode> {
if let Some(token) = self.token {
Ok(token)
} else {
encode(
&Header::new(jsonwebtoken::Algorithm::EdDSA),
&self,
encoding_key,
)
.map_err(|err| {
error!(
error = &err as &dyn std::error::Error,
"failed to convert claim to token"
);
match err.kind() {
jsonwebtoken::errors::ErrorKind::Json(_) => StatusCode::INTERNAL_SERVER_ERROR,
jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
})
}
}
pub fn from_token(token: &str, public_key: &[u8]) -> Result<Self, StatusCode> {
let decoding_key = DecodingKey::from_ed_der(public_key);
let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA);
validation.set_issuer(&[ISS]);
trace!("converting token to claim");
let mut claim: Self = decode(token, &decoding_key, &validation)
.map_err(|err| {
error!(
error = &err as &dyn std::error::Error,
"failed to convert token to claim"
);
match err.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
StatusCode::from_u16(499).unwrap() }
jsonwebtoken::errors::ErrorKind::InvalidSignature
| jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName
| jsonwebtoken::errors::ErrorKind::InvalidIssuer
| jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
StatusCode::UNAUTHORIZED
}
jsonwebtoken::errors::ErrorKind::InvalidToken
| jsonwebtoken::errors::ErrorKind::InvalidAlgorithm
| jsonwebtoken::errors::ErrorKind::Base64(_)
| jsonwebtoken::errors::ErrorKind::Json(_)
| jsonwebtoken::errors::ErrorKind::Utf8(_) => StatusCode::BAD_REQUEST,
jsonwebtoken::errors::ErrorKind::MissingAlgorithm => {
StatusCode::INTERNAL_SERVER_ERROR
}
jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
})?
.claims;
claim.token = Some(token.to_string());
Ok(claim)
}
}
#[axum::async_trait]
impl<S> FromRequestParts<S> for Claim {
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let claim = parts
.extensions
.get::<Claim>()
.ok_or(StatusCode::UNAUTHORIZED)?;
Span::current().record("account.user_id", &claim.sub);
Span::current().record("shuttle.user.id", &claim.sub);
trace!(?claim, "got user");
Ok(claim.clone())
}
}
#[pin_project]
pub struct ResponseFuture<F>(#[pin] pub F);
impl<F, Response, Error> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response, Error>>,
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.0.poll(cx)
}
}
#[derive(Clone)]
pub struct ClaimLayer;
impl<S> Layer<S> for ClaimLayer {
type Service = ClaimService<S>;
fn layer(&self, inner: S) -> Self::Service {
ClaimService { inner }
}
}
#[derive(Clone)]
pub struct ClaimService<S> {
inner: S,
}
impl<S, RequestError> Service<Request<UnsyncBoxBody<Bytes, RequestError>>> for ClaimService<S>
where
S: Service<Request<UnsyncBoxBody<Bytes, RequestError>>> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<UnsyncBoxBody<Bytes, RequestError>>) -> Self::Future {
if let Some(claim) = req.extensions().get::<Claim>() {
if let Some(token) = claim.token.clone() {
req.headers_mut()
.typed_insert(Authorization::bearer(&token).expect("to set JWT token"));
}
}
let future = self.inner.call(req);
ResponseFuture(future)
}
}
#[derive(Clone)]
pub struct InjectPropagationLayer;
impl<S> Layer<S> for InjectPropagationLayer {
type Service = InjectPropagation<S>;
fn layer(&self, inner: S) -> Self::Service {
InjectPropagation { inner }
}
}
#[derive(Clone)]
pub struct InjectPropagation<S> {
inner: S,
}
impl<S, Body> Service<Request<Body>> for InjectPropagation<S>
where
S: Service<Request<Body>> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let cx = Span::current().context();
global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut()))
});
let future = self.inner.call(req);
ResponseFuture(future)
}
}