1use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::BoxFuture;
18use http::{Request, Response};
19use tonic::Status;
20use tonic::body::BoxBody;
21use tower::{Layer, Service};
22
23use super::{AuthCtx, AuthError, CURRENT_AUTH, TokenExtractor, TokenVerifier};
24
25#[derive(Clone)]
27pub struct AuthLayer {
28 extractor: Arc<dyn TokenExtractor>,
29 verifier: Arc<dyn TokenVerifier>,
30 optional: bool,
32}
33
34impl AuthLayer {
35 pub fn new<E, V>(extractor: E, verifier: V) -> Self
36 where
37 E: TokenExtractor,
38 V: TokenVerifier,
39 {
40 Self {
41 extractor: Arc::new(extractor),
42 verifier: Arc::new(verifier),
43 optional: false,
44 }
45 }
46
47 pub fn optional(mut self) -> Self {
50 self.optional = true;
51 self
52 }
53}
54
55impl<S> Layer<S> for AuthLayer {
56 type Service = AuthService<S>;
57 fn layer(&self, inner: S) -> Self::Service {
58 AuthService {
59 inner,
60 extractor: self.extractor.clone(),
61 verifier: self.verifier.clone(),
62 optional: self.optional,
63 }
64 }
65}
66
67#[derive(Clone)]
68pub struct AuthService<S> {
69 inner: S,
70 extractor: Arc<dyn TokenExtractor>,
71 verifier: Arc<dyn TokenVerifier>,
72 optional: bool,
73}
74
75impl<S> Service<Request<BoxBody>> for AuthService<S>
76where
77 S: Service<Request<BoxBody>, Response = Response<BoxBody>> + Clone + Send + 'static,
78 S::Error: Send + 'static,
79 S::Future: Send + 'static,
80{
81 type Response = Response<BoxBody>;
82 type Error = S::Error;
83 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
84
85 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.inner.poll_ready(cx)
87 }
88
89 fn call(&mut self, mut req: Request<BoxBody>) -> Self::Future {
90 let mut inner = self.inner.clone();
91 let extractor = self.extractor.clone();
92 let verifier = self.verifier.clone();
93 let optional = self.optional;
94
95 Box::pin(async move {
96 let metadata = metadata_from_headers(req.headers());
100
101 let ctx = match extractor.extract(&metadata) {
102 Ok(token) => match verifier.verify(&token).await {
103 Ok(ctx) => ctx,
104 Err(e) => return Ok(error_response(e)),
105 },
106 Err(AuthError::MissingToken) if optional => AuthCtx::anonymous(),
107 Err(e) => return Ok(error_response(e)),
108 };
109
110 req.extensions_mut().insert(ctx.clone());
112
113 CURRENT_AUTH.scope(ctx, inner.call(req)).await
116 })
117 }
118}
119
120fn metadata_from_headers(h: &http::HeaderMap) -> tonic::metadata::MetadataMap {
123 tonic::metadata::MetadataMap::from_headers(h.clone())
124}
125
126fn error_response(e: AuthError) -> Response<BoxBody> {
129 let status: Status = e.into();
130 status.into_http()
131}