snap_control/server/
auth.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP control plane API authentication middleware.
15
16use std::{
17    fmt::Display,
18    future::Future,
19    pin::Pin,
20    task::{Context, Poll},
21    time::SystemTime,
22};
23
24use axum::body::Body;
25use http::{Request, Response};
26use jsonwebtoken::DecodingKey;
27use scion_sdk_token_validator::validator::{TokenValidator, Validator};
28use snap_tokens::snap_token::SnapTokenClaims;
29use thiserror::Error;
30use tower::{BoxError, Layer, Service};
31use tracing::debug;
32
33#[derive(Clone)]
34pub(crate) struct AuthMiddlewareLayer {
35    validator: Validator<SnapTokenClaims>,
36}
37
38impl AuthMiddlewareLayer {
39    pub(crate) fn new(dec: DecodingKey) -> Self {
40        Self {
41            validator: Validator::new(dec, Some(&["snap"])),
42        }
43    }
44}
45
46impl<S> Layer<S> for AuthMiddlewareLayer {
47    type Service = AuthMiddleware<S>;
48
49    fn layer(&self, inner: S) -> Self::Service {
50        AuthMiddleware::new(inner, self.validator.clone())
51    }
52}
53
54#[derive(Clone)]
55pub(crate) struct AuthMiddleware<S> {
56    inner: S,
57    validator: Validator<SnapTokenClaims>,
58}
59
60impl<S> AuthMiddleware<S> {
61    pub(crate) fn new(inner: S, validator: Validator<SnapTokenClaims>) -> Self {
62        Self { inner, validator }
63    }
64}
65
66impl<S> Service<Request<Body>> for AuthMiddleware<S>
67where
68    S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
69    S::Error: Into<BoxError>,
70    S::Future: Send + 'static,
71{
72    type Response = Response<Body>;
73    type Error = BoxError;
74    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
75
76    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77        self.inner.poll_ready(cx).map_err(Into::into)
78    }
79
80    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
81        let token = match extract_bearer_token(&request) {
82            Ok(token) => token,
83            Err(err) => {
84                debug!(error=%err, "extract bearer token");
85                return Box::pin(async { Ok(build_unauthorized_response(err)) });
86            }
87        };
88
89        match self.validator.validate(SystemTime::now(), token.as_str()) {
90            Ok(token_claims) => {
91                request.extensions_mut().insert(token_claims);
92                let mut inner = self.inner.clone();
93                Box::pin(async move { inner.call(request).await.map_err(Into::into) })
94            }
95            Err(err) => {
96                debug!(error=%err, "Invalid Token");
97                Box::pin(async { Ok(build_unauthorized_response(err)) })
98            }
99        }
100    }
101}
102
103fn build_unauthorized_response<E: Display>(err: E) -> Response<Body> {
104    Response::builder()
105        .status(http::StatusCode::UNAUTHORIZED)
106        .body(Body::from(format!("SNAP Token validation failed: {err}")))
107        .expect("no fail")
108}
109
110/// Extracts the bearer token from the `Authorization` header of the request.
111pub fn extract_bearer_token(req: &Request<Body>) -> Result<String, ExtractBearerTokenError> {
112    let auth_header = match req.headers().get("authorization") {
113        Some(header) => header,
114        None => return Err(ExtractBearerTokenError::AuthHeaderMissing),
115    };
116
117    let auth_str = match auth_header.to_str() {
118        Ok(str) => str,
119        Err(_) => return Err(ExtractBearerTokenError::AuthHeaderInvalidUtf8),
120    };
121
122    match auth_str.strip_prefix("Bearer ") {
123        Some(token) => Ok(token.to_string()),
124        None => Err(ExtractBearerTokenError::AuthHeaderNotBearer),
125    }
126}
127
128/// Bearer token extraction error.
129#[derive(Debug, Error)]
130pub enum ExtractBearerTokenError {
131    /// Authorization header is missing.
132    #[error("authorization header is missing")]
133    AuthHeaderMissing,
134    /// Authorization header is not valid UTF-8.
135    #[error("authorization header is not valid UTF-8")]
136    AuthHeaderInvalidUtf8,
137    /// Authorization header is not a Bearer token.
138    #[error("authorization header is not a bearer token")]
139    AuthHeaderNotBearer,
140}