snap_control/server/
auth.rs1use std::{
17 fmt::Display,
18 future::Future,
19 pin::Pin,
20 sync::Arc,
21 task::{Context, Poll},
22};
23
24use axum::body::Body;
25use http::{Request, Response};
26use thiserror::Error;
27use tower::{BoxError, Layer, Service};
28
29use crate::server::token_verifier::SnapTokenVerifier;
30
31#[derive(Clone)]
32pub(crate) struct AuthMiddlewareLayer {
33 verifier: Arc<SnapTokenVerifier>,
34}
35
36impl AuthMiddlewareLayer {
37 pub(crate) fn new(verifier: SnapTokenVerifier) -> Self {
38 Self {
39 verifier: Arc::new(verifier),
40 }
41 }
42}
43
44impl<S> Layer<S> for AuthMiddlewareLayer {
45 type Service = AuthMiddleware<S>;
46
47 fn layer(&self, inner: S) -> Self::Service {
48 AuthMiddleware::new(inner, self.verifier.clone())
49 }
50}
51
52#[derive(Clone)]
53pub(crate) struct AuthMiddleware<S> {
54 inner: S,
55 verifier: Arc<SnapTokenVerifier>,
56}
57
58impl<S> AuthMiddleware<S> {
59 pub(crate) fn new(inner: S, verifier: Arc<SnapTokenVerifier>) -> Self {
60 Self { inner, verifier }
61 }
62}
63
64impl<S> Service<Request<Body>> for AuthMiddleware<S>
65where
66 S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
67 S::Error: Into<BoxError>,
68 S::Future: Send + 'static,
69{
70 type Response = Response<Body>;
71 type Error = BoxError;
72 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
73
74 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 self.inner.poll_ready(cx).map_err(Into::into)
76 }
77
78 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
79 let token = match extract_bearer_token(&request) {
80 Ok(token) => token,
81 Err(err) => {
82 tracing::debug!(%err, "Extract bearer token");
83 return Box::pin(async { Ok(build_unauthorized_response(err)) });
84 }
85 };
86
87 let verifier = self.verifier.clone();
88 let mut inner = self.inner.clone();
89 Box::pin(async move {
90 match verifier.verify(&token).await {
91 Ok(token_claims) => {
92 request.extensions_mut().insert(token_claims);
93 inner.call(request).await.map_err(Into::into)
94 }
95 Err(err) => {
96 tracing::debug!(%err, "Invalid Token");
97 Ok(build_unauthorized_response(err))
98 }
99 }
100 })
101 }
102}
103
104fn build_unauthorized_response<E: Display>(err: E) -> Response<Body> {
105 Response::builder()
106 .status(http::StatusCode::UNAUTHORIZED)
107 .body(Body::from(format!("SNAP Token validation failed: {err}")))
108 .expect("no fail")
109}
110
111pub fn extract_bearer_token(req: &Request<Body>) -> Result<String, ExtractBearerTokenError> {
113 let auth_header = match req.headers().get("authorization") {
114 Some(header) => header,
115 None => return Err(ExtractBearerTokenError::AuthHeaderMissing),
116 };
117
118 let auth_str = match auth_header.to_str() {
119 Ok(str) => str,
120 Err(_) => return Err(ExtractBearerTokenError::AuthHeaderInvalidUtf8),
121 };
122
123 match auth_str.strip_prefix("Bearer ") {
124 Some(token) => Ok(token.to_string()),
125 None => Err(ExtractBearerTokenError::AuthHeaderNotBearer),
126 }
127}
128
129#[derive(Debug, Error)]
131pub enum ExtractBearerTokenError {
132 #[error("authorization header is missing")]
134 AuthHeaderMissing,
135 #[error("authorization header is not valid UTF-8")]
137 AuthHeaderInvalidUtf8,
138 #[error("authorization header is not a bearer token")]
140 AuthHeaderNotBearer,
141}