Skip to main content

snap_control/server/
auth.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP control plane API authentication middleware.
15
16use std::{
17    fmt::Display,
18    future::Future,
19    pin::Pin,
20    task::{Context, Poll},
21    time::SystemTime,
22};
23
24use axum::body::Body;
25use http::{Request, Response};
26use jsonwebtoken::DecodingKey;
27use scion_sdk_token_validator::validator::{TokenValidator, Validator};
28use snap_tokens::AnyClaims;
29use thiserror::Error;
30use tower::{BoxError, Layer, Service};
31
32#[derive(Clone)]
33pub(crate) struct AuthMiddlewareLayer {
34    validator: Validator<AnyClaims>,
35}
36
37impl AuthMiddlewareLayer {
38    pub(crate) fn new(dec: DecodingKey) -> Self {
39        Self {
40            validator: Validator::new(dec, Some(&["snap"])),
41        }
42    }
43}
44
45impl<S> Layer<S> for AuthMiddlewareLayer {
46    type Service = AuthMiddleware<S>;
47
48    fn layer(&self, inner: S) -> Self::Service {
49        AuthMiddleware::new(inner, self.validator.clone())
50    }
51}
52
53#[derive(Clone)]
54pub(crate) struct AuthMiddleware<S> {
55    inner: S,
56    validator: Validator<AnyClaims>,
57}
58
59impl<S> AuthMiddleware<S> {
60    pub(crate) fn new(inner: S, validator: Validator<AnyClaims>) -> Self {
61        Self { inner, validator }
62    }
63}
64
65impl<S> Service<Request<Body>> for AuthMiddleware<S>
66where
67    S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
68    S::Error: Into<BoxError>,
69    S::Future: Send + 'static,
70{
71    type Response = Response<Body>;
72    type Error = BoxError;
73    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
74
75    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        self.inner.poll_ready(cx).map_err(Into::into)
77    }
78
79    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
80        let token = match extract_bearer_token(&request) {
81            Ok(token) => token,
82            Err(err) => {
83                tracing::debug!(%err, "Extract bearer token");
84                return Box::pin(async { Ok(build_unauthorized_response(err)) });
85            }
86        };
87
88        match self.validator.validate(SystemTime::now(), token.as_str()) {
89            Ok(token_claims) => {
90                match token_claims {
91                    AnyClaims::V0(claims) => {
92                        request.extensions_mut().insert(claims);
93                    }
94                    AnyClaims::V1(claims) => {
95                        request.extensions_mut().insert(claims);
96                    }
97                }
98                let mut inner = self.inner.clone();
99                Box::pin(async move { inner.call(request).await.map_err(Into::into) })
100            }
101            Err(err) => {
102                tracing::debug!(%err, "Invalid Token");
103                Box::pin(async { Ok(build_unauthorized_response(err)) })
104            }
105        }
106    }
107}
108
109fn build_unauthorized_response<E: Display>(err: E) -> Response<Body> {
110    Response::builder()
111        .status(http::StatusCode::UNAUTHORIZED)
112        .body(Body::from(format!("SNAP Token validation failed: {err}")))
113        .expect("no fail")
114}
115
116/// Extracts the bearer token from the `Authorization` header of the request.
117pub fn extract_bearer_token(req: &Request<Body>) -> Result<String, ExtractBearerTokenError> {
118    let auth_header = match req.headers().get("authorization") {
119        Some(header) => header,
120        None => return Err(ExtractBearerTokenError::AuthHeaderMissing),
121    };
122
123    let auth_str = match auth_header.to_str() {
124        Ok(str) => str,
125        Err(_) => return Err(ExtractBearerTokenError::AuthHeaderInvalidUtf8),
126    };
127
128    match auth_str.strip_prefix("Bearer ") {
129        Some(token) => Ok(token.to_string()),
130        None => Err(ExtractBearerTokenError::AuthHeaderNotBearer),
131    }
132}
133
134/// Bearer token extraction error.
135#[derive(Debug, Error)]
136pub enum ExtractBearerTokenError {
137    /// Authorization header is missing.
138    #[error("authorization header is missing")]
139    AuthHeaderMissing,
140    /// Authorization header is not valid UTF-8.
141    #[error("authorization header is not valid UTF-8")]
142    AuthHeaderInvalidUtf8,
143    /// Authorization header is not a Bearer token.
144    #[error("authorization header is not a bearer token")]
145    AuthHeaderNotBearer,
146}