Skip to main content

typeway_server/
auth.rs

1//! Type-level authentication for endpoints.
2//!
3//! [`Protected<Auth, E>`] wraps an endpoint type to declare that it requires
4//! authentication. The compiler enforces that the handler's first argument
5//! is the auth extractor type.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use typeway_server::auth::Protected;
11//!
12//! // Tag endpoints as protected in the API type
13//! type API = (
14//!     GetEndpoint<TagsPath, TagsResponse>,                           // public
15//!     Protected<AuthUser, GetEndpoint<UserPath, UserResponse>>,      // auth required
16//! );
17//!
18//! // Handlers for protected endpoints MUST accept AuthUser as first arg.
19//! async fn get_user(auth: AuthUser, state: State<Db>) -> Json<User> { ... }
20//!
21//! // Wire up with bind_auth!():
22//! Server::<API>::new((
23//!     bind!(get_tags),             // public
24//!     bind_auth!(get_user),        // protected — AuthUser enforced
25//! ));
26//! ```
27
28use std::future::Future;
29use std::marker::PhantomData;
30use std::pin::Pin;
31
32use typeway_core::ApiSpec;
33
34use crate::body::BoxBody;
35use crate::extract::{FromRequest, FromRequestParts};
36use crate::handler::BoxedHandler;
37use crate::handler_for::{BindableEndpoint, BoundHandler};
38use crate::response::IntoResponse;
39
40/// An endpoint that requires authentication.
41///
42/// `Auth` is the authentication extractor type (e.g., `AuthUser`).
43/// `E` is the underlying endpoint type.
44///
45/// Handlers bound to `Protected` endpoints via `bind_auth!()` must accept
46/// `Auth` as their first argument. This is enforced at compile time by
47/// the `AuthHandler` trait — using `bind!()` (without auth) for a
48/// `Protected` endpoint produces a type mismatch.
49pub struct Protected<Auth, E> {
50    _marker: PhantomData<(Auth, E)>,
51}
52
53impl<Auth, E: ApiSpec> ApiSpec for Protected<Auth, E> {}
54
55// Protected<Auth, E> delegates AllProvided to the inner endpoint E.
56// This allows EffectfulServer to work with APIs containing Protected endpoints.
57impl<Auth, E, Provided, Idx> typeway_core::effects::AllProvided<Provided, Idx>
58    for Protected<Auth, E>
59where
60    E: typeway_core::effects::AllProvided<Provided, Idx>,
61{
62}
63
64// NOTE: BindableEndpoint is intentionally NOT implemented for Protected.
65// This means bind!() cannot be used with Protected endpoints — only
66// bind_auth!() works. This is the compile-time enforcement mechanism.
67
68// ---------------------------------------------------------------------------
69// AuthHandler trait — enforces Auth as first argument
70// ---------------------------------------------------------------------------
71
72/// A handler that takes `Auth` as its first argument.
73///
74/// This is separate from `Handler<Args>` to ensure that `Protected`
75/// endpoints can only be bound with handlers that accept the auth type.
76/// The trait is implemented for async functions where the first argument
77/// is `Auth: FromRequestParts`.
78pub trait AuthHandler<Auth, Args>: Clone + Send + Sync + 'static {
79    fn call(
80        self,
81        parts: http::request::Parts,
82        body: bytes::Bytes,
83    ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>>;
84}
85
86// Auth + no other args
87impl<F, Fut, Res, Auth> AuthHandler<Auth, ()> for F
88where
89    F: FnOnce(Auth) -> Fut + Clone + Send + Sync + 'static,
90    Fut: Future<Output = Res> + Send,
91    Res: IntoResponse,
92    Auth: FromRequestParts + 'static,
93{
94    fn call(
95        self,
96        parts: http::request::Parts,
97        _body: bytes::Bytes,
98    ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
99        Box::pin(async move {
100            let auth = match Auth::from_request_parts(&parts) {
101                Ok(v) => v,
102                Err(e) => return e.into_response(),
103            };
104            self(auth).await.into_response()
105        })
106    }
107}
108
109// Generate impls for Auth + N FromRequestParts args
110macro_rules! impl_auth_handler_parts {
111    ([$($T:ident),+], [$($t:ident),+]) => {
112        #[allow(non_snake_case)]
113        impl<F, Fut, Res, Auth, $($T,)+> AuthHandler<Auth, ($($T,)+)> for F
114        where
115            F: FnOnce(Auth, $($T,)+) -> Fut + Clone + Send + Sync + 'static,
116            Fut: Future<Output = Res> + Send,
117            Res: IntoResponse,
118            Auth: FromRequestParts + 'static,
119            $($T: FromRequestParts + 'static,)+
120        {
121            fn call(
122                self,
123                parts: http::request::Parts,
124                _body: bytes::Bytes,
125            ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
126                Box::pin(async move {
127                    let auth = match Auth::from_request_parts(&parts) {
128                        Ok(v) => v,
129                        Err(e) => return e.into_response(),
130                    };
131                    $(
132                        let $t = match $T::from_request_parts(&parts) {
133                            Ok(v) => v,
134                            Err(e) => return e.into_response(),
135                        };
136                    )+
137                    self(auth, $($t,)+).await.into_response()
138                })
139            }
140        }
141    };
142}
143
144impl_auth_handler_parts!([T1], [t1]);
145impl_auth_handler_parts!([T1, T2], [t1, t2]);
146impl_auth_handler_parts!([T1, T2, T3], [t1, t2, t3]);
147impl_auth_handler_parts!([T1, T2, T3, T4], [t1, t2, t3, t4]);
148impl_auth_handler_parts!([T1, T2, T3, T4, T5], [t1, t2, t3, t4, t5]);
149impl_auth_handler_parts!([T1, T2, T3, T4, T5, T6], [t1, t2, t3, t4, t5, t6]);
150
151// Generate impls for Auth + N FromRequestParts args + body extractor (last arg)
152macro_rules! impl_auth_handler_with_body {
153    ([], []) => {
154        impl<F, Fut, Res, Auth, B> AuthHandler<Auth, AuthWithBody<(), B>> for F
155        where
156            F: FnOnce(Auth, B) -> Fut + Clone + Send + Sync + 'static,
157            Fut: Future<Output = Res> + Send,
158            Res: IntoResponse,
159            Auth: FromRequestParts + 'static,
160            B: FromRequest + 'static,
161        {
162            fn call(
163                self,
164                parts: http::request::Parts,
165                body: bytes::Bytes,
166            ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
167                Box::pin(async move {
168                    let auth = match Auth::from_request_parts(&parts) {
169                        Ok(v) => v,
170                        Err(e) => return e.into_response(),
171                    };
172                    let b = match B::from_request(&parts, body).await {
173                        Ok(v) => v,
174                        Err(e) => return e.into_response(),
175                    };
176                    self(auth, b).await.into_response()
177                })
178            }
179        }
180    };
181    ([$($T:ident),+], [$($t:ident),+]) => {
182        #[allow(non_snake_case)]
183        impl<F, Fut, Res, Auth, $($T,)+ B> AuthHandler<Auth, AuthWithBody<($($T,)+), B>> for F
184        where
185            F: FnOnce(Auth, $($T,)+ B) -> Fut + Clone + Send + Sync + 'static,
186            Fut: Future<Output = Res> + Send,
187            Res: IntoResponse,
188            Auth: FromRequestParts + 'static,
189            $($T: FromRequestParts + 'static,)+
190            B: FromRequest + 'static,
191        {
192            fn call(
193                self,
194                parts: http::request::Parts,
195                body: bytes::Bytes,
196            ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
197                Box::pin(async move {
198                    let auth = match Auth::from_request_parts(&parts) {
199                        Ok(v) => v,
200                        Err(e) => return e.into_response(),
201                    };
202                    $(
203                        let $t = match $T::from_request_parts(&parts) {
204                            Ok(v) => v,
205                            Err(e) => return e.into_response(),
206                        };
207                    )+
208                    let b = match B::from_request(&parts, body).await {
209                        Ok(v) => v,
210                        Err(e) => return e.into_response(),
211                    };
212                    self(auth, $($t,)+ b).await.into_response()
213                })
214            }
215        }
216    };
217}
218
219/// Marker for auth handlers with a body extractor as last arg.
220pub struct AuthWithBody<Parts, Body>(PhantomData<(Parts, Body)>);
221
222impl_auth_handler_with_body!([], []);
223impl_auth_handler_with_body!([T1], [t1]);
224impl_auth_handler_with_body!([T1, T2], [t1, t2]);
225impl_auth_handler_with_body!([T1, T2, T3], [t1, t2, t3]);
226impl_auth_handler_with_body!([T1, T2, T3, T4], [t1, t2, t3, t4]);
227impl_auth_handler_with_body!([T1, T2, T3, T4, T5], [t1, t2, t3, t4, t5]);
228
229// ---------------------------------------------------------------------------
230// bind_protected — uses AuthHandler instead of Handler
231// ---------------------------------------------------------------------------
232
233/// Bind a handler to a `Protected<Auth, E>` endpoint.
234///
235/// The handler's first argument MUST be `Auth`. This is enforced by the
236/// `AuthHandler<Auth, Args>` trait — the compiler rejects handlers that
237/// don't take `Auth` as their first argument.
238/// Trait to extract binding info from the inner endpoint of a Protected type.
239pub trait ProtectedEndpoint {
240    type Auth;
241    type Inner: BindableEndpoint;
242}
243
244impl<Auth, E: BindableEndpoint> ProtectedEndpoint for Protected<Auth, E> {
245    type Auth = Auth;
246    type Inner = E;
247}
248
249pub fn bind_protected<P, H, Args>(handler: H) -> BoundHandler<P>
250where
251    P: ProtectedEndpoint,
252    P::Auth: FromRequestParts + 'static,
253    P::Inner: BindableEndpoint,
254    H: AuthHandler<P::Auth, Args>,
255    Args: 'static,
256{
257    let method = P::Inner::method();
258    let pattern = P::Inner::pattern();
259    let match_fn = P::Inner::match_fn();
260
261    // Type-erase via AuthHandler::call
262    let boxed: BoxedHandler = std::sync::Arc::new(move |parts, body| {
263        let h = handler.clone();
264        h.call(parts, body)
265    });
266
267    BoundHandler::new(method, pattern, match_fn, boxed)
268}
269
270/// Convenience macro for binding protected handlers.
271#[macro_export]
272macro_rules! bind_auth {
273    ($handler:expr) => {
274        $crate::auth::bind_protected::<_, _, _>($handler)
275    };
276}