torii_axum/
extractors.rs

1use std::net::SocketAddr;
2
3use axum::{
4    Extension, RequestPartsExt,
5    extract::{ConnectInfo, FromRequestParts},
6    http::{StatusCode, request::Parts},
7};
8use axum_extra::{TypedHeader, extract::CookieJar, headers::UserAgent};
9use torii::{SessionToken, User};
10
11use crate::{error::AuthError, types::ConnectionInfo};
12
13impl<S> FromRequestParts<S> for ConnectionInfo
14where
15    S: Send + Sync,
16{
17    type Rejection = (StatusCode, &'static str);
18
19    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
20        let user_agent = parts
21            .extract::<Option<TypedHeader<UserAgent>>>()
22            .await
23            .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid user agent header"))?
24            .map(|ua| ua.to_string());
25
26        let ip = parts
27            .extract::<ConnectInfo<SocketAddr>>()
28            .await
29            .ok()
30            .map(|addr| addr.ip().to_string());
31
32        Ok(ConnectionInfo { ip, user_agent })
33    }
34}
35
36pub struct AuthUser(pub User);
37
38impl<S> FromRequestParts<S> for AuthUser
39where
40    S: Send + Sync,
41{
42    type Rejection = AuthError;
43
44    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
45        let Extension(user): Extension<User> =
46            parts.extract().await.map_err(|_| AuthError::Unauthorized)?;
47
48        Ok(AuthUser(user))
49    }
50}
51
52pub struct OptionalAuthUser(pub Option<User>);
53
54impl<S> FromRequestParts<S> for OptionalAuthUser
55where
56    S: Send + Sync,
57{
58    type Rejection = (StatusCode, &'static str);
59
60    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61        let user = parts.extensions.get::<User>().cloned();
62
63        Ok(OptionalAuthUser(user))
64    }
65}
66
67pub struct SessionTokenFromCookie(pub Option<SessionToken>);
68
69impl<S> FromRequestParts<S> for SessionTokenFromCookie
70where
71    S: Send + Sync,
72{
73    type Rejection = (StatusCode, &'static str);
74
75    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
76        let jar = parts
77            .extract::<CookieJar>()
78            .await
79            .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid cookie header"))?;
80
81        let session_token = jar
82            .get("session_id")
83            .and_then(|cookie| cookie.value().parse::<String>().ok())
84            .map(|token| SessionToken::new(&token));
85
86        Ok(SessionTokenFromCookie(session_token))
87    }
88}
89
90pub struct SessionTokenFromBearer(pub Option<SessionToken>);
91
92impl<S> FromRequestParts<S> for SessionTokenFromBearer
93where
94    S: Send + Sync,
95{
96    type Rejection = (StatusCode, &'static str);
97
98    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
99        let session_token = parts
100            .headers
101            .get("Authorization")
102            .and_then(|header| header.to_str().ok())
103            .and_then(|header| header.strip_prefix("Bearer "))
104            .map(SessionToken::new);
105
106        Ok(SessionTokenFromBearer(session_token))
107    }
108}
109
110pub struct SessionTokenFromRequest(pub Option<SessionToken>);
111
112impl<S> FromRequestParts<S> for SessionTokenFromRequest
113where
114    S: Send + Sync,
115{
116    type Rejection = (StatusCode, &'static str);
117
118    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
119        // Try Bearer token first, then fall back to cookie
120        if let Some(token) = parts
121            .headers
122            .get("Authorization")
123            .and_then(|header| header.to_str().ok())
124            .and_then(|header| header.strip_prefix("Bearer "))
125        {
126            return Ok(SessionTokenFromRequest(Some(SessionToken::new(token))));
127        }
128
129        // Fall back to cookie
130        let jar = parts
131            .extract::<CookieJar>()
132            .await
133            .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid cookie header"))?;
134
135        let session_token = jar
136            .get("session_id")
137            .and_then(|cookie| cookie.value().parse::<String>().ok())
138            .map(|token| SessionToken::new(&token));
139
140        Ok(SessionTokenFromRequest(session_token))
141    }
142}